139 lines
4.9 KiB
Python
139 lines
4.9 KiB
Python
import asyncio
|
|
|
|
import numpy as np
|
|
import zmq
|
|
import zmq.asyncio as azmq
|
|
|
|
from control_backend.agents import BaseAgent
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.core.config import settings
|
|
|
|
from .speech_recognizer import SpeechRecognizer
|
|
|
|
|
|
class TranscriptionAgent(BaseAgent):
|
|
"""
|
|
Transcription Agent.
|
|
|
|
This agent listens to audio fragments (containing speech) on a ZMQ SUB socket,
|
|
transcribes them using the configured :class:`SpeechRecognizer`, and sends the
|
|
resulting text to other agents (e.g., the Text Belief Extractor).
|
|
|
|
It uses an internal semaphore to limit the number of concurrent transcription tasks.
|
|
|
|
:ivar audio_in_address: The ZMQ address to receive audio from (usually from VAD Agent).
|
|
:ivar audio_in_socket: The ZMQ SUB socket instance.
|
|
:ivar speech_recognizer: The speech recognition engine instance.
|
|
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
|
|
"""
|
|
|
|
def __init__(self, audio_in_address: str):
|
|
"""
|
|
Initialize the Transcription Agent.
|
|
|
|
:param audio_in_address: The ZMQ address of the audio source (e.g., VAD output).
|
|
"""
|
|
super().__init__(settings.agent_settings.transcription_name)
|
|
|
|
self.audio_in_address = audio_in_address
|
|
self.audio_in_socket: azmq.Socket | None = None
|
|
self.speech_recognizer = None
|
|
self._concurrency = None
|
|
|
|
async def setup(self):
|
|
"""
|
|
Initialize the agent resources.
|
|
|
|
1. Connects to the audio input ZMQ socket.
|
|
2. Initializes the :class:`SpeechRecognizer` (choosing the best available backend).
|
|
3. Starts the background transcription loop.
|
|
"""
|
|
self.logger.info("Setting up %s", self.name)
|
|
|
|
self._connect_audio_in_socket()
|
|
|
|
# Initialize recognizer and semaphore
|
|
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
|
|
self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
|
|
self.speech_recognizer = SpeechRecognizer.best_type()
|
|
self.speech_recognizer.load_model() # Warmup
|
|
|
|
# Start background loop
|
|
self.add_behavior(self._transcribing_loop())
|
|
|
|
self.logger.info("Finished setting up %s", self.name)
|
|
|
|
async def stop(self):
|
|
"""
|
|
Stop the agent and close sockets.
|
|
"""
|
|
assert self.audio_in_socket is not None
|
|
self.audio_in_socket.close()
|
|
self.audio_in_socket = None
|
|
return await super().stop()
|
|
|
|
def _connect_audio_in_socket(self):
|
|
"""
|
|
Helper to connect the ZMQ SUB socket for audio input.
|
|
"""
|
|
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
|
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
|
self.audio_in_socket.connect(self.audio_in_address)
|
|
|
|
async def _transcribe(self, audio: np.ndarray) -> str:
|
|
"""
|
|
Run the speech recognition on the audio data.
|
|
|
|
This runs in a separate thread (via `asyncio.to_thread`) to avoid blocking the event loop,
|
|
constrained by the concurrency semaphore.
|
|
|
|
:param audio: The audio data as a numpy array.
|
|
:return: The transcribed text string.
|
|
"""
|
|
assert self._concurrency is not None and self.speech_recognizer is not None
|
|
async with self._concurrency:
|
|
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
|
|
|
async def _share_transcription(self, transcription: str):
|
|
"""
|
|
Share a transcription to the other agents that depend on it.
|
|
|
|
Currently sends to:
|
|
- :attr:`settings.agent_settings.text_belief_extractor_name`
|
|
|
|
:param transcription: The transcribed text.
|
|
"""
|
|
receiver_names = [
|
|
settings.agent_settings.text_belief_extractor_name,
|
|
]
|
|
|
|
for receiver_name in receiver_names:
|
|
message = InternalMessage(
|
|
to=receiver_name,
|
|
sender=self.name,
|
|
body=transcription,
|
|
)
|
|
await self.send(message)
|
|
|
|
async def _transcribing_loop(self) -> None:
|
|
"""
|
|
The main loop for receiving audio and triggering transcription.
|
|
|
|
Receives audio chunks from ZMQ, decodes them to float32, and calls :meth:`_transcribe`.
|
|
If speech is found, it calls :meth:`_share_transcription`.
|
|
"""
|
|
while self._running:
|
|
try:
|
|
assert self.audio_in_socket is not None
|
|
audio_data = await self.audio_in_socket.recv()
|
|
audio = np.frombuffer(audio_data, dtype=np.float32)
|
|
speech = await self._transcribe(audio)
|
|
if not speech:
|
|
self.logger.info("Nothing transcribed.")
|
|
continue
|
|
|
|
self.logger.info("Transcribed speech: %s", speech)
|
|
await self._share_transcription(speech)
|
|
except Exception as e:
|
|
self.logger.error(f"Error in transcription loop: {e}")
|