sac_tf_model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import gym
  2. from gym.spaces import Box, Discrete
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from typing import Dict, List, Optional
  6. from ray.rllib.models.catalog import ModelCatalog
  7. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  8. from ray.rllib.utils import force_list
  9. from ray.rllib.utils.annotations import override
  10. from ray.rllib.utils.framework import try_import_tf
  11. from ray.rllib.utils.spaces.simplex import Simplex
  12. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  13. tf1, tf, tfv = try_import_tf()
  14. class SACTFModel(TFModelV2):
  15. """Extension of the standard TFModelV2 for SAC.
  16. To customize, do one of the following:
  17. - sub-class SACTFModel and override one or more of its methods.
  18. - Use SAC's `Q_model` and `policy_model` keys to tweak the default model
  19. behaviors (e.g. fcnet_hiddens, conv_filters, etc..).
  20. - Use SAC's `Q_model->custom_model` and `policy_model->custom_model` keys
  21. to specify your own custom Q-model(s) and policy-models, which will be
  22. created within this SACTFModel (see `build_policy_model` and
  23. `build_q_model`.
  24. Note: It is not recommended to override the `forward` method for SAC. This
  25. would lead to shared weights (between policy and Q-nets), which will then
  26. not be optimized by either of the critic- or actor-optimizers!
  27. Data flow:
  28. `obs` -> forward() (should stay a noop method!) -> `model_out`
  29. `model_out` -> get_policy_output() -> pi(actions|obs)
  30. `model_out`, `actions` -> get_q_values() -> Q(s, a)
  31. `model_out`, `actions` -> get_twin_q_values() -> Q_twin(s, a)
  32. """
  33. def __init__(self,
  34. obs_space: gym.spaces.Space,
  35. action_space: gym.spaces.Space,
  36. num_outputs: Optional[int],
  37. model_config: ModelConfigDict,
  38. name: str,
  39. policy_model_config: ModelConfigDict = None,
  40. q_model_config: ModelConfigDict = None,
  41. twin_q: bool = False,
  42. initial_alpha: float = 1.0,
  43. target_entropy: Optional[float] = None):
  44. """Initialize a SACTFModel instance.
  45. Args:
  46. policy_model_config (ModelConfigDict): The config dict for the
  47. policy network.
  48. q_model_config (ModelConfigDict): The config dict for the
  49. Q-network(s) (2 if twin_q=True).
  50. twin_q (bool): Build twin Q networks (Q-net and target) for more
  51. stable Q-learning.
  52. initial_alpha (float): The initial value for the to-be-optimized
  53. alpha parameter (default: 1.0).
  54. target_entropy (Optional[float]): A target entropy value for
  55. the to-be-optimized alpha parameter. If None, will use the
  56. defaults described in the papers for SAC (and discrete SAC).
  57. Note that the core layers for forward() are not defined here, this
  58. only defines the layers for the output heads. Those layers for
  59. forward() should be defined in subclasses of SACModel.
  60. """
  61. super(SACTFModel, self).__init__(obs_space, action_space, num_outputs,
  62. model_config, name)
  63. if isinstance(action_space, Discrete):
  64. self.action_dim = action_space.n
  65. self.discrete = True
  66. action_outs = q_outs = self.action_dim
  67. elif isinstance(action_space, Box):
  68. self.action_dim = np.product(action_space.shape)
  69. self.discrete = False
  70. action_outs = 2 * self.action_dim
  71. q_outs = 1
  72. else:
  73. assert isinstance(action_space, Simplex)
  74. self.action_dim = np.product(action_space.shape)
  75. self.discrete = False
  76. action_outs = self.action_dim
  77. q_outs = 1
  78. self.action_model = self.build_policy_model(
  79. self.obs_space, action_outs, policy_model_config, "policy_model")
  80. self.q_net = self.build_q_model(self.obs_space, self.action_space,
  81. q_outs, q_model_config, "q")
  82. if twin_q:
  83. self.twin_q_net = self.build_q_model(self.obs_space,
  84. self.action_space, q_outs,
  85. q_model_config, "twin_q")
  86. else:
  87. self.twin_q_net = None
  88. self.log_alpha = tf.Variable(
  89. np.log(initial_alpha), dtype=tf.float32, name="log_alpha")
  90. self.alpha = tf.exp(self.log_alpha)
  91. # Auto-calculate the target entropy.
  92. if target_entropy is None or target_entropy == "auto":
  93. # See hyperparams in [2] (README.md).
  94. if self.discrete:
  95. target_entropy = 0.98 * np.array(
  96. -np.log(1.0 / action_space.n), dtype=np.float32)
  97. # See [1] (README.md).
  98. else:
  99. target_entropy = -np.prod(action_space.shape)
  100. self.target_entropy = target_entropy
  101. @override(TFModelV2)
  102. def forward(self, input_dict: Dict[str, TensorType],
  103. state: List[TensorType],
  104. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  105. """The common (Q-net and policy-net) forward pass.
  106. NOTE: It is not(!) recommended to override this method as it would
  107. introduce a shared pre-network, which would be updated by both
  108. actor- and critic optimizers.
  109. """
  110. return input_dict["obs"], state
  111. def build_policy_model(self, obs_space, num_outputs, policy_model_config,
  112. name):
  113. """Builds the policy model used by this SAC.
  114. Override this method in a sub-class of SACTFModel to implement your
  115. own policy net. Alternatively, simply set `custom_model` within the
  116. top level SAC `policy_model` config key to make this default
  117. implementation of `build_policy_model` use your custom policy network.
  118. Returns:
  119. TFModelV2: The TFModelV2 policy sub-model.
  120. """
  121. model = ModelCatalog.get_model_v2(
  122. obs_space,
  123. self.action_space,
  124. num_outputs,
  125. policy_model_config,
  126. framework="tf",
  127. name=name)
  128. return model
  129. def build_q_model(self, obs_space, action_space, num_outputs,
  130. q_model_config, name):
  131. """Builds one of the (twin) Q-nets used by this SAC.
  132. Override this method in a sub-class of SACTFModel to implement your
  133. own Q-nets. Alternatively, simply set `custom_model` within the
  134. top level SAC `Q_model` config key to make this default implementation
  135. of `build_q_model` use your custom Q-nets.
  136. Returns:
  137. TFModelV2: The TFModelV2 Q-net sub-model.
  138. """
  139. self.concat_obs_and_actions = False
  140. if self.discrete:
  141. input_space = obs_space
  142. else:
  143. orig_space = getattr(obs_space, "original_space", obs_space)
  144. if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
  145. input_space = Box(
  146. float("-inf"),
  147. float("inf"),
  148. shape=(orig_space.shape[0] + action_space.shape[0], ))
  149. self.concat_obs_and_actions = True
  150. else:
  151. if isinstance(orig_space, gym.spaces.Tuple):
  152. spaces = list(orig_space.spaces)
  153. elif isinstance(orig_space, gym.spaces.Dict):
  154. spaces = list(orig_space.spaces.values())
  155. else:
  156. spaces = [obs_space]
  157. input_space = gym.spaces.Tuple(spaces + [action_space])
  158. model = ModelCatalog.get_model_v2(
  159. input_space,
  160. action_space,
  161. num_outputs,
  162. q_model_config,
  163. framework="tf",
  164. name=name)
  165. return model
  166. def get_q_values(self,
  167. model_out: TensorType,
  168. actions: Optional[TensorType] = None) -> TensorType:
  169. """Returns Q-values, given the output of self.__call__().
  170. This implements Q(s, a) -> [single Q-value] for the continuous case and
  171. Q(s) -> [Q-values for all actions] for the discrete case.
  172. Args:
  173. model_out (TensorType): Feature outputs from the model layers
  174. (result of doing `self.__call__(obs)`).
  175. actions (Optional[TensorType]): Continuous action batch to return
  176. Q-values for. Shape: [BATCH_SIZE, action_dim]. If None
  177. (discrete action case), return Q-values for all actions.
  178. Returns:
  179. TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
  180. """
  181. return self._get_q_value(model_out, actions, self.q_net)
  182. def get_twin_q_values(self,
  183. model_out: TensorType,
  184. actions: Optional[TensorType] = None) -> TensorType:
  185. """Same as get_q_values but using the twin Q net.
  186. This implements the twin Q(s, a).
  187. Args:
  188. model_out (TensorType): Feature outputs from the model layers
  189. (result of doing `self.__call__(obs)`).
  190. actions (Optional[Tensor]): Actions to return the Q-values for.
  191. Shape: [BATCH_SIZE, action_dim]. If None (discrete action
  192. case), return Q-values for all actions.
  193. Returns:
  194. TensorType: Q-values tensor of shape [BATCH_SIZE, 1].
  195. """
  196. return self._get_q_value(model_out, actions, self.twin_q_net)
  197. def _get_q_value(self, model_out, actions, net):
  198. # Model outs may come as original Tuple/Dict observations, concat them
  199. # here if this is the case.
  200. if isinstance(net.obs_space, Box):
  201. if isinstance(model_out, (list, tuple)):
  202. model_out = tf.concat(model_out, axis=-1)
  203. elif isinstance(model_out, dict):
  204. model_out = tf.concat(list(model_out.values()), axis=-1)
  205. elif isinstance(model_out, dict):
  206. model_out = list(model_out.values())
  207. # Continuous case -> concat actions to model_out.
  208. if actions is not None:
  209. if self.concat_obs_and_actions:
  210. input_dict = {"obs": tf.concat([model_out, actions], axis=-1)}
  211. else:
  212. # TODO(junogng) : SampleBatch doesn't support list columns yet.
  213. # Use ModelInputDict.
  214. input_dict = {"obs": force_list(model_out) + [actions]}
  215. # Discrete case -> return q-vals for all actions.
  216. else:
  217. input_dict = {"obs": model_out}
  218. # Switch on training mode (when getting Q-values, we are usually in
  219. # training).
  220. input_dict["is_training"] = True
  221. out, _ = net(input_dict, [], None)
  222. return out
  223. def get_policy_output(self, model_out: TensorType) -> TensorType:
  224. """Returns policy outputs, given the output of self.__call__().
  225. For continuous action spaces, these will be the mean/stddev
  226. distribution inputs for the (SquashedGaussian) action distribution.
  227. For discrete action spaces, these will be the logits for a categorical
  228. distribution.
  229. Args:
  230. model_out (TensorType): Feature outputs from the model layers
  231. (result of doing `self.__call__(obs)`).
  232. Returns:
  233. TensorType: Distribution inputs for sampling actions.
  234. """
  235. # Model outs may come as original Tuple/Dict observations, concat them
  236. # here if this is the case.
  237. if isinstance(self.action_model.obs_space, Box):
  238. if isinstance(model_out, (list, tuple)):
  239. model_out = tf.concat(model_out, axis=-1)
  240. elif isinstance(model_out, dict):
  241. model_out = tf.concat(
  242. [
  243. tf.expand_dims(val, 1) if len(val.shape) == 1 else val
  244. for val in tree.flatten(model_out.values())
  245. ],
  246. axis=-1)
  247. out, _ = self.action_model({"obs": model_out}, [], None)
  248. return out
  249. def policy_variables(self):
  250. """Return the list of variables for the policy net."""
  251. return self.action_model.variables()
  252. def q_variables(self):
  253. """Return the list of variables for Q / twin Q nets."""
  254. return self.q_net.variables() + (self.twin_q_net.variables()
  255. if self.twin_q_net else [])