rnn.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. class PreNet(nn.Module):
  5. def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
  6. super().__init__()
  7. self.fc1 = nn.Linear(in_dims, fc1_dims)
  8. self.fc2 = nn.Linear(fc1_dims, fc2_dims)
  9. self.p = dropout
  10. def forward(self, x):
  11. x = self.fc1(x)
  12. x = F.relu(x)
  13. x = F.dropout(x, self.p, training=self.training)
  14. x = self.fc2(x)
  15. x = F.relu(x)
  16. x = F.dropout(x, self.p, training=self.training)
  17. return x
  18. class HighwayNetwork(nn.Module):
  19. def __init__(self, size):
  20. super().__init__()
  21. self.W1 = nn.Linear(size, size)
  22. self.W2 = nn.Linear(size, size)
  23. self.W1.bias.data.fill_(0.)
  24. def forward(self, x):
  25. x1 = self.W1(x)
  26. x2 = self.W2(x)
  27. g = torch.sigmoid(x2)
  28. y = g * F.relu(x1) + (1. - g) * x
  29. return y
  30. class BatchNormConv(nn.Module):
  31. def __init__(self, in_channels, out_channels, kernel, relu=True):
  32. super().__init__()
  33. self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
  34. self.bnorm = nn.BatchNorm1d(out_channels)
  35. self.relu = relu
  36. def forward(self, x):
  37. x = self.conv(x)
  38. x = F.relu(x) if self.relu is True else x
  39. return self.bnorm(x)
  40. class ConvNorm(torch.nn.Module):
  41. def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
  42. padding=None, dilation=1, bias=True, w_init_gain='linear'):
  43. super(ConvNorm, self).__init__()
  44. if padding is None:
  45. assert (kernel_size % 2 == 1)
  46. padding = int(dilation * (kernel_size - 1) / 2)
  47. self.conv = torch.nn.Conv1d(in_channels, out_channels,
  48. kernel_size=kernel_size, stride=stride,
  49. padding=padding, dilation=dilation,
  50. bias=bias)
  51. torch.nn.init.xavier_uniform_(
  52. self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
  53. def forward(self, signal):
  54. conv_signal = self.conv(signal)
  55. return conv_signal
  56. class CBHG(nn.Module):
  57. def __init__(self, K, in_channels, channels, proj_channels, num_highways):
  58. super().__init__()
  59. # List of all rnns to call `flatten_parameters()` on
  60. self._to_flatten = []
  61. self.bank_kernels = [i for i in range(1, K + 1)]
  62. self.conv1d_bank = nn.ModuleList()
  63. for k in self.bank_kernels:
  64. conv = BatchNormConv(in_channels, channels, k)
  65. self.conv1d_bank.append(conv)
  66. self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
  67. self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
  68. self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
  69. # Fix the highway input if necessary
  70. if proj_channels[-1] != channels:
  71. self.highway_mismatch = True
  72. self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
  73. else:
  74. self.highway_mismatch = False
  75. self.highways = nn.ModuleList()
  76. for i in range(num_highways):
  77. hn = HighwayNetwork(channels)
  78. self.highways.append(hn)
  79. self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
  80. self._to_flatten.append(self.rnn)
  81. # Avoid fragmentation of RNN parameters and associated warning
  82. self._flatten_parameters()
  83. def forward(self, x):
  84. # Although we `_flatten_parameters()` on init, when using DataParallel
  85. # the model gets replicated, making it no longer guaranteed that the
  86. # weights are contiguous in GPU memory. Hence, we must call it again
  87. self._flatten_parameters()
  88. # Save these for later
  89. residual = x
  90. seq_len = x.size(-1)
  91. conv_bank = []
  92. # Convolution Bank
  93. for conv in self.conv1d_bank:
  94. c = conv(x) # Convolution
  95. conv_bank.append(c[:, :, :seq_len])
  96. # Stack along the channel axis
  97. conv_bank = torch.cat(conv_bank, dim=1)
  98. # dump the last padding to fit residual
  99. x = self.maxpool(conv_bank)[:, :, :seq_len]
  100. # Conv1d projections
  101. x = self.conv_project1(x)
  102. x = self.conv_project2(x)
  103. # Residual Connect
  104. x = x + residual
  105. # Through the highways
  106. x = x.transpose(1, 2)
  107. if self.highway_mismatch is True:
  108. x = self.pre_highway(x)
  109. for h in self.highways:
  110. x = h(x)
  111. # And then the RNN
  112. x, _ = self.rnn(x)
  113. return x
  114. def _flatten_parameters(self):
  115. """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
  116. to improve efficiency and avoid PyTorch yelling at us."""
  117. [m.flatten_parameters() for m in self._to_flatten]
  118. class TacotronEncoder(nn.Module):
  119. def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
  120. super().__init__()
  121. self.embedding = nn.Embedding(num_chars, embed_dims)
  122. self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout)
  123. self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
  124. proj_channels=[cbhg_channels, cbhg_channels],
  125. num_highways=num_highways)
  126. self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels)
  127. def forward(self, x):
  128. x = self.embedding(x)
  129. x = self.pre_net(x)
  130. x.transpose_(1, 2)
  131. x = self.cbhg(x)
  132. x = self.proj_out(x)
  133. return x
  134. class RNNEncoder(nn.Module):
  135. def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5):
  136. super(RNNEncoder, self).__init__()
  137. self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0)
  138. convolutions = []
  139. for _ in range(n_convolutions):
  140. conv_layer = nn.Sequential(
  141. ConvNorm(embedding_dim,
  142. embedding_dim,
  143. kernel_size=kernel_size, stride=1,
  144. padding=int((kernel_size - 1) / 2),
  145. dilation=1, w_init_gain='relu'),
  146. nn.BatchNorm1d(embedding_dim))
  147. convolutions.append(conv_layer)
  148. self.convolutions = nn.ModuleList(convolutions)
  149. self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1,
  150. batch_first=True, bidirectional=True)
  151. def forward(self, x):
  152. input_lengths = (x > 0).sum(-1)
  153. input_lengths = input_lengths.cpu().numpy()
  154. x = self.embedding(x)
  155. x = x.transpose(1, 2) # [B, H, T]
  156. for conv in self.convolutions:
  157. x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x
  158. x = x.transpose(1, 2) # [B, T, H]
  159. # pytorch tensor are not reversible, hence the conversion
  160. x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
  161. self.lstm.flatten_parameters()
  162. outputs, _ = self.lstm(x)
  163. outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
  164. return outputs
  165. class DecoderRNN(torch.nn.Module):
  166. def __init__(self, hidden_size, decoder_rnn_dim, dropout):
  167. super(DecoderRNN, self).__init__()
  168. self.in_conv1d = nn.Sequential(
  169. torch.nn.Conv1d(
  170. in_channels=hidden_size,
  171. out_channels=hidden_size,
  172. kernel_size=9, padding=4,
  173. ),
  174. torch.nn.ReLU(),
  175. torch.nn.Conv1d(
  176. in_channels=hidden_size,
  177. out_channels=hidden_size,
  178. kernel_size=9, padding=4,
  179. ),
  180. )
  181. self.ln = nn.LayerNorm(hidden_size)
  182. if decoder_rnn_dim == 0:
  183. decoder_rnn_dim = hidden_size * 2
  184. self.rnn = torch.nn.LSTM(
  185. input_size=hidden_size,
  186. hidden_size=decoder_rnn_dim,
  187. num_layers=1,
  188. batch_first=True,
  189. bidirectional=True,
  190. dropout=dropout
  191. )
  192. self.rnn.flatten_parameters()
  193. self.conv1d = torch.nn.Conv1d(
  194. in_channels=decoder_rnn_dim * 2,
  195. out_channels=hidden_size,
  196. kernel_size=3,
  197. padding=1,
  198. )
  199. def forward(self, x):
  200. input_masks = x.abs().sum(-1).ne(0).data[:, :, None]
  201. input_lengths = input_masks.sum([-1, -2])
  202. input_lengths = input_lengths.cpu().numpy()
  203. x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2)
  204. x = self.ln(x)
  205. x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
  206. self.rnn.flatten_parameters()
  207. x, _ = self.rnn(x) # [B, T, C]
  208. x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
  209. x = x * input_masks
  210. pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C]
  211. pre_mel = pre_mel * input_masks
  212. return pre_mel