vqvae_fsq.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
  3. Code adapted from Jax version in Appendix A.1
  4. """
  5. from typing import List
  6. import torch
  7. import torch.nn as nn
  8. from torch import Tensor, int32
  9. def round_ste(z: Tensor) -> Tensor:
  10. """Round with straight through gradients."""
  11. zhat = z.round()
  12. return z + (zhat - z).detach()
  13. class FSQ(nn.Module):
  14. def __init__(self, levels: List[int]):
  15. super().__init__()
  16. _levels = torch.tensor(levels, dtype=int32)
  17. self.register_buffer("_levels", _levels)
  18. _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
  19. self.register_buffer("_basis", _basis)
  20. self.dim = len(levels)
  21. self.n_codes = self._levels.prod().item()
  22. implicit_codebook = self.indices_to_codes(torch.arange(self.n_codes))
  23. self.register_buffer("implicit_codebook", implicit_codebook)
  24. def forward(self, z: Tensor) -> Tensor:
  25. zhat = self.quantize(z)
  26. indices = self.codes_to_indices(zhat)
  27. return zhat, indices
  28. def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
  29. """Bound `z`, an array of shape (..., d)."""
  30. half_l = (self._levels - 1) * (1 - eps) / 2
  31. offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
  32. shift = (offset / half_l).tan()
  33. return (z + shift).tanh() * half_l - offset
  34. def quantize(self, z: Tensor) -> Tensor:
  35. """Quantizes z, returns quantized zhat, same shape as z."""
  36. quantized = round_ste(self.bound(z))
  37. half_width = self._levels // 2 # Renormalize to [-1, 1].
  38. return quantized / half_width
  39. def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
  40. half_width = self._levels // 2
  41. return (zhat_normalized * half_width) + half_width
  42. def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
  43. half_width = self._levels // 2
  44. return (zhat - half_width) / half_width
  45. def codes_to_indices(self, zhat: Tensor) -> Tensor:
  46. """Converts a `code` to an index in the codebook."""
  47. assert zhat.shape[-1] == self.dim
  48. zhat = self._scale_and_shift(zhat)
  49. return (zhat * self._basis).sum(dim=-1).to(int32)
  50. def indices_to_codes(self, indices: Tensor) -> Tensor:
  51. """Inverse of `codes_to_indices`."""
  52. indices = indices.unsqueeze(-1)
  53. codes_non_centered = (indices // self._basis) % self._levels
  54. return self._scale_and_shift_inverse(codes_non_centered)
  55. def get_codebook_entry(self, encoding_indices):
  56. return self.indices_to_codes(encoding_indices)