conv.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from modules.commons.layers import LayerNorm, Embedding
  6. class LambdaLayer(nn.Module):
  7. def __init__(self, lambd):
  8. super(LambdaLayer, self).__init__()
  9. self.lambd = lambd
  10. def forward(self, x):
  11. return self.lambd(x)
  12. def init_weights_func(m):
  13. classname = m.__class__.__name__
  14. if classname.find("Conv1d") != -1:
  15. torch.nn.init.xavier_uniform_(m.weight)
  16. class ResidualBlock(nn.Module):
  17. """Implements conv->PReLU->norm n-times"""
  18. def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
  19. c_multiple=2, ln_eps=1e-12, left_pad=False):
  20. super(ResidualBlock, self).__init__()
  21. if norm_type == 'bn':
  22. norm_builder = lambda: nn.BatchNorm1d(channels)
  23. elif norm_type == 'in':
  24. norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
  25. elif norm_type == 'gn':
  26. norm_builder = lambda: nn.GroupNorm(8, channels)
  27. elif norm_type == 'ln':
  28. norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
  29. else:
  30. norm_builder = lambda: nn.Identity()
  31. if left_pad:
  32. self.blocks = [
  33. nn.Sequential(
  34. norm_builder(),
  35. nn.ConstantPad1d(((dilation * (kernel_size - 1)) // 2 * 2, 0), 0),
  36. nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation, padding=0),
  37. LambdaLayer(lambda x: x * kernel_size ** -0.5),
  38. nn.GELU(),
  39. nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, padding_mode='reflect'),
  40. )
  41. for i in range(n)
  42. ]
  43. else:
  44. self.blocks = [
  45. nn.Sequential(
  46. norm_builder(),
  47. nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
  48. padding=(dilation * (kernel_size - 1)) // 2, padding_mode='reflect'),
  49. LambdaLayer(lambda x: x * kernel_size ** -0.5),
  50. nn.GELU(),
  51. nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, padding_mode='reflect'),
  52. )
  53. for i in range(n)
  54. ]
  55. self.blocks = nn.ModuleList(self.blocks)
  56. self.dropout = dropout
  57. def forward(self, x):
  58. nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
  59. for b in self.blocks:
  60. x_ = b(x)
  61. if self.dropout > 0 and self.training:
  62. x_ = F.dropout(x_, self.dropout, training=self.training)
  63. x = x + x_
  64. x = x * nonpadding
  65. return x
  66. class ConvBlocks(nn.Module):
  67. """Decodes the expanded phoneme encoding into spectrograms"""
  68. def __init__(self, hidden_size, out_dims, dilations, kernel_size,
  69. norm_type='ln', layers_in_block=2, c_multiple=2,
  70. dropout=0.0, ln_eps=1e-5,
  71. init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3,
  72. left_pad=False, c_in=None):
  73. super(ConvBlocks, self).__init__()
  74. self.is_BTC = is_BTC
  75. if num_layers is not None:
  76. dilations = [1] * num_layers
  77. self.res_blocks = nn.Sequential(
  78. *[ResidualBlock(hidden_size, kernel_size, d,
  79. n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
  80. dropout=dropout, ln_eps=ln_eps, left_pad=left_pad)
  81. for d in dilations],
  82. )
  83. if norm_type == 'bn':
  84. norm = nn.BatchNorm1d(hidden_size)
  85. elif norm_type == 'in':
  86. norm = nn.InstanceNorm1d(hidden_size, affine=True)
  87. elif norm_type == 'gn':
  88. norm = nn.GroupNorm(8, hidden_size)
  89. elif norm_type == 'ln':
  90. norm = LayerNorm(hidden_size, dim=1, eps=ln_eps)
  91. self.last_norm = norm
  92. if left_pad:
  93. self.post_net1 = nn.Sequential(
  94. nn.ConstantPad1d((post_net_kernel // 2 * 2, 0), 0),
  95. nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel, padding=0),
  96. )
  97. else:
  98. self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
  99. padding=post_net_kernel // 2, padding_mode='reflect')
  100. self.c_in = c_in
  101. if c_in is not None:
  102. self.in_conv = nn.Conv1d(c_in, hidden_size, kernel_size=1, padding_mode='reflect')
  103. if init_weights:
  104. self.apply(init_weights_func)
  105. def forward(self, x, nonpadding=None):
  106. """
  107. :param x: [B, T, H]
  108. :return: [B, T, H]
  109. """
  110. if self.is_BTC:
  111. x = x.transpose(1, 2)
  112. if self.c_in is not None:
  113. x = self.in_conv(x)
  114. if nonpadding is None:
  115. nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
  116. elif self.is_BTC:
  117. nonpadding = nonpadding.transpose(1, 2)
  118. x = self.res_blocks(x) * nonpadding
  119. x = self.last_norm(x) * nonpadding
  120. x = self.post_net1(x) * nonpadding
  121. if self.is_BTC:
  122. x = x.transpose(1, 2)
  123. return x
  124. class TextConvEncoder(ConvBlocks):
  125. def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
  126. norm_type='ln', layers_in_block=2, c_multiple=2,
  127. dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
  128. super().__init__(hidden_size, out_dims, dilations, kernel_size,
  129. norm_type, layers_in_block, c_multiple,
  130. dropout, ln_eps, init_weights, num_layers=num_layers,
  131. post_net_kernel=post_net_kernel)
  132. self.dict_size = dict_size
  133. if dict_size > 0:
  134. self.embed_tokens = Embedding(dict_size, hidden_size, 0)
  135. self.embed_scale = math.sqrt(hidden_size)
  136. def forward(self, txt_tokens, other_embeds=0):
  137. """
  138. :param txt_tokens: [B, T]
  139. :return: {
  140. 'encoder_out': [B x T x C]
  141. }
  142. """
  143. if self.dict_size > 0:
  144. x = self.embed_scale * self.embed_tokens(txt_tokens)
  145. else:
  146. x = txt_tokens
  147. x = x + other_embeds
  148. return super().forward(x, nonpadding=(txt_tokens > 0).float()[..., None])
  149. class ConditionalConvBlocks(ConvBlocks):
  150. def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
  151. norm_type='ln', layers_in_block=2, c_multiple=2,
  152. dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
  153. super().__init__(hidden_size, c_out, dilations, kernel_size,
  154. norm_type, layers_in_block, c_multiple,
  155. dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
  156. self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1, padding_mode='reflect')
  157. self.is_BTC_ = is_BTC
  158. if init_weights:
  159. self.g_prenet.apply(init_weights_func)
  160. def forward(self, x, cond, nonpadding=None):
  161. if self.is_BTC_:
  162. x = x.transpose(1, 2)
  163. cond = cond.transpose(1, 2)
  164. if nonpadding is not None:
  165. nonpadding = nonpadding.transpose(1, 2)
  166. if nonpadding is None:
  167. nonpadding = x.abs().sum(1)[:, None]
  168. x = x + self.g_prenet(cond)
  169. x = x * nonpadding
  170. x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
  171. if self.is_BTC_:
  172. x = x.transpose(1, 2)
  173. return x