attention_net.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. """
  2. [1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar,
  3. Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017.
  4. https://arxiv.org/pdf/1706.03762.pdf
  5. [2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto
  6. et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf
  7. [3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
  8. Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019.
  9. https://www.aclweb.org/anthology/P19-1285.pdf
  10. """
  11. import gym
  12. from gym.spaces import Box, Discrete, MultiDiscrete
  13. import numpy as np
  14. import tree # pip install dm_tree
  15. from typing import Dict, Optional, Union
  16. from ray.rllib.models.modelv2 import ModelV2
  17. from ray.rllib.models.torch.misc import SlimFC
  18. from ray.rllib.models.torch.modules import GRUGate, \
  19. RelativeMultiHeadAttention, SkipConnection
  20. from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
  21. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  22. from ray.rllib.policy.sample_batch import SampleBatch
  23. from ray.rllib.policy.view_requirement import ViewRequirement
  24. from ray.rllib.utils.annotations import override
  25. from ray.rllib.utils.framework import try_import_torch
  26. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  27. from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor, one_hot
  28. from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
  29. torch, nn = try_import_torch()
  30. class GTrXLNet(RecurrentNetwork, nn.Module):
  31. """A GTrXL net Model described in [2].
  32. This is still in an experimental phase.
  33. Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
  34. For an example script, see: `ray/rllib/examples/attention_net.py`.
  35. To use this network as a replacement for an RNN, configure your Trainer
  36. as follows:
  37. Examples:
  38. >> config["model"]["custom_model"] = GTrXLNet
  39. >> config["model"]["max_seq_len"] = 10
  40. >> config["model"]["custom_model_config"] = {
  41. >> num_transformer_units=1,
  42. >> attention_dim=32,
  43. >> num_heads=2,
  44. >> memory_tau=50,
  45. >> etc..
  46. >> }
  47. """
  48. def __init__(self,
  49. observation_space: gym.spaces.Space,
  50. action_space: gym.spaces.Space,
  51. num_outputs: Optional[int],
  52. model_config: ModelConfigDict,
  53. name: str,
  54. *,
  55. num_transformer_units: int = 1,
  56. attention_dim: int = 64,
  57. num_heads: int = 2,
  58. memory_inference: int = 50,
  59. memory_training: int = 50,
  60. head_dim: int = 32,
  61. position_wise_mlp_dim: int = 32,
  62. init_gru_gate_bias: float = 2.0):
  63. """Initializes a GTrXLNet.
  64. Args:
  65. num_transformer_units (int): The number of Transformer repeats to
  66. use (denoted L in [2]).
  67. attention_dim (int): The input and output dimensions of one
  68. Transformer unit.
  69. num_heads (int): The number of attention heads to use in parallel.
  70. Denoted as `H` in [3].
  71. memory_inference (int): The number of timesteps to concat (time
  72. axis) and feed into the next transformer unit as inference
  73. input. The first transformer unit will receive this number of
  74. past observations (plus the current one), instead.
  75. memory_training (int): The number of timesteps to concat (time
  76. axis) and feed into the next transformer unit as training
  77. input (plus the actual input sequence of len=max_seq_len).
  78. The first transformer unit will receive this number of
  79. past observations (plus the input sequence), instead.
  80. head_dim (int): The dimension of a single(!) attention head within
  81. a multi-head attention unit. Denoted as `d` in [3].
  82. position_wise_mlp_dim (int): The dimension of the hidden layer
  83. within the position-wise MLP (after the multi-head attention
  84. block within one Transformer unit). This is the size of the
  85. first of the two layers within the PositionwiseFeedforward. The
  86. second layer always has size=`attention_dim`.
  87. init_gru_gate_bias (float): Initial bias values for the GRU gates
  88. (two GRUs per Transformer unit, one after the MHA, one after
  89. the position-wise MLP).
  90. """
  91. super().__init__(observation_space, action_space, num_outputs,
  92. model_config, name)
  93. nn.Module.__init__(self)
  94. self.num_transformer_units = num_transformer_units
  95. self.attention_dim = attention_dim
  96. self.num_heads = num_heads
  97. self.memory_inference = memory_inference
  98. self.memory_training = memory_training
  99. self.head_dim = head_dim
  100. self.max_seq_len = model_config["max_seq_len"]
  101. self.obs_dim = observation_space.shape[0]
  102. self.linear_layer = SlimFC(
  103. in_size=self.obs_dim, out_size=self.attention_dim)
  104. self.layers = [self.linear_layer]
  105. attention_layers = []
  106. # 2) Create L Transformer blocks according to [2].
  107. for i in range(self.num_transformer_units):
  108. # RelativeMultiHeadAttention part.
  109. MHA_layer = SkipConnection(
  110. RelativeMultiHeadAttention(
  111. in_dim=self.attention_dim,
  112. out_dim=self.attention_dim,
  113. num_heads=num_heads,
  114. head_dim=head_dim,
  115. input_layernorm=True,
  116. output_activation=nn.ReLU),
  117. fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias))
  118. # Position-wise MultiLayerPerceptron part.
  119. E_layer = SkipConnection(
  120. nn.Sequential(
  121. torch.nn.LayerNorm(self.attention_dim),
  122. SlimFC(
  123. in_size=self.attention_dim,
  124. out_size=position_wise_mlp_dim,
  125. use_bias=False,
  126. activation_fn=nn.ReLU),
  127. SlimFC(
  128. in_size=position_wise_mlp_dim,
  129. out_size=self.attention_dim,
  130. use_bias=False,
  131. activation_fn=nn.ReLU)),
  132. fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias))
  133. # Build a list of all attanlayers in order.
  134. attention_layers.extend([MHA_layer, E_layer])
  135. # Create a Sequential such that all parameters inside the attention
  136. # layers are automatically registered with this top-level model.
  137. self.attention_layers = nn.Sequential(*attention_layers)
  138. self.layers.extend(attention_layers)
  139. # Final layers if num_outputs not None.
  140. self.logits = None
  141. self.values_out = None
  142. # Last value output.
  143. self._value_out = None
  144. # Postprocess GTrXL output with another hidden layer.
  145. if self.num_outputs is not None:
  146. self.logits = SlimFC(
  147. in_size=self.attention_dim,
  148. out_size=self.num_outputs,
  149. activation_fn=nn.ReLU)
  150. # Value function used by all RLlib Torch RL implementations.
  151. self.values_out = SlimFC(
  152. in_size=self.attention_dim, out_size=1, activation_fn=None)
  153. else:
  154. self.num_outputs = self.attention_dim
  155. # Setup trajectory views (`memory-inference` x past memory outs).
  156. for i in range(self.num_transformer_units):
  157. space = Box(-1.0, 1.0, shape=(self.attention_dim, ))
  158. self.view_requirements["state_in_{}".format(i)] = \
  159. ViewRequirement(
  160. "state_out_{}".format(i),
  161. shift="-{}:-1".format(self.memory_inference),
  162. # Repeat the incoming state every max-seq-len times.
  163. batch_repeat_value=self.max_seq_len,
  164. space=space)
  165. self.view_requirements["state_out_{}".format(i)] = \
  166. ViewRequirement(
  167. space=space,
  168. used_for_training=False)
  169. @override(ModelV2)
  170. def forward(self, input_dict, state: List[TensorType],
  171. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  172. assert seq_lens is not None
  173. # Add the needed batch rank (tf Models' Input requires this).
  174. observations = input_dict[SampleBatch.OBS]
  175. # Add the time dim to observations.
  176. B = len(seq_lens)
  177. T = observations.shape[0] // B
  178. observations = torch.reshape(observations,
  179. [-1, T] + list(observations.shape[1:]))
  180. all_out = observations
  181. memory_outs = []
  182. for i in range(len(self.layers)):
  183. # MHA layers which need memory passed in.
  184. if i % 2 == 1:
  185. all_out = self.layers[i](all_out, memory=state[i // 2])
  186. # Either self.linear_layer (initial obs -> attn. dim layer) or
  187. # MultiLayerPerceptrons. The output of these layers is always the
  188. # memory for the next forward pass.
  189. else:
  190. all_out = self.layers[i](all_out)
  191. memory_outs.append(all_out)
  192. # Discard last output (not needed as a memory since it's the last
  193. # layer).
  194. memory_outs = memory_outs[:-1]
  195. if self.logits is not None:
  196. out = self.logits(all_out)
  197. self._value_out = self.values_out(all_out)
  198. out_dim = self.num_outputs
  199. else:
  200. out = all_out
  201. out_dim = self.attention_dim
  202. return torch.reshape(out, [-1, out_dim]), [
  203. torch.reshape(m, [-1, self.attention_dim]) for m in memory_outs
  204. ]
  205. # TODO: (sven) Deprecate this once trajectory view API has fully matured.
  206. @override(RecurrentNetwork)
  207. def get_initial_state(self) -> List[np.ndarray]:
  208. return []
  209. @override(ModelV2)
  210. def value_function(self) -> TensorType:
  211. assert self._value_out is not None,\
  212. "Must call forward first AND must have value branch!"
  213. return torch.reshape(self._value_out, [-1])
  214. class AttentionWrapper(TorchModelV2, nn.Module):
  215. """GTrXL wrapper serving as interface for ModelV2s that set use_attention.
  216. """
  217. def __init__(self, obs_space: gym.spaces.Space,
  218. action_space: gym.spaces.Space, num_outputs: int,
  219. model_config: ModelConfigDict, name: str):
  220. nn.Module.__init__(self)
  221. super().__init__(obs_space, action_space, None, model_config, name)
  222. self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
  223. self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"]
  224. self.action_space_struct = get_base_struct_from_space(
  225. self.action_space)
  226. self.action_dim = 0
  227. for space in tree.flatten(self.action_space_struct):
  228. if isinstance(space, Discrete):
  229. self.action_dim += space.n
  230. elif isinstance(space, MultiDiscrete):
  231. self.action_dim += np.sum(space.nvec)
  232. elif space.shape is not None:
  233. self.action_dim += int(np.product(space.shape))
  234. else:
  235. self.action_dim += int(len(space))
  236. # Add prev-action/reward nodes to input to LSTM.
  237. if self.use_n_prev_actions:
  238. self.num_outputs += self.use_n_prev_actions * self.action_dim
  239. if self.use_n_prev_rewards:
  240. self.num_outputs += self.use_n_prev_rewards
  241. cfg = model_config
  242. self.attention_dim = cfg["attention_dim"]
  243. if self.num_outputs is not None:
  244. in_space = gym.spaces.Box(
  245. float("-inf"),
  246. float("inf"),
  247. shape=(self.num_outputs, ),
  248. dtype=np.float32)
  249. else:
  250. in_space = obs_space
  251. # Construct GTrXL sub-module w/ num_outputs=None (so it does not
  252. # create a logits/value output; we'll do this ourselves in this wrapper
  253. # here).
  254. self.gtrxl = GTrXLNet(
  255. in_space,
  256. action_space,
  257. None,
  258. model_config,
  259. "gtrxl",
  260. num_transformer_units=cfg["attention_num_transformer_units"],
  261. attention_dim=self.attention_dim,
  262. num_heads=cfg["attention_num_heads"],
  263. head_dim=cfg["attention_head_dim"],
  264. memory_inference=cfg["attention_memory_inference"],
  265. memory_training=cfg["attention_memory_training"],
  266. position_wise_mlp_dim=cfg["attention_position_wise_mlp_dim"],
  267. init_gru_gate_bias=cfg["attention_init_gru_gate_bias"],
  268. )
  269. # Set final num_outputs to correct value (depending on action space).
  270. self.num_outputs = num_outputs
  271. # Postprocess GTrXL output with another hidden layer and compute
  272. # values.
  273. self._logits_branch = SlimFC(
  274. in_size=self.attention_dim,
  275. out_size=self.num_outputs,
  276. activation_fn=None,
  277. initializer=torch.nn.init.xavier_uniform_)
  278. self._value_branch = SlimFC(
  279. in_size=self.attention_dim,
  280. out_size=1,
  281. activation_fn=None,
  282. initializer=torch.nn.init.xavier_uniform_)
  283. self.view_requirements = self.gtrxl.view_requirements
  284. self.view_requirements["obs"].space = self.obs_space
  285. # Add prev-a/r to this model's view, if required.
  286. if self.use_n_prev_actions:
  287. self.view_requirements[SampleBatch.PREV_ACTIONS] = \
  288. ViewRequirement(
  289. SampleBatch.ACTIONS,
  290. space=self.action_space,
  291. shift="-{}:-1".format(self.use_n_prev_actions))
  292. if self.use_n_prev_rewards:
  293. self.view_requirements[SampleBatch.PREV_REWARDS] = \
  294. ViewRequirement(
  295. SampleBatch.REWARDS,
  296. shift="-{}:-1".format(self.use_n_prev_rewards))
  297. @override(RecurrentNetwork)
  298. def forward(self, input_dict: Dict[str, TensorType],
  299. state: List[TensorType],
  300. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  301. assert seq_lens is not None
  302. # Push obs through "unwrapped" net's `forward()` first.
  303. wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  304. # Concat. prev-action/reward if required.
  305. prev_a_r = []
  306. # Prev actions.
  307. if self.use_n_prev_actions:
  308. prev_n_actions = input_dict[SampleBatch.PREV_ACTIONS]
  309. # If actions are not processed yet (in their original form as
  310. # have been sent to environment):
  311. # Flatten/one-hot into 1D array.
  312. if self.model_config["_disable_action_flattening"]:
  313. # Merge prev n actions into flat tensor.
  314. flat = flatten_inputs_to_1d_tensor(
  315. prev_n_actions,
  316. spaces_struct=self.action_space_struct,
  317. time_axis=True,
  318. )
  319. # Fold time-axis into flattened data.
  320. flat = torch.reshape(flat, [flat.shape[0], -1])
  321. prev_a_r.append(flat)
  322. # If actions are already flattened (but not one-hot'd yet!),
  323. # one-hot discrete/multi-discrete actions here and concatenate the
  324. # n most recent actions together.
  325. else:
  326. if isinstance(self.action_space, Discrete):
  327. for i in range(self.use_n_prev_actions):
  328. prev_a_r.append(
  329. one_hot(
  330. prev_n_actions[:, i].float(),
  331. space=self.action_space))
  332. elif isinstance(self.action_space, MultiDiscrete):
  333. for i in range(0, self.use_n_prev_actions,
  334. self.action_space.shape[0]):
  335. prev_a_r.append(
  336. one_hot(
  337. prev_n_actions[:, i:i +
  338. self.action_space.shape[0]]
  339. .float(),
  340. space=self.action_space))
  341. else:
  342. prev_a_r.append(
  343. torch.reshape(
  344. prev_n_actions.float(),
  345. [-1, self.use_n_prev_actions * self.action_dim]))
  346. # Prev rewards.
  347. if self.use_n_prev_rewards:
  348. prev_a_r.append(
  349. torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(),
  350. [-1, self.use_n_prev_rewards]))
  351. # Concat prev. actions + rewards to the "main" input.
  352. if prev_a_r:
  353. wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1)
  354. # Then through our GTrXL.
  355. input_dict["obs_flat"] = input_dict["obs"] = wrapped_out
  356. self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
  357. model_out = self._logits_branch(self._features)
  358. return model_out, memory_outs
  359. @override(ModelV2)
  360. def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
  361. return []
  362. @override(ModelV2)
  363. def value_function(self) -> TensorType:
  364. assert self._features is not None, "Must call forward() first!"
  365. return torch.reshape(self._value_branch(self._features), [-1])