relative_multi_head_attention.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from typing import Union
  2. from ray.rllib.utils.framework import try_import_torch
  3. from ray.rllib.models.torch.misc import SlimFC
  4. from ray.rllib.utils.torch_utils import sequence_mask
  5. from ray.rllib.utils.typing import TensorType
  6. torch, nn = try_import_torch()
  7. class RelativePositionEmbedding(nn.Module):
  8. """Creates a [seq_length x seq_length] matrix for rel. pos encoding.
  9. Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding
  10. matrix.
  11. Args:
  12. seq_length (int): The max. sequence length (time axis).
  13. out_dim (int): The number of nodes to go into the first Tranformer
  14. layer with.
  15. Returns:
  16. torch.Tensor: The encoding matrix Phi.
  17. """
  18. def __init__(self, out_dim, **kwargs):
  19. super().__init__()
  20. self.out_dim = out_dim
  21. out_range = torch.arange(0, self.out_dim, 2.0)
  22. inverse_freq = 1 / (10000**(out_range / self.out_dim))
  23. self.register_buffer("inverse_freq", inverse_freq)
  24. def forward(self, seq_length):
  25. pos_input = torch.arange(
  26. seq_length - 1, -1, -1.0,
  27. dtype=torch.float).to(self.inverse_freq.device)
  28. sinusoid_input = torch.einsum("i,j->ij", pos_input, self.inverse_freq)
  29. pos_embeddings = torch.cat(
  30. [torch.sin(sinusoid_input),
  31. torch.cos(sinusoid_input)], dim=-1)
  32. return pos_embeddings[:, None, :]
  33. class RelativeMultiHeadAttention(nn.Module):
  34. """A RelativeMultiHeadAttention layer as described in [3].
  35. Uses segment level recurrence with state reuse.
  36. """
  37. def __init__(self,
  38. in_dim: int,
  39. out_dim: int,
  40. num_heads: int,
  41. head_dim: int,
  42. input_layernorm: bool = False,
  43. output_activation: Union[str, callable] = None,
  44. **kwargs):
  45. """Initializes a RelativeMultiHeadAttention nn.Module object.
  46. Args:
  47. in_dim (int):
  48. out_dim (int): The output dimension of this module. Also known as
  49. "attention dim".
  50. num_heads (int): The number of attention heads to use.
  51. Denoted `H` in [2].
  52. head_dim (int): The dimension of a single(!) attention head
  53. Denoted `D` in [2].
  54. input_layernorm (bool): Whether to prepend a LayerNorm before
  55. everything else. Should be True for building a GTrXL.
  56. output_activation (Union[str, callable]): Optional activation
  57. function or activation function specifier (str).
  58. Should be "relu" for GTrXL.
  59. **kwargs:
  60. """
  61. super().__init__(**kwargs)
  62. # No bias or non-linearity.
  63. self._num_heads = num_heads
  64. self._head_dim = head_dim
  65. # 3=Query, key, and value inputs.
  66. self._qkv_layer = SlimFC(
  67. in_size=in_dim, out_size=3 * num_heads * head_dim, use_bias=False)
  68. self._linear_layer = SlimFC(
  69. in_size=num_heads * head_dim,
  70. out_size=out_dim,
  71. use_bias=False,
  72. activation_fn=output_activation)
  73. self._uvar = nn.Parameter(torch.zeros(num_heads, head_dim))
  74. self._vvar = nn.Parameter(torch.zeros(num_heads, head_dim))
  75. nn.init.xavier_uniform_(self._uvar)
  76. nn.init.xavier_uniform_(self._vvar)
  77. self.register_parameter("_uvar", self._uvar)
  78. self.register_parameter("_vvar", self._vvar)
  79. self._pos_proj = SlimFC(
  80. in_size=in_dim, out_size=num_heads * head_dim, use_bias=False)
  81. self._rel_pos_embedding = RelativePositionEmbedding(out_dim)
  82. self._input_layernorm = None
  83. if input_layernorm:
  84. self._input_layernorm = torch.nn.LayerNorm(in_dim)
  85. def forward(self, inputs: TensorType,
  86. memory: TensorType = None) -> TensorType:
  87. T = list(inputs.size())[1] # length of segment (time)
  88. H = self._num_heads # number of attention heads
  89. d = self._head_dim # attention head dimension
  90. # Add previous memory chunk (as const, w/o gradient) to input.
  91. # Tau (number of (prev) time slices in each memory chunk).
  92. Tau = list(memory.shape)[1]
  93. inputs = torch.cat((memory.detach(), inputs), dim=1)
  94. # Apply the Layer-Norm.
  95. if self._input_layernorm is not None:
  96. inputs = self._input_layernorm(inputs)
  97. qkv = self._qkv_layer(inputs)
  98. queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
  99. # Cut out Tau memory timesteps from query.
  100. queries = queries[:, -T:]
  101. queries = torch.reshape(queries, [-1, T, H, d])
  102. keys = torch.reshape(keys, [-1, Tau + T, H, d])
  103. values = torch.reshape(values, [-1, Tau + T, H, d])
  104. R = self._pos_proj(self._rel_pos_embedding(Tau + T))
  105. R = torch.reshape(R, [Tau + T, H, d])
  106. # b=batch
  107. # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
  108. # h=head
  109. # d=head-dim (over which we will reduce-sum)
  110. score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
  111. pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R)
  112. score = score + self.rel_shift(pos_score)
  113. score = score / d**0.5
  114. # causal mask of the same length as the sequence
  115. mask = sequence_mask(
  116. torch.arange(Tau + 1, Tau + T + 1),
  117. dtype=score.dtype).to(score.device)
  118. mask = mask[None, :, :, None]
  119. masked_score = score * mask + 1e30 * (mask.float() - 1.)
  120. wmat = nn.functional.softmax(masked_score, dim=2)
  121. out = torch.einsum("bijh,bjhd->bihd", wmat, values)
  122. shape = list(out.shape)[:2] + [H * d]
  123. out = torch.reshape(out, shape)
  124. return self._linear_layer(out)
  125. @staticmethod
  126. def rel_shift(x: TensorType) -> TensorType:
  127. # Transposed version of the shift approach described in [3].
  128. # https://github.com/kimiyoung/transformer-xl/blob/
  129. # 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31
  130. x_size = list(x.shape)
  131. x = torch.nn.functional.pad(x, (0, 0, 1, 0, 0, 0, 0, 0))
  132. x = torch.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]])
  133. x = x[:, 1:, :, :]
  134. x = torch.reshape(x, x_size)
  135. return x