import os import time import json import ray from ray.util.placement_group import placement_group # Tests are supposed to run for 10 minutes. RUNTIME = 600 NUM_CPU_BUNDLES = 30 @ray.remote(num_cpus=1) class Worker(object): def __init__(self, i): self.i = i def work(self): time.sleep(0.1) print("work ", self.i) @ray.remote(num_cpus=1, num_gpus=1) class Trainer(object): def __init__(self, i): self.i = i def train(self): time.sleep(0.2) print("train ", self.i) def main(): ray.init(address="auto") bundles = [{"CPU": 1, "GPU": 1}] bundles += [{"CPU": 1} for _ in range(NUM_CPU_BUNDLES)] pg = placement_group(bundles, strategy="PACK") ray.get(pg.ready()) workers = [ Worker.options(placement_group=pg).remote(i) for i in range(NUM_CPU_BUNDLES) ] trainer = Trainer.options(placement_group=pg).remote(0) start = time.time() while True: ray.get([workers[i].work.remote() for i in range(NUM_CPU_BUNDLES)]) ray.get(trainer.train.remote()) end = time.time() if end - start > RUNTIME: break if "TEST_OUTPUT_JSON" in os.environ: out_file = open(os.environ["TEST_OUTPUT_JSON"], "w") results = {} json.dump(results, out_file) if __name__ == "__main__": main()