multi_agent_two_trainers.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. """Example of using two different training methods at once in multi-agent.
  2. Here we create a number of CartPole agents, some of which are trained with
  3. DQN, and some of which are trained with PPO. We periodically sync weights
  4. between the two trainers (note that no such syncing is needed when using just
  5. a single training method).
  6. For a simpler example, see also: multiagent_cartpole.py
  7. """
  8. import argparse
  9. import gym
  10. import os
  11. import ray
  12. from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
  13. from ray.rllib.agents.ppo import PPOTrainer, PPOTFPolicy, PPOTorchPolicy
  14. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  15. from ray.tune.logger import pretty_print
  16. from ray.tune.registry import register_env
  17. parser = argparse.ArgumentParser()
  18. # Use torch for both policies.
  19. parser.add_argument(
  20. "--framework",
  21. choices=["tf", "tf2", "tfe", "torch"],
  22. default="tf",
  23. help="The DL framework specifier.")
  24. parser.add_argument(
  25. "--as-test",
  26. action="store_true",
  27. help="Whether this script should be run as a test: --stop-reward must "
  28. "be achieved within --stop-timesteps AND --stop-iters.")
  29. parser.add_argument(
  30. "--stop-iters",
  31. type=int,
  32. default=20,
  33. help="Number of iterations to train.")
  34. parser.add_argument(
  35. "--stop-timesteps",
  36. type=int,
  37. default=100000,
  38. help="Number of timesteps to train.")
  39. parser.add_argument(
  40. "--stop-reward",
  41. type=float,
  42. default=50.0,
  43. help="Reward at which we stop training.")
  44. if __name__ == "__main__":
  45. args = parser.parse_args()
  46. ray.init()
  47. # Simple environment with 4 independent cartpole entities
  48. register_env("multi_agent_cartpole",
  49. lambda _: MultiAgentCartPole({"num_agents": 4}))
  50. single_dummy_env = gym.make("CartPole-v0")
  51. obs_space = single_dummy_env.observation_space
  52. act_space = single_dummy_env.action_space
  53. # You can also have multiple policies per trainer, but here we just
  54. # show one each for PPO and DQN.
  55. policies = {
  56. "ppo_policy": (PPOTorchPolicy if args.framework == "torch" else
  57. PPOTFPolicy, obs_space, act_space, {}),
  58. "dqn_policy": (DQNTorchPolicy if args.framework == "torch" else
  59. DQNTFPolicy, obs_space, act_space, {}),
  60. }
  61. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  62. if agent_id % 2 == 0:
  63. return "ppo_policy"
  64. else:
  65. return "dqn_policy"
  66. ppo_trainer = PPOTrainer(
  67. env="multi_agent_cartpole",
  68. config={
  69. "multiagent": {
  70. "policies": policies,
  71. "policy_mapping_fn": policy_mapping_fn,
  72. "policies_to_train": ["ppo_policy"],
  73. },
  74. "model": {
  75. "vf_share_layers": True,
  76. },
  77. "num_sgd_iter": 6,
  78. "vf_loss_coeff": 0.01,
  79. # disable filters, otherwise we would need to synchronize those
  80. # as well to the DQN agent
  81. "observation_filter": "MeanStdFilter",
  82. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  83. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  84. "framework": args.framework,
  85. })
  86. dqn_trainer = DQNTrainer(
  87. env="multi_agent_cartpole",
  88. config={
  89. "multiagent": {
  90. "policies": policies,
  91. "policy_mapping_fn": policy_mapping_fn,
  92. "policies_to_train": ["dqn_policy"],
  93. },
  94. "model": {
  95. "vf_share_layers": True,
  96. },
  97. "gamma": 0.95,
  98. "n_step": 3,
  99. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  100. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  101. "framework": args.framework,
  102. })
  103. # You should see both the printed X and Y approach 200 as this trains:
  104. # info:
  105. # policy_reward_mean:
  106. # dqn_policy: X
  107. # ppo_policy: Y
  108. for i in range(args.stop_iters):
  109. print("== Iteration", i, "==")
  110. # improve the DQN policy
  111. print("-- DQN --")
  112. result_dqn = dqn_trainer.train()
  113. print(pretty_print(result_dqn))
  114. # improve the PPO policy
  115. print("-- PPO --")
  116. result_ppo = ppo_trainer.train()
  117. print(pretty_print(result_ppo))
  118. # Test passed gracefully.
  119. if args.as_test and \
  120. result_dqn["episode_reward_mean"] > args.stop_reward and \
  121. result_ppo["episode_reward_mean"] > args.stop_reward:
  122. print("test passed (both agents above requested reward)")
  123. quit(0)
  124. # swap weights to synchronize
  125. dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"]))
  126. ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))
  127. # Desired reward not reached.
  128. if args.as_test:
  129. raise ValueError("Desired reward ({}) not reached!".format(
  130. args.stop_reward))