123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- """Example showing how one can implement a simple self-play training workflow.
- Uses the open spiel adapter of RLlib with the "connect_four" game and
- a multi-agent setup with a "main" policy and n "main_v[x]" policies
- (x=version number), which are all at-some-point-frozen copies of
- "main". At the very beginning, "main" plays against RandomPolicy.
- Checks for the training progress after each training update via a custom
- callback. We simply measure the win rate of "main" vs the opponent
- ("main_v[x]" or RandomPolicy at the beginning) by looking through the
- achieved rewards in the episodes in the train batch. If this win rate
- reaches some configurable threshold, we add a new policy to
- the policy map (a frozen copy of the current "main" one) and change the
- policy_mapping_fn to make new matches of "main" vs any of the previous
- versions of "main" (including the just added one).
- After training for n iterations, a configurable number of episodes can
- be played by the user against the "main" agent on the command line.
- """
- import argparse
- import numpy as np
- import os
- import pyspiel
- from open_spiel.python.rl_environment import Environment
- import sys
- import ray
- from ray import tune
- from ray.rllib.agents.callbacks import DefaultCallbacks
- from ray.rllib.agents.ppo import PPOTrainer
- from ray.rllib.examples.policy.random_policy import RandomPolicy
- from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
- from ray.rllib.policy.policy import PolicySpec
- from ray.tune import CLIReporter, register_env
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--framework",
- choices=["tf", "tf2", "tfe", "torch"],
- default="tf",
- help="The DL framework specifier.")
- parser.add_argument("--num-cpus", type=int, default=0)
- parser.add_argument("--num-workers", type=int, default=2)
- parser.add_argument(
- "--from-checkpoint",
- type=str,
- default=None,
- help="Full path to a checkpoint file for restoring a previously saved "
- "Trainer state.")
- parser.add_argument(
- "--env",
- type=str,
- default="connect_four",
- choices=["markov_soccer", "connect_four"])
- parser.add_argument(
- "--stop-iters",
- type=int,
- default=200,
- help="Number of iterations to train.")
- parser.add_argument(
- "--stop-timesteps",
- type=int,
- default=10000000,
- help="Number of timesteps to train.")
- parser.add_argument(
- "--win-rate-threshold",
- type=float,
- default=0.95,
- help="Win-rate at which we setup another opponent by freezing the "
- "current main policy and playing against a uniform distribution "
- "of previously frozen 'main's from here on.")
- parser.add_argument(
- "--num-episodes-human-play",
- type=int,
- default=10,
- help="How many episodes to play against the user on the command "
- "line after training has finished.")
- args = parser.parse_args()
- def ask_user_for_action(time_step):
- """Asks the user for a valid action on the command line and returns it.
- Re-queries the user until she picks a valid one.
- Args:
- time_step: The open spiel Environment time-step object.
- """
- pid = time_step.observations["current_player"]
- legal_moves = time_step.observations["legal_actions"][pid]
- choice = -1
- while choice not in legal_moves:
- print("Choose an action from {}:".format(legal_moves))
- sys.stdout.flush()
- choice_str = input()
- try:
- choice = int(choice_str)
- except ValueError:
- continue
- return choice
- class SelfPlayCallback(DefaultCallbacks):
- def __init__(self):
- super().__init__()
- # 0=RandomPolicy, 1=1st main policy snapshot,
- # 2=2nd main policy snapshot, etc..
- self.current_opponent = 0
- def on_train_result(self, *, trainer, result, **kwargs):
- # Get the win rate for the train batch.
- # Note that normally, one should set up a proper evaluation config,
- # such that evaluation always happens on the already updated policy,
- # instead of on the already used train_batch.
- main_rew = result["hist_stats"].pop("policy_main_reward")
- opponent_rew = list(result["hist_stats"].values())[0]
- assert len(main_rew) == len(opponent_rew)
- won = 0
- for r_main, r_opponent in zip(main_rew, opponent_rew):
- if r_main > r_opponent:
- won += 1
- win_rate = won / len(main_rew)
- result["win_rate"] = win_rate
- print(f"Iter={trainer.iteration} win-rate={win_rate} -> ", end="")
- # If win rate is good -> Snapshot current policy and play against
- # it next, keeping the snapshot fixed and only improving the "main"
- # policy.
- if win_rate > args.win_rate_threshold:
- self.current_opponent += 1
- new_pol_id = f"main_v{self.current_opponent}"
- print(f"adding new opponent to the mix ({new_pol_id}).")
- # Re-define the mapping function, such that "main" is forced
- # to play against any of the previously played policies
- # (excluding "random").
- def policy_mapping_fn(agent_id, episode, worker, **kwargs):
- # agent_id = [0|1] -> policy depends on episode ID
- # This way, we make sure that both policies sometimes play
- # (start player) and sometimes agent1 (player to move 2nd).
- return "main" if episode.episode_id % 2 == agent_id \
- else "main_v{}".format(np.random.choice(
- list(range(1, self.current_opponent + 1))))
- new_policy = trainer.add_policy(
- policy_id=new_pol_id,
- policy_cls=type(trainer.get_policy("main")),
- policy_mapping_fn=policy_mapping_fn,
- )
- # Set the weights of the new policy to the main policy.
- # We'll keep training the main policy, whereas `new_pol_id` will
- # remain fixed.
- main_state = trainer.get_policy("main").get_state()
- new_policy.set_state(main_state)
- # We need to sync the just copied local weights (from main policy)
- # to all the remote workers as well.
- trainer.workers.sync_weights()
- else:
- print("not good enough; will keep learning ...")
- # +2 = main + random
- result["league_size"] = self.current_opponent + 2
- if __name__ == "__main__":
- ray.init(num_cpus=args.num_cpus or None, include_dashboard=False)
- register_env("open_spiel_env",
- lambda _: OpenSpielEnv(pyspiel.load_game(args.env)))
- def policy_mapping_fn(agent_id, episode, worker, **kwargs):
- # agent_id = [0|1] -> policy depends on episode ID
- # This way, we make sure that both policies sometimes play agent0
- # (start player) and sometimes agent1 (player to move 2nd).
- return "main" if episode.episode_id % 2 == agent_id else "random"
- config = {
- "env": "open_spiel_env",
- "callbacks": SelfPlayCallback,
- "model": {
- "fcnet_hiddens": [512, 512],
- },
- "num_sgd_iter": 20,
- "num_envs_per_worker": 5,
- "multiagent": {
- # Initial policy map: Random and PPO. This will be expanded
- # to more policy snapshots taken from "main" against which "main"
- # will then play (instead of "random"). This is done in the
- # custom callback defined above (`SelfPlayCallback`).
- "policies": {
- # Our main policy, we'd like to optimize.
- "main": PolicySpec(),
- # An initial random opponent to play against.
- "random": PolicySpec(policy_class=RandomPolicy),
- },
- # Assign agent 0 and 1 randomly to the "main" policy or
- # to the opponent ("random" at first). Make sure (via episode_id)
- # that "main" always plays against "random" (and not against
- # another "main").
- "policy_mapping_fn": policy_mapping_fn,
- # Always just train the "main" policy.
- "policies_to_train": ["main"],
- },
- "num_workers": args.num_workers,
- # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
- "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
- "framework": args.framework,
- }
- stop = {
- "timesteps_total": args.stop_timesteps,
- "training_iteration": args.stop_iters,
- }
- # Train the "main" policy to play really well using self-play.
- results = None
- if not args.from_checkpoint:
- results = tune.run(
- "PPO",
- config=config,
- stop=stop,
- checkpoint_at_end=True,
- checkpoint_freq=10,
- verbose=2,
- progress_reporter=CLIReporter(
- metric_columns={
- "training_iteration": "iter",
- "time_total_s": "time_total_s",
- "timesteps_total": "ts",
- "episodes_this_iter": "train_episodes",
- "policy_reward_mean/main": "reward",
- "win_rate": "win_rate",
- "league_size": "league_size",
- },
- sort_by_metric=True,
- ),
- )
- # Restore trained trainer (set to non-explore behavior) and play against
- # human on command line.
- if args.num_episodes_human_play > 0:
- num_episodes = 0
- trainer = PPOTrainer(config=dict(config, **{"explore": False}))
- if args.from_checkpoint:
- trainer.restore(args.from_checkpoint)
- else:
- checkpoint = results.get_last_checkpoint()
- if not checkpoint:
- raise ValueError("No last checkpoint found in results!")
- trainer.restore(checkpoint)
- # Play from the command line against the trained agent
- # in an actual (non-RLlib-wrapped) open-spiel env.
- human_player = 1
- env = Environment(args.env)
- while num_episodes < args.num_episodes_human_play:
- print("You play as {}".format("o" if human_player else "x"))
- time_step = env.reset()
- while not time_step.last():
- player_id = time_step.observations["current_player"]
- if player_id == human_player:
- action = ask_user_for_action(time_step)
- else:
- obs = np.array(
- time_step.observations["info_state"][player_id])
- action = trainer.compute_single_action(
- obs, policy_id="main")
- # In case computer chooses an invalid action, pick a
- # random one.
- legal = time_step.observations["legal_actions"][player_id]
- if action not in legal:
- action = np.random.choice(legal)
- time_step = env.step([action])
- print(f"\n{env.get_state}")
- print(f"\n{env.get_state}")
- print("End of game!")
- if time_step.rewards[human_player] > 0:
- print("You win")
- elif time_step.rewards[human_player] < 0:
- print("You lose")
- else:
- print("Draw")
- # Switch order of players
- human_player = 1 - human_player
- num_episodes += 1
- ray.shutdown()
|