model.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. // Copyright 2021 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package ranking
  15. import (
  16. "context"
  17. "fmt"
  18. "io"
  19. "reflect"
  20. "time"
  21. "github.com/bits-and-blooms/bitset"
  22. "github.com/chewxy/math32"
  23. mapset "github.com/deckarep/golang-set/v2"
  24. "github.com/juju/errors"
  25. "github.com/samber/lo"
  26. "github.com/zhenghaoz/gorse/base"
  27. "github.com/zhenghaoz/gorse/base/copier"
  28. "github.com/zhenghaoz/gorse/base/encoding"
  29. "github.com/zhenghaoz/gorse/base/floats"
  30. "github.com/zhenghaoz/gorse/base/log"
  31. "github.com/zhenghaoz/gorse/base/parallel"
  32. "github.com/zhenghaoz/gorse/base/progress"
  33. "github.com/zhenghaoz/gorse/base/task"
  34. "github.com/zhenghaoz/gorse/model"
  35. "go.uber.org/zap"
  36. )
  37. type Score struct {
  38. NDCG float32
  39. Precision float32
  40. Recall float32
  41. }
  42. type FitConfig struct {
  43. *task.JobsAllocator
  44. Verbose int
  45. Candidates int
  46. TopK int
  47. }
  48. func NewFitConfig() *FitConfig {
  49. return &FitConfig{
  50. Verbose: 10,
  51. Candidates: 100,
  52. TopK: 10,
  53. }
  54. }
  55. func (config *FitConfig) SetVerbose(verbose int) *FitConfig {
  56. config.Verbose = verbose
  57. return config
  58. }
  59. func (config *FitConfig) SetJobsAllocator(allocator *task.JobsAllocator) *FitConfig {
  60. config.JobsAllocator = allocator
  61. return config
  62. }
  63. func (config *FitConfig) LoadDefaultIfNil() *FitConfig {
  64. if config == nil {
  65. return NewFitConfig()
  66. }
  67. return config
  68. }
  69. type Model interface {
  70. model.Model
  71. // Fit a model with a train set and parameters.
  72. Fit(ctx context.Context, trainSet *DataSet, validateSet *DataSet, config *FitConfig) Score
  73. // GetItemIndex returns item index.
  74. GetItemIndex() base.Index
  75. // Marshal model into byte stream.
  76. Marshal(w io.Writer) error
  77. // Unmarshal model from byte stream.
  78. Unmarshal(r io.Reader) error
  79. // GetUserFactor returns latent factor of a user.
  80. GetUserFactor(userIndex int32) []float32
  81. // GetItemFactor returns latent factor of an item.
  82. GetItemFactor(itemIndex int32) []float32
  83. }
  84. type MatrixFactorization interface {
  85. Model
  86. // Predict the rating given by a user (userId) to a item (itemId).
  87. Predict(userId, itemId string) float32
  88. // InternalPredict predicts rating given by a user index and a item index
  89. InternalPredict(userIndex, itemIndex int32) float32
  90. // GetUserIndex returns user index.
  91. GetUserIndex() base.Index
  92. // GetItemIndex returns item index.
  93. GetItemIndex() base.Index
  94. // IsUserPredictable returns false if user has no feedback and its embedding vector never be trained.
  95. IsUserPredictable(userIndex int32) bool
  96. // IsItemPredictable returns false if item has no feedback and its embedding vector never be trained.
  97. IsItemPredictable(itemIndex int32) bool
  98. // Marshal model into byte stream.
  99. Marshal(w io.Writer) error
  100. // Unmarshal model from byte stream.
  101. Unmarshal(r io.Reader) error
  102. // Bytes returns used memory.
  103. Bytes() int
  104. }
  105. type BaseMatrixFactorization struct {
  106. model.BaseModel
  107. UserIndex base.Index
  108. ItemIndex base.Index
  109. UserPredictable *bitset.BitSet
  110. ItemPredictable *bitset.BitSet
  111. // Model parameters
  112. UserFactor [][]float32 // p_u
  113. ItemFactor [][]float32 // q_i
  114. }
  115. func (baseModel *BaseMatrixFactorization) Bytes() int {
  116. bytes := reflect.TypeOf(baseModel).Elem().Size()
  117. bytes += encoding.ArrayBytes(baseModel.UserPredictable.Bytes())
  118. bytes += encoding.ArrayBytes(baseModel.ItemPredictable.Bytes())
  119. bytes += encoding.MatrixBytes(baseModel.UserFactor)
  120. bytes += encoding.MatrixBytes(baseModel.ItemFactor)
  121. return int(bytes) + baseModel.UserIndex.Bytes() + baseModel.ItemIndex.Bytes()
  122. }
  123. func (baseModel *BaseMatrixFactorization) Init(trainSet *DataSet) {
  124. baseModel.UserIndex = trainSet.UserIndex
  125. baseModel.ItemIndex = trainSet.ItemIndex
  126. // set user trained flags
  127. baseModel.UserPredictable = bitset.New(uint(trainSet.UserIndex.Len()))
  128. for userIndex := range baseModel.UserIndex.GetNames() {
  129. if len(trainSet.UserFeedback[userIndex]) > 0 {
  130. baseModel.UserPredictable.Set(uint(userIndex))
  131. }
  132. }
  133. // set item trained flags
  134. baseModel.ItemPredictable = bitset.New(uint(trainSet.ItemIndex.Len()))
  135. for itemIndex := range baseModel.ItemIndex.GetNames() {
  136. if len(trainSet.ItemFeedback[itemIndex]) > 0 {
  137. baseModel.ItemPredictable.Set(uint(itemIndex))
  138. }
  139. }
  140. }
  141. func (baseModel *BaseMatrixFactorization) GetUserIndex() base.Index {
  142. return baseModel.UserIndex
  143. }
  144. func (baseModel *BaseMatrixFactorization) GetItemIndex() base.Index {
  145. return baseModel.ItemIndex
  146. }
  147. // IsUserPredictable returns false if user has no feedback and its embedding vector never be trained.
  148. func (baseModel *BaseMatrixFactorization) IsUserPredictable(userIndex int32) bool {
  149. if userIndex >= baseModel.UserIndex.Len() {
  150. return false
  151. }
  152. return baseModel.UserPredictable.Test(uint(userIndex))
  153. }
  154. // IsItemPredictable returns false if item has no feedback and its embedding vector never be trained.
  155. func (baseModel *BaseMatrixFactorization) IsItemPredictable(itemIndex int32) bool {
  156. if itemIndex >= baseModel.ItemIndex.Len() {
  157. return false
  158. }
  159. return baseModel.ItemPredictable.Test(uint(itemIndex))
  160. }
  161. // Marshal model into byte stream.
  162. func (baseModel *BaseMatrixFactorization) Marshal(w io.Writer) error {
  163. // write params
  164. err := encoding.WriteGob(w, baseModel.Params)
  165. if err != nil {
  166. return errors.Trace(err)
  167. }
  168. // write user index
  169. err = base.MarshalIndex(w, baseModel.UserIndex)
  170. if err != nil {
  171. return errors.Trace(err)
  172. }
  173. // write item index
  174. err = base.MarshalIndex(w, baseModel.ItemIndex)
  175. if err != nil {
  176. return errors.Trace(err)
  177. }
  178. // write user predictable
  179. _, err = baseModel.UserPredictable.WriteTo(w)
  180. if err != nil {
  181. return errors.Trace(err)
  182. }
  183. // write item predictable
  184. _, err = baseModel.ItemPredictable.WriteTo(w)
  185. return errors.Trace(err)
  186. }
  187. // Unmarshal model from byte stream.
  188. func (baseModel *BaseMatrixFactorization) Unmarshal(r io.Reader) error {
  189. // read params
  190. err := encoding.ReadGob(r, &baseModel.Params)
  191. if err != nil {
  192. return errors.Trace(err)
  193. }
  194. // read user index
  195. baseModel.UserIndex, err = base.UnmarshalIndex(r)
  196. if err != nil {
  197. return errors.Trace(err)
  198. }
  199. // read item index
  200. baseModel.ItemIndex, err = base.UnmarshalIndex(r)
  201. if err != nil {
  202. return errors.Trace(err)
  203. }
  204. // read user predictable
  205. baseModel.UserPredictable = &bitset.BitSet{}
  206. _, err = baseModel.UserPredictable.ReadFrom(r)
  207. if err != nil {
  208. return errors.Trace(err)
  209. }
  210. // read item predictable
  211. baseModel.ItemPredictable = &bitset.BitSet{}
  212. _, err = baseModel.ItemPredictable.ReadFrom(r)
  213. return errors.Trace(err)
  214. }
  215. // Clone a model with deep copy.
  216. func Clone(m MatrixFactorization) MatrixFactorization {
  217. var copied MatrixFactorization
  218. if err := copier.Copy(&copied, m); err != nil {
  219. panic(err)
  220. } else {
  221. copied.SetParams(copied.GetParams())
  222. return copied
  223. }
  224. }
  225. const (
  226. CollaborativeBPR = "bpr"
  227. CollaborativeCCD = "ccd"
  228. )
  229. func GetModelName(m Model) string {
  230. switch m.(type) {
  231. case *BPR:
  232. return CollaborativeBPR
  233. case *CCD:
  234. return CollaborativeCCD
  235. default:
  236. return reflect.TypeOf(m).String()
  237. }
  238. }
  239. func MarshalModel(w io.Writer, m Model) error {
  240. if err := encoding.WriteString(w, GetModelName(m)); err != nil {
  241. return errors.Trace(err)
  242. }
  243. if err := m.Marshal(w); err != nil {
  244. return errors.Trace(err)
  245. }
  246. return nil
  247. }
  248. func UnmarshalModel(r io.Reader) (MatrixFactorization, error) {
  249. name, err := encoding.ReadString(r)
  250. if err != nil {
  251. return nil, errors.Trace(err)
  252. }
  253. switch name {
  254. case "bpr":
  255. var bpr BPR
  256. if err := bpr.Unmarshal(r); err != nil {
  257. return nil, errors.Trace(err)
  258. }
  259. return &bpr, nil
  260. case "ccd":
  261. var ccd CCD
  262. if err := ccd.Unmarshal(r); err != nil {
  263. return nil, errors.Trace(err)
  264. }
  265. return &ccd, nil
  266. }
  267. return nil, fmt.Errorf("unknown model %v", name)
  268. }
  269. // BPR means Bayesian Personal Ranking, is a pairwise learning algorithm for matrix factorization
  270. // model with implicit feedback. The pairwise ranking between item i and j for user u is estimated
  271. // by:
  272. //
  273. // p(i >_u j) = \sigma( p_u^T (q_i - q_j) )
  274. //
  275. // Hyper-parameters:
  276. //
  277. // Reg - The regularization parameter of the cost function that is
  278. // optimized. Default is 0.01.
  279. // Lr - The learning rate of SGD. Default is 0.05.
  280. // nFactors - The number of latent factors. Default is 10.
  281. // NEpochs - The number of iteration of the SGD procedure. Default is 100.
  282. // InitMean - The mean of initial random latent factors. Default is 0.
  283. // InitStdDev - The standard deviation of initial random latent factors. Default is 0.001.
  284. type BPR struct {
  285. BaseMatrixFactorization
  286. // Hyper parameters
  287. nFactors int
  288. nEpochs int
  289. lr float32
  290. reg float32
  291. initMean float32
  292. initStdDev float32
  293. }
  294. // NewBPR creates a BPR model.
  295. func NewBPR(params model.Params) *BPR {
  296. bpr := new(BPR)
  297. bpr.SetParams(params)
  298. return bpr
  299. }
  300. // GetUserFactor returns the latent factor of a user.
  301. func (bpr *BPR) GetUserFactor(userIndex int32) []float32 {
  302. return bpr.UserFactor[userIndex]
  303. }
  304. // GetItemFactor returns the latent factor of an item.
  305. func (bpr *BPR) GetItemFactor(itemIndex int32) []float32 {
  306. return bpr.ItemFactor[itemIndex]
  307. }
  308. // SetParams sets hyper-parameters of the BPR model.
  309. func (bpr *BPR) SetParams(params model.Params) {
  310. bpr.BaseMatrixFactorization.SetParams(params)
  311. // Setup hyper-parameters
  312. bpr.nFactors = bpr.Params.GetInt(model.NFactors, 16)
  313. bpr.nEpochs = bpr.Params.GetInt(model.NEpochs, 100)
  314. bpr.lr = bpr.Params.GetFloat32(model.Lr, 0.05)
  315. bpr.reg = bpr.Params.GetFloat32(model.Reg, 0.01)
  316. bpr.initMean = bpr.Params.GetFloat32(model.InitMean, 0)
  317. bpr.initStdDev = bpr.Params.GetFloat32(model.InitStdDev, 0.001)
  318. }
  319. func (bpr *BPR) GetParamsGrid(withSize bool) model.ParamsGrid {
  320. return model.ParamsGrid{
  321. model.NFactors: lo.If(withSize, []interface{}{8, 16, 32, 64}).Else([]interface{}{16}),
  322. model.Lr: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1},
  323. model.Reg: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1},
  324. model.InitMean: []interface{}{0},
  325. model.InitStdDev: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1},
  326. }
  327. }
  328. // Predict by the BPR model.
  329. func (bpr *BPR) Predict(userId, itemId string) float32 {
  330. // Convert sparse Names to dense Names
  331. userIndex := bpr.UserIndex.ToNumber(userId)
  332. itemIndex := bpr.ItemIndex.ToNumber(itemId)
  333. if userIndex == base.NotId {
  334. log.Logger().Warn("unknown user", zap.String("user_id", userId))
  335. }
  336. if itemIndex == base.NotId {
  337. log.Logger().Warn("unknown item", zap.String("item_id", itemId))
  338. }
  339. return bpr.InternalPredict(userIndex, itemIndex)
  340. }
  341. func (bpr *BPR) InternalPredict(userIndex, itemIndex int32) float32 {
  342. ret := float32(0.0)
  343. // + q_i^Tp_u
  344. if itemIndex != base.NotId && userIndex != base.NotId {
  345. ret += floats.Dot(bpr.UserFactor[userIndex], bpr.ItemFactor[itemIndex])
  346. } else {
  347. log.Logger().Warn("unknown user or item")
  348. }
  349. return ret
  350. }
  351. // Fit the BPR model. Its task complexity is O(bpr.nEpochs).
  352. func (bpr *BPR) Fit(ctx context.Context, trainSet, valSet *DataSet, config *FitConfig) Score {
  353. config = config.LoadDefaultIfNil()
  354. log.Logger().Info("fit bpr",
  355. zap.Int("train_set_size", trainSet.Count()),
  356. zap.Int("test_set_size", valSet.Count()),
  357. zap.Any("params", bpr.GetParams()),
  358. zap.Any("config", config))
  359. bpr.Init(trainSet)
  360. // Create buffers
  361. maxJobs := config.MaxJobs()
  362. temp := base.NewMatrix32(maxJobs, bpr.nFactors)
  363. userFactor := base.NewMatrix32(maxJobs, bpr.nFactors)
  364. positiveItemFactor := base.NewMatrix32(maxJobs, bpr.nFactors)
  365. negativeItemFactor := base.NewMatrix32(maxJobs, bpr.nFactors)
  366. rng := make([]base.RandomGenerator, maxJobs)
  367. for i := 0; i < maxJobs; i++ {
  368. rng[i] = base.NewRandomGenerator(bpr.GetRandomGenerator().Int63())
  369. }
  370. // Convert array to hashmap
  371. userFeedback := make([]mapset.Set[int32], trainSet.UserCount())
  372. for u := range userFeedback {
  373. userFeedback[u] = mapset.NewSet[int32]()
  374. for _, i := range trainSet.UserFeedback[u] {
  375. userFeedback[u].Add(i)
  376. }
  377. }
  378. snapshots := SnapshotManger{}
  379. evalStart := time.Now()
  380. scores := Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(), NDCG, Precision, Recall)
  381. evalTime := time.Since(evalStart)
  382. log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", 0, bpr.nEpochs),
  383. zap.String("eval_time", evalTime.String()),
  384. zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), scores[0]),
  385. zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), scores[1]),
  386. zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2]))
  387. snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, bpr.UserFactor, bpr.ItemFactor)
  388. // Training
  389. _, span := progress.Start(ctx, "BPR.Fit", bpr.nEpochs)
  390. for epoch := 1; epoch <= bpr.nEpochs; epoch++ {
  391. fitStart := time.Now()
  392. // Training epoch
  393. numJobs := config.AvailableJobs()
  394. cost := make([]float32, numJobs)
  395. _ = parallel.Parallel(trainSet.Count(), numJobs, func(workerId, _ int) error {
  396. // Select a user
  397. var userIndex int32
  398. var ratingCount int
  399. for {
  400. userIndex = rng[workerId].Int31n(int32(trainSet.UserCount()))
  401. ratingCount = len(trainSet.UserFeedback[userIndex])
  402. if ratingCount > 0 {
  403. break
  404. }
  405. }
  406. posIndex := trainSet.UserFeedback[userIndex][rng[workerId].Intn(ratingCount)]
  407. // Select a negative sample
  408. negIndex := int32(-1)
  409. for {
  410. temp := rng[workerId].Int31n(int32(trainSet.ItemCount()))
  411. if !userFeedback[userIndex].Contains(temp) {
  412. negIndex = temp
  413. break
  414. }
  415. }
  416. diff := bpr.InternalPredict(userIndex, posIndex) - bpr.InternalPredict(userIndex, negIndex)
  417. cost[workerId] += math32.Log1p(math32.Exp(-diff))
  418. grad := math32.Exp(-diff) / (1.0 + math32.Exp(-diff))
  419. // Pairwise update
  420. copy(userFactor[workerId], bpr.UserFactor[userIndex])
  421. copy(positiveItemFactor[workerId], bpr.ItemFactor[posIndex])
  422. copy(negativeItemFactor[workerId], bpr.ItemFactor[negIndex])
  423. // Update positive item latent factor: +w_u
  424. floats.MulConstTo(userFactor[workerId], grad, temp[workerId])
  425. floats.MulConstAddTo(positiveItemFactor[workerId], -bpr.reg, temp[workerId])
  426. floats.MulConstAddTo(temp[workerId], bpr.lr, bpr.ItemFactor[posIndex])
  427. // Update negative item latent factor: -w_u
  428. floats.MulConstTo(userFactor[workerId], -grad, temp[workerId])
  429. floats.MulConstAddTo(negativeItemFactor[workerId], -bpr.reg, temp[workerId])
  430. floats.MulConstAddTo(temp[workerId], bpr.lr, bpr.ItemFactor[negIndex])
  431. // Update user latent factor: h_i-h_j
  432. floats.SubTo(positiveItemFactor[workerId], negativeItemFactor[workerId], temp[workerId])
  433. floats.MulConst(temp[workerId], grad)
  434. floats.MulConstAddTo(userFactor[workerId], -bpr.reg, temp[workerId])
  435. floats.MulConstAddTo(temp[workerId], bpr.lr, bpr.UserFactor[userIndex])
  436. return nil
  437. })
  438. fitTime := time.Since(fitStart)
  439. // Cross validation
  440. if epoch%config.Verbose == 0 || epoch == bpr.nEpochs {
  441. evalStart = time.Now()
  442. scores = Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(), NDCG, Precision, Recall)
  443. evalTime = time.Since(evalStart)
  444. log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", epoch, bpr.nEpochs),
  445. zap.String("fit_time", fitTime.String()),
  446. zap.String("eval_time", evalTime.String()),
  447. zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), scores[0]),
  448. zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), scores[1]),
  449. zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2]))
  450. snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, bpr.UserFactor, bpr.ItemFactor)
  451. }
  452. span.Add(1)
  453. }
  454. span.End()
  455. // restore best snapshot
  456. bpr.UserFactor = snapshots.BestWeights[0].([][]float32)
  457. bpr.ItemFactor = snapshots.BestWeights[1].([][]float32)
  458. log.Logger().Info("fit bpr complete",
  459. zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), snapshots.BestScore.NDCG),
  460. zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), snapshots.BestScore.Precision),
  461. zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), snapshots.BestScore.Recall))
  462. return snapshots.BestScore
  463. }
  464. func (bpr *BPR) Clear() {
  465. bpr.UserIndex = nil
  466. bpr.ItemIndex = nil
  467. bpr.UserFactor = nil
  468. bpr.ItemFactor = nil
  469. }
  470. func (bpr *BPR) Invalid() bool {
  471. return bpr == nil ||
  472. bpr.UserIndex == nil ||
  473. bpr.ItemIndex == nil ||
  474. bpr.UserFactor == nil ||
  475. bpr.ItemFactor == nil
  476. }
  477. func (bpr *BPR) Init(trainSet *DataSet) {
  478. // Initialize parameters
  479. newUserFactor := bpr.GetRandomGenerator().NormalMatrix(trainSet.UserCount(), bpr.nFactors, bpr.initMean, bpr.initStdDev)
  480. newItemFactor := bpr.GetRandomGenerator().NormalMatrix(trainSet.ItemCount(), bpr.nFactors, bpr.initMean, bpr.initStdDev)
  481. // Relocate parameters
  482. if bpr.UserIndex != nil {
  483. for _, userId := range trainSet.UserIndex.GetNames() {
  484. oldIndex := bpr.UserIndex.ToNumber(userId)
  485. newIndex := trainSet.UserIndex.ToNumber(userId)
  486. if oldIndex != base.NotId {
  487. newUserFactor[newIndex] = bpr.UserFactor[oldIndex]
  488. }
  489. }
  490. }
  491. if bpr.ItemIndex != nil {
  492. for _, itemId := range trainSet.ItemIndex.GetNames() {
  493. oldIndex := bpr.ItemIndex.ToNumber(itemId)
  494. newIndex := trainSet.ItemIndex.ToNumber(itemId)
  495. if oldIndex != base.NotId {
  496. newItemFactor[newIndex] = bpr.ItemFactor[oldIndex]
  497. }
  498. }
  499. }
  500. // Initialize base
  501. bpr.UserFactor = newUserFactor
  502. bpr.ItemFactor = newItemFactor
  503. bpr.BaseMatrixFactorization.Init(trainSet)
  504. }
  505. // Marshal model into byte stream.
  506. func (bpr *BPR) Marshal(w io.Writer) error {
  507. // write base
  508. err := bpr.BaseMatrixFactorization.Marshal(w)
  509. if err != nil {
  510. return errors.Trace(err)
  511. }
  512. // write user factors
  513. err = encoding.WriteMatrix(w, bpr.UserFactor)
  514. if err != nil {
  515. return errors.Trace(err)
  516. }
  517. // write item factors
  518. err = encoding.WriteMatrix(w, bpr.ItemFactor)
  519. if err != nil {
  520. return errors.Trace(err)
  521. }
  522. return nil
  523. }
  524. // Unmarshal model from byte stream.
  525. func (bpr *BPR) Unmarshal(r io.Reader) error {
  526. // read base
  527. var err error
  528. err = bpr.BaseMatrixFactorization.Unmarshal(r)
  529. if err != nil {
  530. return errors.Trace(err)
  531. }
  532. bpr.SetParams(bpr.Params)
  533. // read user factors
  534. bpr.UserFactor = base.NewMatrix32(int(bpr.UserIndex.Len()), bpr.nFactors)
  535. err = encoding.ReadMatrix(r, bpr.UserFactor)
  536. if err != nil {
  537. return errors.Trace(err)
  538. }
  539. // read item factors
  540. bpr.ItemFactor = base.NewMatrix32(int(bpr.ItemIndex.Len()), bpr.nFactors)
  541. err = encoding.ReadMatrix(r, bpr.ItemFactor)
  542. if err != nil {
  543. return errors.Trace(err)
  544. }
  545. return nil
  546. }
  547. type CCD struct {
  548. BaseMatrixFactorization
  549. // Hyper parameters
  550. nFactors int
  551. nEpochs int
  552. reg float32
  553. initMean float32
  554. initStdDev float32
  555. weight float32
  556. }
  557. // NewCCD creates a eALS model.
  558. func NewCCD(params model.Params) *CCD {
  559. fast := new(CCD)
  560. fast.SetParams(params)
  561. return fast
  562. }
  563. // GetUserFactor returns latent factor of a user.
  564. func (ccd *CCD) GetUserFactor(userIndex int32) []float32 {
  565. return ccd.UserFactor[userIndex]
  566. }
  567. // GetItemFactor returns latent factor of an item.
  568. func (ccd *CCD) GetItemFactor(itemIndex int32) []float32 {
  569. return ccd.ItemFactor[itemIndex]
  570. }
  571. // SetParams sets hyper-parameters for the ALS model.
  572. func (ccd *CCD) SetParams(params model.Params) {
  573. ccd.BaseMatrixFactorization.SetParams(params)
  574. ccd.nFactors = ccd.Params.GetInt(model.NFactors, 16)
  575. ccd.nEpochs = ccd.Params.GetInt(model.NEpochs, 50)
  576. ccd.initMean = ccd.Params.GetFloat32(model.InitMean, 0)
  577. ccd.initStdDev = ccd.Params.GetFloat32(model.InitStdDev, 0.1)
  578. ccd.reg = ccd.Params.GetFloat32(model.Reg, 0.06)
  579. ccd.weight = ccd.Params.GetFloat32(model.Alpha, 0.001)
  580. }
  581. func (ccd *CCD) GetParamsGrid(withSize bool) model.ParamsGrid {
  582. return model.ParamsGrid{
  583. model.NFactors: lo.If(withSize, []interface{}{8, 16, 32, 64}).Else([]interface{}{16}),
  584. model.InitMean: []interface{}{0},
  585. model.InitStdDev: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1},
  586. model.Reg: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1},
  587. model.Alpha: []interface{}{0.001, 0.005, 0.01, 0.05, 0.1},
  588. }
  589. }
  590. // Predict by the ALS model.
  591. func (ccd *CCD) Predict(userId, itemId string) float32 {
  592. userIndex := ccd.UserIndex.ToNumber(userId)
  593. itemIndex := ccd.ItemIndex.ToNumber(itemId)
  594. if userIndex == base.NotId {
  595. log.Logger().Info("unknown user:", zap.String("user_id", userId))
  596. return 0
  597. }
  598. if itemIndex == base.NotId {
  599. log.Logger().Info("unknown item:", zap.String("item_id", itemId))
  600. return 0
  601. }
  602. return ccd.InternalPredict(userIndex, itemIndex)
  603. }
  604. func (ccd *CCD) InternalPredict(userIndex, itemIndex int32) float32 {
  605. ret := float32(0.0)
  606. if itemIndex != base.NotId && userIndex != base.NotId {
  607. ret = floats.Dot(ccd.UserFactor[userIndex], ccd.ItemFactor[itemIndex])
  608. } else {
  609. log.Logger().Warn("unknown user or item")
  610. }
  611. return ret
  612. }
  613. func (ccd *CCD) Clear() {
  614. ccd.UserIndex = nil
  615. ccd.ItemIndex = nil
  616. ccd.ItemFactor = nil
  617. ccd.UserFactor = nil
  618. }
  619. func (ccd *CCD) Invalid() bool {
  620. return ccd == nil ||
  621. ccd.UserIndex == nil ||
  622. ccd.ItemIndex == nil ||
  623. ccd.ItemFactor == nil ||
  624. ccd.UserFactor == nil
  625. }
  626. func (ccd *CCD) Init(trainSet *DataSet) {
  627. // Initialize
  628. newUserFactor := ccd.GetRandomGenerator().NormalMatrix(trainSet.UserCount(), ccd.nFactors, ccd.initMean, ccd.initStdDev)
  629. newItemFactor := ccd.GetRandomGenerator().NormalMatrix(trainSet.ItemCount(), ccd.nFactors, ccd.initMean, ccd.initStdDev)
  630. // Relocate parameters
  631. if ccd.UserIndex != nil {
  632. for _, userId := range trainSet.UserIndex.GetNames() {
  633. oldIndex := ccd.UserIndex.ToNumber(userId)
  634. newIndex := trainSet.UserIndex.ToNumber(userId)
  635. if oldIndex != base.NotId {
  636. newUserFactor[newIndex] = ccd.UserFactor[oldIndex]
  637. }
  638. }
  639. }
  640. if ccd.ItemIndex != nil {
  641. for _, itemId := range trainSet.ItemIndex.GetNames() {
  642. oldIndex := ccd.ItemIndex.ToNumber(itemId)
  643. newIndex := trainSet.ItemIndex.ToNumber(itemId)
  644. if oldIndex != base.NotId {
  645. newItemFactor[newIndex] = ccd.ItemFactor[oldIndex]
  646. }
  647. }
  648. }
  649. // Initialize base
  650. ccd.UserFactor = newUserFactor
  651. ccd.ItemFactor = newItemFactor
  652. ccd.BaseMatrixFactorization.Init(trainSet)
  653. }
  654. // Fit the CCD model. Its task complexity is O(ccd.nEpochs).
  655. func (ccd *CCD) Fit(ctx context.Context, trainSet, valSet *DataSet, config *FitConfig) Score {
  656. config = config.LoadDefaultIfNil()
  657. log.Logger().Info("fit ccd",
  658. zap.Int("train_set_size", trainSet.Count()),
  659. zap.Int("test_set_size", valSet.Count()),
  660. zap.Any("params", ccd.GetParams()),
  661. zap.Any("config", config))
  662. ccd.Init(trainSet)
  663. // Create temporary matrix
  664. maxJobs := config.MaxJobs()
  665. s := base.NewMatrix32(ccd.nFactors, ccd.nFactors)
  666. userPredictions := make([][]float32, maxJobs)
  667. itemPredictions := make([][]float32, maxJobs)
  668. userRes := make([][]float32, maxJobs)
  669. itemRes := make([][]float32, maxJobs)
  670. for i := 0; i < maxJobs; i++ {
  671. userPredictions[i] = make([]float32, trainSet.ItemCount())
  672. itemPredictions[i] = make([]float32, trainSet.UserCount())
  673. userRes[i] = make([]float32, trainSet.ItemCount())
  674. itemRes[i] = make([]float32, trainSet.UserCount())
  675. }
  676. // evaluate initial model
  677. snapshots := SnapshotManger{}
  678. evalStart := time.Now()
  679. scores := Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(), NDCG, Precision, Recall)
  680. evalTime := time.Since(evalStart)
  681. log.Logger().Debug(fmt.Sprintf("fit ccd %v/%v", 0, ccd.nEpochs),
  682. zap.String("eval_time", evalTime.String()),
  683. zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), scores[0]),
  684. zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), scores[1]),
  685. zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2]))
  686. snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, ccd.UserFactor, ccd.ItemFactor)
  687. _, span := progress.Start(ctx, "CCD.Fit", ccd.nEpochs)
  688. for ep := 1; ep <= ccd.nEpochs; ep++ {
  689. fitStart := time.Now()
  690. // Update user factors
  691. // S^q <- \sum^N_{itemIndex=1} c_i q_i q_i^T
  692. floats.MatZero(s)
  693. for itemIndex := 0; itemIndex < trainSet.ItemCount(); itemIndex++ {
  694. if len(trainSet.ItemFeedback[itemIndex]) > 0 {
  695. for i := 0; i < ccd.nFactors; i++ {
  696. for j := 0; j < ccd.nFactors; j++ {
  697. s[i][j] += ccd.ItemFactor[itemIndex][i] * ccd.ItemFactor[itemIndex][j]
  698. }
  699. }
  700. }
  701. }
  702. _ = parallel.Parallel(trainSet.UserCount(), config.AvailableJobs(), func(workerId, userIndex int) error {
  703. userFeedback := trainSet.UserFeedback[userIndex]
  704. for _, i := range userFeedback {
  705. userPredictions[workerId][i] = ccd.InternalPredict(int32(userIndex), i)
  706. }
  707. for f := 0; f < ccd.nFactors; f++ {
  708. // for itemIndex \in R_u do \hat_{r}^f_{ui} <- \hat_{r}_{ui} - p_{uf]q_{if}
  709. for _, i := range userFeedback {
  710. userRes[workerId][i] = userPredictions[workerId][i] - ccd.UserFactor[userIndex][f]*ccd.ItemFactor[i][f]
  711. }
  712. // p_{uf} <-
  713. a, b, c := float32(0), float32(0), float32(0)
  714. for _, i := range userFeedback {
  715. a += (1 - (1-ccd.weight)*userRes[workerId][i]) * ccd.ItemFactor[i][f]
  716. c += (1 - ccd.weight) * ccd.ItemFactor[i][f] * ccd.ItemFactor[i][f]
  717. }
  718. for k := 0; k < ccd.nFactors; k++ {
  719. if k != f {
  720. b += ccd.weight * ccd.UserFactor[userIndex][k] * s[k][f]
  721. }
  722. }
  723. ccd.UserFactor[userIndex][f] = (a - b) / (c + ccd.weight*s[f][f] + ccd.reg)
  724. // for itemIndex \in R_u do \hat_{r}_{ui} <- \hat_{r}^f_{ui} - p_{uf]q_{if}
  725. for _, i := range userFeedback {
  726. userPredictions[workerId][i] = userRes[workerId][i] + ccd.UserFactor[userIndex][f]*ccd.ItemFactor[i][f]
  727. }
  728. }
  729. return nil
  730. })
  731. // Update item factors
  732. // S^p <- P^T P
  733. floats.MatZero(s)
  734. for userIndex := 0; userIndex < trainSet.UserCount(); userIndex++ {
  735. if len(trainSet.UserFeedback[userIndex]) > 0 {
  736. for i := 0; i < ccd.nFactors; i++ {
  737. for j := 0; j < ccd.nFactors; j++ {
  738. s[i][j] += ccd.UserFactor[userIndex][i] * ccd.UserFactor[userIndex][j]
  739. }
  740. }
  741. }
  742. }
  743. _ = parallel.Parallel(trainSet.ItemCount(), config.AvailableJobs(), func(workerId, itemIndex int) error {
  744. itemFeedback := trainSet.ItemFeedback[itemIndex]
  745. for _, u := range itemFeedback {
  746. itemPredictions[workerId][u] = ccd.InternalPredict(u, int32(itemIndex))
  747. }
  748. for f := 0; f < ccd.nFactors; f++ {
  749. // for itemIndex \in R_u do \hat_{r}^f_{ui} <- \hat_{r}_{ui} - p_{uf]q_{if}
  750. for _, u := range itemFeedback {
  751. itemRes[workerId][u] = itemPredictions[workerId][u] - ccd.UserFactor[u][f]*ccd.ItemFactor[itemIndex][f]
  752. }
  753. // q_{if} <-
  754. a, b, c := float32(0), float32(0), float32(0)
  755. for _, u := range itemFeedback {
  756. a += (1 - (1-ccd.weight)*itemRes[workerId][u]) * ccd.UserFactor[u][f]
  757. c += (1 - ccd.weight) * ccd.UserFactor[u][f] * ccd.UserFactor[u][f]
  758. }
  759. for k := 0; k < ccd.nFactors; k++ {
  760. if k != f {
  761. b += ccd.weight * ccd.ItemFactor[itemIndex][k] * s[k][f]
  762. }
  763. }
  764. ccd.ItemFactor[itemIndex][f] = (a - b) / (c + ccd.weight*s[f][f] + ccd.reg)
  765. // for itemIndex \in R_u do \hat_{r}_{ui} <- \hat_{r}^f_{ui} - p_{uf]q_{if}
  766. for _, u := range itemFeedback {
  767. itemPredictions[workerId][u] = itemRes[workerId][u] + ccd.UserFactor[u][f]*ccd.ItemFactor[itemIndex][f]
  768. }
  769. }
  770. return nil
  771. })
  772. fitTime := time.Since(fitStart)
  773. // Cross validation
  774. if ep%config.Verbose == 0 || ep == ccd.nEpochs {
  775. evalStart = time.Now()
  776. scores = Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(), NDCG, Precision, Recall)
  777. evalTime = time.Since(evalStart)
  778. log.Logger().Debug(fmt.Sprintf("fit ccd %v/%v", ep, ccd.nEpochs),
  779. zap.String("fit_time", fitTime.String()),
  780. zap.String("eval_time", evalTime.String()),
  781. zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), scores[0]),
  782. zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), scores[1]),
  783. zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2]))
  784. snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, ccd.UserFactor, ccd.ItemFactor)
  785. }
  786. span.Add(1)
  787. }
  788. span.End()
  789. // restore best snapshot
  790. ccd.UserFactor = snapshots.BestWeights[0].([][]float32)
  791. ccd.ItemFactor = snapshots.BestWeights[1].([][]float32)
  792. log.Logger().Info("fit ccd complete",
  793. zap.Float32(fmt.Sprintf("NDCG@%v", config.TopK), snapshots.BestScore.NDCG),
  794. zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), snapshots.BestScore.Precision),
  795. zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), snapshots.BestScore.Recall))
  796. return snapshots.BestScore
  797. }
  798. // Marshal model into byte stream.
  799. func (ccd *CCD) Marshal(w io.Writer) error {
  800. // write params
  801. err := ccd.BaseMatrixFactorization.Marshal(w)
  802. if err != nil {
  803. return errors.Trace(err)
  804. }
  805. // write user factors
  806. err = encoding.WriteMatrix(w, ccd.UserFactor)
  807. if err != nil {
  808. return errors.Trace(err)
  809. }
  810. // write item factors
  811. err = encoding.WriteMatrix(w, ccd.ItemFactor)
  812. if err != nil {
  813. return errors.Trace(err)
  814. }
  815. return nil
  816. }
  817. // Unmarshal model from byte stream.
  818. func (ccd *CCD) Unmarshal(r io.Reader) error {
  819. // read params
  820. var err error
  821. err = ccd.BaseMatrixFactorization.Unmarshal(r)
  822. if err != nil {
  823. return errors.Trace(err)
  824. }
  825. ccd.SetParams(ccd.Params)
  826. // read user factors
  827. ccd.UserFactor = base.NewMatrix32(int(ccd.UserIndex.Len()), ccd.nFactors)
  828. err = encoding.ReadMatrix(r, ccd.UserFactor)
  829. if err != nil {
  830. return errors.Trace(err)
  831. }
  832. // read item factors
  833. ccd.ItemFactor = base.NewMatrix32(int(ccd.ItemIndex.Len()), ccd.nFactors)
  834. err = encoding.ReadMatrix(r, ccd.ItemFactor)
  835. if err != nil {
  836. return errors.Trace(err)
  837. }
  838. return nil
  839. }