ddpg_torch_model.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import numpy as np
  2. import gym
  3. from typing import List, Dict, Union, Optional
  4. from ray.rllib.models.torch.misc import SlimFC
  5. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  6. from ray.rllib.models.utils import get_activation_fn
  7. from ray.rllib.utils.framework import try_import_torch
  8. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  9. torch, nn = try_import_torch()
  10. class DDPGTorchModel(TorchModelV2, nn.Module):
  11. """Extension of standard TorchModelV2 for DDPG.
  12. Data flow:
  13. obs -> forward() -> model_out
  14. model_out -> get_policy_output() -> pi(s)
  15. model_out, actions -> get_q_values() -> Q(s, a)
  16. model_out, actions -> get_twin_q_values() -> Q_twin(s, a)
  17. Note that this class by itself is not a valid model unless you
  18. implement forward() in a subclass."""
  19. def __init__(
  20. self,
  21. obs_space: gym.spaces.Space,
  22. action_space: gym.spaces.Space,
  23. num_outputs: int,
  24. model_config: ModelConfigDict,
  25. name: str,
  26. # Extra DDPGActionModel args:
  27. actor_hiddens: Optional[List[int]] = None,
  28. actor_hidden_activation: str = "relu",
  29. critic_hiddens: Optional[List[int]] = None,
  30. critic_hidden_activation: str = "relu",
  31. twin_q: bool = False,
  32. add_layer_norm: bool = False):
  33. """Initialize variables of this model.
  34. Extra model kwargs:
  35. actor_hidden_activation (str): activation for actor network
  36. actor_hiddens (list): hidden layers sizes for actor network
  37. critic_hidden_activation (str): activation for critic network
  38. critic_hiddens (list): hidden layers sizes for critic network
  39. twin_q (bool): build twin Q networks.
  40. add_layer_norm (bool): Enable layer norm (for param noise).
  41. Note that the core layers for forward() are not defined here, this
  42. only defines the layers for the output heads. Those layers for
  43. forward() should be defined in subclasses of DDPGTorchModel.
  44. """
  45. if actor_hiddens is None:
  46. actor_hiddens = [256, 256]
  47. if critic_hiddens is None:
  48. critic_hiddens = [256, 256]
  49. nn.Module.__init__(self)
  50. super(DDPGTorchModel, self).__init__(obs_space, action_space,
  51. num_outputs, model_config, name)
  52. self.bounded = np.logical_and(self.action_space.bounded_above,
  53. self.action_space.bounded_below).any()
  54. self.action_dim = np.product(self.action_space.shape)
  55. # Build the policy network.
  56. self.policy_model = nn.Sequential()
  57. ins = num_outputs
  58. self.obs_ins = ins
  59. activation = get_activation_fn(
  60. actor_hidden_activation, framework="torch")
  61. for i, n in enumerate(actor_hiddens):
  62. self.policy_model.add_module(
  63. "action_{}".format(i),
  64. SlimFC(
  65. ins,
  66. n,
  67. initializer=torch.nn.init.xavier_uniform_,
  68. activation_fn=activation))
  69. # Add LayerNorm after each Dense.
  70. if add_layer_norm:
  71. self.policy_model.add_module("LayerNorm_A_{}".format(i),
  72. nn.LayerNorm(n))
  73. ins = n
  74. self.policy_model.add_module(
  75. "action_out",
  76. SlimFC(
  77. ins,
  78. self.action_dim,
  79. initializer=torch.nn.init.xavier_uniform_,
  80. activation_fn=None))
  81. # Use sigmoid to scale to [0,1], but also double magnitude of input to
  82. # emulate behaviour of tanh activation used in DDPG and TD3 papers.
  83. # After sigmoid squashing, re-scale to env action space bounds.
  84. class _Lambda(nn.Module):
  85. def __init__(self_):
  86. super().__init__()
  87. low_action = nn.Parameter(
  88. torch.from_numpy(self.action_space.low).float())
  89. low_action.requires_grad = False
  90. self_.register_parameter("low_action", low_action)
  91. action_range = nn.Parameter(
  92. torch.from_numpy(self.action_space.high -
  93. self.action_space.low).float())
  94. action_range.requires_grad = False
  95. self_.register_parameter("action_range", action_range)
  96. def forward(self_, x):
  97. sigmoid_out = nn.Sigmoid()(2.0 * x)
  98. squashed = self_.action_range * sigmoid_out + self_.low_action
  99. return squashed
  100. # Only squash if we have bounded actions.
  101. if self.bounded:
  102. self.policy_model.add_module("action_out_squashed", _Lambda())
  103. # Build the Q-net(s), including target Q-net(s).
  104. def build_q_net(name_):
  105. activation = get_activation_fn(
  106. critic_hidden_activation, framework="torch")
  107. # For continuous actions: Feed obs and actions (concatenated)
  108. # through the NN. For discrete actions, only obs.
  109. q_net = nn.Sequential()
  110. ins = self.obs_ins + self.action_dim
  111. for i, n in enumerate(critic_hiddens):
  112. q_net.add_module(
  113. "{}_hidden_{}".format(name_, i),
  114. SlimFC(
  115. ins,
  116. n,
  117. initializer=torch.nn.init.xavier_uniform_,
  118. activation_fn=activation))
  119. ins = n
  120. q_net.add_module(
  121. "{}_out".format(name_),
  122. SlimFC(
  123. ins,
  124. 1,
  125. initializer=torch.nn.init.xavier_uniform_,
  126. activation_fn=None))
  127. return q_net
  128. self.q_model = build_q_net("q")
  129. if twin_q:
  130. self.twin_q_model = build_q_net("twin_q")
  131. else:
  132. self.twin_q_model = None
  133. def get_q_values(self, model_out: TensorType,
  134. actions: TensorType) -> TensorType:
  135. """Return the Q estimates for the most recent forward pass.
  136. This implements Q(s, a).
  137. Args:
  138. model_out (Tensor): obs embeddings from the model layers, of shape
  139. [BATCH_SIZE, num_outputs].
  140. actions (Tensor): Actions to return the Q-values for.
  141. Shape: [BATCH_SIZE, action_dim].
  142. Returns:
  143. tensor of shape [BATCH_SIZE].
  144. """
  145. return self.q_model(torch.cat([model_out, actions], -1))
  146. def get_twin_q_values(self, model_out: TensorType,
  147. actions: TensorType) -> TensorType:
  148. """Same as get_q_values but using the twin Q net.
  149. This implements the twin Q(s, a).
  150. Args:
  151. model_out (Tensor): obs embeddings from the model layers, of shape
  152. [BATCH_SIZE, num_outputs].
  153. actions (Optional[Tensor]): Actions to return the Q-values for.
  154. Shape: [BATCH_SIZE, action_dim].
  155. Returns:
  156. tensor of shape [BATCH_SIZE].
  157. """
  158. return self.twin_q_model(torch.cat([model_out, actions], -1))
  159. def get_policy_output(self, model_out: TensorType) -> TensorType:
  160. """Return the action output for the most recent forward pass.
  161. This outputs the support for pi(s). For continuous action spaces, this
  162. is the action directly. For discrete, is is the mean / std dev.
  163. Args:
  164. model_out (Tensor): obs embeddings from the model layers, of shape
  165. [BATCH_SIZE, num_outputs].
  166. Returns:
  167. tensor of shape [BATCH_SIZE, action_out_size]
  168. """
  169. return self.policy_model(model_out)
  170. def policy_variables(self, as_dict: bool = False
  171. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  172. """Return the list of variables for the policy net."""
  173. if as_dict:
  174. return self.policy_model.state_dict()
  175. return list(self.policy_model.parameters())
  176. def q_variables(self, as_dict=False
  177. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  178. """Return the list of variables for Q / twin Q nets."""
  179. if as_dict:
  180. return {
  181. **self.q_model.state_dict(),
  182. **(self.twin_q_model.state_dict() if self.twin_q_model else {})
  183. }
  184. return list(self.q_model.parameters()) + \
  185. (list(self.twin_q_model.parameters()) if self.twin_q_model else [])