utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import random
  2. import time
  3. import random
  4. import string
  5. from faker import Faker
  6. import numpy as np
  7. from sklearn import preprocessing
  8. import requests
  9. from loguru import logger
  10. import datetime
  11. fake = Faker()
  12. def random_string(length=8):
  13. letters = string.ascii_letters
  14. return ''.join(random.choice(letters) for _ in range(length))
  15. def gen_collection_name(prefix="test_collection", length=8):
  16. name = f'{prefix}_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + random_string(length=length)
  17. return name
  18. def admin_password():
  19. return "Milvus"
  20. def invalid_cluster_name():
  21. res = [
  22. "demo" * 100,
  23. "demo" + "!",
  24. "demo" + "@",
  25. ]
  26. return res
  27. def wait_cluster_be_ready(cluster_id, client, timeout=120):
  28. t0 = time.time()
  29. while True and time.time() - t0 < timeout:
  30. rsp = client.cluster_describe(cluster_id)
  31. if rsp['code'] == 200:
  32. if rsp['data']['status'] == "RUNNING":
  33. return time.time() - t0
  34. time.sleep(1)
  35. logger.debug("wait cluster to be ready, cost time: %s" % (time.time() - t0))
  36. return -1
  37. def gen_data_by_type(field):
  38. data_type = field["type"]
  39. if data_type == "bool":
  40. return random.choice([True, False])
  41. if data_type == "int8":
  42. return random.randint(-128, 127)
  43. if data_type == "int16":
  44. return random.randint(-32768, 32767)
  45. if data_type == "int32":
  46. return random.randint(-2147483648, 2147483647)
  47. if data_type == "int64":
  48. return random.randint(-9223372036854775808, 9223372036854775807)
  49. if data_type == "float32":
  50. return np.float64(random.random()) # Object of type float32 is not JSON serializable, so set it as float64
  51. if data_type == "float64":
  52. return np.float64(random.random())
  53. if "varchar" in data_type:
  54. length = int(data_type.split("(")[1].split(")")[0])
  55. return "".join([chr(random.randint(97, 122)) for _ in range(length)])
  56. if "floatVector" in data_type:
  57. dim = int(data_type.split("(")[1].split(")")[0])
  58. return preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
  59. return None
  60. def get_data_by_fields(fields, nb):
  61. # logger.info(f"fields: {fields}")
  62. fields_not_auto_id = []
  63. for field in fields:
  64. if not field.get("autoId", False):
  65. fields_not_auto_id.append(field)
  66. # logger.info(f"fields_not_auto_id: {fields_not_auto_id}")
  67. data = []
  68. for i in range(nb):
  69. tmp = {}
  70. for field in fields_not_auto_id:
  71. tmp[field["name"]] = gen_data_by_type(field)
  72. data.append(tmp)
  73. return data
  74. def get_random_json_data(uid=None):
  75. # gen random dict data
  76. if uid is None:
  77. uid = 0
  78. data = {"uid": uid, "name": fake.name(), "address": fake.address(), "text": fake.text(), "email": fake.email(),
  79. "phone_number": fake.phone_number(),
  80. "array_int_dynamic": [random.randint(1, 100_000) for i in range(random.randint(1, 10))],
  81. "array_varchar_dynamic": [fake.name() for i in range(random.randint(1, 10))],
  82. "json": {
  83. "name": fake.name(),
  84. "address": fake.address()
  85. }
  86. }
  87. for i in range(random.randint(1, 10)):
  88. data["key" + str(random.randint(1, 100_000))] = "value" + str(random.randint(1, 100_000))
  89. return data
  90. def get_data_by_payload(payload, nb=100):
  91. dim = payload.get("dimension", 128)
  92. vector_field = payload.get("vectorField", "vector")
  93. data = []
  94. if nb == 1:
  95. data = [{
  96. vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(),
  97. **get_random_json_data()
  98. }]
  99. else:
  100. for i in range(nb):
  101. data.append({
  102. vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(),
  103. **get_random_json_data(uid=i)
  104. })
  105. return data
  106. def get_common_fields_by_data(data, exclude_fields=None):
  107. fields = set()
  108. if isinstance(data, dict):
  109. data = [data]
  110. if not isinstance(data, list):
  111. raise Exception("data must be list or dict")
  112. common_fields = set(data[0].keys())
  113. for d in data:
  114. keys = set(d.keys())
  115. common_fields = common_fields.intersection(keys)
  116. if exclude_fields is not None:
  117. exclude_fields = set(exclude_fields)
  118. common_fields = common_fields.difference(exclude_fields)
  119. return list(common_fields)
  120. def get_all_fields_by_data(data, exclude_fields=None):
  121. fields = set()
  122. for d in data:
  123. keys = list(d.keys())
  124. fields.union(keys)
  125. if exclude_fields is not None:
  126. exclude_fields = set(exclude_fields)
  127. fields = fields.difference(exclude_fields)
  128. return list(fields)