import asyncio import numpy as np import torch import zmq import zmq.asyncio as azmq from control_backend.agents import BaseAgent from control_backend.core.config import settings from control_backend.schemas.internal_message import InternalMessage from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus from .transcription_agent.transcription_agent import TranscriptionAgent class SocketPoller[T]: """ Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for multiple usages. :param T: The type of data returned by the socket. """ def __init__( self, socket: azmq.Socket, timeout_ms: int = settings.behaviour_settings.socket_poller_timeout_ms, ): """ :param socket: The socket to poll and get data from. :param timeout_ms: A timeout in milliseconds to wait for data. """ self.socket = socket self.poller = azmq.Poller() self.poller.register(self.socket, zmq.POLLIN) self.timeout_ms = timeout_ms async def poll(self, timeout_ms: int | None = None) -> T | None: """ Get data from the socket, or None if the timeout is reached. :param timeout_ms: If given, the timeout. Otherwise, ``self.timeout_ms`` is used. :return: Data from the socket or None. """ timeout_ms = timeout_ms or self.timeout_ms socks = dict(await self.poller.poll(timeout_ms)) if socks.get(self.socket) == zmq.POLLIN: return await self.socket.recv() return None class VADAgent(BaseAgent): """ Voice Activity Detection (VAD) Agent. This agent: 1. Receives an audio stream (via ZMQ). 2. Processes the audio using the Silero VAD model to detect speech. 3. Buffers potential speech segments. 4. Publishes valid speech fragments (containing speech plus small buffer) to a ZMQ PUB socket. 5. Instantiates and starts agents (like :class:`TranscriptionAgent`) that use this output. :ivar audio_in_address: Address of the input audio stream. :ivar audio_in_bind: Whether to bind or connect to the input address. :ivar audio_out_socket: ZMQ PUB socket for sending speech fragments. :ivar program_sub_socket: ZMQ SUB socket for receiving program status updates. """ def __init__(self, audio_in_address: str, audio_in_bind: bool): """ Initialize the VAD Agent. :param audio_in_address: ZMQ address for input audio. :param audio_in_bind: True if this agent should bind to the input address, False to connect. """ super().__init__(settings.agent_settings.vad_name) self.audio_in_address = audio_in_address self.audio_in_bind = audio_in_bind self.audio_in_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None self.audio_in_poller: SocketPoller | None = None self.program_sub_socket: azmq.Socket | None = None self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech self._ready = asyncio.Event() # Pause control self._reset_needed = False self._paused = asyncio.Event() self._paused.set() # Not paused at start self.model = None async def setup(self): """ Initialize resources. 1. Connects audio input socket. 2. Binds audio output socket (random port). 3. Connects to program communication socket. 4. Loads VAD model from Torch Hub. 5. Starts the streaming loop. 6. Instantiates and starts the :class:`TranscriptionAgent` with the output address. """ self.logger.info("Setting up %s", self.name) self._connect_audio_in_socket() audio_out_port = self._connect_audio_out_socket() if audio_out_port is None: self.logger.error("Could not bind output socket, stopping.") await self.stop() return audio_out_address = f"tcp://localhost:{audio_out_port}" # Connect to internal communication socket self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB) self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address) self.program_sub_socket.subscribe(PROGRAM_STATUS) # Initialize VAD model try: self.model, _ = torch.hub.load( repo_or_dir=settings.vad_settings.repo_or_dir, model=settings.vad_settings.model_name, force_reload=False, ) except Exception: self.logger.exception("Failed to load VAD model.") await self.stop() return self.add_behavior(self._streaming_loop()) self.add_behavior(self._status_loop()) # Start agents dependent on the output audio fragments here transcriber = TranscriptionAgent(audio_out_address) await transcriber.start() self.logger.info("Finished setting up %s", self.name) async def stop(self): """ Stop listening to audio, stop publishing audio, close sockets. """ if self.audio_in_socket is not None: self.audio_in_socket.close() self.audio_in_socket = None if self.audio_out_socket is not None: self.audio_out_socket.close() self.audio_out_socket = None await super().stop() def _connect_audio_in_socket(self): """ Connects (or binds) the socket for listening to audio from RI. :return: """ self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") if self.audio_in_bind: self.audio_in_socket.bind(self.audio_in_address) else: self.audio_in_socket.connect(self.audio_in_address) self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket) def _connect_audio_out_socket(self) -> int | None: """ Returns the port bound, or None if binding failed. """ try: self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100) except zmq.ZMQBindError: self.logger.error("Failed to bind an audio output socket after 100 tries.") self.audio_out_socket = None return None async def _reset_stream(self): """ Clears the ZeroMQ queue and sets ready state. """ discarded = 0 assert self.audio_in_poller is not None while await self.audio_in_poller.poll(1) is not None: discarded += 1 self.logger.info(f"Discarded {discarded} audio packets before starting.") self._ready.set() async def _status_loop(self): """Loop for checking program status. Only start listening if program is RUNNING.""" while self._running: topic, body = await self.program_sub_socket.recv_multipart() if topic != PROGRAM_STATUS: continue if body != ProgramStatus.RUNNING.value: continue # Program is now running, we can start our stream await self._reset_stream() # We don't care about further status updates self.program_sub_socket.close() break async def _streaming_loop(self): """ Main loop for processing audio stream. 1. Polls for new audio chunks. 2. Passes chunk to VAD model. 3. Manages `i_since_speech` counter to determine start/end of speech. 4. Buffers speech + context. 5. Sends complete speech segment to output socket when silence is detected. """ await self._ready.wait() while self._running: await self._paused.wait() # After being unpaused, reset stream and buffers if self._reset_needed: self.logger.debug("Resuming: resetting stream and buffers.") await self._reset_stream() self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech self._reset_needed = False assert self.audio_in_poller is not None data = await self.audio_in_poller.poll() if data is None: if len(self.audio_buffer) > 0: self.logger.debug( "No audio data received. Discarding buffer until new data arrives." ) self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech continue # copy otherwise Torch will be sad that it's immutable chunk = np.frombuffer(data, dtype=np.float32).copy() assert self.model is not None prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item() non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks prob_threshold = settings.behaviour_settings.vad_prob_threshold if prob > prob_threshold: if self.i_since_speech > non_speech_patience: self.logger.debug("Speech started.") self.audio_buffer = np.append(self.audio_buffer, chunk) self.i_since_speech = 0 continue self.i_since_speech += 1 # prob < threshold, so speech maybe ended. Wait a bit more before to be more certain if self.i_since_speech <= non_speech_patience: self.audio_buffer = np.append(self.audio_buffer, chunk) continue # Speech probably ended. Make sure we have a usable amount of data. if len(self.audio_buffer) >= 3 * len(chunk): self.logger.debug("Speech ended.") assert self.audio_out_socket is not None await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes()) # At this point, we know that the speech has ended. # Prepend the last chunk that had no speech, for a more fluent boundary self.audio_buffer = chunk async def handle_message(self, msg: InternalMessage): """ Handle incoming messages. Expects messages to pause or resume the VAD processing from User Interrupt Agent. :param msg: The received internal message. """ sender = msg.sender if sender == settings.agent_settings.user_interrupt_name: if msg.body == "PAUSE": self.logger.info("Pausing VAD processing.") self._paused.clear() # If the robot needs to pick up speaking where it left off, do not set _reset_needed self._reset_needed = True elif msg.body == "RESUME": self.logger.info("Resuming VAD processing.") self._paused.set() else: self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}") else: self.logger.debug(f"Ignoring message from unknown sender: {sender}")