parler_handler.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from threading import Thread
  2. from time import perf_counter
  3. from baseHandler import BaseHandler
  4. import numpy as np
  5. import torch
  6. from transformers import (
  7. AutoTokenizer,
  8. )
  9. from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
  10. import librosa
  11. import logging
  12. from rich.console import Console
  13. from utils.utils import next_power_of_2
  14. from transformers.utils.import_utils import (
  15. is_flash_attn_2_available,
  16. )
  17. torch._inductor.config.fx_graph_cache = True
  18. # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
  19. torch._dynamo.config.cache_size_limit = 15
  20. logger = logging.getLogger(__name__)
  21. console = Console()
  22. if not is_flash_attn_2_available() and torch.cuda.is_available():
  23. logger.warn(
  24. """Parler TTS works best with flash attention 2, but is not installed
  25. Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`"""
  26. )
  27. class ParlerTTSHandler(BaseHandler):
  28. def setup(
  29. self,
  30. should_listen,
  31. model_name="ylacombe/parler-tts-mini-jenny-30H",
  32. device="cuda",
  33. torch_dtype="float16",
  34. compile_mode=None,
  35. gen_kwargs={},
  36. max_prompt_pad_length=8,
  37. description=(
  38. "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
  39. "She speaks very fast."
  40. ),
  41. play_steps_s=1,
  42. blocksize=512,
  43. ):
  44. self.should_listen = should_listen
  45. self.device = device
  46. self.torch_dtype = getattr(torch, torch_dtype)
  47. self.gen_kwargs = gen_kwargs
  48. self.compile_mode = compile_mode
  49. self.max_prompt_pad_length = max_prompt_pad_length
  50. self.description = description
  51. self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
  52. self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
  53. self.model = ParlerTTSForConditionalGeneration.from_pretrained(
  54. model_name, torch_dtype=self.torch_dtype
  55. ).to(device)
  56. framerate = self.model.audio_encoder.config.frame_rate
  57. self.play_steps = int(framerate * play_steps_s)
  58. self.blocksize = blocksize
  59. if self.compile_mode not in (None, "default"):
  60. logger.warning(
  61. "Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'"
  62. )
  63. self.compile_mode = "default"
  64. if self.compile_mode:
  65. self.model.generation_config.cache_implementation = "static"
  66. self.model.forward = torch.compile(
  67. self.model.forward, mode=self.compile_mode, fullgraph=True
  68. )
  69. self.warmup()
  70. def prepare_model_inputs(
  71. self,
  72. prompt,
  73. max_length_prompt=50,
  74. pad=False,
  75. ):
  76. pad_args_prompt = (
  77. {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
  78. )
  79. tokenized_description = self.description_tokenizer(
  80. self.description, return_tensors="pt"
  81. )
  82. input_ids = tokenized_description.input_ids.to(self.device)
  83. attention_mask = tokenized_description.attention_mask.to(self.device)
  84. tokenized_prompt = self.prompt_tokenizer(
  85. prompt, return_tensors="pt", **pad_args_prompt
  86. )
  87. prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
  88. prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
  89. gen_kwargs = {
  90. "input_ids": input_ids,
  91. "attention_mask": attention_mask,
  92. "prompt_input_ids": prompt_input_ids,
  93. "prompt_attention_mask": prompt_attention_mask,
  94. **self.gen_kwargs,
  95. }
  96. return gen_kwargs
  97. def warmup(self):
  98. logger.info(f"Warming up {self.__class__.__name__}")
  99. if self.device == "cuda":
  100. start_event = torch.cuda.Event(enable_timing=True)
  101. end_event = torch.cuda.Event(enable_timing=True)
  102. # 2 warmup steps for no compile or compile mode with CUDA graphs capture
  103. n_steps = 1 if self.compile_mode == "default" else 2
  104. if self.device == "cuda":
  105. torch.cuda.synchronize()
  106. start_event.record()
  107. if self.compile_mode:
  108. pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
  109. for pad_length in pad_lengths[::-1]:
  110. model_kwargs = self.prepare_model_inputs(
  111. "dummy prompt", max_length_prompt=pad_length, pad=True
  112. )
  113. for _ in range(n_steps):
  114. _ = self.model.generate(**model_kwargs)
  115. logger.info(f"Warmed up length {pad_length} tokens!")
  116. else:
  117. model_kwargs = self.prepare_model_inputs("dummy prompt")
  118. for _ in range(n_steps):
  119. _ = self.model.generate(**model_kwargs)
  120. if self.device == "cuda":
  121. end_event.record()
  122. torch.cuda.synchronize()
  123. logger.info(
  124. f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
  125. )
  126. def process(self, llm_sentence):
  127. if isinstance(llm_sentence, tuple):
  128. llm_sentence, _ = llm_sentence
  129. console.print(f"[green]ASSISTANT: {llm_sentence}")
  130. nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
  131. pad_args = {}
  132. if self.compile_mode:
  133. # pad to closest upper power of two
  134. pad_length = next_power_of_2(nb_tokens)
  135. logger.debug(f"padding to {pad_length}")
  136. pad_args["pad"] = True
  137. pad_args["max_length_prompt"] = pad_length
  138. tts_gen_kwargs = self.prepare_model_inputs(
  139. llm_sentence,
  140. **pad_args,
  141. )
  142. streamer = ParlerTTSStreamer(
  143. self.model, device=self.device, play_steps=self.play_steps
  144. )
  145. tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
  146. torch.manual_seed(0)
  147. thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
  148. thread.start()
  149. for i, audio_chunk in enumerate(streamer):
  150. global pipeline_start
  151. if i == 0 and "pipeline_start" in globals():
  152. logger.info(
  153. f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
  154. )
  155. audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
  156. audio_chunk = (audio_chunk * 32768).astype(np.int16)
  157. for i in range(0, len(audio_chunk), self.blocksize):
  158. yield np.pad(
  159. audio_chunk[i : i + self.blocksize],
  160. (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
  161. )
  162. self.should_listen.set()
  163. yield b"END"