greyscale_env.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. """
  2. Example of interfacing with an environment that produces 2D observations.
  3. This example shows how turning 2D observations with shape (A, B) into a 3D
  4. observations with shape (C, D, 1) can enable usage of RLlib's default models.
  5. RLlib's default Catalog class does not provide default models for 2D observation
  6. spaces, but it does so for 3D observations.
  7. Therefore, one can either write a custom model or transform the 2D observations into 3D
  8. observations. This enables RLlib to use one of the default CNN filters, even though the
  9. original observation space of the environment does not fit them.
  10. This simple example should reach rewards of 50 within 150k timesteps.
  11. """
  12. from numpy import float32
  13. import argparse
  14. from pettingzoo.butterfly import pistonball_v6
  15. from supersuit import (
  16. normalize_obs_v0,
  17. dtype_v0,
  18. color_reduction_v0,
  19. reshape_v0,
  20. resize_v1,
  21. )
  22. from ray.rllib.algorithms.ppo import PPOConfig
  23. from ray.rllib.env import PettingZooEnv
  24. from ray.tune.registry import register_env
  25. from ray import tune
  26. from ray import air
  27. parser = argparse.ArgumentParser()
  28. parser.add_argument(
  29. "--framework",
  30. choices=["tf2", "torch"],
  31. default="torch",
  32. help="The DL framework specifier.",
  33. )
  34. parser.add_argument(
  35. "--as-test",
  36. action="store_true",
  37. help="Whether this script should be run as a compilation test.",
  38. )
  39. parser.add_argument(
  40. "--stop-iters", type=int, default=150, help="Number of iterations to train."
  41. )
  42. parser.add_argument(
  43. "--stop-timesteps", type=int, default=1000000, help="Number of timesteps to train."
  44. )
  45. parser.add_argument(
  46. "--stop-reward", type=float, default=50, help="Reward at which we stop training."
  47. )
  48. args = parser.parse_args()
  49. # The space we down-sample and transform the greyscale pistonball images to.
  50. # Other spaces supported by RLlib can be chosen here.
  51. TRANSFORMED_OBS_SPACE = (42, 42, 1)
  52. def env_creator(config):
  53. env = pistonball_v6.env(n_pistons=5)
  54. env = dtype_v0(env, dtype=float32)
  55. # This gives us greyscale images for the color red
  56. env = color_reduction_v0(env, mode="R")
  57. env = normalize_obs_v0(env)
  58. # This gives us images that are upsampled to the number of pixels in the
  59. # default CNN filter
  60. env = resize_v1(
  61. env, x_size=TRANSFORMED_OBS_SPACE[0], y_size=TRANSFORMED_OBS_SPACE[1]
  62. )
  63. # This gives us 3D images for which we have default filters
  64. env = reshape_v0(env, shape=TRANSFORMED_OBS_SPACE)
  65. return env
  66. # Register env
  67. register_env("pistonball", lambda config: PettingZooEnv(env_creator(config)))
  68. config = (
  69. PPOConfig()
  70. .environment("pistonball", env_config={"local_ratio": 0.5}, clip_rewards=True)
  71. .rollouts(
  72. num_rollout_workers=15 if not args.as_test else 2,
  73. num_envs_per_worker=1,
  74. observation_filter="NoFilter",
  75. rollout_fragment_length="auto",
  76. )
  77. .framework("torch")
  78. .training(
  79. entropy_coeff=0.01,
  80. vf_loss_coeff=0.1,
  81. clip_param=0.1,
  82. vf_clip_param=10.0,
  83. num_sgd_iter=10,
  84. kl_coeff=0.5,
  85. lr=0.0001,
  86. grad_clip=100,
  87. sgd_minibatch_size=500,
  88. train_batch_size=5000 if not args.as_test else 1000,
  89. model={"vf_share_layers": True},
  90. )
  91. .resources(num_gpus=1 if not args.as_test else 0)
  92. .reporting(min_time_s_per_iteration=30)
  93. )
  94. tune.Tuner(
  95. "PPO",
  96. param_space=config.to_dict(),
  97. run_config=air.RunConfig(
  98. stop={
  99. "training_iteration": args.stop_iters,
  100. "timesteps_total": args.stop_timesteps,
  101. "episode_reward_mean": args.stop_reward,
  102. },
  103. verbose=2,
  104. ),
  105. ).fit()