test: make VAD tests work again

ref: N25B-301
This commit is contained in:
Twirre Meulenbelt
2025-11-20 16:33:12 +01:00
parent 610c4b526d
commit 0493d390e3
5 changed files with 137 additions and 72 deletions

View File

@@ -2,11 +2,10 @@ import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from spade.message import Message
from control_backend.agents.bdi.text_belief_extractor_agent.behaviours.text_belief_extractor_behaviour import ( # noqa: E501, We can't shorten this import.
TextBeliefExtractorBehaviour,
)
from spade.message import Message
@pytest.fixture

View File

@@ -16,8 +16,8 @@ async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test"
socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll = AsyncMock(return_value=[(socket, zmq.POLLIN)])
poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused
@@ -35,8 +35,8 @@ async def test_socket_poller_with_data(socket, mocker):
@pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = []
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll = AsyncMock(return_value=[])
poller = SocketPoller(socket)
data = await poller.poll()

View File

@@ -3,12 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.perception.vad_agent import StreamingBehaviour
@pytest.fixture
def audio_in_socket():
return AsyncMock()
from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture
@@ -17,22 +12,8 @@ def audio_out_socket():
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock BDIAgent."""
agent = MagicMock()
agent.jid = "vad_agent@test"
return agent
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket, mock_agent):
import torch
torch.hub.load.return_value = (..., ...) # Mock
streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
streaming._ready = True
streaming.agent = mock_agent
return streaming
def vad_agent(audio_out_socket):
return VADAgent("tcp://localhost:5555", False)
@pytest.fixture(autouse=True)
@@ -61,25 +42,40 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
streaming.model = MagicMock(return_value=model_item)
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
# Prepare deterministic audio chunks and a poller that stops the loop when exhausted
chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes()
chunks = [chunk_bytes for _ in probabilities]
for _ in probabilities:
await streaming.run()
class DummyPoller:
def __init__(self, data, agent):
self.data = data
self.agent = agent
async def poll(self, timeout_ms=None):
if self.data:
return self.data.pop(0)
# Stop the loop cleanly once we've consumed all chunks
self.agent._running = False
return None
streaming.audio_in_poller = DummyPoller(chunks, streaming)
streaming._ready = True
streaming._running = True
await streaming._streaming_loop()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
async def test_voice_activity_detected(audio_out_socket, vad_agent):
"""
Test a scenario where there is voice activity detected between silences.
"""
speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
await simulate_streaming_with_probabilities(streaming, probabilities)
vad_agent.audio_out_socket = audio_out_socket
await simulate_streaming_with_probabilities(vad_agent, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
@@ -88,7 +84,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
@@ -97,7 +93,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
)
await simulate_streaming_with_probabilities(streaming, probabilities)
vad_agent.audio_out_socket = audio_out_socket
await simulate_streaming_with_probabilities(vad_agent, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
@@ -107,15 +104,22 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
@pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
async def test_no_data(audio_out_socket, vad_agent):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
await streaming.run()
class DummyPoller:
async def poll(self, timeout_ms=None):
vad_agent._running = False
return None
vad_agent.audio_out_socket = audio_out_socket
vad_agent.audio_in_poller = DummyPoller()
vad_agent._ready = True
vad_agent._running = True
await vad_agent._streaming_loop()
audio_out_socket.send.assert_not_called()
assert len(streaming.audio_buffer) == 0
assert len(vad_agent.audio_buffer) == 0