from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest import zmq from control_backend.agents.perception.vad_agent import VADAgent from control_backend.core.config import settings # We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets # aren't closed properly. @pytest.fixture(autouse=True) def mock_zmq(): with patch("zmq.asyncio.Context") as mock: mock.instance.return_value = MagicMock() yield mock @pytest.fixture def audio_out_socket(): return AsyncMock() @pytest.fixture def vad_agent(audio_out_socket): agent = VADAgent("tcp://localhost:5555", False) agent._internal_pub_socket = AsyncMock() return agent @pytest.fixture(autouse=True) def patch_settings(monkeypatch): # Patch the settings that vad_agent.run() reads from control_backend.agents.perception import vad_agent monkeypatch.setattr( vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False ) monkeypatch.setattr( vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False ) monkeypatch.setattr( vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False ) monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False) @pytest.fixture(autouse=True) def mock_experiment_logger(): with patch("control_backend.agents.perception.vad_agent.experiment_logger") as logger: yield logger async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): """ Simulates a streaming scenario with given VAD model probabilities for testing purposes. :param streaming: The streaming component to be tested. :param probabilities: A list of probabilities representing the outputs of the VAD model. """ model_item = MagicMock() model_item.item.side_effect = probabilities streaming.model = MagicMock(return_value=model_item) # Prepare deterministic audio chunks and a poller that stops the loop when exhausted chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes() chunks = [chunk_bytes for _ in probabilities] class DummyPoller: def __init__(self, data, agent): self.data = data self.agent = agent async def poll(self, timeout_ms=None): if self.data: return self.data.pop(0) # Stop the loop cleanly once we've consumed all chunks self.agent._running = False return None streaming.audio_in_poller = DummyPoller(chunks, streaming) streaming._ready = AsyncMock() streaming._running = True await streaming._streaming_loop() @pytest.mark.asyncio async def test_voice_activity_detected(audio_out_socket, vad_agent): """ Test a scenario where there is voice activity detected between silences. """ speech_chunk_count = 5 begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks probabilities = [0.0] * 15 + [1.0] * speech_chunk_count + [0.0] * 5 vad_agent.audio_out_socket = audio_out_socket await simulate_streaming_with_probabilities(vad_agent, probabilities) audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] assert isinstance(data, bytes) assert len(data) == 512 * 4 * (begin_silence_chunks + speech_chunk_count) @pytest.mark.asyncio async def test_voice_activity_short_pause(audio_out_socket, vad_agent): """ Test a scenario where there is a short pause between speech, checking whether it ignores the short pause. """ speech_chunk_count = 5 begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks probabilities = ( [0.0] * 15 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5 ) vad_agent.audio_out_socket = audio_out_socket await simulate_streaming_with_probabilities(vad_agent, probabilities) audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] assert isinstance(data, bytes) # Expecting 13 chunks (2*5 with speech, 1 pause between, begin_silence_chunks as padding) assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + begin_silence_chunks) @pytest.mark.asyncio async def test_no_data(audio_out_socket, vad_agent): """ Test a scenario where there is no data received. This should not cause errors. """ class DummyPoller: async def poll(self, timeout_ms=None): vad_agent._running = False return None vad_agent.audio_out_socket = audio_out_socket vad_agent.audio_in_poller = DummyPoller() vad_agent._ready = AsyncMock() vad_agent._running = True await vad_agent._streaming_loop() audio_out_socket.send.assert_not_called() assert len(vad_agent.audio_buffer) == 0 @pytest.mark.asyncio async def test_streaming_loop_reset_needed(audio_out_socket, vad_agent): """Test that _reset_needed branch works as expected.""" vad_agent._reset_needed = True vad_agent._ready.set() vad_agent._paused.set() vad_agent._running = True vad_agent.audio_buffer = np.array([1.0], dtype=np.float32) vad_agent.i_since_speech = 0 # Mock _reset_stream to stop the loop by setting _running=False async def mock_reset(): vad_agent._running = False vad_agent._reset_stream = mock_reset # Needs a poller to avoid AssertionError vad_agent.audio_in_poller = AsyncMock() vad_agent.audio_in_poller.poll.return_value = None await vad_agent._streaming_loop() assert vad_agent._reset_needed is False assert len(vad_agent.audio_buffer) == 0 assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech @pytest.mark.asyncio async def test_streaming_loop_no_data_clears_buffer(audio_out_socket, vad_agent): """Test that if poll returns None, buffer is cleared if not empty.""" vad_agent.audio_buffer = np.array([1.0], dtype=np.float32) vad_agent._ready.set() vad_agent._paused.set() vad_agent._running = True class MockPoller: async def poll(self, timeout_ms=None): vad_agent._running = False # stop after one poll return None vad_agent.audio_in_poller = MockPoller() await vad_agent._streaming_loop() assert len(vad_agent.audio_buffer) == 0 assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech @pytest.mark.asyncio async def test_vad_model_load_failure_stops_agent(vad_agent): """ Test that if loading the VAD model raises an Exception, it is caught, the agent logs an exception, stops itself, and setup returns. """ # Patch torch.hub.load to raise an exception with patch( "control_backend.agents.perception.vad_agent.torch.hub.load", side_effect=Exception("model fail"), ): # Patch stop to an AsyncMock so we can check it was awaited vad_agent.stop = AsyncMock() await vad_agent.setup() # Assert stop was called vad_agent.stop.assert_awaited_once() @pytest.mark.asyncio async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog): """ Test that if binding the output socket raises ZMQBindError, audio_out_socket is set to None, None is returned, and an error is logged. """ mock_socket = MagicMock() mock_socket.bind.side_effect = zmq.ZMQBindError() with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx: mock_ctx.return_value.socket.return_value = mock_socket with caplog.at_level("ERROR"): port = vad_agent._connect_audio_out_socket() assert port is None assert vad_agent.audio_out_socket is None assert caplog.text is not None