local_cache.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // Copyright 2021 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package master
  15. import (
  16. "encoding/binary"
  17. std_errors "errors"
  18. "github.com/juju/errors"
  19. "github.com/zhenghaoz/gorse/base/encoding"
  20. "github.com/zhenghaoz/gorse/base/log"
  21. "github.com/zhenghaoz/gorse/model/click"
  22. "github.com/zhenghaoz/gorse/model/ranking"
  23. "go.uber.org/zap"
  24. "os"
  25. "path/filepath"
  26. )
  27. // LocalCache is local cache for the master node.
  28. type LocalCache struct {
  29. path string
  30. RankingModelName string
  31. RankingModelVersion int64
  32. RankingModel ranking.MatrixFactorization
  33. RankingModelScore ranking.Score
  34. ClickModelVersion int64
  35. ClickModelScore click.Score
  36. ClickModel click.FactorizationMachine
  37. }
  38. // LoadLocalCache loads local cache from a file.
  39. // If the ranking model is invalid, RankingModel == nil.
  40. // If the click model is invalid, ClickModel == nil.
  41. func LoadLocalCache(path string) (*LocalCache, error) {
  42. log.Logger().Info("load cache", zap.String("path", path))
  43. state := &LocalCache{path: path}
  44. // check if file exists
  45. if _, err := os.Stat(path); err != nil {
  46. if std_errors.Is(err, os.ErrNotExist) {
  47. return state, errors.NotFoundf("cache folder %s", path)
  48. }
  49. return state, errors.Trace(err)
  50. }
  51. // open file
  52. f, err := os.Open(state.GetFilePath(ModelFile))
  53. if err != nil {
  54. return state, errors.Trace(err)
  55. }
  56. defer func(f *os.File) {
  57. err = f.Close()
  58. if err != nil {
  59. log.Logger().Error("fail to close file", zap.Error(err))
  60. }
  61. }(f)
  62. // 1. ranking model name
  63. state.RankingModelName, err = encoding.ReadString(f)
  64. if err != nil {
  65. return state, errors.Trace(err)
  66. }
  67. // 2. ranking model version
  68. err = binary.Read(f, binary.LittleEndian, &state.RankingModelVersion)
  69. if err != nil {
  70. return state, errors.Trace(err)
  71. }
  72. // 3. ranking model
  73. state.RankingModel, err = ranking.UnmarshalModel(f)
  74. if err != nil {
  75. return state, errors.Trace(err)
  76. }
  77. // 4. ranking model score
  78. err = encoding.ReadGob(f, &state.RankingModelScore)
  79. if err != nil {
  80. return state, errors.Trace(err)
  81. }
  82. // 7. click model version
  83. err = binary.Read(f, binary.LittleEndian, &state.ClickModelVersion)
  84. if err != nil {
  85. return state, errors.Trace(err)
  86. }
  87. // 8. click model score
  88. err = encoding.ReadGob(f, &state.ClickModelScore)
  89. if err != nil {
  90. return state, errors.Trace(err)
  91. }
  92. // 9. click model
  93. state.ClickModel, err = click.UnmarshalModel(f)
  94. if err != nil {
  95. return state, errors.Trace(err)
  96. }
  97. return state, nil
  98. }
  99. // WriteLocalCache writes local cache to a file.
  100. func (c *LocalCache) WriteLocalCache() error {
  101. // create parent folder if not exists
  102. if _, err := os.Stat(c.path); os.IsNotExist(err) {
  103. err = os.MkdirAll(c.path, os.ModePerm)
  104. if err != nil {
  105. return errors.Trace(err)
  106. }
  107. }
  108. // create file
  109. f, err := os.Create(c.GetFilePath(ModelFile))
  110. if err != nil {
  111. return errors.Trace(err)
  112. }
  113. defer func(f *os.File) {
  114. err = f.Close()
  115. if err != nil {
  116. log.Logger().Error("fail to close file", zap.Error(err))
  117. }
  118. }(f)
  119. // 1. ranking model name
  120. err = encoding.WriteString(f, c.RankingModelName)
  121. if err != nil {
  122. return errors.Trace(err)
  123. }
  124. // 2. ranking model version
  125. err = binary.Write(f, binary.LittleEndian, c.RankingModelVersion)
  126. if err != nil {
  127. return errors.Trace(err)
  128. }
  129. // 3. ranking model
  130. err = ranking.MarshalModel(f, c.RankingModel)
  131. if err != nil {
  132. return errors.Trace(err)
  133. }
  134. // 4. ranking model score
  135. err = encoding.WriteGob(f, c.RankingModelScore)
  136. if err != nil {
  137. return errors.Trace(err)
  138. }
  139. // 7. click model version
  140. err = binary.Write(f, binary.LittleEndian, c.ClickModelVersion)
  141. if err != nil {
  142. return errors.Trace(err)
  143. }
  144. // 8. click model score
  145. err = encoding.WriteGob(f, c.ClickModelScore)
  146. if err != nil {
  147. return errors.Trace(err)
  148. }
  149. // 9. click model
  150. err = click.MarshalModel(f, c.ClickModel)
  151. if err != nil {
  152. return errors.Trace(err)
  153. }
  154. return nil
  155. }
  156. const (
  157. ModelFile = "model.bin"
  158. )
  159. func (c *LocalCache) GetFilePath(file string) string {
  160. return filepath.Join(c.path, file)
  161. }