vqvae_cvq.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. from torch import einsum
  6. from einops import rearrange
  7. import torch.distributed as dist
  8. from utils.commons.hparams import hparams
  9. class ClusteringVectorQuantiser(nn.Module):
  10. """
  11. Improved version over vector quantiser, with the dynamic initialisation
  12. for these unoptimised "dead" points.
  13. num_embed: number of codebook entry
  14. embed_dim: dimensionality of codebook entry
  15. beta: weight for the commitment loss
  16. distance: distance for looking up the closest code
  17. anchor: anchor sampled methods
  18. first_batch: if true, the offline version of our model
  19. contras_loss: if true, use the contras_loss to further improve the performance
  20. """
  21. def __init__(self, num_embed=1024, embed_dim=512, beta=0.25, distance='l2',
  22. anchor='closest', first_batch=False, contras_loss=True):
  23. super().__init__()
  24. self.num_embed = num_embed
  25. self.embed_dim = embed_dim
  26. self.beta = beta
  27. self.distance = distance
  28. self.anchor = anchor
  29. self.first_batch = first_batch
  30. self.contras_loss = contras_loss
  31. self.decay = 0.99
  32. self.init = False
  33. self.pool = FeaturePool(self.num_embed, self.embed_dim)
  34. self.embedding = nn.Embedding(self.num_embed, self.embed_dim)
  35. self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed)
  36. self.register_buffer("embed_prob", torch.zeros(self.num_embed))
  37. def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False):
  38. if mask is not None:
  39. assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape)
  40. assert mask.shape[-1] == 1, (mask.shape,)
  41. z = z * mask
  42. assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
  43. assert rescale_logits == False, "Only for interface compatible with Gumbel"
  44. assert return_logits == False, "Only for interface compatible with Gumbel"
  45. # reshape z -> (batch, height, width, channel) and flatten
  46. # z = rearrange(z, 'b c h w -> b h w c').contiguous()
  47. assert z.shape[-1] == self.embed_dim
  48. z_flattened = z.view(-1, self.embed_dim)
  49. # clculate the distance
  50. if self.distance == 'l2':
  51. # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  52. d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \
  53. torch.sum(self.embedding.weight ** 2, dim=1) + \
  54. 2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n'))
  55. elif self.distance == 'cos':
  56. # cosine distances from z to embeddings e_j
  57. normed_z_flattened = F.normalize(z_flattened, dim=1).detach()
  58. normed_codebook = F.normalize(self.embedding.weight, dim=1)
  59. d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n'))
  60. # encoding
  61. sort_distance, indices = d.sort(dim=1)
  62. # look up the closest point for the indices
  63. encoding_indices = indices[:,-1]
  64. encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device)
  65. encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
  66. # quantise and unflatten
  67. z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
  68. # compute loss for embedding
  69. loss = self.beta * (z_q.detach() - z) ** 2 + (z_q - z.detach()) ** 2
  70. if mask is not None:
  71. loss = (loss * mask).sum() / mask.sum() / self.embed_dim
  72. else:
  73. loss = loss.mean()
  74. # loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
  75. # preserve gradients
  76. z_q = z + (z_q - z).detach()
  77. # reshape back to match original input shape
  78. # z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
  79. # count
  80. # import pdb
  81. # pdb.set_trace()
  82. avg_probs = torch.mean(encodings, dim=0)
  83. # perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
  84. # min_encodings = encodings
  85. # online clustered reinitialisation for unoptimized points
  86. if self.training:
  87. # calculate the average usage of code entries
  88. self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay)
  89. # running average updates
  90. if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init):
  91. # closest sampling
  92. if self.anchor == 'closest':
  93. sort_distance, indices = d.sort(dim=0)
  94. random_feat = z_flattened.detach()[indices[-1,:]]
  95. # feature pool based random sampling
  96. elif self.anchor == 'random':
  97. random_feat = self.pool.query(z_flattened.detach())
  98. # probabilitical based random sampling
  99. elif self.anchor == 'probrandom':
  100. norm_distance = F.softmax(d.t(), dim=1)
  101. prob = torch.multinomial(norm_distance, num_samples=1).view(-1)
  102. random_feat = z_flattened.detach()[prob]
  103. # decay parameter based on the average usage
  104. decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim)
  105. if hparams.get('reduce_cvq_embed') and dist.is_initialized():
  106. # 确保在所有GPU上同步embedding的权重
  107. dist.all_reduce(random_feat.data, op=dist.ReduceOp.SUM)
  108. random_feat.data /= dist.get_world_size()
  109. self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
  110. if self.first_batch:
  111. self.init = True
  112. # contrastive loss
  113. if self.contras_loss:
  114. sort_distance, indices = d.sort(dim=0)
  115. dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True)
  116. dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:]
  117. dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07
  118. contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device))
  119. loss += contra_loss
  120. encoding_indices = encoding_indices.reshape(z.shape[:-1])
  121. return z_q, loss, encoding_indices
  122. def get_codebook_entry(self, encoding_indices):
  123. # # get quantized latent vectors
  124. # print(encoding_indices.shape)
  125. # encoding_indices = encoding_indices.view(-1)
  126. # encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=encoding_indices.device)
  127. # print(encodings.shape)
  128. # encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
  129. # print(encodings.shape)
  130. # # quantise and unflatten
  131. # z_q = torch.matmul(encodings, self.embedding.weight).view(encoding_indices.shape[0], -1)
  132. z_q = self.embedding(encoding_indices)
  133. return z_q
  134. class FeaturePool():
  135. """
  136. This class implements a feature buffer that stores previously encoded features
  137. This buffer enables us to initialize the codebook using a history of generated features
  138. rather than the ones produced by the latest encoders
  139. """
  140. def __init__(self, pool_size, dim=64):
  141. """
  142. Initialize the FeaturePool class
  143. Parameters:
  144. pool_size(int) -- the size of featue buffer
  145. """
  146. self.pool_size = pool_size
  147. if self.pool_size > 0:
  148. self.nums_features = 0
  149. self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size
  150. def query(self, features):
  151. """
  152. return features from the pool
  153. """
  154. self.features = self.features.to(features.device)
  155. if self.nums_features < self.pool_size:
  156. if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook
  157. random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
  158. self.features = features[random_feat_id]
  159. self.nums_features = self.pool_size
  160. else:
  161. # if the mini-batch is not large nuough, just store it for the next update
  162. num = self.nums_features + features.size(0)
  163. self.features[self.nums_features:num] = features
  164. self.nums_features = num
  165. else:
  166. if features.size(0) > int(self.pool_size):
  167. random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
  168. self.features = features[random_feat_id]
  169. else:
  170. random_id = torch.randperm(self.pool_size)
  171. self.features[random_id[:features.size(0)]] = features
  172. return self.features