123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- """Example showing how one can restore a connector enabled TF policy
- checkpoint for a new self-play PyTorch training job.
- The checkpointed policy may be trained with a different algorithm too.
- """
- import argparse
- from functools import partial
- import os
- import tempfile
- import ray
- from ray import air, tune
- from ray.rllib.algorithms.callbacks import DefaultCallbacks
- from ray.rllib.algorithms.sac import SACConfig
- from ray.rllib.env.utils import try_import_pyspiel
- from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv
- from ray.rllib.examples.connectors.prepare_checkpoint import (
- create_open_spiel_checkpoint,
- )
- from ray.rllib.policy.policy import Policy
- from ray.tune import CLIReporter, register_env
- pyspiel = try_import_pyspiel(error=True)
- register_env(
- "open_spiel_env", lambda _: OpenSpielEnv(pyspiel.load_game("connect_four"))
- )
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--train_iteration",
- type=int,
- default=10,
- help="Number of iterations to train.",
- )
- args = parser.parse_args()
- MAIN_POLICY_ID = "main"
- OPPONENT_POLICY_ID = "opponent"
- class AddPolicyCallback(DefaultCallbacks):
- def __init__(self, checkpoint_dir):
- self._checkpoint_dir = checkpoint_dir
- super().__init__()
- def on_algorithm_init(self, *, algorithm, **kwargs):
- policy = Policy.from_checkpoint(
- self._checkpoint_dir, policy_ids=[OPPONENT_POLICY_ID]
- )
- # Add restored policy to Algorithm.
- # Note that this policy doesn't have to be trained with the same algorithm
- # of the training stack. You can even mix up TF policies with a Torch stack.
- algorithm.add_policy(
- policy_id=OPPONENT_POLICY_ID,
- policy=policy,
- evaluation_workers=True,
- )
- def policy_mapping_fn(agent_id, episode, worker, **kwargs):
- # main policy plays against opponent policy.
- return MAIN_POLICY_ID if episode.episode_id % 2 == agent_id else OPPONENT_POLICY_ID
- def main(checkpoint_dir):
- config = (
- SACConfig()
- .environment("open_spiel_env")
- .framework("torch")
- .callbacks(partial(AddPolicyCallback, checkpoint_dir))
- .rollouts(
- num_rollout_workers=1,
- num_envs_per_worker=5,
- # We will be restoring a TF2 policy.
- # So tell the RolloutWorkers to enable TF eager exec as well, even if
- # framework is set to torch.
- enable_tf1_exec_eagerly=True,
- )
- .training(model={"fcnet_hiddens": [512, 512]})
- .multi_agent(
- # Initial policy map: Random and PPO. This will be expanded
- # to more policy snapshots taken from "main" against which "main"
- # will then play (instead of "random"). This is done in the
- # custom callback defined above (`SelfPlayCallback`).
- # Note: We will add the "opponent" policy with callback.
- policies={MAIN_POLICY_ID}, # Our main policy, we'd like to optimize.
- # Assign agent 0 and 1 randomly to the "main" policy or
- # to the opponent ("random" at first). Make sure (via episode_id)
- # that "main" always plays against "random" (and not against
- # another "main").
- policy_mapping_fn=policy_mapping_fn,
- # Always just train the "main" policy.
- policies_to_train=[MAIN_POLICY_ID],
- )
- )
- stop = {
- "training_iteration": args.train_iteration,
- }
- # Train the "main" policy to play really well using self-play.
- tuner = tune.Tuner(
- "SAC",
- param_space=config.to_dict(),
- run_config=air.RunConfig(
- stop=stop,
- checkpoint_config=air.CheckpointConfig(
- checkpoint_at_end=True,
- checkpoint_frequency=10,
- ),
- verbose=2,
- progress_reporter=CLIReporter(
- metric_columns={
- "training_iteration": "iter",
- "time_total_s": "time_total_s",
- "timesteps_total": "ts",
- "episodes_this_iter": "train_episodes",
- "policy_reward_mean/main": "reward_main",
- },
- sort_by_metric=True,
- ),
- ),
- )
- tuner.fit()
- if __name__ == "__main__":
- ray.init()
- with tempfile.TemporaryDirectory() as tmpdir:
- create_open_spiel_checkpoint(tmpdir)
- policy_checkpoint_path = os.path.join(
- tmpdir,
- "checkpoint_000000",
- "policies",
- OPPONENT_POLICY_ID,
- )
- main(policy_checkpoint_path)
- ray.shutdown()
|