123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- from threading import Thread
- from time import perf_counter
- from baseHandler import BaseHandler
- import numpy as np
- import torch
- from transformers import (
- AutoTokenizer,
- )
- from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
- import librosa
- import logging
- from rich.console import Console
- from utils.utils import next_power_of_2
- from transformers.utils.import_utils import (
- is_flash_attn_2_available,
- )
- torch._inductor.config.fx_graph_cache = True
- # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
- torch._dynamo.config.cache_size_limit = 15
- logger = logging.getLogger(__name__)
- console = Console()
- if not is_flash_attn_2_available() and torch.cuda.is_available():
- logger.warn(
- """Parler TTS works best with flash attention 2, but is not installed
- Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`"""
- )
- class ParlerTTSHandler(BaseHandler):
- def setup(
- self,
- should_listen,
- model_name="ylacombe/parler-tts-mini-jenny-30H",
- device="cuda",
- torch_dtype="float16",
- compile_mode=None,
- gen_kwargs={},
- max_prompt_pad_length=8,
- description=(
- "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
- "She speaks very fast."
- ),
- play_steps_s=1,
- blocksize=512,
- ):
- self.should_listen = should_listen
- self.device = device
- self.torch_dtype = getattr(torch, torch_dtype)
- self.gen_kwargs = gen_kwargs
- self.compile_mode = compile_mode
- self.max_prompt_pad_length = max_prompt_pad_length
- self.description = description
- self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
- self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
- self.model = ParlerTTSForConditionalGeneration.from_pretrained(
- model_name, torch_dtype=self.torch_dtype
- ).to(device)
- framerate = self.model.audio_encoder.config.frame_rate
- self.play_steps = int(framerate * play_steps_s)
- self.blocksize = blocksize
- if self.compile_mode not in (None, "default"):
- logger.warning(
- "Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'"
- )
- self.compile_mode = "default"
- if self.compile_mode:
- self.model.generation_config.cache_implementation = "static"
- self.model.forward = torch.compile(
- self.model.forward, mode=self.compile_mode, fullgraph=True
- )
- self.warmup()
- def prepare_model_inputs(
- self,
- prompt,
- max_length_prompt=50,
- pad=False,
- ):
- pad_args_prompt = (
- {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
- )
- tokenized_description = self.description_tokenizer(
- self.description, return_tensors="pt"
- )
- input_ids = tokenized_description.input_ids.to(self.device)
- attention_mask = tokenized_description.attention_mask.to(self.device)
- tokenized_prompt = self.prompt_tokenizer(
- prompt, return_tensors="pt", **pad_args_prompt
- )
- prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
- prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
- gen_kwargs = {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "prompt_input_ids": prompt_input_ids,
- "prompt_attention_mask": prompt_attention_mask,
- **self.gen_kwargs,
- }
- return gen_kwargs
- def warmup(self):
- logger.info(f"Warming up {self.__class__.__name__}")
- if self.device == "cuda":
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- # 2 warmup steps for no compile or compile mode with CUDA graphs capture
- n_steps = 1 if self.compile_mode == "default" else 2
- if self.device == "cuda":
- torch.cuda.synchronize()
- start_event.record()
- if self.compile_mode:
- pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
- for pad_length in pad_lengths[::-1]:
- model_kwargs = self.prepare_model_inputs(
- "dummy prompt", max_length_prompt=pad_length, pad=True
- )
- for _ in range(n_steps):
- _ = self.model.generate(**model_kwargs)
- logger.info(f"Warmed up length {pad_length} tokens!")
- else:
- model_kwargs = self.prepare_model_inputs("dummy prompt")
- for _ in range(n_steps):
- _ = self.model.generate(**model_kwargs)
- if self.device == "cuda":
- end_event.record()
- torch.cuda.synchronize()
- logger.info(
- f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
- )
- def process(self, llm_sentence):
- if isinstance(llm_sentence, tuple):
- llm_sentence, _ = llm_sentence
-
- console.print(f"[green]ASSISTANT: {llm_sentence}")
- nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
- pad_args = {}
- if self.compile_mode:
- # pad to closest upper power of two
- pad_length = next_power_of_2(nb_tokens)
- logger.debug(f"padding to {pad_length}")
- pad_args["pad"] = True
- pad_args["max_length_prompt"] = pad_length
- tts_gen_kwargs = self.prepare_model_inputs(
- llm_sentence,
- **pad_args,
- )
- streamer = ParlerTTSStreamer(
- self.model, device=self.device, play_steps=self.play_steps
- )
- tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
- torch.manual_seed(0)
- thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
- thread.start()
- for i, audio_chunk in enumerate(streamer):
- global pipeline_start
- if i == 0 and "pipeline_start" in globals():
- logger.info(
- f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
- )
- audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
- audio_chunk = (audio_chunk * 32768).astype(np.int16)
- for i in range(0, len(audio_chunk), self.blocksize):
- yield np.pad(
- audio_chunk[i : i + self.blocksize],
- (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
- )
- self.should_listen.set()
- yield b"END"
|