whisper_stt_handler.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from time import perf_counter
  2. from transformers import (
  3. AutoProcessor,
  4. AutoModelForSpeechSeq2Seq
  5. )
  6. import torch
  7. from copy import copy
  8. from baseHandler import BaseHandler
  9. from rich.console import Console
  10. import logging
  11. logger = logging.getLogger(__name__)
  12. console = Console()
  13. SUPPORTED_LANGUAGES = [
  14. "en",
  15. "fr",
  16. "es",
  17. "zh",
  18. "ja",
  19. "ko",
  20. "hi"
  21. ]
  22. class WhisperSTTHandler(BaseHandler):
  23. """
  24. Handles the Speech To Text generation using a Whisper model.
  25. """
  26. def setup(
  27. self,
  28. model_name="distil-whisper/distil-large-v3",
  29. device="cuda",
  30. torch_dtype="float16",
  31. compile_mode=None,
  32. language=None,
  33. gen_kwargs={},
  34. ):
  35. self.device = device
  36. self.torch_dtype = getattr(torch, torch_dtype)
  37. self.compile_mode = compile_mode
  38. self.gen_kwargs = gen_kwargs
  39. self.start_language = language
  40. self.last_language = language if language != "auto" else None
  41. if self.last_language is not None:
  42. self.gen_kwargs["language"] = self.last_language
  43. self.processor = AutoProcessor.from_pretrained(model_name)
  44. self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
  45. model_name,
  46. torch_dtype=self.torch_dtype,
  47. ).to(device)
  48. # compile
  49. if self.compile_mode:
  50. self.model.generation_config.cache_implementation = "static"
  51. self.model.forward = torch.compile(
  52. self.model.forward, mode=self.compile_mode, fullgraph=True
  53. )
  54. self.warmup()
  55. def prepare_model_inputs(self, spoken_prompt):
  56. input_features = self.processor(
  57. spoken_prompt, sampling_rate=16000, return_tensors="pt"
  58. ).input_features
  59. input_features = input_features.to(self.device, dtype=self.torch_dtype)
  60. return input_features
  61. def warmup(self):
  62. logger.info(f"Warming up {self.__class__.__name__}")
  63. # 2 warmup steps for no compile or compile mode with CUDA graphs capture
  64. n_steps = 1 if self.compile_mode == "default" else 2
  65. dummy_input = torch.randn(
  66. (1, self.model.config.num_mel_bins, 3000),
  67. dtype=self.torch_dtype,
  68. device=self.device,
  69. )
  70. if self.compile_mode not in (None, "default"):
  71. # generating more tokens than previously will trigger CUDA graphs capture
  72. # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
  73. # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
  74. warmup_gen_kwargs = {
  75. "min_new_tokens": self.gen_kwargs[
  76. "max_new_tokens"
  77. ], # Yes, assign max_new_tokens to min_new_tokens
  78. "max_new_tokens": self.gen_kwargs["max_new_tokens"],
  79. **self.gen_kwargs,
  80. }
  81. else:
  82. warmup_gen_kwargs = self.gen_kwargs
  83. if self.device == "cuda":
  84. start_event = torch.cuda.Event(enable_timing=True)
  85. end_event = torch.cuda.Event(enable_timing=True)
  86. torch.cuda.synchronize()
  87. start_event.record()
  88. for _ in range(n_steps):
  89. _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
  90. if self.device == "cuda":
  91. end_event.record()
  92. torch.cuda.synchronize()
  93. logger.info(
  94. f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
  95. )
  96. def process(self, spoken_prompt):
  97. logger.debug("infering whisper...")
  98. global pipeline_start
  99. pipeline_start = perf_counter()
  100. input_features = self.prepare_model_inputs(spoken_prompt)
  101. pred_ids = self.model.generate(input_features, **self.gen_kwargs)
  102. language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
  103. if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
  104. logger.warning("Whisper detected unsupported language:", language_code)
  105. gen_kwargs = copy(self.gen_kwargs)
  106. gen_kwargs['language'] = self.last_language
  107. language_code = self.last_language
  108. pred_ids = self.model.generate(input_features, **gen_kwargs)
  109. else:
  110. self.last_language = language_code
  111. pred_text = self.processor.batch_decode(
  112. pred_ids, skip_special_tokens=True, decode_with_timestamps=False
  113. )[0]
  114. language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
  115. logger.debug("finished whisper inference")
  116. console.print(f"[yellow]USER: {pred_text}")
  117. logger.debug(f"Language Code Whisper: {language_code}")
  118. if self.start_language == "auto":
  119. language_code += "-auto"
  120. yield (pred_text, language_code)