151 lines
5.1 KiB
Python
151 lines
5.1 KiB
Python
import abc
|
|
import sys
|
|
|
|
if sys.platform == "darwin":
|
|
import mlx.core as mx
|
|
import mlx_whisper
|
|
from mlx_whisper.transcribe import ModelHolder
|
|
|
|
import numpy as np
|
|
import torch
|
|
import whisper
|
|
|
|
from control_backend.core.config import settings
|
|
|
|
|
|
class SpeechRecognizer(abc.ABC):
|
|
"""
|
|
Abstract base class for speech recognition backends.
|
|
|
|
Provides a common interface for loading models and transcribing audio,
|
|
as well as heuristics for estimating token counts to optimize decoding.
|
|
|
|
:ivar limit_output_length: If True, limits the generated text length based on audio duration.
|
|
"""
|
|
|
|
def __init__(self, limit_output_length=True):
|
|
"""
|
|
:param limit_output_length: When ``True``, the length of the generated speech will be
|
|
limited by the length of the input audio and some heuristics.
|
|
"""
|
|
self.limit_output_length = limit_output_length
|
|
|
|
@abc.abstractmethod
|
|
def load_model(self):
|
|
"""
|
|
Load the speech recognition model into memory.
|
|
"""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
|
"""
|
|
Recognize speech from the given audio sample.
|
|
|
|
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
|
|
range [-1.0, 1.0].
|
|
:return: The recognized speech text.
|
|
"""
|
|
|
|
@staticmethod
|
|
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
|
"""
|
|
Estimate the maximum length of a given audio sample in tokens.
|
|
|
|
Assumes a maximum speaking rate of 450 words per minute (3x average), and assumes that
|
|
3 words is approx. 4 tokens.
|
|
|
|
:param audio: The audio sample (16 kHz) to use for length estimation.
|
|
:return: The estimated length of the transcribed audio in tokens.
|
|
"""
|
|
length_seconds = len(audio) / settings.vad_settings.sample_rate_hz
|
|
length_minutes = length_seconds / 60
|
|
word_count = length_minutes * settings.behaviour_settings.transcription_words_per_minute
|
|
token_count = word_count / settings.behaviour_settings.transcription_words_per_token
|
|
return int(token_count) + settings.behaviour_settings.transcription_token_buffer
|
|
|
|
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
|
"""
|
|
Construct decoding options for the Whisper model.
|
|
|
|
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
|
|
:return: A dict that can be used to construct ``whisper.DecodingOptions`` (or equivalent).
|
|
"""
|
|
options = {}
|
|
if self.limit_output_length:
|
|
options["sample_len"] = self._estimate_max_tokens(audio)
|
|
return options
|
|
|
|
@staticmethod
|
|
def best_type():
|
|
"""
|
|
Factory method to get the best available `SpeechRecognizer`.
|
|
|
|
:return: An instance of :class:`MLXWhisperSpeechRecognizer` if on macOS with Apple Silicon,
|
|
otherwise :class:`OpenAIWhisperSpeechRecognizer`.
|
|
"""
|
|
if torch.mps.is_available():
|
|
print("Choosing MLX Whisper model.")
|
|
return MLXWhisperSpeechRecognizer()
|
|
else:
|
|
print("Choosing reference Whisper model.")
|
|
return OpenAIWhisperSpeechRecognizer()
|
|
|
|
|
|
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
|
"""
|
|
Speech recognizer using the MLX framework (optimized for Apple Silicon).
|
|
"""
|
|
|
|
def __init__(self, limit_output_length=True):
|
|
super().__init__(limit_output_length)
|
|
self.was_loaded = False
|
|
self.model_name = settings.speech_model_settings.mlx_model_name
|
|
|
|
def load_model(self):
|
|
"""
|
|
Ensures the model is downloaded and cached. MLX loads dynamically, so this
|
|
pre-fetches the model.
|
|
"""
|
|
if self.was_loaded:
|
|
return
|
|
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
|
# store it in memory for later usage
|
|
ModelHolder.get_model(self.model_name, mx.float16)
|
|
self.was_loaded = True
|
|
|
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
|
self.load_model()
|
|
return mlx_whisper.transcribe(
|
|
audio,
|
|
path_or_hf_repo=self.model_name,
|
|
**self._get_decode_options(audio),
|
|
)["text"].strip()
|
|
|
|
|
|
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
|
"""
|
|
Speech recognizer using the standard OpenAI Whisper library (PyTorch).
|
|
"""
|
|
|
|
def __init__(self, limit_output_length=True):
|
|
super().__init__(limit_output_length)
|
|
self.model = None
|
|
|
|
def load_model(self):
|
|
"""
|
|
Loads the OpenAI Whisper model onto the available device (CUDA or CPU).
|
|
"""
|
|
if self.model is not None:
|
|
return
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
self.model = whisper.load_model(
|
|
settings.speech_model_settings.openai_model_name, device=device
|
|
)
|
|
|
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
|
self.load_model()
|
|
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))[
|
|
"text"
|
|
].strip()
|