234 lines
7.8 KiB
Python
234 lines
7.8 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import zmq
|
|
|
|
from control_backend.agents.perception.vad_agent import VADAgent
|
|
from control_backend.core.config import settings
|
|
|
|
|
|
# We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets
|
|
# aren't closed properly.
|
|
@pytest.fixture(autouse=True)
|
|
def mock_zmq():
|
|
with patch("zmq.asyncio.Context") as mock:
|
|
mock.instance.return_value = MagicMock()
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def audio_out_socket():
|
|
return AsyncMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def vad_agent(audio_out_socket):
|
|
agent = VADAgent("tcp://localhost:5555", False)
|
|
agent._internal_pub_socket = AsyncMock()
|
|
return agent
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def patch_settings(monkeypatch):
|
|
# Patch the settings that vad_agent.run() reads
|
|
from control_backend.agents.perception import vad_agent
|
|
|
|
monkeypatch.setattr(
|
|
vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False
|
|
)
|
|
monkeypatch.setattr(
|
|
vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False
|
|
)
|
|
monkeypatch.setattr(
|
|
vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False
|
|
)
|
|
monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_experiment_logger():
|
|
with patch("control_backend.agents.perception.vad_agent.experiment_logger") as logger:
|
|
yield logger
|
|
|
|
|
|
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
|
"""
|
|
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
|
|
|
:param streaming: The streaming component to be tested.
|
|
:param probabilities: A list of probabilities representing the outputs of the VAD model.
|
|
"""
|
|
model_item = MagicMock()
|
|
model_item.item.side_effect = probabilities
|
|
streaming.model = MagicMock(return_value=model_item)
|
|
|
|
# Prepare deterministic audio chunks and a poller that stops the loop when exhausted
|
|
chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes()
|
|
chunks = [chunk_bytes for _ in probabilities]
|
|
|
|
class DummyPoller:
|
|
def __init__(self, data, agent):
|
|
self.data = data
|
|
self.agent = agent
|
|
|
|
async def poll(self, timeout_ms=None):
|
|
if self.data:
|
|
return self.data.pop(0)
|
|
# Stop the loop cleanly once we've consumed all chunks
|
|
self.agent._running = False
|
|
return None
|
|
|
|
streaming.audio_in_poller = DummyPoller(chunks, streaming)
|
|
streaming._ready = AsyncMock()
|
|
streaming._running = True
|
|
|
|
await streaming._streaming_loop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_activity_detected(audio_out_socket, vad_agent):
|
|
"""
|
|
Test a scenario where there is voice activity detected between silences.
|
|
"""
|
|
speech_chunk_count = 5
|
|
begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks
|
|
probabilities = [0.0] * 15 + [1.0] * speech_chunk_count + [0.0] * 5
|
|
vad_agent.audio_out_socket = audio_out_socket
|
|
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
|
|
|
audio_out_socket.send.assert_called_once()
|
|
data = audio_out_socket.send.call_args[0][0]
|
|
assert isinstance(data, bytes)
|
|
assert len(data) == 512 * 4 * (begin_silence_chunks + speech_chunk_count)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
|
|
"""
|
|
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
|
short pause.
|
|
"""
|
|
speech_chunk_count = 5
|
|
begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks
|
|
probabilities = (
|
|
[0.0] * 15 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
|
)
|
|
vad_agent.audio_out_socket = audio_out_socket
|
|
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
|
|
|
audio_out_socket.send.assert_called_once()
|
|
data = audio_out_socket.send.call_args[0][0]
|
|
assert isinstance(data, bytes)
|
|
# Expecting 13 chunks (2*5 with speech, 1 pause between, begin_silence_chunks as padding)
|
|
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + begin_silence_chunks)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_data(audio_out_socket, vad_agent):
|
|
"""
|
|
Test a scenario where there is no data received. This should not cause errors.
|
|
"""
|
|
|
|
class DummyPoller:
|
|
async def poll(self, timeout_ms=None):
|
|
vad_agent._running = False
|
|
return None
|
|
|
|
vad_agent.audio_out_socket = audio_out_socket
|
|
vad_agent.audio_in_poller = DummyPoller()
|
|
vad_agent._ready = AsyncMock()
|
|
vad_agent._running = True
|
|
|
|
await vad_agent._streaming_loop()
|
|
|
|
audio_out_socket.send.assert_not_called()
|
|
assert len(vad_agent.audio_buffer) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_loop_reset_needed(audio_out_socket, vad_agent):
|
|
"""Test that _reset_needed branch works as expected."""
|
|
vad_agent._reset_needed = True
|
|
vad_agent._ready.set()
|
|
vad_agent._paused.set()
|
|
vad_agent._running = True
|
|
vad_agent.audio_buffer = np.array([1.0], dtype=np.float32)
|
|
vad_agent.i_since_speech = 0
|
|
|
|
# Mock _reset_stream to stop the loop by setting _running=False
|
|
async def mock_reset():
|
|
vad_agent._running = False
|
|
|
|
vad_agent._reset_stream = mock_reset
|
|
|
|
# Needs a poller to avoid AssertionError
|
|
vad_agent.audio_in_poller = AsyncMock()
|
|
vad_agent.audio_in_poller.poll.return_value = None
|
|
|
|
await vad_agent._streaming_loop()
|
|
|
|
assert vad_agent._reset_needed is False
|
|
assert len(vad_agent.audio_buffer) == 0
|
|
assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_loop_no_data_clears_buffer(audio_out_socket, vad_agent):
|
|
"""Test that if poll returns None, buffer is cleared if not empty."""
|
|
vad_agent.audio_buffer = np.array([1.0], dtype=np.float32)
|
|
vad_agent._ready.set()
|
|
vad_agent._paused.set()
|
|
vad_agent._running = True
|
|
|
|
class MockPoller:
|
|
async def poll(self, timeout_ms=None):
|
|
vad_agent._running = False # stop after one poll
|
|
return None
|
|
|
|
vad_agent.audio_in_poller = MockPoller()
|
|
|
|
await vad_agent._streaming_loop()
|
|
|
|
assert len(vad_agent.audio_buffer) == 0
|
|
assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_vad_model_load_failure_stops_agent(vad_agent):
|
|
"""
|
|
Test that if loading the VAD model raises an Exception, it is caught,
|
|
the agent logs an exception, stops itself, and setup returns.
|
|
"""
|
|
# Patch torch.hub.load to raise an exception
|
|
with patch(
|
|
"control_backend.agents.perception.vad_agent.torch.hub.load",
|
|
side_effect=Exception("model fail"),
|
|
):
|
|
# Patch stop to an AsyncMock so we can check it was awaited
|
|
vad_agent.stop = AsyncMock()
|
|
|
|
await vad_agent.setup()
|
|
|
|
# Assert stop was called
|
|
vad_agent.stop.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog):
|
|
"""
|
|
Test that if binding the output socket raises ZMQBindError,
|
|
audio_out_socket is set to None, None is returned, and an error is logged.
|
|
"""
|
|
mock_socket = MagicMock()
|
|
mock_socket.bind.side_effect = zmq.ZMQBindError()
|
|
with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx:
|
|
mock_ctx.return_value.socket.return_value = mock_socket
|
|
|
|
with caplog.at_level("ERROR"):
|
|
port = vad_agent._connect_audio_out_socket()
|
|
|
|
assert port is None
|
|
assert vad_agent.audio_out_socket is None
|
|
assert caplog.text is not None
|