"""Example showing how one can implement a simple self-play training workflow. Uses the open spiel adapter of RLlib with the "connect_four" game and a multi-agent setup with a "main" policy and n "main_v[x]" policies (x=version number), which are all at-some-point-frozen copies of "main". At the very beginning, "main" plays against RandomPolicy. Checks for the training progress after each training update via a custom callback. We simply measure the win rate of "main" vs the opponent ("main_v[x]" or RandomPolicy at the beginning) by looking through the achieved rewards in the episodes in the train batch. If this win rate reaches some configurable threshold, we add a new policy to the policy map (a frozen copy of the current "main" one) and change the policy_mapping_fn to make new matches of "main" vs any of the previous versions of "main" (including the just added one). After training for n iterations, a configurable number of episodes can be played by the user against the "main" agent on the command line. """ import argparse import numpy as np import os import pyspiel from open_spiel.python.rl_environment import Environment import sys import ray from ray import tune from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv from ray.rllib.policy.policy import PolicySpec from ray.tune import CLIReporter, register_env parser = argparse.ArgumentParser() parser.add_argument( "--framework", choices=["tf", "tf2", "tfe", "torch"], default="tf", help="The DL framework specifier.") parser.add_argument("--num-cpus", type=int, default=0) parser.add_argument("--num-workers", type=int, default=2) parser.add_argument( "--from-checkpoint", type=str, default=None, help="Full path to a checkpoint file for restoring a previously saved " "Trainer state.") parser.add_argument( "--env", type=str, default="connect_four", choices=["markov_soccer", "connect_four"]) parser.add_argument( "--stop-iters", type=int, default=200, help="Number of iterations to train.") parser.add_argument( "--stop-timesteps", type=int, default=10000000, help="Number of timesteps to train.") parser.add_argument( "--win-rate-threshold", type=float, default=0.95, help="Win-rate at which we setup another opponent by freezing the " "current main policy and playing against a uniform distribution " "of previously frozen 'main's from here on.") parser.add_argument( "--num-episodes-human-play", type=int, default=10, help="How many episodes to play against the user on the command " "line after training has finished.") args = parser.parse_args() def ask_user_for_action(time_step): """Asks the user for a valid action on the command line and returns it. Re-queries the user until she picks a valid one. Args: time_step: The open spiel Environment time-step object. """ pid = time_step.observations["current_player"] legal_moves = time_step.observations["legal_actions"][pid] choice = -1 while choice not in legal_moves: print("Choose an action from {}:".format(legal_moves)) sys.stdout.flush() choice_str = input() try: choice = int(choice_str) except ValueError: continue return choice class SelfPlayCallback(DefaultCallbacks): def __init__(self): super().__init__() # 0=RandomPolicy, 1=1st main policy snapshot, # 2=2nd main policy snapshot, etc.. self.current_opponent = 0 def on_train_result(self, *, trainer, result, **kwargs): # Get the win rate for the train batch. # Note that normally, one should set up a proper evaluation config, # such that evaluation always happens on the already updated policy, # instead of on the already used train_batch. main_rew = result["hist_stats"].pop("policy_main_reward") opponent_rew = list(result["hist_stats"].values())[0] assert len(main_rew) == len(opponent_rew) won = 0 for r_main, r_opponent in zip(main_rew, opponent_rew): if r_main > r_opponent: won += 1 win_rate = won / len(main_rew) result["win_rate"] = win_rate print(f"Iter={trainer.iteration} win-rate={win_rate} -> ", end="") # If win rate is good -> Snapshot current policy and play against # it next, keeping the snapshot fixed and only improving the "main" # policy. if win_rate > args.win_rate_threshold: self.current_opponent += 1 new_pol_id = f"main_v{self.current_opponent}" print(f"adding new opponent to the mix ({new_pol_id}).") # Re-define the mapping function, such that "main" is forced # to play against any of the previously played policies # (excluding "random"). def policy_mapping_fn(agent_id, episode, worker, **kwargs): # agent_id = [0|1] -> policy depends on episode ID # This way, we make sure that both policies sometimes play # (start player) and sometimes agent1 (player to move 2nd). return "main" if episode.episode_id % 2 == agent_id \ else "main_v{}".format(np.random.choice( list(range(1, self.current_opponent + 1)))) new_policy = trainer.add_policy( policy_id=new_pol_id, policy_cls=type(trainer.get_policy("main")), policy_mapping_fn=policy_mapping_fn, ) # Set the weights of the new policy to the main policy. # We'll keep training the main policy, whereas `new_pol_id` will # remain fixed. main_state = trainer.get_policy("main").get_state() new_policy.set_state(main_state) # We need to sync the just copied local weights (from main policy) # to all the remote workers as well. trainer.workers.sync_weights() else: print("not good enough; will keep learning ...") # +2 = main + random result["league_size"] = self.current_opponent + 2 if __name__ == "__main__": ray.init(num_cpus=args.num_cpus or None, include_dashboard=False) register_env("open_spiel_env", lambda _: OpenSpielEnv(pyspiel.load_game(args.env))) def policy_mapping_fn(agent_id, episode, worker, **kwargs): # agent_id = [0|1] -> policy depends on episode ID # This way, we make sure that both policies sometimes play agent0 # (start player) and sometimes agent1 (player to move 2nd). return "main" if episode.episode_id % 2 == agent_id else "random" config = { "env": "open_spiel_env", "callbacks": SelfPlayCallback, "model": { "fcnet_hiddens": [512, 512], }, "num_sgd_iter": 20, "num_envs_per_worker": 5, "multiagent": { # Initial policy map: Random and PPO. This will be expanded # to more policy snapshots taken from "main" against which "main" # will then play (instead of "random"). This is done in the # custom callback defined above (`SelfPlayCallback`). "policies": { # Our main policy, we'd like to optimize. "main": PolicySpec(), # An initial random opponent to play against. "random": PolicySpec(policy_class=RandomPolicy), }, # Assign agent 0 and 1 randomly to the "main" policy or # to the opponent ("random" at first). Make sure (via episode_id) # that "main" always plays against "random" (and not against # another "main"). "policy_mapping_fn": policy_mapping_fn, # Always just train the "main" policy. "policies_to_train": ["main"], }, "num_workers": args.num_workers, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "framework": args.framework, } stop = { "timesteps_total": args.stop_timesteps, "training_iteration": args.stop_iters, } # Train the "main" policy to play really well using self-play. results = None if not args.from_checkpoint: results = tune.run( "PPO", config=config, stop=stop, checkpoint_at_end=True, checkpoint_freq=10, verbose=2, progress_reporter=CLIReporter( metric_columns={ "training_iteration": "iter", "time_total_s": "time_total_s", "timesteps_total": "ts", "episodes_this_iter": "train_episodes", "policy_reward_mean/main": "reward", "win_rate": "win_rate", "league_size": "league_size", }, sort_by_metric=True, ), ) # Restore trained trainer (set to non-explore behavior) and play against # human on command line. if args.num_episodes_human_play > 0: num_episodes = 0 trainer = PPOTrainer(config=dict(config, **{"explore": False})) if args.from_checkpoint: trainer.restore(args.from_checkpoint) else: checkpoint = results.get_last_checkpoint() if not checkpoint: raise ValueError("No last checkpoint found in results!") trainer.restore(checkpoint) # Play from the command line against the trained agent # in an actual (non-RLlib-wrapped) open-spiel env. human_player = 1 env = Environment(args.env) while num_episodes < args.num_episodes_human_play: print("You play as {}".format("o" if human_player else "x")) time_step = env.reset() while not time_step.last(): player_id = time_step.observations["current_player"] if player_id == human_player: action = ask_user_for_action(time_step) else: obs = np.array( time_step.observations["info_state"][player_id]) action = trainer.compute_single_action( obs, policy_id="main") # In case computer chooses an invalid action, pick a # random one. legal = time_step.observations["legal_actions"][player_id] if action not in legal: action = np.random.choice(legal) time_step = env.step([action]) print(f"\n{env.get_state}") print(f"\n{env.get_state}") print("End of game!") if time_step.rewards[human_player] > 0: print("You win") elif time_step.rewards[human_player] < 0: print("You lose") else: print("Draw") # Switch order of players human_player = 1 - human_player num_episodes += 1 ray.shutdown()