evaluator.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. // Copyright 2020 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package ranking
  15. import (
  16. "github.com/chewxy/math32"
  17. mapset "github.com/deckarep/golang-set/v2"
  18. "github.com/thoas/go-funk"
  19. "github.com/zhenghaoz/gorse/base/copier"
  20. "github.com/zhenghaoz/gorse/base/floats"
  21. "github.com/zhenghaoz/gorse/base/heap"
  22. "github.com/zhenghaoz/gorse/base/parallel"
  23. )
  24. /* Evaluate Item Ranking */
  25. // Metric is used by evaluators in personalized ranking tasks.
  26. type Metric func(targetSet mapset.Set[int32], rankList []int32) float32
  27. // Evaluate evaluates a model in top-n tasks.
  28. func Evaluate(estimator MatrixFactorization, testSet, trainSet *DataSet, topK, numCandidates, nJobs int, scorers ...Metric) []float32 {
  29. partSum := make([][]float32, nJobs)
  30. partCount := make([]float32, nJobs)
  31. for i := 0; i < nJobs; i++ {
  32. partSum[i] = make([]float32, len(scorers))
  33. }
  34. //rng := NewRandomGenerator(0)
  35. // For all UserFeedback
  36. negatives := testSet.NegativeSample(trainSet, numCandidates)
  37. _ = parallel.Parallel(testSet.UserCount(), nJobs, func(workerId, userIndex int) error {
  38. // Find top-n ItemFeedback in test set
  39. targetSet := mapset.NewSet(testSet.UserFeedback[userIndex]...)
  40. if targetSet.Cardinality() > 0 {
  41. // Sample negative samples
  42. //userTrainSet := NewSet(trainSet.UserFeedback[userIndex])
  43. negativeSample := negatives[userIndex]
  44. candidates := make([]int32, 0, targetSet.Cardinality()+len(negativeSample))
  45. candidates = append(candidates, testSet.UserFeedback[userIndex]...)
  46. candidates = append(candidates, negativeSample...)
  47. // Find top-n ItemFeedback in predictions
  48. rankList, _ := Rank(estimator, int32(userIndex), candidates, topK)
  49. partCount[workerId]++
  50. for i, metric := range scorers {
  51. partSum[workerId][i] += metric(targetSet, rankList)
  52. }
  53. }
  54. return nil
  55. })
  56. sum := make([]float32, len(scorers))
  57. for i := 0; i < nJobs; i++ {
  58. for j := range partSum[i] {
  59. sum[j] += partSum[i][j]
  60. }
  61. }
  62. count := funk.SumFloat32(partCount)
  63. floats.MulConst(sum, 1/count)
  64. return sum
  65. }
  66. // NDCG means Normalized Discounted Cumulative Gain.
  67. func NDCG(targetSet mapset.Set[int32], rankList []int32) float32 {
  68. // IDCG = \sum^{|REL|}_{i=1} \frac {1} {\log_2(i+1)}
  69. idcg := float32(0)
  70. for i := 0; i < targetSet.Cardinality() && i < len(rankList); i++ {
  71. idcg += 1.0 / math32.Log2(float32(i)+2.0)
  72. }
  73. // DCG = \sum^{N}_{i=1} \frac {2^{rel_i}-1} {\log_2(i+1)}
  74. dcg := float32(0)
  75. for i, itemId := range rankList {
  76. if targetSet.Contains(itemId) {
  77. dcg += 1.0 / math32.Log2(float32(i)+2.0)
  78. }
  79. }
  80. return dcg / idcg
  81. }
  82. // Precision is the fraction of relevant ItemFeedback among the recommended ItemFeedback.
  83. //
  84. // \frac{|relevant documents| \cap |retrieved documents|} {|{retrieved documents}|}
  85. func Precision(targetSet mapset.Set[int32], rankList []int32) float32 {
  86. hit := float32(0)
  87. for _, itemId := range rankList {
  88. if targetSet.Contains(itemId) {
  89. hit++
  90. }
  91. }
  92. return hit / float32(len(rankList))
  93. }
  94. // Recall is the fraction of relevant ItemFeedback that have been recommended over the total
  95. // amount of relevant ItemFeedback.
  96. //
  97. // \frac{|relevant documents| \cap |retrieved documents|} {|{relevant documents}|}
  98. func Recall(targetSet mapset.Set[int32], rankList []int32) float32 {
  99. hit := 0
  100. for _, itemId := range rankList {
  101. if targetSet.Contains(itemId) {
  102. hit++
  103. }
  104. }
  105. return float32(hit) / float32(targetSet.Cardinality())
  106. }
  107. // HR means Hit Ratio.
  108. func HR(targetSet mapset.Set[int32], rankList []int32) float32 {
  109. for _, itemId := range rankList {
  110. if targetSet.Contains(itemId) {
  111. return 1
  112. }
  113. }
  114. return 0
  115. }
  116. // MAP means Mean Average Precision.
  117. // mAP: http://sdsawtelle.github.io/blog/output/mean-average-precision-MAP-for-recommender-systems.html
  118. func MAP(targetSet mapset.Set[int32], rankList []int32) float32 {
  119. sumPrecision := float32(0)
  120. hit := 0
  121. for i, itemId := range rankList {
  122. if targetSet.Contains(itemId) {
  123. hit++
  124. sumPrecision += float32(hit) / float32(i+1)
  125. }
  126. }
  127. return sumPrecision / float32(targetSet.Cardinality())
  128. }
  129. // MRR means Mean Reciprocal Rank.
  130. //
  131. // The mean reciprocal rank is a statistic measure for evaluating any process
  132. // that produces a list of possible responses to a sample of queries, ordered
  133. // by probability of correctness. The reciprocal rank of a query response is
  134. // the multiplicative inverse of the rank of the first correct answer: 1 for
  135. // first place, ​1⁄2 for second place, ​1⁄3 for third place and so on. The
  136. // mean reciprocal rank is the average of the reciprocal ranks of results for
  137. // a sample of queries Q:
  138. //
  139. // MRR = \frac{1}{Q} \sum^{|Q|}_{i=1} \frac{1}{rank_i}
  140. func MRR(targetSet mapset.Set[int32], rankList []int32) float32 {
  141. for i, itemId := range rankList {
  142. if targetSet.Contains(itemId) {
  143. return 1 / float32(i+1)
  144. }
  145. }
  146. return 0
  147. }
  148. func Rank(model MatrixFactorization, userId int32, candidates []int32, topN int) ([]int32, []float32) {
  149. // Get top-n list
  150. itemsHeap := heap.NewTopKFilter[int32, float32](topN)
  151. for _, itemId := range candidates {
  152. itemsHeap.Push(itemId, model.InternalPredict(userId, itemId))
  153. }
  154. elem, scores := itemsHeap.PopAll()
  155. recommends := make([]int32, len(elem))
  156. for i := range recommends {
  157. recommends[i] = elem[i]
  158. }
  159. return recommends, scores
  160. }
  161. // SnapshotManger manages the best snapshot.
  162. type SnapshotManger struct {
  163. BestWeights []interface{}
  164. BestScore Score
  165. }
  166. // AddSnapshot adds a copied snapshot.
  167. func (sm *SnapshotManger) AddSnapshot(score Score, weights ...interface{}) {
  168. if sm.BestWeights == nil || score.NDCG > sm.BestScore.NDCG {
  169. sm.BestScore = score
  170. if err := copier.Copy(&sm.BestWeights, weights); err != nil {
  171. panic(err)
  172. }
  173. }
  174. }
  175. // AddSnapshotNoCopy adds a snapshot without copy.
  176. func (sm *SnapshotManger) AddSnapshotNoCopy(score Score, weights ...interface{}) {
  177. if sm.BestWeights == nil || score.NDCG > sm.BestScore.NDCG {
  178. sm.BestScore = score
  179. if err := copier.Copy(&sm.BestWeights, weights); err != nil {
  180. panic(err)
  181. }
  182. }
  183. }