master.go 15 KB


  1. // Copyright 2020 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package master
  15. import (
  16. "context"
  17. "fmt"
  18. "math"
  19. "math/rand"
  20. "net"
  21. "net/http"
  22. "sync"
  23. "time"
  24. "github.com/ReneKroon/ttlcache/v2"
  25. "github.com/emicklei/go-restful/v3"
  26. "github.com/juju/errors"
  27. "github.com/zhenghaoz/gorse/base"
  28. "github.com/zhenghaoz/gorse/base/encoding"
  29. "github.com/zhenghaoz/gorse/base/log"
  30. "github.com/zhenghaoz/gorse/base/parallel"
  31. "github.com/zhenghaoz/gorse/base/progress"
  32. "github.com/zhenghaoz/gorse/base/sizeof"
  33. "github.com/zhenghaoz/gorse/base/task"
  34. "github.com/zhenghaoz/gorse/config"
  35. "github.com/zhenghaoz/gorse/model"
  36. "github.com/zhenghaoz/gorse/model/click"
  37. "github.com/zhenghaoz/gorse/model/ranking"
  38. "github.com/zhenghaoz/gorse/protocol"
  39. "github.com/zhenghaoz/gorse/server"
  40. "github.com/zhenghaoz/gorse/storage/cache"
  41. "github.com/zhenghaoz/gorse/storage/data"
  42. "go.opentelemetry.io/otel"
  43. "go.opentelemetry.io/otel/propagation"
  44. "go.uber.org/zap"
  45. "google.golang.org/grpc"
  46. )
  47. type ScheduleState struct {
  48. IsRunning bool `json:"is_running"`
  49. SearchModel bool `json:"search_model"`
  50. StartTime time.Time `json:"start_time"`
  51. }
  52. // Master is the master node.
  53. type Master struct {
  54. protocol.UnimplementedMasterServer
  55. server.RestServer
  56. grpcServer *grpc.Server
  57. tracer *progress.Tracer
  58. remoteProgress sync.Map
  59. jobsScheduler *task.JobsScheduler
  60. cacheFile string
  61. managedMode bool
  62. // cluster meta cache
  63. ttlCache *ttlcache.Cache
  64. nodesInfo map[string]*Node
  65. nodesInfoMutex sync.RWMutex
  66. // ranking dataset
  67. rankingTrainSet *ranking.DataSet
  68. rankingTestSet *ranking.DataSet
  69. rankingDataMutex sync.RWMutex
  70. // click dataset
  71. clickTrainSet *click.Dataset
  72. clickTestSet *click.Dataset
  73. clickDataMutex sync.RWMutex
  74. // ranking model
  75. rankingModelName string
  76. rankingScore ranking.Score
  77. rankingModelMutex sync.RWMutex
  78. rankingModelSearcher *ranking.ModelSearcher
  79. // click model
  80. clickScore click.Score
  81. clickModelMutex sync.RWMutex
  82. clickModelSearcher *click.ModelSearcher
  83. localCache *LocalCache
  84. // events
  85. fitTicker *time.Ticker
  86. importedChan *parallel.ConditionChannel // feedback inserted events
  87. loadDataChan *parallel.ConditionChannel // dataset loaded events
  88. triggerChan *parallel.ConditionChannel // manually trigger events
  89. scheduleState ScheduleState
  90. workerScheduleHandler http.HandlerFunc
  91. }
  92. // NewMaster creates a master node.
  93. func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master {
  94. rand.Seed(time.Now().UnixNano())
  95. // setup trace provider
  96. tp, err := cfg.Tracing.NewTracerProvider()
  97. if err != nil {
  98. log.Logger().Fatal("failed to create trace provider", zap.Error(err))
  99. }
  100. otel.SetTracerProvider(tp)
  101. otel.SetErrorHandler(log.GetErrorHandler())
  102. otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
  103. m := &Master{
  104. nodesInfo: make(map[string]*Node),
  105. // create task monitor
  106. cacheFile: cacheFile,
  107. managedMode: managedMode,
  108. jobsScheduler: task.NewJobsScheduler(cfg.Master.NumJobs),
  109. tracer: progress.NewTracer("master"),
  110. // default ranking model
  111. rankingModelName: "bpr",
  112. rankingModelSearcher: ranking.NewModelSearcher(
  113. cfg.Recommend.Collaborative.ModelSearchEpoch,
  114. cfg.Recommend.Collaborative.ModelSearchTrials,
  115. cfg.Recommend.Collaborative.EnableModelSizeSearch,
  116. ),
  117. // default click model
  118. clickModelSearcher: click.NewModelSearcher(
  119. cfg.Recommend.Collaborative.ModelSearchEpoch,
  120. cfg.Recommend.Collaborative.ModelSearchTrials,
  121. cfg.Recommend.Collaborative.EnableModelSizeSearch,
  122. ),
  123. RestServer: server.RestServer{
  124. Settings: &config.Settings{
  125. Config: cfg,
  126. CacheClient: cache.NoDatabase{},
  127. DataClient: data.NoDatabase{},
  128. RankingModel: ranking.NewBPR(nil),
  129. ClickModel: click.NewFM(click.FMClassification, nil),
  130. // init versions
  131. RankingModelVersion: rand.Int63(),
  132. ClickModelVersion: rand.Int63(),
  133. },
  134. HttpHost: cfg.Master.HttpHost,
  135. HttpPort: cfg.Master.HttpPort,
  136. WebService: new(restful.WebService),
  137. },
  138. fitTicker: time.NewTicker(cfg.Recommend.Collaborative.ModelFitPeriod),
  139. importedChan: parallel.NewConditionChannel(),
  140. loadDataChan: parallel.NewConditionChannel(),
  141. triggerChan: parallel.NewConditionChannel(),
  142. }
  143. // enable deep learning
  144. if cfg.Experimental.EnableDeepLearning {
  145. log.Logger().Debug("enable deep learning")
  146. m.ClickModel = click.NewDeepFM(model.Params{
  147. model.BatchSize: cfg.Experimental.DeepLearningBatchSize,
  148. })
  149. }
  150. return m
  151. }
  152. // Serve starts the master node.
  153. func (m *Master) Serve() {
  154. // load local cached model
  155. var err error
  156. m.localCache, err = LoadLocalCache(m.cacheFile)
  157. if err != nil {
  158. if errors.Is(err, errors.NotFound) {
  159. log.Logger().Info("no local cache found, create a new one", zap.String("path", m.cacheFile))
  160. } else {
  161. log.Logger().Error("failed to load local cache", zap.String("path", m.cacheFile), zap.Error(err))
  162. }
  163. }
  164. if m.localCache.RankingModel != nil {
  165. log.Logger().Info("load cached ranking model",
  166. zap.String("model_name", m.localCache.RankingModelName),
  167. zap.String("model_version", encoding.Hex(m.localCache.RankingModelVersion)),
  168. zap.Float32("model_score", m.localCache.RankingModelScore.NDCG),
  169. zap.Any("params", m.localCache.RankingModel.GetParams()))
  170. m.RankingModel = m.localCache.RankingModel
  171. m.rankingModelName = m.localCache.RankingModelName
  172. m.RankingModelVersion = m.localCache.RankingModelVersion
  173. m.rankingScore = m.localCache.RankingModelScore
  174. CollaborativeFilteringPrecision10.Set(float64(m.rankingScore.Precision))
  175. CollaborativeFilteringRecall10.Set(float64(m.rankingScore.Recall))
  176. CollaborativeFilteringNDCG10.Set(float64(m.rankingScore.NDCG))
  177. MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(m.RankingModel.Bytes()))
  178. }
  179. if m.localCache.ClickModel != nil {
  180. log.Logger().Info("load cached click model",
  181. zap.String("model_version", encoding.Hex(m.localCache.ClickModelVersion)),
  182. zap.Float32("model_score", m.localCache.ClickModelScore.Precision),
  183. zap.Any("params", m.localCache.ClickModel.GetParams()))
  184. m.ClickModel = m.localCache.ClickModel
  185. m.clickScore = m.localCache.ClickModelScore
  186. m.ClickModelVersion = m.localCache.ClickModelVersion
  187. RankingPrecision.Set(float64(m.clickScore.Precision))
  188. RankingRecall.Set(float64(m.clickScore.Recall))
  189. RankingAUC.Set(float64(m.clickScore.AUC))
  190. MemoryInUseBytesVec.WithLabelValues("ranking_model").Set(float64(sizeof.DeepSize(m.ClickModel)))
  191. }
  192. // create cluster meta cache
  193. m.ttlCache = ttlcache.NewCache()
  194. m.ttlCache.SetExpirationCallback(m.nodeDown)
  195. m.ttlCache.SetNewItemCallback(m.nodeUp)
  196. if err = m.ttlCache.SetTTL(m.Config.Master.MetaTimeout + 10*time.Second); err != nil {
  197. log.Logger().Fatal("failed to set TTL", zap.Error(err))
  198. }
  199. // connect data database
  200. m.DataClient, err = data.Open(m.Config.Database.DataStore, m.Config.Database.DataTablePrefix)
  201. if err != nil {
  202. log.Logger().Fatal("failed to connect data database", zap.Error(err),
  203. zap.String("database", log.RedactDBURL(m.Config.Database.DataStore)))
  204. }
  205. if err = m.DataClient.Init(); err != nil {
  206. log.Logger().Fatal("failed to init database", zap.Error(err))
  207. }
  208. // connect cache database
  209. m.CacheClient, err = cache.Open(m.Config.Database.CacheStore, m.Config.Database.CacheTablePrefix)
  210. if err != nil {
  211. log.Logger().Fatal("failed to connect cache database", zap.Error(err),
  212. zap.String("database", log.RedactDBURL(m.Config.Database.CacheStore)))
  213. }
  214. if err = m.CacheClient.Init(); err != nil {
  215. log.Logger().Fatal("failed to init database", zap.Error(err))
  216. }
  217. if m.managedMode {
  218. go m.RunManagedTasksLoop()
  219. } else {
  220. go m.RunPrivilegedTasksLoop()
  221. log.Logger().Info("start model fit", zap.Duration("period", m.Config.Recommend.Collaborative.ModelFitPeriod))
  222. go m.RunRagtagTasksLoop()
  223. log.Logger().Info("start model searcher", zap.Duration("period", m.Config.Recommend.Collaborative.ModelSearchPeriod))
  224. }
  225. // start rpc server
  226. go func() {
  227. log.Logger().Info("start rpc server",
  228. zap.String("host", m.Config.Master.Host),
  229. zap.Int("port", m.Config.Master.Port))
  230. lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", m.Config.Master.Host, m.Config.Master.Port))
  231. if err != nil {
  232. log.Logger().Fatal("failed to listen", zap.Error(err))
  233. }
  234. m.grpcServer = grpc.NewServer(grpc.MaxSendMsgSize(math.MaxInt))
  235. protocol.RegisterMasterServer(m.grpcServer, m)
  236. if err = m.grpcServer.Serve(lis); err != nil {
  237. log.Logger().Fatal("failed to start rpc server", zap.Error(err))
  238. }
  239. }()
  240. // start http server
  241. m.StartHttpServer()
  242. }
  243. func (m *Master) Shutdown() {
  244. // stop http server
  245. err := m.HttpServer.Shutdown(context.TODO())
  246. if err != nil {
  247. log.Logger().Error("failed to shutdown http server", zap.Error(err))
  248. }
  249. // stop grpc server
  250. m.grpcServer.GracefulStop()
  251. }
  252. func (m *Master) RunPrivilegedTasksLoop() {
  253. defer base.CheckPanic()
  254. var (
  255. err error
  256. tasks = []Task{
  257. NewFitClickModelTask(m),
  258. NewFitRankingModelTask(m),
  259. NewFindUserNeighborsTask(m),
  260. NewFindItemNeighborsTask(m),
  261. }
  262. firstLoop = true
  263. )
  264. go func() {
  265. m.importedChan.Signal()
  266. for {
  267. if m.checkDataImported() {
  268. m.importedChan.Signal()
  269. }
  270. time.Sleep(time.Second)
  271. }
  272. }()
  273. for {
  274. select {
  275. case <-m.fitTicker.C:
  276. case <-m.importedChan.C:
  277. }
  278. // download dataset
  279. err = m.runLoadDatasetTask()
  280. if err != nil {
  281. log.Logger().Error("failed to load ranking dataset", zap.Error(err))
  282. continue
  283. }
  284. if m.rankingTrainSet.UserCount() == 0 && m.rankingTrainSet.ItemCount() == 0 && m.rankingTrainSet.Count() == 0 {
  285. log.Logger().Warn("empty ranking dataset",
  286. zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
  287. continue
  288. }
  289. if firstLoop {
  290. m.loadDataChan.Signal()
  291. firstLoop = false
  292. }
  293. var registeredTask []Task
  294. for _, t := range tasks {
  295. if m.jobsScheduler.Register(t.name(), t.priority(), true) {
  296. registeredTask = append(registeredTask, t)
  297. }
  298. }
  299. for _, t := range registeredTask {
  300. go func(task Task) {
  301. j := m.jobsScheduler.GetJobsAllocator(task.name())
  302. defer m.jobsScheduler.Unregister(task.name())
  303. j.Init()
  304. if err := task.run(context.Background(), j); err != nil {
  305. log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err))
  306. return
  307. }
  308. }(t)
  309. }
  310. }
  311. }
  312. // RunRagtagTasksLoop searches optimal recommendation model in background. It never modifies variables other than
  313. // rankingModelSearcher, clickSearchedModel and clickSearchedScore.
  314. func (m *Master) RunRagtagTasksLoop() {
  315. defer base.CheckPanic()
  316. <-m.loadDataChan.C
  317. var (
  318. err error
  319. tasks = []Task{
  320. NewCacheGarbageCollectionTask(m),
  321. NewSearchRankingModelTask(m),
  322. NewSearchClickModelTask(m),
  323. }
  324. )
  325. for {
  326. if m.rankingTrainSet == nil || m.clickTrainSet == nil {
  327. time.Sleep(time.Second)
  328. continue
  329. }
  330. var registeredTask []Task
  331. for _, t := range tasks {
  332. if m.jobsScheduler.Register(t.name(), t.priority(), false) {
  333. registeredTask = append(registeredTask, t)
  334. }
  335. }
  336. for _, t := range registeredTask {
  337. go func(task Task) {
  338. defer m.jobsScheduler.Unregister(task.name())
  339. j := m.jobsScheduler.GetJobsAllocator(task.name())
  340. j.Init()
  341. if err = task.run(context.Background(), j); err != nil {
  342. log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err))
  343. }
  344. }(t)
  345. }
  346. time.Sleep(m.Config.Recommend.Collaborative.ModelSearchPeriod)
  347. }
  348. }
  349. func (m *Master) RunManagedTasksLoop() {
  350. var (
  351. privilegedTasks = []Task{
  352. NewFitClickModelTask(m),
  353. NewFitRankingModelTask(m),
  354. NewFindUserNeighborsTask(m),
  355. NewFindItemNeighborsTask(m),
  356. }
  357. ragtagTasks = []Task{
  358. NewCacheGarbageCollectionTask(m),
  359. NewSearchRankingModelTask(m),
  360. NewSearchClickModelTask(m),
  361. }
  362. )
  363. for range m.triggerChan.C {
  364. func() {
  365. defer base.CheckPanic()
  366. searchModel := m.scheduleState.SearchModel
  367. m.scheduleState.IsRunning = true
  368. m.scheduleState.StartTime = time.Now()
  369. defer func() {
  370. m.scheduleState.IsRunning = false
  371. m.scheduleState.SearchModel = false
  372. m.scheduleState.StartTime = time.Time{}
  373. }()
  374. _ = searchModel
  375. // download dataset
  376. if err := m.runLoadDatasetTask(); err != nil {
  377. log.Logger().Error("failed to load ranking dataset", zap.Error(err))
  378. return
  379. }
  380. if m.rankingTrainSet.UserCount() == 0 && m.rankingTrainSet.ItemCount() == 0 && m.rankingTrainSet.Count() == 0 {
  381. log.Logger().Warn("empty ranking dataset",
  382. zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
  383. return
  384. }
  385. var registeredTask []Task
  386. for _, t := range privilegedTasks {
  387. if m.jobsScheduler.Register(t.name(), t.priority(), true) {
  388. registeredTask = append(registeredTask, t)
  389. }
  390. }
  391. if searchModel {
  392. for _, t := range ragtagTasks {
  393. if m.jobsScheduler.Register(t.name(), t.priority(), false) {
  394. registeredTask = append(registeredTask, t)
  395. }
  396. }
  397. }
  398. var wg sync.WaitGroup
  399. wg.Add(len(registeredTask))
  400. for _, t := range registeredTask {
  401. go func(task Task) {
  402. j := m.jobsScheduler.GetJobsAllocator(task.name())
  403. defer m.jobsScheduler.Unregister(task.name())
  404. defer wg.Done()
  405. j.Init()
  406. if err := task.run(context.Background(), j); err != nil {
  407. log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err))
  408. return
  409. }
  410. }(t)
  411. }
  412. wg.Wait()
  413. }()
  414. }
  415. }
  416. func (m *Master) checkDataImported() bool {
  417. ctx := context.Background()
  418. isDataImported, err := m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.DataImported)).Integer()
  419. if err != nil {
  420. if !errors.Is(err, errors.NotFound) {
  421. log.Logger().Error("failed to read meta", zap.Error(err))
  422. }
  423. return false
  424. }
  425. if isDataImported > 0 {
  426. err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.DataImported), 0))
  427. if err != nil {
  428. log.Logger().Error("failed to write meta", zap.Error(err))
  429. }
  430. return true
  431. }
  432. return false
  433. }
  434. func (m *Master) notifyDataImported() {
  435. ctx := context.Background()
  436. err := m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.DataImported), 1))
  437. if err != nil {
  438. log.Logger().Error("failed to write meta", zap.Error(err))
  439. }
  440. }