Files
pepperplus-cb/src/control_backend/agents/perception/transcription_agent/transcription_agent.py
2026-01-29 15:36:28 +01:00

155 lines
5.6 KiB
Python

"""
This program has been developed by students from the bachelor Computer Science at Utrecht
University within the Software Project course.
© Copyright Utrecht University (Department of Information and Computing Sciences)
"""
import asyncio
import logging
import numpy as np
import zmq
import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from .speech_recognizer import SpeechRecognizer
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
class TranscriptionAgent(BaseAgent):
"""
Transcription Agent.
This agent listens to audio fragments (containing speech) on a ZMQ SUB socket,
transcribes them using the configured :class:`SpeechRecognizer`, and sends the
resulting text to other agents (e.g., the Text Belief Extractor).
It uses an internal semaphore to limit the number of concurrent transcription tasks.
:ivar audio_in_address: The ZMQ address to receive audio from (usually from VAD Agent).
:ivar audio_in_socket: The ZMQ SUB socket instance.
:ivar speech_recognizer: The speech recognition engine instance.
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
:ivar _current_speech_reference: The reference of the current user utterance, for synchronising
experiment logs.
"""
def __init__(self, audio_in_address: str):
"""
Initialize the Transcription Agent.
:param audio_in_address: The ZMQ address of the audio source (e.g., VAD output).
"""
super().__init__(settings.agent_settings.transcription_name)
self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None
self.speech_recognizer = None
self._concurrency = None
self._current_speech_reference: str | None = None
async def setup(self):
"""
Initialize the agent resources.
1. Connects to the audio input ZMQ socket.
2. Initializes the :class:`SpeechRecognizer` (choosing the best available backend).
3. Starts the background transcription loop.
"""
self.logger.info("Setting up %s", self.name)
self._connect_audio_in_socket()
# Initialize recognizer and semaphore
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
self.speech_recognizer = SpeechRecognizer.best_type()
self.speech_recognizer.load_model() # Warmup
# Start background loop
self.add_behavior(self._transcribing_loop())
self.logger.info("Finished setting up %s", self.name)
async def handle_message(self, msg: InternalMessage):
if msg.thread == "voice_activity":
self._current_speech_reference = msg.body
async def stop(self):
"""
Stop the agent and close sockets.
"""
assert self.audio_in_socket is not None
self.audio_in_socket.close()
self.audio_in_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
"""
Connects the ZMQ SUB socket for receiving audio data.
"""
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)
async def _transcribe(self, audio: np.ndarray) -> str:
"""
Run the speech recognition on the audio data.
This runs in a separate thread (via `asyncio.to_thread`) to avoid blocking the event loop,
constrained by the concurrency semaphore.
:param audio: The audio data as a numpy array.
:return: The transcribed text string.
"""
assert self._concurrency is not None and self.speech_recognizer is not None
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
async def _share_transcription(self, transcription: str):
"""
Share a transcription to the other agents that depend on it, and to experiment logs.
Currently sends to:
- :attr:`settings.agent_settings.text_belief_extractor_name`
- The UI via the experiment logger
:param transcription: The transcribed text.
"""
experiment_logger.chat(
transcription,
extra={"role": "user", "reference": self._current_speech_reference, "partial": False},
)
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=self.name,
body=transcription,
)
await self.send(message)
async def _transcribing_loop(self) -> None:
"""
The main loop for receiving audio and triggering transcription.
Receives audio chunks from ZMQ, decodes them to float32, and calls :meth:`_transcribe`.
If speech is found, it calls :meth:`_share_transcription`.
"""
while self._running:
try:
assert self.audio_in_socket is not None
audio_data = await self.audio_in_socket.recv()
audio = np.frombuffer(audio_data, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
self.logger.debug("Nothing transcribed.")
continue
await self._share_transcription(speech)
except Exception as e:
self.logger.error(f"Error in transcription loop: {e}")