bm25_function.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /*
  2. * # Licensed to the LF AI & Data foundation under one
  3. * # or more contributor license agreements. See the NOTICE file
  4. * # distributed with this work for additional information
  5. * # regarding copyright ownership. The ASF licenses this file
  6. * # to you under the Apache License, Version 2.0 (the
  7. * # "License"); you may not use this file except in compliance
  8. * # with the License. You may obtain a copy of the License at
  9. * #
  10. * # http://www.apache.org/licenses/LICENSE-2.0
  11. * #
  12. * # Unless required by applicable law or agreed to in writing, software
  13. * # distributed under the License is distributed on an "AS IS" BASIS,
  14. * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * # See the License for the specific language governing permissions and
  16. * # limitations under the License.
  17. */
  18. package function
  19. import (
  20. "fmt"
  21. "sync"
  22. "github.com/samber/lo"
  23. "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
  24. "github.com/milvus-io/milvus/internal/util/ctokenizer"
  25. "github.com/milvus-io/milvus/internal/util/tokenizerapi"
  26. "github.com/milvus-io/milvus/pkg/util/typeutil"
  27. )
  28. // BM25 Runner
  29. // Input: string
  30. // Output: map[uint32]float32
  31. type BM25FunctionRunner struct {
  32. tokenizer tokenizerapi.Tokenizer
  33. schema *schemapb.FunctionSchema
  34. outputField *schemapb.FieldSchema
  35. concurrency int
  36. }
  37. func NewBM25FunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*BM25FunctionRunner, error) {
  38. if len(schema.GetOutputFieldIds()) != 1 {
  39. return nil, fmt.Errorf("bm25 function should only have one output field, but now %d", len(schema.GetOutputFieldIds()))
  40. }
  41. runner := &BM25FunctionRunner{
  42. schema: schema,
  43. concurrency: 8,
  44. }
  45. for _, field := range coll.GetFields() {
  46. if field.GetFieldID() == schema.GetOutputFieldIds()[0] {
  47. runner.outputField = field
  48. break
  49. }
  50. }
  51. if runner.outputField == nil {
  52. return nil, fmt.Errorf("no output field")
  53. }
  54. tokenizer, err := ctokenizer.NewTokenizer(map[string]string{})
  55. if err != nil {
  56. return nil, err
  57. }
  58. runner.tokenizer = tokenizer
  59. return runner, nil
  60. }
  61. func (v *BM25FunctionRunner) run(data []string, dst []map[uint32]float32) error {
  62. // TODO AOIASD Support single Tokenizer concurrency
  63. tokenizer, err := ctokenizer.NewTokenizer(map[string]string{})
  64. if err != nil {
  65. return err
  66. }
  67. defer tokenizer.Destroy()
  68. for i := 0; i < len(data); i++ {
  69. embeddingMap := map[uint32]float32{}
  70. tokenStream := tokenizer.NewTokenStream(data[i])
  71. defer tokenStream.Destroy()
  72. for tokenStream.Advance() {
  73. token := tokenStream.Token()
  74. // TODO More Hash Option
  75. hash := typeutil.HashString2Uint32(token)
  76. embeddingMap[hash] += 1
  77. }
  78. dst[i] = embeddingMap
  79. }
  80. return nil
  81. }
  82. func (v *BM25FunctionRunner) BatchRun(inputs ...any) ([]any, error) {
  83. if len(inputs) > 1 {
  84. return nil, fmt.Errorf("BM25 function receieve more than one input")
  85. }
  86. text, ok := inputs[0].([]string)
  87. if !ok {
  88. return nil, fmt.Errorf("BM25 function batch input not string list")
  89. }
  90. rowNum := len(text)
  91. embedData := make([]map[uint32]float32, rowNum)
  92. wg := sync.WaitGroup{}
  93. errCh := make(chan error, v.concurrency)
  94. for i, j := 0, 0; i < v.concurrency && j < rowNum; i++ {
  95. start := j
  96. end := start + rowNum/v.concurrency
  97. if i < rowNum%v.concurrency {
  98. end += 1
  99. }
  100. wg.Add(1)
  101. go func() {
  102. defer wg.Done()
  103. err := v.run(text[start:end], embedData[start:end])
  104. if err != nil {
  105. errCh <- err
  106. return
  107. }
  108. }()
  109. j = end
  110. }
  111. wg.Wait()
  112. close(errCh)
  113. for err := range errCh {
  114. if err != nil {
  115. return nil, err
  116. }
  117. }
  118. return []any{buildSparseFloatArray(embedData)}, nil
  119. }
  120. func (v *BM25FunctionRunner) GetSchema() *schemapb.FunctionSchema {
  121. return v.schema
  122. }
  123. func (v *BM25FunctionRunner) GetOutputFields() []*schemapb.FieldSchema {
  124. return []*schemapb.FieldSchema{v.outputField}
  125. }
  126. func buildSparseFloatArray(mapdata []map[uint32]float32) *schemapb.SparseFloatArray {
  127. dim := 0
  128. bytes := lo.Map(mapdata, func(sparseMap map[uint32]float32, _ int) []byte {
  129. if len(sparseMap) > dim {
  130. dim = len(sparseMap)
  131. }
  132. return typeutil.CreateAndSortSparseFloatRow(sparseMap)
  133. })
  134. return &schemapb.SparseFloatArray{
  135. Contents: bytes,
  136. Dim: int64(dim),
  137. }
  138. }