test: make VAD tests work again

ref: N25B-301
This commit is contained in:
Twirre Meulenbelt
2025-11-20 16:33:12 +01:00
parent 610c4b526d
commit 0493d390e3
5 changed files with 137 additions and 72 deletions

View File

@@ -1,9 +1,8 @@
import random
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from spade.agent import Agent
from control_backend.agents.perception.vad_agent import VADAgent
@@ -15,11 +14,6 @@ def zmq_context(mocker):
return mock_context
@pytest.fixture
def streaming(mocker):
return mocker.patch("control_backend.agents.perception.vad_agent.StreamingBehaviour")
@pytest.fixture
def per_transcription_agent(mocker):
return mocker.patch(
@@ -27,21 +21,36 @@ def per_transcription_agent(mocker):
)
@pytest.fixture(autouse=True)
def torch_load(mocker):
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
model = MagicMock()
mock_torch.hub.load.return_value = (model, None)
mock_torch.from_numpy.side_effect = lambda arr: arr
return mock_torch
@pytest.mark.asyncio
async def test_normal_setup(streaming, per_transcription_agent):
async def test_normal_setup(per_transcription_agent):
"""
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets, and starts the TranscriptionAgent without loading real models.
"""
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.add_behaviour = MagicMock()
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
coro.close()
per_vad_agent.add_background_task = swallow_background_task
per_vad_agent.reset_stream = AsyncMock()
await per_vad_agent.setup()
streaming.assert_called_once()
per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
per_transcription_agent.assert_called_once()
per_transcription_agent.return_value.start.assert_called_once()
per_vad_agent._streaming_loop.assert_called_once()
per_vad_agent.reset_stream.assert_called_once()
assert per_vad_agent.audio_in_socket is not None
assert per_vad_agent.audio_out_socket is not None
@@ -91,16 +100,22 @@ async def test_out_socket_creation_failure(zmq_context):
"""
Test setup failure when the audio output socket cannot be created.
"""
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError
)
per_vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.stop = AsyncMock()
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
await per_vad_agent.setup()
async def swallow_background_task(coro):
coro.close()
assert per_vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once()
per_vad_agent.add_background_task = swallow_background_task
await per_vad_agent.setup()
assert per_vad_agent.audio_out_socket is None
per_vad_agent.stop.assert_called_once()
@pytest.mark.asyncio
@@ -109,6 +124,13 @@ async def test_stop(zmq_context, per_transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
coro.close()
per_vad_agent.add_background_task = swallow_background_task
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000,
10000,

View File

@@ -5,7 +5,24 @@ import pytest
import soundfile as sf
import zmq
from control_backend.agents.perception.vad_agent import StreamingBehaviour
from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture(autouse=True)
def patch_settings():
from control_backend.agents.perception import vad_agent
vad_agent.settings.behaviour_settings.vad_prob_threshold = 0.5
vad_agent.settings.behaviour_settings.vad_non_speech_patience_chunks = 3
vad_agent.settings.behaviour_settings.vad_initial_since_speech = 0
vad_agent.settings.vad_settings.sample_rate_hz = 16_000
@pytest.fixture(autouse=True)
def mock_torch(mocker):
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
mock_torch.from_numpy.side_effect = lambda arr: arr
return mock_torch
def get_audio_chunks() -> list[bytes]:
@@ -42,16 +59,39 @@ async def test_real_audio(mocker):
audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll = AsyncMock(return_value=[(audio_in_socket, zmq.POLLIN)])
audio_out_socket = AsyncMock()
vad_streamer = StreamingBehaviour(audio_in_socket, audio_out_socket)
vad_streamer._ready = True
vad_streamer.agent = MagicMock()
for _ in audio_chunks:
await vad_streamer.run()
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent.audio_out_socket = audio_out_socket
# Use a fake model that marks most chunks as speech and ends with a few silences
silence_padding = 5
probabilities = [1.0] * len(audio_chunks) + [0.0] * silence_padding
chunk_bytes = audio_chunks + [b"\x00" * len(audio_chunks[0])] * silence_padding
model_item = MagicMock()
model_item.item.side_effect = probabilities
vad_agent.model = MagicMock(return_value=model_item)
class DummyPoller:
def __init__(self, data, agent):
self.data = data
self.agent = agent
async def poll(self, timeout_ms=None):
if self.data:
return self.data.pop(0)
self.agent._running = False
return None
vad_agent.audio_in_poller = DummyPoller(chunk_bytes, vad_agent)
vad_agent._ready = True
vad_agent._running = True
vad_agent.i_since_speech = 0
await vad_agent._streaming_loop()
audio_out_socket.send.assert_called()
for args in audio_out_socket.send.call_args_list:

View File

@@ -2,11 +2,10 @@ import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from spade.message import Message
from control_backend.agents.bdi.text_belief_extractor_agent.behaviours.text_belief_extractor_behaviour import ( # noqa: E501, We can't shorten this import.
TextBeliefExtractorBehaviour,
)
from spade.message import Message
@pytest.fixture

View File

@@ -16,8 +16,8 @@ async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test"
socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll = AsyncMock(return_value=[(socket, zmq.POLLIN)])
poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused
@@ -35,8 +35,8 @@ async def test_socket_poller_with_data(socket, mocker):
@pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = []
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll = AsyncMock(return_value=[])
poller = SocketPoller(socket)
data = await poller.poll()

View File

@@ -3,12 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.perception.vad_agent import StreamingBehaviour
@pytest.fixture
def audio_in_socket():
return AsyncMock()
from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture
@@ -17,22 +12,8 @@ def audio_out_socket():
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock BDIAgent."""
agent = MagicMock()
agent.jid = "vad_agent@test"
return agent
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket, mock_agent):
import torch
torch.hub.load.return_value = (..., ...) # Mock
streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
streaming._ready = True
streaming.agent = mock_agent
return streaming
def vad_agent(audio_out_socket):
return VADAgent("tcp://localhost:5555", False)
@pytest.fixture(autouse=True)
@@ -61,25 +42,40 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
streaming.model = MagicMock(return_value=model_item)
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
# Prepare deterministic audio chunks and a poller that stops the loop when exhausted
chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes()
chunks = [chunk_bytes for _ in probabilities]
for _ in probabilities:
await streaming.run()
class DummyPoller:
def __init__(self, data, agent):
self.data = data
self.agent = agent
async def poll(self, timeout_ms=None):
if self.data:
return self.data.pop(0)
# Stop the loop cleanly once we've consumed all chunks
self.agent._running = False
return None
streaming.audio_in_poller = DummyPoller(chunks, streaming)
streaming._ready = True
streaming._running = True
await streaming._streaming_loop()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
async def test_voice_activity_detected(audio_out_socket, vad_agent):
"""
Test a scenario where there is voice activity detected between silences.
"""
speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
await simulate_streaming_with_probabilities(streaming, probabilities)
vad_agent.audio_out_socket = audio_out_socket
await simulate_streaming_with_probabilities(vad_agent, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
@@ -88,7 +84,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
@@ -97,7 +93,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
)
await simulate_streaming_with_probabilities(streaming, probabilities)
vad_agent.audio_out_socket = audio_out_socket
await simulate_streaming_with_probabilities(vad_agent, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
@@ -107,15 +104,22 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
@pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
async def test_no_data(audio_out_socket, vad_agent):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
await streaming.run()
class DummyPoller:
async def poll(self, timeout_ms=None):
vad_agent._running = False
return None
vad_agent.audio_out_socket = audio_out_socket
vad_agent.audio_in_poller = DummyPoller()
vad_agent._ready = True
vad_agent._running = True
await vad_agent._streaming_loop()
audio_out_socket.send.assert_not_called()
assert len(streaming.audio_buffer) == 0
assert len(vad_agent.audio_buffer) == 0