123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- """Example of using two different training methods at once in multi-agent.
- Here we create a number of CartPole agents, some of which are trained with
- DQN, and some of which are trained with PPO. We periodically sync weights
- between the two trainers (note that no such syncing is needed when using just
- a single training method).
- For a simpler example, see also: multiagent_cartpole.py
- """
- import argparse
- import gym
- import os
- import ray
- from ray.rllib.agents.dqn import DQNTrainer, DQNTFPolicy, DQNTorchPolicy
- from ray.rllib.agents.ppo import PPOTrainer, PPOTFPolicy, PPOTorchPolicy
- from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
- from ray.tune.logger import pretty_print
- from ray.tune.registry import register_env
- parser = argparse.ArgumentParser()
- # Use torch for both policies.
- parser.add_argument(
- "--framework",
- choices=["tf", "tf2", "tfe", "torch"],
- default="tf",
- help="The DL framework specifier.")
- parser.add_argument(
- "--as-test",
- action="store_true",
- help="Whether this script should be run as a test: --stop-reward must "
- "be achieved within --stop-timesteps AND --stop-iters.")
- parser.add_argument(
- "--stop-iters",
- type=int,
- default=20,
- help="Number of iterations to train.")
- parser.add_argument(
- "--stop-timesteps",
- type=int,
- default=100000,
- help="Number of timesteps to train.")
- parser.add_argument(
- "--stop-reward",
- type=float,
- default=50.0,
- help="Reward at which we stop training.")
- if __name__ == "__main__":
- args = parser.parse_args()
- ray.init()
- # Simple environment with 4 independent cartpole entities
- register_env("multi_agent_cartpole",
- lambda _: MultiAgentCartPole({"num_agents": 4}))
- single_dummy_env = gym.make("CartPole-v0")
- obs_space = single_dummy_env.observation_space
- act_space = single_dummy_env.action_space
- # You can also have multiple policies per trainer, but here we just
- # show one each for PPO and DQN.
- policies = {
- "ppo_policy": (PPOTorchPolicy if args.framework == "torch" else
- PPOTFPolicy, obs_space, act_space, {}),
- "dqn_policy": (DQNTorchPolicy if args.framework == "torch" else
- DQNTFPolicy, obs_space, act_space, {}),
- }
- def policy_mapping_fn(agent_id, episode, worker, **kwargs):
- if agent_id % 2 == 0:
- return "ppo_policy"
- else:
- return "dqn_policy"
- ppo_trainer = PPOTrainer(
- env="multi_agent_cartpole",
- config={
- "multiagent": {
- "policies": policies,
- "policy_mapping_fn": policy_mapping_fn,
- "policies_to_train": ["ppo_policy"],
- },
- "model": {
- "vf_share_layers": True,
- },
- "num_sgd_iter": 6,
- "vf_loss_coeff": 0.01,
- # disable filters, otherwise we would need to synchronize those
- # as well to the DQN agent
- "observation_filter": "MeanStdFilter",
- # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
- "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
- "framework": args.framework,
- })
- dqn_trainer = DQNTrainer(
- env="multi_agent_cartpole",
- config={
- "multiagent": {
- "policies": policies,
- "policy_mapping_fn": policy_mapping_fn,
- "policies_to_train": ["dqn_policy"],
- },
- "model": {
- "vf_share_layers": True,
- },
- "gamma": 0.95,
- "n_step": 3,
- # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
- "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
- "framework": args.framework,
- })
- # You should see both the printed X and Y approach 200 as this trains:
- # info:
- # policy_reward_mean:
- # dqn_policy: X
- # ppo_policy: Y
- for i in range(args.stop_iters):
- print("== Iteration", i, "==")
- # improve the DQN policy
- print("-- DQN --")
- result_dqn = dqn_trainer.train()
- print(pretty_print(result_dqn))
- # improve the PPO policy
- print("-- PPO --")
- result_ppo = ppo_trainer.train()
- print(pretty_print(result_ppo))
- # Test passed gracefully.
- if args.as_test and \
- result_dqn["episode_reward_mean"] > args.stop_reward and \
- result_ppo["episode_reward_mean"] > args.stop_reward:
- print("test passed (both agents above requested reward)")
- quit(0)
- # swap weights to synchronize
- dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"]))
- ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))
- # Desired reward not reached.
- if args.as_test:
- raise ValueError("Desired reward ({}) not reached!".format(
- args.stop_reward))
|