85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
import asyncio
|
|
import logging
|
|
|
|
import numpy as np
|
|
import zmq
|
|
import zmq.asyncio as azmq
|
|
from spade.agent import Agent
|
|
from spade.behaviour import CyclicBehaviour
|
|
from spade.message import Message
|
|
|
|
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
|
|
from control_backend.core.config import settings
|
|
from control_backend.core.zmq_context import context as zmq_context
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TranscriptionAgent(Agent):
|
|
"""
|
|
An agent which listens to audio fragments with voice, transcribes them, and sends the
|
|
transcription to other agents.
|
|
"""
|
|
|
|
def __init__(self, audio_in_address: str):
|
|
jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host
|
|
super().__init__(jid, settings.agent_settings.transcription_agent_name)
|
|
|
|
self.audio_in_address = audio_in_address
|
|
self.audio_in_socket: azmq.Socket | None = None
|
|
|
|
class Transcribing(CyclicBehaviour):
|
|
def __init__(self, audio_in_socket: azmq.Socket):
|
|
super().__init__()
|
|
self.audio_in_socket = audio_in_socket
|
|
self.speech_recognizer = SpeechRecognizer.best_type()
|
|
self._concurrency = asyncio.Semaphore(3)
|
|
|
|
def warmup(self):
|
|
"""Load the transcription model into memory to speed up the first transcription."""
|
|
self.speech_recognizer.load_model()
|
|
|
|
async def _transcribe(self, audio: np.ndarray) -> str:
|
|
async with self._concurrency:
|
|
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
|
|
|
async def _share_transcription(self, transcription: str):
|
|
"""Share a transcription to the other agents that depend on it."""
|
|
receiver_jids = [
|
|
settings.agent_settings.text_belief_extractor_agent_name
|
|
+ '@' + settings.agent_settings.host,
|
|
] # Set message receivers here
|
|
|
|
for receiver_jid in receiver_jids:
|
|
message = Message(to=receiver_jid, body=transcription)
|
|
await self.send(message)
|
|
|
|
async def run(self) -> None:
|
|
audio = await self.audio_in_socket.recv()
|
|
audio = np.frombuffer(audio, dtype=np.float32)
|
|
speech = await self._transcribe(audio)
|
|
logger.info("Transcribed speech: %s", speech)
|
|
|
|
await self._share_transcription(speech)
|
|
|
|
async def stop(self):
|
|
self.audio_in_socket.close()
|
|
self.audio_in_socket = None
|
|
return await super().stop()
|
|
|
|
def _connect_audio_in_socket(self):
|
|
self.audio_in_socket = zmq_context.socket(zmq.SUB)
|
|
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
|
self.audio_in_socket.connect(self.audio_in_address)
|
|
|
|
async def setup(self):
|
|
logger.info("Setting up %s", self.jid)
|
|
|
|
self._connect_audio_in_socket()
|
|
|
|
transcribing = self.Transcribing(self.audio_in_socket)
|
|
transcribing.warmup()
|
|
self.add_behaviour(transcribing)
|
|
|
|
logger.info("Finished setting up %s", self.jid)
|