diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py new file mode 100644 index 0000000..10e1d1e --- /dev/null +++ b/src/control_backend/agents/vad_agent.py @@ -0,0 +1,142 @@ +import logging + +import numpy as np +import torch +import zmq +from spade.agent import Agent +from spade.behaviour import CyclicBehaviour + +from control_backend.core.config import settings +from control_backend.core.zmq_context import context as zmq_context + +logger = logging.getLogger(__name__) + + +class SocketPoller[T]: + def __init__(self, socket: zmq.Socket[T]): + self.socket = socket + self.poller = zmq.Poller() + self.poller.register(self.socket, zmq.POLLIN) + + async def poll(self, timeout_ms: int) -> T | None: + """ + Get data from the socket, or None if the timeout is reached. + + :param timeout_ms: The number of milliseconds to wait for the socket. + :return: Data from the socket or None. + """ + socks = dict(self.poller.poll(timeout_ms)) + if socks.get(self.socket) == zmq.POLLIN: + return await self.socket.recv() + return None + + +class VADAgent(Agent): + """ + An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends + fragments with detected speech to other agents over ZeroMQ. + """ + def __init__(self, audio_in_address: str, audio_in_bind: bool): + jid = settings.agent_settings.vad_agent_name + '@' + settings.agent_settings.host + super().__init__(jid, settings.agent_settings.vad_agent_name) + + self.audio_in_address = audio_in_address + self.audio_in_bind = audio_in_bind + + self.audio_in_socket: zmq.Socket | None = None + self.audio_out_socket: zmq.Socket | None = None + + class Stream(CyclicBehaviour): + def __init__(self, audio_in_socket: zmq.Socket, audio_out_socket: zmq.Socket): + super().__init__() + self.audio_in_poller = SocketPoller[bytes](audio_in_socket) + self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", + model="silero_vad", + force_reload=False) + self.audio_out_socket = audio_out_socket + + self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor + self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops + self.i_since_speech = 0 # Used to allow small pauses in speech + + async def run(self) -> None: + timeout_ms = 100 + data = await self.audio_in_poller.poll(timeout_ms) + if data is None: + if self.i_since_data % 10 == 0: + logger.debug("Failed to receive audio from socket for %d ms.", + timeout_ms*self.i_since_data) + self.i_since_data += 1 + return + self.i_since_data = 0 + + # copy otherwise Torch will be sad that it's immutable + chunk = np.frombuffer(data, dtype=np.float32).copy() + prob = self.model(torch.from_numpy(chunk), 16000).item() + + if prob > 0.5: + if self.i_since_speech > 3: logger.debug("Speech started.") + self.audio_buffer = np.append(self.audio_buffer, chunk) + self.i_since_speech = 0 + return + self.i_since_speech += 1 + + # prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain + if self.i_since_speech <= 3: + self.audio_buffer = np.append(self.audio_buffer, chunk) + return + + # Speech probably ended. Make sure we have a usable amount of data. + if len(self.audio_buffer) >= 3*len(chunk): + logger.debug("Speech ended.") + self.audio_out_socket.send(self.audio_buffer) + + # At this point, we know that the speech has ended. + # Prepend the last chunk that had no speech, for a more fluent boundary + self.audio_buffer = chunk + + async def stop(self): + """ + Stop listening to audio, stop publishing audio, close sockets. + """ + self.audio_in_socket.close() + self.audio_in_socket = None + self.audio_out_socket.close() + self.audio_out_socket = None + return await super().stop() + + def _connect_audio_in_socket(self): + self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") + if self.audio_in_bind: + self.audio_in_socket.bind(self.audio_in_address) + else: + self.audio_in_socket.connect(self.audio_in_address) + self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket) + + def _connect_audio_out_socket(self) -> int | None: + """Returns the port bound, or None if binding failed.""" + try: + self.audio_out_socket = zmq_context.socket(zmq.PUB) + return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) + except zmq.ZMQBindError: + logger.error("Failed to bind an audio output socket after 100 tries.") + self.audio_out_socket = None + return None + + async def setup(self): + logger.info("Setting up %s", self.jid) + + self._connect_audio_in_socket() + + audio_out_port = self._connect_audio_out_socket() + if audio_out_port is None: + await self.stop() + return + + stream = self.Stream(self.audio_in_socket, self.audio_out_socket) + self.add_behaviour(stream) + + # ... start agents dependent on the output audio fragments here + + logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 07a828d..147c6aa 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -2,14 +2,20 @@ from re import L from pydantic import BaseModel from pydantic_settings import BaseSettings, SettingsConfigDict + class ZMQSettings(BaseModel): internal_comm_address: str = "tcp://localhost:5560" + audio_fragment_port: int = 5561 + audio_fragment_address: str = f"tcp://localhost:{audio_fragment_port}" + + class AgentSettings(BaseModel): host: str = "localhost" bdi_core_agent_name: str = "bdi_core" belief_collector_agent_name: str = "belief_collector" - test_agent_name: str = "test_agent" + vad_agent_name: str = "vad_agent" + class Settings(BaseSettings): app_title: str = "PepperPlus" @@ -21,5 +27,6 @@ class Settings(BaseSettings): agent_settings: AgentSettings = AgentSettings() model_config = SettingsConfigDict(env_file=".env") - + + settings = Settings() diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 1f377c4..8b1e9e3 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -13,6 +13,7 @@ import zmq # Internal imports from control_backend.agents.bdi.bdi_core import BDICoreAgent +from control_backend.agents.vad_agent import VADAgent from control_backend.api.v1.router import api_router from control_backend.core.config import AgentSettings, settings from control_backend.core.zmq_context import context @@ -34,6 +35,9 @@ async def lifespan(app: FastAPI): # Initiate agents bdi_core = BDICoreAgent(settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name, "src/control_backend/agents/bdi/rules.asl") await bdi_core.start() + + _temp_vad_agent = VADAgent("tcp://localhost:5558", False) + await _temp_vad_agent.start() yield