glow_modules.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. import scipy
  2. from torch.nn import functional as F
  3. import torch
  4. from torch import nn
  5. import numpy as np
  6. from modules.commons.wavenet import WN
  7. from modules.tts.glow import utils
  8. class ActNorm(nn.Module):
  9. def __init__(self, channels, ddi=False, **kwargs):
  10. super().__init__()
  11. self.channels = channels
  12. self.initialized = not ddi
  13. self.logs = nn.Parameter(torch.zeros(1, channels, 1))
  14. self.bias = nn.Parameter(torch.zeros(1, channels, 1))
  15. def forward(self, x, x_mask=None, reverse=False, **kwargs):
  16. if x_mask is None:
  17. x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
  18. x_len = torch.sum(x_mask, [1, 2])
  19. if not self.initialized:
  20. self.initialize(x, x_mask)
  21. self.initialized = True
  22. if reverse:
  23. z = (x - self.bias) * torch.exp(-self.logs) * x_mask
  24. logdet = torch.sum(-self.logs) * x_len
  25. else:
  26. z = (self.bias + torch.exp(self.logs) * x) * x_mask
  27. logdet = torch.sum(self.logs) * x_len # [b]
  28. return z, logdet
  29. def store_inverse(self):
  30. pass
  31. def set_ddi(self, ddi):
  32. self.initialized = not ddi
  33. def initialize(self, x, x_mask):
  34. with torch.no_grad():
  35. denom = torch.sum(x_mask, [0, 2])
  36. m = torch.sum(x * x_mask, [0, 2]) / denom
  37. m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
  38. v = m_sq - (m ** 2)
  39. logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
  40. bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
  41. logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
  42. self.bias.data.copy_(bias_init)
  43. self.logs.data.copy_(logs_init)
  44. class InvConvNear(nn.Module):
  45. def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
  46. super().__init__()
  47. assert (n_split % 2 == 0)
  48. self.channels = channels
  49. self.n_split = n_split
  50. self.n_sqz = n_sqz
  51. self.no_jacobian = no_jacobian
  52. w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
  53. if torch.det(w_init) < 0:
  54. w_init[:, 0] = -1 * w_init[:, 0]
  55. self.lu = lu
  56. if lu:
  57. # LU decomposition can slightly speed up the inverse
  58. np_p, np_l, np_u = scipy.linalg.lu(w_init)
  59. np_s = np.diag(np_u)
  60. np_sign_s = np.sign(np_s)
  61. np_log_s = np.log(np.abs(np_s))
  62. np_u = np.triu(np_u, k=1)
  63. l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
  64. eye = np.eye(*w_init.shape, dtype=float)
  65. self.register_buffer('p', torch.Tensor(np_p.astype(float)))
  66. self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
  67. self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
  68. self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
  69. self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
  70. self.register_buffer('l_mask', torch.Tensor(l_mask))
  71. self.register_buffer('eye', torch.Tensor(eye))
  72. else:
  73. self.weight = nn.Parameter(w_init)
  74. def forward(self, x, x_mask=None, reverse=False, **kwargs):
  75. b, c, t = x.size()
  76. assert (c % self.n_split == 0)
  77. if x_mask is None:
  78. x_mask = 1
  79. x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
  80. else:
  81. x_len = torch.sum(x_mask, [1, 2])
  82. x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
  83. x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
  84. if self.lu:
  85. self.weight, log_s = self._get_weight()
  86. logdet = log_s.sum()
  87. logdet = logdet * (c / self.n_split) * x_len
  88. else:
  89. logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
  90. if reverse:
  91. if hasattr(self, "weight_inv"):
  92. weight = self.weight_inv
  93. else:
  94. weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
  95. logdet = -logdet
  96. else:
  97. weight = self.weight
  98. if self.no_jacobian:
  99. logdet = 0
  100. weight = weight.view(self.n_split, self.n_split, 1, 1)
  101. z = F.conv2d(x, weight)
  102. z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
  103. z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
  104. return z, logdet
  105. def _get_weight(self):
  106. l, log_s, u = self.l, self.log_s, self.u
  107. l = l * self.l_mask + self.eye
  108. u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
  109. weight = torch.matmul(self.p, torch.matmul(l, u))
  110. return weight, log_s
  111. def store_inverse(self):
  112. weight, _ = self._get_weight()
  113. self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
  114. class InvConv(nn.Module):
  115. def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
  116. super().__init__()
  117. w_shape = [channels, channels]
  118. w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
  119. LU_decomposed = lu
  120. if not LU_decomposed:
  121. # Sample a random orthogonal matrix:
  122. self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
  123. else:
  124. np_p, np_l, np_u = scipy.linalg.lu(w_init)
  125. np_s = np.diag(np_u)
  126. np_sign_s = np.sign(np_s)
  127. np_log_s = np.log(np.abs(np_s))
  128. np_u = np.triu(np_u, k=1)
  129. l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
  130. eye = np.eye(*w_shape, dtype=float)
  131. self.register_buffer('p', torch.Tensor(np_p.astype(float)))
  132. self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
  133. self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
  134. self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
  135. self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
  136. self.l_mask = torch.Tensor(l_mask)
  137. self.eye = torch.Tensor(eye)
  138. self.w_shape = w_shape
  139. self.LU = LU_decomposed
  140. self.weight = None
  141. def get_weight(self, device, reverse):
  142. w_shape = self.w_shape
  143. self.p = self.p.to(device)
  144. self.sign_s = self.sign_s.to(device)
  145. self.l_mask = self.l_mask.to(device)
  146. self.eye = self.eye.to(device)
  147. l = self.l * self.l_mask + self.eye
  148. u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
  149. dlogdet = self.log_s.sum()
  150. if not reverse:
  151. w = torch.matmul(self.p, torch.matmul(l, u))
  152. else:
  153. l = torch.inverse(l.double()).float()
  154. u = torch.inverse(u.double()).float()
  155. w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
  156. return w.view(w_shape[0], w_shape[1], 1), dlogdet
  157. def forward(self, x, x_mask=None, reverse=False, **kwargs):
  158. """
  159. log-det = log|abs(|W|)| * pixels
  160. """
  161. b, c, t = x.size()
  162. if x_mask is None:
  163. x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
  164. else:
  165. x_len = torch.sum(x_mask, [1, 2])
  166. logdet = 0
  167. if not reverse:
  168. weight, dlogdet = self.get_weight(x.device, reverse)
  169. z = F.conv1d(x, weight)
  170. if logdet is not None:
  171. logdet = logdet + dlogdet * x_len
  172. return z, logdet
  173. else:
  174. if self.weight is None:
  175. weight, dlogdet = self.get_weight(x.device, reverse)
  176. else:
  177. weight, dlogdet = self.weight, self.dlogdet
  178. z = F.conv1d(x, weight)
  179. if logdet is not None:
  180. logdet = logdet - dlogdet * x_len
  181. return z, logdet
  182. def store_inverse(self):
  183. self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
  184. class CouplingBlock(nn.Module):
  185. def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
  186. gin_channels=0, p_dropout=0, sigmoid_scale=False, wn=None):
  187. super().__init__()
  188. self.in_channels = in_channels
  189. self.hidden_channels = hidden_channels
  190. self.kernel_size = kernel_size
  191. self.dilation_rate = dilation_rate
  192. self.n_layers = n_layers
  193. self.gin_channels = gin_channels
  194. self.p_dropout = p_dropout
  195. self.sigmoid_scale = sigmoid_scale
  196. start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
  197. start = torch.nn.utils.weight_norm(start)
  198. self.start = start
  199. # Initializing last layer to 0 makes the affine coupling layers
  200. # do nothing at first. This helps with training stability
  201. end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
  202. end.weight.data.zero_()
  203. end.bias.data.zero_()
  204. self.end = end
  205. self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, p_dropout)
  206. if wn is not None:
  207. self.wn.in_layers = wn.in_layers
  208. self.wn.res_skip_layers = wn.res_skip_layers
  209. def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
  210. if x_mask is None:
  211. x_mask = 1
  212. x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
  213. x = self.start(x_0) * x_mask
  214. x = self.wn(x, x_mask, g)
  215. out = self.end(x)
  216. z_0 = x_0
  217. m = out[:, :self.in_channels // 2, :]
  218. logs = out[:, self.in_channels // 2:, :]
  219. if self.sigmoid_scale:
  220. logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
  221. if reverse:
  222. z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
  223. logdet = torch.sum(-logs * x_mask, [1, 2])
  224. else:
  225. z_1 = (m + torch.exp(logs) * x_1) * x_mask
  226. logdet = torch.sum(logs * x_mask, [1, 2])
  227. z = torch.cat([z_0, z_1], 1)
  228. return z, logdet
  229. def store_inverse(self):
  230. self.wn.remove_weight_norm()
  231. class Glow(nn.Module):
  232. def __init__(self,
  233. in_channels,
  234. hidden_channels,
  235. kernel_size,
  236. dilation_rate,
  237. n_blocks,
  238. n_layers,
  239. p_dropout=0.,
  240. n_split=4,
  241. n_sqz=2,
  242. sigmoid_scale=False,
  243. gin_channels=0,
  244. inv_conv_type='near',
  245. share_cond_layers=False,
  246. share_wn_layers=0,
  247. ):
  248. super().__init__()
  249. self.in_channels = in_channels
  250. self.hidden_channels = hidden_channels
  251. self.kernel_size = kernel_size
  252. self.dilation_rate = dilation_rate
  253. self.n_blocks = n_blocks
  254. self.n_layers = n_layers
  255. self.p_dropout = p_dropout
  256. self.n_split = n_split
  257. self.n_sqz = n_sqz
  258. self.sigmoid_scale = sigmoid_scale
  259. self.gin_channels = gin_channels
  260. self.share_cond_layers = share_cond_layers
  261. if gin_channels != 0 and share_cond_layers:
  262. cond_layer = torch.nn.Conv1d(gin_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
  263. self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
  264. wn = None
  265. self.flows = nn.ModuleList()
  266. for b in range(n_blocks):
  267. self.flows.append(ActNorm(channels=in_channels * n_sqz))
  268. if inv_conv_type == 'near':
  269. self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
  270. if inv_conv_type == 'invconv':
  271. self.flows.append(InvConv(channels=in_channels * n_sqz))
  272. if share_wn_layers > 0:
  273. if b % share_wn_layers == 0:
  274. wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels * n_sqz,
  275. p_dropout, share_cond_layers)
  276. self.flows.append(
  277. CouplingBlock(
  278. in_channels * n_sqz,
  279. hidden_channels,
  280. kernel_size=kernel_size,
  281. dilation_rate=dilation_rate,
  282. n_layers=n_layers,
  283. gin_channels=gin_channels * n_sqz,
  284. p_dropout=p_dropout,
  285. sigmoid_scale=sigmoid_scale,
  286. wn=wn
  287. ))
  288. def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
  289. logdet_tot = 0
  290. if not reverse:
  291. flows = self.flows
  292. else:
  293. flows = reversed(self.flows)
  294. if return_hiddens:
  295. hs = []
  296. if self.n_sqz > 1:
  297. x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz)
  298. if g is not None:
  299. g, _ = utils.squeeze(g, x_mask, self.n_sqz)
  300. x_mask = x_mask_
  301. if self.share_cond_layers and g is not None:
  302. g = self.cond_layer(g)
  303. for f in flows:
  304. x, logdet = f(x, x_mask, g=g, reverse=reverse)
  305. if return_hiddens:
  306. hs.append(x)
  307. logdet_tot += logdet
  308. if self.n_sqz > 1:
  309. x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz)
  310. if return_hiddens:
  311. return x, logdet_tot, hs
  312. return x, logdet_tot
  313. def store_inverse(self):
  314. def remove_weight_norm(m):
  315. try:
  316. nn.utils.remove_weight_norm(m)
  317. except ValueError: # this module didn't have weight norm
  318. return
  319. self.apply(remove_weight_norm)
  320. for f in self.flows:
  321. f.store_inverse()