test: add first unit test for VAD agent

Mocking audio input probabilities, checking whether it publishes audio data on the output socket.

ref: N25B-213
This commit is contained in:
Twirre Meulenbelt
2025-10-23 17:40:47 +02:00
parent 6391af883a
commit ca5e59d029
4 changed files with 102 additions and 57 deletions

View File

@@ -3,6 +3,7 @@ import logging
import numpy as np import numpy as np
import torch import torch
import zmq import zmq
import zmq.asyncio as azmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
class SocketPoller[T]: class SocketPoller[T]:
def __init__(self, socket: zmq.Socket[T]): def __init__(self, socket: azmq.Socket):
self.socket = socket self.socket = socket
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN) self.poller.register(self.socket, zmq.POLLIN)
@@ -31,6 +32,56 @@ class SocketPoller[T]:
return None return None
class Streaming(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
self.i_since_speech = 0 # Used to allow small pauses in speech
async def run(self) -> None:
timeout_ms = 100
data = await self.audio_in_poller.poll(timeout_ms)
if data is None:
if self.i_since_data % 10 == 0:
logger.debug("Failed to receive audio from socket for %d ms.",
timeout_ms*self.i_since_data)
self.i_since_data += 1
return
self.i_since_data = 0
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3: logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3*len(chunk):
logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer.tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(Agent): class VADAgent(Agent):
""" """
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
@@ -43,57 +94,8 @@ class VADAgent(Agent):
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind self.audio_in_bind = audio_in_bind
self.audio_in_socket: zmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: zmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None
class Stream(CyclicBehaviour):
def __init__(self, audio_in_socket: zmq.Socket, audio_out_socket: zmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
self.i_since_speech = 0 # Used to allow small pauses in speech
async def run(self) -> None:
timeout_ms = 100
data = await self.audio_in_poller.poll(timeout_ms)
if data is None:
if self.i_since_data % 10 == 0:
logger.debug("Failed to receive audio from socket for %d ms.",
timeout_ms*self.i_since_data)
self.i_since_data += 1
return
self.i_since_data = 0
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3: logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3*len(chunk):
logger.debug("Speech ended.")
self.audio_out_socket.send(self.audio_buffer)
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
async def stop(self): async def stop(self):
""" """
@@ -133,9 +135,10 @@ class VADAgent(Agent):
if audio_out_port is None: if audio_out_port is None:
await self.stop() await self.stop()
return return
audio_out_address = f"tcp://localhost:{audio_out_port}"
stream = self.Stream(self.audio_in_socket, self.audio_out_socket) streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(stream) self.add_behaviour(streaming)
# ... start agents dependent on the output audio fragments here # ... start agents dependent on the output audio fragments here

View File

@@ -6,9 +6,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel): class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560" internal_comm_address: str = "tcp://localhost:5560"
audio_fragment_port: int = 5561
audio_fragment_address: str = f"tcp://localhost:{audio_fragment_port}"
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
host: str = "localhost" host: str = "localhost"

View File

@@ -0,0 +1,45 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.vad_agent import Streaming
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket):
return Streaming(audio_in_socket, audio_out_socket)
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
# After three chunks of audio with speech probability of 1.0, then four chunks of audio with
# speech probability of 0.0, it should send a message over the audio out socket
probabilities = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
for _ in probabilities:
await streaming.run()
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# each sample has 512 frames of 4 bytes, expecting 5 chunks (3 with speech, 2 as padding)
assert len(data) == 512*4*5