bruteforce.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. // Copyright 2022 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package search
  15. import (
  16. "context"
  17. "github.com/zhenghaoz/gorse/base/heap"
  18. )
  19. var _ VectorIndex = &Bruteforce{}
  20. // Bruteforce is a naive implementation of vector index.
  21. type Bruteforce struct {
  22. vectors []Vector
  23. }
  24. // Build a vector index on data.
  25. func (b *Bruteforce) Build(_ context.Context) {}
  26. // NewBruteforce creates a Bruteforce vector index.
  27. func NewBruteforce(vectors []Vector) *Bruteforce {
  28. return &Bruteforce{
  29. vectors: vectors,
  30. }
  31. }
  32. // Search top-k similar vectors.
  33. func (b *Bruteforce) Search(q Vector, n int, prune0 bool) (values []int32, scores []float32) {
  34. pq := heap.NewPriorityQueue(true)
  35. for i, vec := range b.vectors {
  36. if vec != q {
  37. pq.Push(int32(i), q.Distance(vec))
  38. if pq.Len() > n {
  39. pq.Pop()
  40. }
  41. }
  42. }
  43. pq = pq.Reverse()
  44. for pq.Len() > 0 {
  45. value, score := pq.Pop()
  46. if !prune0 || score < 0 {
  47. values = append(values, value)
  48. scores = append(scores, score)
  49. }
  50. }
  51. return
  52. }
  53. func (b *Bruteforce) MultiSearch(q Vector, terms []string, n int, prune0 bool) (values map[string][]int32, scores map[string][]float32) {
  54. // create priority queues
  55. queues := make(map[string]*heap.PriorityQueue)
  56. queues[""] = heap.NewPriorityQueue(true)
  57. for _, term := range terms {
  58. queues[term] = heap.NewPriorityQueue(true)
  59. }
  60. // search with terms
  61. for i, vec := range b.vectors {
  62. if vec != q {
  63. queues[""].Push(int32(i), q.Distance(vec))
  64. if queues[""].Len() > n {
  65. queues[""].Pop()
  66. }
  67. for _, term := range vec.Terms() {
  68. if _, match := queues[term]; match {
  69. queues[term].Push(int32(i), q.Distance(vec))
  70. if queues[term].Len() > n {
  71. queues[term].Pop()
  72. }
  73. }
  74. }
  75. }
  76. }
  77. // retrieve results
  78. values = make(map[string][]int32)
  79. scores = make(map[string][]float32)
  80. for term, pq := range queues {
  81. pq = pq.Reverse()
  82. for pq.Len() > 0 {
  83. value, score := pq.Pop()
  84. if !prune0 || score < 0 {
  85. values[term] = append(values[term], value)
  86. scores[term] = append(scores[term], score)
  87. }
  88. }
  89. }
  90. return
  91. }