sac_tf_policy.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. """
  2. TensorFlow policy class used for SAC.
  3. """
  4. import gym
  5. from gym.spaces import Box, Discrete
  6. from functools import partial
  7. import logging
  8. from typing import Dict, List, Optional, Tuple, Type, Union
  9. import ray
  10. import ray.experimental.tf_utils
  11. from ray.rllib.agents.ddpg.ddpg_tf_policy import ComputeTDErrorMixin, \
  12. TargetNetworkMixin
  13. from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
  14. PRIO_WEIGHTS
  15. from ray.rllib.agents.sac.sac_tf_model import SACTFModel
  16. from ray.rllib.agents.sac.sac_torch_model import SACTorchModel
  17. from ray.rllib.evaluation.episode import Episode
  18. from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS
  19. from ray.rllib.models.modelv2 import ModelV2
  20. from ray.rllib.models.tf.tf_action_dist import Beta, Categorical, \
  21. DiagGaussian, Dirichlet, SquashedGaussian, TFActionDistribution
  22. from ray.rllib.policy.policy import Policy
  23. from ray.rllib.policy.sample_batch import SampleBatch
  24. from ray.rllib.policy.tf_policy_template import build_tf_policy
  25. from ray.rllib.utils.error import UnsupportedSpaceException
  26. from ray.rllib.utils.framework import get_variable, try_import_tf
  27. from ray.rllib.utils.spaces.simplex import Simplex
  28. from ray.rllib.utils.tf_utils import huber_loss
  29. from ray.rllib.utils.typing import AgentID, LocalOptimizer, ModelGradients, \
  30. TensorType, TrainerConfigDict
  31. tf1, tf, tfv = try_import_tf()
  32. logger = logging.getLogger(__name__)
  33. def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
  34. action_space: gym.spaces.Space,
  35. config: TrainerConfigDict) -> ModelV2:
  36. """Constructs the necessary ModelV2 for the Policy and returns it.
  37. Args:
  38. policy (Policy): The TFPolicy that will use the models.
  39. obs_space (gym.spaces.Space): The observation space.
  40. action_space (gym.spaces.Space): The action space.
  41. config (TrainerConfigDict): The SAC trainer's config dict.
  42. Returns:
  43. ModelV2: The ModelV2 to be used by the Policy. Note: An additional
  44. target model will be created in this function and assigned to
  45. `policy.target_model`.
  46. """
  47. # Force-ignore any additionally provided hidden layer sizes.
  48. # Everything should be configured using SAC's "Q_model" and "policy_model"
  49. # settings.
  50. policy_model_config = MODEL_DEFAULTS.copy()
  51. policy_model_config.update(config["policy_model"])
  52. q_model_config = MODEL_DEFAULTS.copy()
  53. q_model_config.update(config["Q_model"])
  54. default_model_cls = SACTorchModel if config["framework"] == "torch" \
  55. else SACTFModel
  56. model = ModelCatalog.get_model_v2(
  57. obs_space=obs_space,
  58. action_space=action_space,
  59. num_outputs=None,
  60. model_config=config["model"],
  61. framework=config["framework"],
  62. default_model=default_model_cls,
  63. name="sac_model",
  64. policy_model_config=policy_model_config,
  65. q_model_config=q_model_config,
  66. twin_q=config["twin_q"],
  67. initial_alpha=config["initial_alpha"],
  68. target_entropy=config["target_entropy"])
  69. assert isinstance(model, default_model_cls)
  70. # Create an exact copy of the model and store it in `policy.target_model`.
  71. # This will be used for tau-synched Q-target models that run behind the
  72. # actual Q-networks and are used for target q-value calculations in the
  73. # loss terms.
  74. policy.target_model = ModelCatalog.get_model_v2(
  75. obs_space=obs_space,
  76. action_space=action_space,
  77. num_outputs=None,
  78. model_config=config["model"],
  79. framework=config["framework"],
  80. default_model=default_model_cls,
  81. name="target_sac_model",
  82. policy_model_config=policy_model_config,
  83. q_model_config=q_model_config,
  84. twin_q=config["twin_q"],
  85. initial_alpha=config["initial_alpha"],
  86. target_entropy=config["target_entropy"])
  87. assert isinstance(policy.target_model, default_model_cls)
  88. return model
  89. def postprocess_trajectory(
  90. policy: Policy,
  91. sample_batch: SampleBatch,
  92. other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
  93. episode: Optional[Episode] = None) -> SampleBatch:
  94. """Postprocesses a trajectory and returns the processed trajectory.
  95. The trajectory contains only data from one episode and from one agent.
  96. - If `config.batch_mode=truncate_episodes` (default), sample_batch may
  97. contain a truncated (at-the-end) episode, in case the
  98. `config.rollout_fragment_length` was reached by the sampler.
  99. - If `config.batch_mode=complete_episodes`, sample_batch will contain
  100. exactly one episode (no matter how long).
  101. New columns can be added to sample_batch and existing ones may be altered.
  102. Args:
  103. policy (Policy): The Policy used to generate the trajectory
  104. (`sample_batch`)
  105. sample_batch (SampleBatch): The SampleBatch to postprocess.
  106. other_agent_batches (Optional[Dict[AgentID, SampleBatch]]): Optional
  107. dict of AgentIDs mapping to other agents' trajectory data (from the
  108. same episode). NOTE: The other agents use the same policy.
  109. episode (Optional[Episode]): Optional multi-agent episode
  110. object in which the agents operated.
  111. Returns:
  112. SampleBatch: The postprocessed, modified SampleBatch (or a new one).
  113. """
  114. return postprocess_nstep_and_prio(policy, sample_batch)
  115. def _get_dist_class(policy: Policy,
  116. config: TrainerConfigDict,
  117. action_space: gym.spaces.Space) -> \
  118. Type[TFActionDistribution]:
  119. """Helper function to return a dist class based on config and action space.
  120. Args:
  121. policy (Policy): The policy for which to return the action
  122. dist class.
  123. config (TrainerConfigDict): The Trainer's config dict.
  124. action_space (gym.spaces.Space): The action space used.
  125. Returns:
  126. Type[TFActionDistribution]: A TF distribution class.
  127. """
  128. if hasattr(policy, "dist_class") and policy.dist_class is not None:
  129. return policy.dist_class
  130. elif config["model"].get("custom_action_dist"):
  131. action_dist_class, _ = ModelCatalog.get_action_dist(
  132. action_space, config["model"], framework="tf")
  133. return action_dist_class
  134. elif isinstance(action_space, Discrete):
  135. return Categorical
  136. elif isinstance(action_space, Simplex):
  137. return Dirichlet
  138. else:
  139. assert isinstance(action_space, Box)
  140. if config["normalize_actions"]:
  141. return SquashedGaussian if \
  142. not config["_use_beta_distribution"] else Beta
  143. else:
  144. return DiagGaussian
  145. def get_distribution_inputs_and_class(
  146. policy: Policy,
  147. model: ModelV2,
  148. obs_batch: TensorType,
  149. *,
  150. explore: bool = True,
  151. **kwargs) \
  152. -> Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]:
  153. """The action distribution function to be used the algorithm.
  154. An action distribution function is used to customize the choice of action
  155. distribution class and the resulting action distribution inputs (to
  156. parameterize the distribution object).
  157. After parameterizing the distribution, a `sample()` call
  158. will be made on it to generate actions.
  159. Args:
  160. policy (Policy): The Policy being queried for actions and calling this
  161. function.
  162. model (SACTFModel): The SAC specific Model to use to generate the
  163. distribution inputs (see sac_tf|torch_model.py). Must support the
  164. `get_policy_output` method.
  165. obs_batch (TensorType): The observations to be used as inputs to the
  166. model.
  167. explore (bool): Whether to activate exploration or not.
  168. Returns:
  169. Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: The
  170. dist inputs, dist class, and a list of internal state outputs
  171. (in the RNN case).
  172. """
  173. # Get base-model (forward) output (this should be a noop call).
  174. forward_out, state_out = model(
  175. SampleBatch(
  176. obs=obs_batch, _is_training=policy._get_is_training_placeholder()),
  177. [], None)
  178. # Use the base output to get the policy outputs from the SAC model's
  179. # policy components.
  180. distribution_inputs = model.get_policy_output(forward_out)
  181. # Get a distribution class to be used with the just calculated dist-inputs.
  182. action_dist_class = _get_dist_class(policy, policy.config,
  183. policy.action_space)
  184. return distribution_inputs, action_dist_class, state_out
  185. def sac_actor_critic_loss(
  186. policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
  187. train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
  188. """Constructs the loss for the Soft Actor Critic.
  189. Args:
  190. policy (Policy): The Policy to calculate the loss for.
  191. model (ModelV2): The Model to calculate the loss for.
  192. dist_class (Type[ActionDistribution]: The action distr. class.
  193. train_batch (SampleBatch): The training data.
  194. Returns:
  195. Union[TensorType, List[TensorType]]: A single loss tensor or a list
  196. of loss tensors.
  197. """
  198. # Should be True only for debugging purposes (e.g. test cases)!
  199. deterministic = policy.config["_deterministic_loss"]
  200. _is_training = policy._get_is_training_placeholder()
  201. # Get the base model output from the train batch.
  202. model_out_t, _ = model(
  203. SampleBatch(
  204. obs=train_batch[SampleBatch.CUR_OBS], _is_training=_is_training),
  205. [], None)
  206. # Get the base model output from the next observations in the train batch.
  207. model_out_tp1, _ = model(
  208. SampleBatch(
  209. obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training),
  210. [], None)
  211. # Get the target model's base outputs from the next observations in the
  212. # train batch.
  213. target_model_out_tp1, _ = policy.target_model(
  214. SampleBatch(
  215. obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training),
  216. [], None)
  217. # Discrete actions case.
  218. if model.discrete:
  219. # Get all action probs directly from pi and form their logp.
  220. log_pis_t = tf.nn.log_softmax(model.get_policy_output(model_out_t), -1)
  221. policy_t = tf.math.exp(log_pis_t)
  222. log_pis_tp1 = tf.nn.log_softmax(
  223. model.get_policy_output(model_out_tp1), -1)
  224. policy_tp1 = tf.math.exp(log_pis_tp1)
  225. # Q-values.
  226. q_t = model.get_q_values(model_out_t)
  227. # Target Q-values.
  228. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1)
  229. if policy.config["twin_q"]:
  230. twin_q_t = model.get_twin_q_values(model_out_t)
  231. twin_q_tp1 = policy.target_model.get_twin_q_values(
  232. target_model_out_tp1)
  233. q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
  234. q_tp1 -= model.alpha * log_pis_tp1
  235. # Actually selected Q-values (from the actions batch).
  236. one_hot = tf.one_hot(
  237. train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1])
  238. q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1)
  239. if policy.config["twin_q"]:
  240. twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1)
  241. # Discrete case: "Best" means weighted by the policy (prob) outputs.
  242. q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1)
  243. q_tp1_best_masked = \
  244. (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \
  245. q_tp1_best
  246. # Continuous actions case.
  247. else:
  248. # Sample simgle actions from distribution.
  249. action_dist_class = _get_dist_class(policy, policy.config,
  250. policy.action_space)
  251. action_dist_t = action_dist_class(
  252. model.get_policy_output(model_out_t), policy.model)
  253. policy_t = action_dist_t.sample() if not deterministic else \
  254. action_dist_t.deterministic_sample()
  255. log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1)
  256. action_dist_tp1 = action_dist_class(
  257. model.get_policy_output(model_out_tp1), policy.model)
  258. policy_tp1 = action_dist_tp1.sample() if not deterministic else \
  259. action_dist_tp1.deterministic_sample()
  260. log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1)
  261. # Q-values for the actually selected actions.
  262. q_t = model.get_q_values(
  263. model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32))
  264. if policy.config["twin_q"]:
  265. twin_q_t = model.get_twin_q_values(
  266. model_out_t,
  267. tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32))
  268. # Q-values for current policy in given current state.
  269. q_t_det_policy = model.get_q_values(model_out_t, policy_t)
  270. if policy.config["twin_q"]:
  271. twin_q_t_det_policy = model.get_twin_q_values(
  272. model_out_t, policy_t)
  273. q_t_det_policy = tf.reduce_min(
  274. (q_t_det_policy, twin_q_t_det_policy), axis=0)
  275. # target q network evaluation
  276. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
  277. policy_tp1)
  278. if policy.config["twin_q"]:
  279. twin_q_tp1 = policy.target_model.get_twin_q_values(
  280. target_model_out_tp1, policy_tp1)
  281. # Take min over both twin-NNs.
  282. q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
  283. q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
  284. if policy.config["twin_q"]:
  285. twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
  286. q_tp1 -= model.alpha * log_pis_tp1
  287. q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
  288. q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES],
  289. tf.float32)) * q_tp1_best
  290. # Compute RHS of bellman equation for the Q-loss (critic(s)).
  291. q_t_selected_target = tf.stop_gradient(
  292. tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) +
  293. policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)
  294. # Compute the TD-error (potentially clipped).
  295. base_td_error = tf.math.abs(q_t_selected - q_t_selected_target)
  296. if policy.config["twin_q"]:
  297. twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target)
  298. td_error = 0.5 * (base_td_error + twin_td_error)
  299. else:
  300. td_error = base_td_error
  301. # Calculate one or two critic losses (2 in the twin_q case).
  302. prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
  303. critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))]
  304. if policy.config["twin_q"]:
  305. critic_loss.append(
  306. tf.reduce_mean(prio_weights * huber_loss(twin_td_error)))
  307. # Alpha- and actor losses.
  308. # Note: In the papers, alpha is used directly, here we take the log.
  309. # Discrete case: Multiply the action probs as weights with the original
  310. # loss terms (no expectations needed).
  311. if model.discrete:
  312. alpha_loss = tf.reduce_mean(
  313. tf.reduce_sum(
  314. tf.multiply(
  315. tf.stop_gradient(policy_t), -model.log_alpha *
  316. tf.stop_gradient(log_pis_t + model.target_entropy)),
  317. axis=-1))
  318. actor_loss = tf.reduce_mean(
  319. tf.reduce_sum(
  320. tf.multiply(
  321. # NOTE: No stop_grad around policy output here
  322. # (compare with q_t_det_policy for continuous case).
  323. policy_t,
  324. model.alpha * log_pis_t - tf.stop_gradient(q_t)),
  325. axis=-1))
  326. else:
  327. alpha_loss = -tf.reduce_mean(
  328. model.log_alpha *
  329. tf.stop_gradient(log_pis_t + model.target_entropy))
  330. actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy)
  331. # Save for stats function.
  332. policy.policy_t = policy_t
  333. policy.q_t = q_t
  334. policy.td_error = td_error
  335. policy.actor_loss = actor_loss
  336. policy.critic_loss = critic_loss
  337. policy.alpha_loss = alpha_loss
  338. policy.alpha_value = model.alpha
  339. policy.target_entropy = model.target_entropy
  340. # In a custom apply op we handle the losses separately, but return them
  341. # combined in one loss here.
  342. return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
  343. def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
  344. loss: TensorType) -> ModelGradients:
  345. """Gradients computing function (from loss tensor, using local optimizer).
  346. Note: For SAC, optimizer and loss are ignored b/c we have 3
  347. losses and 3 local optimizers (all stored in policy).
  348. `optimizer` will be used, though, in the tf-eager case b/c it is then a
  349. fake optimizer (OptimizerWrapper) object with a `tape` property to
  350. generate a GradientTape object for gradient recording.
  351. Args:
  352. policy (Policy): The Policy object that generated the loss tensor and
  353. that holds the given local optimizer.
  354. optimizer (LocalOptimizer): The tf (local) optimizer object to
  355. calculate the gradients with.
  356. loss (TensorType): The loss tensor for which gradients should be
  357. calculated.
  358. Returns:
  359. ModelGradients: List of the possibly clipped gradients- and variable
  360. tuples.
  361. """
  362. # Eager: Use GradientTape (which is a property of the `optimizer` object
  363. # (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py).
  364. if policy.config["framework"] in ["tf2", "tfe"]:
  365. tape = optimizer.tape
  366. pol_weights = policy.model.policy_variables()
  367. actor_grads_and_vars = list(
  368. zip(tape.gradient(policy.actor_loss, pol_weights), pol_weights))
  369. q_weights = policy.model.q_variables()
  370. if policy.config["twin_q"]:
  371. half_cutoff = len(q_weights) // 2
  372. grads_1 = tape.gradient(policy.critic_loss[0],
  373. q_weights[:half_cutoff])
  374. grads_2 = tape.gradient(policy.critic_loss[1],
  375. q_weights[half_cutoff:])
  376. critic_grads_and_vars = \
  377. list(zip(grads_1, q_weights[:half_cutoff])) + \
  378. list(zip(grads_2, q_weights[half_cutoff:]))
  379. else:
  380. critic_grads_and_vars = list(
  381. zip(
  382. tape.gradient(policy.critic_loss[0], q_weights),
  383. q_weights))
  384. alpha_vars = [policy.model.log_alpha]
  385. alpha_grads_and_vars = list(
  386. zip(tape.gradient(policy.alpha_loss, alpha_vars), alpha_vars))
  387. # Tf1.x: Use optimizer.compute_gradients()
  388. else:
  389. actor_grads_and_vars = policy._actor_optimizer.compute_gradients(
  390. policy.actor_loss, var_list=policy.model.policy_variables())
  391. q_weights = policy.model.q_variables()
  392. if policy.config["twin_q"]:
  393. half_cutoff = len(q_weights) // 2
  394. base_q_optimizer, twin_q_optimizer = policy._critic_optimizer
  395. critic_grads_and_vars = base_q_optimizer.compute_gradients(
  396. policy.critic_loss[0], var_list=q_weights[:half_cutoff]
  397. ) + twin_q_optimizer.compute_gradients(
  398. policy.critic_loss[1], var_list=q_weights[half_cutoff:])
  399. else:
  400. critic_grads_and_vars = policy._critic_optimizer[
  401. 0].compute_gradients(
  402. policy.critic_loss[0], var_list=q_weights)
  403. alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients(
  404. policy.alpha_loss, var_list=[policy.model.log_alpha])
  405. # Clip if necessary.
  406. if policy.config["grad_clip"]:
  407. clip_func = partial(
  408. tf.clip_by_norm, clip_norm=policy.config["grad_clip"])
  409. else:
  410. clip_func = tf.identity
  411. # Save grads and vars for later use in `build_apply_op`.
  412. policy._actor_grads_and_vars = [(clip_func(g), v)
  413. for (g, v) in actor_grads_and_vars
  414. if g is not None]
  415. policy._critic_grads_and_vars = [(clip_func(g), v)
  416. for (g, v) in critic_grads_and_vars
  417. if g is not None]
  418. policy._alpha_grads_and_vars = [(clip_func(g), v)
  419. for (g, v) in alpha_grads_and_vars
  420. if g is not None]
  421. grads_and_vars = (
  422. policy._actor_grads_and_vars + policy._critic_grads_and_vars +
  423. policy._alpha_grads_and_vars)
  424. return grads_and_vars
  425. def apply_gradients(
  426. policy: Policy, optimizer: LocalOptimizer,
  427. grads_and_vars: ModelGradients) -> Union["tf.Operation", None]:
  428. """Gradients applying function (from list of "grad_and_var" tuples).
  429. Note: For SAC, optimizer and grads_and_vars are ignored b/c we have 3
  430. losses and optimizers (stored in policy).
  431. Args:
  432. policy (Policy): The Policy object whose Model(s) the given gradients
  433. should be applied to.
  434. optimizer (LocalOptimizer): The tf (local) optimizer object through
  435. which to apply the gradients.
  436. grads_and_vars (ModelGradients): The list of grad_and_var tuples to
  437. apply via the given optimizer.
  438. Returns:
  439. Union[tf.Operation, None]: The tf op to be used to run the apply
  440. operation. None for eager mode.
  441. """
  442. actor_apply_ops = policy._actor_optimizer.apply_gradients(
  443. policy._actor_grads_and_vars)
  444. cgrads = policy._critic_grads_and_vars
  445. half_cutoff = len(cgrads) // 2
  446. if policy.config["twin_q"]:
  447. critic_apply_ops = [
  448. policy._critic_optimizer[0].apply_gradients(cgrads[:half_cutoff]),
  449. policy._critic_optimizer[1].apply_gradients(cgrads[half_cutoff:])
  450. ]
  451. else:
  452. critic_apply_ops = [
  453. policy._critic_optimizer[0].apply_gradients(cgrads)
  454. ]
  455. # Eager mode -> Just apply and return None.
  456. if policy.config["framework"] in ["tf2", "tfe"]:
  457. policy._alpha_optimizer.apply_gradients(policy._alpha_grads_and_vars)
  458. return
  459. # Tf static graph -> Return op.
  460. else:
  461. alpha_apply_ops = policy._alpha_optimizer.apply_gradients(
  462. policy._alpha_grads_and_vars,
  463. global_step=tf1.train.get_or_create_global_step())
  464. return tf.group([actor_apply_ops, alpha_apply_ops] + critic_apply_ops)
  465. def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
  466. """Stats function for SAC. Returns a dict with important loss stats.
  467. Args:
  468. policy (Policy): The Policy to generate stats for.
  469. train_batch (SampleBatch): The SampleBatch (already) used for training.
  470. Returns:
  471. Dict[str, TensorType]: The stats dict.
  472. """
  473. return {
  474. "mean_td_error": tf.reduce_mean(policy.td_error),
  475. "actor_loss": tf.reduce_mean(policy.actor_loss),
  476. "critic_loss": tf.reduce_mean(policy.critic_loss),
  477. "alpha_loss": tf.reduce_mean(policy.alpha_loss),
  478. "alpha_value": tf.reduce_mean(policy.alpha_value),
  479. "target_entropy": tf.constant(policy.target_entropy),
  480. "mean_q": tf.reduce_mean(policy.q_t),
  481. "max_q": tf.reduce_max(policy.q_t),
  482. "min_q": tf.reduce_min(policy.q_t),
  483. }
  484. class ActorCriticOptimizerMixin:
  485. """Mixin class to generate the necessary optimizers for actor-critic algos.
  486. - Creates global step for counting the number of update operations.
  487. - Creates separate optimizers for actor, critic, and alpha.
  488. """
  489. def __init__(self, config):
  490. # Eager mode.
  491. if config["framework"] in ["tf2", "tfe"]:
  492. self.global_step = get_variable(0, tf_name="global_step")
  493. self._actor_optimizer = tf.keras.optimizers.Adam(
  494. learning_rate=config["optimization"]["actor_learning_rate"])
  495. self._critic_optimizer = [
  496. tf.keras.optimizers.Adam(learning_rate=config["optimization"][
  497. "critic_learning_rate"])
  498. ]
  499. if config["twin_q"]:
  500. self._critic_optimizer.append(
  501. tf.keras.optimizers.Adam(learning_rate=config[
  502. "optimization"]["critic_learning_rate"]))
  503. self._alpha_optimizer = tf.keras.optimizers.Adam(
  504. learning_rate=config["optimization"]["entropy_learning_rate"])
  505. # Static graph mode.
  506. else:
  507. self.global_step = tf1.train.get_or_create_global_step()
  508. self._actor_optimizer = tf1.train.AdamOptimizer(
  509. learning_rate=config["optimization"]["actor_learning_rate"])
  510. self._critic_optimizer = [
  511. tf1.train.AdamOptimizer(learning_rate=config["optimization"][
  512. "critic_learning_rate"])
  513. ]
  514. if config["twin_q"]:
  515. self._critic_optimizer.append(
  516. tf1.train.AdamOptimizer(learning_rate=config[
  517. "optimization"]["critic_learning_rate"]))
  518. self._alpha_optimizer = tf1.train.AdamOptimizer(
  519. learning_rate=config["optimization"]["entropy_learning_rate"])
  520. def setup_early_mixins(policy: Policy, obs_space: gym.spaces.Space,
  521. action_space: gym.spaces.Space,
  522. config: TrainerConfigDict) -> None:
  523. """Call mixin classes' constructors before Policy's initialization.
  524. Adds the necessary optimizers to the given Policy.
  525. Args:
  526. policy (Policy): The Policy object.
  527. obs_space (gym.spaces.Space): The Policy's observation space.
  528. action_space (gym.spaces.Space): The Policy's action space.
  529. config (TrainerConfigDict): The Policy's config.
  530. """
  531. ActorCriticOptimizerMixin.__init__(policy, config)
  532. def setup_mid_mixins(policy: Policy, obs_space: gym.spaces.Space,
  533. action_space: gym.spaces.Space,
  534. config: TrainerConfigDict) -> None:
  535. """Call mixin classes' constructors before Policy's loss initialization.
  536. Adds the `compute_td_error` method to the given policy.
  537. Calling `compute_td_error` with batch data will re-calculate the loss
  538. on that batch AND return the per-batch-item TD-error for prioritized
  539. replay buffer record weight updating (in case a prioritized replay buffer
  540. is used).
  541. Args:
  542. policy (Policy): The Policy object.
  543. obs_space (gym.spaces.Space): The Policy's observation space.
  544. action_space (gym.spaces.Space): The Policy's action space.
  545. config (TrainerConfigDict): The Policy's config.
  546. """
  547. ComputeTDErrorMixin.__init__(policy, sac_actor_critic_loss)
  548. def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
  549. action_space: gym.spaces.Space,
  550. config: TrainerConfigDict) -> None:
  551. """Call mixin classes' constructors after Policy initialization.
  552. Adds the `update_target` method to the given policy.
  553. Calling `update_target` updates all target Q-networks' weights from their
  554. respective "main" Q-metworks, based on tau (smooth, partial updating).
  555. Args:
  556. policy (Policy): The Policy object.
  557. obs_space (gym.spaces.Space): The Policy's observation space.
  558. action_space (gym.spaces.Space): The Policy's action space.
  559. config (TrainerConfigDict): The Policy's config.
  560. """
  561. TargetNetworkMixin.__init__(policy, config)
  562. def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
  563. action_space: gym.spaces.Space,
  564. config: TrainerConfigDict) -> None:
  565. """Validates the observation- and action spaces used for the Policy.
  566. Args:
  567. policy (Policy): The policy, whose spaces are being validated.
  568. observation_space (gym.spaces.Space): The observation space to
  569. validate.
  570. action_space (gym.spaces.Space): The action space to validate.
  571. config (TrainerConfigDict): The Policy's config dict.
  572. Raises:
  573. UnsupportedSpaceException: If one of the spaces is not supported.
  574. """
  575. # Only support single Box or single Discrete spaces.
  576. if not isinstance(action_space, (Box, Discrete, Simplex)):
  577. raise UnsupportedSpaceException(
  578. "Action space ({}) of {} is not supported for "
  579. "SAC. Must be [Box|Discrete|Simplex].".format(
  580. action_space, policy))
  581. # If Box, make sure it's a 1D vector space.
  582. elif isinstance(action_space,
  583. (Box, Simplex)) and len(action_space.shape) > 1:
  584. raise UnsupportedSpaceException(
  585. "Action space ({}) of {} has multiple dimensions "
  586. "{}. ".format(action_space, policy, action_space.shape) +
  587. "Consider reshaping this into a single dimension, "
  588. "using a Tuple action space, or the multi-agent API.")
  589. # Build a child class of `DynamicTFPolicy`, given the custom functions defined
  590. # above.
  591. SACTFPolicy = build_tf_policy(
  592. name="SACTFPolicy",
  593. get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
  594. make_model=build_sac_model,
  595. postprocess_fn=postprocess_trajectory,
  596. action_distribution_fn=get_distribution_inputs_and_class,
  597. loss_fn=sac_actor_critic_loss,
  598. stats_fn=stats,
  599. compute_gradients_fn=compute_and_clip_gradients,
  600. apply_gradients_fn=apply_gradients,
  601. extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
  602. mixins=[
  603. TargetNetworkMixin, ActorCriticOptimizerMixin, ComputeTDErrorMixin
  604. ],
  605. validate_spaces=validate_spaces,
  606. before_init=setup_early_mixins,
  607. before_loss_init=setup_mid_mixins,
  608. after_init=setup_late_mixins,
  609. )