hierarchical_training.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """Example of hierarchical training using the multi-agent API.
  2. The example env is that of a "windy maze". The agent observes the current wind
  3. direction and can either choose to stand still, or move in that direction.
  4. You can try out the env directly with:
  5. $ python hierarchical_training.py --flat
  6. A simple hierarchical formulation involves a high-level agent that issues goals
  7. (i.e., go north / south / east / west), and a low-level agent that executes
  8. these goals over a number of time-steps. This can be implemented as a
  9. multi-agent environment with a top-level agent and low-level agents spawned
  10. for each higher-level action. The lower level agent is rewarded for moving
  11. in the right direction.
  12. You can try this formulation with:
  13. $ python hierarchical_training.py # gets ~100 rew after ~100k timesteps
  14. Note that the hierarchical formulation actually converges slightly slower than
  15. using --flat in this example.
  16. """
  17. import argparse
  18. from gym.spaces import Discrete, Tuple
  19. import logging
  20. import os
  21. import ray
  22. from ray import tune
  23. from ray.tune import function
  24. from ray.rllib.examples.env.windy_maze_env import WindyMazeEnv, \
  25. HierarchicalWindyMazeEnv
  26. from ray.rllib.utils.test_utils import check_learning_achieved
  27. parser = argparse.ArgumentParser()
  28. parser.add_argument("--flat", action="store_true")
  29. parser.add_argument(
  30. "--framework",
  31. choices=["tf", "tf2", "tfe", "torch"],
  32. default="tf",
  33. help="The DL framework specifier.")
  34. parser.add_argument(
  35. "--as-test",
  36. action="store_true",
  37. help="Whether this script should be run as a test: --stop-reward must "
  38. "be achieved within --stop-timesteps AND --stop-iters.")
  39. parser.add_argument(
  40. "--stop-iters",
  41. type=int,
  42. default=200,
  43. help="Number of iterations to train.")
  44. parser.add_argument(
  45. "--stop-timesteps",
  46. type=int,
  47. default=100000,
  48. help="Number of timesteps to train.")
  49. parser.add_argument(
  50. "--stop-reward",
  51. type=float,
  52. default=0.0,
  53. help="Reward at which we stop training.")
  54. logger = logging.getLogger(__name__)
  55. if __name__ == "__main__":
  56. args = parser.parse_args()
  57. ray.init()
  58. stop = {
  59. "training_iteration": args.stop_iters,
  60. "timesteps_total": args.stop_timesteps,
  61. "episode_reward_mean": args.stop_reward,
  62. }
  63. if args.flat:
  64. results = tune.run(
  65. "PPO",
  66. stop=stop,
  67. config={
  68. "env": WindyMazeEnv,
  69. "num_workers": 0,
  70. "framework": args.framework,
  71. },
  72. )
  73. else:
  74. maze = WindyMazeEnv(None)
  75. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  76. if agent_id.startswith("low_level_"):
  77. return "low_level_policy"
  78. else:
  79. return "high_level_policy"
  80. config = {
  81. "env": HierarchicalWindyMazeEnv,
  82. "num_workers": 0,
  83. "entropy_coeff": 0.01,
  84. "multiagent": {
  85. "policies": {
  86. "high_level_policy": (None, maze.observation_space,
  87. Discrete(4), {
  88. "gamma": 0.9
  89. }),
  90. "low_level_policy": (None,
  91. Tuple([
  92. maze.observation_space,
  93. Discrete(4)
  94. ]), maze.action_space, {
  95. "gamma": 0.0
  96. }),
  97. },
  98. "policy_mapping_fn": function(policy_mapping_fn),
  99. },
  100. "framework": args.framework,
  101. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  102. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  103. }
  104. results = tune.run("PPO", stop=stop, config=config, verbose=1)
  105. if args.as_test:
  106. check_learning_achieved(results, args.stop_reward)
  107. ray.shutdown()