util_query.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. // Licensed to the LF AI & Data foundation under one
  2. // or more contributor license agreements. See the NOTICE file
  3. // distributed with this work for additional information
  4. // regarding copyright ownership. The ASF licenses this file
  5. // to you under the Apache License, Version 2.0 (the
  6. // "License"); you may not use this file except in compliance
  7. // with the License. You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. package integration
  17. import (
  18. "bytes"
  19. "context"
  20. "encoding/binary"
  21. "encoding/json"
  22. "math/rand"
  23. "strconv"
  24. "time"
  25. "google.golang.org/protobuf/proto"
  26. "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
  27. "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
  28. "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
  29. "github.com/milvus-io/milvus/pkg/common"
  30. "github.com/milvus-io/milvus/pkg/util/testutils"
  31. )
  32. const (
  33. AnnsFieldKey = "anns_field"
  34. TopKKey = "topk"
  35. NQKey = "nq"
  36. MetricTypeKey = common.MetricTypeKey
  37. SearchParamsKey = common.IndexParamsKey
  38. RoundDecimalKey = "round_decimal"
  39. OffsetKey = "offset"
  40. LimitKey = "limit"
  41. )
  42. func (s *MiniClusterSuite) WaitForLoadWithDB(ctx context.Context, dbName, collection string) {
  43. s.waitForLoadInternal(ctx, dbName, collection)
  44. }
  45. func (s *MiniClusterSuite) WaitForLoad(ctx context.Context, collection string) {
  46. s.waitForLoadInternal(ctx, "", collection)
  47. }
  48. func (s *MiniClusterSuite) waitForLoadInternal(ctx context.Context, dbName, collection string) {
  49. cluster := s.Cluster
  50. getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
  51. loadProgress, err := cluster.Proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
  52. DbName: dbName,
  53. CollectionName: collection,
  54. })
  55. if err != nil {
  56. panic("GetLoadingProgress fail")
  57. }
  58. return loadProgress
  59. }
  60. for getLoadingProgress().GetProgress() != 100 {
  61. select {
  62. case <-ctx.Done():
  63. s.FailNow("failed to wait for load")
  64. return
  65. default:
  66. time.Sleep(500 * time.Millisecond)
  67. }
  68. }
  69. }
  70. func (s *MiniClusterSuite) WaitForLoadRefresh(ctx context.Context, dbName, collection string) {
  71. cluster := s.Cluster
  72. getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
  73. loadProgress, err := cluster.Proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
  74. DbName: dbName,
  75. CollectionName: collection,
  76. })
  77. if err != nil {
  78. panic("GetLoadingProgress fail")
  79. }
  80. return loadProgress
  81. }
  82. for getLoadingProgress().GetRefreshProgress() != 100 {
  83. select {
  84. case <-ctx.Done():
  85. s.FailNow("failed to wait for load (refresh)")
  86. return
  87. default:
  88. time.Sleep(500 * time.Millisecond)
  89. }
  90. }
  91. }
  92. func ConstructSearchRequest(
  93. dbName, collectionName string,
  94. expr string,
  95. vecField string,
  96. vectorType schemapb.DataType,
  97. outputFields []string,
  98. metricType string,
  99. params map[string]any,
  100. nq, dim int, topk, roundDecimal int,
  101. ) *milvuspb.SearchRequest {
  102. b, err := json.Marshal(params)
  103. if err != nil {
  104. panic(err)
  105. }
  106. plg := constructPlaceholderGroup(nq, dim, vectorType)
  107. plgBs, err := proto.Marshal(plg)
  108. if err != nil {
  109. panic(err)
  110. }
  111. return &milvuspb.SearchRequest{
  112. Base: nil,
  113. DbName: dbName,
  114. CollectionName: collectionName,
  115. PartitionNames: nil,
  116. Dsl: expr,
  117. PlaceholderGroup: plgBs,
  118. DslType: commonpb.DslType_BoolExprV1,
  119. OutputFields: outputFields,
  120. SearchParams: []*commonpb.KeyValuePair{
  121. {
  122. Key: common.MetricTypeKey,
  123. Value: metricType,
  124. },
  125. {
  126. Key: SearchParamsKey,
  127. Value: string(b),
  128. },
  129. {
  130. Key: AnnsFieldKey,
  131. Value: vecField,
  132. },
  133. {
  134. Key: common.TopKKey,
  135. Value: strconv.Itoa(topk),
  136. },
  137. {
  138. Key: RoundDecimalKey,
  139. Value: strconv.Itoa(roundDecimal),
  140. },
  141. },
  142. TravelTimestamp: 0,
  143. GuaranteeTimestamp: 0,
  144. Nq: int64(nq),
  145. }
  146. }
  147. func ConstructSearchRequestWithConsistencyLevel(
  148. dbName, collectionName string,
  149. expr string,
  150. vecField string,
  151. vectorType schemapb.DataType,
  152. outputFields []string,
  153. metricType string,
  154. params map[string]any,
  155. nq, dim int, topk, roundDecimal int,
  156. useDefaultConsistency bool,
  157. consistencyLevel commonpb.ConsistencyLevel,
  158. ) *milvuspb.SearchRequest {
  159. b, err := json.Marshal(params)
  160. if err != nil {
  161. panic(err)
  162. }
  163. plg := constructPlaceholderGroup(nq, dim, vectorType)
  164. plgBs, err := proto.Marshal(plg)
  165. if err != nil {
  166. panic(err)
  167. }
  168. return &milvuspb.SearchRequest{
  169. Base: nil,
  170. DbName: dbName,
  171. CollectionName: collectionName,
  172. PartitionNames: nil,
  173. Dsl: expr,
  174. PlaceholderGroup: plgBs,
  175. DslType: commonpb.DslType_BoolExprV1,
  176. OutputFields: outputFields,
  177. SearchParams: []*commonpb.KeyValuePair{
  178. {
  179. Key: common.MetricTypeKey,
  180. Value: metricType,
  181. },
  182. {
  183. Key: SearchParamsKey,
  184. Value: string(b),
  185. },
  186. {
  187. Key: AnnsFieldKey,
  188. Value: vecField,
  189. },
  190. {
  191. Key: common.TopKKey,
  192. Value: strconv.Itoa(topk),
  193. },
  194. {
  195. Key: RoundDecimalKey,
  196. Value: strconv.Itoa(roundDecimal),
  197. },
  198. },
  199. TravelTimestamp: 0,
  200. GuaranteeTimestamp: 0,
  201. UseDefaultConsistency: useDefaultConsistency,
  202. ConsistencyLevel: consistencyLevel,
  203. }
  204. }
  205. func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commonpb.PlaceholderGroup {
  206. values := make([][]byte, 0, nq)
  207. var placeholderType commonpb.PlaceholderType
  208. switch vectorType {
  209. case schemapb.DataType_FloatVector:
  210. placeholderType = commonpb.PlaceholderType_FloatVector
  211. for i := 0; i < nq; i++ {
  212. bs := make([]byte, 0, dim*4)
  213. for j := 0; j < dim; j++ {
  214. var buffer bytes.Buffer
  215. f := rand.Float32()
  216. err := binary.Write(&buffer, common.Endian, f)
  217. if err != nil {
  218. panic(err)
  219. }
  220. bs = append(bs, buffer.Bytes()...)
  221. }
  222. values = append(values, bs)
  223. }
  224. case schemapb.DataType_BinaryVector:
  225. placeholderType = commonpb.PlaceholderType_BinaryVector
  226. for i := 0; i < nq; i++ {
  227. total := dim / 8
  228. ret := make([]byte, total)
  229. _, err := rand.Read(ret)
  230. if err != nil {
  231. panic(err)
  232. }
  233. values = append(values, ret)
  234. }
  235. case schemapb.DataType_Float16Vector:
  236. placeholderType = commonpb.PlaceholderType_Float16Vector
  237. data := testutils.GenerateFloat16Vectors(nq, dim)
  238. for i := 0; i < nq; i++ {
  239. rowBytes := dim * 2
  240. values = append(values, data[rowBytes*i:rowBytes*(i+1)])
  241. }
  242. case schemapb.DataType_BFloat16Vector:
  243. placeholderType = commonpb.PlaceholderType_BFloat16Vector
  244. data := testutils.GenerateBFloat16Vectors(nq, dim)
  245. for i := 0; i < nq; i++ {
  246. rowBytes := dim * 2
  247. values = append(values, data[rowBytes*i:rowBytes*(i+1)])
  248. }
  249. case schemapb.DataType_SparseFloatVector:
  250. // for sparse, all query rows are encoded in a single byte array
  251. values = make([][]byte, 0, 1)
  252. placeholderType = commonpb.PlaceholderType_SparseFloatVector
  253. sparseVecs := GenerateSparseFloatArray(nq)
  254. values = append(values, sparseVecs.Contents...)
  255. default:
  256. panic("invalid vector data type")
  257. }
  258. return &commonpb.PlaceholderGroup{
  259. Placeholders: []*commonpb.PlaceholderValue{
  260. {
  261. Tag: "$0",
  262. Type: placeholderType,
  263. Values: values,
  264. },
  265. },
  266. }
  267. }