vqvae_lfq_y.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. """
  2. Lookup Free Quantization
  3. Proposed in https://arxiv.org/abs/2310.05737
  4. basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss
  5. https://arxiv.org/abs/2309.15505
  6. """
  7. import torch
  8. from einops import rearrange
  9. from torch.nn import Module
  10. # entropy
  11. def binary_entropy(prob):
  12. return -prob * log(prob) - (1 - prob) * log(1 - prob)
  13. # tensor helpers
  14. def log(t, eps=1e-20):
  15. return t.clamp(min=eps).log()
  16. # convert to bit representations and back
  17. def decimal_to_bits(x: torch.LongTensor, bits: int) -> torch.FloatTensor:
  18. # [b, ...] {0, 1, ..., max - 1} -> [b, ..., d] {-1, 1}
  19. mask = 2 ** torch.arange(bits).to(x) # [d]
  20. bits = ((x.unsqueeze(-1) & mask) != 0).float() # [b, n, d] {0, 1}
  21. return bits * 2 - 1 # {0, 1} -> {-1, 1}
  22. def bits_to_decimal(x: torch.FloatTensor) -> torch.LongTensor:
  23. # [b, ..., d] {-1, 1} -> [b, ...] {0, 1, ..., max - 1}
  24. x = (x > 0).long() # {-1, 1} -> {0, 1}, [b, ..., d]
  25. mask = 2 ** torch.arange(x.size(-1)).to(x) # [d]
  26. dec = (x * mask).sum(-1) # [b, ...]
  27. return dec
  28. # class
  29. class LFQY(Module):
  30. def __init__(self, dim, entropy_loss_weight=0.1, diversity_gamma=1.0):
  31. super().__init__()
  32. self.dim = dim
  33. self.diversity_gamma = diversity_gamma
  34. self.entropy_loss_weight = entropy_loss_weight
  35. def indices_to_codes(self, indices):
  36. codes = decimal_to_bits(indices, self.dim)
  37. # codes = rearrange(codes, 'b ... d -> b d ...')
  38. return codes
  39. def forward(self, x, mask=None, inv_temperature=1.):
  40. """
  41. einstein notation
  42. b - batch
  43. n - sequence (or flattened spatial dimensions)
  44. d - feature dimension, which is also log2(codebook size)
  45. """
  46. # x = rearrange(x, 'b d ... -> b ... d')
  47. assert x.shape[-1] == self.dim
  48. z = torch.tanh(x / inv_temperature) # (-1, 1)
  49. # quantize by eq 3.
  50. quantized = torch.sign(x) # {-1, 1}
  51. z = z + (quantized - z).detach()
  52. # calculate indices
  53. indices = bits_to_decimal(z)
  54. # entropy aux loss
  55. if self.training:
  56. prob = torch.sigmoid(x / inv_temperature) # [b, ..., d]
  57. bit_entropy = binary_entropy(prob).sum(-1).mean()
  58. # E[H(q)] = avg(sum(H(q_i)))
  59. avg_prob = prob.flatten(0, -2).mean(0) # [b, ..., d] -> [n, d] -> [d]
  60. codebook_entropy = binary_entropy(avg_prob).sum()
  61. # H(E[q]) = sum(H(avg(q_i)))
  62. """
  63. 1. entropy will be nudged to be low for each bit,
  64. so each scalar commits to one latent binary bit or the other.
  65. 2. codebook entropy will be nudged to be high,
  66. to encourage all codes to be uniformly used.
  67. """
  68. entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy
  69. else:
  70. # if not training, just return dummy 0
  71. entropy_aux_loss = torch.zeros(1).to(z)
  72. entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight
  73. # reconstitute image or video dimensions
  74. # z = rearrange(z, 'b ... d -> b d ...')
  75. # bits to decimal for the codebook indices
  76. return z, entropy_aux_loss, indices
  77. def get_codebook_entry(self, encoding_indices):
  78. return self.indices_to_codes(encoding_indices)