""" Lookup Free Quantization Proposed in https://arxiv.org/abs/2310.05737 basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss https://arxiv.org/abs/2309.15505 """ import torch from einops import rearrange from torch.nn import Module # entropy def binary_entropy(prob): return -prob * log(prob) - (1 - prob) * log(1 - prob) # tensor helpers def log(t, eps=1e-20): return t.clamp(min=eps).log() # convert to bit representations and back def decimal_to_bits(x: torch.LongTensor, bits: int) -> torch.FloatTensor: # [b, ...] {0, 1, ..., max - 1} -> [b, ..., d] {-1, 1} mask = 2 ** torch.arange(bits).to(x) # [d] bits = ((x.unsqueeze(-1) & mask) != 0).float() # [b, n, d] {0, 1} return bits * 2 - 1 # {0, 1} -> {-1, 1} def bits_to_decimal(x: torch.FloatTensor) -> torch.LongTensor: # [b, ..., d] {-1, 1} -> [b, ...] {0, 1, ..., max - 1} x = (x > 0).long() # {-1, 1} -> {0, 1}, [b, ..., d] mask = 2 ** torch.arange(x.size(-1)).to(x) # [d] dec = (x * mask).sum(-1) # [b, ...] return dec # class class LFQY(Module): def __init__(self, dim, entropy_loss_weight=0.1, diversity_gamma=1.0): super().__init__() self.dim = dim self.diversity_gamma = diversity_gamma self.entropy_loss_weight = entropy_loss_weight def indices_to_codes(self, indices): codes = decimal_to_bits(indices, self.dim) # codes = rearrange(codes, 'b ... d -> b d ...') return codes def forward(self, x, mask=None, inv_temperature=1.): """ einstein notation b - batch n - sequence (or flattened spatial dimensions) d - feature dimension, which is also log2(codebook size) """ # x = rearrange(x, 'b d ... -> b ... d') assert x.shape[-1] == self.dim z = torch.tanh(x / inv_temperature) # (-1, 1) # quantize by eq 3. quantized = torch.sign(x) # {-1, 1} z = z + (quantized - z).detach() # calculate indices indices = bits_to_decimal(z) # entropy aux loss if self.training: prob = torch.sigmoid(x / inv_temperature) # [b, ..., d] bit_entropy = binary_entropy(prob).sum(-1).mean() # E[H(q)] = avg(sum(H(q_i))) avg_prob = prob.flatten(0, -2).mean(0) # [b, ..., d] -> [n, d] -> [d] codebook_entropy = binary_entropy(avg_prob).sum() # H(E[q]) = sum(H(avg(q_i))) """ 1. entropy will be nudged to be low for each bit, so each scalar commits to one latent binary bit or the other. 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used. """ entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy else: # if not training, just return dummy 0 entropy_aux_loss = torch.zeros(1).to(z) entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight # reconstitute image or video dimensions # z = rearrange(z, 'b ... d -> b d ...') # bits to decimal for the codebook indices return z, entropy_aux_loss, indices def get_codebook_entry(self, encoding_indices): return self.indices_to_codes(encoding_indices)