write_options.go 11 KB


  1. // Licensed to the LF AI & Data foundation under one
  2. // or more contributor license agreements. See the NOTICE file
  3. // distributed with this work for additional information
  4. // regarding copyright ownership. The ASF licenses this file
  5. // to you under the Apache License, Version 2.0 (the
  6. // "License"); you may not use this file except in compliance
  7. // with the License. You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // Unless required by applicable law or agreed to in writing, software
  12. // distributed under the License is distributed on an "AS IS" BASIS,
  13. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. package client
  17. import (
  18. "encoding/json"
  19. "fmt"
  20. "strings"
  21. "github.com/cockroachdb/errors"
  22. "github.com/samber/lo"
  23. "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
  24. "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
  25. "github.com/milvus-io/milvus/client/v2/column"
  26. "github.com/milvus-io/milvus/client/v2/entity"
  27. "github.com/milvus-io/milvus/client/v2/row"
  28. )
  29. type InsertOption interface {
  30. InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error)
  31. CollectionName() string
  32. WriteBackPKs(schema *entity.Schema, pks column.Column) error
  33. }
  34. type UpsertOption interface {
  35. UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error)
  36. CollectionName() string
  37. }
  38. var (
  39. _ UpsertOption = (*columnBasedDataOption)(nil)
  40. _ InsertOption = (*columnBasedDataOption)(nil)
  41. )
  42. type columnBasedDataOption struct {
  43. collName string
  44. partitionName string
  45. columns []column.Column
  46. }
  47. func (opt *columnBasedDataOption) WriteBackPKs(_ *entity.Schema, _ column.Column) error {
  48. // column based data option need not write back pk
  49. return nil
  50. }
  51. func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema, columns ...column.Column) ([]*schemapb.FieldData, int, error) {
  52. // setup dynamic related var
  53. isDynamic := colSchema.EnableDynamicField
  54. // check columns and field matches
  55. var rowSize int
  56. mNameField := make(map[string]*entity.Field)
  57. for _, field := range colSchema.Fields {
  58. mNameField[field.Name] = field
  59. }
  60. mNameColumn := make(map[string]column.Column)
  61. var dynamicColumns []column.Column
  62. for _, col := range columns {
  63. _, dup := mNameColumn[col.Name()]
  64. if dup {
  65. return nil, 0, fmt.Errorf("duplicated column %s found", col.Name())
  66. }
  67. l := col.Len()
  68. if rowSize == 0 {
  69. rowSize = l
  70. } else if rowSize != l {
  71. return nil, 0, errors.New("column size not match")
  72. }
  73. field, has := mNameField[col.Name()]
  74. if !has {
  75. if !isDynamic {
  76. return nil, 0, fmt.Errorf("field %s does not exist in collection %s", col.Name(), colSchema.CollectionName)
  77. }
  78. // add to dynamic column list for further processing
  79. dynamicColumns = append(dynamicColumns, col)
  80. continue
  81. }
  82. mNameColumn[col.Name()] = col
  83. if col.Type() != field.DataType {
  84. return nil, 0, fmt.Errorf("param column %s has type %v but collection field definition is %v", col.Name(), col.Type(), field.DataType)
  85. }
  86. if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector {
  87. dim := 0
  88. switch column := col.(type) {
  89. case *column.ColumnFloatVector:
  90. dim = column.Dim()
  91. case *column.ColumnBinaryVector:
  92. dim = column.Dim()
  93. }
  94. if fmt.Sprintf("%d", dim) != field.TypeParams[entity.TypeParamDim] {
  95. return nil, 0, fmt.Errorf("params column %s vector dim %d not match collection definition, which has dim of %s", field.Name, dim, field.TypeParams[entity.TypeParamDim])
  96. }
  97. }
  98. }
  99. // check all fixed field pass value
  100. for _, field := range colSchema.Fields {
  101. _, has := mNameColumn[field.Name]
  102. if !has &&
  103. !field.AutoID && !field.IsDynamic {
  104. return nil, 0, fmt.Errorf("field %s not passed", field.Name)
  105. }
  106. }
  107. fieldsData := make([]*schemapb.FieldData, 0, len(mNameColumn)+1)
  108. for _, fixedColumn := range mNameColumn {
  109. fieldsData = append(fieldsData, fixedColumn.FieldData())
  110. }
  111. if len(dynamicColumns) > 0 {
  112. // use empty column name here
  113. col, err := opt.mergeDynamicColumns("", rowSize, dynamicColumns)
  114. if err != nil {
  115. return nil, 0, err
  116. }
  117. fieldsData = append(fieldsData, col)
  118. }
  119. return fieldsData, rowSize, nil
  120. }
  121. func (opt *columnBasedDataOption) mergeDynamicColumns(dynamicName string, rowSize int, columns []column.Column) (*schemapb.FieldData, error) {
  122. values := make([][]byte, 0, rowSize)
  123. for i := 0; i < rowSize; i++ {
  124. m := make(map[string]interface{})
  125. for _, column := range columns {
  126. // range guaranteed
  127. m[column.Name()], _ = column.Get(i)
  128. }
  129. bs, err := json.Marshal(m)
  130. if err != nil {
  131. return nil, err
  132. }
  133. values = append(values, bs)
  134. }
  135. return &schemapb.FieldData{
  136. Type: schemapb.DataType_JSON,
  137. FieldName: dynamicName,
  138. Field: &schemapb.FieldData_Scalars{
  139. Scalars: &schemapb.ScalarField{
  140. Data: &schemapb.ScalarField_JsonData{
  141. JsonData: &schemapb.JSONArray{
  142. Data: values,
  143. },
  144. },
  145. },
  146. },
  147. IsDynamic: true,
  148. }, nil
  149. }
  150. func (opt *columnBasedDataOption) WithColumns(columns ...column.Column) *columnBasedDataOption {
  151. opt.columns = append(opt.columns, columns...)
  152. return opt
  153. }
  154. func (opt *columnBasedDataOption) WithBoolColumn(colName string, data []bool) *columnBasedDataOption {
  155. column := column.NewColumnBool(colName, data)
  156. return opt.WithColumns(column)
  157. }
  158. func (opt *columnBasedDataOption) WithInt8Column(colName string, data []int8) *columnBasedDataOption {
  159. column := column.NewColumnInt8(colName, data)
  160. return opt.WithColumns(column)
  161. }
  162. func (opt *columnBasedDataOption) WithInt16Column(colName string, data []int16) *columnBasedDataOption {
  163. column := column.NewColumnInt16(colName, data)
  164. return opt.WithColumns(column)
  165. }
  166. func (opt *columnBasedDataOption) WithInt32Column(colName string, data []int32) *columnBasedDataOption {
  167. column := column.NewColumnInt32(colName, data)
  168. return opt.WithColumns(column)
  169. }
  170. func (opt *columnBasedDataOption) WithInt64Column(colName string, data []int64) *columnBasedDataOption {
  171. column := column.NewColumnInt64(colName, data)
  172. return opt.WithColumns(column)
  173. }
  174. func (opt *columnBasedDataOption) WithVarcharColumn(colName string, data []string) *columnBasedDataOption {
  175. column := column.NewColumnVarChar(colName, data)
  176. return opt.WithColumns(column)
  177. }
  178. func (opt *columnBasedDataOption) WithFloatVectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption {
  179. column := column.NewColumnFloatVector(colName, dim, data)
  180. return opt.WithColumns(column)
  181. }
  182. func (opt *columnBasedDataOption) WithBinaryVectorColumn(colName string, dim int, data [][]byte) *columnBasedDataOption {
  183. column := column.NewColumnBinaryVector(colName, dim, data)
  184. return opt.WithColumns(column)
  185. }
  186. func (opt *columnBasedDataOption) WithPartition(partitionName string) *columnBasedDataOption {
  187. opt.partitionName = partitionName
  188. return opt
  189. }
  190. func (opt *columnBasedDataOption) CollectionName() string {
  191. return opt.collName
  192. }
  193. func (opt *columnBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) {
  194. fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
  195. if err != nil {
  196. return nil, err
  197. }
  198. return &milvuspb.InsertRequest{
  199. CollectionName: opt.collName,
  200. PartitionName: opt.partitionName,
  201. FieldsData: fieldsData,
  202. NumRows: uint32(rowNum),
  203. }, nil
  204. }
  205. func (opt *columnBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error) {
  206. fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
  207. if err != nil {
  208. return nil, err
  209. }
  210. return &milvuspb.UpsertRequest{
  211. CollectionName: opt.collName,
  212. PartitionName: opt.partitionName,
  213. FieldsData: fieldsData,
  214. NumRows: uint32(rowNum),
  215. }, nil
  216. }
  217. func NewColumnBasedInsertOption(collName string, columns ...column.Column) *columnBasedDataOption {
  218. return &columnBasedDataOption{
  219. columns: columns,
  220. collName: collName,
  221. // leave partition name empty, using default partition
  222. }
  223. }
  224. type rowBasedDataOption struct {
  225. *columnBasedDataOption
  226. rows []any
  227. }
  228. func NewRowBasedInsertOption(collName string, rows ...any) *rowBasedDataOption {
  229. return &rowBasedDataOption{
  230. columnBasedDataOption: &columnBasedDataOption{
  231. collName: collName,
  232. },
  233. rows: rows,
  234. }
  235. }
  236. func (opt *rowBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) {
  237. columns, err := row.AnyToColumns(opt.rows, coll.Schema)
  238. if err != nil {
  239. return nil, err
  240. }
  241. opt.columnBasedDataOption.columns = columns
  242. fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
  243. if err != nil {
  244. return nil, err
  245. }
  246. return &milvuspb.InsertRequest{
  247. CollectionName: opt.collName,
  248. PartitionName: opt.partitionName,
  249. FieldsData: fieldsData,
  250. NumRows: uint32(rowNum),
  251. }, nil
  252. }
  253. func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error) {
  254. columns, err := row.AnyToColumns(opt.rows, coll.Schema)
  255. if err != nil {
  256. return nil, err
  257. }
  258. opt.columnBasedDataOption.columns = columns
  259. fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
  260. if err != nil {
  261. return nil, err
  262. }
  263. return &milvuspb.UpsertRequest{
  264. CollectionName: opt.collName,
  265. PartitionName: opt.partitionName,
  266. FieldsData: fieldsData,
  267. NumRows: uint32(rowNum),
  268. }, nil
  269. }
  270. func (opt *rowBasedDataOption) WriteBackPKs(sch *entity.Schema, pks column.Column) error {
  271. pkField := sch.PKField()
  272. // not auto id, return
  273. if pkField == nil || !pkField.AutoID {
  274. return nil
  275. }
  276. if len(opt.rows) != pks.Len() {
  277. return errors.New("input row count is not equal to result pk length")
  278. }
  279. for i, r := range opt.rows {
  280. // index range checked
  281. v, _ := pks.Get(i)
  282. err := row.SetField(r, pkField.Name, v)
  283. if err != nil {
  284. return err
  285. }
  286. }
  287. return nil
  288. }
  289. type DeleteOption interface {
  290. Request() *milvuspb.DeleteRequest
  291. }
  292. type deleteOption struct {
  293. collectionName string
  294. partitionName string
  295. expr string
  296. }
  297. func (opt *deleteOption) Request() *milvuspb.DeleteRequest {
  298. return &milvuspb.DeleteRequest{
  299. CollectionName: opt.collectionName,
  300. PartitionName: opt.partitionName,
  301. Expr: opt.expr,
  302. }
  303. }
  304. func (opt *deleteOption) WithExpr(expr string) *deleteOption {
  305. opt.expr = expr
  306. return opt
  307. }
  308. func (opt *deleteOption) WithInt64IDs(fieldName string, ids []int64) *deleteOption {
  309. opt.expr = fmt.Sprintf("%s in %s", fieldName, strings.Join(strings.Fields(fmt.Sprint(ids)), ","))
  310. return opt
  311. }
  312. func (opt *deleteOption) WithStringIDs(fieldName string, ids []string) *deleteOption {
  313. opt.expr = fmt.Sprintf("%s in [%s]", fieldName, strings.Join(lo.Map(ids, func(id string, _ int) string { return fmt.Sprintf("\"%s\"", id) }), ","))
  314. return opt
  315. }
  316. func (opt *deleteOption) WithPartition(partitionName string) *deleteOption {
  317. opt.partitionName = partitionName
  318. return opt
  319. }
  320. func NewDeleteOption(collectionName string) *deleteOption {
  321. return &deleteOption{collectionName: collectionName}
  322. }