import asyncio import numpy as np import zmq import zmq.asyncio as azmq from spade.behaviour import CyclicBehaviour from spade.message import Message from control_backend.agents import BaseAgent from control_backend.core.config import settings from .speech_recognizer import SpeechRecognizer class TranscriptionAgent(BaseAgent): """ An agent which listens to audio fragments with voice, transcribes them, and sends the transcription to other agents. """ def __init__(self, audio_in_address: str): jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host super().__init__(jid, settings.agent_settings.transcription_agent_name) self.audio_in_address = audio_in_address self.audio_in_socket: azmq.Socket | None = None class Transcribing(CyclicBehaviour): def __init__(self, audio_in_socket: azmq.Socket): super().__init__() self.audio_in_socket = audio_in_socket self.speech_recognizer = SpeechRecognizer.best_type() self._concurrency = asyncio.Semaphore(3) def warmup(self): """Load the transcription model into memory to speed up the first transcription.""" self.speech_recognizer.load_model() async def _transcribe(self, audio: np.ndarray) -> str: async with self._concurrency: return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio) async def _share_transcription(self, transcription: str): """Share a transcription to the other agents that depend on it.""" receiver_jids = [ settings.agent_settings.text_belief_extractor_agent_name + "@" + settings.agent_settings.host, ] # Set message receivers here for receiver_jid in receiver_jids: message = Message(to=receiver_jid, body=transcription) await self.send(message) async def run(self) -> None: audio = await self.audio_in_socket.recv() audio = np.frombuffer(audio, dtype=np.float32) speech = await self._transcribe(audio) if not speech: self.agent.logger.info("Nothing transcribed.") return self.agent.logger.info("Transcribed speech: %s", speech) await self._share_transcription(speech) async def stop(self): self.audio_in_socket.close() self.audio_in_socket = None return await super().stop() def _connect_audio_in_socket(self): self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.connect(self.audio_in_address) async def setup(self): self.logger.info("Setting up %s", self.jid) self._connect_audio_in_socket() transcribing = self.Transcribing(self.audio_in_socket) transcribing.warmup() self.add_behaviour(transcribing) self.logger.info("Finished setting up %s", self.jid)