maddpg.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """Contributed port of MADDPG from OpenAI baselines.
  2. The implementation has a couple assumptions:
  3. - The number of agents is fixed and known upfront.
  4. - Each agent is bound to a policy of the same name.
  5. - Discrete actions are sent as logits (pre-softmax).
  6. For a minimal example, see rllib/examples/two_step_game.py,
  7. and the README for how to run with the multi-agent particle envs.
  8. """
  9. import logging
  10. from typing import Type
  11. from ray.rllib.agents.trainer import COMMON_CONFIG, with_common_config
  12. from ray.rllib.agents.dqn.dqn import DQNTrainer
  13. from ray.rllib.contrib.maddpg.maddpg_policy import MADDPGTFPolicy
  14. from ray.rllib.policy.policy import Policy
  15. from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
  16. from ray.rllib.utils import merge_dicts
  17. from ray.rllib.utils.annotations import override
  18. from ray.rllib.utils.deprecation import DEPRECATED_VALUE
  19. from ray.rllib.utils.typing import TrainerConfigDict
  20. logger = logging.getLogger(__name__)
  21. logger.setLevel(logging.INFO)
  22. # yapf: disable
  23. # __sphinx_doc_begin__
  24. DEFAULT_CONFIG = with_common_config({
  25. # === Framework to run the algorithm ===
  26. "framework": "tf",
  27. # === Settings for each individual policy ===
  28. # ID of the agent controlled by this policy
  29. "agent_id": None,
  30. # Use a local critic for this policy.
  31. "use_local_critic": False,
  32. # === Evaluation ===
  33. # Evaluation interval
  34. "evaluation_interval": None,
  35. # Number of episodes to run per evaluation period.
  36. "evaluation_duration": 10,
  37. # === Model ===
  38. # Apply a state preprocessor with spec given by the "model" config option
  39. # (like other RL algorithms). This is mostly useful if you have a weird
  40. # observation shape, like an image. Disabled by default.
  41. "use_state_preprocessor": False,
  42. # Postprocess the policy network model output with these hidden layers. If
  43. # use_state_preprocessor is False, then these will be the *only* hidden
  44. # layers in the network.
  45. "actor_hiddens": [64, 64],
  46. # Hidden layers activation of the postprocessing stage of the policy
  47. # network
  48. "actor_hidden_activation": "relu",
  49. # Postprocess the critic network model output with these hidden layers;
  50. # again, if use_state_preprocessor is True, then the state will be
  51. # preprocessed by the model specified with the "model" config option first.
  52. "critic_hiddens": [64, 64],
  53. # Hidden layers activation of the postprocessing state of the critic.
  54. "critic_hidden_activation": "relu",
  55. # N-step Q learning
  56. "n_step": 1,
  57. # Algorithm for good policies.
  58. "good_policy": "maddpg",
  59. # Algorithm for adversary policies.
  60. "adv_policy": "maddpg",
  61. # === Replay buffer ===
  62. # Size of the replay buffer. Note that if async_updates is set, then
  63. # each worker will have a replay buffer of this size.
  64. "buffer_size": DEPRECATED_VALUE,
  65. "replay_buffer_config": {
  66. "type": "MultiAgentReplayBuffer",
  67. "capacity": int(1e6),
  68. },
  69. # Observation compression. Note that compression makes simulation slow in
  70. # MPE.
  71. "compress_observations": False,
  72. # If set, this will fix the ratio of replayed from a buffer and learned on
  73. # timesteps to sampled from an environment and stored in the replay buffer
  74. # timesteps. Otherwise, the replay will proceed at the native ratio
  75. # determined by (train_batch_size / rollout_fragment_length).
  76. "training_intensity": None,
  77. # Force lockstep replay mode for MADDPG.
  78. "multiagent": merge_dicts(COMMON_CONFIG["multiagent"], {
  79. "replay_mode": "lockstep",
  80. }),
  81. # === Optimization ===
  82. # Learning rate for the critic (Q-function) optimizer.
  83. "critic_lr": 1e-2,
  84. # Learning rate for the actor (policy) optimizer.
  85. "actor_lr": 1e-2,
  86. # Update the target network every `target_network_update_freq` steps.
  87. "target_network_update_freq": 0,
  88. # Update the target by \tau * policy + (1-\tau) * target_policy
  89. "tau": 0.01,
  90. # Weights for feature regularization for the actor
  91. "actor_feature_reg": 0.001,
  92. # If not None, clip gradients during optimization at this value
  93. "grad_norm_clipping": 0.5,
  94. # How many steps of the model to sample before learning starts.
  95. "learning_starts": 1024 * 25,
  96. # Update the replay buffer with this many samples at once. Note that this
  97. # setting applies per-worker if num_workers > 1.
  98. "rollout_fragment_length": 100,
  99. # Size of a batched sampled from replay buffer for training. Note that
  100. # if async_updates is set, then each worker returns gradients for a
  101. # batch of this size.
  102. "train_batch_size": 1024,
  103. # Number of env steps to optimize for before returning
  104. "timesteps_per_iteration": 0,
  105. # === Parallelism ===
  106. # Number of workers for collecting samples with. This only makes sense
  107. # to increase if your environment is particularly slow to sample, or if
  108. # you're using the Async or Ape-X optimizers.
  109. "num_workers": 1,
  110. # Prevent iterations from going lower than this time span
  111. "min_time_s_per_reporting": 0,
  112. })
  113. # __sphinx_doc_end__
  114. # yapf: enable
  115. def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
  116. samples = {}
  117. # Modify keys.
  118. for pid, p in policies.items():
  119. i = p.config["agent_id"]
  120. keys = multi_agent_batch.policy_batches[pid].keys()
  121. keys = ["_".join([k, str(i)]) for k in keys]
  122. samples.update(
  123. dict(zip(keys, multi_agent_batch.policy_batches[pid].values())))
  124. # Make ops and feed_dict to get "new_obs" from target action sampler.
  125. new_obs_ph_n = [p.new_obs_ph for p in policies.values()]
  126. new_obs_n = list()
  127. for k, v in samples.items():
  128. if "new_obs" in k:
  129. new_obs_n.append(v)
  130. for i, p in enumerate(policies.values()):
  131. feed_dict = {new_obs_ph_n[i]: new_obs_n[i]}
  132. new_act = p.get_session().run(p.target_act_sampler, feed_dict)
  133. samples.update({"new_actions_%d" % i: new_act})
  134. # Share samples among agents.
  135. policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()}
  136. return MultiAgentBatch(policy_batches, train_batch_size)
  137. class MADDPGTrainer(DQNTrainer):
  138. @classmethod
  139. @override(DQNTrainer)
  140. def get_default_config(cls) -> TrainerConfigDict:
  141. return DEFAULT_CONFIG
  142. @override(DQNTrainer)
  143. def validate_config(self, config: TrainerConfigDict) -> None:
  144. """Adds the `before_learn_on_batch` hook to the config.
  145. This hook is called explicitly prior to TrainOneStep() in the execution
  146. setups for DQN and APEX.
  147. """
  148. # Call super's validation method.
  149. super().validate_config(config)
  150. def f(batch, workers, config):
  151. policies = dict(workers.local_worker()
  152. .foreach_trainable_policy(lambda p, i: (i, p)))
  153. return before_learn_on_batch(batch, policies,
  154. config["train_batch_size"])
  155. config["before_learn_on_batch"] = f
  156. @override(DQNTrainer)
  157. def get_default_policy_class(self,
  158. config: TrainerConfigDict) -> Type[Policy]:
  159. return MADDPGTFPolicy