123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- // Licensed to the LF AI & Data foundation under one
- // or more contributor license agreements. See the NOTICE file
- // distributed with this work for additional information
- // regarding copyright ownership. The ASF licenses this file
- // to you under the Apache License, Version 2.0 (the
- // "License"); you may not use this file except in compliance
- // with the License. You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package integration
- import (
- "bytes"
- "context"
- "encoding/binary"
- "encoding/json"
- "math/rand"
- "strconv"
- "time"
- "google.golang.org/protobuf/proto"
- "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
- "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
- "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
- "github.com/milvus-io/milvus/pkg/common"
- "github.com/milvus-io/milvus/pkg/util/testutils"
- )
- const (
- AnnsFieldKey = "anns_field"
- TopKKey = "topk"
- NQKey = "nq"
- MetricTypeKey = common.MetricTypeKey
- SearchParamsKey = common.IndexParamsKey
- RoundDecimalKey = "round_decimal"
- OffsetKey = "offset"
- LimitKey = "limit"
- )
- func (s *MiniClusterSuite) WaitForLoadWithDB(ctx context.Context, dbName, collection string) {
- s.waitForLoadInternal(ctx, dbName, collection)
- }
- func (s *MiniClusterSuite) WaitForLoad(ctx context.Context, collection string) {
- s.waitForLoadInternal(ctx, "", collection)
- }
- func (s *MiniClusterSuite) waitForLoadInternal(ctx context.Context, dbName, collection string) {
- cluster := s.Cluster
- getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
- loadProgress, err := cluster.Proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
- DbName: dbName,
- CollectionName: collection,
- })
- if err != nil {
- panic("GetLoadingProgress fail")
- }
- return loadProgress
- }
- for getLoadingProgress().GetProgress() != 100 {
- select {
- case <-ctx.Done():
- s.FailNow("failed to wait for load")
- return
- default:
- time.Sleep(500 * time.Millisecond)
- }
- }
- }
- func (s *MiniClusterSuite) WaitForLoadRefresh(ctx context.Context, dbName, collection string) {
- cluster := s.Cluster
- getLoadingProgress := func() *milvuspb.GetLoadingProgressResponse {
- loadProgress, err := cluster.Proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
- DbName: dbName,
- CollectionName: collection,
- })
- if err != nil {
- panic("GetLoadingProgress fail")
- }
- return loadProgress
- }
- for getLoadingProgress().GetRefreshProgress() != 100 {
- select {
- case <-ctx.Done():
- s.FailNow("failed to wait for load (refresh)")
- return
- default:
- time.Sleep(500 * time.Millisecond)
- }
- }
- }
- func ConstructSearchRequest(
- dbName, collectionName string,
- expr string,
- vecField string,
- vectorType schemapb.DataType,
- outputFields []string,
- metricType string,
- params map[string]any,
- nq, dim int, topk, roundDecimal int,
- ) *milvuspb.SearchRequest {
- b, err := json.Marshal(params)
- if err != nil {
- panic(err)
- }
- plg := constructPlaceholderGroup(nq, dim, vectorType)
- plgBs, err := proto.Marshal(plg)
- if err != nil {
- panic(err)
- }
- return &milvuspb.SearchRequest{
- Base: nil,
- DbName: dbName,
- CollectionName: collectionName,
- PartitionNames: nil,
- Dsl: expr,
- PlaceholderGroup: plgBs,
- DslType: commonpb.DslType_BoolExprV1,
- OutputFields: outputFields,
- SearchParams: []*commonpb.KeyValuePair{
- {
- Key: common.MetricTypeKey,
- Value: metricType,
- },
- {
- Key: SearchParamsKey,
- Value: string(b),
- },
- {
- Key: AnnsFieldKey,
- Value: vecField,
- },
- {
- Key: common.TopKKey,
- Value: strconv.Itoa(topk),
- },
- {
- Key: RoundDecimalKey,
- Value: strconv.Itoa(roundDecimal),
- },
- },
- TravelTimestamp: 0,
- GuaranteeTimestamp: 0,
- Nq: int64(nq),
- }
- }
- func ConstructSearchRequestWithConsistencyLevel(
- dbName, collectionName string,
- expr string,
- vecField string,
- vectorType schemapb.DataType,
- outputFields []string,
- metricType string,
- params map[string]any,
- nq, dim int, topk, roundDecimal int,
- useDefaultConsistency bool,
- consistencyLevel commonpb.ConsistencyLevel,
- ) *milvuspb.SearchRequest {
- b, err := json.Marshal(params)
- if err != nil {
- panic(err)
- }
- plg := constructPlaceholderGroup(nq, dim, vectorType)
- plgBs, err := proto.Marshal(plg)
- if err != nil {
- panic(err)
- }
- return &milvuspb.SearchRequest{
- Base: nil,
- DbName: dbName,
- CollectionName: collectionName,
- PartitionNames: nil,
- Dsl: expr,
- PlaceholderGroup: plgBs,
- DslType: commonpb.DslType_BoolExprV1,
- OutputFields: outputFields,
- SearchParams: []*commonpb.KeyValuePair{
- {
- Key: common.MetricTypeKey,
- Value: metricType,
- },
- {
- Key: SearchParamsKey,
- Value: string(b),
- },
- {
- Key: AnnsFieldKey,
- Value: vecField,
- },
- {
- Key: common.TopKKey,
- Value: strconv.Itoa(topk),
- },
- {
- Key: RoundDecimalKey,
- Value: strconv.Itoa(roundDecimal),
- },
- },
- TravelTimestamp: 0,
- GuaranteeTimestamp: 0,
- UseDefaultConsistency: useDefaultConsistency,
- ConsistencyLevel: consistencyLevel,
- }
- }
- func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commonpb.PlaceholderGroup {
- values := make([][]byte, 0, nq)
- var placeholderType commonpb.PlaceholderType
- switch vectorType {
- case schemapb.DataType_FloatVector:
- placeholderType = commonpb.PlaceholderType_FloatVector
- for i := 0; i < nq; i++ {
- bs := make([]byte, 0, dim*4)
- for j := 0; j < dim; j++ {
- var buffer bytes.Buffer
- f := rand.Float32()
- err := binary.Write(&buffer, common.Endian, f)
- if err != nil {
- panic(err)
- }
- bs = append(bs, buffer.Bytes()...)
- }
- values = append(values, bs)
- }
- case schemapb.DataType_BinaryVector:
- placeholderType = commonpb.PlaceholderType_BinaryVector
- for i := 0; i < nq; i++ {
- total := dim / 8
- ret := make([]byte, total)
- _, err := rand.Read(ret)
- if err != nil {
- panic(err)
- }
- values = append(values, ret)
- }
- case schemapb.DataType_Float16Vector:
- placeholderType = commonpb.PlaceholderType_Float16Vector
- data := testutils.GenerateFloat16Vectors(nq, dim)
- for i := 0; i < nq; i++ {
- rowBytes := dim * 2
- values = append(values, data[rowBytes*i:rowBytes*(i+1)])
- }
- case schemapb.DataType_BFloat16Vector:
- placeholderType = commonpb.PlaceholderType_BFloat16Vector
- data := testutils.GenerateBFloat16Vectors(nq, dim)
- for i := 0; i < nq; i++ {
- rowBytes := dim * 2
- values = append(values, data[rowBytes*i:rowBytes*(i+1)])
- }
- case schemapb.DataType_SparseFloatVector:
- // for sparse, all query rows are encoded in a single byte array
- values = make([][]byte, 0, 1)
- placeholderType = commonpb.PlaceholderType_SparseFloatVector
- sparseVecs := GenerateSparseFloatArray(nq)
- values = append(values, sparseVecs.Contents...)
- default:
- panic("invalid vector data type")
- }
- return &commonpb.PlaceholderGroup{
- Placeholders: []*commonpb.PlaceholderValue{
- {
- Tag: "$0",
- Type: placeholderType,
- Values: values,
- },
- },
- }
- }
|