From ca5e59d0290df875e1035bd770366257f422ec83 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:40:47 +0200 Subject: [PATCH] test: add first unit test for VAD agent Mocking audio input probabilities, checking whether it publishes audio data on the output socket. ref: N25B-213 --- src/control_backend/agents/vad_agent.py | 111 ++++++++++++------------ src/control_backend/core/config.py | 3 - test/unit/agents/test_vad_streaming.py | 45 ++++++++++ test/{ => unit}/conftest.py | 0 4 files changed, 102 insertions(+), 57 deletions(-) create mode 100644 test/unit/agents/test_vad_streaming.py rename test/{ => unit}/conftest.py (100%) diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index 10e1d1e..f0325c2 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -3,6 +3,7 @@ import logging import numpy as np import torch import zmq +import zmq.asyncio as azmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour @@ -13,7 +14,7 @@ logger = logging.getLogger(__name__) class SocketPoller[T]: - def __init__(self, socket: zmq.Socket[T]): + def __init__(self, socket: azmq.Socket): self.socket = socket self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN) @@ -31,6 +32,56 @@ class SocketPoller[T]: return None +class Streaming(CyclicBehaviour): + def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket): + super().__init__() + self.audio_in_poller = SocketPoller[bytes](audio_in_socket) + self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", + model="silero_vad", + force_reload=False) + self.audio_out_socket = audio_out_socket + + self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor + self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops + self.i_since_speech = 0 # Used to allow small pauses in speech + + async def run(self) -> None: + timeout_ms = 100 + data = await self.audio_in_poller.poll(timeout_ms) + if data is None: + if self.i_since_data % 10 == 0: + logger.debug("Failed to receive audio from socket for %d ms.", + timeout_ms*self.i_since_data) + self.i_since_data += 1 + return + self.i_since_data = 0 + + # copy otherwise Torch will be sad that it's immutable + chunk = np.frombuffer(data, dtype=np.float32).copy() + prob = self.model(torch.from_numpy(chunk), 16000).item() + + if prob > 0.5: + if self.i_since_speech > 3: logger.debug("Speech started.") + self.audio_buffer = np.append(self.audio_buffer, chunk) + self.i_since_speech = 0 + return + self.i_since_speech += 1 + + # prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain + if self.i_since_speech <= 3: + self.audio_buffer = np.append(self.audio_buffer, chunk) + return + + # Speech probably ended. Make sure we have a usable amount of data. + if len(self.audio_buffer) >= 3*len(chunk): + logger.debug("Speech ended.") + await self.audio_out_socket.send(self.audio_buffer.tobytes()) + + # At this point, we know that the speech has ended. + # Prepend the last chunk that had no speech, for a more fluent boundary + self.audio_buffer = chunk + + class VADAgent(Agent): """ An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends @@ -43,57 +94,8 @@ class VADAgent(Agent): self.audio_in_address = audio_in_address self.audio_in_bind = audio_in_bind - self.audio_in_socket: zmq.Socket | None = None - self.audio_out_socket: zmq.Socket | None = None - - class Stream(CyclicBehaviour): - def __init__(self, audio_in_socket: zmq.Socket, audio_out_socket: zmq.Socket): - super().__init__() - self.audio_in_poller = SocketPoller[bytes](audio_in_socket) - self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", - model="silero_vad", - force_reload=False) - self.audio_out_socket = audio_out_socket - - self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor - self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops - self.i_since_speech = 0 # Used to allow small pauses in speech - - async def run(self) -> None: - timeout_ms = 100 - data = await self.audio_in_poller.poll(timeout_ms) - if data is None: - if self.i_since_data % 10 == 0: - logger.debug("Failed to receive audio from socket for %d ms.", - timeout_ms*self.i_since_data) - self.i_since_data += 1 - return - self.i_since_data = 0 - - # copy otherwise Torch will be sad that it's immutable - chunk = np.frombuffer(data, dtype=np.float32).copy() - prob = self.model(torch.from_numpy(chunk), 16000).item() - - if prob > 0.5: - if self.i_since_speech > 3: logger.debug("Speech started.") - self.audio_buffer = np.append(self.audio_buffer, chunk) - self.i_since_speech = 0 - return - self.i_since_speech += 1 - - # prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain - if self.i_since_speech <= 3: - self.audio_buffer = np.append(self.audio_buffer, chunk) - return - - # Speech probably ended. Make sure we have a usable amount of data. - if len(self.audio_buffer) >= 3*len(chunk): - logger.debug("Speech ended.") - self.audio_out_socket.send(self.audio_buffer) - - # At this point, we know that the speech has ended. - # Prepend the last chunk that had no speech, for a more fluent boundary - self.audio_buffer = chunk + self.audio_in_socket: azmq.Socket | None = None + self.audio_out_socket: azmq.Socket | None = None async def stop(self): """ @@ -133,9 +135,10 @@ class VADAgent(Agent): if audio_out_port is None: await self.stop() return + audio_out_address = f"tcp://localhost:{audio_out_port}" - stream = self.Stream(self.audio_in_socket, self.audio_out_socket) - self.add_behaviour(stream) + streaming = Streaming(self.audio_in_socket, self.audio_out_socket) + self.add_behaviour(streaming) # ... start agents dependent on the output audio fragments here diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 147c6aa..093a64e 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -6,9 +6,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class ZMQSettings(BaseModel): internal_comm_address: str = "tcp://localhost:5560" - audio_fragment_port: int = 5561 - audio_fragment_address: str = f"tcp://localhost:{audio_fragment_port}" - class AgentSettings(BaseModel): host: str = "localhost" diff --git a/test/unit/agents/test_vad_streaming.py b/test/unit/agents/test_vad_streaming.py new file mode 100644 index 0000000..c48626d --- /dev/null +++ b/test/unit/agents/test_vad_streaming.py @@ -0,0 +1,45 @@ +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest + +from control_backend.agents.vad_agent import Streaming + + +@pytest.fixture +def audio_in_socket(): + return AsyncMock() + + +@pytest.fixture +def audio_out_socket(): + return AsyncMock() + + +@pytest.fixture +def streaming(audio_in_socket, audio_out_socket): + return Streaming(audio_in_socket, audio_out_socket) + + +@pytest.mark.asyncio +async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming): + # After three chunks of audio with speech probability of 1.0, then four chunks of audio with + # speech probability of 0.0, it should send a message over the audio out socket + probabilities = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0] + model_item = MagicMock() + model_item.item.side_effect = probabilities + streaming.model = MagicMock() + streaming.model.return_value = model_item + + audio_in_poller = AsyncMock() + audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32) + streaming.audio_in_poller = audio_in_poller + + for _ in probabilities: + await streaming.run() + + audio_out_socket.send.assert_called_once() + data = audio_out_socket.send.call_args[0][0] + assert isinstance(data, bytes) + # each sample has 512 frames of 4 bytes, expecting 5 chunks (3 with speech, 2 as padding) + assert len(data) == 512*4*5 diff --git a/test/conftest.py b/test/unit/conftest.py similarity index 100% rename from test/conftest.py rename to test/unit/conftest.py