chatTTS_handler.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import ChatTTS
  2. import logging
  3. from baseHandler import BaseHandler
  4. import librosa
  5. import numpy as np
  6. from rich.console import Console
  7. import torch
  8. logging.basicConfig(
  9. format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  10. )
  11. logger = logging.getLogger(__name__)
  12. console = Console()
  13. class ChatTTSHandler(BaseHandler):
  14. def setup(
  15. self,
  16. should_listen,
  17. device="cuda",
  18. gen_kwargs={}, # Unused
  19. stream=True,
  20. chunk_size=512,
  21. ):
  22. self.should_listen = should_listen
  23. self.device = device
  24. self.model = ChatTTS.Chat()
  25. self.model.load(compile=False) # Doesn't work for me with True
  26. self.chunk_size = chunk_size
  27. self.stream = stream
  28. rnd_spk_emb = self.model.sample_random_speaker()
  29. self.params_infer_code = ChatTTS.Chat.InferCodeParams(
  30. spk_emb=rnd_spk_emb,
  31. )
  32. self.warmup()
  33. def warmup(self):
  34. logger.info(f"Warming up {self.__class__.__name__}")
  35. _ = self.model.infer("text")
  36. def process(self, llm_sentence):
  37. console.print(f"[green]ASSISTANT: {llm_sentence}")
  38. if self.device == "mps":
  39. import time
  40. start = time.time()
  41. torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
  42. torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
  43. _ = (
  44. time.time() - start
  45. ) # Removing this line makes it fail more often. I'm looking into it.
  46. wavs_gen = self.model.infer(
  47. llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream
  48. )
  49. if self.stream:
  50. wavs = [np.array([])]
  51. for gen in wavs_gen:
  52. if gen[0] is None or len(gen[0]) == 0:
  53. self.should_listen.set()
  54. return
  55. audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
  56. audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
  57. while len(audio_chunk) > self.chunk_size:
  58. yield audio_chunk[: self.chunk_size] # Return the first chunk_size samples of the audio data
  59. audio_chunk = audio_chunk[self.chunk_size :] # Remove the samples that have already been returned
  60. yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
  61. else:
  62. wavs = wavs_gen
  63. if len(wavs[0]) == 0:
  64. self.should_listen.set()
  65. return
  66. audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
  67. audio_chunk = (audio_chunk * 32768).astype(np.int16)
  68. for i in range(0, len(audio_chunk), self.chunk_size):
  69. yield np.pad(
  70. audio_chunk[i : i + self.chunk_size],
  71. (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
  72. )
  73. self.should_listen.set()