index_test.go 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. // Copyright 2022 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package search
  15. import (
  16. "context"
  17. "math/big"
  18. "runtime"
  19. "testing"
  20. "github.com/stretchr/testify/assert"
  21. "github.com/zhenghaoz/gorse/base/task"
  22. "github.com/zhenghaoz/gorse/model"
  23. "github.com/zhenghaoz/gorse/model/ranking"
  24. )
  25. func TestHNSW_InnerProduct(t *testing.T) {
  26. // load dataset
  27. trainSet, testSet, err := ranking.LoadDataFromBuiltIn("ml-100k")
  28. assert.NoError(t, err)
  29. m := ranking.NewBPR(model.Params{
  30. model.NFactors: 8,
  31. model.Reg: 0.01,
  32. model.Lr: 0.05,
  33. model.NEpochs: 30,
  34. model.InitMean: 0,
  35. model.InitStdDev: 0.001,
  36. })
  37. fitConfig := ranking.NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU()))
  38. m.Fit(context.Background(), trainSet, testSet, fitConfig)
  39. var vectors []Vector
  40. for i, itemFactor := range m.ItemFactor {
  41. var terms []string
  42. if big.NewInt(int64(i)).ProbablyPrime(0) {
  43. terms = append(terms, "prime")
  44. }
  45. vectors = append(vectors, NewDenseVector(itemFactor, terms, false))
  46. }
  47. // build vector index
  48. builder := NewHNSWBuilder(vectors, 10, runtime.NumCPU())
  49. idx, recall := builder.Build(context.Background(), 0.9, 5, false)
  50. assert.Greater(t, recall, float32(0.9))
  51. recall = builder.evaluateTermSearch(idx, true, "prime")
  52. assert.Greater(t, recall, float32(0.8))
  53. }
  54. func TestIVF_Cosine(t *testing.T) {
  55. // load dataset
  56. trainSet, _, err := ranking.LoadDataFromBuiltIn("ml-100k")
  57. assert.NoError(t, err)
  58. values := make([]float32, trainSet.UserCount())
  59. for i := range values {
  60. values[i] = 1
  61. }
  62. var vectors []Vector
  63. for i, feedback := range trainSet.ItemFeedback {
  64. var terms []string
  65. if big.NewInt(int64(i)).ProbablyPrime(0) {
  66. terms = append(terms, "prime")
  67. }
  68. vectors = append(vectors, NewDictionaryVector(feedback, values, terms, false))
  69. }
  70. // build vector index
  71. builder := NewIVFBuilder(vectors, 10)
  72. idx, recall := builder.Build(0.9, 5, true)
  73. assert.Greater(t, recall, float32(0.9))
  74. recall = builder.evaluateTermSearch(idx, true, "prime")
  75. assert.Greater(t, recall, float32(0.8))
  76. }