res_flow.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. from torch import nn
  3. from modules.commons.conv import ConditionalConvBlocks
  4. from modules.commons.wavenet import WN
  5. class FlipLayer(nn.Module):
  6. def forward(self, x, *args, **kwargs):
  7. x = torch.flip(x, [1])
  8. return x
  9. class CouplingLayer(nn.Module):
  10. def __init__(self, c_in, hidden_size, kernel_size, n_layers, p_dropout=0, c_in_g=0, nn_type='wn'):
  11. super().__init__()
  12. self.channels = c_in
  13. self.hidden_size = hidden_size
  14. self.kernel_size = kernel_size
  15. self.n_layers = n_layers
  16. self.c_half = c_in // 2
  17. self.pre = nn.Conv1d(self.c_half, hidden_size, 1)
  18. if nn_type == 'wn':
  19. self.enc = WN(hidden_size, kernel_size, 1, n_layers, p_dropout=p_dropout,
  20. c_cond=c_in_g)
  21. elif nn_type == 'conv':
  22. self.enc = ConditionalConvBlocks(
  23. hidden_size, c_in_g, hidden_size, None, kernel_size,
  24. layers_in_block=1, is_BTC=False, num_layers=n_layers)
  25. self.post = nn.Conv1d(hidden_size, self.c_half, 1)
  26. def forward(self, x, nonpadding, cond=None, reverse=False):
  27. x0, x1 = x[:, :self.c_half], x[:, self.c_half:]
  28. x_ = self.pre(x0) * nonpadding
  29. x_ = self.enc(x_, nonpadding=nonpadding, cond=cond)
  30. m = self.post(x_)
  31. x1 = m + x1 if not reverse else x1 - m
  32. x = torch.cat([x0, x1], 1)
  33. return x * nonpadding
  34. class ResFlow(nn.Module):
  35. def __init__(self,
  36. c_in,
  37. hidden_size,
  38. kernel_size,
  39. n_flow_layers,
  40. n_flow_steps=4,
  41. c_cond=0,
  42. nn_type='wn'):
  43. super().__init__()
  44. self.flows = nn.ModuleList()
  45. for i in range(n_flow_steps):
  46. self.flows.append(
  47. CouplingLayer(c_in, hidden_size, kernel_size, n_flow_layers, c_in_g=c_cond, nn_type=nn_type))
  48. self.flows.append(FlipLayer())
  49. def forward(self, x, nonpadding, cond=None, reverse=False):
  50. for flow in (self.flows if not reverse else reversed(self.flows)):
  51. x = flow(x, nonpadding, cond=cond, reverse=reverse)
  52. return x