docs: add docs to CB

Pretty much every class and method should have documentation now.

ref: N25B-295
This commit is contained in:
2025-11-24 21:58:22 +01:00
parent 54502e441c
commit 129d3c4420
26 changed files with 757 additions and 80 deletions

View File

@@ -14,15 +14,28 @@ from control_backend.core.config import settings
class SpeechRecognizer(abc.ABC):
"""
Abstract base class for speech recognition backends.
Provides a common interface for loading models and transcribing audio,
as well as heuristics for estimating token counts to optimize decoding.
:ivar limit_output_length: If True, limits the generated text length based on audio duration.
"""
def __init__(self, limit_output_length=True):
"""
:param limit_output_length: When `True`, the length of the generated speech will be limited
by the length of the input audio and some heuristics.
:param limit_output_length: When ``True``, the length of the generated speech will be
limited by the length of the input audio and some heuristics.
"""
self.limit_output_length = limit_output_length
@abc.abstractmethod
def load_model(self): ...
def load_model(self):
"""
Load the speech recognition model into memory.
"""
...
@abc.abstractmethod
def recognize_speech(self, audio: np.ndarray) -> str:
@@ -30,15 +43,17 @@ class SpeechRecognizer(abc.ABC):
Recognize speech from the given audio sample.
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
range [-1.0, 1.0].
:return: Recognized speech.
range [-1.0, 1.0].
:return: The recognized speech text.
"""
@staticmethod
def _estimate_max_tokens(audio: np.ndarray) -> int:
"""
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
Estimate the maximum length of a given audio sample in tokens.
Assumes a maximum speaking rate of 450 words per minute (3x average), and assumes that
3 words is approx. 4 tokens.
:param audio: The audio sample (16 kHz) to use for length estimation.
:return: The estimated length of the transcribed audio in tokens.
@@ -51,8 +66,10 @@ class SpeechRecognizer(abc.ABC):
def _get_decode_options(self, audio: np.ndarray) -> dict:
"""
Construct decoding options for the Whisper model.
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
:return: A dict that can be used to construct `whisper.DecodingOptions`.
:return: A dict that can be used to construct ``whisper.DecodingOptions`` (or equivalent).
"""
options = {}
if self.limit_output_length:
@@ -61,7 +78,12 @@ class SpeechRecognizer(abc.ABC):
@staticmethod
def best_type():
"""Get the best type of SpeechRecognizer based on system capabilities."""
"""
Factory method to get the best available `SpeechRecognizer`.
:return: An instance of :class:`MLXWhisperSpeechRecognizer` if on macOS with Apple Silicon,
otherwise :class:`OpenAIWhisperSpeechRecognizer`.
"""
if torch.mps.is_available():
print("Choosing MLX Whisper model.")
return MLXWhisperSpeechRecognizer()
@@ -71,12 +93,20 @@ class SpeechRecognizer(abc.ABC):
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
"""
Speech recognizer using the MLX framework (optimized for Apple Silicon).
"""
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.was_loaded = False
self.model_name = settings.speech_model_settings.mlx_model_name
def load_model(self):
"""
Ensures the model is downloaded and cached. MLX loads dynamically, so this
pre-fetches the model.
"""
if self.was_loaded:
return
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
@@ -94,11 +124,18 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
"""
Speech recognizer using the standard OpenAI Whisper library (PyTorch).
"""
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.model = None
def load_model(self):
"""
Loads the OpenAI Whisper model onto the available device (CUDA or CPU).
"""
if self.model is not None:
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")