feat: apply new agent naming standards

Expanding abbreviations to remove ambiguity, simplifying agent names to reduce repetition.

ref: N25B-257
This commit is contained in:
Twirre Meulenbelt
2025-11-19 15:56:09 +01:00
parent f4dbca5b94
commit efe49c219c
39 changed files with 218 additions and 222 deletions

View File

@@ -0,0 +1,4 @@
from .transcription_agent.transcription_agent import (
TranscriptionAgent as TranscriptionAgent,
)
from .vad_agent import VADAgent as VADAgent

View File

@@ -0,0 +1,107 @@
import abc
import sys
if sys.platform == "darwin":
import mlx.core as mx
import mlx_whisper
from mlx_whisper.transcribe import ModelHolder
import numpy as np
import torch
import whisper
class SpeechRecognizer(abc.ABC):
def __init__(self, limit_output_length=True):
"""
:param limit_output_length: When `True`, the length of the generated speech will be limited
by the length of the input audio and some heuristics.
"""
self.limit_output_length = limit_output_length
@abc.abstractmethod
def load_model(self): ...
@abc.abstractmethod
def recognize_speech(self, audio: np.ndarray) -> str:
"""
Recognize speech from the given audio sample.
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
range [-1.0, 1.0].
:return: Recognized speech.
"""
@staticmethod
def _estimate_max_tokens(audio: np.ndarray) -> int:
"""
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
:param audio: The audio sample (16 kHz) to use for length estimation.
:return: The estimated length of the transcribed audio in tokens.
"""
length_seconds = len(audio) / 16_000
length_minutes = length_seconds / 60
word_count = length_minutes * 450
token_count = word_count / 3 * 4
return int(token_count) + 10
def _get_decode_options(self, audio: np.ndarray) -> dict:
"""
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
:return: A dict that can be used to construct `whisper.DecodingOptions`.
"""
options = {}
if self.limit_output_length:
options["sample_len"] = self._estimate_max_tokens(audio)
return options
@staticmethod
def best_type():
"""Get the best type of SpeechRecognizer based on system capabilities."""
if torch.mps.is_available():
print("Choosing MLX Whisper model.")
return MLXWhisperSpeechRecognizer()
else:
print("Choosing reference Whisper model.")
return OpenAIWhisperSpeechRecognizer()
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.was_loaded = False
self.model_name = "mlx-community/whisper-small.en-mlx"
def load_model(self):
if self.was_loaded:
return
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
# store it in memory for later usage
ModelHolder.get_model(self.model_name, mx.float16)
self.was_loaded = True
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return mlx_whisper.transcribe(
audio,
path_or_hf_repo=self.model_name,
**self._get_decode_options(audio),
)["text"].strip()
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.model = None
def load_model(self):
if self.model is not None:
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = whisper.load_model("small.en", device=device)
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"]

View File

@@ -0,0 +1,86 @@
import asyncio
import numpy as np
import zmq
import zmq.asyncio as azmq
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .speech_recognizer import SpeechRecognizer
class TranscriptionAgent(BaseAgent):
"""
An agent which listens to audio fragments with voice, transcribes them, and sends the
transcription to other agents.
"""
def __init__(self, audio_in_address: str):
jid = settings.agent_settings.transcription_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.transcription_name)
self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None
class TranscribingBehaviour(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket):
super().__init__()
self.audio_in_socket = audio_in_socket
self.speech_recognizer = SpeechRecognizer.best_type()
self._concurrency = asyncio.Semaphore(3)
def warmup(self):
"""Load the transcription model into memory to speed up the first transcription."""
self.speech_recognizer.load_model()
async def _transcribe(self, audio: np.ndarray) -> str:
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it."""
receiver_jids = [
settings.agent_settings.texbdi_text_belief_agent_name
+ "@"
+ settings.agent_settings.host,
] # Set message receivers here
for receiver_jid in receiver_jids:
message = Message(to=receiver_jid, body=transcription)
await self.send(message)
async def run(self) -> None:
audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
self.agent.logger.info("Nothing transcribed.")
return
self.agent.logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)
async def stop(self):
self.audio_in_socket.close()
self.audio_in_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)
async def setup(self):
self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
transcribing = self.TranscribingBehaviour(self.audio_in_socket)
transcribing.warmup()
self.add_behaviour(transcribing)
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,172 @@
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .transcription_agent.transcription_agent import TranscriptionAgent
class SocketPoller[T]:
"""
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
"""
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
"""
:param socket: The socket to poll and get data from.
:param timeout_ms: A timeout in milliseconds to wait for data.
"""
self.socket = socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms
async def poll(self, timeout_ms: int | None = None) -> T | None:
"""
Get data from the socket, or None if the timeout is reached.
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
socks = dict(self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv()
return None
class StreamingBehaviour(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100 # Used to allow small pauses in speech
self._ready = False
async def reset(self):
"""Clears the ZeroMQ queue and tells this behavior to start."""
discarded = 0
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
self.agent.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready = True
async def run(self) -> None:
if not self._ready:
return
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
self.agent.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100
return
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3:
self.agent.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
self.agent.logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(BaseAgent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
self.streaming_behaviour: StreamingBehaviour | None = None
async def stop(self):
"""
Stop listening to audio, stop publishing audio, close sockets.
"""
if self.audio_in_socket is not None:
self.audio_in_socket.close()
self.audio_in_socket = None
if self.audio_out_socket is not None:
self.audio_out_socket.close()
self.audio_out_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address)
else:
self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed."""
try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def setup(self):
self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
self.streaming_behaviour = StreamingBehaviour(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(self.streaming_behaviour)
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
self.logger.info("Finished setting up %s", self.jid)