remote_envs_with_inference_done_on_main_node.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """
  2. This script demonstrates how one can specify n (vectorized) envs
  3. as ray remote (actors), such that stepping through these occurs in parallel.
  4. Also, actions for each env step will be calculated on the "main" node.
  5. This can be useful if the "main" node is a GPU machine and we would like to
  6. speed up batched action calculations, similar to DeepMind's SEED
  7. architecture, described here:
  8. https://ai.googleblog.com/2020/03/massively-scaling-reinforcement.html
  9. """
  10. import argparse
  11. import os
  12. from typing import Union
  13. import ray
  14. from ray import air, tune
  15. from ray.rllib.algorithms.ppo import PPO, PPOConfig
  16. from ray.rllib.algorithms.algorithm import Algorithm
  17. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  18. from ray.rllib.utils.annotations import override
  19. from ray.rllib.utils.test_utils import check_learning_achieved
  20. from ray.rllib.utils.typing import PartialAlgorithmConfigDict
  21. from ray.tune import PlacementGroupFactory
  22. from ray.tune.logger import pretty_print
  23. def get_cli_args():
  24. """Create CLI parser and return parsed arguments"""
  25. parser = argparse.ArgumentParser()
  26. # example-specific args
  27. # This should be >1, otherwise, remote envs make no sense.
  28. parser.add_argument("--num-envs-per-worker", type=int, default=4)
  29. # general args
  30. parser.add_argument(
  31. "--framework",
  32. choices=["tf", "tf2", "torch"],
  33. default="torch",
  34. help="The DL framework specifier.",
  35. )
  36. parser.add_argument(
  37. "--as-test",
  38. action="store_true",
  39. help="Whether this script should be run as a test: --stop-reward must "
  40. "be achieved within --stop-timesteps AND --stop-iters.",
  41. )
  42. parser.add_argument(
  43. "--stop-iters", type=int, default=50, help="Number of iterations to train."
  44. )
  45. parser.add_argument(
  46. "--stop-timesteps",
  47. type=int,
  48. default=100000,
  49. help="Number of timesteps to train.",
  50. )
  51. parser.add_argument(
  52. "--stop-reward",
  53. type=float,
  54. default=150.0,
  55. help="Reward at which we stop training.",
  56. )
  57. parser.add_argument(
  58. "--no-tune",
  59. action="store_true",
  60. help="Run without Tune using a manual train loop instead. Here,"
  61. "there is no TensorBoard support.",
  62. )
  63. parser.add_argument(
  64. "--local-mode",
  65. action="store_true",
  66. help="Init Ray in local mode for easier debugging.",
  67. )
  68. args = parser.parse_args()
  69. print(f"Running with following CLI args: {args}")
  70. return args
  71. # The modified Algorithm class we will use:
  72. # Subclassing from PPO, our algo will only modity `default_resource_request`,
  73. # telling Ray Tune that it's ok (not mandatory) to place our n remote envs on a
  74. # different node (each env using 1 CPU).
  75. class PPORemoteInference(PPO):
  76. @classmethod
  77. @override(Algorithm)
  78. def default_resource_request(
  79. cls,
  80. config: Union[AlgorithmConfig, PartialAlgorithmConfigDict],
  81. ):
  82. if isinstance(config, AlgorithmConfig):
  83. cf = config
  84. else:
  85. cf = cls.get_default_config().update_from_dict(config)
  86. # Return PlacementGroupFactory containing all needed resources
  87. # (already properly defined as device bundles).
  88. return PlacementGroupFactory(
  89. bundles=[
  90. {
  91. # Single CPU for the local worker. This CPU will host the
  92. # main model in this example (num_workers=0).
  93. "CPU": 1,
  94. # Possibly add n GPUs to this.
  95. "GPU": cf.num_gpus,
  96. },
  97. {
  98. # Different bundle (meaning: possibly different node)
  99. # for your n "remote" envs (set remote_worker_envs=True).
  100. "CPU": cf.num_envs_per_worker,
  101. },
  102. ],
  103. strategy=cf.placement_strategy,
  104. )
  105. if __name__ == "__main__":
  106. args = get_cli_args()
  107. ray.init(num_cpus=6, local_mode=args.local_mode)
  108. config = (
  109. PPOConfig()
  110. .environment("CartPole-v1")
  111. .framework(args.framework)
  112. .rollouts(
  113. # Force sub-envs to be ray.actor.ActorHandles, so we can step
  114. # through them in parallel.
  115. remote_worker_envs=True,
  116. num_envs_per_worker=args.num_envs_per_worker,
  117. # Use a single worker (however, with n parallelized remote envs, maybe
  118. # even running on another node).
  119. # Action computations will occur on the "main" (GPU?) node, while
  120. # the envs run on one or more CPU node(s).
  121. num_rollout_workers=0,
  122. )
  123. .resources(
  124. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  125. num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  126. # Set the number of CPUs used by the (local) worker, aka "driver"
  127. # to match the number of ray remote envs.
  128. num_cpus_for_local_worker=args.num_envs_per_worker + 1,
  129. )
  130. )
  131. # Run as manual training loop.
  132. if args.no_tune:
  133. # manual training loop using PPO and manually keeping track of state
  134. algo = PPORemoteInference(config=config)
  135. # run manual training loop and print results after each iteration
  136. for _ in range(args.stop_iters):
  137. result = algo.train()
  138. print(pretty_print(result))
  139. # Stop training if the target train steps or reward are reached.
  140. if (
  141. result["timesteps_total"] >= args.stop_timesteps
  142. or result["episode_reward_mean"] >= args.stop_reward
  143. ):
  144. break
  145. # Run with Tune for auto env and algorithm creation and TensorBoard.
  146. else:
  147. stop = {
  148. "training_iteration": args.stop_iters,
  149. "timesteps_total": args.stop_timesteps,
  150. "episode_reward_mean": args.stop_reward,
  151. }
  152. results = tune.Tuner(
  153. PPORemoteInference,
  154. param_space=config,
  155. run_config=air.RunConfig(stop=stop, verbose=1),
  156. ).fit()
  157. if args.as_test:
  158. check_learning_achieved(results, args.stop_reward)
  159. ray.shutdown()