dct.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. import torch
  3. def dct(x, norm=None):
  4. x_shape = x.shape
  5. N = x_shape[-1]
  6. x = x.contiguous().view(-1, N)
  7. v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
  8. Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) # add this line
  9. k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
  10. W_r = torch.cos(k)
  11. W_i = torch.sin(k)
  12. V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
  13. if norm == 'ortho':
  14. V[:, 0] /= np.sqrt(N) * 2
  15. V[:, 1:] /= np.sqrt(N / 2) * 2
  16. V = 2 * V.view(*x_shape)
  17. return V
  18. def idct(X, norm=None):
  19. x_shape = X.shape
  20. N = x_shape[-1]
  21. X_v = X.contiguous().view(-1, x_shape[-1]) / 2
  22. if norm == 'ortho':
  23. X_v[:, 0] *= np.sqrt(N) * 2
  24. X_v[:, 1:] *= np.sqrt(N / 2) * 2
  25. k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
  26. W_r = torch.cos(k)
  27. W_i = torch.sin(k)
  28. V_t_r = X_v
  29. V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
  30. V_r = V_t_r * W_r - V_t_i * W_i
  31. V_i = V_t_r * W_i + V_t_i * W_r
  32. V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
  33. # v = torch.irfft(V, 1, onesided=False) # comment this line
  34. v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) # add this line
  35. x = v.new_zeros(v.shape)
  36. x[:, ::2] += v[:, :N - (N // 2)]
  37. x[:, 1::2] += v.flip([1])[:, :N // 2]
  38. return x.view(*x_shape)