test: add first unit test for VAD agent

Mocking audio input probabilities, checking whether it publishes audio data on the output socket.

ref: N25B-213
This commit is contained in:
Twirre Meulenbelt
2025-10-23 17:40:47 +02:00
parent 6391af883a
commit ca5e59d029
4 changed files with 102 additions and 57 deletions

View File

@@ -3,6 +3,7 @@ import logging
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
class SocketPoller[T]:
def __init__(self, socket: zmq.Socket[T]):
def __init__(self, socket: azmq.Socket):
self.socket = socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
@@ -31,6 +32,56 @@ class SocketPoller[T]:
return None
class Streaming(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
self.i_since_speech = 0 # Used to allow small pauses in speech
async def run(self) -> None:
timeout_ms = 100
data = await self.audio_in_poller.poll(timeout_ms)
if data is None:
if self.i_since_data % 10 == 0:
logger.debug("Failed to receive audio from socket for %d ms.",
timeout_ms*self.i_since_data)
self.i_since_data += 1
return
self.i_since_data = 0
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3: logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3*len(chunk):
logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer.tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(Agent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
@@ -43,57 +94,8 @@ class VADAgent(Agent):
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: zmq.Socket | None = None
self.audio_out_socket: zmq.Socket | None = None
class Stream(CyclicBehaviour):
def __init__(self, audio_in_socket: zmq.Socket, audio_out_socket: zmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
self.i_since_speech = 0 # Used to allow small pauses in speech
async def run(self) -> None:
timeout_ms = 100
data = await self.audio_in_poller.poll(timeout_ms)
if data is None:
if self.i_since_data % 10 == 0:
logger.debug("Failed to receive audio from socket for %d ms.",
timeout_ms*self.i_since_data)
self.i_since_data += 1
return
self.i_since_data = 0
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3: logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3*len(chunk):
logger.debug("Speech ended.")
self.audio_out_socket.send(self.audio_buffer)
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
async def stop(self):
"""
@@ -133,9 +135,10 @@ class VADAgent(Agent):
if audio_out_port is None:
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
stream = self.Stream(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(stream)
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming)
# ... start agents dependent on the output audio fragments here

View File

@@ -6,9 +6,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560"
audio_fragment_port: int = 5561
audio_fragment_address: str = f"tcp://localhost:{audio_fragment_port}"
class AgentSettings(BaseModel):
host: str = "localhost"