attention_net.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. """
  2. Example of using an RL agent (default: PPO) with an AttentionNet model,
  3. which is useful for environments where state is important but not explicitly
  4. part of the observations.
  5. For example, in the "repeat after me" environment (default here), the agent
  6. needs to repeat an observation from n timesteps before.
  7. AttentionNet keeps state of previous observations and uses transformers to
  8. learn a policy that successfully repeats previous observations.
  9. Without attention, the RL agent only "sees" the last observation, not the one
  10. n timesteps ago and cannot learn to repeat this previous observation.
  11. AttentionNet paper: https://arxiv.org/abs/1506.07704
  12. This example script also shows how to train and test a PPO agent with an
  13. AttentionNet model manually, i.e., without using Tune.
  14. ---
  15. Run this example with defaults (using Tune and AttentionNet on the "repeat
  16. after me" environment):
  17. $ python attention_net.py
  18. Then run again without attention:
  19. $ python attention_net.py --no-attention
  20. Compare the learning curve on TensorBoard:
  21. $ cd ~/ray-results/; tensorboard --logdir .
  22. There will be a huge difference between the version with and without attention!
  23. Other options for running this example:
  24. $ python attention_net.py --help
  25. """
  26. import argparse
  27. import os
  28. import numpy as np
  29. import ray
  30. from ray import tune
  31. from ray.rllib.agents import ppo
  32. from ray.rllib.examples.env.look_and_push import LookAndPush, OneHot
  33. from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
  34. from ray.rllib.examples.env.repeat_initial_obs_env import RepeatInitialObsEnv
  35. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  36. from ray.rllib.utils.framework import try_import_tf
  37. from ray.rllib.utils.test_utils import check_learning_achieved
  38. from ray.tune import registry
  39. from ray.tune.logger import pretty_print
  40. tf1, tf, tfv = try_import_tf()
  41. SUPPORTED_ENVS = [
  42. "RepeatAfterMeEnv", "RepeatInitialObsEnv", "LookAndPush",
  43. "StatelessCartPole"
  44. ]
  45. def get_cli_args():
  46. """Create CLI parser and return parsed arguments"""
  47. parser = argparse.ArgumentParser()
  48. # example-specific args
  49. parser.add_argument(
  50. "--no-attention",
  51. action="store_true",
  52. help="Do NOT use attention. For comparison: The agent will not learn.")
  53. parser.add_argument(
  54. "--env", choices=SUPPORTED_ENVS, default="RepeatAfterMeEnv")
  55. # general args
  56. parser.add_argument(
  57. "--run", default="PPO", help="The RLlib-registered algorithm to use.")
  58. parser.add_argument("--num-cpus", type=int, default=3)
  59. parser.add_argument(
  60. "--framework",
  61. choices=["tf", "tf2", "tfe", "torch"],
  62. default="tf",
  63. help="The DL framework specifier.")
  64. parser.add_argument(
  65. "--stop-iters",
  66. type=int,
  67. default=200,
  68. help="Number of iterations to train.")
  69. parser.add_argument(
  70. "--stop-timesteps",
  71. type=int,
  72. default=500000,
  73. help="Number of timesteps to train.")
  74. parser.add_argument(
  75. "--stop-reward",
  76. type=float,
  77. default=80.0,
  78. help="Reward at which we stop training.")
  79. parser.add_argument(
  80. "--as-test",
  81. action="store_true",
  82. help="Whether this script should be run as a test: --stop-reward must "
  83. "be achieved within --stop-timesteps AND --stop-iters.")
  84. parser.add_argument(
  85. "--no-tune",
  86. action="store_true",
  87. help="Run without Tune using a manual train loop instead. Here,"
  88. "there is no TensorBoard support.")
  89. parser.add_argument(
  90. "--local-mode",
  91. action="store_true",
  92. help="Init Ray in local mode for easier debugging.")
  93. args = parser.parse_args()
  94. print(f"Running with following CLI args: {args}")
  95. return args
  96. if __name__ == "__main__":
  97. args = get_cli_args()
  98. ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
  99. # register custom environments
  100. registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
  101. registry.register_env("RepeatInitialObsEnv",
  102. lambda _: RepeatInitialObsEnv())
  103. registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush()))
  104. registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())
  105. # main part: RLlib config with AttentionNet model
  106. config = {
  107. "env": args.env,
  108. # This env_config is only used for the RepeatAfterMeEnv env.
  109. "env_config": {
  110. "repeat_delay": 2,
  111. },
  112. "gamma": 0.99,
  113. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  114. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", 0)),
  115. "num_envs_per_worker": 20,
  116. "entropy_coeff": 0.001,
  117. "num_sgd_iter": 10,
  118. "vf_loss_coeff": 1e-5,
  119. "model": {
  120. # Attention net wrapping (for tf) can already use the native keras
  121. # model versions. For torch, this will have no effect.
  122. "_use_default_native_models": True,
  123. "use_attention": not args.no_attention,
  124. "max_seq_len": 10,
  125. "attention_num_transformer_units": 1,
  126. "attention_dim": 32,
  127. "attention_memory_inference": 10,
  128. "attention_memory_training": 10,
  129. "attention_num_heads": 1,
  130. "attention_head_dim": 32,
  131. "attention_position_wise_mlp_dim": 32,
  132. },
  133. "framework": args.framework,
  134. }
  135. stop = {
  136. "training_iteration": args.stop_iters,
  137. "timesteps_total": args.stop_timesteps,
  138. "episode_reward_mean": args.stop_reward,
  139. }
  140. # Manual training loop (no Ray tune).
  141. if args.no_tune:
  142. # manual training loop using PPO and manually keeping track of state
  143. if args.run != "PPO":
  144. raise ValueError("Only support --run PPO with --no-tune.")
  145. ppo_config = ppo.DEFAULT_CONFIG.copy()
  146. ppo_config.update(config)
  147. trainer = ppo.PPOTrainer(config=ppo_config, env=args.env)
  148. # run manual training loop and print results after each iteration
  149. for _ in range(args.stop_iters):
  150. result = trainer.train()
  151. print(pretty_print(result))
  152. # stop training if the target train steps or reward are reached
  153. if result["timesteps_total"] >= args.stop_timesteps or \
  154. result["episode_reward_mean"] >= args.stop_reward:
  155. break
  156. # Run manual test loop (only for RepeatAfterMe env).
  157. if args.env == "RepeatAfterMeEnv":
  158. print("Finished training. Running manual test/inference loop.")
  159. # prepare env
  160. env = RepeatAfterMeEnv(config["env_config"])
  161. obs = env.reset()
  162. done = False
  163. total_reward = 0
  164. # start with all zeros as state
  165. num_transformers = config["model"][
  166. "attention_num_transformer_units"]
  167. init_state = state = [
  168. np.zeros([100, 32], np.float32)
  169. for _ in range(num_transformers)
  170. ]
  171. # run one iteration until done
  172. print(f"RepeatAfterMeEnv with {config['env_config']}")
  173. while not done:
  174. action, state_out, _ = trainer.compute_single_action(
  175. obs, state)
  176. next_obs, reward, done, _ = env.step(action)
  177. print(f"Obs: {obs}, Action: {action}, Reward: {reward}")
  178. obs = next_obs
  179. total_reward += reward
  180. state = [
  181. np.concatenate([state[i], [state_out[i]]], axis=0)[1:]
  182. for i in range(num_transformers)
  183. ]
  184. print(f"Total reward in test episode: {total_reward}")
  185. # Run with Tune for auto env and trainer creation and TensorBoard.
  186. else:
  187. results = tune.run(args.run, config=config, stop=stop, verbose=2)
  188. if args.as_test:
  189. print("Checking if learning goals were achieved")
  190. check_learning_achieved(results, args.stop_reward)
  191. ray.shutdown()