create_test_data.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import argparse
  2. import numpy as np
  3. import os
  4. from xgboost_ray.tests.utils import create_parquet
  5. if __name__ == "__main__":
  6. if "OMP_NUM_THREADS" in os.environ:
  7. del os.environ["OMP_NUM_THREADS"]
  8. parser = argparse.ArgumentParser(description="Create fake data.")
  9. parser.add_argument(
  10. "filename", type=str, default="/data/parted.parquet/", help="ray/dask"
  11. )
  12. parser.add_argument(
  13. "-r", "--num-rows", required=False, type=int, default=1e8, help="num rows"
  14. )
  15. parser.add_argument(
  16. "-p",
  17. "--num-partitions",
  18. required=False,
  19. type=int,
  20. default=100,
  21. help="num partitions",
  22. )
  23. parser.add_argument(
  24. "-c",
  25. "--num-cols",
  26. required=False,
  27. type=int,
  28. default=4,
  29. help="num columns (features)",
  30. )
  31. parser.add_argument(
  32. "-C", "--num-classes", required=False, type=int, default=2, help="num classes"
  33. )
  34. parser.add_argument(
  35. "-s", "--seed", required=False, type=int, default=1234, help="random seed"
  36. )
  37. args = parser.parse_args()
  38. np.random.seed(args.seed)
  39. create_parquet(
  40. args.filename,
  41. num_rows=int(args.num_rows),
  42. num_partitions=int(args.num_partitions),
  43. num_features=int(args.num_cols),
  44. num_classes=int(args.num_classes),
  45. )