docs: add docs to CB
Pretty much every class and method should have documentation now. ref: N25B-295
This commit is contained in:
@@ -14,15 +14,28 @@ from control_backend.core.config import settings
|
||||
|
||||
|
||||
class SpeechRecognizer(abc.ABC):
|
||||
"""
|
||||
Abstract base class for speech recognition backends.
|
||||
|
||||
Provides a common interface for loading models and transcribing audio,
|
||||
as well as heuristics for estimating token counts to optimize decoding.
|
||||
|
||||
:ivar limit_output_length: If True, limits the generated text length based on audio duration.
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
"""
|
||||
:param limit_output_length: When `True`, the length of the generated speech will be limited
|
||||
by the length of the input audio and some heuristics.
|
||||
:param limit_output_length: When ``True``, the length of the generated speech will be
|
||||
limited by the length of the input audio and some heuristics.
|
||||
"""
|
||||
self.limit_output_length = limit_output_length
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_model(self): ...
|
||||
def load_model(self):
|
||||
"""
|
||||
Load the speech recognition model into memory.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
@@ -30,15 +43,17 @@ class SpeechRecognizer(abc.ABC):
|
||||
Recognize speech from the given audio sample.
|
||||
|
||||
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
|
||||
range [-1.0, 1.0].
|
||||
:return: Recognized speech.
|
||||
range [-1.0, 1.0].
|
||||
:return: The recognized speech text.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||
"""
|
||||
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
||||
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
|
||||
Estimate the maximum length of a given audio sample in tokens.
|
||||
|
||||
Assumes a maximum speaking rate of 450 words per minute (3x average), and assumes that
|
||||
3 words is approx. 4 tokens.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||
:return: The estimated length of the transcribed audio in tokens.
|
||||
@@ -51,8 +66,10 @@ class SpeechRecognizer(abc.ABC):
|
||||
|
||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||
"""
|
||||
Construct decoding options for the Whisper model.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
|
||||
:return: A dict that can be used to construct `whisper.DecodingOptions`.
|
||||
:return: A dict that can be used to construct ``whisper.DecodingOptions`` (or equivalent).
|
||||
"""
|
||||
options = {}
|
||||
if self.limit_output_length:
|
||||
@@ -61,7 +78,12 @@ class SpeechRecognizer(abc.ABC):
|
||||
|
||||
@staticmethod
|
||||
def best_type():
|
||||
"""Get the best type of SpeechRecognizer based on system capabilities."""
|
||||
"""
|
||||
Factory method to get the best available `SpeechRecognizer`.
|
||||
|
||||
:return: An instance of :class:`MLXWhisperSpeechRecognizer` if on macOS with Apple Silicon,
|
||||
otherwise :class:`OpenAIWhisperSpeechRecognizer`.
|
||||
"""
|
||||
if torch.mps.is_available():
|
||||
print("Choosing MLX Whisper model.")
|
||||
return MLXWhisperSpeechRecognizer()
|
||||
@@ -71,12 +93,20 @@ class SpeechRecognizer(abc.ABC):
|
||||
|
||||
|
||||
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
"""
|
||||
Speech recognizer using the MLX framework (optimized for Apple Silicon).
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
super().__init__(limit_output_length)
|
||||
self.was_loaded = False
|
||||
self.model_name = settings.speech_model_settings.mlx_model_name
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
Ensures the model is downloaded and cached. MLX loads dynamically, so this
|
||||
pre-fetches the model.
|
||||
"""
|
||||
if self.was_loaded:
|
||||
return
|
||||
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
||||
@@ -94,11 +124,18 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
|
||||
|
||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
"""
|
||||
Speech recognizer using the standard OpenAI Whisper library (PyTorch).
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
super().__init__(limit_output_length)
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
Loads the OpenAI Whisper model onto the available device (CUDA or CPU).
|
||||
"""
|
||||
if self.model is not None:
|
||||
return
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -13,11 +13,26 @@ from .speech_recognizer import SpeechRecognizer
|
||||
|
||||
class TranscriptionAgent(BaseAgent):
|
||||
"""
|
||||
An agent which listens to audio fragments with voice, transcribes them, and sends the
|
||||
transcription to other agents.
|
||||
Transcription Agent.
|
||||
|
||||
This agent listens to audio fragments (containing speech) on a ZMQ SUB socket,
|
||||
transcribes them using the configured :class:`SpeechRecognizer`, and sends the
|
||||
resulting text to other agents (e.g., the Text Belief Extractor).
|
||||
|
||||
It uses an internal semaphore to limit the number of concurrent transcription tasks.
|
||||
|
||||
:ivar audio_in_address: The ZMQ address to receive audio from (usually from VAD Agent).
|
||||
:ivar audio_in_socket: The ZMQ SUB socket instance.
|
||||
:ivar speech_recognizer: The speech recognition engine instance.
|
||||
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str):
|
||||
"""
|
||||
Initialize the Transcription Agent.
|
||||
|
||||
:param audio_in_address: The ZMQ address of the audio source (e.g., VAD output).
|
||||
"""
|
||||
super().__init__(settings.agent_settings.transcription_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
@@ -26,6 +41,13 @@ class TranscriptionAgent(BaseAgent):
|
||||
self._concurrency = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent resources.
|
||||
|
||||
1. Connects to the audio input ZMQ socket.
|
||||
2. Initializes the :class:`SpeechRecognizer` (choosing the best available backend).
|
||||
3. Starts the background transcription loop.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
@@ -42,23 +64,45 @@ class TranscriptionAgent(BaseAgent):
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Stop the agent and close sockets.
|
||||
"""
|
||||
assert self.audio_in_socket is not None
|
||||
self.audio_in_socket.close()
|
||||
self.audio_in_socket = None
|
||||
return await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
"""
|
||||
Helper to connect the ZMQ SUB socket for audio input.
|
||||
"""
|
||||
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.audio_in_socket.connect(self.audio_in_address)
|
||||
|
||||
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||
"""
|
||||
Run the speech recognition on the audio data.
|
||||
|
||||
This runs in a separate thread (via `asyncio.to_thread`) to avoid blocking the event loop,
|
||||
constrained by the concurrency semaphore.
|
||||
|
||||
:param audio: The audio data as a numpy array.
|
||||
:return: The transcribed text string.
|
||||
"""
|
||||
assert self._concurrency is not None and self.speech_recognizer is not None
|
||||
async with self._concurrency:
|
||||
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||
|
||||
async def _share_transcription(self, transcription: str):
|
||||
"""Share a transcription to the other agents that depend on it."""
|
||||
"""
|
||||
Share a transcription to the other agents that depend on it.
|
||||
|
||||
Currently sends to:
|
||||
- :attr:`settings.agent_settings.text_belief_extractor_name`
|
||||
|
||||
:param transcription: The transcribed text.
|
||||
"""
|
||||
receiver_names = [
|
||||
settings.agent_settings.text_belief_extractor_name,
|
||||
]
|
||||
@@ -72,6 +116,12 @@ class TranscriptionAgent(BaseAgent):
|
||||
await self.send(message)
|
||||
|
||||
async def _transcribing_loop(self) -> None:
|
||||
"""
|
||||
The main loop for receiving audio and triggering transcription.
|
||||
|
||||
Receives audio chunks from ZMQ, decodes them to float32, and calls :meth:`_transcribe`.
|
||||
If speech is found, it calls :meth:`_share_transcription`.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
assert self.audio_in_socket is not None
|
||||
|
||||
@@ -15,6 +15,8 @@ class SocketPoller[T]:
|
||||
"""
|
||||
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
|
||||
multiple usages.
|
||||
|
||||
:param T: The type of data returned by the socket.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,7 +37,7 @@ class SocketPoller[T]:
|
||||
"""
|
||||
Get data from the socket, or None if the timeout is reached.
|
||||
|
||||
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
|
||||
:param timeout_ms: If given, the timeout. Otherwise, ``self.timeout_ms`` is used.
|
||||
:return: Data from the socket or None.
|
||||
"""
|
||||
timeout_ms = timeout_ms or self.timeout_ms
|
||||
@@ -47,11 +49,27 @@ class SocketPoller[T]:
|
||||
|
||||
class VADAgent(BaseAgent):
|
||||
"""
|
||||
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
|
||||
fragments with detected speech to other agents over ZeroMQ.
|
||||
Voice Activity Detection (VAD) Agent.
|
||||
|
||||
This agent:
|
||||
1. Receives an audio stream (via ZMQ).
|
||||
2. Processes the audio using the Silero VAD model to detect speech.
|
||||
3. Buffers potential speech segments.
|
||||
4. Publishes valid speech fragments (containing speech plus small buffer) to a ZMQ PUB socket.
|
||||
5. Instantiates and starts agents (like :class:`TranscriptionAgent`) that use this output.
|
||||
|
||||
:ivar audio_in_address: Address of the input audio stream.
|
||||
:ivar audio_in_bind: Whether to bind or connect to the input address.
|
||||
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||
"""
|
||||
Initialize the VAD Agent.
|
||||
|
||||
:param audio_in_address: ZMQ address for input audio.
|
||||
:param audio_in_bind: True if this agent should bind to the input address, False to connect.
|
||||
"""
|
||||
super().__init__(settings.agent_settings.vad_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
@@ -67,6 +85,15 @@ class VADAgent(BaseAgent):
|
||||
self.model = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize resources.
|
||||
|
||||
1. Connects audio input socket.
|
||||
2. Binds audio output socket (random port).
|
||||
3. Loads VAD model from Torch Hub.
|
||||
4. Starts the streaming loop.
|
||||
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
@@ -123,7 +150,9 @@ class VADAgent(BaseAgent):
|
||||
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
|
||||
|
||||
def _connect_audio_out_socket(self) -> int | None:
|
||||
"""Returns the port bound, or None if binding failed."""
|
||||
"""
|
||||
Returns the port bound, or None if binding failed.
|
||||
"""
|
||||
try:
|
||||
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
|
||||
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
|
||||
@@ -144,6 +173,15 @@ class VADAgent(BaseAgent):
|
||||
self._ready.set()
|
||||
|
||||
async def _streaming_loop(self):
|
||||
"""
|
||||
Main loop for processing audio stream.
|
||||
|
||||
1. Polls for new audio chunks.
|
||||
2. Passes chunk to VAD model.
|
||||
3. Manages `i_since_speech` counter to determine start/end of speech.
|
||||
4. Buffers speech + context.
|
||||
5. Sends complete speech segment to output socket when silence is detected.
|
||||
"""
|
||||
await self._ready.wait()
|
||||
while self._running:
|
||||
assert self.audio_in_poller is not None
|
||||
|
||||
Reference in New Issue
Block a user