custom_env.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. Example of a custom gym environment and model. Run this for a demo.
  3. This example shows:
  4. - using a custom environment
  5. - using a custom model
  6. - using Tune for grid search to try different learning rates
  7. You can visualize experiment results in ~/ray_results using TensorBoard.
  8. Run example with defaults:
  9. $ python custom_env.py
  10. For CLI options:
  11. $ python custom_env.py --help
  12. """
  13. import argparse
  14. import gym
  15. from gym.spaces import Discrete, Box
  16. import numpy as np
  17. import os
  18. import random
  19. import ray
  20. from ray import tune
  21. from ray.rllib.agents import ppo
  22. from ray.rllib.env.env_context import EnvContext
  23. from ray.rllib.models import ModelCatalog
  24. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  25. from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
  26. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  27. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
  28. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  29. from ray.rllib.utils.test_utils import check_learning_achieved
  30. from ray.tune.logger import pretty_print
  31. tf1, tf, tfv = try_import_tf()
  32. torch, nn = try_import_torch()
  33. parser = argparse.ArgumentParser()
  34. parser.add_argument(
  35. "--run",
  36. type=str,
  37. default="PPO",
  38. help="The RLlib-registered algorithm to use.")
  39. parser.add_argument(
  40. "--framework",
  41. choices=["tf", "tf2", "tfe", "torch"],
  42. default="tf",
  43. help="The DL framework specifier.")
  44. parser.add_argument(
  45. "--as-test",
  46. action="store_true",
  47. help="Whether this script should be run as a test: --stop-reward must "
  48. "be achieved within --stop-timesteps AND --stop-iters.")
  49. parser.add_argument(
  50. "--stop-iters",
  51. type=int,
  52. default=50,
  53. help="Number of iterations to train.")
  54. parser.add_argument(
  55. "--stop-timesteps",
  56. type=int,
  57. default=100000,
  58. help="Number of timesteps to train.")
  59. parser.add_argument(
  60. "--stop-reward",
  61. type=float,
  62. default=0.1,
  63. help="Reward at which we stop training.")
  64. parser.add_argument(
  65. "--no-tune",
  66. action="store_true",
  67. help="Run without Tune using a manual train loop instead. In this case,"
  68. "use PPO without grid search and no TensorBoard.")
  69. parser.add_argument(
  70. "--local-mode",
  71. action="store_true",
  72. help="Init Ray in local mode for easier debugging.")
  73. class SimpleCorridor(gym.Env):
  74. """Example of a custom env in which you have to walk down a corridor.
  75. You can configure the length of the corridor via the env config."""
  76. def __init__(self, config: EnvContext):
  77. self.end_pos = config["corridor_length"]
  78. self.cur_pos = 0
  79. self.action_space = Discrete(2)
  80. self.observation_space = Box(
  81. 0.0, self.end_pos, shape=(1, ), dtype=np.float32)
  82. # Set the seed. This is only used for the final (reach goal) reward.
  83. self.seed(config.worker_index * config.num_workers)
  84. def reset(self):
  85. self.cur_pos = 0
  86. return [self.cur_pos]
  87. def step(self, action):
  88. assert action in [0, 1], action
  89. if action == 0 and self.cur_pos > 0:
  90. self.cur_pos -= 1
  91. elif action == 1:
  92. self.cur_pos += 1
  93. done = self.cur_pos >= self.end_pos
  94. # Produce a random reward when we reach the goal.
  95. return [self.cur_pos], \
  96. random.random() * 2 if done else -0.1, done, {}
  97. def seed(self, seed=None):
  98. random.seed(seed)
  99. class CustomModel(TFModelV2):
  100. """Example of a keras custom model that just delegates to an fc-net."""
  101. def __init__(self, obs_space, action_space, num_outputs, model_config,
  102. name):
  103. super(CustomModel, self).__init__(obs_space, action_space, num_outputs,
  104. model_config, name)
  105. self.model = FullyConnectedNetwork(obs_space, action_space,
  106. num_outputs, model_config, name)
  107. def forward(self, input_dict, state, seq_lens):
  108. return self.model.forward(input_dict, state, seq_lens)
  109. def value_function(self):
  110. return self.model.value_function()
  111. class TorchCustomModel(TorchModelV2, nn.Module):
  112. """Example of a PyTorch custom model that just delegates to a fc-net."""
  113. def __init__(self, obs_space, action_space, num_outputs, model_config,
  114. name):
  115. TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
  116. model_config, name)
  117. nn.Module.__init__(self)
  118. self.torch_sub_model = TorchFC(obs_space, action_space, num_outputs,
  119. model_config, name)
  120. def forward(self, input_dict, state, seq_lens):
  121. input_dict["obs"] = input_dict["obs"].float()
  122. fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens)
  123. return fc_out, []
  124. def value_function(self):
  125. return torch.reshape(self.torch_sub_model.value_function(), [-1])
  126. if __name__ == "__main__":
  127. args = parser.parse_args()
  128. print(f"Running with following CLI options: {args}")
  129. ray.init(local_mode=args.local_mode)
  130. # Can also register the env creator function explicitly with:
  131. # register_env("corridor", lambda config: SimpleCorridor(config))
  132. ModelCatalog.register_custom_model(
  133. "my_model", TorchCustomModel
  134. if args.framework == "torch" else CustomModel)
  135. config = {
  136. "env": SimpleCorridor, # or "corridor" if registered above
  137. "env_config": {
  138. "corridor_length": 5,
  139. },
  140. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  141. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  142. "model": {
  143. "custom_model": "my_model",
  144. "vf_share_layers": True,
  145. },
  146. "num_workers": 1, # parallelism
  147. "framework": args.framework,
  148. }
  149. stop = {
  150. "training_iteration": args.stop_iters,
  151. "timesteps_total": args.stop_timesteps,
  152. "episode_reward_mean": args.stop_reward,
  153. }
  154. if args.no_tune:
  155. # manual training with train loop using PPO and fixed learning rate
  156. if args.run != "PPO":
  157. raise ValueError("Only support --run PPO with --no-tune.")
  158. print("Running manual train loop without Ray Tune.")
  159. ppo_config = ppo.DEFAULT_CONFIG.copy()
  160. ppo_config.update(config)
  161. # use fixed learning rate instead of grid search (needs tune)
  162. ppo_config["lr"] = 1e-3
  163. trainer = ppo.PPOTrainer(config=ppo_config, env=SimpleCorridor)
  164. # run manual training loop and print results after each iteration
  165. for _ in range(args.stop_iters):
  166. result = trainer.train()
  167. print(pretty_print(result))
  168. # stop training of the target train steps or reward are reached
  169. if result["timesteps_total"] >= args.stop_timesteps or \
  170. result["episode_reward_mean"] >= args.stop_reward:
  171. break
  172. else:
  173. # automated run with Tune and grid search and TensorBoard
  174. print("Training automatically with Ray Tune")
  175. results = tune.run(args.run, config=config, stop=stop)
  176. if args.as_test:
  177. print("Checking if learning goals were achieved")
  178. check_learning_achieved(results, args.stop_reward)
  179. ray.shutdown()