123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- """An example of customizing PPO to leverage a centralized critic.
- Here the model and policy are hard-coded to implement a centralized critic
- for TwoStepGame, but you can adapt this for your own use cases.
- Compared to simply running `rllib/examples/two_step_game.py --run=PPO`,
- this centralized critic version reaches vf_explained_variance=1.0 more stably
- since it takes into account the opponent actions as well as the policy's.
- Note that this is also using two independent policies instead of weight-sharing
- with one.
- See also: centralized_critic_2.py for a simpler approach that instead
- modifies the environment.
- """
- import argparse
- import numpy as np
- from gymnasium.spaces import Discrete
- import os
- import ray
- from ray import air, tune
- from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
- from ray.rllib.algorithms.ppo.ppo_tf_policy import (
- PPOTF1Policy,
- PPOTF2Policy,
- )
- from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
- from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
- from ray.rllib.examples.env.two_step_game import TwoStepGame
- from ray.rllib.examples.models.centralized_critic_models import (
- CentralizedCriticModel,
- TorchCentralizedCriticModel,
- )
- from ray.rllib.models import ModelCatalog
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.numpy import convert_to_numpy
- from ray.rllib.utils.test_utils import check_learning_achieved
- from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
- from ray.rllib.utils.torch_utils import convert_to_torch_tensor
- tf1, tf, tfv = try_import_tf()
- torch, nn = try_import_torch()
- OPPONENT_OBS = "opponent_obs"
- OPPONENT_ACTION = "opponent_action"
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--framework",
- choices=["tf", "tf2", "torch"],
- default="torch",
- help="The DL framework specifier.",
- )
- parser.add_argument(
- "--as-test",
- action="store_true",
- help="Whether this script should be run as a test: --stop-reward must "
- "be achieved within --stop-timesteps AND --stop-iters.",
- )
- parser.add_argument(
- "--stop-iters", type=int, default=100, help="Number of iterations to train."
- )
- parser.add_argument(
- "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
- )
- parser.add_argument(
- "--stop-reward", type=float, default=7.99, help="Reward at which we stop training."
- )
- class CentralizedValueMixin:
- """Add method to evaluate the central value function from the model."""
- def __init__(self):
- if self.config["framework"] != "torch":
- self.compute_central_vf = make_tf_callable(self.get_session())(
- self.model.central_value_function
- )
- else:
- self.compute_central_vf = self.model.central_value_function
- # Grabs the opponent obs/act and includes it in the experience train_batch,
- # and computes GAE using the central vf predictions.
- def centralized_critic_postprocessing(
- policy, sample_batch, other_agent_batches=None, episode=None
- ):
- pytorch = policy.config["framework"] == "torch"
- if (pytorch and hasattr(policy, "compute_central_vf")) or (
- not pytorch and policy.loss_initialized()
- ):
- assert other_agent_batches is not None
- if policy.config["enable_connectors"]:
- [(_, _, opponent_batch)] = list(other_agent_batches.values())
- else:
- [(_, opponent_batch)] = list(other_agent_batches.values())
- # also record the opponent obs and actions in the trajectory
- sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS]
- sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS]
- # overwrite default VF prediction with the central VF
- if args.framework == "torch":
- sample_batch[SampleBatch.VF_PREDS] = (
- policy.compute_central_vf(
- convert_to_torch_tensor(
- sample_batch[SampleBatch.CUR_OBS], policy.device
- ),
- convert_to_torch_tensor(sample_batch[OPPONENT_OBS], policy.device),
- convert_to_torch_tensor(
- sample_batch[OPPONENT_ACTION], policy.device
- ),
- )
- .cpu()
- .detach()
- .numpy()
- )
- else:
- sample_batch[SampleBatch.VF_PREDS] = convert_to_numpy(
- policy.compute_central_vf(
- sample_batch[SampleBatch.CUR_OBS],
- sample_batch[OPPONENT_OBS],
- sample_batch[OPPONENT_ACTION],
- )
- )
- else:
- # Policy hasn't been initialized yet, use zeros.
- sample_batch[OPPONENT_OBS] = np.zeros_like(sample_batch[SampleBatch.CUR_OBS])
- sample_batch[OPPONENT_ACTION] = np.zeros_like(sample_batch[SampleBatch.ACTIONS])
- sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
- sample_batch[SampleBatch.REWARDS], dtype=np.float32
- )
- completed = sample_batch[SampleBatch.TERMINATEDS][-1]
- if completed:
- last_r = 0.0
- else:
- last_r = sample_batch[SampleBatch.VF_PREDS][-1]
- train_batch = compute_advantages(
- sample_batch,
- last_r,
- policy.config["gamma"],
- policy.config["lambda"],
- use_gae=policy.config["use_gae"],
- )
- return train_batch
- # Copied from PPO but optimizing the central value function.
- def loss_with_central_critic(policy, base_policy, model, dist_class, train_batch):
- # Save original value function.
- vf_saved = model.value_function
- # Calculate loss with a custom value function.
- model.value_function = lambda: policy.model.central_value_function(
- train_batch[SampleBatch.CUR_OBS],
- train_batch[OPPONENT_OBS],
- train_batch[OPPONENT_ACTION],
- )
- policy._central_value_out = model.value_function()
- loss = base_policy.loss(model, dist_class, train_batch)
- # Restore original value function.
- model.value_function = vf_saved
- return loss
- def central_vf_stats(policy, train_batch):
- # Report the explained variance of the central value function.
- return {
- "vf_explained_var": explained_variance(
- train_batch[Postprocessing.VALUE_TARGETS], policy._central_value_out
- )
- }
- def get_ccppo_policy(base):
- class CCPPOTFPolicy(CentralizedValueMixin, base):
- def __init__(self, observation_space, action_space, config):
- base.__init__(self, observation_space, action_space, config)
- CentralizedValueMixin.__init__(self)
- @override(base)
- def loss(self, model, dist_class, train_batch):
- # Use super() to get to the base PPO policy.
- # This special loss function utilizes a shared
- # value function defined on self, and the loss function
- # defined on PPO policies.
- return loss_with_central_critic(
- self, super(), model, dist_class, train_batch
- )
- @override(base)
- def postprocess_trajectory(
- self, sample_batch, other_agent_batches=None, episode=None
- ):
- return centralized_critic_postprocessing(
- self, sample_batch, other_agent_batches, episode
- )
- @override(base)
- def stats_fn(self, train_batch: SampleBatch):
- stats = super().stats_fn(train_batch)
- stats.update(central_vf_stats(self, train_batch))
- return stats
- return CCPPOTFPolicy
- CCPPOStaticGraphTFPolicy = get_ccppo_policy(PPOTF1Policy)
- CCPPOEagerTFPolicy = get_ccppo_policy(PPOTF2Policy)
- class CCPPOTorchPolicy(CentralizedValueMixin, PPOTorchPolicy):
- def __init__(self, observation_space, action_space, config):
- PPOTorchPolicy.__init__(self, observation_space, action_space, config)
- CentralizedValueMixin.__init__(self)
- @override(PPOTorchPolicy)
- def loss(self, model, dist_class, train_batch):
- return loss_with_central_critic(self, super(), model, dist_class, train_batch)
- @override(PPOTorchPolicy)
- def postprocess_trajectory(
- self, sample_batch, other_agent_batches=None, episode=None
- ):
- return centralized_critic_postprocessing(
- self, sample_batch, other_agent_batches, episode
- )
- class CentralizedCritic(PPO):
- @classmethod
- @override(PPO)
- def get_default_policy_class(cls, config):
- if config["framework"] == "torch":
- return CCPPOTorchPolicy
- elif config["framework"] == "tf":
- return CCPPOStaticGraphTFPolicy
- else:
- return CCPPOEagerTFPolicy
- if __name__ == "__main__":
- ray.init(local_mode=True)
- args = parser.parse_args()
- ModelCatalog.register_custom_model(
- "cc_model",
- TorchCentralizedCriticModel
- if args.framework == "torch"
- else CentralizedCriticModel,
- )
- config = (
- PPOConfig()
- .environment(TwoStepGame)
- .framework(args.framework)
- .rollouts(batch_mode="complete_episodes", num_rollout_workers=0)
- # TODO (Kourosh): Lift this example to the new RLModule stack, and enable it.
- .training(model={"custom_model": "cc_model"}, _enable_learner_api=False)
- .multi_agent(
- policies={
- "pol1": (
- None,
- Discrete(6),
- TwoStepGame.action_space,
- # `framework` would also be ok here.
- PPOConfig.overrides(framework_str=args.framework),
- ),
- "pol2": (
- None,
- Discrete(6),
- TwoStepGame.action_space,
- # `framework` would also be ok here.
- PPOConfig.overrides(framework_str=args.framework),
- ),
- },
- policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
- if agent_id == 0
- else "pol2",
- )
- # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
- .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
- .rl_module(_enable_rl_module_api=False)
- )
- stop = {
- "training_iteration": args.stop_iters,
- "timesteps_total": args.stop_timesteps,
- "episode_reward_mean": args.stop_reward,
- }
- tuner = tune.Tuner(
- CentralizedCritic,
- param_space=config.to_dict(),
- run_config=air.RunConfig(stop=stop, verbose=1),
- )
- results = tuner.fit()
- if args.as_test:
- check_learning_achieved(results, args.stop_reward)
|