123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- """
- Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
- Code adapted from Jax version in Appendix A.1
- """
- from typing import List
- import torch
- import torch.nn as nn
- from torch import Tensor, int32
- def round_ste(z: Tensor) -> Tensor:
- """Round with straight through gradients."""
- zhat = z.round()
- return z + (zhat - z).detach()
- class FSQ(nn.Module):
- def __init__(self, levels: List[int]):
- super().__init__()
- _levels = torch.tensor(levels, dtype=int32)
- self.register_buffer("_levels", _levels)
- _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
- self.register_buffer("_basis", _basis)
- self.dim = len(levels)
- self.n_codes = self._levels.prod().item()
- implicit_codebook = self.indices_to_codes(torch.arange(self.n_codes))
- self.register_buffer("implicit_codebook", implicit_codebook)
- def forward(self, z: Tensor) -> Tensor:
- zhat = self.quantize(z)
- indices = self.codes_to_indices(zhat)
- return zhat, indices
- def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
- """Bound `z`, an array of shape (..., d)."""
- half_l = (self._levels - 1) * (1 - eps) / 2
- offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
- shift = (offset / half_l).tan()
- return (z + shift).tanh() * half_l - offset
- def quantize(self, z: Tensor) -> Tensor:
- """Quantizes z, returns quantized zhat, same shape as z."""
- quantized = round_ste(self.bound(z))
- half_width = self._levels // 2 # Renormalize to [-1, 1].
- return quantized / half_width
- def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
- half_width = self._levels // 2
- return (zhat_normalized * half_width) + half_width
- def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
- half_width = self._levels // 2
- return (zhat - half_width) / half_width
- def codes_to_indices(self, zhat: Tensor) -> Tensor:
- """Converts a `code` to an index in the codebook."""
- assert zhat.shape[-1] == self.dim
- zhat = self._scale_and_shift(zhat)
- return (zhat * self._basis).sum(dim=-1).to(int32)
- def indices_to_codes(self, indices: Tensor) -> Tensor:
- """Inverse of `codes_to_indices`."""
- indices = indices.unsqueeze(-1)
- codes_non_centered = (indices // self._basis) % self._levels
- return self._scale_and_shift_inverse(codes_non_centered)
- def get_codebook_entry(self, encoding_indices):
- return self.indices_to_codes(encoding_indices)
|