restore_1_of_n_agents_from_checkpoint.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """Simple example of how to restore only one of n agents from a trained
  2. multi-agent Trainer using Ray tune.
  3. The trick/workaround is to use an intermediate trainer that loads the
  4. trained checkpoint into all policies and then reverts those policies
  5. that we don't want to restore, then saves a new checkpoint, from which
  6. tune can pick up training.
  7. Control the number of agents and policies via --num-agents and --num-policies.
  8. """
  9. import argparse
  10. import gym
  11. import os
  12. import random
  13. import ray
  14. from ray import tune
  15. from ray.rllib.agents.ppo import PPOTrainer
  16. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
  17. from ray.rllib.utils.framework import try_import_tf
  18. from ray.rllib.utils.test_utils import check_learning_achieved
  19. tf1, tf, tfv = try_import_tf()
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument("--num-agents", type=int, default=4)
  22. parser.add_argument("--num-policies", type=int, default=2)
  23. parser.add_argument("--pre-training-iters", type=int, default=5)
  24. parser.add_argument("--num-cpus", type=int, default=0)
  25. parser.add_argument(
  26. "--framework",
  27. choices=["tf", "tf2", "tfe", "torch"],
  28. default="tf",
  29. help="The DL framework specifier.")
  30. parser.add_argument(
  31. "--as-test",
  32. action="store_true",
  33. help="Whether this script should be run as a test: --stop-reward must "
  34. "be achieved within --stop-timesteps AND --stop-iters.")
  35. parser.add_argument(
  36. "--stop-iters",
  37. type=int,
  38. default=200,
  39. help="Number of iterations to train.")
  40. parser.add_argument(
  41. "--stop-timesteps",
  42. type=int,
  43. default=100000,
  44. help="Number of timesteps to train.")
  45. parser.add_argument(
  46. "--stop-reward",
  47. type=float,
  48. default=150.0,
  49. help="Reward at which we stop training.")
  50. if __name__ == "__main__":
  51. args = parser.parse_args()
  52. ray.init(num_cpus=args.num_cpus or None)
  53. # Get obs- and action Spaces.
  54. single_env = gym.make("CartPole-v0")
  55. obs_space = single_env.observation_space
  56. act_space = single_env.action_space
  57. # Setup PPO with an ensemble of `num_policies` different policies.
  58. policies = {
  59. f"policy_{i}": (None, obs_space, act_space, {})
  60. for i in range(args.num_policies)
  61. }
  62. policy_ids = list(policies.keys())
  63. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  64. pol_id = random.choice(policy_ids)
  65. return pol_id
  66. config = {
  67. "env": MultiAgentCartPole,
  68. "env_config": {
  69. "num_agents": args.num_agents,
  70. },
  71. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  72. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  73. "num_sgd_iter": 10,
  74. "multiagent": {
  75. "policies": policies,
  76. "policy_mapping_fn": policy_mapping_fn,
  77. },
  78. "framework": args.framework,
  79. }
  80. # Do some training and store the checkpoint.
  81. results = tune.run(
  82. "PPO",
  83. config=config,
  84. stop={"training_iteration": args.pre_training_iters},
  85. verbose=1,
  86. checkpoint_freq=1,
  87. checkpoint_at_end=True,
  88. )
  89. print("Pre-training done.")
  90. best_checkpoint = results.get_best_checkpoint(
  91. results.trials[0], mode="max")
  92. print(f".. best checkpoint was: {best_checkpoint}")
  93. # Create a new dummy Trainer to "fix" our checkpoint.
  94. new_trainer = PPOTrainer(config=config)
  95. # Get untrained weights for all policies.
  96. untrained_weights = new_trainer.get_weights()
  97. # Restore all policies from checkpoint.
  98. new_trainer.restore(best_checkpoint)
  99. # Set back all weights (except for 1st agent) to original
  100. # untrained weights.
  101. new_trainer.set_weights(
  102. {pid: w
  103. for pid, w in untrained_weights.items() if pid != "policy_0"})
  104. # Create the checkpoint from which tune can pick up the
  105. # experiment.
  106. new_checkpoint = new_trainer.save()
  107. new_trainer.stop()
  108. print(".. checkpoint to restore from (all policies reset, "
  109. f"except policy_0): {new_checkpoint}")
  110. print("Starting new tune.run")
  111. # Start our actual experiment.
  112. stop = {
  113. "episode_reward_mean": args.stop_reward,
  114. "timesteps_total": args.stop_timesteps,
  115. "training_iteration": args.stop_iters,
  116. }
  117. # Make sure, the non-1st policies are not updated anymore.
  118. config["multiagent"]["policies_to_train"] = [
  119. pid for pid in policy_ids if pid != "policy_0"
  120. ]
  121. results = tune.run(
  122. "PPO",
  123. stop=stop,
  124. config=config,
  125. verbose=1,
  126. restore=new_checkpoint,
  127. )
  128. if args.as_test:
  129. check_learning_achieved(results, args.stop_reward)
  130. ray.shutdown()