apex.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # This workload tests running APEX
  2. import ray
  3. from ray.tune import run_experiments
  4. from ray.tune.utils.release_test_util import ProgressCallback
  5. num_redis_shards = 5
  6. redis_max_memory = 10**8
  7. object_store_memory = 10**9
  8. num_nodes = 3
  9. message = (
  10. "Make sure there is enough memory on this machine to run this "
  11. "workload. We divide the system memory by 2 to provide a buffer."
  12. )
  13. assert (
  14. num_nodes * object_store_memory + num_redis_shards * redis_max_memory
  15. < ray._private.utils.get_system_memory() / 2
  16. ), message
  17. # Simulate a cluster on one machine.
  18. # cluster = Cluster()
  19. # for i in range(num_nodes):
  20. # cluster.add_node(redis_port=6379 if i == 0 else None,
  21. # num_redis_shards=num_redis_shards if i == 0 else None,
  22. # num_cpus=20,
  23. # num_gpus=0,
  24. # resources={str(i): 2},
  25. # object_store_memory=object_store_memory,
  26. # redis_max_memory=redis_max_memory,
  27. # dashboard_host="0.0.0.0")
  28. # ray.init(address=cluster.address)
  29. ray.init()
  30. # Run the workload.
  31. run_experiments(
  32. {
  33. "apex": {
  34. "run": "APEX",
  35. "env": "ALE/Pong-v5",
  36. "config": {
  37. "num_workers": 3,
  38. "num_gpus": 0,
  39. "replay_buffer_config": {
  40. "capacity": 10000,
  41. },
  42. "num_steps_sampled_before_learning_starts": 0,
  43. "rollout_fragment_length": "auto",
  44. "train_batch_size": 1,
  45. "min_time_s_per_iteration": 10,
  46. "min_sample_timesteps_per_iteration": 10,
  47. },
  48. }
  49. },
  50. callbacks=[ProgressCallback()],
  51. )