ivf.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. // Copyright 2022 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package search
  15. import (
  16. "context"
  17. "math"
  18. "math/rand"
  19. "sync"
  20. "time"
  21. "github.com/chewxy/math32"
  22. "github.com/zhenghaoz/gorse/base"
  23. "github.com/zhenghaoz/gorse/base/heap"
  24. "github.com/zhenghaoz/gorse/base/log"
  25. "github.com/zhenghaoz/gorse/base/parallel"
  26. "github.com/zhenghaoz/gorse/base/task"
  27. "go.uber.org/atomic"
  28. "go.uber.org/zap"
  29. "modernc.org/mathutil"
  30. )
  31. const (
  32. DefaultTestSize = 1000
  33. DefaultMaxIter = 100
  34. )
  35. var _ VectorIndex = &IVF{}
  36. type IVF struct {
  37. clusters []ivfCluster
  38. data []Vector
  39. k int
  40. errorRate float32
  41. maxIter int
  42. numProbe int
  43. jobsAlloc *task.JobsAllocator
  44. }
  45. type IVFConfig func(ivf *IVF)
  46. func SetNumProbe(numProbe int) IVFConfig {
  47. return func(ivf *IVF) {
  48. ivf.numProbe = numProbe
  49. }
  50. }
  51. func SetClusterErrorRate(errorRate float32) IVFConfig {
  52. return func(ivf *IVF) {
  53. ivf.errorRate = errorRate
  54. }
  55. }
  56. func SetIVFJobsAllocator(jobsAlloc *task.JobsAllocator) IVFConfig {
  57. return func(ivf *IVF) {
  58. ivf.jobsAlloc = jobsAlloc
  59. }
  60. }
  61. func SetMaxIteration(maxIter int) IVFConfig {
  62. return func(ivf *IVF) {
  63. ivf.maxIter = maxIter
  64. }
  65. }
  66. type ivfCluster struct {
  67. centroid CentroidVector
  68. observations []int32
  69. mu sync.Mutex
  70. }
  71. func NewIVF(vectors []Vector, configs ...IVFConfig) *IVF {
  72. idx := &IVF{
  73. data: vectors,
  74. k: int(math32.Sqrt(float32(len(vectors)))),
  75. errorRate: 0.05,
  76. maxIter: DefaultMaxIter,
  77. numProbe: 1,
  78. }
  79. for _, config := range configs {
  80. config(idx)
  81. }
  82. return idx
  83. }
  84. func (idx *IVF) Search(q Vector, n int, prune0 bool) (values []int32, scores []float32) {
  85. cq := heap.NewTopKFilter[int, float32](idx.numProbe)
  86. for c := range idx.clusters {
  87. d := idx.clusters[c].centroid.Distance(q)
  88. cq.Push(c, -d)
  89. }
  90. pq := heap.NewPriorityQueue(true)
  91. clusters, _ := cq.PopAll()
  92. for _, c := range clusters {
  93. for _, i := range idx.clusters[c].observations {
  94. if idx.data[i] != q {
  95. pq.Push(i, q.Distance(idx.data[i]))
  96. if pq.Len() > n {
  97. pq.Pop()
  98. }
  99. }
  100. }
  101. }
  102. pq = pq.Reverse()
  103. for pq.Len() > 0 {
  104. value, score := pq.Pop()
  105. if !prune0 || score < 0 {
  106. values = append(values, value)
  107. scores = append(scores, score)
  108. }
  109. }
  110. return
  111. }
  112. func (idx *IVF) MultiSearch(q Vector, terms []string, n int, prune0 bool) (values map[string][]int32, scores map[string][]float32) {
  113. cq := heap.NewTopKFilter[int, float32](idx.numProbe)
  114. for c := range idx.clusters {
  115. d := idx.clusters[c].centroid.Distance(q)
  116. cq.Push(c, -d)
  117. }
  118. // create priority queues
  119. queues := make(map[string]*heap.PriorityQueue)
  120. queues[""] = heap.NewPriorityQueue(true)
  121. for _, term := range terms {
  122. queues[term] = heap.NewPriorityQueue(true)
  123. }
  124. // search with terms
  125. clusters, _ := cq.PopAll()
  126. for _, c := range clusters {
  127. for _, i := range idx.clusters[c].observations {
  128. if idx.data[i] != q {
  129. vec := idx.data[i]
  130. queues[""].Push(i, q.Distance(vec))
  131. if queues[""].Len() > n {
  132. queues[""].Pop()
  133. }
  134. for _, term := range vec.Terms() {
  135. if _, match := queues[term]; match {
  136. queues[term].Push(i, q.Distance(vec))
  137. if queues[term].Len() > n {
  138. queues[term].Pop()
  139. }
  140. }
  141. }
  142. }
  143. }
  144. }
  145. // retrieve results
  146. values = make(map[string][]int32)
  147. scores = make(map[string][]float32)
  148. for term, pq := range queues {
  149. pq = pq.Reverse()
  150. for pq.Len() > 0 {
  151. value, score := pq.Pop()
  152. if !prune0 || score < 0 {
  153. values[term] = append(values[term], value)
  154. scores[term] = append(scores[term], score)
  155. }
  156. }
  157. }
  158. return
  159. }
  160. func (idx *IVF) Build(_ context.Context) {
  161. if idx.k > len(idx.data) {
  162. panic("the size of the observations set must greater than or equal to k")
  163. } else if len(idx.data) == 0 {
  164. log.Logger().Warn("no vectors for building IVF")
  165. return
  166. }
  167. // initialize clusters
  168. clusters := make([]ivfCluster, idx.k)
  169. assignments := make([]int, len(idx.data))
  170. for i := range idx.data {
  171. if !idx.data[i].IsHidden() {
  172. c := rand.Intn(idx.k)
  173. clusters[c].observations = append(clusters[c].observations, int32(i))
  174. assignments[i] = c
  175. }
  176. }
  177. for c := range clusters {
  178. clusters[c].centroid = idx.data[0].Centroid(idx.data, clusters[c].observations)
  179. }
  180. for it := 0; it < idx.maxIter; it++ {
  181. errorCount := atomic.NewInt32(0)
  182. // reassign clusters
  183. nextClusters := make([]ivfCluster, idx.k)
  184. _ = parallel.Parallel(len(idx.data), idx.jobsAlloc.AvailableJobs(), func(_, i int) error {
  185. if !idx.data[i].IsHidden() {
  186. nextCluster, nextDistance := -1, float32(math32.MaxFloat32)
  187. for c := range clusters {
  188. d := clusters[c].centroid.Distance(idx.data[i])
  189. if d < nextDistance {
  190. nextCluster = c
  191. nextDistance = d
  192. }
  193. }
  194. if nextCluster == -1 {
  195. return nil
  196. }
  197. if nextCluster != assignments[i] {
  198. errorCount.Inc()
  199. }
  200. nextClusters[nextCluster].mu.Lock()
  201. defer nextClusters[nextCluster].mu.Unlock()
  202. nextClusters[nextCluster].observations = append(nextClusters[nextCluster].observations, int32(i))
  203. assignments[i] = nextCluster
  204. }
  205. return nil
  206. })
  207. log.Logger().Debug("spatial k means clustering",
  208. zap.Int32("changes", errorCount.Load()))
  209. if float32(errorCount.Load())/float32(len(idx.data)) < idx.errorRate {
  210. idx.clusters = clusters
  211. break
  212. }
  213. for c := range clusters {
  214. nextClusters[c].centroid = idx.data[0].Centroid(idx.data, nextClusters[c].observations)
  215. }
  216. clusters = nextClusters
  217. }
  218. }
  219. type IVFBuilder struct {
  220. bruteForce *Bruteforce
  221. data []Vector
  222. testSize int
  223. k int
  224. rng base.RandomGenerator
  225. configs []IVFConfig
  226. }
  227. func NewIVFBuilder(data []Vector, k int, configs ...IVFConfig) *IVFBuilder {
  228. b := &IVFBuilder{
  229. bruteForce: NewBruteforce(data),
  230. data: data,
  231. testSize: DefaultTestSize,
  232. k: k,
  233. rng: base.NewRandomGenerator(0),
  234. configs: configs,
  235. }
  236. b.bruteForce.Build(context.Background())
  237. return b
  238. }
  239. func (b *IVFBuilder) evaluate(idx *IVF, prune0 bool) float32 {
  240. testSize := mathutil.Min(b.testSize, len(b.data))
  241. samples := b.rng.Sample(0, len(b.data), testSize)
  242. var result, count float32
  243. var mu sync.Mutex
  244. _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(), func(_, i int) error {
  245. sample := samples[i]
  246. expected, _ := b.bruteForce.Search(b.data[sample], b.k, prune0)
  247. if len(expected) > 0 {
  248. actual, _ := idx.Search(b.data[sample], b.k, prune0)
  249. mu.Lock()
  250. defer mu.Unlock()
  251. result += recall(expected, actual)
  252. count++
  253. }
  254. return nil
  255. })
  256. if count == 0 {
  257. return 0
  258. }
  259. return result / count
  260. }
  261. func (b *IVFBuilder) Build(recall float32, numEpoch int, prune0 bool) (idx *IVF, score float32) {
  262. idx = NewIVF(b.data, b.configs...)
  263. start := time.Now()
  264. idx.Build(context.Background())
  265. buildTime := time.Since(start)
  266. idx.numProbe = int(math32.Ceil(float32(b.k) / math32.Sqrt(float32(len(b.data)))))
  267. for i := 0; i < numEpoch; i++ {
  268. score = b.evaluate(idx, prune0)
  269. log.Logger().Info("try to build vector index",
  270. zap.String("index_type", "IVF"),
  271. zap.Int("num_probe", idx.numProbe),
  272. zap.Float32("recall", score),
  273. zap.String("build_time", buildTime.String()))
  274. if score >= recall {
  275. return
  276. } else {
  277. idx.numProbe <<= 1
  278. }
  279. }
  280. return
  281. }
  282. func (b *IVFBuilder) evaluateTermSearch(idx *IVF, prune0 bool, term string) float32 {
  283. testSize := mathutil.Min(b.testSize, len(b.data))
  284. samples := b.rng.Sample(0, len(b.data), testSize)
  285. var result, count float32
  286. var mu sync.Mutex
  287. _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(), func(_, i int) error {
  288. sample := samples[i]
  289. expected, _ := b.bruteForce.MultiSearch(b.data[sample], []string{term}, b.k, prune0)
  290. if len(expected) > 0 {
  291. actual, _ := idx.MultiSearch(b.data[sample], []string{term}, b.k, prune0)
  292. mu.Lock()
  293. defer mu.Unlock()
  294. result += recall(expected[term], actual[term])
  295. count++
  296. }
  297. return nil
  298. })
  299. return result / count
  300. }
  301. func EstimateIVFBuilderComplexity(num, numEpoch int) int {
  302. // clustering complexity
  303. complexity := DefaultMaxIter * num * int(math.Sqrt(float64(num)))
  304. // search complexity
  305. complexity += num * DefaultTestSize * numEpoch
  306. return complexity
  307. }