config.go 27 KB


  1. // Copyright 2021 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package config
  15. import (
  16. "context"
  17. "crypto/md5"
  18. "encoding/hex"
  19. "fmt"
  20. "os"
  21. "reflect"
  22. "strings"
  23. "sync"
  24. "time"
  25. "github.com/go-playground/locales/en"
  26. ut "github.com/go-playground/universal-translator"
  27. "github.com/go-playground/validator/v10"
  28. en_translations "github.com/go-playground/validator/v10/translations/en"
  29. "github.com/juju/errors"
  30. "github.com/samber/lo"
  31. "github.com/spf13/viper"
  32. "github.com/zhenghaoz/gorse/base/log"
  33. "github.com/zhenghaoz/gorse/storage"
  34. "go.opentelemetry.io/otel/exporters/jaeger"
  35. "go.opentelemetry.io/otel/exporters/otlp/otlptrace"
  36. "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
  37. "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
  38. "go.opentelemetry.io/otel/exporters/zipkin"
  39. "go.opentelemetry.io/otel/sdk/resource"
  40. tracesdk "go.opentelemetry.io/otel/sdk/trace"
  41. semconv "go.opentelemetry.io/otel/semconv/v1.8.0"
  42. "go.opentelemetry.io/otel/trace"
  43. "go.uber.org/zap"
  44. )
  45. const (
  46. NeighborTypeAuto = "auto"
  47. NeighborTypeSimilar = "similar"
  48. NeighborTypeRelated = "related"
  49. )
  50. // Config is the configuration for the engine.
  51. type Config struct {
  52. Database DatabaseConfig `mapstructure:"database"`
  53. Master MasterConfig `mapstructure:"master"`
  54. Server ServerConfig `mapstructure:"server"`
  55. Recommend RecommendConfig `mapstructure:"recommend"`
  56. Tracing TracingConfig `mapstructure:"tracing"`
  57. Experimental ExperimentalConfig `mapstructure:"experimental"`
  58. }
  59. // DatabaseConfig is the configuration for the database.
  60. type DatabaseConfig struct {
  61. DataStore string `mapstructure:"data_store" validate:"required,data_store"` // database for data store
  62. CacheStore string `mapstructure:"cache_store" validate:"required,cache_store"` // database for cache store
  63. TablePrefix string `mapstructure:"table_prefix"`
  64. DataTablePrefix string `mapstructure:"data_table_prefix"`
  65. CacheTablePrefix string `mapstructure:"cache_table_prefix"`
  66. }
  67. // MasterConfig is the configuration for the master.
  68. type MasterConfig struct {
  69. Port int `mapstructure:"port" validate:"gte=0"` // master port
  70. Host string `mapstructure:"host"` // master host
  71. HttpPort int `mapstructure:"http_port" validate:"gte=0"` // HTTP port
  72. HttpHost string `mapstructure:"http_host"` // HTTP host
  73. HttpCorsDomains []string `mapstructure:"http_cors_domains"` // add allowed cors domains
  74. HttpCorsMethods []string `mapstructure:"http_cors_methods"` // add allowed cors methods
  75. NumJobs int `mapstructure:"n_jobs" validate:"gt=0"` // number of working jobs
  76. MetaTimeout time.Duration `mapstructure:"meta_timeout" validate:"gt=0"` // cluster meta timeout (second)
  77. DashboardUserName string `mapstructure:"dashboard_user_name"` // dashboard user name
  78. DashboardPassword string `mapstructure:"dashboard_password"` // dashboard password
  79. DashboardAuthServer string `mapstructure:"dashboard_auth_server"` // dashboard auth server
  80. DashboardRedacted bool `mapstructure:"dashboard_redacted"`
  81. AdminAPIKey string `mapstructure:"admin_api_key"`
  82. }
  83. // ServerConfig is the configuration for the server.
  84. type ServerConfig struct {
  85. APIKey string `mapstructure:"api_key"` // default number of returned items
  86. DefaultN int `mapstructure:"default_n" validate:"gt=0"` // secret key for RESTful APIs (SSL required)
  87. ClockError time.Duration `mapstructure:"clock_error" validate:"gte=0"` // clock error in the cluster in seconds
  88. AutoInsertUser bool `mapstructure:"auto_insert_user"` // insert new users while inserting feedback
  89. AutoInsertItem bool `mapstructure:"auto_insert_item"` // insert new items while inserting feedback
  90. CacheExpire time.Duration `mapstructure:"cache_expire" validate:"gt=0"` // server-side cache expire time
  91. }
  92. // RecommendConfig is the configuration of recommendation setup.
  93. type RecommendConfig struct {
  94. CacheSize int `mapstructure:"cache_size" validate:"gt=0"`
  95. CacheExpire time.Duration `mapstructure:"cache_expire" validate:"gt=0"`
  96. ActiveUserTTL int `mapstructure:"active_user_ttl" validate:"gte=0"`
  97. DataSource DataSourceConfig `mapstructure:"data_source"`
  98. Popular PopularConfig `mapstructure:"popular"`
  99. UserNeighbors NeighborsConfig `mapstructure:"user_neighbors"`
  100. ItemNeighbors NeighborsConfig `mapstructure:"item_neighbors"`
  101. Collaborative CollaborativeConfig `mapstructure:"collaborative"`
  102. Replacement ReplacementConfig `mapstructure:"replacement"`
  103. Offline OfflineConfig `mapstructure:"offline"`
  104. Online OnlineConfig `mapstructure:"online"`
  105. }
  106. type DataSourceConfig struct {
  107. PositiveFeedbackTypes []string `mapstructure:"positive_feedback_types"` // positive feedback type
  108. ReadFeedbackTypes []string `mapstructure:"read_feedback_types"` // feedback type for read event
  109. PositiveFeedbackTTL uint `mapstructure:"positive_feedback_ttl" validate:"gte=0"` // time-to-live of positive feedbacks
  110. ItemTTL uint `mapstructure:"item_ttl" validate:"gte=0"` // item-to-live of items
  111. }
  112. type PopularConfig struct {
  113. PopularWindow time.Duration `mapstructure:"popular_window" validate:"gte=0"`
  114. }
  115. type NeighborsConfig struct {
  116. NeighborType string `mapstructure:"neighbor_type" validate:"oneof=auto similar related ''"`
  117. EnableIndex bool `mapstructure:"enable_index"`
  118. IndexRecall float32 `mapstructure:"index_recall" validate:"gt=0"`
  119. IndexFitEpoch int `mapstructure:"index_fit_epoch" validate:"gt=0"`
  120. }
  121. type CollaborativeConfig struct {
  122. ModelFitPeriod time.Duration `mapstructure:"model_fit_period" validate:"gt=0"`
  123. ModelSearchPeriod time.Duration `mapstructure:"model_search_period" validate:"gt=0"`
  124. ModelSearchEpoch int `mapstructure:"model_search_epoch" validate:"gt=0"`
  125. ModelSearchTrials int `mapstructure:"model_search_trials" validate:"gt=0"`
  126. EnableModelSizeSearch bool `mapstructure:"enable_model_size_search"`
  127. EnableIndex bool `mapstructure:"enable_index"`
  128. IndexRecall float32 `mapstructure:"index_recall" validate:"gt=0"`
  129. IndexFitEpoch int `mapstructure:"index_fit_epoch" validate:"gt=0"`
  130. }
  131. type ReplacementConfig struct {
  132. EnableReplacement bool `mapstructure:"enable_replacement"`
  133. PositiveReplacementDecay float64 `mapstructure:"positive_replacement_decay" validate:"gt=0"`
  134. ReadReplacementDecay float64 `mapstructure:"read_replacement_decay" validate:"gt=0"`
  135. }
  136. type OfflineConfig struct {
  137. CheckRecommendPeriod time.Duration `mapstructure:"check_recommend_period" validate:"gt=0"`
  138. RefreshRecommendPeriod time.Duration `mapstructure:"refresh_recommend_period" validate:"gt=0"`
  139. ExploreRecommend map[string]float64 `mapstructure:"explore_recommend"`
  140. EnableLatestRecommend bool `mapstructure:"enable_latest_recommend"`
  141. EnablePopularRecommend bool `mapstructure:"enable_popular_recommend"`
  142. EnableUserBasedRecommend bool `mapstructure:"enable_user_based_recommend"`
  143. EnableItemBasedRecommend bool `mapstructure:"enable_item_based_recommend"`
  144. EnableColRecommend bool `mapstructure:"enable_collaborative_recommend"`
  145. EnableClickThroughPrediction bool `mapstructure:"enable_click_through_prediction"`
  146. exploreRecommendLock sync.RWMutex
  147. }
  148. type OnlineConfig struct {
  149. FallbackRecommend []string `mapstructure:"fallback_recommend"`
  150. NumFeedbackFallbackItemBased int `mapstructure:"num_feedback_fallback_item_based" validate:"gt=0"`
  151. }
  152. type TracingConfig struct {
  153. EnableTracing bool `mapstructure:"enable_tracing"`
  154. Exporter string `mapstructure:"exporter" validate:"oneof=jaeger zipkin otlp otlphttp"`
  155. CollectorEndpoint string `mapstructure:"collector_endpoint"`
  156. Sampler string `mapstructure:"sampler"`
  157. Ratio float64 `mapstructure:"ratio"`
  158. }
  159. type ExperimentalConfig struct {
  160. EnableDeepLearning bool `mapstructure:"enable_deep_learning"`
  161. DeepLearningBatchSize int `mapstructure:"deep_learning_batch_size"`
  162. }
  163. func GetDefaultConfig() *Config {
  164. return &Config{
  165. Master: MasterConfig{
  166. Port: 8086,
  167. Host: "0.0.0.0",
  168. HttpPort: 8088,
  169. HttpHost: "0.0.0.0",
  170. HttpCorsDomains: []string{".*"},
  171. HttpCorsMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH"},
  172. NumJobs: 1,
  173. MetaTimeout: 10 * time.Second,
  174. },
  175. Server: ServerConfig{
  176. DefaultN: 10,
  177. ClockError: 5 * time.Second,
  178. AutoInsertUser: true,
  179. AutoInsertItem: true,
  180. CacheExpire: 10 * time.Second,
  181. },
  182. Recommend: RecommendConfig{
  183. CacheSize: 100,
  184. CacheExpire: 72 * time.Hour,
  185. Popular: PopularConfig{
  186. PopularWindow: 180 * 24 * time.Hour,
  187. },
  188. UserNeighbors: NeighborsConfig{
  189. NeighborType: "auto",
  190. EnableIndex: true,
  191. IndexRecall: 0.8,
  192. IndexFitEpoch: 3,
  193. },
  194. ItemNeighbors: NeighborsConfig{
  195. NeighborType: "auto",
  196. EnableIndex: true,
  197. IndexRecall: 0.8,
  198. IndexFitEpoch: 3,
  199. },
  200. Collaborative: CollaborativeConfig{
  201. ModelFitPeriod: 60 * time.Minute,
  202. ModelSearchPeriod: 180 * time.Minute,
  203. ModelSearchEpoch: 100,
  204. ModelSearchTrials: 10,
  205. EnableIndex: true,
  206. IndexRecall: 0.9,
  207. IndexFitEpoch: 3,
  208. },
  209. Replacement: ReplacementConfig{
  210. EnableReplacement: false,
  211. PositiveReplacementDecay: 0.8,
  212. ReadReplacementDecay: 0.6,
  213. },
  214. Offline: OfflineConfig{
  215. CheckRecommendPeriod: time.Minute,
  216. RefreshRecommendPeriod: 120 * time.Hour,
  217. EnableLatestRecommend: false,
  218. EnablePopularRecommend: false,
  219. EnableUserBasedRecommend: false,
  220. EnableItemBasedRecommend: false,
  221. EnableColRecommend: true,
  222. EnableClickThroughPrediction: false,
  223. },
  224. Online: OnlineConfig{
  225. FallbackRecommend: []string{"latest"},
  226. NumFeedbackFallbackItemBased: 10,
  227. },
  228. },
  229. Tracing: TracingConfig{
  230. Exporter: "jaeger",
  231. Sampler: "always",
  232. },
  233. Experimental: ExperimentalConfig{
  234. DeepLearningBatchSize: 128,
  235. },
  236. }
  237. }
  238. func (config *Config) Now() *time.Time {
  239. return lo.ToPtr(time.Now().Add(config.Server.ClockError))
  240. }
  241. func (config *Config) UserNeighborDigest() string {
  242. var builder strings.Builder
  243. builder.WriteString(fmt.Sprintf("%v-%v", config.Recommend.UserNeighbors.NeighborType, config.Recommend.UserNeighbors.EnableIndex))
  244. // feedback option
  245. if lo.Contains([]string{"auto", "related"}, config.Recommend.UserNeighbors.NeighborType) {
  246. builder.WriteString(fmt.Sprintf("-%s", strings.Join(config.Recommend.DataSource.PositiveFeedbackTypes, "-")))
  247. } else {
  248. builder.WriteString("-")
  249. }
  250. // index option
  251. if config.Recommend.UserNeighbors.EnableIndex {
  252. builder.WriteString(fmt.Sprintf("-%v-%v", config.Recommend.UserNeighbors.IndexRecall, config.Recommend.UserNeighbors.IndexFitEpoch))
  253. } else {
  254. builder.WriteString("--")
  255. }
  256. digest := md5.Sum([]byte(builder.String()))
  257. return hex.EncodeToString(digest[:])
  258. }
  259. func (config *Config) ItemNeighborDigest() string {
  260. var builder strings.Builder
  261. builder.WriteString(fmt.Sprintf("%v-%v", config.Recommend.ItemNeighbors.NeighborType, config.Recommend.ItemNeighbors.EnableIndex))
  262. // feedback option
  263. if lo.Contains([]string{"auto", "related"}, config.Recommend.ItemNeighbors.NeighborType) {
  264. builder.WriteString(fmt.Sprintf("-%s", strings.Join(config.Recommend.DataSource.PositiveFeedbackTypes, "-")))
  265. } else {
  266. builder.WriteString("-")
  267. }
  268. // index option
  269. if config.Recommend.ItemNeighbors.EnableIndex {
  270. builder.WriteString(fmt.Sprintf("-%v-%v", config.Recommend.ItemNeighbors.IndexRecall, config.Recommend.ItemNeighbors.IndexFitEpoch))
  271. } else {
  272. builder.WriteString("--")
  273. }
  274. digest := md5.Sum([]byte(builder.String()))
  275. return hex.EncodeToString(digest[:])
  276. }
  277. type digestOptions struct {
  278. userNeighborDigest string
  279. itemNeighborDigest string
  280. enableCollaborative bool
  281. enableRanking bool
  282. }
  283. type DigestOption func(option *digestOptions)
  284. func WithUserNeighborDigest(digest string) DigestOption {
  285. return func(option *digestOptions) {
  286. option.userNeighborDigest = digest
  287. }
  288. }
  289. func WithItemNeighborDigest(digest string) DigestOption {
  290. return func(option *digestOptions) {
  291. option.itemNeighborDigest = digest
  292. }
  293. }
  294. func WithCollaborative(v bool) DigestOption {
  295. return func(option *digestOptions) {
  296. option.enableCollaborative = v
  297. }
  298. }
  299. func WithRanking(v bool) DigestOption {
  300. return func(option *digestOptions) {
  301. option.enableRanking = v
  302. }
  303. }
  304. func (config *Config) OfflineRecommendDigest(option ...DigestOption) string {
  305. options := digestOptions{
  306. userNeighborDigest: config.UserNeighborDigest(),
  307. itemNeighborDigest: config.ItemNeighborDigest(),
  308. enableCollaborative: config.Recommend.Offline.EnableColRecommend,
  309. enableRanking: config.Recommend.Offline.EnableClickThroughPrediction,
  310. }
  311. lo.ForEach(option, func(opt DigestOption, _ int) {
  312. opt(&options)
  313. })
  314. var builder strings.Builder
  315. config.Recommend.Offline.Lock()
  316. builder.WriteString(fmt.Sprintf("%v-%v-%v-%v-%v-%v-%v-%v",
  317. config.Recommend.Offline.ExploreRecommend,
  318. config.Recommend.Offline.EnableLatestRecommend,
  319. config.Recommend.Offline.EnablePopularRecommend,
  320. config.Recommend.Offline.EnableUserBasedRecommend,
  321. config.Recommend.Offline.EnableItemBasedRecommend,
  322. options.enableCollaborative,
  323. options.enableRanking,
  324. config.Recommend.Replacement.EnableReplacement,
  325. ))
  326. config.Recommend.Offline.UnLock()
  327. if config.Recommend.Offline.EnablePopularRecommend {
  328. builder.WriteString(fmt.Sprintf("-%v", config.Recommend.Popular.PopularWindow))
  329. }
  330. if config.Recommend.Offline.EnableUserBasedRecommend {
  331. builder.WriteString(fmt.Sprintf("-%v", options.userNeighborDigest))
  332. }
  333. if config.Recommend.Offline.EnableItemBasedRecommend {
  334. builder.WriteString(fmt.Sprintf("-%v", options.itemNeighborDigest))
  335. }
  336. if options.enableCollaborative {
  337. builder.WriteString(fmt.Sprintf("-%v", config.Recommend.Collaborative.EnableIndex))
  338. if config.Recommend.Collaborative.EnableIndex {
  339. builder.WriteString(fmt.Sprintf("-%v-%v",
  340. config.Recommend.Collaborative.IndexRecall, config.Recommend.Collaborative.IndexFitEpoch))
  341. }
  342. }
  343. if config.Recommend.Replacement.EnableReplacement {
  344. builder.WriteString(fmt.Sprintf("-%v-%v",
  345. config.Recommend.Replacement.PositiveReplacementDecay, config.Recommend.Replacement.ReadReplacementDecay))
  346. }
  347. digest := md5.Sum([]byte(builder.String()))
  348. return hex.EncodeToString(digest[:])
  349. }
  350. func (config *OfflineConfig) Lock() {
  351. config.exploreRecommendLock.Lock()
  352. }
  353. func (config *OfflineConfig) UnLock() {
  354. config.exploreRecommendLock.Unlock()
  355. }
  356. func (config *OfflineConfig) GetExploreRecommend(key string) (value float64, exist bool) {
  357. if config == nil {
  358. return 0.0, false
  359. }
  360. config.exploreRecommendLock.RLock()
  361. defer config.exploreRecommendLock.RUnlock()
  362. value, exist = config.ExploreRecommend[key]
  363. return
  364. }
  365. func (config *TracingConfig) NewTracerProvider() (trace.TracerProvider, error) {
  366. if !config.EnableTracing {
  367. return trace.NewNoopTracerProvider(), nil
  368. }
  369. var exporter tracesdk.SpanExporter
  370. var err error
  371. switch config.Exporter {
  372. case "jaeger":
  373. exporter, err = jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(config.CollectorEndpoint)))
  374. if err != nil {
  375. return nil, errors.Trace(err)
  376. }
  377. case "zipkin":
  378. exporter, err = zipkin.New(config.CollectorEndpoint)
  379. if err != nil {
  380. return nil, errors.Trace(err)
  381. }
  382. case "otlp":
  383. client := otlptracegrpc.NewClient(otlptracegrpc.WithInsecure(), otlptracegrpc.WithEndpoint(config.CollectorEndpoint))
  384. exporter, err = otlptrace.New(context.TODO(), client)
  385. if err != nil {
  386. return nil, errors.Trace(err)
  387. }
  388. case "otlphttp":
  389. client := otlptracehttp.NewClient(otlptracehttp.WithInsecure(), otlptracehttp.WithEndpoint(config.CollectorEndpoint))
  390. exporter, err = otlptrace.New(context.TODO(), client)
  391. if err != nil {
  392. return nil, errors.Trace(err)
  393. }
  394. default:
  395. return nil, errors.NotSupportedf("exporter %s", config.Exporter)
  396. }
  397. var sampler tracesdk.Sampler
  398. switch config.Sampler {
  399. case "always":
  400. sampler = tracesdk.AlwaysSample()
  401. case "never":
  402. sampler = tracesdk.NeverSample()
  403. case "ratio":
  404. sampler = tracesdk.TraceIDRatioBased(config.Ratio)
  405. default:
  406. return nil, errors.NotSupportedf("sampler %s", config.Sampler)
  407. }
  408. return tracesdk.NewTracerProvider(
  409. tracesdk.WithSampler(sampler),
  410. tracesdk.WithBatcher(exporter),
  411. tracesdk.WithResource(resource.NewWithAttributes(
  412. semconv.SchemaURL,
  413. semconv.ServiceNameKey.String("gorse"),
  414. )),
  415. ), nil
  416. }
  417. func (config *TracingConfig) Equal(other TracingConfig) bool {
  418. if config == nil {
  419. return false
  420. }
  421. return config.EnableTracing == other.EnableTracing &&
  422. config.Exporter == other.Exporter &&
  423. config.CollectorEndpoint == other.CollectorEndpoint &&
  424. config.Sampler == other.Sampler &&
  425. config.Ratio == other.Ratio
  426. }
  427. func setDefault() {
  428. defaultConfig := GetDefaultConfig()
  429. // [master]
  430. viper.SetDefault("master.port", defaultConfig.Master.Port)
  431. viper.SetDefault("master.host", defaultConfig.Master.Host)
  432. viper.SetDefault("master.http_port", defaultConfig.Master.HttpPort)
  433. viper.SetDefault("master.http_host", defaultConfig.Master.HttpHost)
  434. viper.SetDefault("master.http_cors_domains", defaultConfig.Master.HttpCorsDomains)
  435. viper.SetDefault("master.http_cors_methods", defaultConfig.Master.HttpCorsMethods)
  436. viper.SetDefault("master.n_jobs", defaultConfig.Master.NumJobs)
  437. viper.SetDefault("master.meta_timeout", defaultConfig.Master.MetaTimeout)
  438. // [server]
  439. viper.SetDefault("server.api_key", defaultConfig.Server.APIKey)
  440. viper.SetDefault("server.default_n", defaultConfig.Server.DefaultN)
  441. viper.SetDefault("server.clock_error", defaultConfig.Server.ClockError)
  442. viper.SetDefault("server.auto_insert_user", defaultConfig.Server.AutoInsertUser)
  443. viper.SetDefault("server.auto_insert_item", defaultConfig.Server.AutoInsertItem)
  444. viper.SetDefault("server.cache_expire", defaultConfig.Server.CacheExpire)
  445. // [recommend]
  446. viper.SetDefault("recommend.cache_size", defaultConfig.Recommend.CacheSize)
  447. viper.SetDefault("recommend.cache_expire", defaultConfig.Recommend.CacheExpire)
  448. // [recommend.popular]
  449. viper.SetDefault("recommend.popular.popular_window", defaultConfig.Recommend.Popular.PopularWindow)
  450. // [recommend.user_neighbors]
  451. viper.SetDefault("recommend.user_neighbors.neighbor_type", defaultConfig.Recommend.UserNeighbors.NeighborType)
  452. viper.SetDefault("recommend.user_neighbors.enable_index", defaultConfig.Recommend.UserNeighbors.EnableIndex)
  453. viper.SetDefault("recommend.user_neighbors.index_recall", defaultConfig.Recommend.UserNeighbors.IndexRecall)
  454. viper.SetDefault("recommend.user_neighbors.index_fit_epoch", defaultConfig.Recommend.UserNeighbors.IndexFitEpoch)
  455. // [recommend.item_neighbors]
  456. viper.SetDefault("recommend.item_neighbors.neighbor_type", defaultConfig.Recommend.ItemNeighbors.NeighborType)
  457. viper.SetDefault("recommend.item_neighbors.enable_index", defaultConfig.Recommend.ItemNeighbors.EnableIndex)
  458. viper.SetDefault("recommend.item_neighbors.index_recall", defaultConfig.Recommend.ItemNeighbors.IndexRecall)
  459. viper.SetDefault("recommend.item_neighbors.index_fit_epoch", defaultConfig.Recommend.ItemNeighbors.IndexFitEpoch)
  460. // [recommend.collaborative]
  461. viper.SetDefault("recommend.collaborative.model_fit_period", defaultConfig.Recommend.Collaborative.ModelFitPeriod)
  462. viper.SetDefault("recommend.collaborative.model_search_period", defaultConfig.Recommend.Collaborative.ModelSearchPeriod)
  463. viper.SetDefault("recommend.collaborative.model_search_epoch", defaultConfig.Recommend.Collaborative.ModelSearchEpoch)
  464. viper.SetDefault("recommend.collaborative.model_search_trials", defaultConfig.Recommend.Collaborative.ModelSearchTrials)
  465. viper.SetDefault("recommend.collaborative.enable_index", defaultConfig.Recommend.Collaborative.EnableIndex)
  466. viper.SetDefault("recommend.collaborative.index_recall", defaultConfig.Recommend.Collaborative.IndexRecall)
  467. viper.SetDefault("recommend.collaborative.index_fit_epoch", defaultConfig.Recommend.Collaborative.IndexFitEpoch)
  468. // [recommend.replacement]
  469. viper.SetDefault("recommend.replacement.enable_replacement", defaultConfig.Recommend.Replacement.EnableReplacement)
  470. viper.SetDefault("recommend.replacement.positive_replacement_decay", defaultConfig.Recommend.Replacement.PositiveReplacementDecay)
  471. viper.SetDefault("recommend.replacement.read_replacement_decay", defaultConfig.Recommend.Replacement.ReadReplacementDecay)
  472. // [recommend.offline]
  473. viper.SetDefault("recommend.offline.check_recommend_period", defaultConfig.Recommend.Offline.CheckRecommendPeriod)
  474. viper.SetDefault("recommend.offline.refresh_recommend_period", defaultConfig.Recommend.Offline.RefreshRecommendPeriod)
  475. viper.SetDefault("recommend.offline.enable_latest_recommend", defaultConfig.Recommend.Offline.EnableLatestRecommend)
  476. viper.SetDefault("recommend.offline.enable_popular_recommend", defaultConfig.Recommend.Offline.EnablePopularRecommend)
  477. viper.SetDefault("recommend.offline.enable_user_based_recommend", defaultConfig.Recommend.Offline.EnableUserBasedRecommend)
  478. viper.SetDefault("recommend.offline.enable_item_based_recommend", defaultConfig.Recommend.Offline.EnableItemBasedRecommend)
  479. viper.SetDefault("recommend.offline.enable_collaborative_recommend", defaultConfig.Recommend.Offline.EnableColRecommend)
  480. viper.SetDefault("recommend.offline.enable_click_through_prediction", defaultConfig.Recommend.Offline.EnableClickThroughPrediction)
  481. // [recommend.online]
  482. viper.SetDefault("recommend.online.fallback_recommend", defaultConfig.Recommend.Online.FallbackRecommend)
  483. viper.SetDefault("recommend.online.num_feedback_fallback_item_based", defaultConfig.Recommend.Online.NumFeedbackFallbackItemBased)
  484. // [tracing]
  485. viper.SetDefault("tracing.exporter", defaultConfig.Tracing.Exporter)
  486. viper.SetDefault("tracing.sampler", defaultConfig.Tracing.Sampler)
  487. // [experimental]
  488. viper.SetDefault("experimental.deep_learning_batch_size", defaultConfig.Experimental.DeepLearningBatchSize)
  489. }
  490. type configBinding struct {
  491. key string
  492. env string
  493. }
  494. // LoadConfig loads configuration from toml file.
  495. func LoadConfig(path string, oneModel bool) (*Config, error) {
  496. // set default config
  497. setDefault()
  498. // bind environment bindings
  499. bindings := []configBinding{
  500. {"database.cache_store", "GORSE_CACHE_STORE"},
  501. {"database.data_store", "GORSE_DATA_STORE"},
  502. {"database.table_prefix", "GORSE_TABLE_PREFIX"},
  503. {"database.cache_table_prefix", "GORSE_CACHE_TABLE_PREFIX"},
  504. {"database.data_table_prefix", "GORSE_DATA_TABLE_PREFIX"},
  505. {"master.port", "GORSE_MASTER_PORT"},
  506. {"master.host", "GORSE_MASTER_HOST"},
  507. {"master.http_port", "GORSE_MASTER_HTTP_PORT"},
  508. {"master.http_host", "GORSE_MASTER_HTTP_HOST"},
  509. {"master.n_jobs", "GORSE_MASTER_JOBS"},
  510. {"master.dashboard_user_name", "GORSE_DASHBOARD_USER_NAME"},
  511. {"master.dashboard_password", "GORSE_DASHBOARD_PASSWORD"},
  512. {"master.dashboard_auth_server", "GORSE_DASHBOARD_AUTH_SERVER"},
  513. {"master.dashboard_redacted", "GORSE_DASHBOARD_REDACTED"},
  514. {"master.admin_api_key", "GORSE_ADMIN_API_KEY"},
  515. {"server.api_key", "GORSE_SERVER_API_KEY"},
  516. }
  517. for _, binding := range bindings {
  518. err := viper.BindEnv(binding.key, binding.env)
  519. if err != nil {
  520. log.Logger().Fatal("failed to bind a Viper key to a ENV variable", zap.Error(err))
  521. }
  522. }
  523. // check if file exist
  524. if _, err := os.Stat(path); err != nil {
  525. return nil, errors.Trace(err)
  526. }
  527. // load config file
  528. viper.SetConfigFile(path)
  529. if err := viper.ReadInConfig(); err != nil {
  530. return nil, errors.Trace(err)
  531. }
  532. // unmarshal config file
  533. var conf Config
  534. if err := viper.Unmarshal(&conf); err != nil {
  535. return nil, errors.Trace(err)
  536. }
  537. // validate config file
  538. if err := conf.Validate(oneModel); err != nil {
  539. return nil, errors.Trace(err)
  540. }
  541. // apply table prefix
  542. if conf.Database.CacheTablePrefix == "" {
  543. conf.Database.CacheTablePrefix = conf.Database.TablePrefix
  544. }
  545. if conf.Database.DataTablePrefix == "" {
  546. conf.Database.DataTablePrefix = conf.Database.TablePrefix
  547. }
  548. return &conf, nil
  549. }
  550. func (config *Config) Validate(oneModel bool) error {
  551. validate := validator.New()
  552. if err := validate.RegisterValidation("data_store", func(fl validator.FieldLevel) bool {
  553. prefixes := []string{
  554. storage.MongoPrefix,
  555. storage.MongoSrvPrefix,
  556. storage.MySQLPrefix,
  557. storage.PostgresPrefix,
  558. storage.PostgreSQLPrefix,
  559. }
  560. if oneModel {
  561. prefixes = append(prefixes, storage.SQLitePrefix)
  562. }
  563. for _, prefix := range prefixes {
  564. if strings.HasPrefix(fl.Field().String(), prefix) {
  565. return true
  566. }
  567. }
  568. return false
  569. }); err != nil {
  570. return errors.Trace(err)
  571. }
  572. if err := validate.RegisterValidation("cache_store", func(fl validator.FieldLevel) bool {
  573. prefixes := []string{
  574. storage.RedisPrefix,
  575. storage.RedissPrefix,
  576. storage.MongoPrefix,
  577. storage.MongoSrvPrefix,
  578. storage.MySQLPrefix,
  579. storage.PostgresPrefix,
  580. storage.PostgreSQLPrefix,
  581. }
  582. if oneModel {
  583. prefixes = append(prefixes, storage.SQLitePrefix)
  584. }
  585. for _, prefix := range prefixes {
  586. if strings.HasPrefix(fl.Field().String(), prefix) {
  587. return true
  588. }
  589. }
  590. return false
  591. }); err != nil {
  592. return errors.Trace(err)
  593. }
  594. validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
  595. return strings.SplitN(fld.Tag.Get("mapstructure"), ",", 2)[0]
  596. })
  597. err := validate.Struct(config)
  598. if err != nil {
  599. // translate errors
  600. trans := ut.New(en.New()).GetFallback()
  601. if err := en_translations.RegisterDefaultTranslations(validate, trans); err != nil {
  602. return errors.Trace(err)
  603. }
  604. if err := validate.RegisterTranslation("data_store", trans, func(ut ut.Translator) error {
  605. return ut.Add("data_store", "unsupported data storage backend", true) // see universal-translator for details
  606. }, func(ut ut.Translator, fe validator.FieldError) string {
  607. t, _ := ut.T("data_store", fe.Field())
  608. return t
  609. }); err != nil {
  610. return errors.Trace(err)
  611. }
  612. if err := validate.RegisterTranslation("cache_store", trans, func(ut ut.Translator) error {
  613. return ut.Add("cache_store", "unsupported cache storage backend", true) // see universal-translator for details
  614. }, func(ut ut.Translator, fe validator.FieldError) string {
  615. t, _ := ut.T("cache_store", fe.Field())
  616. return t
  617. }); err != nil {
  618. return errors.Trace(err)
  619. }
  620. errs := err.(validator.ValidationErrors)
  621. for _, e := range errs {
  622. return errors.New(e.Translate(trans))
  623. }
  624. }
  625. return nil
  626. }