From 0493d390e37d561fa548c8f38453749ec64bc506 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:33:12 +0100 Subject: [PATCH] test: make VAD tests work again ref: N25B-301 --- .../perception/vad_agent/test_vad_agent.py | 60 +++++++++----- .../vad_agent/test_vad_with_audio.py | 56 +++++++++++-- .../behaviours/test_belief_from_text.py | 3 +- .../vad_agent/test_vad_socket_poller.py | 8 +- .../vad_agent/test_vad_streaming.py | 82 ++++++++++--------- 5 files changed, 137 insertions(+), 72 deletions(-) diff --git a/test/integration/agents/perception/vad_agent/test_vad_agent.py b/test/integration/agents/perception/vad_agent/test_vad_agent.py index ecf9634..20a388c 100644 --- a/test/integration/agents/perception/vad_agent/test_vad_agent.py +++ b/test/integration/agents/perception/vad_agent/test_vad_agent.py @@ -1,9 +1,8 @@ import random -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest import zmq -from spade.agent import Agent from control_backend.agents.perception.vad_agent import VADAgent @@ -15,11 +14,6 @@ def zmq_context(mocker): return mock_context -@pytest.fixture -def streaming(mocker): - return mocker.patch("control_backend.agents.perception.vad_agent.StreamingBehaviour") - - @pytest.fixture def per_transcription_agent(mocker): return mocker.patch( @@ -27,21 +21,36 @@ def per_transcription_agent(mocker): ) +@pytest.fixture(autouse=True) +def torch_load(mocker): + mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch") + model = MagicMock() + mock_torch.hub.load.return_value = (model, None) + mock_torch.from_numpy.side_effect = lambda arr: arr + return mock_torch + + @pytest.mark.asyncio -async def test_normal_setup(streaming, per_transcription_agent): +async def test_normal_setup(per_transcription_agent): """ Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio sockets, and starts the TranscriptionAgent without loading real models. """ per_vad_agent = VADAgent("tcp://localhost:12345", False) - per_vad_agent.add_behaviour = MagicMock() + per_vad_agent._streaming_loop = AsyncMock() + + async def swallow_background_task(coro): + coro.close() + + per_vad_agent.add_background_task = swallow_background_task + per_vad_agent.reset_stream = AsyncMock() await per_vad_agent.setup() - streaming.assert_called_once() - per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value) per_transcription_agent.assert_called_once() per_transcription_agent.return_value.start.assert_called_once() + per_vad_agent._streaming_loop.assert_called_once() + per_vad_agent.reset_stream.assert_called_once() assert per_vad_agent.audio_in_socket is not None assert per_vad_agent.audio_out_socket is not None @@ -91,16 +100,22 @@ async def test_out_socket_creation_failure(zmq_context): """ Test setup failure when the audio output socket cannot be created. """ - with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: - zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = ( - zmq.ZMQBindError - ) - per_vad_agent = VADAgent("tcp://localhost:12345", False) + zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + per_vad_agent = VADAgent("tcp://localhost:12345", False) + per_vad_agent.stop = AsyncMock() + per_vad_agent.reset_stream = AsyncMock() + per_vad_agent._streaming_loop = AsyncMock() + per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None) - await per_vad_agent.setup() + async def swallow_background_task(coro): + coro.close() - assert per_vad_agent.audio_out_socket is None - mock_super_stop.assert_called_once() + per_vad_agent.add_background_task = swallow_background_task + + await per_vad_agent.setup() + + assert per_vad_agent.audio_out_socket is None + per_vad_agent.stop.assert_called_once() @pytest.mark.asyncio @@ -109,6 +124,13 @@ async def test_stop(zmq_context, per_transcription_agent): Test that when the VAD agent is stopped, the sockets are closed correctly. """ per_vad_agent = VADAgent("tcp://localhost:12345", False) + per_vad_agent.reset_stream = AsyncMock() + per_vad_agent._streaming_loop = AsyncMock() + + async def swallow_background_task(coro): + coro.close() + + per_vad_agent.add_background_task = swallow_background_task zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( 1000, 10000, diff --git a/test/integration/agents/perception/vad_agent/test_vad_with_audio.py b/test/integration/agents/perception/vad_agent/test_vad_with_audio.py index b197c31..ab10b5f 100644 --- a/test/integration/agents/perception/vad_agent/test_vad_with_audio.py +++ b/test/integration/agents/perception/vad_agent/test_vad_with_audio.py @@ -5,7 +5,24 @@ import pytest import soundfile as sf import zmq -from control_backend.agents.perception.vad_agent import StreamingBehaviour +from control_backend.agents.perception.vad_agent import VADAgent + + +@pytest.fixture(autouse=True) +def patch_settings(): + from control_backend.agents.perception import vad_agent + + vad_agent.settings.behaviour_settings.vad_prob_threshold = 0.5 + vad_agent.settings.behaviour_settings.vad_non_speech_patience_chunks = 3 + vad_agent.settings.behaviour_settings.vad_initial_since_speech = 0 + vad_agent.settings.vad_settings.sample_rate_hz = 16_000 + + +@pytest.fixture(autouse=True) +def mock_torch(mocker): + mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch") + mock_torch.from_numpy.side_effect = lambda arr: arr + return mock_torch def get_audio_chunks() -> list[bytes]: @@ -42,16 +59,39 @@ async def test_real_audio(mocker): audio_in_socket = AsyncMock() audio_in_socket.recv.side_effect = audio_chunks - mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller") - mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)] + mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller") + mock_poller.return_value.poll = AsyncMock(return_value=[(audio_in_socket, zmq.POLLIN)]) audio_out_socket = AsyncMock() - vad_streamer = StreamingBehaviour(audio_in_socket, audio_out_socket) - vad_streamer._ready = True - vad_streamer.agent = MagicMock() - for _ in audio_chunks: - await vad_streamer.run() + vad_agent = VADAgent("tcp://localhost:12345", False) + vad_agent.audio_out_socket = audio_out_socket + + # Use a fake model that marks most chunks as speech and ends with a few silences + silence_padding = 5 + probabilities = [1.0] * len(audio_chunks) + [0.0] * silence_padding + chunk_bytes = audio_chunks + [b"\x00" * len(audio_chunks[0])] * silence_padding + model_item = MagicMock() + model_item.item.side_effect = probabilities + vad_agent.model = MagicMock(return_value=model_item) + + class DummyPoller: + def __init__(self, data, agent): + self.data = data + self.agent = agent + + async def poll(self, timeout_ms=None): + if self.data: + return self.data.pop(0) + self.agent._running = False + return None + + vad_agent.audio_in_poller = DummyPoller(chunk_bytes, vad_agent) + vad_agent._ready = True + vad_agent._running = True + vad_agent.i_since_speech = 0 + + await vad_agent._streaming_loop() audio_out_socket.send.assert_called() for args in audio_out_socket.send.call_args_list: diff --git a/test/unit/agents/bdi/text_belief_agent/behaviours/test_belief_from_text.py b/test/unit/agents/bdi/text_belief_agent/behaviours/test_belief_from_text.py index 294f00d..92e1716 100644 --- a/test/unit/agents/bdi/text_belief_agent/behaviours/test_belief_from_text.py +++ b/test/unit/agents/bdi/text_belief_agent/behaviours/test_belief_from_text.py @@ -2,11 +2,10 @@ import json from unittest.mock import AsyncMock, MagicMock, patch import pytest -from spade.message import Message - from control_backend.agents.bdi.text_belief_extractor_agent.behaviours.text_belief_extractor_behaviour import ( # noqa: E501, We can't shorten this import. TextBeliefExtractorBehaviour, ) +from spade.message import Message @pytest.fixture diff --git a/test/unit/agents/perception/vad_agent/test_vad_socket_poller.py b/test/unit/agents/perception/vad_agent/test_vad_socket_poller.py index 6ac074f..2a4ae62 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_socket_poller.py +++ b/test/unit/agents/perception/vad_agent/test_vad_socket_poller.py @@ -16,8 +16,8 @@ async def test_socket_poller_with_data(socket, mocker): socket_data = b"test" socket.recv.return_value = socket_data - mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller") - mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)] + mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller") + mock_poller.return_value.poll = AsyncMock(return_value=[(socket, zmq.POLLIN)]) poller = SocketPoller(socket) # Calling `poll` twice to be able to check that the poller is reused @@ -35,8 +35,8 @@ async def test_socket_poller_with_data(socket, mocker): @pytest.mark.asyncio async def test_socket_poller_no_data(socket, mocker): - mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller") - mock_poller.return_value.poll.return_value = [] + mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller") + mock_poller.return_value.poll = AsyncMock(return_value=[]) poller = SocketPoller(socket) data = await poller.poll() diff --git a/test/unit/agents/perception/vad_agent/test_vad_streaming.py b/test/unit/agents/perception/vad_agent/test_vad_streaming.py index 13b3f23..84fc71e 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_streaming.py +++ b/test/unit/agents/perception/vad_agent/test_vad_streaming.py @@ -3,12 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import numpy as np import pytest -from control_backend.agents.perception.vad_agent import StreamingBehaviour - - -@pytest.fixture -def audio_in_socket(): - return AsyncMock() +from control_backend.agents.perception.vad_agent import VADAgent @pytest.fixture @@ -17,22 +12,8 @@ def audio_out_socket(): @pytest.fixture -def mock_agent(mocker): - """Fixture to create a mock BDIAgent.""" - agent = MagicMock() - agent.jid = "vad_agent@test" - return agent - - -@pytest.fixture -def streaming(audio_in_socket, audio_out_socket, mock_agent): - import torch - - torch.hub.load.return_value = (..., ...) # Mock - streaming = StreamingBehaviour(audio_in_socket, audio_out_socket) - streaming._ready = True - streaming.agent = mock_agent - return streaming +def vad_agent(audio_out_socket): + return VADAgent("tcp://localhost:5555", False) @pytest.fixture(autouse=True) @@ -61,25 +42,40 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f """ model_item = MagicMock() model_item.item.side_effect = probabilities - streaming.model = MagicMock() - streaming.model.return_value = model_item + streaming.model = MagicMock(return_value=model_item) - audio_in_poller = AsyncMock() - audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32) - streaming.audio_in_poller = audio_in_poller + # Prepare deterministic audio chunks and a poller that stops the loop when exhausted + chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes() + chunks = [chunk_bytes for _ in probabilities] - for _ in probabilities: - await streaming.run() + class DummyPoller: + def __init__(self, data, agent): + self.data = data + self.agent = agent + + async def poll(self, timeout_ms=None): + if self.data: + return self.data.pop(0) + # Stop the loop cleanly once we've consumed all chunks + self.agent._running = False + return None + + streaming.audio_in_poller = DummyPoller(chunks, streaming) + streaming._ready = True + streaming._running = True + + await streaming._streaming_loop() @pytest.mark.asyncio -async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming): +async def test_voice_activity_detected(audio_out_socket, vad_agent): """ Test a scenario where there is voice activity detected between silences. """ speech_chunk_count = 5 probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5 - await simulate_streaming_with_probabilities(streaming, probabilities) + vad_agent.audio_out_socket = audio_out_socket + await simulate_streaming_with_probabilities(vad_agent, probabilities) audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] @@ -88,7 +84,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream @pytest.mark.asyncio -async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming): +async def test_voice_activity_short_pause(audio_out_socket, vad_agent): """ Test a scenario where there is a short pause between speech, checking whether it ignores the short pause. @@ -97,7 +93,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str probabilities = ( [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5 ) - await simulate_streaming_with_probabilities(streaming, probabilities) + vad_agent.audio_out_socket = audio_out_socket + await simulate_streaming_with_probabilities(vad_agent, probabilities) audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] @@ -107,15 +104,22 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str @pytest.mark.asyncio -async def test_no_data(audio_in_socket, audio_out_socket, streaming): +async def test_no_data(audio_out_socket, vad_agent): """ Test a scenario where there is no data received. This should not cause errors. """ - audio_in_poller = AsyncMock() - audio_in_poller.poll.return_value = None - streaming.audio_in_poller = audio_in_poller - await streaming.run() + class DummyPoller: + async def poll(self, timeout_ms=None): + vad_agent._running = False + return None + + vad_agent.audio_out_socket = audio_out_socket + vad_agent.audio_in_poller = DummyPoller() + vad_agent._ready = True + vad_agent._running = True + + await vad_agent._streaming_loop() audio_out_socket.send.assert_not_called() - assert len(streaming.audio_buffer) == 0 + assert len(vad_agent.audio_buffer) == 0