espnet_positional_embedding.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import math
  2. import torch
  3. class PositionalEncoding(torch.nn.Module):
  4. """Positional encoding.
  5. Args:
  6. d_model (int): Embedding dimension.
  7. dropout_rate (float): Dropout rate.
  8. max_len (int): Maximum input length.
  9. reverse (bool): Whether to reverse the input position.
  10. """
  11. def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
  12. """Construct an PositionalEncoding object."""
  13. super(PositionalEncoding, self).__init__()
  14. self.d_model = d_model
  15. self.reverse = reverse
  16. self.xscale = math.sqrt(self.d_model)
  17. self.dropout = torch.nn.Dropout(p=dropout_rate)
  18. self.pe = None
  19. self.extend_pe(torch.tensor(0.0).expand(1, max_len))
  20. def extend_pe(self, x):
  21. """Reset the positional encodings."""
  22. if self.pe is not None:
  23. if self.pe.size(1) >= x.size(1):
  24. if self.pe.dtype != x.dtype or self.pe.device != x.device:
  25. self.pe = self.pe.to(dtype=x.dtype, device=x.device)
  26. return
  27. pe = torch.zeros(x.size(1), self.d_model)
  28. if self.reverse:
  29. position = torch.arange(
  30. x.size(1) - 1, -1, -1.0, dtype=torch.float32
  31. ).unsqueeze(1)
  32. else:
  33. position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
  34. div_term = torch.exp(
  35. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  36. * -(math.log(10000.0) / self.d_model)
  37. )
  38. pe[:, 0::2] = torch.sin(position * div_term)
  39. pe[:, 1::2] = torch.cos(position * div_term)
  40. pe = pe.unsqueeze(0)
  41. self.pe = pe.to(device=x.device, dtype=x.dtype)
  42. def forward(self, x: torch.Tensor):
  43. """Add positional encoding.
  44. Args:
  45. x (torch.Tensor): Input tensor (batch, time, `*`).
  46. Returns:
  47. torch.Tensor: Encoded tensor (batch, time, `*`).
  48. """
  49. self.extend_pe(x)
  50. x = x * self.xscale + self.pe[:, : x.size(1)]
  51. return self.dropout(x)
  52. class ScaledPositionalEncoding(PositionalEncoding):
  53. """Scaled positional encoding module.
  54. See Sec. 3.2 https://arxiv.org/abs/1809.08895
  55. Args:
  56. d_model (int): Embedding dimension.
  57. dropout_rate (float): Dropout rate.
  58. max_len (int): Maximum input length.
  59. """
  60. def __init__(self, d_model, dropout_rate, max_len=5000):
  61. """Initialize class."""
  62. super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
  63. self.alpha = torch.nn.Parameter(torch.tensor(1.0))
  64. def reset_parameters(self):
  65. """Reset parameters."""
  66. self.alpha.data = torch.tensor(1.0)
  67. def forward(self, x):
  68. """Add positional encoding.
  69. Args:
  70. x (torch.Tensor): Input tensor (batch, time, `*`).
  71. Returns:
  72. torch.Tensor: Encoded tensor (batch, time, `*`).
  73. """
  74. self.extend_pe(x)
  75. x = x + self.alpha * self.pe[:, : x.size(1)]
  76. return self.dropout(x)
  77. class RelPositionalEncoding(PositionalEncoding):
  78. """Relative positional encoding module.
  79. See : Appendix B in https://arxiv.org/abs/1901.02860
  80. Args:
  81. d_model (int): Embedding dimension.
  82. dropout_rate (float): Dropout rate.
  83. max_len (int): Maximum input length.
  84. """
  85. def __init__(self, d_model, dropout_rate, max_len=5000):
  86. """Initialize class."""
  87. super().__init__(d_model, dropout_rate, max_len, reverse=True)
  88. def forward(self, x):
  89. """Compute positional encoding.
  90. Args:
  91. x (torch.Tensor): Input tensor (batch, time, `*`).
  92. Returns:
  93. torch.Tensor: Encoded tensor (batch, time, `*`).
  94. torch.Tensor: Positional embedding tensor (1, time, `*`).
  95. """
  96. self.extend_pe(x)
  97. x = x * self.xscale
  98. pos_emb = self.pe[:, : x.size(1)]
  99. return self.dropout(x), self.dropout(pos_emb)