pg_run.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import os
  2. import time
  3. import json
  4. import ray
  5. from ray.util.placement_group import placement_group
  6. # Tests are supposed to run for 10 minutes.
  7. RUNTIME = 600
  8. NUM_CPU_BUNDLES = 30
  9. @ray.remote(num_cpus=1)
  10. class Worker(object):
  11. def __init__(self, i):
  12. self.i = i
  13. def work(self):
  14. time.sleep(0.1)
  15. print("work ", self.i)
  16. @ray.remote(num_cpus=1, num_gpus=1)
  17. class Trainer(object):
  18. def __init__(self, i):
  19. self.i = i
  20. def train(self):
  21. time.sleep(0.2)
  22. print("train ", self.i)
  23. def main():
  24. ray.init(address="auto")
  25. bundles = [{"CPU": 1, "GPU": 1}]
  26. bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)]
  27. pg = placement_group(bundles, strategy="PACK")
  28. ray.get(pg.ready())
  29. workers = [
  30. Worker.options(placement_group=pg).remote(i) for i in range(NUM_CPU_BUNDLES)
  31. ]
  32. trainer = Trainer.options(placement_group=pg).remote(0)
  33. start = time.time()
  34. while True:
  35. ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)])
  36. ray.get(trainer.train.remote())
  37. end = time.time()
  38. if end - start > RUNTIME:
  39. break
  40. if "TEST_OUTPUT_JSON" in os.environ:
  41. out_file = open(os.environ["TEST_OUTPUT_JSON"], "w")
  42. results = {}
  43. json.dump(results, out_file)
  44. if __name__ == "__main__":
  45. main()