123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- // Copyright 2021 gorse Project Authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package master
- import (
- "context"
- "encoding/json"
- "io"
- "strings"
- "github.com/juju/errors"
- "github.com/zhenghaoz/gorse/base/log"
- "github.com/zhenghaoz/gorse/model/click"
- "github.com/zhenghaoz/gorse/model/ranking"
- "github.com/zhenghaoz/gorse/protocol"
- "go.uber.org/zap"
- "google.golang.org/grpc/peer"
- )
- // Node could be worker node for server node.
- type Node struct {
- Name string
- Type string
- IP string
- HttpPort int64
- BinaryVersion string
- }
- const (
- ServerNode = "Server"
- WorkerNode = "Worker"
- )
- // NewNode creates a node from Context and NodeInfo.
- func NewNode(ctx context.Context, nodeInfo *protocol.NodeInfo) *Node {
- node := new(Node)
- node.Name = nodeInfo.NodeName
- node.HttpPort = nodeInfo.HttpPort
- node.BinaryVersion = nodeInfo.BinaryVersion
- // read address
- p, _ := peer.FromContext(ctx)
- hostAndPort := p.Addr.String()
- node.IP = strings.Split(hostAndPort, ":")[0]
- // read type
- switch nodeInfo.NodeType {
- case protocol.NodeType_ServerNode:
- node.Type = ServerNode
- case protocol.NodeType_WorkerNode:
- node.Type = WorkerNode
- }
- return node
- }
- // GetMeta returns latest configuration.
- func (m *Master) GetMeta(ctx context.Context, nodeInfo *protocol.NodeInfo) (*protocol.Meta, error) {
- // register node
- node := NewNode(ctx, nodeInfo)
- if node.Type != "" {
- if err := m.ttlCache.Set(nodeInfo.NodeName, node); err != nil {
- log.Logger().Error("failed to set ttl cache", zap.Error(err))
- return nil, err
- }
- }
- // marshall config
- s, err := json.Marshal(m.Config)
- if err != nil {
- return nil, err
- }
- // save ranking model version
- m.rankingModelMutex.RLock()
- var rankingModelVersion int64
- if m.RankingModel != nil && !m.RankingModel.Invalid() {
- rankingModelVersion = m.RankingModelVersion
- }
- m.rankingModelMutex.RUnlock()
- // save click model version
- m.clickModelMutex.RLock()
- var clickModelVersion int64
- if m.ClickModel != nil && !m.ClickModel.Invalid() {
- clickModelVersion = m.ClickModelVersion
- }
- m.clickModelMutex.RUnlock()
- // collect nodes
- workers := make([]string, 0)
- servers := make([]string, 0)
- m.nodesInfoMutex.RLock()
- for name, info := range m.nodesInfo {
- switch info.Type {
- case WorkerNode:
- workers = append(workers, name)
- case ServerNode:
- servers = append(servers, name)
- }
- }
- m.nodesInfoMutex.RUnlock()
- return &protocol.Meta{
- Config: string(s),
- RankingModelVersion: rankingModelVersion,
- ClickModelVersion: clickModelVersion,
- Me: nodeInfo.NodeName,
- Workers: workers,
- Servers: servers,
- }, nil
- }
- // GetRankingModel returns latest ranking model.
- func (m *Master) GetRankingModel(version *protocol.VersionInfo, sender protocol.Master_GetRankingModelServer) error {
- m.rankingModelMutex.RLock()
- defer m.rankingModelMutex.RUnlock()
- // skip empty model
- if m.RankingModel == nil || m.RankingModel.Invalid() {
- return errors.New("no valid model found")
- }
- // check model version
- if m.RankingModelVersion != version.Version {
- return errors.New("model version mismatch")
- }
- // encode model
- reader, writer := io.Pipe()
- var encoderError error
- go func() {
- defer func(writer *io.PipeWriter) {
- err := writer.Close()
- if err != nil {
- log.Logger().Error("fail to close pipe", zap.Error(err))
- }
- }(writer)
- err := ranking.MarshalModel(writer, m.RankingModel)
- if err != nil {
- log.Logger().Error("fail to marshal ranking model", zap.Error(err))
- encoderError = err
- return
- }
- }()
- // send model
- for {
- buf := make([]byte, batchSize)
- n, err := reader.Read(buf)
- if err == io.EOF {
- log.Logger().Debug("complete sending ranking model")
- break
- } else if err != nil {
- return err
- }
- err = sender.Send(&protocol.Fragment{Data: buf[:n]})
- if err != nil {
- return err
- }
- }
- return encoderError
- }
- // GetClickModel returns latest click model.
- func (m *Master) GetClickModel(version *protocol.VersionInfo, sender protocol.Master_GetClickModelServer) error {
- m.clickModelMutex.RLock()
- defer m.clickModelMutex.RUnlock()
- // skip empty model
- if m.ClickModel == nil || m.ClickModel.Invalid() {
- return errors.New("no valid model found")
- }
- // check empty model
- if m.ClickModelVersion != version.Version {
- return errors.New("model version mismatch")
- }
- // encode model
- reader, writer := io.Pipe()
- var encoderError error
- go func() {
- defer func(writer *io.PipeWriter) {
- err := writer.Close()
- if err != nil {
- log.Logger().Error("fail to close pipe", zap.Error(err))
- }
- }(writer)
- err := click.MarshalModel(writer, m.ClickModel)
- if err != nil {
- log.Logger().Error("fail to marshal click model", zap.Error(err))
- encoderError = err
- return
- }
- }()
- // send model
- for {
- buf := make([]byte, batchSize)
- n, err := reader.Read(buf)
- if err == io.EOF {
- log.Logger().Debug("complete sending click model")
- break
- } else if err != nil {
- return err
- }
- err = sender.Send(&protocol.Fragment{Data: buf[:n]})
- if err != nil {
- return err
- }
- }
- return encoderError
- }
- // nodeUp handles node information inserted events.
- func (m *Master) nodeUp(key string, value interface{}) {
- node := value.(*Node)
- log.Logger().Info("node up",
- zap.String("node_name", key),
- zap.String("node_ip", node.IP),
- zap.String("node_type", node.Type))
- m.nodesInfoMutex.Lock()
- defer m.nodesInfoMutex.Unlock()
- m.nodesInfo[key] = node
- }
- // nodeDown handles node information timeout events.
- func (m *Master) nodeDown(key string, value interface{}) {
- node := value.(*Node)
- log.Logger().Info("node down",
- zap.String("node_name", key),
- zap.String("node_ip", node.IP),
- zap.String("node_type", node.Type))
- m.nodesInfoMutex.Lock()
- defer m.nodesInfoMutex.Unlock()
- delete(m.nodesInfo, key)
- }
- func (m *Master) PushProgress(
- _ context.Context,
- in *protocol.PushProgressRequest) (*protocol.PushProgressResponse, error) {
- // check empty progress
- if len(in.Progress) == 0 {
- return &protocol.PushProgressResponse{}, nil
- }
- // check tracers
- tracer := in.Progress[0].Tracer
- for _, p := range in.Progress {
- if p.Tracer != tracer {
- return nil, errors.Errorf("tracers must be the same, expect %v, got %v", tracer, p.Tracer)
- }
- }
- // store progress
- m.remoteProgress.Store(tracer, protocol.DecodeProgress(in))
- return &protocol.PushProgressResponse{}, nil
- }
|