read_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package client
  2. import (
  3. "context"
  4. "fmt"
  5. "math/rand"
  6. "testing"
  7. "github.com/samber/lo"
  8. "github.com/stretchr/testify/mock"
  9. "github.com/stretchr/testify/suite"
  10. "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
  11. "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
  12. "github.com/milvus-io/milvus/client/v2/entity"
  13. "github.com/milvus-io/milvus/pkg/util/merr"
  14. )
  15. type ReadSuite struct {
  16. MockSuiteBase
  17. schema *entity.Schema
  18. schemaDyn *entity.Schema
  19. }
  20. func (s *ReadSuite) SetupSuite() {
  21. s.MockSuiteBase.SetupSuite()
  22. s.schema = entity.NewSchema().
  23. WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
  24. WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
  25. s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
  26. WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
  27. WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
  28. }
  29. func (s *ReadSuite) TestSearch() {
  30. ctx, cancel := context.WithCancel(context.Background())
  31. defer cancel()
  32. s.Run("success", func() {
  33. collectionName := fmt.Sprintf("coll_%s", s.randString(6))
  34. partitionName := fmt.Sprintf("part_%s", s.randString(6))
  35. s.setupCache(collectionName, s.schema)
  36. s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
  37. s.Equal(collectionName, sr.GetCollectionName())
  38. s.ElementsMatch([]string{partitionName}, sr.GetPartitionNames())
  39. return &milvuspb.SearchResults{
  40. Status: merr.Success(),
  41. Results: &schemapb.SearchResultData{
  42. NumQueries: 1,
  43. TopK: 10,
  44. FieldsData: []*schemapb.FieldData{
  45. s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
  46. },
  47. Ids: &schemapb.IDs{
  48. IdField: &schemapb.IDs_IntId{
  49. IntId: &schemapb.LongArray{
  50. Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
  51. },
  52. },
  53. },
  54. Scores: make([]float32, 10),
  55. Topks: []int64{10},
  56. },
  57. }, nil
  58. }).Once()
  59. _, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
  60. entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
  61. return rand.Float32()
  62. })),
  63. }).WithPartitions(partitionName))
  64. s.NoError(err)
  65. })
  66. s.Run("dynamic_schema", func() {
  67. collectionName := fmt.Sprintf("coll_%s", s.randString(6))
  68. partitionName := fmt.Sprintf("part_%s", s.randString(6))
  69. s.setupCache(collectionName, s.schemaDyn)
  70. s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
  71. return &milvuspb.SearchResults{
  72. Status: merr.Success(),
  73. Results: &schemapb.SearchResultData{
  74. NumQueries: 1,
  75. TopK: 2,
  76. FieldsData: []*schemapb.FieldData{
  77. s.getInt64FieldData("ID", []int64{1, 2}),
  78. s.getJSONBytesFieldData("$meta", [][]byte{
  79. []byte(`{"A": 123, "B": "456"}`),
  80. []byte(`{"B": "abc", "A": 456}`),
  81. }, true),
  82. },
  83. Ids: &schemapb.IDs{
  84. IdField: &schemapb.IDs_IntId{
  85. IntId: &schemapb.LongArray{
  86. Data: []int64{1, 2},
  87. },
  88. },
  89. },
  90. Scores: make([]float32, 2),
  91. Topks: []int64{2},
  92. },
  93. }, nil
  94. }).Once()
  95. _, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
  96. entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
  97. return rand.Float32()
  98. })),
  99. }).WithPartitions(partitionName))
  100. s.NoError(err)
  101. })
  102. s.Run("failure", func() {
  103. collectionName := fmt.Sprintf("coll_%s", s.randString(6))
  104. s.setupCache(collectionName, s.schemaDyn)
  105. s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
  106. return nil, merr.WrapErrServiceInternal("mocked")
  107. }).Once()
  108. _, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
  109. entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
  110. return rand.Float32()
  111. })),
  112. }))
  113. s.Error(err)
  114. })
  115. }
  116. func (s *ReadSuite) TestQuery() {
  117. ctx, cancel := context.WithCancel(context.Background())
  118. defer cancel()
  119. s.Run("success", func() {
  120. collectionName := fmt.Sprintf("coll_%s", s.randString(6))
  121. partitionName := fmt.Sprintf("part_%s", s.randString(6))
  122. s.setupCache(collectionName, s.schema)
  123. s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
  124. s.Equal(collectionName, qr.GetCollectionName())
  125. return &milvuspb.QueryResults{}, nil
  126. }).Once()
  127. _, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions(partitionName))
  128. s.NoError(err)
  129. })
  130. }
  131. func TestRead(t *testing.T) {
  132. suite.Run(t, new(ReadSuite))
  133. }