conformer.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. from torch import nn
  3. from .espnet_positional_embedding import RelPositionalEncoding
  4. from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
  5. from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
  6. from ..layers import Embedding
  7. def sequence_mask(length, max_length=None):
  8. if max_length is None:
  9. max_length = length.max()
  10. x = torch.arange(max_length, dtype=length.dtype, device=length.device)
  11. return x.unsqueeze(0) < length.unsqueeze(1)
  12. class ConformerLayers(nn.Module):
  13. def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4, use_last_norm=True):
  14. super().__init__()
  15. self.use_last_norm = use_last_norm
  16. self.layers = nn.ModuleList()
  17. positionwise_layer = MultiLayeredConv1d
  18. positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
  19. self.encoder_layers = nn.ModuleList([EncoderLayer(
  20. hidden_size,
  21. MultiHeadedAttention(num_heads, hidden_size, 0.0),
  22. positionwise_layer(*positionwise_layer_args),
  23. positionwise_layer(*positionwise_layer_args),
  24. ConvolutionModule(hidden_size, kernel_size, Swish()),
  25. dropout,
  26. ) for _ in range(num_layers)])
  27. if self.use_last_norm:
  28. self.layer_norm = nn.LayerNorm(hidden_size)
  29. else:
  30. self.layer_norm = nn.Linear(hidden_size, hidden_size)
  31. def forward(self, x, x_mask):
  32. """
  33. :param x: [B, T, H]
  34. :param padding_mask: [B, T]
  35. :return: [B, T, H]
  36. """
  37. for l in self.encoder_layers:
  38. x, mask = l(x, x_mask)
  39. x = self.layer_norm(x) * x_mask
  40. return x
  41. class ConformerEncoder(ConformerLayers):
  42. def __init__(self, hidden_size, dict_size=0, in_size=0, strides=[2,2], num_layers=None):
  43. conformer_enc_kernel_size = 9
  44. super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
  45. self.dict_size = dict_size
  46. if dict_size != 0:
  47. self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
  48. else:
  49. self.seq_proj_in = torch.nn.Linear(in_size, hidden_size)
  50. self.seq_proj_out = torch.nn.Linear(hidden_size, in_size)
  51. self.mel_in = torch.nn.Linear(160, hidden_size)
  52. self.mel_pre_net = torch.nn.Sequential(*[
  53. torch.nn.Conv1d(hidden_size, hidden_size, kernel_size=s * 2, stride=s, padding=s // 2)
  54. for i, s in enumerate(strides)
  55. ])
  56. def forward(self, seq_out, mels_timbre, other_embeds=0):
  57. """
  58. :param src_tokens: [B, T]
  59. :return: [B x T x C]
  60. """
  61. x_lengths = (seq_out > 0).long().sum(-1)
  62. x = seq_out
  63. if self.dict_size != 0:
  64. x = self.embed(x) + other_embeds # [B, T, H]
  65. else:
  66. x = self.seq_proj_in(x) + other_embeds # [B, T, H]
  67. mels_timbre = self.mel_in(mels_timbre).transpose(1, 2)
  68. mels_timbre = self.mel_pre_net(mels_timbre).transpose(1, 2)
  69. T_out = x.size(1)
  70. if self.dict_size != 0:
  71. x_mask = torch.unsqueeze(sequence_mask(x_lengths + mels_timbre.size(1), x.size(1) + mels_timbre.size(1)), 2).to(x.dtype)
  72. else:
  73. x_mask = torch.cat((torch.ones(x.size(0), mels_timbre.size(1), 1).to(x.device), (x.abs().sum(2) > 0).float()[:, :, None]), dim=1)
  74. x = torch.cat((mels_timbre, x), 1)
  75. x = super(ConformerEncoder, self).forward(x, x_mask)
  76. if self.dict_size != 0:
  77. x = x[:, -T_out:, :]
  78. else:
  79. x = self.seq_proj_out(x[:, -T_out:, :])
  80. return x
  81. class ConformerDecoder(ConformerLayers):
  82. def __init__(self, hidden_size, num_layers):
  83. conformer_dec_kernel_size = 9
  84. super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)