maddpg_policy.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import ray
  2. from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip
  3. from ray.rllib.evaluation.postprocessing import adjust_nstep
  4. from ray.rllib.models import ModelCatalog
  5. from ray.rllib.policy.sample_batch import SampleBatch
  6. from ray.rllib.utils.annotations import override
  7. from ray.rllib.utils.error import UnsupportedSpaceException
  8. from ray.rllib.policy.policy import Policy
  9. from ray.rllib.policy.tf_policy import TFPolicy
  10. from ray.rllib.utils.framework import try_import_tf, try_import_tfp
  11. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  12. import logging
  13. from gym.spaces import Box, Discrete
  14. import numpy as np
  15. logger = logging.getLogger(__name__)
  16. tf1, tf, tfv = try_import_tf()
  17. tfp = try_import_tfp()
  18. class MADDPGPostprocessing:
  19. """Implements agentwise termination signal and n-step learning."""
  20. @override(Policy)
  21. def postprocess_trajectory(self,
  22. sample_batch,
  23. other_agent_batches=None,
  24. episode=None):
  25. # FIXME: Get done from info is required since agentwise done is not
  26. # supported now.
  27. sample_batch[SampleBatch.DONES] = self.get_done_from_info(
  28. sample_batch[SampleBatch.INFOS])
  29. # N-step Q adjustments
  30. if self.config["n_step"] > 1:
  31. adjust_nstep(self.config["n_step"], self.config["gamma"],
  32. sample_batch)
  33. return sample_batch
  34. class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
  35. def __init__(self, obs_space, act_space, config):
  36. # _____ Initial Configuration
  37. config = dict(ray.rllib.contrib.maddpg.DEFAULT_CONFIG, **config)
  38. self.config = config
  39. self.global_step = tf1.train.get_or_create_global_step()
  40. # FIXME: Get done from info is required since agentwise done is not
  41. # supported now.
  42. self.get_done_from_info = np.vectorize(
  43. lambda info: info.get("done", False))
  44. agent_id = config["agent_id"]
  45. if agent_id is None:
  46. raise ValueError("Must set `agent_id` in the policy config.")
  47. if type(agent_id) is not int:
  48. raise ValueError("Agent ids must be integers for MADDPG.")
  49. # _____ Environment Setting
  50. def _make_continuous_space(space):
  51. if isinstance(space, Box):
  52. return space
  53. elif isinstance(space, Discrete):
  54. return Box(
  55. low=np.zeros((space.n, )), high=np.ones((space.n, )))
  56. else:
  57. raise UnsupportedSpaceException(
  58. "Space {} is not supported.".format(space))
  59. obs_space_n = [
  60. _make_continuous_space(space)
  61. for _, (_, space, _,
  62. _) in config["multiagent"]["policies"].items()
  63. ]
  64. act_space_n = [
  65. _make_continuous_space(space)
  66. for _, (_, _, space,
  67. _) in config["multiagent"]["policies"].items()
  68. ]
  69. # _____ Placeholders
  70. # Placeholders for policy evaluation and updates
  71. def _make_ph_n(space_n, name=""):
  72. return [
  73. tf1.placeholder(
  74. tf.float32,
  75. shape=(None, ) + space.shape,
  76. name=name + "_%d" % i) for i, space in enumerate(space_n)
  77. ]
  78. obs_ph_n = _make_ph_n(obs_space_n, SampleBatch.OBS)
  79. act_ph_n = _make_ph_n(act_space_n, SampleBatch.ACTIONS)
  80. new_obs_ph_n = _make_ph_n(obs_space_n, SampleBatch.NEXT_OBS)
  81. new_act_ph_n = _make_ph_n(act_space_n, "new_actions")
  82. rew_ph = tf1.placeholder(
  83. tf.float32, shape=None, name="rewards_{}".format(agent_id))
  84. done_ph = tf1.placeholder(
  85. tf.float32, shape=None, name="dones_{}".format(agent_id))
  86. if config["use_local_critic"]:
  87. obs_space_n, act_space_n = [obs_space_n[agent_id]], [
  88. act_space_n[agent_id]
  89. ]
  90. obs_ph_n, act_ph_n = [obs_ph_n[agent_id]], [act_ph_n[agent_id]]
  91. new_obs_ph_n, new_act_ph_n = [new_obs_ph_n[agent_id]], [
  92. new_act_ph_n[agent_id]
  93. ]
  94. agent_id = 0
  95. # _____ Value Network
  96. # Build critic network for t.
  97. critic, _, critic_model_n, critic_vars = self._build_critic_network(
  98. obs_ph_n,
  99. act_ph_n,
  100. obs_space_n,
  101. act_space_n,
  102. config["use_state_preprocessor"],
  103. config["critic_hiddens"],
  104. getattr(tf.nn, config["critic_hidden_activation"]),
  105. scope="critic")
  106. # Build critic network for t + 1.
  107. target_critic, _, _, target_critic_vars = self._build_critic_network(
  108. new_obs_ph_n,
  109. new_act_ph_n,
  110. obs_space_n,
  111. act_space_n,
  112. config["use_state_preprocessor"],
  113. config["critic_hiddens"],
  114. getattr(tf.nn, config["critic_hidden_activation"]),
  115. scope="target_critic")
  116. # Build critic loss.
  117. td_error = tf.subtract(
  118. tf.stop_gradient(
  119. rew_ph + (1.0 - done_ph) *
  120. (config["gamma"]**config["n_step"]) * target_critic[:, 0]),
  121. critic[:, 0])
  122. critic_loss = tf.reduce_mean(td_error**2)
  123. # _____ Policy Network
  124. # Build actor network for t.
  125. act_sampler, actor_feature, actor_model, actor_vars = (
  126. self._build_actor_network(
  127. obs_ph_n[agent_id],
  128. obs_space_n[agent_id],
  129. act_space_n[agent_id],
  130. config["use_state_preprocessor"],
  131. config["actor_hiddens"],
  132. getattr(tf.nn, config["actor_hidden_activation"]),
  133. scope="actor"))
  134. # Build actor network for t + 1.
  135. self.new_obs_ph = new_obs_ph_n[agent_id]
  136. self.target_act_sampler, _, _, target_actor_vars = (
  137. self._build_actor_network(
  138. self.new_obs_ph,
  139. obs_space_n[agent_id],
  140. act_space_n[agent_id],
  141. config["use_state_preprocessor"],
  142. config["actor_hiddens"],
  143. getattr(tf.nn, config["actor_hidden_activation"]),
  144. scope="target_actor"))
  145. # Build actor loss.
  146. act_n = act_ph_n.copy()
  147. act_n[agent_id] = act_sampler
  148. critic, _, _, _ = self._build_critic_network(
  149. obs_ph_n,
  150. act_n,
  151. obs_space_n,
  152. act_space_n,
  153. config["use_state_preprocessor"],
  154. config["critic_hiddens"],
  155. getattr(tf.nn, config["critic_hidden_activation"]),
  156. scope="critic")
  157. actor_loss = -tf.reduce_mean(critic)
  158. if config["actor_feature_reg"] is not None:
  159. actor_loss += config["actor_feature_reg"] * tf.reduce_mean(
  160. actor_feature**2)
  161. # _____ Losses
  162. self.losses = {"critic": critic_loss, "actor": actor_loss}
  163. # _____ Optimizers
  164. self.optimizers = {
  165. "critic": tf1.train.AdamOptimizer(config["critic_lr"]),
  166. "actor": tf1.train.AdamOptimizer(config["actor_lr"])
  167. }
  168. # _____ Build variable update ops.
  169. self.tau = tf1.placeholder_with_default(
  170. config["tau"], shape=(), name="tau")
  171. def _make_target_update_op(vs, target_vs, tau):
  172. return [
  173. target_v.assign(tau * v + (1.0 - tau) * target_v)
  174. for v, target_v in zip(vs, target_vs)
  175. ]
  176. self.update_target_vars = _make_target_update_op(
  177. critic_vars + actor_vars, target_critic_vars + target_actor_vars,
  178. self.tau)
  179. def _make_set_weight_op(variables):
  180. vs = list()
  181. for v in variables.values():
  182. vs += v
  183. phs = [
  184. tf1.placeholder(
  185. tf.float32,
  186. shape=v.get_shape(),
  187. name=v.name.split(":")[0] + "_ph") for v in vs
  188. ]
  189. return tf.group(*[v.assign(ph) for v, ph in zip(vs, phs)]), phs
  190. self.vars = {
  191. "critic": critic_vars,
  192. "actor": actor_vars,
  193. "target_critic": target_critic_vars,
  194. "target_actor": target_actor_vars
  195. }
  196. self.update_vars, self.vars_ph = _make_set_weight_op(self.vars)
  197. # _____ TensorFlow Initialization
  198. sess = tf1.get_default_session()
  199. assert sess
  200. def _make_loss_inputs(placeholders):
  201. return [(ph.name.split("/")[-1].split(":")[0], ph)
  202. for ph in placeholders]
  203. loss_inputs = _make_loss_inputs(obs_ph_n + act_ph_n + new_obs_ph_n +
  204. new_act_ph_n + [rew_ph, done_ph])
  205. TFPolicy.__init__(
  206. self,
  207. obs_space,
  208. act_space,
  209. config=config,
  210. sess=sess,
  211. obs_input=obs_ph_n[agent_id],
  212. sampled_action=act_sampler,
  213. loss=actor_loss + critic_loss,
  214. loss_inputs=loss_inputs,
  215. dist_inputs=actor_feature)
  216. del self.view_requirements["prev_actions"]
  217. del self.view_requirements["prev_rewards"]
  218. self.get_session().run(tf1.global_variables_initializer())
  219. # Hard initial update
  220. self.update_target(1.0)
  221. @override(TFPolicy)
  222. def optimizer(self):
  223. return None
  224. @override(TFPolicy)
  225. def gradients(self, optimizer, loss):
  226. self.gvs = {
  227. k: minimize_and_clip(optimizer, self.losses[k], self.vars[k],
  228. self.config["grad_norm_clipping"])
  229. for k, optimizer in self.optimizers.items()
  230. }
  231. return self.gvs["critic"] + self.gvs["actor"]
  232. @override(TFPolicy)
  233. def build_apply_op(self, optimizer, grads_and_vars):
  234. critic_apply_op = self.optimizers["critic"].apply_gradients(
  235. self.gvs["critic"])
  236. with tf1.control_dependencies([tf1.assign_add(self.global_step, 1)]):
  237. with tf1.control_dependencies([critic_apply_op]):
  238. actor_apply_op = self.optimizers["actor"].apply_gradients(
  239. self.gvs["actor"])
  240. return actor_apply_op
  241. @override(TFPolicy)
  242. def extra_compute_action_feed_dict(self):
  243. return {}
  244. @override(TFPolicy)
  245. def extra_compute_grad_fetches(self):
  246. return {LEARNER_STATS_KEY: {}}
  247. @override(TFPolicy)
  248. def get_weights(self):
  249. var_list = []
  250. for var in self.vars.values():
  251. var_list += var
  252. return {"_state": self.get_session().run(var_list)}
  253. @override(TFPolicy)
  254. def set_weights(self, weights):
  255. self.get_session().run(
  256. self.update_vars,
  257. feed_dict=dict(zip(self.vars_ph, weights["_state"])))
  258. @override(Policy)
  259. def get_state(self):
  260. return TFPolicy.get_state(self)
  261. @override(Policy)
  262. def set_state(self, state):
  263. TFPolicy.set_state(self, state)
  264. def _build_critic_network(self,
  265. obs_n,
  266. act_n,
  267. obs_space_n,
  268. act_space_n,
  269. use_state_preprocessor,
  270. hiddens,
  271. activation=None,
  272. scope=None):
  273. with tf1.variable_scope(scope, reuse=tf1.AUTO_REUSE) as scope:
  274. if use_state_preprocessor:
  275. model_n = [
  276. ModelCatalog.get_model({
  277. SampleBatch.OBS: obs,
  278. "is_training": self._get_is_training_placeholder(),
  279. }, obs_space, act_space, 1, self.config["model"])
  280. for obs, obs_space, act_space in zip(
  281. obs_n, obs_space_n, act_space_n)
  282. ]
  283. out_n = [model.last_layer for model in model_n]
  284. out = tf.concat(out_n + act_n, axis=1)
  285. else:
  286. model_n = [None] * len(obs_n)
  287. out = tf.concat(obs_n + act_n, axis=1)
  288. for hidden in hiddens:
  289. out = tf1.layers.dense(
  290. out, units=hidden, activation=activation)
  291. feature = out
  292. out = tf1.layers.dense(feature, units=1, activation=None)
  293. return out, feature, model_n, tf1.global_variables(scope.name)
  294. def _build_actor_network(self,
  295. obs,
  296. obs_space,
  297. act_space,
  298. use_state_preprocessor,
  299. hiddens,
  300. activation=None,
  301. scope=None):
  302. with tf1.variable_scope(scope, reuse=tf1.AUTO_REUSE) as scope:
  303. if use_state_preprocessor:
  304. model = ModelCatalog.get_model({
  305. SampleBatch.OBS: obs,
  306. "is_training": self._get_is_training_placeholder(),
  307. }, obs_space, act_space, 1, self.config["model"])
  308. out = model.last_layer
  309. else:
  310. model = None
  311. out = obs
  312. for hidden in hiddens:
  313. out = tf1.layers.dense(
  314. out, units=hidden, activation=activation)
  315. feature = tf1.layers.dense(
  316. out, units=act_space.shape[0], activation=None)
  317. sampler = tfp.distributions.RelaxedOneHotCategorical(
  318. temperature=1.0, logits=feature).sample()
  319. return sampler, feature, model, tf1.global_variables(scope.name)
  320. def update_target(self, tau=None):
  321. if tau is not None:
  322. self.get_session().run(self.update_target_vars, {self.tau: tau})
  323. else:
  324. self.get_session().run(self.update_target_vars)