import abc import sys if sys.platform == "darwin": import mlx.core as mx import mlx_whisper from mlx_whisper.transcribe import ModelHolder import numpy as np import torch import whisper from control_backend.core.config import settings class SpeechRecognizer(abc.ABC): """ Abstract base class for speech recognition backends. Provides a common interface for loading models and transcribing audio, as well as heuristics for estimating token counts to optimize decoding. :ivar limit_output_length: If True, limits the generated text length based on audio duration. """ def __init__(self, limit_output_length=True): """ :param limit_output_length: When ``True``, the length of the generated speech will be limited by the length of the input audio and some heuristics. """ self.limit_output_length = limit_output_length @abc.abstractmethod def load_model(self): """ Load the speech recognition model into memory. """ ... @abc.abstractmethod def recognize_speech(self, audio: np.ndarray) -> str: """ Recognize speech from the given audio sample. :param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the range [-1.0, 1.0]. :return: The recognized speech text. """ @staticmethod def _estimate_max_tokens(audio: np.ndarray) -> int: """ Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking rate of 450 words per minute (3x average), and assumes that 3 words is approx. 4 tokens. :param audio: The audio sample (16 kHz) to use for length estimation. :return: The estimated length of the transcribed audio in tokens. """ length_seconds = len(audio) / settings.vad_settings.sample_rate_hz length_minutes = length_seconds / 60 word_count = length_minutes * settings.behaviour_settings.transcription_words_per_minute token_count = word_count / settings.behaviour_settings.transcription_words_per_token return int(token_count) + settings.behaviour_settings.transcription_token_buffer def _get_decode_options(self, audio: np.ndarray) -> dict: """ Construct decoding options for the Whisper model. :param audio: The audio sample (16 kHz) to use to determine options like max decode length. :return: A dict that can be used to construct ``whisper.DecodingOptions`` (or equivalent). """ options = {} if self.limit_output_length: options["sample_len"] = self._estimate_max_tokens(audio) return options @staticmethod def best_type(): """ Factory method to get the best available `SpeechRecognizer`. :return: An instance of :class:`MLXWhisperSpeechRecognizer` if on macOS with Apple Silicon, otherwise :class:`OpenAIWhisperSpeechRecognizer`. """ if torch.mps.is_available(): print("Choosing MLX Whisper model.") return MLXWhisperSpeechRecognizer() else: print("Choosing reference Whisper model.") return OpenAIWhisperSpeechRecognizer() class MLXWhisperSpeechRecognizer(SpeechRecognizer): """ Speech recognizer using the MLX framework (optimized for Apple Silicon). """ def __init__(self, limit_output_length=True): super().__init__(limit_output_length) self.was_loaded = False self.model_name = settings.speech_model_settings.mlx_model_name def load_model(self): """ Ensures the model is downloaded and cached. MLX loads dynamically, so this pre-fetches the model. """ if self.was_loaded: return # There appears to be no dedicated mechanism to preload a model, but this `get_model` does # store it in memory for later usage ModelHolder.get_model(self.model_name, mx.float16) self.was_loaded = True def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() return mlx_whisper.transcribe( audio, path_or_hf_repo=self.model_name, **self._get_decode_options(audio), )["text"].strip() class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): """ Speech recognizer using the standard OpenAI Whisper library (PyTorch). """ def __init__(self, limit_output_length=True): super().__init__(limit_output_length) self.model = None def load_model(self): """ Loads the OpenAI Whisper model onto the available device (CUDA or CPU). """ if self.model is not None: return device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model = whisper.load_model( settings.speech_model_settings.openai_model_name, device=device ) def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))[ "text" ].strip()