self_play_with_policy_checkpoint.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. """Example showing how one can restore a connector enabled TF policy
  2. checkpoint for a new self-play PyTorch training job.
  3. The checkpointed policy may be trained with a different algorithm too.
  4. """
  5. import argparse
  6. from functools import partial
  7. import os
  8. import tempfile
  9. import ray
  10. from ray import air, tune
  11. from ray.rllib.algorithms.callbacks import DefaultCallbacks
  12. from ray.rllib.algorithms.sac import SACConfig
  13. from ray.rllib.env.utils import try_import_pyspiel
  14. from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
  15. from ray.rllib.examples.connectors.prepare_checkpoint import (
  16. create_open_spiel_checkpoint,
  17. )
  18. from ray.rllib.policy.policy import Policy
  19. from ray.tune import CLIReporter, register_env
  20. pyspiel = try_import_pyspiel(error=True)
  21. register_env(
  22. "open_spiel_env", lambda _: OpenSpielEnv(pyspiel.load_game("connect_four"))
  23. )
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument(
  26. "--train_iteration",
  27. type=int,
  28. default=10,
  29. help="Number of iterations to train.",
  30. )
  31. args = parser.parse_args()
  32. MAIN_POLICY_ID = "main"
  33. OPPONENT_POLICY_ID = "opponent"
  34. class AddPolicyCallback(DefaultCallbacks):
  35. def __init__(self, checkpoint_dir):
  36. self._checkpoint_dir = checkpoint_dir
  37. super().__init__()
  38. def on_algorithm_init(self, *, algorithm, **kwargs):
  39. policy = Policy.from_checkpoint(
  40. self._checkpoint_dir, policy_ids=[OPPONENT_POLICY_ID]
  41. )
  42. # Add restored policy to Algorithm.
  43. # Note that this policy doesn't have to be trained with the same algorithm
  44. # of the training stack. You can even mix up TF policies with a Torch stack.
  45. algorithm.add_policy(
  46. policy_id=OPPONENT_POLICY_ID,
  47. policy=policy,
  48. evaluation_workers=True,
  49. )
  50. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  51. # main policy plays against opponent policy.
  52. return MAIN_POLICY_ID if episode.episode_id % 2 == agent_id else OPPONENT_POLICY_ID
  53. def main(checkpoint_dir):
  54. config = (
  55. SACConfig()
  56. .environment("open_spiel_env")
  57. .framework("torch")
  58. .callbacks(partial(AddPolicyCallback, checkpoint_dir))
  59. .rollouts(
  60. num_rollout_workers=1,
  61. num_envs_per_worker=5,
  62. # We will be restoring a TF2 policy.
  63. # So tell the RolloutWorkers to enable TF eager exec as well, even if
  64. # framework is set to torch.
  65. enable_tf1_exec_eagerly=True,
  66. )
  67. .training(model={"fcnet_hiddens": [512, 512]})
  68. .multi_agent(
  69. # Initial policy map: Random and PPO. This will be expanded
  70. # to more policy snapshots taken from "main" against which "main"
  71. # will then play (instead of "random"). This is done in the
  72. # custom callback defined above (`SelfPlayCallback`).
  73. # Note: We will add the "opponent" policy with callback.
  74. policies={MAIN_POLICY_ID}, # Our main policy, we'd like to optimize.
  75. # Assign agent 0 and 1 randomly to the "main" policy or
  76. # to the opponent ("random" at first). Make sure (via episode_id)
  77. # that "main" always plays against "random" (and not against
  78. # another "main").
  79. policy_mapping_fn=policy_mapping_fn,
  80. # Always just train the "main" policy.
  81. policies_to_train=[MAIN_POLICY_ID],
  82. )
  83. )
  84. stop = {
  85. "training_iteration": args.train_iteration,
  86. }
  87. # Train the "main" policy to play really well using self-play.
  88. tuner = tune.Tuner(
  89. "SAC",
  90. param_space=config.to_dict(),
  91. run_config=air.RunConfig(
  92. stop=stop,
  93. checkpoint_config=air.CheckpointConfig(
  94. checkpoint_at_end=True,
  95. checkpoint_frequency=10,
  96. ),
  97. verbose=2,
  98. progress_reporter=CLIReporter(
  99. metric_columns={
  100. "training_iteration": "iter",
  101. "time_total_s": "time_total_s",
  102. "timesteps_total": "ts",
  103. "episodes_this_iter": "train_episodes",
  104. "policy_reward_mean/main": "reward_main",
  105. },
  106. sort_by_metric=True,
  107. ),
  108. ),
  109. )
  110. tuner.fit()
  111. if __name__ == "__main__":
  112. ray.init()
  113. with tempfile.TemporaryDirectory() as tmpdir:
  114. create_open_spiel_checkpoint(tmpdir)
  115. policy_checkpoint_path = os.path.join(
  116. tmpdir,
  117. "checkpoint_000000",
  118. "policies",
  119. OPPONENT_POLICY_ID,
  120. )
  121. main(policy_checkpoint_path)
  122. ray.shutdown()