docs: add docs to CB

Pretty much every class and method should have documentation now.

ref: N25B-295
This commit is contained in:
2025-11-24 21:58:22 +01:00
parent 54502e441c
commit 129d3c4420
26 changed files with 757 additions and 80 deletions

View File

@@ -14,15 +14,28 @@ from control_backend.core.config import settings
class SpeechRecognizer(abc.ABC):
"""
Abstract base class for speech recognition backends.
Provides a common interface for loading models and transcribing audio,
as well as heuristics for estimating token counts to optimize decoding.
:ivar limit_output_length: If True, limits the generated text length based on audio duration.
"""
def __init__(self, limit_output_length=True):
"""
:param limit_output_length: When `True`, the length of the generated speech will be limited
by the length of the input audio and some heuristics.
:param limit_output_length: When ``True``, the length of the generated speech will be
limited by the length of the input audio and some heuristics.
"""
self.limit_output_length = limit_output_length
@abc.abstractmethod
def load_model(self): ...
def load_model(self):
"""
Load the speech recognition model into memory.
"""
...
@abc.abstractmethod
def recognize_speech(self, audio: np.ndarray) -> str:
@@ -30,15 +43,17 @@ class SpeechRecognizer(abc.ABC):
Recognize speech from the given audio sample.
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
range [-1.0, 1.0].
:return: Recognized speech.
range [-1.0, 1.0].
:return: The recognized speech text.
"""
@staticmethod
def _estimate_max_tokens(audio: np.ndarray) -> int:
"""
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
Estimate the maximum length of a given audio sample in tokens.
Assumes a maximum speaking rate of 450 words per minute (3x average), and assumes that
3 words is approx. 4 tokens.
:param audio: The audio sample (16 kHz) to use for length estimation.
:return: The estimated length of the transcribed audio in tokens.
@@ -51,8 +66,10 @@ class SpeechRecognizer(abc.ABC):
def _get_decode_options(self, audio: np.ndarray) -> dict:
"""
Construct decoding options for the Whisper model.
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
:return: A dict that can be used to construct `whisper.DecodingOptions`.
:return: A dict that can be used to construct ``whisper.DecodingOptions`` (or equivalent).
"""
options = {}
if self.limit_output_length:
@@ -61,7 +78,12 @@ class SpeechRecognizer(abc.ABC):
@staticmethod
def best_type():
"""Get the best type of SpeechRecognizer based on system capabilities."""
"""
Factory method to get the best available `SpeechRecognizer`.
:return: An instance of :class:`MLXWhisperSpeechRecognizer` if on macOS with Apple Silicon,
otherwise :class:`OpenAIWhisperSpeechRecognizer`.
"""
if torch.mps.is_available():
print("Choosing MLX Whisper model.")
return MLXWhisperSpeechRecognizer()
@@ -71,12 +93,20 @@ class SpeechRecognizer(abc.ABC):
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
"""
Speech recognizer using the MLX framework (optimized for Apple Silicon).
"""
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.was_loaded = False
self.model_name = settings.speech_model_settings.mlx_model_name
def load_model(self):
"""
Ensures the model is downloaded and cached. MLX loads dynamically, so this
pre-fetches the model.
"""
if self.was_loaded:
return
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
@@ -94,11 +124,18 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
"""
Speech recognizer using the standard OpenAI Whisper library (PyTorch).
"""
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.model = None
def load_model(self):
"""
Loads the OpenAI Whisper model onto the available device (CUDA or CPU).
"""
if self.model is not None:
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

View File

@@ -13,11 +13,26 @@ from .speech_recognizer import SpeechRecognizer
class TranscriptionAgent(BaseAgent):
"""
An agent which listens to audio fragments with voice, transcribes them, and sends the
transcription to other agents.
Transcription Agent.
This agent listens to audio fragments (containing speech) on a ZMQ SUB socket,
transcribes them using the configured :class:`SpeechRecognizer`, and sends the
resulting text to other agents (e.g., the Text Belief Extractor).
It uses an internal semaphore to limit the number of concurrent transcription tasks.
:ivar audio_in_address: The ZMQ address to receive audio from (usually from VAD Agent).
:ivar audio_in_socket: The ZMQ SUB socket instance.
:ivar speech_recognizer: The speech recognition engine instance.
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
"""
def __init__(self, audio_in_address: str):
"""
Initialize the Transcription Agent.
:param audio_in_address: The ZMQ address of the audio source (e.g., VAD output).
"""
super().__init__(settings.agent_settings.transcription_name)
self.audio_in_address = audio_in_address
@@ -26,6 +41,13 @@ class TranscriptionAgent(BaseAgent):
self._concurrency = None
async def setup(self):
"""
Initialize the agent resources.
1. Connects to the audio input ZMQ socket.
2. Initializes the :class:`SpeechRecognizer` (choosing the best available backend).
3. Starts the background transcription loop.
"""
self.logger.info("Setting up %s", self.name)
self._connect_audio_in_socket()
@@ -42,23 +64,45 @@ class TranscriptionAgent(BaseAgent):
self.logger.info("Finished setting up %s", self.name)
async def stop(self):
"""
Stop the agent and close sockets.
"""
assert self.audio_in_socket is not None
self.audio_in_socket.close()
self.audio_in_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
"""
Helper to connect the ZMQ SUB socket for audio input.
"""
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)
async def _transcribe(self, audio: np.ndarray) -> str:
"""
Run the speech recognition on the audio data.
This runs in a separate thread (via `asyncio.to_thread`) to avoid blocking the event loop,
constrained by the concurrency semaphore.
:param audio: The audio data as a numpy array.
:return: The transcribed text string.
"""
assert self._concurrency is not None and self.speech_recognizer is not None
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it."""
"""
Share a transcription to the other agents that depend on it.
Currently sends to:
- :attr:`settings.agent_settings.text_belief_extractor_name`
:param transcription: The transcribed text.
"""
receiver_names = [
settings.agent_settings.text_belief_extractor_name,
]
@@ -72,6 +116,12 @@ class TranscriptionAgent(BaseAgent):
await self.send(message)
async def _transcribing_loop(self) -> None:
"""
The main loop for receiving audio and triggering transcription.
Receives audio chunks from ZMQ, decodes them to float32, and calls :meth:`_transcribe`.
If speech is found, it calls :meth:`_share_transcription`.
"""
while self._running:
try:
assert self.audio_in_socket is not None

View File

@@ -15,6 +15,8 @@ class SocketPoller[T]:
"""
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
:param T: The type of data returned by the socket.
"""
def __init__(
@@ -35,7 +37,7 @@ class SocketPoller[T]:
"""
Get data from the socket, or None if the timeout is reached.
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
:param timeout_ms: If given, the timeout. Otherwise, ``self.timeout_ms`` is used.
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
@@ -47,11 +49,27 @@ class SocketPoller[T]:
class VADAgent(BaseAgent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
Voice Activity Detection (VAD) Agent.
This agent:
1. Receives an audio stream (via ZMQ).
2. Processes the audio using the Silero VAD model to detect speech.
3. Buffers potential speech segments.
4. Publishes valid speech fragments (containing speech plus small buffer) to a ZMQ PUB socket.
5. Instantiates and starts agents (like :class:`TranscriptionAgent`) that use this output.
:ivar audio_in_address: Address of the input audio stream.
:ivar audio_in_bind: Whether to bind or connect to the input address.
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
"""
Initialize the VAD Agent.
:param audio_in_address: ZMQ address for input audio.
:param audio_in_bind: True if this agent should bind to the input address, False to connect.
"""
super().__init__(settings.agent_settings.vad_name)
self.audio_in_address = audio_in_address
@@ -67,6 +85,15 @@ class VADAgent(BaseAgent):
self.model = None
async def setup(self):
"""
Initialize resources.
1. Connects audio input socket.
2. Binds audio output socket (random port).
3. Loads VAD model from Torch Hub.
4. Starts the streaming loop.
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
"""
self.logger.info("Setting up %s", self.name)
self._connect_audio_in_socket()
@@ -123,7 +150,9 @@ class VADAgent(BaseAgent):
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed."""
"""
Returns the port bound, or None if binding failed.
"""
try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
@@ -144,6 +173,15 @@ class VADAgent(BaseAgent):
self._ready.set()
async def _streaming_loop(self):
"""
Main loop for processing audio stream.
1. Polls for new audio chunks.
2. Passes chunk to VAD model.
3. Manages `i_since_speech` counter to determine start/end of speech.
4. Buffers speech + context.
5. Sends complete speech segment to output socket when silence is detected.
"""
await self._ready.wait()
while self._running:
assert self.audio_in_poller is not None