123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- // Copyright 2022 gorse Project Authors
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package search
- import (
- "context"
- "fmt"
- "reflect"
- "sort"
- "github.com/chewxy/math32"
- "github.com/zhenghaoz/gorse/base/floats"
- "github.com/zhenghaoz/gorse/base/log"
- "go.uber.org/zap"
- "modernc.org/sortutil"
- )
- type Vector interface {
- Distance(vector Vector) float32
- Terms() []string
- IsHidden() bool
- Centroid(vectors []Vector, indices []int32) CentroidVector
- }
- type DenseVector struct {
- data []float32
- terms []string
- isHidden bool
- }
- func NewDenseVector(data []float32, terms []string, isHidden bool) *DenseVector {
- return &DenseVector{
- data: data,
- terms: terms,
- isHidden: isHidden,
- }
- }
- func (v *DenseVector) Distance(vector Vector) float32 {
- feedbackVector, isFeedback := vector.(*DenseVector)
- if !isFeedback {
- log.Logger().Fatal("vector type mismatch",
- zap.String("expect", reflect.TypeOf(v).String()),
- zap.String("actual", reflect.TypeOf(vector).String()))
- }
- return -floats.Dot(v.data, feedbackVector.data)
- }
- func (v *DenseVector) Terms() []string {
- return v.terms
- }
- func (v *DenseVector) IsHidden() bool {
- return v.isHidden
- }
- func (v *DenseVector) Centroid(_ []Vector, _ []int32) CentroidVector {
- panic("not implemented")
- }
- type DictionaryVector struct {
- isHidden bool
- terms []string
- indices []int32
- values []float32
- norm float32
- }
- func NewDictionaryVector(indices []int32, values []float32, terms []string, isHidden bool) *DictionaryVector {
- sort.Sort(sortutil.Int32Slice(indices))
- var norm float32
- for _, i := range indices {
- norm += values[i]
- }
- norm = math32.Sqrt(norm)
- return &DictionaryVector{
- isHidden: isHidden,
- terms: terms,
- indices: indices,
- values: values,
- norm: norm,
- }
- }
- func (v *DictionaryVector) Dot(vector *DictionaryVector) (float32, float32) {
- i, j, sum, common := 0, 0, float32(0), float32(0)
- for i < len(v.indices) && j < len(vector.indices) {
- if v.indices[i] == vector.indices[j] {
- sum += v.values[v.indices[i]]
- common++
- i++
- j++
- } else if v.indices[i] < vector.indices[j] {
- i++
- } else if v.indices[i] > vector.indices[j] {
- j++
- }
- }
- return sum, common
- }
- const similarityShrink = 100
- func (v *DictionaryVector) Distance(vector Vector) float32 {
- var score float32
- if dictVec, isDictVec := vector.(*DictionaryVector); !isDictVec {
- panic(fmt.Sprintf("unexpected vector type: %v", reflect.TypeOf(vector)))
- } else {
- dot, common := v.Dot(dictVec)
- if dot > 0 {
- score = -dot / v.norm / dictVec.norm * common / (common + similarityShrink)
- }
- }
- return score
- }
- func (v *DictionaryVector) Terms() []string {
- return v.terms
- }
- func (v *DictionaryVector) IsHidden() bool {
- return v.isHidden
- }
- type CentroidVector interface {
- Distance(vector Vector) float32
- }
- type DictionaryCentroidVector struct {
- data map[int32]float32
- norm float32
- }
- func (v DictionaryVector) Centroid(vectors []Vector, indices []int32) CentroidVector {
- data := make(map[int32]float32)
- for _, i := range indices {
- vector, isDictVector := vectors[i].(*DictionaryVector)
- if !isDictVector {
- panic(fmt.Sprintf("unexpected vector type: %v", reflect.TypeOf(vector)))
- }
- for _, i := range vector.indices {
- data[i] += math32.Sqrt(vector.values[i])
- }
- }
- var norm float32
- for _, val := range data {
- norm += val * val
- }
- norm = math32.Sqrt(norm)
- for i := range data {
- data[i] /= norm
- }
- return &DictionaryCentroidVector{
- data: data,
- norm: norm,
- }
- }
- func (v *DictionaryCentroidVector) Distance(vector Vector) float32 {
- var sum, common float32
- if dictVector, isDictVec := vector.(*DictionaryVector); !isDictVec {
- panic(fmt.Sprintf("unexpected vector type: %v", reflect.TypeOf(vector)))
- } else {
- for _, i := range dictVector.indices {
- if val, exist := v.data[i]; exist {
- sum += val * math32.Sqrt(v.data[i])
- common++
- }
- }
- }
- return -sum * common / (common + similarityShrink)
- }
- type VectorIndex interface {
- Build(ctx context.Context)
- Search(q Vector, n int, prune0 bool) ([]int32, []float32)
- MultiSearch(q Vector, terms []string, n int, prune0 bool) (map[string][]int32, map[string][]float32)
- }
|