rpc.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. // Copyright 2021 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package master
  15. import (
  16. "context"
  17. "encoding/json"
  18. "io"
  19. "strings"
  20. "github.com/juju/errors"
  21. "github.com/zhenghaoz/gorse/base/log"
  22. "github.com/zhenghaoz/gorse/model/click"
  23. "github.com/zhenghaoz/gorse/model/ranking"
  24. "github.com/zhenghaoz/gorse/protocol"
  25. "go.uber.org/zap"
  26. "google.golang.org/grpc/peer"
  27. )
  28. // Node could be worker node for server node.
  29. type Node struct {
  30. Name string
  31. Type string
  32. IP string
  33. HttpPort int64
  34. BinaryVersion string
  35. }
  36. const (
  37. ServerNode = "Server"
  38. WorkerNode = "Worker"
  39. )
  40. // NewNode creates a node from Context and NodeInfo.
  41. func NewNode(ctx context.Context, nodeInfo *protocol.NodeInfo) *Node {
  42. node := new(Node)
  43. node.Name = nodeInfo.NodeName
  44. node.HttpPort = nodeInfo.HttpPort
  45. node.BinaryVersion = nodeInfo.BinaryVersion
  46. // read address
  47. p, _ := peer.FromContext(ctx)
  48. hostAndPort := p.Addr.String()
  49. node.IP = strings.Split(hostAndPort, ":")[0]
  50. // read type
  51. switch nodeInfo.NodeType {
  52. case protocol.NodeType_ServerNode:
  53. node.Type = ServerNode
  54. case protocol.NodeType_WorkerNode:
  55. node.Type = WorkerNode
  56. }
  57. return node
  58. }
  59. // GetMeta returns latest configuration.
  60. func (m *Master) GetMeta(ctx context.Context, nodeInfo *protocol.NodeInfo) (*protocol.Meta, error) {
  61. // register node
  62. node := NewNode(ctx, nodeInfo)
  63. if node.Type != "" {
  64. if err := m.ttlCache.Set(nodeInfo.NodeName, node); err != nil {
  65. log.Logger().Error("failed to set ttl cache", zap.Error(err))
  66. return nil, err
  67. }
  68. }
  69. // marshall config
  70. s, err := json.Marshal(m.Config)
  71. if err != nil {
  72. return nil, err
  73. }
  74. // save ranking model version
  75. m.rankingModelMutex.RLock()
  76. var rankingModelVersion int64
  77. if m.RankingModel != nil && !m.RankingModel.Invalid() {
  78. rankingModelVersion = m.RankingModelVersion
  79. }
  80. m.rankingModelMutex.RUnlock()
  81. // save click model version
  82. m.clickModelMutex.RLock()
  83. var clickModelVersion int64
  84. if m.ClickModel != nil && !m.ClickModel.Invalid() {
  85. clickModelVersion = m.ClickModelVersion
  86. }
  87. m.clickModelMutex.RUnlock()
  88. // collect nodes
  89. workers := make([]string, 0)
  90. servers := make([]string, 0)
  91. m.nodesInfoMutex.RLock()
  92. for name, info := range m.nodesInfo {
  93. switch info.Type {
  94. case WorkerNode:
  95. workers = append(workers, name)
  96. case ServerNode:
  97. servers = append(servers, name)
  98. }
  99. }
  100. m.nodesInfoMutex.RUnlock()
  101. return &protocol.Meta{
  102. Config: string(s),
  103. RankingModelVersion: rankingModelVersion,
  104. ClickModelVersion: clickModelVersion,
  105. Me: nodeInfo.NodeName,
  106. Workers: workers,
  107. Servers: servers,
  108. }, nil
  109. }
  110. // GetRankingModel returns latest ranking model.
  111. func (m *Master) GetRankingModel(version *protocol.VersionInfo, sender protocol.Master_GetRankingModelServer) error {
  112. m.rankingModelMutex.RLock()
  113. defer m.rankingModelMutex.RUnlock()
  114. // skip empty model
  115. if m.RankingModel == nil || m.RankingModel.Invalid() {
  116. return errors.New("no valid model found")
  117. }
  118. // check model version
  119. if m.RankingModelVersion != version.Version {
  120. return errors.New("model version mismatch")
  121. }
  122. // encode model
  123. reader, writer := io.Pipe()
  124. var encoderError error
  125. go func() {
  126. defer func(writer *io.PipeWriter) {
  127. err := writer.Close()
  128. if err != nil {
  129. log.Logger().Error("fail to close pipe", zap.Error(err))
  130. }
  131. }(writer)
  132. err := ranking.MarshalModel(writer, m.RankingModel)
  133. if err != nil {
  134. log.Logger().Error("fail to marshal ranking model", zap.Error(err))
  135. encoderError = err
  136. return
  137. }
  138. }()
  139. // send model
  140. for {
  141. buf := make([]byte, batchSize)
  142. n, err := reader.Read(buf)
  143. if err == io.EOF {
  144. log.Logger().Debug("complete sending ranking model")
  145. break
  146. } else if err != nil {
  147. return err
  148. }
  149. err = sender.Send(&protocol.Fragment{Data: buf[:n]})
  150. if err != nil {
  151. return err
  152. }
  153. }
  154. return encoderError
  155. }
  156. // GetClickModel returns latest click model.
  157. func (m *Master) GetClickModel(version *protocol.VersionInfo, sender protocol.Master_GetClickModelServer) error {
  158. m.clickModelMutex.RLock()
  159. defer m.clickModelMutex.RUnlock()
  160. // skip empty model
  161. if m.ClickModel == nil || m.ClickModel.Invalid() {
  162. return errors.New("no valid model found")
  163. }
  164. // check empty model
  165. if m.ClickModelVersion != version.Version {
  166. return errors.New("model version mismatch")
  167. }
  168. // encode model
  169. reader, writer := io.Pipe()
  170. var encoderError error
  171. go func() {
  172. defer func(writer *io.PipeWriter) {
  173. err := writer.Close()
  174. if err != nil {
  175. log.Logger().Error("fail to close pipe", zap.Error(err))
  176. }
  177. }(writer)
  178. err := click.MarshalModel(writer, m.ClickModel)
  179. if err != nil {
  180. log.Logger().Error("fail to marshal click model", zap.Error(err))
  181. encoderError = err
  182. return
  183. }
  184. }()
  185. // send model
  186. for {
  187. buf := make([]byte, batchSize)
  188. n, err := reader.Read(buf)
  189. if err == io.EOF {
  190. log.Logger().Debug("complete sending click model")
  191. break
  192. } else if err != nil {
  193. return err
  194. }
  195. err = sender.Send(&protocol.Fragment{Data: buf[:n]})
  196. if err != nil {
  197. return err
  198. }
  199. }
  200. return encoderError
  201. }
  202. // nodeUp handles node information inserted events.
  203. func (m *Master) nodeUp(key string, value interface{}) {
  204. node := value.(*Node)
  205. log.Logger().Info("node up",
  206. zap.String("node_name", key),
  207. zap.String("node_ip", node.IP),
  208. zap.String("node_type", node.Type))
  209. m.nodesInfoMutex.Lock()
  210. defer m.nodesInfoMutex.Unlock()
  211. m.nodesInfo[key] = node
  212. }
  213. // nodeDown handles node information timeout events.
  214. func (m *Master) nodeDown(key string, value interface{}) {
  215. node := value.(*Node)
  216. log.Logger().Info("node down",
  217. zap.String("node_name", key),
  218. zap.String("node_ip", node.IP),
  219. zap.String("node_type", node.Type))
  220. m.nodesInfoMutex.Lock()
  221. defer m.nodesInfoMutex.Unlock()
  222. delete(m.nodesInfo, key)
  223. }
  224. func (m *Master) PushProgress(
  225. _ context.Context,
  226. in *protocol.PushProgressRequest) (*protocol.PushProgressResponse, error) {
  227. // check empty progress
  228. if len(in.Progress) == 0 {
  229. return &protocol.PushProgressResponse{}, nil
  230. }
  231. // check tracers
  232. tracer := in.Progress[0].Tracer
  233. for _, p := range in.Progress {
  234. if p.Tracer != tracer {
  235. return nil, errors.Errorf("tracers must be the same, expect %v, got %v", tracer, p.Tracer)
  236. }
  237. }
  238. // store progress
  239. m.remoteProgress.Store(tracer, protocol.DecodeProgress(in))
  240. return &protocol.PushProgressResponse{}, nil
  241. }