data_benchmark.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import argparse
  2. import json
  3. import os
  4. import time
  5. import ray
  6. from ray.air.config import DatasetConfig, ScalingConfig
  7. from ray.air.util.check_ingest import DummyTrainer
  8. from ray.data.preprocessors import BatchMapper
  9. GiB = 1024 * 1024 * 1024
  10. def make_ds(size_gb: int):
  11. # Dataset of 10KiB tensor records.
  12. total_size = GiB * size_gb
  13. record_dim = 1280
  14. record_size = record_dim * 8
  15. num_records = int(total_size / record_size)
  16. dataset = ray.data.range_tensor(num_records, shape=(record_dim,))
  17. print("Created dataset", dataset, "of size", dataset.size_bytes())
  18. return dataset
  19. def run_ingest_bulk(dataset, num_workers, num_cpus_per_worker):
  20. dummy_prep = BatchMapper(lambda df: df * 2, batch_format="pandas")
  21. trainer = DummyTrainer(
  22. scaling_config=ScalingConfig(
  23. num_workers=num_workers,
  24. trainer_resources={"CPU": 0},
  25. resources_per_worker={"CPU": num_cpus_per_worker},
  26. _max_cpu_fraction_per_node=0.1,
  27. ),
  28. datasets={"train": dataset},
  29. preprocessor=dummy_prep,
  30. num_epochs=1,
  31. prefetch_batches=1,
  32. dataset_config={"train": DatasetConfig(split=True)},
  33. )
  34. trainer.fit()
  35. if __name__ == "__main__":
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument("--num-workers", type=int, default=4)
  38. parser.add_argument(
  39. "--num-cpus-per-worker",
  40. type=int,
  41. default=1,
  42. help="Number of CPUs for each training worker.",
  43. )
  44. parser.add_argument("--dataset-size-gb", type=int, default=200)
  45. args = parser.parse_args()
  46. ds = make_ds(args.dataset_size_gb)
  47. start = time.time()
  48. run_ingest_bulk(ds, args.num_workers, args.num_cpus_per_worker)
  49. end = time.time()
  50. time_taken = end - start
  51. result = {"time_taken_s": time_taken}
  52. print("Results:", result)
  53. test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/result.json")
  54. with open(test_output_json, "wt") as f:
  55. json.dump(result, f)