import numpy as np import torch import zmq import zmq.asyncio as azmq from spade.behaviour import CyclicBehaviour from control_backend.agents import BaseAgent from control_backend.core.config import settings from .transcription.transcription_agent import TranscriptionAgent class SocketPoller[T]: """ Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for multiple usages. """ def __init__(self, socket: azmq.Socket, timeout_ms: int = 100): """ :param socket: The socket to poll and get data from. :param timeout_ms: A timeout in milliseconds to wait for data. """ self.socket = socket self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN) self.timeout_ms = timeout_ms async def poll(self, timeout_ms: int | None = None) -> T | None: """ Get data from the socket, or None if the timeout is reached. :param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used. :return: Data from the socket or None. """ timeout_ms = timeout_ms or self.timeout_ms socks = dict(self.poller.poll(timeout_ms)) if socks.get(self.socket) == zmq.POLLIN: return await self.socket.recv() return None class Streaming(CyclicBehaviour): def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket): super().__init__() self.audio_in_poller = SocketPoller[bytes](audio_in_socket) self.model, _ = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False ) self.audio_out_socket = audio_out_socket self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = 100 # Used to allow small pauses in speech self._ready = False async def reset(self): """Clears the ZeroMQ queue and tells this behavior to start.""" discarded = 0 while await self.audio_in_poller.poll(1) is not None: discarded += 1 self.agent.logger.info(f"Discarded {discarded} audio packets before starting.") self._ready = True async def run(self) -> None: if not self._ready: return data = await self.audio_in_poller.poll() if data is None: if len(self.audio_buffer) > 0: self.agent.logger.debug( "No audio data received. Discarding buffer until new data arrives." ) self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = 100 return # copy otherwise Torch will be sad that it's immutable chunk = np.frombuffer(data, dtype=np.float32).copy() prob = self.model(torch.from_numpy(chunk), 16000).item() if prob > 0.5: if self.i_since_speech > 3: self.agent.logger.debug("Speech started.") self.audio_buffer = np.append(self.audio_buffer, chunk) self.i_since_speech = 0 return self.i_since_speech += 1 # prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain if self.i_since_speech <= 3: self.audio_buffer = np.append(self.audio_buffer, chunk) return # Speech probably ended. Make sure we have a usable amount of data. if len(self.audio_buffer) >= 3 * len(chunk): self.agent.logger.debug("Speech ended.") await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes()) # At this point, we know that the speech has ended. # Prepend the last chunk that had no speech, for a more fluent boundary self.audio_buffer = chunk class VADAgent(BaseAgent): """ An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends fragments with detected speech to other agents over ZeroMQ. """ def __init__(self, audio_in_address: str, audio_in_bind: bool): jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host super().__init__(jid, settings.agent_settings.vad_agent_name) self.audio_in_address = audio_in_address self.audio_in_bind = audio_in_bind self.audio_in_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None self.streaming_behaviour: Streaming | None = None async def stop(self): """ Stop listening to audio, stop publishing audio, close sockets. """ if self.audio_in_socket is not None: self.audio_in_socket.close() self.audio_in_socket = None if self.audio_out_socket is not None: self.audio_out_socket.close() self.audio_out_socket = None return await super().stop() def _connect_audio_in_socket(self): self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") if self.audio_in_bind: self.audio_in_socket.bind(self.audio_in_address) else: self.audio_in_socket.connect(self.audio_in_address) self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket) def _connect_audio_out_socket(self) -> int | None: """Returns the port bound, or None if binding failed.""" try: self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) except zmq.ZMQBindError: self.logger.error("Failed to bind an audio output socket after 100 tries.") self.audio_out_socket = None return None async def setup(self): self.logger.info("Setting up %s", self.jid) self._connect_audio_in_socket() audio_out_port = self._connect_audio_out_socket() if audio_out_port is None: await self.stop() return audio_out_address = f"tcp://localhost:{audio_out_port}" self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket) self.add_behaviour(self.streaming_behaviour) # Start agents dependent on the output audio fragments here transcriber = TranscriptionAgent(audio_out_address) await transcriber.start() self.logger.info("Finished setting up %s", self.jid)