multi_agent_cartpole.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. """Simple example of setting up a multi-agent policy mapping.
  2. Control the number of agents and policies via --num-agents and --num-policies.
  3. This works with hundreds of agents and policies, but note that initializing
  4. many TF policies will take some time.
  5. Also, TF evals might slow down with large numbers of policies. To debug TF
  6. execution, set the TF_TIMELINE_DIR environment variable.
  7. """
  8. import argparse
  9. import os
  10. import random
  11. import ray
  12. from ray import tune
  13. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  14. from ray.rllib.examples.models.shared_weights_model import \
  15. SharedWeightsModel1, SharedWeightsModel2, TF2SharedWeightsModel, \
  16. TorchSharedWeightsModel
  17. from ray.rllib.models import ModelCatalog
  18. from ray.rllib.policy.policy import PolicySpec
  19. from ray.rllib.utils.framework import try_import_tf
  20. from ray.rllib.utils.test_utils import check_learning_achieved
  21. tf1, tf, tfv = try_import_tf()
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument("--num-agents", type=int, default=4)
  24. parser.add_argument("--num-policies", type=int, default=2)
  25. parser.add_argument("--num-cpus", type=int, default=0)
  26. parser.add_argument(
  27. "--framework",
  28. choices=["tf", "tf2", "tfe", "torch"],
  29. default="tf",
  30. help="The DL framework specifier.")
  31. parser.add_argument(
  32. "--as-test",
  33. action="store_true",
  34. help="Whether this script should be run as a test: --stop-reward must "
  35. "be achieved within --stop-timesteps AND --stop-iters.")
  36. parser.add_argument(
  37. "--stop-iters",
  38. type=int,
  39. default=200,
  40. help="Number of iterations to train.")
  41. parser.add_argument(
  42. "--stop-timesteps",
  43. type=int,
  44. default=100000,
  45. help="Number of timesteps to train.")
  46. parser.add_argument(
  47. "--stop-reward",
  48. type=float,
  49. default=150.0,
  50. help="Reward at which we stop training.")
  51. if __name__ == "__main__":
  52. args = parser.parse_args()
  53. ray.init(num_cpus=args.num_cpus or None)
  54. # Register the models to use.
  55. if args.framework == "torch":
  56. mod1 = mod2 = TorchSharedWeightsModel
  57. elif args.framework in ["tfe", "tf2"]:
  58. mod1 = mod2 = TF2SharedWeightsModel
  59. else:
  60. mod1 = SharedWeightsModel1
  61. mod2 = SharedWeightsModel2
  62. ModelCatalog.register_custom_model("model1", mod1)
  63. ModelCatalog.register_custom_model("model2", mod2)
  64. # Each policy can have a different configuration (including custom model).
  65. def gen_policy(i):
  66. config = {
  67. "model": {
  68. "custom_model": ["model1", "model2"][i % 2],
  69. },
  70. "gamma": random.choice([0.95, 0.99]),
  71. }
  72. return PolicySpec(config=config)
  73. # Setup PPO with an ensemble of `num_policies` different policies.
  74. policies = {
  75. "policy_{}".format(i): gen_policy(i)
  76. for i in range(args.num_policies)
  77. }
  78. policy_ids = list(policies.keys())
  79. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  80. pol_id = random.choice(policy_ids)
  81. return pol_id
  82. config = {
  83. "env": MultiAgentCartPole,
  84. "env_config": {
  85. "num_agents": args.num_agents,
  86. },
  87. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  88. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  89. "num_sgd_iter": 10,
  90. "multiagent": {
  91. "policies": policies,
  92. "policy_mapping_fn": policy_mapping_fn,
  93. },
  94. "framework": args.framework,
  95. }
  96. stop = {
  97. "episode_reward_mean": args.stop_reward,
  98. "timesteps_total": args.stop_timesteps,
  99. "training_iteration": args.stop_iters,
  100. }
  101. results = tune.run("PPO", stop=stop, config=config, verbose=1)
  102. if args.as_test:
  103. check_learning_achieved(results, args.stop_reward)
  104. ray.shutdown()