centralized_critic_2.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """An example of implementing a centralized critic with ObservationFunction.
  2. The advantage of this approach is that it's very simple and you don't have to
  3. change the algorithm at all -- just use callbacks and a custom model.
  4. However, it is a bit less principled in that you have to change the agent
  5. observation spaces to include data that is only used at train time.
  6. See also: centralized_critic.py for an alternative approach that instead
  7. modifies the policy to add a centralized value function.
  8. """
  9. import numpy as np
  10. from gym.spaces import Dict, Discrete
  11. import argparse
  12. import os
  13. from ray import tune
  14. from ray.rllib.agents.callbacks import DefaultCallbacks
  15. from ray.rllib.examples.models.centralized_critic_models import \
  16. YetAnotherCentralizedCriticModel, YetAnotherTorchCentralizedCriticModel
  17. from ray.rllib.examples.env.two_step_game import TwoStepGame
  18. from ray.rllib.models import ModelCatalog
  19. from ray.rllib.policy.sample_batch import SampleBatch
  20. from ray.rllib.utils.test_utils import check_learning_achieved
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument(
  23. "--framework",
  24. choices=["tf", "tf2", "tfe", "torch"],
  25. default="tf",
  26. help="The DL framework specifier.")
  27. parser.add_argument(
  28. "--as-test",
  29. action="store_true",
  30. help="Whether this script should be run as a test: --stop-reward must "
  31. "be achieved within --stop-timesteps AND --stop-iters.")
  32. parser.add_argument(
  33. "--stop-iters",
  34. type=int,
  35. default=100,
  36. help="Number of iterations to train.")
  37. parser.add_argument(
  38. "--stop-timesteps",
  39. type=int,
  40. default=100000,
  41. help="Number of timesteps to train.")
  42. parser.add_argument(
  43. "--stop-reward",
  44. type=float,
  45. default=7.99,
  46. help="Reward at which we stop training.")
  47. class FillInActions(DefaultCallbacks):
  48. """Fills in the opponent actions info in the training batches."""
  49. def on_postprocess_trajectory(self, worker, episode, agent_id, policy_id,
  50. policies, postprocessed_batch,
  51. original_batches, **kwargs):
  52. to_update = postprocessed_batch[SampleBatch.CUR_OBS]
  53. other_id = 1 if agent_id == 0 else 0
  54. action_encoder = ModelCatalog.get_preprocessor_for_space(Discrete(2))
  55. # set the opponent actions into the observation
  56. _, opponent_batch = original_batches[other_id]
  57. opponent_actions = np.array([
  58. action_encoder.transform(a)
  59. for a in opponent_batch[SampleBatch.ACTIONS]
  60. ])
  61. to_update[:, -2:] = opponent_actions
  62. def central_critic_observer(agent_obs, **kw):
  63. """Rewrites the agent obs to include opponent data for training."""
  64. new_obs = {
  65. 0: {
  66. "own_obs": agent_obs[0],
  67. "opponent_obs": agent_obs[1],
  68. "opponent_action": 0, # filled in by FillInActions
  69. },
  70. 1: {
  71. "own_obs": agent_obs[1],
  72. "opponent_obs": agent_obs[0],
  73. "opponent_action": 0, # filled in by FillInActions
  74. },
  75. }
  76. return new_obs
  77. if __name__ == "__main__":
  78. args = parser.parse_args()
  79. ModelCatalog.register_custom_model(
  80. "cc_model", YetAnotherTorchCentralizedCriticModel
  81. if args.framework == "torch" else YetAnotherCentralizedCriticModel)
  82. action_space = Discrete(2)
  83. observer_space = Dict({
  84. "own_obs": Discrete(6),
  85. # These two fields are filled in by the CentralCriticObserver, and are
  86. # not used for inference, only for training.
  87. "opponent_obs": Discrete(6),
  88. "opponent_action": Discrete(2),
  89. })
  90. config = {
  91. "env": TwoStepGame,
  92. "batch_mode": "complete_episodes",
  93. "callbacks": FillInActions,
  94. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  95. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  96. "num_workers": 0,
  97. "multiagent": {
  98. "policies": {
  99. "pol1": (None, observer_space, action_space, {}),
  100. "pol2": (None, observer_space, action_space, {}),
  101. },
  102. "policy_mapping_fn": (
  103. lambda aid, **kwargs: "pol1" if aid == 0 else "pol2"),
  104. "observation_fn": central_critic_observer,
  105. },
  106. "model": {
  107. "custom_model": "cc_model",
  108. },
  109. "framework": args.framework,
  110. }
  111. stop = {
  112. "training_iteration": args.stop_iters,
  113. "timesteps_total": args.stop_timesteps,
  114. "episode_reward_mean": args.stop_reward,
  115. }
  116. results = tune.run("PPO", config=config, stop=stop, verbose=1)
  117. if args.as_test:
  118. check_learning_achieved(results, args.stop_reward)