sac_torch_model.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import gym
  2. from gym.spaces import Box, Discrete
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from typing import Dict, List, Optional
  6. from ray.rllib.models.catalog import ModelCatalog
  7. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  8. from ray.rllib.utils import force_list
  9. from ray.rllib.utils.annotations import override
  10. from ray.rllib.utils.framework import try_import_torch
  11. from ray.rllib.utils.spaces.simplex import Simplex
  12. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  13. torch, nn = try_import_torch()
  14. class SACTorchModel(TorchModelV2, nn.Module):
  15. """Extension of the standard TorchModelV2 for SAC.
  16. To customize, do one of the following:
  17. - sub-class SACTorchModel and override one or more of its methods.
  18. - Use SAC's `Q_model` and `policy_model` keys to tweak the default model
  19. behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
  20. - Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
  21. to specify your own custom Q-model(s) and policy-models, which will be
  22. created within this SACTFModel (see `build_policy_model` and
  23. `build_q_model`.
  24. Note: It is not recommended to override the `forward` method for SAC. This
  25. would lead to shared weights (between policy and Q-nets), which will then
  26. not be optimized by either of the critic- or actor-optimizers!
  27. Data flow:
  28. `obs` -> forward() (should stay a noop method!) -> `model_out`
  29. `model_out` -> get_policy_output() -> pi(actions|obs)
  30. `model_out`, `actions` -> get_q_values() -> Q(s, a)
  31. `model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a)
  32. """
  33. def __init__(self,
  34. obs_space: gym.spaces.Space,
  35. action_space: gym.spaces.Space,
  36. num_outputs: Optional[int],
  37. model_config: ModelConfigDict,
  38. name: str,
  39. policy_model_config: ModelConfigDict = None,
  40. q_model_config: ModelConfigDict = None,
  41. twin_q: bool = False,
  42. initial_alpha: float = 1.0,
  43. target_entropy: Optional[float] = None):
  44. """Initializes a SACTorchModel instance.
  45. 7
  46. Args:
  47. policy_model_config (ModelConfigDict): The config dict for the
  48. policy network.
  49. q_model_config (ModelConfigDict): The config dict for the
  50. Q-network(s) (2 if twin_q=True).
  51. twin_q (bool): Build twin Q networks (Q-net and target) for more
  52. stable Q-learning.
  53. initial_alpha (float): The initial value for the to-be-optimized
  54. alpha parameter (default: 1.0).
  55. target_entropy (Optional[float]): A target entropy value for
  56. the to-be-optimized alpha parameter. If None, will use the
  57. defaults described in the papers for SAC (and discrete SAC).
  58. Note that the core layers for forward() are not defined here, this
  59. only defines the layers for the output heads. Those layers for
  60. forward() should be defined in subclasses of SACModel.
  61. """
  62. nn.Module.__init__(self)
  63. super(SACTorchModel, self).__init__(obs_space, action_space,
  64. num_outputs, model_config, name)
  65. if isinstance(action_space, Discrete):
  66. self.action_dim = action_space.n
  67. self.discrete = True
  68. action_outs = q_outs = self.action_dim
  69. elif isinstance(action_space, Box):
  70. self.action_dim = np.product(action_space.shape)
  71. self.discrete = False
  72. action_outs = 2 * self.action_dim
  73. q_outs = 1
  74. else:
  75. assert isinstance(action_space, Simplex)
  76. self.action_dim = np.product(action_space.shape)
  77. self.discrete = False
  78. action_outs = self.action_dim
  79. q_outs = 1
  80. # Build the policy network.
  81. self.action_model = self.build_policy_model(
  82. self.obs_space, action_outs, policy_model_config, "policy_model")
  83. # Build the Q-network(s).
  84. self.q_net = self.build_q_model(self.obs_space, self.action_space,
  85. q_outs, q_model_config, "q")
  86. if twin_q:
  87. self.twin_q_net = self.build_q_model(self.obs_space,
  88. self.action_space, q_outs,
  89. q_model_config, "twin_q")
  90. else:
  91. self.twin_q_net = None
  92. log_alpha = nn.Parameter(
  93. torch.from_numpy(np.array([np.log(initial_alpha)])).float())
  94. self.register_parameter("log_alpha", log_alpha)
  95. # Auto-calculate the target entropy.
  96. if target_entropy is None or target_entropy == "auto":
  97. # See hyperparams in [2] (README.md).
  98. if self.discrete:
  99. target_entropy = 0.98 * np.array(
  100. -np.log(1.0 / action_space.n), dtype=np.float32)
  101. # See [1] (README.md).
  102. else:
  103. target_entropy = -np.prod(action_space.shape)
  104. target_entropy = nn.Parameter(
  105. torch.from_numpy(np.array([target_entropy])).float(),
  106. requires_grad=False)
  107. self.register_parameter("target_entropy", target_entropy)
  108. @override(TorchModelV2)
  109. def forward(self, input_dict: Dict[str, TensorType],
  110. state: List[TensorType],
  111. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  112. """The common (Q-net and policy-net) forward pass.
  113. NOTE: It is not(!) recommended to override this method as it would
  114. introduce a shared pre-network, which would be updated by both
  115. actor- and critic optimizers.
  116. """
  117. return input_dict["obs"], state
  118. def build_policy_model(self, obs_space, num_outputs, policy_model_config,
  119. name):
  120. """Builds the policy model used by this SAC.
  121. Override this method in a sub-class of SACTFModel to implement your
  122. own policy net. Alternatively, simply set `custom_model` within the
  123. top level SAC `policy_model` config key to make this default
  124. implementation of `build_policy_model` use your custom policy network.
  125. Returns:
  126. TorchModelV2: The TorchModelV2 policy sub-model.
  127. """
  128. model = ModelCatalog.get_model_v2(
  129. obs_space,
  130. self.action_space,
  131. num_outputs,
  132. policy_model_config,
  133. framework="torch",
  134. name=name)
  135. return model
  136. def build_q_model(self, obs_space, action_space, num_outputs,
  137. q_model_config, name):
  138. """Builds one of the (twin) Q-nets used by this SAC.
  139. Override this method in a sub-class of SACTFModel to implement your
  140. own Q-nets. Alternatively, simply set `custom_model` within the
  141. top level SAC `Q_model` config key to make this default implementation
  142. of `build_q_model` use your custom Q-nets.
  143. Returns:
  144. TorchModelV2: The TorchModelV2 Q-net sub-model.
  145. """
  146. self.concat_obs_and_actions = False
  147. if self.discrete:
  148. input_space = obs_space
  149. else:
  150. orig_space = getattr(obs_space, "original_space", obs_space)
  151. if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
  152. input_space = Box(
  153. float("-inf"),
  154. float("inf"),
  155. shape=(orig_space.shape[0] + action_space.shape[0], ))
  156. self.concat_obs_and_actions = True
  157. else:
  158. if isinstance(orig_space, gym.spaces.Tuple):
  159. spaces = list(orig_space.spaces)
  160. elif isinstance(orig_space, gym.spaces.Dict):
  161. spaces = list(orig_space.spaces.values())
  162. else:
  163. spaces = [obs_space]
  164. input_space = gym.spaces.Tuple(spaces + [action_space])
  165. model = ModelCatalog.get_model_v2(
  166. input_space,
  167. action_space,
  168. num_outputs,
  169. q_model_config,
  170. framework="torch",
  171. name=name)
  172. return model
  173. def get_q_values(self,
  174. model_out: TensorType,
  175. actions: Optional[TensorType] = None) -> TensorType:
  176. """Returns Q-values, given the output of self.__call__().
  177. This implements Q(s, a) -> [single Q-value] for the continuous case and
  178. Q(s) -> [Q-values for all actions] for the discrete case.
  179. Args:
  180. model_out (TensorType): Feature outputs from the model layers
  181. (result of doing `self.__call__(obs)`).
  182. actions (Optional[TensorType]): Continuous action batch to return
  183. Q-values for. Shape: [BATCH_SIZE, action_dim]. If None
  184. (discrete action case), return Q-values for all actions.
  185. Returns:
  186. TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
  187. """
  188. return self._get_q_value(model_out, actions, self.q_net)
  189. def get_twin_q_values(self,
  190. model_out: TensorType,
  191. actions: Optional[TensorType] = None) -> TensorType:
  192. """Same as get_q_values but using the twin Q net.
  193. This implements the twin Q(s, a).
  194. Args:
  195. model_out (TensorType): Feature outputs from the model layers
  196. (result of doing `self.__call__(obs)`).
  197. actions (Optional[Tensor]): Actions to return the Q-values for.
  198. Shape: [BATCH_SIZE, action_dim]. If None (discrete action
  199. case), return Q-values for all actions.
  200. Returns:
  201. TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
  202. """
  203. return self._get_q_value(model_out, actions, self.twin_q_net)
  204. def _get_q_value(self, model_out, actions, net):
  205. # Model outs may come as original Tuple observations, concat them
  206. # here if this is the case.
  207. if isinstance(net.obs_space, Box):
  208. if isinstance(model_out, (list, tuple)):
  209. model_out = torch.cat(model_out, dim=-1)
  210. elif isinstance(model_out, dict):
  211. model_out = torch.cat(list(model_out.values()), dim=-1)
  212. elif isinstance(model_out, dict):
  213. model_out = list(model_out.values())
  214. # Continuous case -> concat actions to model_out.
  215. if actions is not None:
  216. if self.concat_obs_and_actions:
  217. input_dict = {"obs": torch.cat([model_out, actions], dim=-1)}
  218. else:
  219. # TODO(junogng) : SampleBatch doesn't support list columns yet.
  220. # Use ModelInputDict.
  221. input_dict = {"obs": force_list(model_out) + [actions]}
  222. # Discrete case -> return q-vals for all actions.
  223. else:
  224. input_dict = {"obs": model_out}
  225. # Switch on training mode (when getting Q-values, we are usually in
  226. # training).
  227. input_dict["is_training"] = True
  228. out, _ = net(input_dict, [], None)
  229. return out
  230. def get_policy_output(self, model_out: TensorType) -> TensorType:
  231. """Returns policy outputs, given the output of self.__call__().
  232. For continuous action spaces, these will be the mean/stddev
  233. distribution inputs for the (SquashedGaussian) action distribution.
  234. For discrete action spaces, these will be the logits for a categorical
  235. distribution.
  236. Args:
  237. model_out (TensorType): Feature outputs from the model layers
  238. (result of doing `self.__call__(obs)`).
  239. Returns:
  240. TensorType: Distribution inputs for sampling actions.
  241. """
  242. # Model outs may come as original Tuple observations, concat them
  243. # here if this is the case.
  244. if isinstance(self.action_model.obs_space, Box):
  245. if isinstance(model_out, (list, tuple)):
  246. model_out = torch.cat(model_out, dim=-1)
  247. elif isinstance(model_out, dict):
  248. model_out = torch.cat(
  249. [
  250. torch.unsqueeze(val, 1) if len(val.shape) == 1 else val
  251. for val in tree.flatten(model_out.values())
  252. ],
  253. dim=-1)
  254. out, _ = self.action_model({"obs": model_out}, [], None)
  255. return out
  256. def policy_variables(self):
  257. """Return the list of variables for the policy net."""
  258. return self.action_model.variables()
  259. def q_variables(self):
  260. """Return the list of variables for Q / twin Q nets."""
  261. return self.q_net.variables() + (self.twin_q_net.variables()
  262. if self.twin_q_net else [])