many_ppo.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # This workload tests running many instances of PPO (many actors)
  2. # This covers https://github.com/ray-project/ray/pull/12148
  3. import ray
  4. from ray.tune import run_experiments
  5. from ray.tune.utils.release_test_util import ProgressCallback
  6. from ray._private.test_utils import monitor_memory_usage
  7. num_redis_shards = 5
  8. redis_max_memory = 10**8
  9. object_store_memory = 10**9
  10. num_nodes = 3
  11. message = (
  12. "Make sure there is enough memory on this machine to run this "
  13. "workload. We divide the system memory by 2 to provide a buffer."
  14. )
  15. assert (
  16. num_nodes * object_store_memory + num_redis_shards * redis_max_memory
  17. < ray._private.utils.get_system_memory() / 2
  18. ), message
  19. # Simulate a cluster on one machine.
  20. ray.init(address="auto")
  21. monitor_actor = monitor_memory_usage()
  22. # Run the workload.
  23. run_experiments(
  24. {
  25. "ppo": {
  26. "run": "PPO",
  27. "env": "CartPole-v0",
  28. "num_samples": 10000,
  29. "config": {
  30. "framework": "torch",
  31. "num_workers": 7,
  32. "num_gpus": 0,
  33. "num_sgd_iter": 1,
  34. },
  35. "stop": {
  36. "timesteps_total": 1,
  37. },
  38. }
  39. },
  40. callbacks=[ProgressCallback()],
  41. )
  42. ray.get(monitor_actor.stop_run.remote())
  43. used_gb, usage = ray.get(monitor_actor.get_peak_memory_info.remote())
  44. print(f"Peak memory usage: {round(used_gb, 2)}GB")
  45. print(f"Peak memory usage per processes:\n {usage}")