griffin_lim.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import librosa
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. def _stft(y, hop_size, win_size, fft_size):
  6. return librosa.stft(y=y, n_fft=fft_size, hop_length=hop_size, win_length=win_size, pad_mode='constant')
  7. def _istft(y, hop_size, win_size):
  8. return librosa.istft(y, hop_length=hop_size, win_length=win_size)
  9. def griffin_lim(S, hop_size, win_size, fft_size, angles=None, n_iters=30):
  10. angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) if angles is None else angles
  11. S_complex = np.abs(S).astype(np.complex)
  12. y = _istft(S_complex * angles, hop_size, win_size)
  13. for i in range(n_iters):
  14. angles = np.exp(1j * np.angle(_stft(y, hop_size, win_size, fft_size)))
  15. y = _istft(S_complex * angles, hop_size, win_size)
  16. return y
  17. def istft(amp, ang, hop_size, win_size, fft_size, pad=False, window=None):
  18. spec = amp * torch.exp(1j * ang)
  19. spec_r = spec.real
  20. spec_i = spec.imag
  21. spec = torch.stack([spec_r, spec_i], -1)
  22. if window is None:
  23. window = torch.hann_window(win_size).to(amp.device)
  24. if pad:
  25. spec = F.pad(spec, [0, 0, 0, 1], mode='reflect')
  26. wav = torch.istft(spec, fft_size, hop_size, win_size)
  27. return wav
  28. def griffin_lim_torch(S, hop_size, win_size, fft_size, angles=None, n_iters=30):
  29. """
  30. Examples:
  31. >>> x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, win_length=win_length, pad_mode="constant")
  32. >>> x_stft = x_stft[None, ...]
  33. >>> amp = np.abs(x_stft)
  34. >>> angle_init = np.exp(2j * np.pi * np.random.rand(*x_stft.shape))
  35. >>> amp = torch.FloatTensor(amp)
  36. >>> wav = griffin_lim_torch(amp, angle_init, hparams)
  37. :param amp: [B, n_fft, T]
  38. :param ang: [B, n_fft, T]
  39. :return: [B, T_wav]
  40. """
  41. angles = torch.exp(2j * np.pi * torch.rand(*S.shape)) if angles is None else angles
  42. window = torch.hann_window(win_size).to(S.device)
  43. y = istft(S, angles, hop_size, win_size, fft_size, window=window)
  44. for i in range(n_iters):
  45. x_stft = torch.stft(y, fft_size, hop_size, win_size, window)
  46. x_stft = x_stft[..., 0] + 1j * x_stft[..., 1]
  47. angles = torch.angle(x_stft)
  48. y = istft(S, angles, hop_size, win_size, fft_size, window=window)
  49. return y
  50. # Conversions
  51. _mel_basis = None
  52. _inv_mel_basis = None
  53. def _build_mel_basis(audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax):
  54. assert fmax <= audio_sample_rate // 2
  55. return librosa.filters.mel(audio_sample_rate, fft_size, n_mels=audio_num_mel_bins, fmin=fmin, fmax=fmax)
  56. def _linear_to_mel(spectogram, audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax):
  57. global _mel_basis
  58. if _mel_basis is None:
  59. _mel_basis = _build_mel_basis(audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax)
  60. return np.dot(_mel_basis, spectogram)
  61. def _mel_to_linear(mel_spectrogram, audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax):
  62. global _inv_mel_basis
  63. if _inv_mel_basis is None:
  64. _inv_mel_basis = np.linalg.pinv(_build_mel_basis(audio_sample_rate, fft_size, audio_num_mel_bins, fmin, fmax))
  65. return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))