action_masking.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """Example showing how to use "action masking" in RLlib.
  2. "Action masking" allows the agent to select actions based on the current
  3. observation. This is useful in many practical scenarios, where different
  4. actions are available in different time steps.
  5. Blog post explaining action masking: https://boring-guy.sh/posts/masking-rl/
  6. RLlib supports action masking, i.e., disallowing these actions based on the
  7. observation, by slightly adjusting the environment and the model as shown in
  8. this example.
  9. Here, the ActionMaskEnv wraps an underlying environment (here, RandomEnv),
  10. defining only a subset of all actions as valid based on the environment's
  11. observations. If an invalid action is selected, the environment raises an error
  12. - this must not happen!
  13. The environment constructs Dict observations, where obs["observations"] holds
  14. the original observations and obs["action_mask"] holds the valid actions.
  15. To avoid selection invalid actions, the ActionMaskModel is used. This model
  16. takes the original observations, computes the logits of the corresponding
  17. actions and then sets the logits of all invalid actions to zero, thus disabling
  18. them. This only works with discrete actions.
  19. ---
  20. Run this example with defaults (using Tune and action masking):
  21. $ python action_masking.py
  22. Then run again without action masking, which will likely lead to errors due to
  23. invalid actions being selected (ValueError "Invalid action sent to env!"):
  24. $ python action_masking.py --no-masking
  25. Other options for running this example:
  26. $ python action_masking.py --help
  27. """
  28. import argparse
  29. import os
  30. from gym.spaces import Box, Discrete
  31. import ray
  32. from ray import tune
  33. from ray.rllib.agents import ppo
  34. from ray.rllib.examples.env.action_mask_env import ActionMaskEnv
  35. from ray.rllib.examples.models.action_mask_model import \
  36. ActionMaskModel, TorchActionMaskModel
  37. from ray.tune.logger import pretty_print
  38. def get_cli_args():
  39. """Create CLI parser and return parsed arguments"""
  40. parser = argparse.ArgumentParser()
  41. # example-specific args
  42. parser.add_argument(
  43. "--no-masking",
  44. action="store_true",
  45. help="Do NOT mask invalid actions. This will likely lead to errors.")
  46. # general args
  47. parser.add_argument(
  48. "--run",
  49. type=str,
  50. default="APPO",
  51. help="The RLlib-registered algorithm to use.")
  52. parser.add_argument("--num-cpus", type=int, default=0)
  53. parser.add_argument(
  54. "--framework",
  55. choices=["tf", "tf2", "tfe", "torch"],
  56. default="tf",
  57. help="The DL framework specifier.")
  58. parser.add_argument("--eager-tracing", action="store_true")
  59. parser.add_argument(
  60. "--stop-iters",
  61. type=int,
  62. default=10,
  63. help="Number of iterations to train.")
  64. parser.add_argument(
  65. "--stop-timesteps",
  66. type=int,
  67. default=10000,
  68. help="Number of timesteps to train.")
  69. parser.add_argument(
  70. "--stop-reward",
  71. type=float,
  72. default=80.0,
  73. help="Reward at which we stop training.")
  74. parser.add_argument(
  75. "--no-tune",
  76. action="store_true",
  77. help="Run without Tune using a manual train loop instead. Here,"
  78. "there is no TensorBoard support.")
  79. parser.add_argument(
  80. "--local-mode",
  81. action="store_true",
  82. help="Init Ray in local mode for easier debugging.")
  83. args = parser.parse_args()
  84. print(f"Running with following CLI args: {args}")
  85. return args
  86. if __name__ == "__main__":
  87. args = get_cli_args()
  88. ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
  89. # main part: configure the ActionMaskEnv and ActionMaskModel
  90. config = {
  91. # random env with 100 discrete actions and 5x [-1,1] observations
  92. # some actions are declared invalid and lead to errors
  93. "env": ActionMaskEnv,
  94. "env_config": {
  95. "action_space": Discrete(100),
  96. "observation_space": Box(-1.0, 1.0, (5, )),
  97. },
  98. # the ActionMaskModel retrieves the invalid actions and avoids them
  99. "model": {
  100. "custom_model": ActionMaskModel
  101. if args.framework != "torch" else TorchActionMaskModel,
  102. # disable action masking according to CLI
  103. "custom_model_config": {
  104. "no_masking": args.no_masking
  105. }
  106. },
  107. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  108. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  109. "framework": args.framework,
  110. # Run with tracing enabled for tfe/tf2?
  111. "eager_tracing": args.eager_tracing,
  112. }
  113. stop = {
  114. "training_iteration": args.stop_iters,
  115. "timesteps_total": args.stop_timesteps,
  116. "episode_reward_mean": args.stop_reward,
  117. }
  118. # manual training loop (no Ray tune)
  119. if args.no_tune:
  120. if args.run not in {"APPO", "PPO"}:
  121. raise ValueError("This example only supports APPO and PPO.")
  122. ppo_config = ppo.DEFAULT_CONFIG.copy()
  123. ppo_config.update(config)
  124. trainer = ppo.PPOTrainer(config=ppo_config, env=ActionMaskEnv)
  125. # run manual training loop and print results after each iteration
  126. for _ in range(args.stop_iters):
  127. result = trainer.train()
  128. print(pretty_print(result))
  129. # stop training if the target train steps or reward are reached
  130. if result["timesteps_total"] >= args.stop_timesteps or \
  131. result["episode_reward_mean"] >= args.stop_reward:
  132. break
  133. # manual test loop
  134. print("Finished training. Running manual test/inference loop.")
  135. # prepare environment with max 10 steps
  136. config["env_config"]["max_episode_len"] = 10
  137. env = ActionMaskEnv(config["env_config"])
  138. obs = env.reset()
  139. done = False
  140. # run one iteration until done
  141. print(f"ActionMaskEnv with {config['env_config']}")
  142. while not done:
  143. action = trainer.compute_single_action(obs)
  144. next_obs, reward, done, _ = env.step(action)
  145. # observations contain original observations and the action mask
  146. # reward is random and irrelevant here and therefore not printed
  147. print(f"Obs: {obs}, Action: {action}")
  148. obs = next_obs
  149. # run with tune for auto trainer creation, stopping, TensorBoard, etc.
  150. else:
  151. results = tune.run(args.run, config=config, stop=stop, verbose=2)
  152. print("Finished successfully without selecting invalid actions.")
  153. ray.shutdown()