index.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. // Copyright 2022 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package search
  15. import (
  16. "context"
  17. "fmt"
  18. "reflect"
  19. "sort"
  20. "github.com/chewxy/math32"
  21. "github.com/zhenghaoz/gorse/base/floats"
  22. "github.com/zhenghaoz/gorse/base/log"
  23. "go.uber.org/zap"
  24. "modernc.org/sortutil"
  25. )
  26. type Vector interface {
  27. Distance(vector Vector) float32
  28. Terms() []string
  29. IsHidden() bool
  30. Centroid(vectors []Vector, indices []int32) CentroidVector
  31. }
  32. type DenseVector struct {
  33. data []float32
  34. terms []string
  35. isHidden bool
  36. }
  37. func NewDenseVector(data []float32, terms []string, isHidden bool) *DenseVector {
  38. return &DenseVector{
  39. data: data,
  40. terms: terms,
  41. isHidden: isHidden,
  42. }
  43. }
  44. func (v *DenseVector) Distance(vector Vector) float32 {
  45. feedbackVector, isFeedback := vector.(*DenseVector)
  46. if !isFeedback {
  47. log.Logger().Fatal("vector type mismatch",
  48. zap.String("expect", reflect.TypeOf(v).String()),
  49. zap.String("actual", reflect.TypeOf(vector).String()))
  50. }
  51. return -floats.Dot(v.data, feedbackVector.data)
  52. }
  53. func (v *DenseVector) Terms() []string {
  54. return v.terms
  55. }
  56. func (v *DenseVector) IsHidden() bool {
  57. return v.isHidden
  58. }
  59. func (v *DenseVector) Centroid(_ []Vector, _ []int32) CentroidVector {
  60. panic("not implemented")
  61. }
  62. type DictionaryVector struct {
  63. isHidden bool
  64. terms []string
  65. indices []int32
  66. values []float32
  67. norm float32
  68. }
  69. func NewDictionaryVector(indices []int32, values []float32, terms []string, isHidden bool) *DictionaryVector {
  70. sort.Sort(sortutil.Int32Slice(indices))
  71. var norm float32
  72. for _, i := range indices {
  73. norm += values[i]
  74. }
  75. norm = math32.Sqrt(norm)
  76. return &DictionaryVector{
  77. isHidden: isHidden,
  78. terms: terms,
  79. indices: indices,
  80. values: values,
  81. norm: norm,
  82. }
  83. }
  84. func (v *DictionaryVector) Dot(vector *DictionaryVector) (float32, float32) {
  85. i, j, sum, common := 0, 0, float32(0), float32(0)
  86. for i < len(v.indices) && j < len(vector.indices) {
  87. if v.indices[i] == vector.indices[j] {
  88. sum += v.values[v.indices[i]]
  89. common++
  90. i++
  91. j++
  92. } else if v.indices[i] < vector.indices[j] {
  93. i++
  94. } else if v.indices[i] > vector.indices[j] {
  95. j++
  96. }
  97. }
  98. return sum, common
  99. }
  100. const similarityShrink = 100
  101. func (v *DictionaryVector) Distance(vector Vector) float32 {
  102. var score float32
  103. if dictVec, isDictVec := vector.(*DictionaryVector); !isDictVec {
  104. panic(fmt.Sprintf("unexpected vector type: %v", reflect.TypeOf(vector)))
  105. } else {
  106. dot, common := v.Dot(dictVec)
  107. if dot > 0 {
  108. score = -dot / v.norm / dictVec.norm * common / (common + similarityShrink)
  109. }
  110. }
  111. return score
  112. }
  113. func (v *DictionaryVector) Terms() []string {
  114. return v.terms
  115. }
  116. func (v *DictionaryVector) IsHidden() bool {
  117. return v.isHidden
  118. }
  119. type CentroidVector interface {
  120. Distance(vector Vector) float32
  121. }
  122. type DictionaryCentroidVector struct {
  123. data map[int32]float32
  124. norm float32
  125. }
  126. func (v DictionaryVector) Centroid(vectors []Vector, indices []int32) CentroidVector {
  127. data := make(map[int32]float32)
  128. for _, i := range indices {
  129. vector, isDictVector := vectors[i].(*DictionaryVector)
  130. if !isDictVector {
  131. panic(fmt.Sprintf("unexpected vector type: %v", reflect.TypeOf(vector)))
  132. }
  133. for _, i := range vector.indices {
  134. data[i] += math32.Sqrt(vector.values[i])
  135. }
  136. }
  137. var norm float32
  138. for _, val := range data {
  139. norm += val * val
  140. }
  141. norm = math32.Sqrt(norm)
  142. for i := range data {
  143. data[i] /= norm
  144. }
  145. return &DictionaryCentroidVector{
  146. data: data,
  147. norm: norm,
  148. }
  149. }
  150. func (v *DictionaryCentroidVector) Distance(vector Vector) float32 {
  151. var sum, common float32
  152. if dictVector, isDictVec := vector.(*DictionaryVector); !isDictVec {
  153. panic(fmt.Sprintf("unexpected vector type: %v", reflect.TypeOf(vector)))
  154. } else {
  155. for _, i := range dictVector.indices {
  156. if val, exist := v.data[i]; exist {
  157. sum += val * math32.Sqrt(v.data[i])
  158. common++
  159. }
  160. }
  161. }
  162. return -sum * common / (common + similarityShrink)
  163. }
  164. type VectorIndex interface {
  165. Build(ctx context.Context)
  166. Search(q Vector, n int, prune0 bool) ([]int32, []float32)
  167. MultiSearch(q Vector, terms []string, n int, prune0 bool) (map[string][]int32, map[string][]float32)
  168. }