recurrent_net.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import numpy as np
  2. import gym
  3. from gym.spaces import Discrete, MultiDiscrete
  4. import tree # pip install dm_tree
  5. from typing import Dict, List, Union
  6. from ray.rllib.models.modelv2 import ModelV2
  7. from ray.rllib.models.torch.misc import SlimFC
  8. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  9. from ray.rllib.policy.rnn_sequencing import add_time_dimension
  10. from ray.rllib.policy.sample_batch import SampleBatch
  11. from ray.rllib.policy.view_requirement import ViewRequirement
  12. from ray.rllib.utils.annotations import override, DeveloperAPI
  13. from ray.rllib.utils.framework import try_import_torch
  14. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  15. from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor, one_hot
  16. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  17. torch, nn = try_import_torch()
  18. @DeveloperAPI
  19. class RecurrentNetwork(TorchModelV2):
  20. """Helper class to simplify implementing RNN models with TorchModelV2.
  21. Instead of implementing forward(), you can implement forward_rnn() which
  22. takes batches with the time dimension added already.
  23. Here is an example implementation for a subclass
  24. ``MyRNNClass(RecurrentNetwork, nn.Module)``::
  25. def __init__(self, obs_space, num_outputs):
  26. nn.Module.__init__(self)
  27. super().__init__(obs_space, action_space, num_outputs,
  28. model_config, name)
  29. self.obs_size = _get_size(obs_space)
  30. self.rnn_hidden_dim = model_config["lstm_cell_size"]
  31. self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
  32. self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
  33. self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
  34. self.value_branch = nn.Linear(self.rnn_hidden_dim, 1)
  35. self._cur_value = None
  36. @override(ModelV2)
  37. def get_initial_state(self):
  38. # Place hidden states on same device as model.
  39. h = [self.fc1.weight.new(
  40. 1, self.rnn_hidden_dim).zero_().squeeze(0)]
  41. return h
  42. @override(ModelV2)
  43. def value_function(self):
  44. assert self._cur_value is not None, "must call forward() first"
  45. return self._cur_value
  46. @override(RecurrentNetwork)
  47. def forward_rnn(self, input_dict, state, seq_lens):
  48. x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
  49. h_in = state[0].reshape(-1, self.rnn_hidden_dim)
  50. h = self.rnn(x, h_in)
  51. q = self.fc2(h)
  52. self._cur_value = self.value_branch(h).squeeze(1)
  53. return q, [h]
  54. """
  55. @override(ModelV2)
  56. def forward(self, input_dict: Dict[str, TensorType],
  57. state: List[TensorType],
  58. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  59. """Adds time dimension to batch before sending inputs to forward_rnn().
  60. You should implement forward_rnn() in your subclass."""
  61. flat_inputs = input_dict["obs_flat"].float()
  62. if isinstance(seq_lens, np.ndarray):
  63. seq_lens = torch.Tensor(seq_lens).int()
  64. max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0]
  65. self.time_major = self.model_config.get("_time_major", False)
  66. inputs = add_time_dimension(
  67. flat_inputs,
  68. max_seq_len=max_seq_len,
  69. framework="torch",
  70. time_major=self.time_major,
  71. )
  72. output, new_state = self.forward_rnn(inputs, state, seq_lens)
  73. output = torch.reshape(output, [-1, self.num_outputs])
  74. return output, new_state
  75. def forward_rnn(self, inputs: TensorType, state: List[TensorType],
  76. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  77. """Call the model with the given input tensors and state.
  78. Args:
  79. inputs (dict): Observation tensor with shape [B, T, obs_size].
  80. state (list): List of state tensors, each with shape [B, size].
  81. seq_lens (Tensor): 1D tensor holding input sequence lengths.
  82. Note: len(seq_lens) == B.
  83. Returns:
  84. (outputs, new_state): The model output tensor of shape
  85. [B, T, num_outputs] and the list of new state tensors each with
  86. shape [B, size].
  87. Examples:
  88. def forward_rnn(self, inputs, state, seq_lens):
  89. model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
  90. return model_out, [h, c]
  91. """
  92. raise NotImplementedError("You must implement this for an RNN model")
  93. class LSTMWrapper(RecurrentNetwork, nn.Module):
  94. """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm.
  95. """
  96. def __init__(self, obs_space: gym.spaces.Space,
  97. action_space: gym.spaces.Space, num_outputs: int,
  98. model_config: ModelConfigDict, name: str):
  99. nn.Module.__init__(self)
  100. super(LSTMWrapper, self).__init__(obs_space, action_space, None,
  101. model_config, name)
  102. # At this point, self.num_outputs is the number of nodes coming
  103. # from the wrapped (underlying) model. In other words, self.num_outputs
  104. # is the input size for the LSTM layer.
  105. # If None, set it to the observation space.
  106. if self.num_outputs is None:
  107. self.num_outputs = int(np.product(self.obs_space.shape))
  108. self.cell_size = model_config["lstm_cell_size"]
  109. self.time_major = model_config.get("_time_major", False)
  110. self.use_prev_action = model_config["lstm_use_prev_action"]
  111. self.use_prev_reward = model_config["lstm_use_prev_reward"]
  112. self.action_space_struct = get_base_struct_from_space(
  113. self.action_space)
  114. self.action_dim = 0
  115. for space in tree.flatten(self.action_space_struct):
  116. if isinstance(space, Discrete):
  117. self.action_dim += space.n
  118. elif isinstance(space, MultiDiscrete):
  119. self.action_dim += np.sum(space.nvec)
  120. elif space.shape is not None:
  121. self.action_dim += int(np.product(space.shape))
  122. else:
  123. self.action_dim += int(len(space))
  124. # Add prev-action/reward nodes to input to LSTM.
  125. if self.use_prev_action:
  126. self.num_outputs += self.action_dim
  127. if self.use_prev_reward:
  128. self.num_outputs += 1
  129. # Define actual LSTM layer (with num_outputs being the nodes coming
  130. # from the wrapped (underlying) layer).
  131. self.lstm = nn.LSTM(
  132. self.num_outputs, self.cell_size, batch_first=not self.time_major)
  133. # Set self.num_outputs to the number of output nodes desired by the
  134. # caller of this constructor.
  135. self.num_outputs = num_outputs
  136. # Postprocess LSTM output with another hidden layer and compute values.
  137. self._logits_branch = SlimFC(
  138. in_size=self.cell_size,
  139. out_size=self.num_outputs,
  140. activation_fn=None,
  141. initializer=torch.nn.init.xavier_uniform_)
  142. self._value_branch = SlimFC(
  143. in_size=self.cell_size,
  144. out_size=1,
  145. activation_fn=None,
  146. initializer=torch.nn.init.xavier_uniform_)
  147. # __sphinx_doc_begin__
  148. # Add prev-a/r to this model's view, if required.
  149. if model_config["lstm_use_prev_action"]:
  150. self.view_requirements[SampleBatch.PREV_ACTIONS] = \
  151. ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
  152. shift=-1)
  153. if model_config["lstm_use_prev_reward"]:
  154. self.view_requirements[SampleBatch.PREV_REWARDS] = \
  155. ViewRequirement(SampleBatch.REWARDS, shift=-1)
  156. # __sphinx_doc_end__
  157. @override(RecurrentNetwork)
  158. def forward(self, input_dict: Dict[str, TensorType],
  159. state: List[TensorType],
  160. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  161. assert seq_lens is not None
  162. # Push obs through "unwrapped" net's `forward()` first.
  163. wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  164. # Concat. prev-action/reward if required.
  165. prev_a_r = []
  166. # Prev actions.
  167. if self.model_config["lstm_use_prev_action"]:
  168. prev_a = input_dict[SampleBatch.PREV_ACTIONS]
  169. # If actions are not processed yet (in their original form as
  170. # have been sent to environment):
  171. # Flatten/one-hot into 1D array.
  172. if self.model_config["_disable_action_flattening"]:
  173. prev_a_r.append(
  174. flatten_inputs_to_1d_tensor(
  175. prev_a,
  176. spaces_struct=self.action_space_struct,
  177. time_axis=False))
  178. # If actions are already flattened (but not one-hot'd yet!),
  179. # one-hot discrete/multi-discrete actions here.
  180. else:
  181. if isinstance(self.action_space, (Discrete, MultiDiscrete)):
  182. prev_a = one_hot(prev_a.float(), self.action_space)
  183. else:
  184. prev_a = prev_a.float()
  185. prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim]))
  186. # Prev rewards.
  187. if self.model_config["lstm_use_prev_reward"]:
  188. prev_a_r.append(
  189. torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
  190. [-1, 1]))
  191. # Concat prev. actions + rewards to the "main" input.
  192. if prev_a_r:
  193. wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)
  194. # Push everything through our LSTM.
  195. input_dict["obs_flat"] = wrapped_out
  196. return super().forward(input_dict, state, seq_lens)
  197. @override(RecurrentNetwork)
  198. def forward_rnn(self, inputs: TensorType, state: List[TensorType],
  199. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  200. # Don't show paddings to RNN(?)
  201. # TODO: (sven) For now, only allow, iff time_major=True to not break
  202. # anything retrospectively (time_major not supported previously).
  203. # max_seq_len = inputs.shape[0]
  204. # time_major = self.model_config["_time_major"]
  205. # if time_major and max_seq_len > 1:
  206. # inputs = torch.nn.utils.rnn.pack_padded_sequence(
  207. # inputs, seq_lens,
  208. # batch_first=not time_major, enforce_sorted=False)
  209. self._features, [h, c] = self.lstm(
  210. inputs,
  211. [torch.unsqueeze(state[0], 0),
  212. torch.unsqueeze(state[1], 0)])
  213. # Re-apply paddings.
  214. # if time_major and max_seq_len > 1:
  215. # self._features, _ = torch.nn.utils.rnn.pad_packed_sequence(
  216. # self._features,
  217. # batch_first=not time_major)
  218. model_out = self._logits_branch(self._features)
  219. return model_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
  220. @override(ModelV2)
  221. def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
  222. # Place hidden states on same device as model.
  223. linear = next(self._logits_branch._model.children())
  224. h = [
  225. linear.weight.new(1, self.cell_size).zero_().squeeze(0),
  226. linear.weight.new(1, self.cell_size).zero_().squeeze(0)
  227. ]
  228. return h
  229. @override(ModelV2)
  230. def value_function(self) -> TensorType:
  231. assert self._features is not None, "must call forward() first"
  232. return torch.reshape(self._value_branch(self._features), [-1])