write_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. // Licensed to the LF AI & Data foundation under one
  2. // or more contributor license agreements. See the NOTICE file
  3. // distributed with this work for additional information
  4. // regarding copyright ownership. The ASF licenses this file
  5. // to you under the Apache License, Version 2.0 (the
  6. // "License"); you may not use this file except in compliance
  7. // with the License. You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. package client
  17. import (
  18. "context"
  19. "fmt"
  20. "math/rand"
  21. "testing"
  22. "github.com/samber/lo"
  23. "github.com/stretchr/testify/mock"
  24. "github.com/stretchr/testify/suite"
  25. "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
  26. "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
  27. "github.com/milvus-io/milvus/client/v2/entity"
  28. "github.com/milvus-io/milvus/pkg/util/merr"
  29. )
  30. type WriteSuite struct {
  31. MockSuiteBase
  32. schema *entity.Schema
  33. schemaDyn *entity.Schema
  34. }
  35. func (s *WriteSuite) SetupSuite() {
  36. s.MockSuiteBase.SetupSuite()
  37. s.schema = entity.NewSchema().
  38. WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
  39. WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
  40. s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
  41. WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
  42. WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
  43. }
  44. func (s *WriteSuite) TestInsert() {
  45. ctx, cancel := context.WithCancel(context.Background())
  46. defer cancel()
  47. s.Run("success", func() {
  48. collName := fmt.Sprintf("coll_%s", s.randString(6))
  49. partName := fmt.Sprintf("part_%s", s.randString(6))
  50. s.setupCache(collName, s.schema)
  51. s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
  52. s.Equal(collName, ir.GetCollectionName())
  53. s.Equal(partName, ir.GetPartitionName())
  54. s.Require().Len(ir.GetFieldsData(), 2)
  55. s.EqualValues(3, ir.GetNumRows())
  56. return &milvuspb.MutationResult{
  57. Status: merr.Success(),
  58. InsertCnt: 3,
  59. IDs: &schemapb.IDs{
  60. IdField: &schemapb.IDs_IntId{
  61. IntId: &schemapb.LongArray{
  62. Data: []int64{1, 2, 3},
  63. },
  64. },
  65. },
  66. }, nil
  67. }).Once()
  68. result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
  69. WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
  70. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  71. })).
  72. WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
  73. s.NoError(err)
  74. s.EqualValues(3, result.InsertCount)
  75. })
  76. s.Run("dynamic_schema", func() {
  77. collName := fmt.Sprintf("coll_%s", s.randString(6))
  78. partName := fmt.Sprintf("part_%s", s.randString(6))
  79. s.setupCache(collName, s.schemaDyn)
  80. s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
  81. s.Equal(collName, ir.GetCollectionName())
  82. s.Equal(partName, ir.GetPartitionName())
  83. s.Require().Len(ir.GetFieldsData(), 3)
  84. s.EqualValues(3, ir.GetNumRows())
  85. return &milvuspb.MutationResult{
  86. Status: merr.Success(),
  87. InsertCnt: 3,
  88. IDs: &schemapb.IDs{
  89. IdField: &schemapb.IDs_IntId{
  90. IntId: &schemapb.LongArray{
  91. Data: []int64{1, 2, 3},
  92. },
  93. },
  94. },
  95. }, nil
  96. }).Once()
  97. result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
  98. WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
  99. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  100. })).
  101. WithVarcharColumn("extra", []string{"a", "b", "c"}).
  102. WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
  103. s.NoError(err)
  104. s.EqualValues(3, result.InsertCount)
  105. })
  106. s.Run("bad_input", func() {
  107. collName := fmt.Sprintf("coll_%s", s.randString(6))
  108. s.setupCache(collName, s.schema)
  109. type badCase struct {
  110. tag string
  111. input InsertOption
  112. }
  113. cases := []badCase{
  114. {
  115. tag: "missing_column",
  116. input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
  117. },
  118. {
  119. tag: "row_count_not_match",
  120. input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
  121. WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
  122. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  123. })),
  124. },
  125. {
  126. tag: "duplicated_columns",
  127. input: NewColumnBasedInsertOption(collName).
  128. WithInt64Column("id", []int64{1}).
  129. WithInt64Column("id", []int64{2}).
  130. WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
  131. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  132. })),
  133. },
  134. {
  135. tag: "different_data_type",
  136. input: NewColumnBasedInsertOption(collName).
  137. WithVarcharColumn("id", []string{"1"}).
  138. WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
  139. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  140. })),
  141. },
  142. }
  143. for _, tc := range cases {
  144. s.Run(tc.tag, func() {
  145. _, err := s.client.Insert(ctx, tc.input)
  146. s.Error(err)
  147. })
  148. }
  149. })
  150. s.Run("failure", func() {
  151. collName := fmt.Sprintf("coll_%s", s.randString(6))
  152. s.setupCache(collName, s.schema)
  153. s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
  154. _, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
  155. WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
  156. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  157. })).
  158. WithInt64Column("id", []int64{1, 2, 3}))
  159. s.Error(err)
  160. })
  161. }
  162. func (s *WriteSuite) TestUpsert() {
  163. ctx, cancel := context.WithCancel(context.Background())
  164. defer cancel()
  165. s.Run("success", func() {
  166. collName := fmt.Sprintf("coll_%s", s.randString(6))
  167. partName := fmt.Sprintf("part_%s", s.randString(6))
  168. s.setupCache(collName, s.schema)
  169. s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
  170. s.Equal(collName, ur.GetCollectionName())
  171. s.Equal(partName, ur.GetPartitionName())
  172. s.Require().Len(ur.GetFieldsData(), 2)
  173. s.EqualValues(3, ur.GetNumRows())
  174. return &milvuspb.MutationResult{
  175. Status: merr.Success(),
  176. UpsertCnt: 3,
  177. IDs: &schemapb.IDs{
  178. IdField: &schemapb.IDs_IntId{
  179. IntId: &schemapb.LongArray{
  180. Data: []int64{1, 2, 3},
  181. },
  182. },
  183. },
  184. }, nil
  185. }).Once()
  186. result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
  187. WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
  188. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  189. })).
  190. WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
  191. s.NoError(err)
  192. s.EqualValues(3, result.UpsertCount)
  193. })
  194. s.Run("dynamic_schema", func() {
  195. collName := fmt.Sprintf("coll_%s", s.randString(6))
  196. partName := fmt.Sprintf("part_%s", s.randString(6))
  197. s.setupCache(collName, s.schemaDyn)
  198. s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
  199. s.Equal(collName, ur.GetCollectionName())
  200. s.Equal(partName, ur.GetPartitionName())
  201. s.Require().Len(ur.GetFieldsData(), 3)
  202. s.EqualValues(3, ur.GetNumRows())
  203. return &milvuspb.MutationResult{
  204. Status: merr.Success(),
  205. UpsertCnt: 3,
  206. IDs: &schemapb.IDs{
  207. IdField: &schemapb.IDs_IntId{
  208. IntId: &schemapb.LongArray{
  209. Data: []int64{1, 2, 3},
  210. },
  211. },
  212. },
  213. }, nil
  214. }).Once()
  215. result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
  216. WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
  217. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  218. })).
  219. WithVarcharColumn("extra", []string{"a", "b", "c"}).
  220. WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
  221. s.NoError(err)
  222. s.EqualValues(3, result.UpsertCount)
  223. })
  224. s.Run("bad_input", func() {
  225. collName := fmt.Sprintf("coll_%s", s.randString(6))
  226. s.setupCache(collName, s.schema)
  227. type badCase struct {
  228. tag string
  229. input UpsertOption
  230. }
  231. cases := []badCase{
  232. {
  233. tag: "missing_column",
  234. input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
  235. },
  236. {
  237. tag: "row_count_not_match",
  238. input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
  239. WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
  240. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  241. })),
  242. },
  243. {
  244. tag: "duplicated_columns",
  245. input: NewColumnBasedInsertOption(collName).
  246. WithInt64Column("id", []int64{1}).
  247. WithInt64Column("id", []int64{2}).
  248. WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
  249. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  250. })),
  251. },
  252. {
  253. tag: "different_data_type",
  254. input: NewColumnBasedInsertOption(collName).
  255. WithVarcharColumn("id", []string{"1"}).
  256. WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
  257. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  258. })),
  259. },
  260. }
  261. for _, tc := range cases {
  262. s.Run(tc.tag, func() {
  263. _, err := s.client.Upsert(ctx, tc.input)
  264. s.Error(err)
  265. })
  266. }
  267. })
  268. s.Run("failure", func() {
  269. collName := fmt.Sprintf("coll_%s", s.randString(6))
  270. s.setupCache(collName, s.schema)
  271. s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
  272. _, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
  273. WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
  274. return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
  275. })).
  276. WithInt64Column("id", []int64{1, 2, 3}))
  277. s.Error(err)
  278. })
  279. }
  280. func (s *WriteSuite) TestDelete() {
  281. ctx, cancel := context.WithCancel(context.Background())
  282. defer cancel()
  283. s.Run("success", func() {
  284. collName := fmt.Sprintf("coll_%s", s.randString(6))
  285. partName := fmt.Sprintf("part_%s", s.randString(6))
  286. type testCase struct {
  287. tag string
  288. input DeleteOption
  289. expectExpr string
  290. }
  291. cases := []testCase{
  292. {
  293. tag: "raw_expr",
  294. input: NewDeleteOption(collName).WithPartition(partName).WithExpr("id > 100"),
  295. expectExpr: "id > 100",
  296. },
  297. {
  298. tag: "int_ids",
  299. input: NewDeleteOption(collName).WithPartition(partName).WithInt64IDs("id", []int64{1, 2, 3}),
  300. expectExpr: "id in [1,2,3]",
  301. },
  302. {
  303. tag: "str_ids",
  304. input: NewDeleteOption(collName).WithPartition(partName).WithStringIDs("id", []string{"a", "b", "c"}),
  305. expectExpr: `id in ["a","b","c"]`,
  306. },
  307. }
  308. for _, tc := range cases {
  309. s.Run(tc.tag, func() {
  310. s.mock.EXPECT().Delete(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dr *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) {
  311. s.Equal(collName, dr.GetCollectionName())
  312. s.Equal(partName, dr.GetPartitionName())
  313. s.Equal(tc.expectExpr, dr.GetExpr())
  314. return &milvuspb.MutationResult{
  315. Status: merr.Success(),
  316. DeleteCnt: 100,
  317. }, nil
  318. }).Once()
  319. result, err := s.client.Delete(ctx, tc.input)
  320. s.NoError(err)
  321. s.EqualValues(100, result.DeleteCount)
  322. })
  323. }
  324. })
  325. }
  326. func TestWrite(t *testing.T) {
  327. suite.Run(t, new(WriteSuite))
  328. }