vqvae.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import scipy
  2. from scipy import linalg
  3. from torch.nn import functional as F
  4. import torch
  5. from torch import nn
  6. import numpy as np
  7. from modules.audio2motion.transformer_models import FFTBlocks
  8. import modules.audio2motion.utils as utils
  9. from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock
  10. import torch.distributions as dist
  11. from modules.audio2motion.cnn_models import LambdaLayer, LayerNorm
  12. from vector_quantize_pytorch import VectorQuantize
  13. class FVAEEncoder(nn.Module):
  14. def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size,
  15. n_layers, gin_channels=0, p_dropout=0, strides=[4]):
  16. super().__init__()
  17. self.strides = strides
  18. self.hidden_size = hidden_channels
  19. self.pre_net = nn.Sequential(*[
  20. nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
  21. if i == 0 else
  22. nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
  23. for i, s in enumerate(strides)
  24. ])
  25. self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
  26. self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1)
  27. self.latent_channels = latent_channels
  28. def forward(self, x, x_mask, g):
  29. x = self.pre_net(x)
  30. x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]]
  31. x = x * x_mask
  32. x = self.wn(x, x_mask, g) * x_mask
  33. x = self.out_proj(x)
  34. m, logs = torch.split(x, self.latent_channels, dim=1)
  35. z = (m + torch.randn_like(m) * torch.exp(logs))
  36. return z, m, logs, x_mask
  37. class FVAEDecoder(nn.Module):
  38. def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size,
  39. n_layers, gin_channels=0, p_dropout=0,
  40. strides=[4]):
  41. super().__init__()
  42. self.strides = strides
  43. self.hidden_size = hidden_channels
  44. self.pre_net = nn.Sequential(*[
  45. nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s)
  46. if i == 0 else
  47. nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s)
  48. for i, s in enumerate(strides)
  49. ])
  50. self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
  51. self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
  52. def forward(self, x, x_mask, g):
  53. x = self.pre_net(x)
  54. x = x * x_mask
  55. x = self.wn(x, x_mask, g) * x_mask
  56. x = self.out_proj(x)
  57. return x
  58. class VQVAE(nn.Module):
  59. def __init__(self,
  60. in_out_channels=64, hidden_channels=256, latent_size=16,
  61. kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,],
  62. sqz_prior=False):
  63. super().__init__()
  64. self.in_out_channels = in_out_channels
  65. self.strides = strides
  66. self.hidden_size = hidden_channels
  67. self.latent_size = latent_size
  68. self.g_pre_net = nn.Sequential(*[
  69. nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2)
  70. for i, s in enumerate(strides)
  71. ])
  72. self.encoder = FVAEEncoder(in_out_channels, hidden_channels, hidden_channels, kernel_size,
  73. enc_n_layers, gin_channels, strides=strides)
  74. # if use_prior_glow:
  75. # self.prior_flow = ResidualCouplingBlock(
  76. # latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels)
  77. self.vq = VectorQuantize(dim=hidden_channels, codebook_size=256, codebook_dim=16)
  78. self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size,
  79. dec_n_layers, gin_channels, strides=strides)
  80. self.prior_dist = dist.Normal(0, 1)
  81. self.sqz_prior = sqz_prior
  82. def forward(self, x=None, x_mask=None, g=None, infer=False, **kwargs):
  83. """
  84. :param x: [B, T, C_in_out]
  85. :param x_mask: [B, T]
  86. :param g: [B, T, C_g]
  87. :return:
  88. """
  89. x_mask = x_mask[:, None, :] # [B, 1, T]
  90. g = g.transpose(1,2) # [B, C_g, T]
  91. g_for_sqz = g
  92. g_sqz = self.g_pre_net(g_for_sqz)
  93. if not infer:
  94. x = x.transpose(1,2) # [B, C, T]
  95. z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz)
  96. if self.sqz_prior:
  97. z_q = F.interpolate(z_q, scale_factor=1/8)
  98. z_p, idx, commit_loss = self.vq(z_q.transpose(1,2))
  99. if self.sqz_prior:
  100. z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2)
  101. x_recon = self.decoder(z_p.transpose(1,2), x_mask, g)
  102. return x_recon.transpose(1,2), commit_loss, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2)
  103. else:
  104. bs, t = g_sqz.shape[0], g_sqz.shape[2]
  105. if self.sqz_prior:
  106. t = t // 8
  107. latent_shape = [int(bs * t)]
  108. latent_idx = torch.randint(0,256,latent_shape).to(self.vq.codebook.device)
  109. # latent_idx = torch.ones_like(latent_idx, dtype=torch.long)
  110. # z_p = torch.gather(self.vq.codebook, 0, latent_idx)# self.vq.codebook[latent_idx]
  111. z_p = self.vq.codebook[latent_idx]
  112. z_p = z_p.reshape([bs, t, -1])
  113. z_p = self.vq.project_out(z_p)
  114. if self.sqz_prior:
  115. z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2)
  116. x_recon = self.decoder(z_p.transpose(1,2), 1, g)
  117. return x_recon.transpose(1,2), z_p.transpose(1,2)
  118. class VQVAEModel(nn.Module):
  119. def __init__(self, in_out_dim=71, sqz_prior=False, enc_no_cond=False):
  120. super().__init__()
  121. self.mel_encoder = nn.Sequential(*[
  122. nn.Conv1d(80, 64, 3, 1, 1, bias=False),
  123. nn.BatchNorm1d(64),
  124. nn.GELU(),
  125. nn.Conv1d(64, 64, 3, 1, 1, bias=False)
  126. ])
  127. self.in_dim, self.out_dim = in_out_dim, in_out_dim
  128. self.sqz_prior = sqz_prior
  129. self.enc_no_cond = enc_no_cond
  130. self.vae = VQVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
  131. enc_n_layers=8, dec_n_layers=4, gin_channels=64, strides=[4,], sqz_prior=sqz_prior)
  132. self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
  133. @property
  134. def device(self):
  135. return self.vae.parameters().__next__().device
  136. def forward(self, batch, ret, log_dict=None, train=True):
  137. infer = not train
  138. mask = batch['y_mask'].to(self.device)
  139. mel = batch['mel'].to(self.device)
  140. mel = self.downsampler(mel)
  141. mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
  142. if not infer:
  143. exp = batch['exp'].to(self.device)
  144. pose = batch['pose'].to(self.device)
  145. if self.in_dim == 71:
  146. x = torch.cat([exp, pose], dim=-1) # [B, T, C=64 + 7]
  147. elif self.in_dim == 64:
  148. x = exp
  149. elif self.in_dim == 7:
  150. x = pose
  151. if self.enc_no_cond:
  152. x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=torch.zeros_like(mel_feat), infer=False)
  153. else:
  154. x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=mel_feat, infer=False)
  155. loss_commit = loss_commit.reshape([])
  156. ret['pred'] = x_recon
  157. ret['mask'] = mask
  158. ret['loss_commit'] = loss_commit
  159. return x_recon, loss_commit, m_q, logs_q
  160. else:
  161. x_recon, z_p = self.vae(x=None, x_mask=mask, g=mel_feat, infer=True)
  162. return x_recon
  163. # def __get_feat(self, exp, pose):
  164. # diff_exp = exp[:-1, :] - exp[1:, :]
  165. # exp_std = (np.std(exp, axis = 0) - self.exp_std_mean) / self.exp_std_std
  166. # diff_exp_std = (np.std(diff_exp, axis = 0) - self.exp_diff_std_mean) / self.exp_diff_std_std
  167. # diff_pose = pose[:-1, :] - pose[1:, :]
  168. # diff_pose_std = (np.std(diff_pose, axis = 0) - self.pose_diff_std_mean) / self.pose_diff_std_std
  169. # return np.concatenate((exp_std, diff_exp_std, diff_pose_std))
  170. def num_params(self, model, print_out=True, model_name="model"):
  171. parameters = filter(lambda p: p.requires_grad, model.parameters())
  172. parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
  173. if print_out:
  174. print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
  175. return parameters