vae.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. import math
  2. import torch
  3. from torch import nn
  4. from torch.nn import functional as F
  5. import torch.distributions as dist
  6. import numpy as np
  7. import copy
  8. from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock
  9. from modules.audio2motion.transformer_base import Embedding
  10. from utils.commons.pitch_utils import f0_to_coarse
  11. from utils.commons.hparams import hparams
  12. class LambdaLayer(nn.Module):
  13. def __init__(self, lambd):
  14. super(LambdaLayer, self).__init__()
  15. self.lambd = lambd
  16. def forward(self, x):
  17. return self.lambd(x)
  18. def make_positions(tensor, padding_idx):
  19. """Replace non-padding symbols with their position numbers.
  20. Position numbers begin at padding_idx+1. Padding symbols are ignored.
  21. """
  22. # The series of casts and type-conversions here are carefully
  23. # balanced to both work with ONNX export and XLA. In particular XLA
  24. # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
  25. # how to handle the dtype kwarg in cumsum.
  26. mask = tensor.ne(padding_idx).int()
  27. return (
  28. torch.cumsum(mask, dim=1).type_as(mask) * mask
  29. ).long() + padding_idx
  30. class SinusoidalPositionalEmbedding(nn.Module):
  31. """This module produces sinusoidal positional embeddings of any length.
  32. Padding symbols are ignored.
  33. """
  34. def __init__(self, embedding_dim, padding_idx, init_size=1024):
  35. super().__init__()
  36. self.embedding_dim = embedding_dim
  37. self.padding_idx = padding_idx
  38. self.weights = SinusoidalPositionalEmbedding.get_embedding(
  39. init_size,
  40. embedding_dim,
  41. padding_idx,
  42. )
  43. self.register_buffer('_float_tensor', torch.FloatTensor(1))
  44. @staticmethod
  45. def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
  46. """Build sinusoidal embeddings.
  47. This matches the implementation in tensor2tensor, but differs slightly
  48. from the description in Section 3.5 of "Attention Is All You Need".
  49. """
  50. half_dim = embedding_dim // 2
  51. emb = math.log(10000) / (half_dim - 1)
  52. emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
  53. emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
  54. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  55. if embedding_dim % 2 == 1:
  56. # zero pad
  57. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  58. if padding_idx is not None:
  59. emb[padding_idx, :] = 0
  60. return emb
  61. def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
  62. """Input is expected to be of size [bsz x seqlen]."""
  63. bsz, seq_len = input.shape[:2]
  64. max_pos = self.padding_idx + 1 + seq_len
  65. if self.weights is None or max_pos > self.weights.size(0):
  66. # recompute/expand embeddings if needed
  67. self.weights = SinusoidalPositionalEmbedding.get_embedding(
  68. max_pos,
  69. self.embedding_dim,
  70. self.padding_idx,
  71. )
  72. self.weights = self.weights.to(self._float_tensor)
  73. if incremental_state is not None:
  74. # positions is the same for every token when decoding a single step
  75. pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
  76. return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
  77. positions = make_positions(input, self.padding_idx) if positions is None else positions
  78. return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
  79. def max_positions(self):
  80. """Maximum number of supported positions."""
  81. return int(1e4) # an arbitrary large number
  82. class FVAEEncoder(nn.Module):
  83. def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size,
  84. n_layers, gin_channels=0, p_dropout=0, strides=[4]):
  85. super().__init__()
  86. self.strides = strides
  87. self.hidden_size = hidden_channels
  88. self.pre_net = nn.Sequential(*[
  89. nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
  90. if i == 0 else
  91. nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
  92. for i, s in enumerate(strides)
  93. ])
  94. self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
  95. self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1)
  96. self.latent_channels = latent_channels
  97. def forward(self, x, x_mask, g):
  98. x = self.pre_net(x)
  99. x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]]
  100. x = x * x_mask
  101. x = self.wn(x, x_mask, g) * x_mask
  102. x = self.out_proj(x)
  103. m, logs = torch.split(x, self.latent_channels, dim=1)
  104. z = (m + torch.randn_like(m) * torch.exp(logs))
  105. return z, m, logs, x_mask
  106. class FVAEDecoder(nn.Module):
  107. def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size,
  108. n_layers, gin_channels=0, p_dropout=0,
  109. strides=[4]):
  110. super().__init__()
  111. self.strides = strides
  112. self.hidden_size = hidden_channels
  113. self.pre_net = nn.Sequential(*[
  114. nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s)
  115. if i == 0 else
  116. nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s)
  117. for i, s in enumerate(strides)
  118. ])
  119. self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
  120. self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
  121. def forward(self, x, x_mask, g):
  122. x = self.pre_net(x)
  123. x = x * x_mask
  124. x = self.wn(x, x_mask, g) * x_mask
  125. x = self.out_proj(x)
  126. return x
  127. class FVAE(nn.Module):
  128. def __init__(self,
  129. in_out_channels=64, hidden_channels=256, latent_size=16,
  130. kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,],
  131. use_prior_glow=True, glow_hidden=256, glow_kernel_size=3, glow_n_blocks=5,
  132. sqz_prior=False, use_pos_emb=False):
  133. super(FVAE, self).__init__()
  134. self.in_out_channels = in_out_channels
  135. self.strides = strides
  136. self.hidden_size = hidden_channels
  137. self.latent_size = latent_size
  138. self.use_prior_glow = use_prior_glow
  139. self.sqz_prior = sqz_prior
  140. self.g_pre_net = nn.Sequential(*[
  141. nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2)
  142. for i, s in enumerate(strides)
  143. ])
  144. self.encoder = FVAEEncoder(in_out_channels, hidden_channels, latent_size, kernel_size,
  145. enc_n_layers, gin_channels, strides=strides)
  146. if use_prior_glow:
  147. self.prior_flow = ResidualCouplingBlock(
  148. latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels)
  149. self.use_pos_embed = use_pos_emb
  150. if sqz_prior:
  151. self.query_proj = nn.Linear(latent_size, latent_size)
  152. self.key_proj = nn.Linear(latent_size, latent_size)
  153. self.value_proj = nn.Linear(latent_size, hidden_channels)
  154. if self.in_out_channels in [7, 64]:
  155. self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size,
  156. dec_n_layers, gin_channels, strides=strides)
  157. elif self.in_out_channels == 71:
  158. self.exp_decoder = FVAEDecoder(hidden_channels, hidden_channels, 64, kernel_size,
  159. dec_n_layers, gin_channels, strides=strides)
  160. self.pose_decoder = FVAEDecoder(hidden_channels, hidden_channels, 7, kernel_size,
  161. dec_n_layers, gin_channels, strides=strides)
  162. if self.use_pos_embed:
  163. self.embed_positions = SinusoidalPositionalEmbedding(self.latent_size, 0,init_size=2000+1,)
  164. else:
  165. self.decoder = FVAEDecoder(latent_size, hidden_channels, in_out_channels, kernel_size,
  166. dec_n_layers, gin_channels, strides=strides)
  167. self.prior_dist = dist.Normal(0, 1)
  168. def forward(self, x=None, x_mask=None, g=None, infer=False, temperature=1. , **kwargs):
  169. """
  170. :param x: [B, T, C_in_out]
  171. :param x_mask: [B, T]
  172. :param g: [B, T, C_g]
  173. :return:
  174. """
  175. x_mask = x_mask[:, None, :] # [B, 1, T]
  176. g = g.transpose(1,2) # [B, C_g, T]
  177. g_for_sqz = g
  178. g_sqz = self.g_pre_net(g_for_sqz)
  179. if not infer:
  180. x = x.transpose(1,2) # [B, C, T]
  181. z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz)
  182. if self.sqz_prior:
  183. z = z_q
  184. if self.use_pos_embed:
  185. position = self.embed_positions(z.transpose(1,2).abs().sum(-1)).transpose(1,2)
  186. z = z + position
  187. q = self.query_proj(z.mean(dim=-1,keepdim=True).transpose(1,2)) # [B, 1, C=16]
  188. k = self.key_proj(z.transpose(1,2)) # [B, T, C=16]
  189. v = self.value_proj(z.transpose(1,2)) # [B, T, C=256]
  190. attn = torch.bmm(q,k.transpose(1,2)) # [B, 1, T]
  191. attn = F.softmax(attn, dim=-1)
  192. out = torch.bmm(attn, v) # [B, 1, C=256]
  193. style_encoding = out.repeat([1,z_q.shape[-1],1]).transpose(1,2) # [B, C=256, T]
  194. if self.in_out_channels == 71:
  195. x_recon = torch.cat([self.exp_decoder(style_encoding, x_mask, g), self.pose_decoder(style_encoding, x_mask, g)], dim=1)
  196. else:
  197. x_recon = self.decoder(style_encoding, x_mask, g)
  198. else:
  199. if self.in_out_channels == 71:
  200. x_recon = torch.cat([self.exp_decoder(z_q, x_mask, g), self.pose_decoder(z_q, x_mask, g)], dim=1)
  201. else:
  202. x_recon = self.decoder(z_q, x_mask, g)
  203. q_dist = dist.Normal(m_q, logs_q.exp())
  204. if self.use_prior_glow:
  205. logqx = q_dist.log_prob(z_q)
  206. z_p = self.prior_flow(z_q, x_mask_sqz, g_sqz)
  207. logpx = self.prior_dist.log_prob(z_p)
  208. loss_kl = ((logqx - logpx) * x_mask_sqz).sum() / x_mask_sqz.sum() / logqx.shape[1]
  209. else:
  210. loss_kl = torch.distributions.kl_divergence(q_dist, self.prior_dist)
  211. loss_kl = (loss_kl * x_mask_sqz).sum() / x_mask_sqz.sum() / z_q.shape[1]
  212. z_p = z_q
  213. return x_recon.transpose(1,2), loss_kl, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2)
  214. else:
  215. latent_shape = [g_sqz.shape[0], self.latent_size, g_sqz.shape[2]]
  216. z_p = self.prior_dist.sample(latent_shape).to(g.device) * temperature # [B, latent_size, T_sqz]
  217. if self.use_prior_glow:
  218. z_p = self.prior_flow(z_p, 1, g_sqz, reverse=True)
  219. if self.sqz_prior:
  220. z = z_p
  221. if self.use_pos_embed:
  222. position = self.embed_positions(z.abs().sum(-1))
  223. z += position
  224. q = self.query_proj(z.mean(dim=-1,keepdim=True).transpose(1,2)) # [B, 1, C=16]
  225. k = self.key_proj(z.transpose(1,2)) # [B, T, C=16]
  226. v = self.value_proj(z.transpose(1,2)) # [B, T, C=256]
  227. attn = torch.bmm(q,k.transpose(1,2)) # [B, 1, T]
  228. attn = F.softmax(attn, dim=-1)
  229. out = torch.bmm(attn, v) # [B, 1, C=256]
  230. style_encoding = out.repeat([1,z_p.shape[-1],1]).transpose(1,2) # [B, C=256, T]
  231. x_recon = self.decoder(style_encoding, 1, g)
  232. if self.in_out_channels == 71:
  233. x_recon = torch.cat([self.exp_decoder(style_encoding, 1, g), self.pose_decoder(style_encoding, 1, g)], dim=1)
  234. else:
  235. x_recon = self.decoder(style_encoding, 1, g)
  236. else:
  237. if self.in_out_channels == 71:
  238. x_recon = torch.cat([self.exp_decoder(z_p, 1, g), self.pose_decoder(z_p, 1, g)], dim=1)
  239. else:
  240. x_recon = self.decoder(z_p, 1, g)
  241. return x_recon.transpose(1,2), z_p.transpose(1,2)
  242. class VAEModel(nn.Module):
  243. def __init__(self, in_out_dim=64, audio_in_dim=1024, sqz_prior=False, cond_drop=False, use_prior_flow=True):
  244. super().__init__()
  245. feat_dim = 64
  246. self.blink_embed = nn.Embedding(2, feat_dim)
  247. self.audio_in_dim = audio_in_dim
  248. cond_dim = feat_dim
  249. self.mel_encoder = nn.Sequential(*[
  250. nn.Conv1d(audio_in_dim, 64, 3, 1, 1, bias=False),
  251. nn.BatchNorm1d(64),
  252. nn.GELU(),
  253. nn.Conv1d(64, feat_dim, 3, 1, 1, bias=False)
  254. ])
  255. self.cond_drop = cond_drop
  256. if self.cond_drop:
  257. self.dropout = nn.Dropout(0.5)
  258. self.in_dim, self.out_dim = in_out_dim, in_out_dim
  259. self.sqz_prior = sqz_prior
  260. self.use_prior_flow = use_prior_flow
  261. self.vae = FVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
  262. enc_n_layers=8, dec_n_layers=4, gin_channels=cond_dim, strides=[4,],
  263. use_prior_glow=self.use_prior_flow, glow_hidden=64, glow_kernel_size=3, glow_n_blocks=4,sqz_prior=sqz_prior)
  264. self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='linear').transpose(1,2))
  265. # self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
  266. def num_params(self, model, print_out=True, model_name="model"):
  267. parameters = filter(lambda p: p.requires_grad, model.parameters())
  268. parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
  269. if print_out:
  270. print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
  271. return parameters
  272. @property
  273. def device(self):
  274. return self.vae.parameters().__next__().device
  275. def forward(self, batch, ret, train=True, return_latent=False, temperature=1.):
  276. infer = not train
  277. mask = batch['y_mask'].to(self.device)
  278. mel = batch['audio'].to(self.device)
  279. mel = self.downsampler(mel)
  280. cond_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
  281. if self.cond_drop:
  282. cond_feat = self.dropout(cond_feat)
  283. if not infer:
  284. exp = batch['y'].to(self.device)
  285. x = exp
  286. x_recon, loss_kl, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=cond_feat, infer=False)
  287. x_recon = x_recon * mask.unsqueeze(-1)
  288. ret['pred'] = x_recon
  289. ret['mask'] = mask
  290. ret['loss_kl'] = loss_kl
  291. if return_latent:
  292. ret['m_q'] = m_q
  293. ret['z_p'] = z_p
  294. return x_recon, loss_kl, m_q, logs_q
  295. else:
  296. x_recon, z_p = self.vae(x=None, x_mask=mask, g=cond_feat, infer=True, temperature=temperature)
  297. x_recon = x_recon * mask.unsqueeze(-1)
  298. ret['pred'] = x_recon
  299. ret['mask'] = mask
  300. return x_recon
  301. class PitchContourVAEModel(nn.Module):
  302. def __init__(self, hparams, in_out_dim=64, audio_in_dim=1024, sqz_prior=False, cond_drop=False, use_prior_flow=True):
  303. super().__init__()
  304. self.hparams = copy.deepcopy(hparams)
  305. feat_dim = 128
  306. self.audio_in_dim = audio_in_dim
  307. self.blink_embed = nn.Embedding(2, feat_dim)
  308. self.mel_encoder = nn.Sequential(*[
  309. nn.Conv1d(audio_in_dim, feat_dim, 3, 1, 1, bias=False),
  310. nn.BatchNorm1d(feat_dim ),
  311. nn.GELU(),
  312. nn.Conv1d(feat_dim , feat_dim, 3, 1, 1, bias=False)
  313. ])
  314. self.pitch_embed = Embedding(300, feat_dim, None)
  315. self.pitch_encoder = nn.Sequential(*[
  316. nn.Conv1d(feat_dim, feat_dim , 3, 1, 1, bias=False),
  317. nn.BatchNorm1d(feat_dim),
  318. nn.GELU(),
  319. nn.Conv1d(feat_dim, feat_dim, 3, 1, 1, bias=False)
  320. ])
  321. cond_dim = feat_dim + feat_dim + feat_dim
  322. if hparams.get('use_mouth_amp_embed', False):
  323. self.mouth_amp_embed = nn.Parameter(torch.randn(feat_dim))
  324. cond_dim += feat_dim
  325. if hparams.get('use_eye_amp_embed', False):
  326. self.eye_amp_embed = nn.Parameter(torch.randn(feat_dim))
  327. cond_dim += feat_dim
  328. self.cond_proj = nn.Linear(cond_dim, feat_dim, bias=True)
  329. self.cond_drop = cond_drop
  330. if self.cond_drop:
  331. self.dropout = nn.Dropout(0.5)
  332. self.in_dim, self.out_dim = in_out_dim, in_out_dim
  333. self.sqz_prior = sqz_prior
  334. self.use_prior_flow = use_prior_flow
  335. self.vae = FVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
  336. enc_n_layers=8, dec_n_layers=4, gin_channels=feat_dim, strides=[4,],
  337. use_prior_glow=self.use_prior_flow, glow_hidden=64, glow_kernel_size=3, glow_n_blocks=4,sqz_prior=sqz_prior)
  338. self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
  339. def num_params(self, model, print_out=True, model_name="model"):
  340. parameters = filter(lambda p: p.requires_grad, model.parameters())
  341. parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
  342. if print_out:
  343. print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
  344. return parameters
  345. @property
  346. def device(self):
  347. return self.vae.parameters().__next__().device
  348. def forward(self, batch, ret, train=True, return_latent=False, temperature=1.):
  349. infer = not train
  350. hparams = self.hparams
  351. mask = batch['y_mask'].to(self.device)
  352. mel = batch['audio'].to(self.device)
  353. f0 = batch['f0'].to(self.device) # [b,t]
  354. if 'blink' not in batch:
  355. batch['blink'] = torch.zeros([f0.shape[0], f0.shape[1], 1], dtype=torch.long, device=f0.device)
  356. blink = batch['blink'].to(self.device)
  357. blink_feat = self.blink_embed(blink.squeeze(2))
  358. blink_feat = self.downsampler(blink_feat)
  359. mel = self.downsampler(mel)
  360. f0 = self.downsampler(f0.unsqueeze(-1)).squeeze(-1)
  361. f0_coarse = f0_to_coarse(f0)
  362. pitch_emb = self.pitch_embed(f0_coarse)
  363. cond_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
  364. pitch_feat = self.pitch_encoder(pitch_emb.transpose(1,2)).transpose(1,2)
  365. cond_feats = [cond_feat, pitch_feat, blink_feat]
  366. if hparams.get('use_mouth_amp_embed', False):
  367. mouth_amp = batch.get('mouth_amp', torch.ones([f0.shape[0], 1], device=f0.device) * 0.4)
  368. mouth_amp_feat = mouth_amp.unsqueeze(1) * self.mouth_amp_embed.unsqueeze(0)
  369. mouth_amp_feat = mouth_amp_feat.repeat([1,cond_feat.shape[1],1])
  370. cond_feats.append(mouth_amp_feat)
  371. if hparams.get('use_eye_amp_embed', False):
  372. eye_amp = batch.get('eye_amp', torch.ones([f0.shape[0], 1], device=f0.device) * 0.4)
  373. eye_amp_feat = eye_amp.unsqueeze(1) * self.eye_amp_embed.unsqueeze(0)
  374. eye_amp_feat = eye_amp_feat.repeat([1,cond_feat.shape[1],1])
  375. cond_feats.append(eye_amp_feat)
  376. cond_feat = torch.cat(cond_feats, dim=-1)
  377. cond_feat = self.cond_proj(cond_feat)
  378. if self.cond_drop:
  379. cond_feat = self.dropout(cond_feat)
  380. if not infer:
  381. exp = batch['y'].to(self.device)
  382. x = exp
  383. x_recon, loss_kl, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=cond_feat, infer=False)
  384. x_recon = x_recon * mask.unsqueeze(-1)
  385. ret['pred'] = x_recon
  386. ret['mask'] = mask
  387. ret['loss_kl'] = loss_kl
  388. if return_latent:
  389. ret['m_q'] = m_q
  390. ret['z_p'] = z_p
  391. return x_recon, loss_kl, m_q, logs_q
  392. else:
  393. x_recon, z_p = self.vae(x=None, x_mask=mask, g=cond_feat, infer=True, temperature=temperature)
  394. x_recon = x_recon * mask.unsqueeze(-1)
  395. ret['pred'] = x_recon
  396. ret['mask'] = mask
  397. return x_recon
  398. if __name__ == '__main__':
  399. model = FVAE(in_out_channels=64, hidden_channels=128, latent_size=32,kernel_size=3, enc_n_layers=6, dec_n_layers=2,
  400. gin_channels=80, strides=[4], use_prior_glow=False, glow_hidden=128, glow_kernel_size=3, glow_n_blocks=3)
  401. x = torch.rand([8, 64, 1000])
  402. x_mask = torch.ones([8,1,1000])
  403. g = torch.rand([8, 80, 1000])
  404. train_out = model(x,x_mask,g,infer=False)
  405. x_recon, loss_kl, z_p, m_q, logs_q = train_out
  406. print(" ")
  407. infer_out = model(x,x_mask,g,infer=True)
  408. x_recon, z_p = infer_out
  409. print(" ")