123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479 |
- // Copyright 2020 gorse Project Authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package master
- import (
- "context"
- "fmt"
- "math"
- "math/rand"
- "net"
- "net/http"
- "sync"
- "time"
- "github.com/ReneKroon/ttlcache/v2"
- "github.com/emicklei/go-restful/v3"
- "github.com/juju/errors"
- "github.com/zhenghaoz/gorse/base"
- "github.com/zhenghaoz/gorse/base/encoding"
- "github.com/zhenghaoz/gorse/base/log"
- "github.com/zhenghaoz/gorse/base/parallel"
- "github.com/zhenghaoz/gorse/base/progress"
- "github.com/zhenghaoz/gorse/base/sizeof"
- "github.com/zhenghaoz/gorse/base/task"
- "github.com/zhenghaoz/gorse/config"
- "github.com/zhenghaoz/gorse/model"
- "github.com/zhenghaoz/gorse/model/click"
- "github.com/zhenghaoz/gorse/model/ranking"
- "github.com/zhenghaoz/gorse/protocol"
- "github.com/zhenghaoz/gorse/server"
- "github.com/zhenghaoz/gorse/storage/cache"
- "github.com/zhenghaoz/gorse/storage/data"
- "go.opentelemetry.io/otel"
- "go.opentelemetry.io/otel/propagation"
- "go.uber.org/zap"
- "google.golang.org/grpc"
- )
- type ScheduleState struct {
- IsRunning bool `json:"is_running"`
- SearchModel bool `json:"search_model"`
- StartTime time.Time `json:"start_time"`
- }
- // Master is the master node.
- type Master struct {
- protocol.UnimplementedMasterServer
- server.RestServer
- grpcServer *grpc.Server
- tracer *progress.Tracer
- remoteProgress sync.Map
- jobsScheduler *task.JobsScheduler
- cacheFile string
- managedMode bool
- // cluster meta cache
- ttlCache *ttlcache.Cache
- nodesInfo map[string]*Node
- nodesInfoMutex sync.RWMutex
- // ranking dataset
- rankingTrainSet *ranking.DataSet
- rankingTestSet *ranking.DataSet
- rankingDataMutex sync.RWMutex
- // click dataset
- clickTrainSet *click.Dataset
- clickTestSet *click.Dataset
- clickDataMutex sync.RWMutex
- // ranking model
- rankingModelName string
- rankingScore ranking.Score
- rankingModelMutex sync.RWMutex
- rankingModelSearcher *ranking.ModelSearcher
- // click model
- clickScore click.Score
- clickModelMutex sync.RWMutex
- clickModelSearcher *click.ModelSearcher
- localCache *LocalCache
- // events
- fitTicker *time.Ticker
- importedChan *parallel.ConditionChannel // feedback inserted events
- loadDataChan *parallel.ConditionChannel // dataset loaded events
- triggerChan *parallel.ConditionChannel // manually trigger events
- scheduleState ScheduleState
- workerScheduleHandler http.HandlerFunc
- }
- // NewMaster creates a master node.
- func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master {
- rand.Seed(time.Now().UnixNano())
- // setup trace provider
- tp, err := cfg.Tracing.NewTracerProvider()
- if err != nil {
- log.Logger().Fatal("failed to create trace provider", zap.Error(err))
- }
- otel.SetTracerProvider(tp)
- otel.SetErrorHandler(log.GetErrorHandler())
- otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
- m := &Master{
- nodesInfo: make(map[string]*Node),
- // create task monitor
- cacheFile: cacheFile,
- managedMode: managedMode,
- jobsScheduler: task.NewJobsScheduler(cfg.Master.NumJobs),
- tracer: progress.NewTracer("master"),
- // default ranking model
- rankingModelName: "bpr",
- rankingModelSearcher: ranking.NewModelSearcher(
- cfg.Recommend.Collaborative.ModelSearchEpoch,
- cfg.Recommend.Collaborative.ModelSearchTrials,
- cfg.Recommend.Collaborative.EnableModelSizeSearch,
- ),
- // default click model
- clickModelSearcher: click.NewModelSearcher(
- cfg.Recommend.Collaborative.ModelSearchEpoch,
- cfg.Recommend.Collaborative.ModelSearchTrials,
- cfg.Recommend.Collaborative.EnableModelSizeSearch,
- ),
- RestServer: server.RestServer{
- Settings: &config.Settings{
- Config: cfg,
- CacheClient: cache.NoDatabase{},
- DataClient: data.NoDatabase{},
- RankingModel: ranking.NewBPR(nil),
- ClickModel: click.NewFM(click.FMClassification, nil),
- // init versions
- RankingModelVersion: rand.Int63(),
- ClickModelVersion: rand.Int63(),
- },
- HttpHost: cfg.Master.HttpHost,
- HttpPort: cfg.Master.HttpPort,
- WebService: new(restful.WebService),
- },
- fitTicker: time.NewTicker(cfg.Recommend.Collaborative.ModelFitPeriod),
- importedChan: parallel.NewConditionChannel(),
- loadDataChan: parallel.NewConditionChannel(),
- triggerChan: parallel.NewConditionChannel(),
- }
- // enable deep learning
- if cfg.Experimental.EnableDeepLearning {
- log.Logger().Debug("enable deep learning")
- m.ClickModel = click.NewDeepFM(model.Params{
- model.BatchSize: cfg.Experimental.DeepLearningBatchSize,
- })
- }
- return m
- }
- // Serve starts the master node.
- func (m *Master) Serve() {
- // load local cached model
- var err error
- m.localCache, err = LoadLocalCache(m.cacheFile)
- if err != nil {
- if errors.Is(err, errors.NotFound) {
- log.Logger().Info("no local cache found, create a new one", zap.String("path", m.cacheFile))
- } else {
- log.Logger().Error("failed to load local cache", zap.String("path", m.cacheFile), zap.Error(err))
- }
- }
- if m.localCache.RankingModel != nil {
- log.Logger().Info("load cached ranking model",
- zap.String("model_name", m.localCache.RankingModelName),
- zap.String("model_version", encoding.Hex(m.localCache.RankingModelVersion)),
- zap.Float32("model_score", m.localCache.RankingModelScore.NDCG),
- zap.Any("params", m.localCache.RankingModel.GetParams()))
- m.RankingModel = m.localCache.RankingModel
- m.rankingModelName = m.localCache.RankingModelName
- m.RankingModelVersion = m.localCache.RankingModelVersion
- m.rankingScore = m.localCache.RankingModelScore
- CollaborativeFilteringPrecision10.Set(float64(m.rankingScore.Precision))
- CollaborativeFilteringRecall10.Set(float64(m.rankingScore.Recall))
- CollaborativeFilteringNDCG10.Set(float64(m.rankingScore.NDCG))
- MemoryInUseBytesVec.WithLabelValues("collaborative_filtering_model").Set(float64(m.RankingModel.Bytes()))
- }
- if m.localCache.ClickModel != nil {
- log.Logger().Info("load cached click model",
- zap.String("model_version", encoding.Hex(m.localCache.ClickModelVersion)),
- zap.Float32("model_score", m.localCache.ClickModelScore.Precision),
- zap.Any("params", m.localCache.ClickModel.GetParams()))
- m.ClickModel = m.localCache.ClickModel
- m.clickScore = m.localCache.ClickModelScore
- m.ClickModelVersion = m.localCache.ClickModelVersion
- RankingPrecision.Set(float64(m.clickScore.Precision))
- RankingRecall.Set(float64(m.clickScore.Recall))
- RankingAUC.Set(float64(m.clickScore.AUC))
- MemoryInUseBytesVec.WithLabelValues("ranking_model").Set(float64(sizeof.DeepSize(m.ClickModel)))
- }
- // create cluster meta cache
- m.ttlCache = ttlcache.NewCache()
- m.ttlCache.SetExpirationCallback(m.nodeDown)
- m.ttlCache.SetNewItemCallback(m.nodeUp)
- if err = m.ttlCache.SetTTL(m.Config.Master.MetaTimeout + 10*time.Second); err != nil {
- log.Logger().Fatal("failed to set TTL", zap.Error(err))
- }
- // connect data database
- m.DataClient, err = data.Open(m.Config.Database.DataStore, m.Config.Database.DataTablePrefix)
- if err != nil {
- log.Logger().Fatal("failed to connect data database", zap.Error(err),
- zap.String("database", log.RedactDBURL(m.Config.Database.DataStore)))
- }
- if err = m.DataClient.Init(); err != nil {
- log.Logger().Fatal("failed to init database", zap.Error(err))
- }
- // connect cache database
- m.CacheClient, err = cache.Open(m.Config.Database.CacheStore, m.Config.Database.CacheTablePrefix)
- if err != nil {
- log.Logger().Fatal("failed to connect cache database", zap.Error(err),
- zap.String("database", log.RedactDBURL(m.Config.Database.CacheStore)))
- }
- if err = m.CacheClient.Init(); err != nil {
- log.Logger().Fatal("failed to init database", zap.Error(err))
- }
- if m.managedMode {
- go m.RunManagedTasksLoop()
- } else {
- go m.RunPrivilegedTasksLoop()
- log.Logger().Info("start model fit", zap.Duration("period", m.Config.Recommend.Collaborative.ModelFitPeriod))
- go m.RunRagtagTasksLoop()
- log.Logger().Info("start model searcher", zap.Duration("period", m.Config.Recommend.Collaborative.ModelSearchPeriod))
- }
- // start rpc server
- go func() {
- log.Logger().Info("start rpc server",
- zap.String("host", m.Config.Master.Host),
- zap.Int("port", m.Config.Master.Port))
- lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", m.Config.Master.Host, m.Config.Master.Port))
- if err != nil {
- log.Logger().Fatal("failed to listen", zap.Error(err))
- }
- m.grpcServer = grpc.NewServer(grpc.MaxSendMsgSize(math.MaxInt))
- protocol.RegisterMasterServer(m.grpcServer, m)
- if err = m.grpcServer.Serve(lis); err != nil {
- log.Logger().Fatal("failed to start rpc server", zap.Error(err))
- }
- }()
- // start http server
- m.StartHttpServer()
- }
- func (m *Master) Shutdown() {
- // stop http server
- err := m.HttpServer.Shutdown(context.TODO())
- if err != nil {
- log.Logger().Error("failed to shutdown http server", zap.Error(err))
- }
- // stop grpc server
- m.grpcServer.GracefulStop()
- }
- func (m *Master) RunPrivilegedTasksLoop() {
- defer base.CheckPanic()
- var (
- err error
- tasks = []Task{
- NewFitClickModelTask(m),
- NewFitRankingModelTask(m),
- NewFindUserNeighborsTask(m),
- NewFindItemNeighborsTask(m),
- }
- firstLoop = true
- )
- go func() {
- m.importedChan.Signal()
- for {
- if m.checkDataImported() {
- m.importedChan.Signal()
- }
- time.Sleep(time.Second)
- }
- }()
- for {
- select {
- case <-m.fitTicker.C:
- case <-m.importedChan.C:
- }
- // download dataset
- err = m.runLoadDatasetTask()
- if err != nil {
- log.Logger().Error("failed to load ranking dataset", zap.Error(err))
- continue
- }
- if m.rankingTrainSet.UserCount() == 0 && m.rankingTrainSet.ItemCount() == 0 && m.rankingTrainSet.Count() == 0 {
- log.Logger().Warn("empty ranking dataset",
- zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
- continue
- }
- if firstLoop {
- m.loadDataChan.Signal()
- firstLoop = false
- }
- var registeredTask []Task
- for _, t := range tasks {
- if m.jobsScheduler.Register(t.name(), t.priority(), true) {
- registeredTask = append(registeredTask, t)
- }
- }
- for _, t := range registeredTask {
- go func(task Task) {
- j := m.jobsScheduler.GetJobsAllocator(task.name())
- defer m.jobsScheduler.Unregister(task.name())
- j.Init()
- if err := task.run(context.Background(), j); err != nil {
- log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err))
- return
- }
- }(t)
- }
- }
- }
- // RunRagtagTasksLoop searches optimal recommendation model in background. It never modifies variables other than
- // rankingModelSearcher, clickSearchedModel and clickSearchedScore.
- func (m *Master) RunRagtagTasksLoop() {
- defer base.CheckPanic()
- <-m.loadDataChan.C
- var (
- err error
- tasks = []Task{
- NewCacheGarbageCollectionTask(m),
- NewSearchRankingModelTask(m),
- NewSearchClickModelTask(m),
- }
- )
- for {
- if m.rankingTrainSet == nil || m.clickTrainSet == nil {
- time.Sleep(time.Second)
- continue
- }
- var registeredTask []Task
- for _, t := range tasks {
- if m.jobsScheduler.Register(t.name(), t.priority(), false) {
- registeredTask = append(registeredTask, t)
- }
- }
- for _, t := range registeredTask {
- go func(task Task) {
- defer m.jobsScheduler.Unregister(task.name())
- j := m.jobsScheduler.GetJobsAllocator(task.name())
- j.Init()
- if err = task.run(context.Background(), j); err != nil {
- log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err))
- }
- }(t)
- }
- time.Sleep(m.Config.Recommend.Collaborative.ModelSearchPeriod)
- }
- }
- func (m *Master) RunManagedTasksLoop() {
- var (
- privilegedTasks = []Task{
- NewFitClickModelTask(m),
- NewFitRankingModelTask(m),
- NewFindUserNeighborsTask(m),
- NewFindItemNeighborsTask(m),
- }
- ragtagTasks = []Task{
- NewCacheGarbageCollectionTask(m),
- NewSearchRankingModelTask(m),
- NewSearchClickModelTask(m),
- }
- )
- for range m.triggerChan.C {
- func() {
- defer base.CheckPanic()
- searchModel := m.scheduleState.SearchModel
- m.scheduleState.IsRunning = true
- m.scheduleState.StartTime = time.Now()
- defer func() {
- m.scheduleState.IsRunning = false
- m.scheduleState.SearchModel = false
- m.scheduleState.StartTime = time.Time{}
- }()
- _ = searchModel
- // download dataset
- if err := m.runLoadDatasetTask(); err != nil {
- log.Logger().Error("failed to load ranking dataset", zap.Error(err))
- return
- }
- if m.rankingTrainSet.UserCount() == 0 && m.rankingTrainSet.ItemCount() == 0 && m.rankingTrainSet.Count() == 0 {
- log.Logger().Warn("empty ranking dataset",
- zap.Strings("positive_feedback_type", m.Config.Recommend.DataSource.PositiveFeedbackTypes))
- return
- }
- var registeredTask []Task
- for _, t := range privilegedTasks {
- if m.jobsScheduler.Register(t.name(), t.priority(), true) {
- registeredTask = append(registeredTask, t)
- }
- }
- if searchModel {
- for _, t := range ragtagTasks {
- if m.jobsScheduler.Register(t.name(), t.priority(), false) {
- registeredTask = append(registeredTask, t)
- }
- }
- }
- var wg sync.WaitGroup
- wg.Add(len(registeredTask))
- for _, t := range registeredTask {
- go func(task Task) {
- j := m.jobsScheduler.GetJobsAllocator(task.name())
- defer m.jobsScheduler.Unregister(task.name())
- defer wg.Done()
- j.Init()
- if err := task.run(context.Background(), j); err != nil {
- log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err))
- return
- }
- }(t)
- }
- wg.Wait()
- }()
- }
- }
- func (m *Master) checkDataImported() bool {
- ctx := context.Background()
- isDataImported, err := m.CacheClient.Get(ctx, cache.Key(cache.GlobalMeta, cache.DataImported)).Integer()
- if err != nil {
- if !errors.Is(err, errors.NotFound) {
- log.Logger().Error("failed to read meta", zap.Error(err))
- }
- return false
- }
- if isDataImported > 0 {
- err = m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.DataImported), 0))
- if err != nil {
- log.Logger().Error("failed to write meta", zap.Error(err))
- }
- return true
- }
- return false
- }
- func (m *Master) notifyDataImported() {
- ctx := context.Background()
- err := m.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.DataImported), 1))
- if err != nil {
- log.Logger().Error("failed to write meta", zap.Error(err))
- }
- }
|