centralized_critic.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. """An example of customizing PPO to leverage a centralized critic.
  2. Here the model and policy are hard-coded to implement a centralized critic
  3. for TwoStepGame, but you can adapt this for your own use cases.
  4. Compared to simply running `rllib/examples/two_step_game.py --run=PPO`,
  5. this centralized critic version reaches vf_explained_variance=1.0 more stably
  6. since it takes into account the opponent actions as well as the policy's.
  7. Note that this is also using two independent policies instead of weight-sharing
  8. with one.
  9. See also: centralized_critic_2.py for a simpler approach that instead
  10. modifies the environment.
  11. """
  12. import argparse
  13. import numpy as np
  14. from gymnasium.spaces import Discrete
  15. import os
  16. import ray
  17. from ray import air, tune
  18. from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig
  19. from ray.rllib.algorithms.ppo.ppo_tf_policy import (
  20. PPOTF1Policy,
  21. PPOTF2Policy,
  22. )
  23. from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy
  24. from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
  25. from ray.rllib.examples.env.two_step_game import TwoStepGame
  26. from ray.rllib.examples.models.centralized_critic_models import (
  27. CentralizedCriticModel,
  28. TorchCentralizedCriticModel,
  29. )
  30. from ray.rllib.models import ModelCatalog
  31. from ray.rllib.policy.sample_batch import SampleBatch
  32. from ray.rllib.utils.annotations import override
  33. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  34. from ray.rllib.utils.numpy import convert_to_numpy
  35. from ray.rllib.utils.test_utils import check_learning_achieved
  36. from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
  37. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  38. tf1, tf, tfv = try_import_tf()
  39. torch, nn = try_import_torch()
  40. OPPONENT_OBS = "opponent_obs"
  41. OPPONENT_ACTION = "opponent_action"
  42. parser = argparse.ArgumentParser()
  43. parser.add_argument(
  44. "--framework",
  45. choices=["tf", "tf2", "torch"],
  46. default="torch",
  47. help="The DL framework specifier.",
  48. )
  49. parser.add_argument(
  50. "--as-test",
  51. action="store_true",
  52. help="Whether this script should be run as a test: --stop-reward must "
  53. "be achieved within --stop-timesteps AND --stop-iters.",
  54. )
  55. parser.add_argument(
  56. "--stop-iters", type=int, default=100, help="Number of iterations to train."
  57. )
  58. parser.add_argument(
  59. "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
  60. )
  61. parser.add_argument(
  62. "--stop-reward", type=float, default=7.99, help="Reward at which we stop training."
  63. )
  64. class CentralizedValueMixin:
  65. """Add method to evaluate the central value function from the model."""
  66. def __init__(self):
  67. if self.config["framework"] != "torch":
  68. self.compute_central_vf = make_tf_callable(self.get_session())(
  69. self.model.central_value_function
  70. )
  71. else:
  72. self.compute_central_vf = self.model.central_value_function
  73. # Grabs the opponent obs/act and includes it in the experience train_batch,
  74. # and computes GAE using the central vf predictions.
  75. def centralized_critic_postprocessing(
  76. policy, sample_batch, other_agent_batches=None, episode=None
  77. ):
  78. pytorch = policy.config["framework"] == "torch"
  79. if (pytorch and hasattr(policy, "compute_central_vf")) or (
  80. not pytorch and policy.loss_initialized()
  81. ):
  82. assert other_agent_batches is not None
  83. if policy.config["enable_connectors"]:
  84. [(_, _, opponent_batch)] = list(other_agent_batches.values())
  85. else:
  86. [(_, opponent_batch)] = list(other_agent_batches.values())
  87. # also record the opponent obs and actions in the trajectory
  88. sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS]
  89. sample_batch[OPPONENT_ACTION] = opponent_batch[SampleBatch.ACTIONS]
  90. # overwrite default VF prediction with the central VF
  91. if args.framework == "torch":
  92. sample_batch[SampleBatch.VF_PREDS] = (
  93. policy.compute_central_vf(
  94. convert_to_torch_tensor(
  95. sample_batch[SampleBatch.CUR_OBS], policy.device
  96. ),
  97. convert_to_torch_tensor(sample_batch[OPPONENT_OBS], policy.device),
  98. convert_to_torch_tensor(
  99. sample_batch[OPPONENT_ACTION], policy.device
  100. ),
  101. )
  102. .cpu()
  103. .detach()
  104. .numpy()
  105. )
  106. else:
  107. sample_batch[SampleBatch.VF_PREDS] = convert_to_numpy(
  108. policy.compute_central_vf(
  109. sample_batch[SampleBatch.CUR_OBS],
  110. sample_batch[OPPONENT_OBS],
  111. sample_batch[OPPONENT_ACTION],
  112. )
  113. )
  114. else:
  115. # Policy hasn't been initialized yet, use zeros.
  116. sample_batch[OPPONENT_OBS] = np.zeros_like(sample_batch[SampleBatch.CUR_OBS])
  117. sample_batch[OPPONENT_ACTION] = np.zeros_like(sample_batch[SampleBatch.ACTIONS])
  118. sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
  119. sample_batch[SampleBatch.REWARDS], dtype=np.float32
  120. )
  121. completed = sample_batch[SampleBatch.TERMINATEDS][-1]
  122. if completed:
  123. last_r = 0.0
  124. else:
  125. last_r = sample_batch[SampleBatch.VF_PREDS][-1]
  126. train_batch = compute_advantages(
  127. sample_batch,
  128. last_r,
  129. policy.config["gamma"],
  130. policy.config["lambda"],
  131. use_gae=policy.config["use_gae"],
  132. )
  133. return train_batch
  134. # Copied from PPO but optimizing the central value function.
  135. def loss_with_central_critic(policy, base_policy, model, dist_class, train_batch):
  136. # Save original value function.
  137. vf_saved = model.value_function
  138. # Calculate loss with a custom value function.
  139. model.value_function = lambda: policy.model.central_value_function(
  140. train_batch[SampleBatch.CUR_OBS],
  141. train_batch[OPPONENT_OBS],
  142. train_batch[OPPONENT_ACTION],
  143. )
  144. policy._central_value_out = model.value_function()
  145. loss = base_policy.loss(model, dist_class, train_batch)
  146. # Restore original value function.
  147. model.value_function = vf_saved
  148. return loss
  149. def central_vf_stats(policy, train_batch):
  150. # Report the explained variance of the central value function.
  151. return {
  152. "vf_explained_var": explained_variance(
  153. train_batch[Postprocessing.VALUE_TARGETS], policy._central_value_out
  154. )
  155. }
  156. def get_ccppo_policy(base):
  157. class CCPPOTFPolicy(CentralizedValueMixin, base):
  158. def __init__(self, observation_space, action_space, config):
  159. base.__init__(self, observation_space, action_space, config)
  160. CentralizedValueMixin.__init__(self)
  161. @override(base)
  162. def loss(self, model, dist_class, train_batch):
  163. # Use super() to get to the base PPO policy.
  164. # This special loss function utilizes a shared
  165. # value function defined on self, and the loss function
  166. # defined on PPO policies.
  167. return loss_with_central_critic(
  168. self, super(), model, dist_class, train_batch
  169. )
  170. @override(base)
  171. def postprocess_trajectory(
  172. self, sample_batch, other_agent_batches=None, episode=None
  173. ):
  174. return centralized_critic_postprocessing(
  175. self, sample_batch, other_agent_batches, episode
  176. )
  177. @override(base)
  178. def stats_fn(self, train_batch: SampleBatch):
  179. stats = super().stats_fn(train_batch)
  180. stats.update(central_vf_stats(self, train_batch))
  181. return stats
  182. return CCPPOTFPolicy
  183. CCPPOStaticGraphTFPolicy = get_ccppo_policy(PPOTF1Policy)
  184. CCPPOEagerTFPolicy = get_ccppo_policy(PPOTF2Policy)
  185. class CCPPOTorchPolicy(CentralizedValueMixin, PPOTorchPolicy):
  186. def __init__(self, observation_space, action_space, config):
  187. PPOTorchPolicy.__init__(self, observation_space, action_space, config)
  188. CentralizedValueMixin.__init__(self)
  189. @override(PPOTorchPolicy)
  190. def loss(self, model, dist_class, train_batch):
  191. return loss_with_central_critic(self, super(), model, dist_class, train_batch)
  192. @override(PPOTorchPolicy)
  193. def postprocess_trajectory(
  194. self, sample_batch, other_agent_batches=None, episode=None
  195. ):
  196. return centralized_critic_postprocessing(
  197. self, sample_batch, other_agent_batches, episode
  198. )
  199. class CentralizedCritic(PPO):
  200. @classmethod
  201. @override(PPO)
  202. def get_default_policy_class(cls, config):
  203. if config["framework"] == "torch":
  204. return CCPPOTorchPolicy
  205. elif config["framework"] == "tf":
  206. return CCPPOStaticGraphTFPolicy
  207. else:
  208. return CCPPOEagerTFPolicy
  209. if __name__ == "__main__":
  210. ray.init(local_mode=True)
  211. args = parser.parse_args()
  212. ModelCatalog.register_custom_model(
  213. "cc_model",
  214. TorchCentralizedCriticModel
  215. if args.framework == "torch"
  216. else CentralizedCriticModel,
  217. )
  218. config = (
  219. PPOConfig()
  220. .environment(TwoStepGame)
  221. .framework(args.framework)
  222. .rollouts(batch_mode="complete_episodes", num_rollout_workers=0)
  223. # TODO (Kourosh): Lift this example to the new RLModule stack, and enable it.
  224. .training(model={"custom_model": "cc_model"}, _enable_learner_api=False)
  225. .multi_agent(
  226. policies={
  227. "pol1": (
  228. None,
  229. Discrete(6),
  230. TwoStepGame.action_space,
  231. # `framework` would also be ok here.
  232. PPOConfig.overrides(framework_str=args.framework),
  233. ),
  234. "pol2": (
  235. None,
  236. Discrete(6),
  237. TwoStepGame.action_space,
  238. # `framework` would also be ok here.
  239. PPOConfig.overrides(framework_str=args.framework),
  240. ),
  241. },
  242. policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
  243. if agent_id == 0
  244. else "pol2",
  245. )
  246. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  247. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  248. .rl_module(_enable_rl_module_api=False)
  249. )
  250. stop = {
  251. "training_iteration": args.stop_iters,
  252. "timesteps_total": args.stop_timesteps,
  253. "episode_reward_mean": args.stop_reward,
  254. }
  255. tuner = tune.Tuner(
  256. CentralizedCritic,
  257. param_space=config.to_dict(),
  258. run_config=air.RunConfig(stop=stop, verbose=1),
  259. )
  260. results = tuner.fit()
  261. if args.as_test:
  262. check_learning_achieved(results, args.stop_reward)