custom_env.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. """
  2. Example of a custom gym environment. Run this example for a demo.
  3. This example shows the usage of:
  4. - a custom environment
  5. - Ray Tune for grid search to try different learning rates
  6. You can visualize experiment results in ~/ray_results using TensorBoard.
  7. Run example with defaults:
  8. $ python custom_env.py
  9. For CLI options:
  10. $ python custom_env.py --help
  11. """
  12. import argparse
  13. import gymnasium as gym
  14. from gymnasium.spaces import Discrete, Box
  15. import numpy as np
  16. import os
  17. import random
  18. import ray
  19. from ray import air, tune
  20. from ray.rllib.env.env_context import EnvContext
  21. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  22. from ray.rllib.utils.test_utils import check_learning_achieved
  23. from ray.tune.logger import pretty_print
  24. from ray.tune.registry import get_trainable_cls
  25. tf1, tf, tfv = try_import_tf()
  26. torch, nn = try_import_torch()
  27. parser = argparse.ArgumentParser()
  28. parser.add_argument(
  29. "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
  30. )
  31. parser.add_argument(
  32. "--framework",
  33. choices=["tf", "tf2", "torch"],
  34. default="torch",
  35. help="The DL framework specifier.",
  36. )
  37. parser.add_argument(
  38. "--as-test",
  39. action="store_true",
  40. help="Whether this script should be run as a test: --stop-reward must "
  41. "be achieved within --stop-timesteps AND --stop-iters.",
  42. )
  43. parser.add_argument(
  44. "--stop-iters", type=int, default=50, help="Number of iterations to train."
  45. )
  46. parser.add_argument(
  47. "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
  48. )
  49. parser.add_argument(
  50. "--stop-reward", type=float, default=0.1, help="Reward at which we stop training."
  51. )
  52. parser.add_argument(
  53. "--no-tune",
  54. action="store_true",
  55. help="Run without Tune using a manual train loop instead. In this case,"
  56. "use PPO without grid search and no TensorBoard.",
  57. )
  58. parser.add_argument(
  59. "--local-mode",
  60. action="store_true",
  61. help="Init Ray in local mode for easier debugging.",
  62. )
  63. class SimpleCorridor(gym.Env):
  64. """Example of a custom env in which you have to walk down a corridor.
  65. You can configure the length of the corridor via the env config."""
  66. def __init__(self, config: EnvContext):
  67. self.end_pos = config["corridor_length"]
  68. self.cur_pos = 0
  69. self.action_space = Discrete(2)
  70. self.observation_space = Box(0.0, self.end_pos, shape=(1,), dtype=np.float32)
  71. # Set the seed. This is only used for the final (reach goal) reward.
  72. self.reset(seed=config.worker_index * config.num_workers)
  73. def reset(self, *, seed=None, options=None):
  74. random.seed(seed)
  75. self.cur_pos = 0
  76. return [self.cur_pos], {}
  77. def step(self, action):
  78. assert action in [0, 1], action
  79. if action == 0 and self.cur_pos > 0:
  80. self.cur_pos -= 1
  81. elif action == 1:
  82. self.cur_pos += 1
  83. done = truncated = self.cur_pos >= self.end_pos
  84. # Produce a random reward when we reach the goal.
  85. return (
  86. [self.cur_pos],
  87. random.random() * 2 if done else -0.1,
  88. done,
  89. truncated,
  90. {},
  91. )
  92. if __name__ == "__main__":
  93. args = parser.parse_args()
  94. print(f"Running with following CLI options: {args}")
  95. ray.init(local_mode=args.local_mode)
  96. # Can also register the env creator function explicitly with:
  97. # register_env("corridor", lambda config: SimpleCorridor(config))
  98. config = (
  99. get_trainable_cls(args.run)
  100. .get_default_config()
  101. # or "corridor" if registered above
  102. .environment(SimpleCorridor, env_config={"corridor_length": 5})
  103. .framework(args.framework)
  104. .rollouts(num_rollout_workers=1)
  105. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  106. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  107. )
  108. stop = {
  109. "training_iteration": args.stop_iters,
  110. "timesteps_total": args.stop_timesteps,
  111. "episode_reward_mean": args.stop_reward,
  112. }
  113. if args.no_tune:
  114. # manual training with train loop using PPO and fixed learning rate
  115. if args.run != "PPO":
  116. raise ValueError("Only support --run PPO with --no-tune.")
  117. print("Running manual train loop without Ray Tune.")
  118. # use fixed learning rate instead of grid search (needs tune)
  119. config.lr = 1e-3
  120. algo = config.build()
  121. # run manual training loop and print results after each iteration
  122. for _ in range(args.stop_iters):
  123. result = algo.train()
  124. print(pretty_print(result))
  125. # stop training of the target train steps or reward are reached
  126. if (
  127. result["timesteps_total"] >= args.stop_timesteps
  128. or result["episode_reward_mean"] >= args.stop_reward
  129. ):
  130. break
  131. algo.stop()
  132. else:
  133. # automated run with Tune and grid search and TensorBoard
  134. print("Training automatically with Ray Tune")
  135. tuner = tune.Tuner(
  136. args.run,
  137. param_space=config.to_dict(),
  138. run_config=air.RunConfig(stop=stop),
  139. )
  140. results = tuner.fit()
  141. if args.as_test:
  142. print("Checking if learning goals were achieved")
  143. check_learning_achieved(results, args.stop_reward)
  144. ray.shutdown()