feat: implement transcriber agent

Uses speech fragments of the VAD agent, emits transcribed text over SPADE's default communication channel to no recipient for now.

ref: N25B-209
This commit is contained in:
Twirre Meulenbelt
2025-10-28 21:57:25 +01:00
parent 7085719d2f
commit 2bb008994b
5 changed files with 138 additions and 1 deletions

View File

@@ -0,0 +1 @@
from .transcription_agent import TranscriptionAgent as TranscriptionAgent

View File

@@ -0,0 +1,62 @@
import abc
import sys
if sys.platform == "darwin":
import mlx.core as mx
import mlx_whisper
from mlx_whisper.transcribe import ModelHolder
import numpy as np
import torch
import whisper
class SpeechRecognizer(abc.ABC):
@abc.abstractmethod
def load_model(self): ...
@abc.abstractmethod
def recognize_speech(self, audio: np.ndarray) -> str: ...
@staticmethod
def best_type():
if torch.mps.is_available():
print("Choosing MLX Whisper model.")
return MLXWhisperSpeechRecognizer()
else:
print("Choosing reference Whisper model.")
return OpenAIWhisperSpeechRecognizer()
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
def __init__(self):
super().__init__()
self.model = None
self.model_name = "mlx-community/whisper-small.en-mlx"
def load_model(self):
if self.model is not None:
return
ModelHolder.get_model(
self.model_name, mx.float16
) # Should store it in memory for later usage
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"]
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
def __init__(self):
super().__init__()
self.model = None
def load_model(self):
if self.model is not None:
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = whisper.load_model("small.en", device=device)
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return self.model.transcribe(audio)["text"]

View File

@@ -0,0 +1,69 @@
import asyncio
import logging
import numpy as np
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
class TranscriptionAgent(Agent):
"""
An agent which listens to audio fragments with voice, transcribes them, and sends the
transcription to other agents.
"""
def __init__(self, audio_in_address: str):
jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.transcription_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None
class Transcribing(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket):
super().__init__()
self.audio_in_socket = audio_in_socket
self.speech_recognizer = SpeechRecognizer.best_type()
self._concurrency = asyncio.Semaphore(3)
async def _transcribe(self, audio: np.ndarray) -> str:
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
async def run(self) -> None:
audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio)
logger.info("Transcribed speech: %s", speech)
message = Message(body=speech)
await self.send(message)
async def stop(self):
self.audio_in_socket.close()
self.audio_in_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)
async def setup(self):
logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
transcribing = self.Transcribing(self.audio_in_socket)
self.add_behaviour(transcribing)
logger.info("Finished setting up %s", self.jid)

View File

@@ -7,6 +7,7 @@ import zmq.asyncio as azmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from control_backend.agents.transcription import TranscriptionAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context from control_backend.core.zmq_context import context as zmq_context
@@ -145,10 +146,13 @@ class VADAgent(Agent):
if audio_out_port is None: if audio_out_port is None:
await self.stop() await self.stop()
return return
audio_out_address = f"tcp://localhost:{audio_out_port}"
streaming = Streaming(self.audio_in_socket, self.audio_out_socket) streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming) self.add_behaviour(streaming)
# ... start agents dependent on the output audio fragments here # Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
logger.info("Finished setting up %s", self.jid) logger.info("Finished setting up %s", self.jid)

View File

@@ -11,6 +11,7 @@ class AgentSettings(BaseModel):
bdi_core_agent_name: str = "bdi_core" bdi_core_agent_name: str = "bdi_core"
belief_collector_agent_name: str = "belief_collector" belief_collector_agent_name: str = "belief_collector"
vad_agent_name: str = "vad_agent" vad_agent_name: str = "vad_agent"
transcription_agent_name: str = "transcription_agent"
ri_communication_agent_name: str = "ri_communication_agent" ri_communication_agent_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent" ri_command_agent_name: str = "ri_command_agent"