range_search_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. // Licensed to the LF AI & Data foundation under one
  2. // or more contributor license agreements. See the NOTICE file
  3. // distributed with this work for additional information
  4. // regarding copyright ownership. The ASF licenses this file
  5. // to you under the Apache License, Version 2.0 (the
  6. // "License"); you may not use this file except in compliance
  7. // with the License. You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. package rangesearch
  17. import (
  18. "context"
  19. "fmt"
  20. "testing"
  21. "github.com/stretchr/testify/suite"
  22. "go.uber.org/zap"
  23. "google.golang.org/protobuf/proto"
  24. "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
  25. "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
  26. "github.com/milvus-io/milvus/pkg/common"
  27. "github.com/milvus-io/milvus/pkg/log"
  28. "github.com/milvus-io/milvus/pkg/util/funcutil"
  29. "github.com/milvus-io/milvus/pkg/util/merr"
  30. "github.com/milvus-io/milvus/pkg/util/metric"
  31. "github.com/milvus-io/milvus/tests/integration"
  32. )
  33. type RangeSearchSuite struct {
  34. integration.MiniClusterSuite
  35. }
  36. func (s *RangeSearchSuite) TestRangeSearchIP() {
  37. c := s.Cluster
  38. ctx, cancel := context.WithCancel(c.GetContext())
  39. defer cancel()
  40. prefix := "TestRangeSearchIP"
  41. dbName := ""
  42. collectionName := prefix + funcutil.GenRandomStr()
  43. dim := 128
  44. rowNum := 3000
  45. schema := integration.ConstructSchema(collectionName, dim, true)
  46. marshaledSchema, err := proto.Marshal(schema)
  47. s.NoError(err)
  48. createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
  49. DbName: dbName,
  50. CollectionName: collectionName,
  51. Schema: marshaledSchema,
  52. ShardsNum: common.DefaultShardsNum,
  53. })
  54. s.NoError(err)
  55. err = merr.Error(createCollectionStatus)
  56. if err != nil {
  57. log.Warn("createCollectionStatus fail reason", zap.Error(err))
  58. }
  59. log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
  60. showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
  61. s.NoError(err)
  62. s.True(merr.Ok(showCollectionsResp.GetStatus()))
  63. log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
  64. fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
  65. hashKeys := integration.GenerateHashKeys(rowNum)
  66. insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
  67. DbName: dbName,
  68. CollectionName: collectionName,
  69. FieldsData: []*schemapb.FieldData{fVecColumn},
  70. HashKeys: hashKeys,
  71. NumRows: uint32(rowNum),
  72. })
  73. s.NoError(err)
  74. s.True(merr.Ok(insertResult.GetStatus()))
  75. // flush
  76. flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
  77. DbName: dbName,
  78. CollectionNames: []string{collectionName},
  79. })
  80. s.NoError(err)
  81. segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
  82. ids := segmentIDs.GetData()
  83. s.Require().NotEmpty(segmentIDs)
  84. s.Require().True(has)
  85. flushTs, has := flushResp.GetCollFlushTs()[collectionName]
  86. s.True(has)
  87. segments, err := c.MetaWatcher.ShowSegments()
  88. s.NoError(err)
  89. s.NotEmpty(segments)
  90. for _, segment := range segments {
  91. log.Info("ShowSegments result", zap.String("segment", segment.String()))
  92. }
  93. s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
  94. // create index
  95. createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
  96. CollectionName: collectionName,
  97. FieldName: integration.FloatVecField,
  98. IndexName: "_default",
  99. ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
  100. })
  101. s.NoError(err)
  102. err = merr.Error(createIndexStatus)
  103. if err != nil {
  104. log.Warn("createIndexStatus fail reason", zap.Error(err))
  105. }
  106. s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
  107. // load
  108. loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
  109. DbName: dbName,
  110. CollectionName: collectionName,
  111. })
  112. s.NoError(err)
  113. err = merr.Error(loadStatus)
  114. if err != nil {
  115. log.Warn("LoadCollection fail reason", zap.Error(err))
  116. }
  117. s.WaitForLoad(ctx, collectionName)
  118. // search
  119. expr := fmt.Sprintf("%s > 0", integration.Int64Field)
  120. nq := 10
  121. topk := 10
  122. roundDecimal := -1
  123. radius := 10
  124. filter := 20
  125. params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
  126. // only pass in radius when range search
  127. params["radius"] = radius
  128. searchReq := integration.ConstructSearchRequest("", collectionName, expr,
  129. integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
  130. searchResult, _ := c.Proxy.Search(ctx, searchReq)
  131. err = merr.Error(searchResult.GetStatus())
  132. if err != nil {
  133. log.Warn("searchResult fail reason", zap.Error(err))
  134. }
  135. s.NoError(err)
  136. // pass in radius and range_filter when range search
  137. params["range_filter"] = filter
  138. searchReq = integration.ConstructSearchRequest("", collectionName, expr,
  139. integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
  140. searchResult, _ = c.Proxy.Search(ctx, searchReq)
  141. err = merr.Error(searchResult.GetStatus())
  142. if err != nil {
  143. log.Warn("searchResult fail reason", zap.Error(err))
  144. }
  145. s.NoError(err)
  146. // pass in illegal radius and range_filter when range search
  147. params["radius"] = filter
  148. params["range_filter"] = radius
  149. searchReq = integration.ConstructSearchRequest("", collectionName, expr,
  150. integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
  151. searchResult, _ = c.Proxy.Search(ctx, searchReq)
  152. err = merr.Error(searchResult.GetStatus())
  153. if err != nil {
  154. log.Warn("searchResult fail reason", zap.Error(err))
  155. }
  156. s.Error(err)
  157. log.Info("=========================")
  158. log.Info("=========================")
  159. log.Info("TestRangeSearchIP succeed")
  160. log.Info("=========================")
  161. log.Info("=========================")
  162. }
  163. func (s *RangeSearchSuite) TestRangeSearchL2() {
  164. c := s.Cluster
  165. ctx, cancel := context.WithCancel(c.GetContext())
  166. defer cancel()
  167. prefix := "TestRangeSearchL2"
  168. dbName := ""
  169. collectionName := prefix + funcutil.GenRandomStr()
  170. dim := 128
  171. rowNum := 3000
  172. schema := integration.ConstructSchema(collectionName, dim, true)
  173. marshaledSchema, err := proto.Marshal(schema)
  174. s.NoError(err)
  175. createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
  176. DbName: dbName,
  177. CollectionName: collectionName,
  178. Schema: marshaledSchema,
  179. ShardsNum: common.DefaultShardsNum,
  180. })
  181. s.NoError(err)
  182. err = merr.Error(createCollectionStatus)
  183. if err != nil {
  184. log.Warn("createCollectionStatus fail reason", zap.Error(err))
  185. }
  186. log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
  187. showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
  188. s.NoError(err)
  189. s.True(merr.Ok(showCollectionsResp.GetStatus()))
  190. log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
  191. fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
  192. hashKeys := integration.GenerateHashKeys(rowNum)
  193. insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
  194. DbName: dbName,
  195. CollectionName: collectionName,
  196. FieldsData: []*schemapb.FieldData{fVecColumn},
  197. HashKeys: hashKeys,
  198. NumRows: uint32(rowNum),
  199. })
  200. s.NoError(err)
  201. s.True(merr.Ok(insertResult.GetStatus()))
  202. // flush
  203. flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
  204. DbName: dbName,
  205. CollectionNames: []string{collectionName},
  206. })
  207. s.NoError(err)
  208. segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
  209. ids := segmentIDs.GetData()
  210. s.Require().NotEmpty(segmentIDs)
  211. s.Require().True(has)
  212. flushTs, has := flushResp.GetCollFlushTs()[collectionName]
  213. s.True(has)
  214. segments, err := c.MetaWatcher.ShowSegments()
  215. s.NoError(err)
  216. s.NotEmpty(segments)
  217. for _, segment := range segments {
  218. log.Info("ShowSegments result", zap.String("segment", segment.String()))
  219. }
  220. s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
  221. // create index
  222. createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
  223. CollectionName: collectionName,
  224. FieldName: integration.FloatVecField,
  225. IndexName: "_default",
  226. ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
  227. })
  228. s.NoError(err)
  229. err = merr.Error(createIndexStatus)
  230. if err != nil {
  231. log.Warn("createIndexStatus fail reason", zap.Error(err))
  232. }
  233. s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
  234. // load
  235. loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
  236. DbName: dbName,
  237. CollectionName: collectionName,
  238. })
  239. s.NoError(err)
  240. err = merr.Error(loadStatus)
  241. if err != nil {
  242. log.Warn("LoadCollection fail reason", zap.Error(err))
  243. }
  244. s.WaitForLoad(ctx, collectionName)
  245. // search
  246. expr := fmt.Sprintf("%s > 0", integration.Int64Field)
  247. nq := 10
  248. topk := 10
  249. roundDecimal := -1
  250. radius := 20
  251. filter := 10
  252. params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2)
  253. // only pass in radius when range search
  254. params["radius"] = radius
  255. searchReq := integration.ConstructSearchRequest("", collectionName, expr,
  256. integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
  257. searchResult, _ := c.Proxy.Search(ctx, searchReq)
  258. err = merr.Error(searchResult.GetStatus())
  259. if err != nil {
  260. log.Warn("searchResult fail reason", zap.Error(err))
  261. }
  262. s.NoError(err)
  263. // pass in radius and range_filter when range search
  264. params["range_filter"] = filter
  265. searchReq = integration.ConstructSearchRequest("", collectionName, expr,
  266. integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
  267. searchResult, _ = c.Proxy.Search(ctx, searchReq)
  268. err = merr.Error(searchResult.GetStatus())
  269. if err != nil {
  270. log.Warn("searchResult fail reason", zap.Error(err))
  271. }
  272. s.NoError(err)
  273. // pass in illegal radius and range_filter when range search
  274. params["radius"] = filter
  275. params["range_filter"] = radius
  276. searchReq = integration.ConstructSearchRequest("", collectionName, expr,
  277. integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, params, nq, dim, topk, roundDecimal)
  278. searchResult, _ = c.Proxy.Search(ctx, searchReq)
  279. err = merr.Error(searchResult.GetStatus())
  280. if err != nil {
  281. log.Warn("searchResult fail reason", zap.Error(err))
  282. }
  283. s.Error(err)
  284. log.Info("=========================")
  285. log.Info("=========================")
  286. log.Info("TestRangeSearchL2 succeed")
  287. log.Info("=========================")
  288. log.Info("=========================")
  289. }
  290. func TestRangeSearch(t *testing.T) {
  291. suite.Run(t, new(RangeSearchSuite))
  292. }