hnsw.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. // Copyright 2022 gorse Project Authors
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package search
  15. import (
  16. "context"
  17. "math/rand"
  18. "runtime"
  19. "sync"
  20. "time"
  21. "github.com/chewxy/math32"
  22. mapset "github.com/deckarep/golang-set/v2"
  23. "github.com/zhenghaoz/gorse/base"
  24. "github.com/zhenghaoz/gorse/base/heap"
  25. "github.com/zhenghaoz/gorse/base/log"
  26. "github.com/zhenghaoz/gorse/base/parallel"
  27. "github.com/zhenghaoz/gorse/base/progress"
  28. "go.uber.org/zap"
  29. "modernc.org/mathutil"
  30. )
  31. var _ VectorIndex = &HNSW{}
  32. // HNSW is a vector index based on Hierarchical Navigable Small Worlds.
  33. type HNSW struct {
  34. vectors []Vector
  35. bottomNeighbors []*heap.PriorityQueue
  36. upperNeighbors []sync.Map
  37. enterPoint int32
  38. nodeMutexes []sync.RWMutex
  39. globalMutex sync.RWMutex
  40. initOnce sync.Once
  41. levelFactor float32
  42. maxConnection int // maximum number of connections for each element per layer
  43. maxConnection0 int
  44. ef int
  45. efConstruction int
  46. numJobs int
  47. }
  48. // HNSWConfig is the configuration function for HNSW.
  49. type HNSWConfig func(*HNSW)
  50. // SetHNSWNumJobs sets the number of jobs for building index.
  51. func SetHNSWNumJobs(numJobs int) HNSWConfig {
  52. return func(h *HNSW) {
  53. h.numJobs = numJobs
  54. }
  55. }
  56. // SetMaxConnection sets the number of connections in HNSW.
  57. func SetMaxConnection(maxConnection int) HNSWConfig {
  58. return func(h *HNSW) {
  59. h.levelFactor = 1.0 / math32.Log(float32(maxConnection))
  60. h.maxConnection = maxConnection
  61. h.maxConnection0 = maxConnection * 2
  62. }
  63. }
  64. // SetEFConstruction sets efConstruction in HNSW.
  65. func SetEFConstruction(efConstruction int) HNSWConfig {
  66. return func(h *HNSW) {
  67. h.efConstruction = efConstruction
  68. }
  69. }
  70. // SetEF sets the EF search value in HNSW.
  71. // By default ef for search is the same as efConstruction. To return it to this default behavior, set it to 0.
  72. func SetEF(ef int) HNSWConfig {
  73. return func(h *HNSW) {
  74. h.ef = ef
  75. }
  76. }
  77. // NewHNSW builds a vector index based on Hierarchical Navigable Small Worlds.
  78. func NewHNSW(vectors []Vector, configs ...HNSWConfig) *HNSW {
  79. h := &HNSW{
  80. vectors: vectors,
  81. levelFactor: 1.0 / math32.Log(48),
  82. maxConnection: 48,
  83. maxConnection0: 96,
  84. efConstruction: 100,
  85. numJobs: runtime.NumCPU(),
  86. }
  87. for _, config := range configs {
  88. config(h)
  89. }
  90. return h
  91. }
  92. // Search a vector in Hierarchical Navigable Small Worlds.
  93. func (h *HNSW) Search(q Vector, n int, prune0 bool) (values []int32, scores []float32) {
  94. w := h.knnSearch(q, n, h.efSearchValue(n))
  95. for w.Len() > 0 {
  96. value, score := w.Pop()
  97. if !prune0 || score < 0 {
  98. values = append(values, value)
  99. scores = append(scores, score)
  100. }
  101. }
  102. return
  103. }
  104. func (h *HNSW) knnSearch(q Vector, k, ef int) *heap.PriorityQueue {
  105. var (
  106. w *heap.PriorityQueue // set for the current the nearest element
  107. enterPoints = h.distance(q, []int32{h.enterPoint}) // get enter point for hnsw
  108. topLayer = len(h.upperNeighbors) // top layer for hnsw
  109. )
  110. for currentLayer := topLayer; currentLayer > 0; currentLayer-- {
  111. w = h.searchLayer(q, enterPoints, 1, currentLayer)
  112. enterPoints = heap.NewPriorityQueue(false)
  113. enterPoints.Push(w.Peek())
  114. }
  115. w = h.searchLayer(q, enterPoints, ef, 0)
  116. return h.selectNeighbors(q, w, k)
  117. }
  118. // Build a vector index on data.
  119. func (h *HNSW) Build(ctx context.Context) {
  120. completed := make(chan struct{}, h.numJobs)
  121. go func() {
  122. defer base.CheckPanic()
  123. completedCount, previousCount := 0, 0
  124. ticker := time.NewTicker(10 * time.Second)
  125. _, span := progress.Start(ctx, "HNSW.Build", len(h.vectors))
  126. for {
  127. select {
  128. case _, ok := <-completed:
  129. if !ok {
  130. span.End()
  131. return
  132. }
  133. completedCount++
  134. case <-ticker.C:
  135. throughput := completedCount - previousCount
  136. previousCount = completedCount
  137. if throughput > 0 {
  138. span.Add(throughput)
  139. log.Logger().Info("building index",
  140. zap.Int("n_indexed_vectors", completedCount),
  141. zap.Int("n_vectors", len(h.vectors)),
  142. zap.Int("throughput", throughput))
  143. }
  144. }
  145. }
  146. }()
  147. h.bottomNeighbors = make([]*heap.PriorityQueue, len(h.vectors))
  148. h.nodeMutexes = make([]sync.RWMutex, len(h.vectors))
  149. _ = parallel.Parallel(len(h.vectors), h.numJobs, func(_, jobId int) error {
  150. h.insert(int32(jobId))
  151. completed <- struct{}{}
  152. return nil
  153. })
  154. close(completed)
  155. }
  156. // insert i-th vector into the vector index.
  157. func (h *HNSW) insert(q int32) {
  158. // insert first point
  159. var isFirstPoint bool
  160. h.initOnce.Do(func() {
  161. if h.upperNeighbors == nil {
  162. h.bottomNeighbors[q] = heap.NewPriorityQueue(false)
  163. h.upperNeighbors = make([]sync.Map, 0)
  164. h.enterPoint = q
  165. isFirstPoint = true
  166. return
  167. }
  168. })
  169. if isFirstPoint {
  170. return
  171. }
  172. h.globalMutex.RLock()
  173. var (
  174. w *heap.PriorityQueue // list for the currently found nearest elements
  175. enterPoints = h.distance(h.vectors[q], []int32{h.enterPoint}) // get enter point for hnsw
  176. l = int(math32.Floor(-math32.Log(rand.Float32()) * h.levelFactor))
  177. topLayer = len(h.upperNeighbors)
  178. )
  179. h.globalMutex.RUnlock()
  180. if l > topLayer {
  181. h.globalMutex.Lock()
  182. defer h.globalMutex.Unlock()
  183. }
  184. for currentLayer := topLayer; currentLayer >= l+1; currentLayer-- {
  185. w = h.searchLayer(h.vectors[q], enterPoints, 1, currentLayer)
  186. enterPoints = h.selectNeighbors(h.vectors[q], w, 1)
  187. }
  188. h.nodeMutexes[q].Lock()
  189. for currentLayer := mathutil.Min(topLayer, l); currentLayer >= 0; currentLayer-- {
  190. w = h.searchLayer(h.vectors[q], enterPoints, h.efConstruction, currentLayer)
  191. neighbors := h.selectNeighbors(h.vectors[q], w, h.maxConnection)
  192. // add bidirectional connections from upperNeighbors to q at layer l_c
  193. h.setNeighbourhood(q, currentLayer, neighbors)
  194. for _, e := range neighbors.Elems() {
  195. h.nodeMutexes[e.Value].Lock()
  196. h.getNeighbourhood(e.Value, currentLayer).Push(q, e.Weight)
  197. connections := h.getNeighbourhood(e.Value, currentLayer)
  198. var currentMaxConnection int
  199. if currentLayer == 0 {
  200. currentMaxConnection = h.maxConnection0
  201. } else {
  202. currentMaxConnection = h.maxConnection
  203. }
  204. if connections.Len() > currentMaxConnection {
  205. // shrink connections of e if lc = 0 then M_max = M_max0
  206. newConnections := h.selectNeighbors(h.vectors[q], connections, h.maxConnection)
  207. h.setNeighbourhood(e.Value, currentLayer, newConnections)
  208. }
  209. h.nodeMutexes[e.Value].Unlock()
  210. }
  211. enterPoints = w
  212. }
  213. h.nodeMutexes[q].Unlock()
  214. if l > topLayer {
  215. // set enter point for hnsw to q
  216. h.enterPoint = q
  217. h.upperNeighbors = append(h.upperNeighbors, sync.Map{})
  218. h.setNeighbourhood(q, topLayer+1, heap.NewPriorityQueue(false))
  219. }
  220. }
  221. func (h *HNSW) searchLayer(q Vector, enterPoints *heap.PriorityQueue, ef, currentLayer int) *heap.PriorityQueue {
  222. var (
  223. v = mapset.NewSet(enterPoints.Values()...) // set of visited elements
  224. candidates = enterPoints.Clone() // set of candidates
  225. w = enterPoints.Reverse() // dynamic list of found nearest upperNeighbors
  226. )
  227. for candidates.Len() > 0 {
  228. // extract nearest element from candidates to q
  229. c, cq := candidates.Pop()
  230. // get the furthest element from w to q
  231. _, fq := w.Peek()
  232. if cq > fq {
  233. break // all elements in w are evaluated
  234. }
  235. // update candidates and w
  236. h.nodeMutexes[c].RLock()
  237. neighbors := h.getNeighbourhood(c, currentLayer).Values()
  238. h.nodeMutexes[c].RUnlock()
  239. for _, e := range neighbors {
  240. if !v.Contains(e) {
  241. v.Add(e)
  242. // get the furthest element from w to q
  243. _, fq = w.Peek()
  244. if eq := h.vectors[e].Distance(q); eq < fq || w.Len() < ef {
  245. candidates.Push(e, eq)
  246. w.Push(e, eq)
  247. if w.Len() > ef {
  248. // remove the furthest element from w to q
  249. w.Pop()
  250. }
  251. }
  252. }
  253. }
  254. }
  255. return w.Reverse()
  256. }
  257. func (h *HNSW) setNeighbourhood(e int32, currentLayer int, connections *heap.PriorityQueue) {
  258. if currentLayer == 0 {
  259. h.bottomNeighbors[e] = connections
  260. } else {
  261. h.upperNeighbors[currentLayer-1].Store(e, connections)
  262. }
  263. }
  264. func (h *HNSW) getNeighbourhood(e int32, currentLayer int) *heap.PriorityQueue {
  265. if currentLayer == 0 {
  266. return h.bottomNeighbors[e]
  267. } else {
  268. temp, _ := h.upperNeighbors[currentLayer-1].Load(e)
  269. return temp.(*heap.PriorityQueue)
  270. }
  271. }
  272. func (h *HNSW) selectNeighbors(_ Vector, candidates *heap.PriorityQueue, m int) *heap.PriorityQueue {
  273. pq := candidates.Reverse()
  274. for pq.Len() > m {
  275. pq.Pop()
  276. }
  277. return pq.Reverse()
  278. }
  279. func (h *HNSW) distance(q Vector, points []int32) *heap.PriorityQueue {
  280. pq := heap.NewPriorityQueue(false)
  281. for _, point := range points {
  282. pq.Push(point, h.vectors[point].Distance(q))
  283. }
  284. return pq
  285. }
  286. type HNSWBuilder struct {
  287. bruteForce *Bruteforce
  288. data []Vector
  289. testSize int
  290. k int
  291. rng base.RandomGenerator
  292. numJobs int
  293. }
  294. func NewHNSWBuilder(data []Vector, k, numJobs int) *HNSWBuilder {
  295. b := &HNSWBuilder{
  296. bruteForce: NewBruteforce(data),
  297. data: data,
  298. testSize: DefaultTestSize,
  299. k: k,
  300. rng: base.NewRandomGenerator(0),
  301. numJobs: numJobs,
  302. }
  303. b.bruteForce.Build(context.Background())
  304. return b
  305. }
  306. func recall(expected, actual []int32) float32 {
  307. var result float32
  308. truth := mapset.NewSet(expected...)
  309. for _, v := range actual {
  310. if truth.Contains(v) {
  311. result++
  312. }
  313. }
  314. if result == 0 {
  315. return 0
  316. }
  317. return result / float32(len(actual))
  318. }
  319. func (b *HNSWBuilder) evaluate(idx *HNSW, prune0 bool) float32 {
  320. testSize := mathutil.Min(b.testSize, len(b.data))
  321. samples := b.rng.Sample(0, len(b.data), testSize)
  322. var result, count float32
  323. var mu sync.Mutex
  324. _ = parallel.Parallel(len(samples), idx.numJobs, func(_, i int) error {
  325. sample := samples[i]
  326. expected, _ := b.bruteForce.Search(b.data[sample], b.k, prune0)
  327. if len(expected) > 0 {
  328. actual, _ := idx.Search(b.data[sample], b.k, prune0)
  329. mu.Lock()
  330. defer mu.Unlock()
  331. result += recall(expected, actual)
  332. count++
  333. }
  334. return nil
  335. })
  336. if count == 0 {
  337. return 0
  338. }
  339. return result / count
  340. }
  341. func (b *HNSWBuilder) Build(ctx context.Context, recall float32, trials int, prune0 bool) (idx *HNSW, score float32) {
  342. ef := 1 << int(math32.Ceil(math32.Log2(float32(b.k))))
  343. newCtx, span := progress.Start(ctx, "HNSWBuilder.Build", trials)
  344. defer span.End()
  345. for i := 0; i < trials; i++ {
  346. start := time.Now()
  347. idx = NewHNSW(b.data,
  348. SetEFConstruction(ef),
  349. SetHNSWNumJobs(b.numJobs))
  350. idx.Build(newCtx)
  351. buildTime := time.Since(start)
  352. score = b.evaluate(idx, prune0)
  353. span.Add(1)
  354. log.Logger().Info("try to build vector index",
  355. zap.String("index_type", "HNSW"),
  356. zap.Int("ef_construction", ef),
  357. zap.Float32("recall", score),
  358. zap.String("build_time", buildTime.String()))
  359. if score > recall {
  360. return
  361. } else {
  362. ef <<= 1
  363. }
  364. }
  365. return
  366. }
  367. func (b *HNSWBuilder) evaluateTermSearch(idx *HNSW, prune0 bool, term string) float32 {
  368. testSize := mathutil.Min(b.testSize, len(b.data))
  369. samples := b.rng.Sample(0, len(b.data), testSize)
  370. var result, count float32
  371. var mu sync.Mutex
  372. _ = parallel.Parallel(len(samples), runtime.NumCPU(), func(_, i int) error {
  373. sample := samples[i]
  374. expected, _ := b.bruteForce.MultiSearch(b.data[sample], []string{term}, b.k, prune0)
  375. if len(expected) > 0 {
  376. actual, _ := idx.MultiSearch(b.data[sample], []string{term}, b.k, prune0)
  377. mu.Lock()
  378. defer mu.Unlock()
  379. result += recall(expected[term], actual[term])
  380. count++
  381. }
  382. return nil
  383. })
  384. return result / count
  385. }
  386. func (h *HNSW) MultiSearch(q Vector, terms []string, n int, prune0 bool) (values map[string][]int32, scores map[string][]float32) {
  387. values = make(map[string][]int32)
  388. scores = make(map[string][]float32)
  389. for _, term := range terms {
  390. values[term] = make([]int32, 0, n)
  391. scores[term] = make([]float32, 0, n)
  392. }
  393. w := h.efSearch(q, h.efSearchValue(n))
  394. for w.Len() > 0 {
  395. value, score := w.Pop()
  396. if !prune0 || score < 0 {
  397. if len(values[""]) < n {
  398. values[""] = append(values[""], value)
  399. scores[""] = append(scores[""], score)
  400. }
  401. for _, term := range h.vectors[value].Terms() {
  402. if _, exist := values[term]; exist && len(values[term]) < n {
  403. values[term] = append(values[term], value)
  404. scores[term] = append(scores[term], score)
  405. }
  406. }
  407. }
  408. }
  409. return
  410. }
  411. func (h *HNSW) efSearch(q Vector, ef int) *heap.PriorityQueue {
  412. var (
  413. w *heap.PriorityQueue // set for the current the nearest element
  414. enterPoints = h.distance(q, []int32{h.enterPoint}) // get enter point for hnsw
  415. topLayer = len(h.upperNeighbors) // top layer for hnsw
  416. )
  417. for currentLayer := topLayer; currentLayer > 0; currentLayer-- {
  418. w = h.searchLayer(q, enterPoints, 1, currentLayer)
  419. enterPoints = heap.NewPriorityQueue(false)
  420. enterPoints.Push(w.Peek())
  421. }
  422. w = h.searchLayer(q, enterPoints, ef, 0)
  423. return w
  424. }
  425. // efSearchValue returns the efSearch value to use, given the current number of elements desired.
  426. func (h *HNSW) efSearchValue(n int) int {
  427. if h.ef > 0 {
  428. return mathutil.Max(h.ef, n)
  429. }
  430. return mathutil.Max(h.efConstruction, n)
  431. }
  432. func EstimateHNSWBuilderComplexity(dataSize, trials int) int {
  433. // build index
  434. complexity := dataSize * dataSize
  435. // evaluate
  436. complexity += DefaultTestSize * dataSize
  437. // with trials
  438. return complexity * trials
  439. }