self_play_with_open_spiel.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. """Example showing how one can implement a simple self-play training workflow.
  2. Uses the open spiel adapter of RLlib with the "connect_four" game and
  3. a multi-agent setup with a "main" policy and n "main_v[x]" policies
  4. (x=version number), which are all at-some-point-frozen copies of
  5. "main". At the very beginning, "main" plays against RandomPolicy.
  6. Checks for the training progress after each training update via a custom
  7. callback. We simply measure the win rate of "main" vs the opponent
  8. ("main_v[x]" or RandomPolicy at the beginning) by looking through the
  9. achieved rewards in the episodes in the train batch. If this win rate
  10. reaches some configurable threshold, we add a new policy to
  11. the policy map (a frozen copy of the current "main" one) and change the
  12. policy_mapping_fn to make new matches of "main" vs any of the previous
  13. versions of "main" (including the just added one).
  14. After training for n iterations, a configurable number of episodes can
  15. be played by the user against the "main" agent on the command line.
  16. """
  17. import argparse
  18. import numpy as np
  19. import os
  20. import pyspiel
  21. from open_spiel.python.rl_environment import Environment
  22. import sys
  23. import ray
  24. from ray import tune
  25. from ray.rllib.agents.callbacks import DefaultCallbacks
  26. from ray.rllib.agents.ppo import PPOTrainer
  27. from ray.rllib.examples.policy.random_policy import RandomPolicy
  28. from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
  29. from ray.rllib.policy.policy import PolicySpec
  30. from ray.tune import CLIReporter, register_env
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument(
  33. "--framework",
  34. choices=["tf", "tf2", "tfe", "torch"],
  35. default="tf",
  36. help="The DL framework specifier.")
  37. parser.add_argument("--num-cpus", type=int, default=0)
  38. parser.add_argument("--num-workers", type=int, default=2)
  39. parser.add_argument(
  40. "--from-checkpoint",
  41. type=str,
  42. default=None,
  43. help="Full path to a checkpoint file for restoring a previously saved "
  44. "Trainer state.")
  45. parser.add_argument(
  46. "--env",
  47. type=str,
  48. default="connect_four",
  49. choices=["markov_soccer", "connect_four"])
  50. parser.add_argument(
  51. "--stop-iters",
  52. type=int,
  53. default=200,
  54. help="Number of iterations to train.")
  55. parser.add_argument(
  56. "--stop-timesteps",
  57. type=int,
  58. default=10000000,
  59. help="Number of timesteps to train.")
  60. parser.add_argument(
  61. "--win-rate-threshold",
  62. type=float,
  63. default=0.95,
  64. help="Win-rate at which we setup another opponent by freezing the "
  65. "current main policy and playing against a uniform distribution "
  66. "of previously frozen 'main's from here on.")
  67. parser.add_argument(
  68. "--num-episodes-human-play",
  69. type=int,
  70. default=10,
  71. help="How many episodes to play against the user on the command "
  72. "line after training has finished.")
  73. args = parser.parse_args()
  74. def ask_user_for_action(time_step):
  75. """Asks the user for a valid action on the command line and returns it.
  76. Re-queries the user until she picks a valid one.
  77. Args:
  78. time_step: The open spiel Environment time-step object.
  79. """
  80. pid = time_step.observations["current_player"]
  81. legal_moves = time_step.observations["legal_actions"][pid]
  82. choice = -1
  83. while choice not in legal_moves:
  84. print("Choose an action from {}:".format(legal_moves))
  85. sys.stdout.flush()
  86. choice_str = input()
  87. try:
  88. choice = int(choice_str)
  89. except ValueError:
  90. continue
  91. return choice
  92. class SelfPlayCallback(DefaultCallbacks):
  93. def __init__(self):
  94. super().__init__()
  95. # 0=RandomPolicy, 1=1st main policy snapshot,
  96. # 2=2nd main policy snapshot, etc..
  97. self.current_opponent = 0
  98. def on_train_result(self, *, trainer, result, **kwargs):
  99. # Get the win rate for the train batch.
  100. # Note that normally, one should set up a proper evaluation config,
  101. # such that evaluation always happens on the already updated policy,
  102. # instead of on the already used train_batch.
  103. main_rew = result["hist_stats"].pop("policy_main_reward")
  104. opponent_rew = list(result["hist_stats"].values())[0]
  105. assert len(main_rew) == len(opponent_rew)
  106. won = 0
  107. for r_main, r_opponent in zip(main_rew, opponent_rew):
  108. if r_main > r_opponent:
  109. won += 1
  110. win_rate = won / len(main_rew)
  111. result["win_rate"] = win_rate
  112. print(f"Iter={trainer.iteration} win-rate={win_rate} -> ", end="")
  113. # If win rate is good -> Snapshot current policy and play against
  114. # it next, keeping the snapshot fixed and only improving the "main"
  115. # policy.
  116. if win_rate > args.win_rate_threshold:
  117. self.current_opponent += 1
  118. new_pol_id = f"main_v{self.current_opponent}"
  119. print(f"adding new opponent to the mix ({new_pol_id}).")
  120. # Re-define the mapping function, such that "main" is forced
  121. # to play against any of the previously played policies
  122. # (excluding "random").
  123. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  124. # agent_id = [0|1] -> policy depends on episode ID
  125. # This way, we make sure that both policies sometimes play
  126. # (start player) and sometimes agent1 (player to move 2nd).
  127. return "main" if episode.episode_id % 2 == agent_id \
  128. else "main_v{}".format(np.random.choice(
  129. list(range(1, self.current_opponent + 1))))
  130. new_policy = trainer.add_policy(
  131. policy_id=new_pol_id,
  132. policy_cls=type(trainer.get_policy("main")),
  133. policy_mapping_fn=policy_mapping_fn,
  134. )
  135. # Set the weights of the new policy to the main policy.
  136. # We'll keep training the main policy, whereas `new_pol_id` will
  137. # remain fixed.
  138. main_state = trainer.get_policy("main").get_state()
  139. new_policy.set_state(main_state)
  140. # We need to sync the just copied local weights (from main policy)
  141. # to all the remote workers as well.
  142. trainer.workers.sync_weights()
  143. else:
  144. print("not good enough; will keep learning ...")
  145. # +2 = main + random
  146. result["league_size"] = self.current_opponent + 2
  147. if __name__ == "__main__":
  148. ray.init(num_cpus=args.num_cpus or None, include_dashboard=False)
  149. register_env("open_spiel_env",
  150. lambda _: OpenSpielEnv(pyspiel.load_game(args.env)))
  151. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  152. # agent_id = [0|1] -> policy depends on episode ID
  153. # This way, we make sure that both policies sometimes play agent0
  154. # (start player) and sometimes agent1 (player to move 2nd).
  155. return "main" if episode.episode_id % 2 == agent_id else "random"
  156. config = {
  157. "env": "open_spiel_env",
  158. "callbacks": SelfPlayCallback,
  159. "model": {
  160. "fcnet_hiddens": [512, 512],
  161. },
  162. "num_sgd_iter": 20,
  163. "num_envs_per_worker": 5,
  164. "multiagent": {
  165. # Initial policy map: Random and PPO. This will be expanded
  166. # to more policy snapshots taken from "main" against which "main"
  167. # will then play (instead of "random"). This is done in the
  168. # custom callback defined above (`SelfPlayCallback`).
  169. "policies": {
  170. # Our main policy, we'd like to optimize.
  171. "main": PolicySpec(),
  172. # An initial random opponent to play against.
  173. "random": PolicySpec(policy_class=RandomPolicy),
  174. },
  175. # Assign agent 0 and 1 randomly to the "main" policy or
  176. # to the opponent ("random" at first). Make sure (via episode_id)
  177. # that "main" always plays against "random" (and not against
  178. # another "main").
  179. "policy_mapping_fn": policy_mapping_fn,
  180. # Always just train the "main" policy.
  181. "policies_to_train": ["main"],
  182. },
  183. "num_workers": args.num_workers,
  184. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  185. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  186. "framework": args.framework,
  187. }
  188. stop = {
  189. "timesteps_total": args.stop_timesteps,
  190. "training_iteration": args.stop_iters,
  191. }
  192. # Train the "main" policy to play really well using self-play.
  193. results = None
  194. if not args.from_checkpoint:
  195. results = tune.run(
  196. "PPO",
  197. config=config,
  198. stop=stop,
  199. checkpoint_at_end=True,
  200. checkpoint_freq=10,
  201. verbose=2,
  202. progress_reporter=CLIReporter(
  203. metric_columns={
  204. "training_iteration": "iter",
  205. "time_total_s": "time_total_s",
  206. "timesteps_total": "ts",
  207. "episodes_this_iter": "train_episodes",
  208. "policy_reward_mean/main": "reward",
  209. "win_rate": "win_rate",
  210. "league_size": "league_size",
  211. },
  212. sort_by_metric=True,
  213. ),
  214. )
  215. # Restore trained trainer (set to non-explore behavior) and play against
  216. # human on command line.
  217. if args.num_episodes_human_play > 0:
  218. num_episodes = 0
  219. trainer = PPOTrainer(config=dict(config, **{"explore": False}))
  220. if args.from_checkpoint:
  221. trainer.restore(args.from_checkpoint)
  222. else:
  223. checkpoint = results.get_last_checkpoint()
  224. if not checkpoint:
  225. raise ValueError("No last checkpoint found in results!")
  226. trainer.restore(checkpoint)
  227. # Play from the command line against the trained agent
  228. # in an actual (non-RLlib-wrapped) open-spiel env.
  229. human_player = 1
  230. env = Environment(args.env)
  231. while num_episodes < args.num_episodes_human_play:
  232. print("You play as {}".format("o" if human_player else "x"))
  233. time_step = env.reset()
  234. while not time_step.last():
  235. player_id = time_step.observations["current_player"]
  236. if player_id == human_player:
  237. action = ask_user_for_action(time_step)
  238. else:
  239. obs = np.array(
  240. time_step.observations["info_state"][player_id])
  241. action = trainer.compute_single_action(
  242. obs, policy_id="main")
  243. # In case computer chooses an invalid action, pick a
  244. # random one.
  245. legal = time_step.observations["legal_actions"][player_id]
  246. if action not in legal:
  247. action = np.random.choice(legal)
  248. time_step = env.step([action])
  249. print(f"\n{env.get_state}")
  250. print(f"\n{env.get_state}")
  251. print("End of game!")
  252. if time_step.rewards[human_player] > 0:
  253. print("You win")
  254. elif time_step.rewards[human_player] < 0:
  255. print("You lose")
  256. else:
  257. print("Draw")
  258. # Switch order of players
  259. human_player = 1 - human_player
  260. num_episodes += 1
  261. ray.shutdown()