read_options.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. // Licensed to the LF AI & Data foundation under one
  2. // or more contributor license agreements. See the NOTICE file
  3. // distributed with this work for additional information
  4. // regarding copyright ownership. The ASF licenses this file
  5. // to you under the Apache License, Version 2.0 (the
  6. // "License"); you may not use this file except in compliance
  7. // with the License. You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. package client
  17. import (
  18. "encoding/json"
  19. "strconv"
  20. "google.golang.org/protobuf/proto"
  21. "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
  22. "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
  23. "github.com/milvus-io/milvus/client/v2/entity"
  24. )
  25. const (
  26. spAnnsField = `anns_field`
  27. spTopK = `topk`
  28. spOffset = `offset`
  29. spLimit = `limit`
  30. spParams = `params`
  31. spMetricsType = `metric_type`
  32. spRoundDecimal = `round_decimal`
  33. spIgnoreGrowing = `ignore_growing`
  34. spGroupBy = `group_by_field`
  35. )
  36. type SearchOption interface {
  37. Request() *milvuspb.SearchRequest
  38. }
  39. var _ SearchOption = (*searchOption)(nil)
  40. type searchOption struct {
  41. collectionName string
  42. partitionNames []string
  43. topK int
  44. offset int
  45. outputFields []string
  46. consistencyLevel entity.ConsistencyLevel
  47. useDefaultConsistencyLevel bool
  48. ignoreGrowing bool
  49. expr string
  50. // normal search request
  51. request *annRequest
  52. // TODO add sub request when support hybrid search
  53. }
  54. type annRequest struct {
  55. vectors []entity.Vector
  56. annField string
  57. metricsType entity.MetricType
  58. searchParam map[string]string
  59. groupByField string
  60. }
  61. func (opt *searchOption) Request() *milvuspb.SearchRequest {
  62. // TODO check whether search is hybrid after logic merged
  63. return opt.prepareSearchRequest(opt.request)
  64. }
  65. func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.SearchRequest {
  66. request := &milvuspb.SearchRequest{
  67. CollectionName: opt.collectionName,
  68. PartitionNames: opt.partitionNames,
  69. Dsl: opt.expr,
  70. DslType: commonpb.DslType_BoolExprV1,
  71. ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
  72. OutputFields: opt.outputFields,
  73. }
  74. if annRequest != nil {
  75. // nq
  76. request.Nq = int64(len(annRequest.vectors))
  77. // search param
  78. bs, _ := json.Marshal(annRequest.searchParam)
  79. params := map[string]string{
  80. spAnnsField: annRequest.annField,
  81. spTopK: strconv.Itoa(opt.topK),
  82. spOffset: strconv.Itoa(opt.offset),
  83. spParams: string(bs),
  84. spMetricsType: string(annRequest.metricsType),
  85. spRoundDecimal: "-1",
  86. spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
  87. }
  88. if annRequest.groupByField != "" {
  89. params[spGroupBy] = annRequest.groupByField
  90. }
  91. request.SearchParams = entity.MapKvPairs(params)
  92. // placeholder group
  93. request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)
  94. }
  95. return request
  96. }
  97. func (opt *searchOption) WithFilter(expr string) *searchOption {
  98. opt.expr = expr
  99. return opt
  100. }
  101. func (opt *searchOption) WithOffset(offset int) *searchOption {
  102. opt.offset = offset
  103. return opt
  104. }
  105. func (opt *searchOption) WithOutputFields(fieldNames ...string) *searchOption {
  106. opt.outputFields = fieldNames
  107. return opt
  108. }
  109. func (opt *searchOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchOption {
  110. opt.consistencyLevel = consistencyLevel
  111. opt.useDefaultConsistencyLevel = false
  112. return opt
  113. }
  114. func (opt *searchOption) WithANNSField(annsField string) *searchOption {
  115. opt.request.annField = annsField
  116. return opt
  117. }
  118. func (opt *searchOption) WithPartitions(partitionNames ...string) *searchOption {
  119. opt.partitionNames = partitionNames
  120. return opt
  121. }
  122. func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
  123. return &searchOption{
  124. collectionName: collectionName,
  125. topK: limit,
  126. request: &annRequest{
  127. vectors: vectors,
  128. },
  129. useDefaultConsistencyLevel: true,
  130. consistencyLevel: entity.ClBounded,
  131. }
  132. }
  133. func vector2PlaceholderGroupBytes(vectors []entity.Vector) []byte {
  134. phg := &commonpb.PlaceholderGroup{
  135. Placeholders: []*commonpb.PlaceholderValue{
  136. vector2Placeholder(vectors),
  137. },
  138. }
  139. bs, _ := proto.Marshal(phg)
  140. return bs
  141. }
  142. func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue {
  143. var placeHolderType commonpb.PlaceholderType
  144. ph := &commonpb.PlaceholderValue{
  145. Tag: "$0",
  146. Values: make([][]byte, 0, len(vectors)),
  147. }
  148. if len(vectors) == 0 {
  149. return ph
  150. }
  151. switch vectors[0].(type) {
  152. case entity.FloatVector:
  153. placeHolderType = commonpb.PlaceholderType_FloatVector
  154. case entity.BinaryVector:
  155. placeHolderType = commonpb.PlaceholderType_BinaryVector
  156. case entity.BFloat16Vector:
  157. placeHolderType = commonpb.PlaceholderType_BFloat16Vector
  158. case entity.Float16Vector:
  159. placeHolderType = commonpb.PlaceholderType_Float16Vector
  160. case entity.SparseEmbedding:
  161. placeHolderType = commonpb.PlaceholderType_SparseFloatVector
  162. }
  163. ph.Type = placeHolderType
  164. for _, vector := range vectors {
  165. ph.Values = append(ph.Values, vector.Serialize())
  166. }
  167. return ph
  168. }
  169. type QueryOption interface {
  170. Request() *milvuspb.QueryRequest
  171. }
  172. type queryOption struct {
  173. collectionName string
  174. partitionNames []string
  175. queryParams map[string]string
  176. outputFields []string
  177. consistencyLevel entity.ConsistencyLevel
  178. useDefaultConsistencyLevel bool
  179. expr string
  180. }
  181. func (opt *queryOption) Request() *milvuspb.QueryRequest {
  182. return &milvuspb.QueryRequest{
  183. CollectionName: opt.collectionName,
  184. PartitionNames: opt.partitionNames,
  185. OutputFields: opt.outputFields,
  186. Expr: opt.expr,
  187. QueryParams: entity.MapKvPairs(opt.queryParams),
  188. ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(),
  189. }
  190. }
  191. func (opt *queryOption) WithFilter(expr string) *queryOption {
  192. opt.expr = expr
  193. return opt
  194. }
  195. func (opt *queryOption) WithOffset(offset int) *queryOption {
  196. if opt.queryParams == nil {
  197. opt.queryParams = make(map[string]string)
  198. }
  199. opt.queryParams[spOffset] = strconv.Itoa(offset)
  200. return opt
  201. }
  202. func (opt *queryOption) WithLimit(limit int) *queryOption {
  203. if opt.queryParams == nil {
  204. opt.queryParams = make(map[string]string)
  205. }
  206. opt.queryParams[spLimit] = strconv.Itoa(limit)
  207. return opt
  208. }
  209. func (opt *queryOption) WithOutputFields(fieldNames ...string) *queryOption {
  210. opt.outputFields = fieldNames
  211. return opt
  212. }
  213. func (opt *queryOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *queryOption {
  214. opt.consistencyLevel = consistencyLevel
  215. opt.useDefaultConsistencyLevel = false
  216. return opt
  217. }
  218. func (opt *queryOption) WithPartitions(partitionNames ...string) *queryOption {
  219. opt.partitionNames = partitionNames
  220. return opt
  221. }
  222. func NewQueryOption(collectionName string) *queryOption {
  223. return &queryOption{
  224. collectionName: collectionName,
  225. useDefaultConsistencyLevel: true,
  226. consistencyLevel: entity.ClBounded,
  227. }
  228. }