From 2bb008994b5b175702b3e5d6c494ab2226143378 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:57:25 +0100 Subject: [PATCH] feat: implement transcriber agent Uses speech fragments of the VAD agent, emits transcribed text over SPADE's default communication channel to no recipient for now. ref: N25B-209 --- .../agents/transcription/__init__.py | 1 + .../agents/transcription/speech_recognizer.py | 62 +++++++++++++++++ .../transcription/transcription_agent.py | 69 +++++++++++++++++++ src/control_backend/agents/vad_agent.py | 6 +- src/control_backend/core/config.py | 1 + 5 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 src/control_backend/agents/transcription/__init__.py create mode 100644 src/control_backend/agents/transcription/speech_recognizer.py create mode 100644 src/control_backend/agents/transcription/transcription_agent.py diff --git a/src/control_backend/agents/transcription/__init__.py b/src/control_backend/agents/transcription/__init__.py new file mode 100644 index 0000000..3e87e70 --- /dev/null +++ b/src/control_backend/agents/transcription/__init__.py @@ -0,0 +1 @@ +from .transcription_agent import TranscriptionAgent as TranscriptionAgent diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py new file mode 100644 index 0000000..58523a4 --- /dev/null +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -0,0 +1,62 @@ +import abc +import sys + +if sys.platform == "darwin": + import mlx.core as mx + import mlx_whisper + from mlx_whisper.transcribe import ModelHolder + +import numpy as np +import torch +import whisper + + +class SpeechRecognizer(abc.ABC): + @abc.abstractmethod + def load_model(self): ... + + @abc.abstractmethod + def recognize_speech(self, audio: np.ndarray) -> str: ... + + @staticmethod + def best_type(): + if torch.mps.is_available(): + print("Choosing MLX Whisper model.") + return MLXWhisperSpeechRecognizer() + else: + print("Choosing reference Whisper model.") + return OpenAIWhisperSpeechRecognizer() + + +class MLXWhisperSpeechRecognizer(SpeechRecognizer): + def __init__(self): + super().__init__() + self.model = None + self.model_name = "mlx-community/whisper-small.en-mlx" + + def load_model(self): + if self.model is not None: + return + ModelHolder.get_model( + self.model_name, mx.float16 + ) # Should store it in memory for later usage + + def recognize_speech(self, audio: np.ndarray) -> str: + self.load_model() + return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"] + + +class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): + def __init__(self): + super().__init__() + self.model = None + + def load_model(self): + if self.model is not None: + return + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model = whisper.load_model("small.en", device=device) + + def recognize_speech(self, audio: np.ndarray) -> str: + self.load_model() + return self.model.transcribe(audio)["text"] diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py new file mode 100644 index 0000000..b572f5e --- /dev/null +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -0,0 +1,69 @@ +import asyncio +import logging + +import numpy as np +import zmq +import zmq.asyncio as azmq +from spade.agent import Agent +from spade.behaviour import CyclicBehaviour +from spade.message import Message + +from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer +from control_backend.core.config import settings +from control_backend.core.zmq_context import context as zmq_context + +logger = logging.getLogger(__name__) + + +class TranscriptionAgent(Agent): + """ + An agent which listens to audio fragments with voice, transcribes them, and sends the + transcription to other agents. + """ + + def __init__(self, audio_in_address: str): + jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host + super().__init__(jid, settings.agent_settings.transcription_agent_name) + + self.audio_in_address = audio_in_address + self.audio_in_socket: azmq.Socket | None = None + + class Transcribing(CyclicBehaviour): + def __init__(self, audio_in_socket: azmq.Socket): + super().__init__() + self.audio_in_socket = audio_in_socket + self.speech_recognizer = SpeechRecognizer.best_type() + self._concurrency = asyncio.Semaphore(3) + + async def _transcribe(self, audio: np.ndarray) -> str: + async with self._concurrency: + return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio) + + async def run(self) -> None: + audio = await self.audio_in_socket.recv() + audio = np.frombuffer(audio, dtype=np.float32) + speech = await self._transcribe(audio) + logger.info("Transcribed speech: %s", speech) + + message = Message(body=speech) + await self.send(message) + + async def stop(self): + self.audio_in_socket.close() + self.audio_in_socket = None + return await super().stop() + + def _connect_audio_in_socket(self): + self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") + self.audio_in_socket.connect(self.audio_in_address) + + async def setup(self): + logger.info("Setting up %s", self.jid) + + self._connect_audio_in_socket() + + transcribing = self.Transcribing(self.audio_in_socket) + self.add_behaviour(transcribing) + + logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index 4fef563..fc60e48 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -7,6 +7,7 @@ import zmq.asyncio as azmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour +from control_backend.agents.transcription import TranscriptionAgent from control_backend.core.config import settings from control_backend.core.zmq_context import context as zmq_context @@ -145,10 +146,13 @@ class VADAgent(Agent): if audio_out_port is None: await self.stop() return + audio_out_address = f"tcp://localhost:{audio_out_port}" streaming = Streaming(self.audio_in_socket, self.audio_out_socket) self.add_behaviour(streaming) - # ... start agents dependent on the output audio fragments here + # Start agents dependent on the output audio fragments here + transcriber = TranscriptionAgent(audio_out_address) + await transcriber.start() logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 4758618..ea362ce 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -11,6 +11,7 @@ class AgentSettings(BaseModel): bdi_core_agent_name: str = "bdi_core" belief_collector_agent_name: str = "belief_collector" vad_agent_name: str = "vad_agent" + transcription_agent_name: str = "transcription_agent" ri_communication_agent_name: str = "ri_communication_agent" ri_command_agent_name: str = "ri_command_agent"