test: make VAD tests work again
ref: N25B-301
This commit is contained in:
@@ -3,12 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.perception.vad_agent import StreamingBehaviour
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_in_socket():
|
||||
return AsyncMock()
|
||||
from control_backend.agents.perception.vad_agent import VADAgent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -17,22 +12,8 @@ def audio_out_socket():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent(mocker):
|
||||
"""Fixture to create a mock BDIAgent."""
|
||||
agent = MagicMock()
|
||||
agent.jid = "vad_agent@test"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streaming(audio_in_socket, audio_out_socket, mock_agent):
|
||||
import torch
|
||||
|
||||
torch.hub.load.return_value = (..., ...) # Mock
|
||||
streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
|
||||
streaming._ready = True
|
||||
streaming.agent = mock_agent
|
||||
return streaming
|
||||
def vad_agent(audio_out_socket):
|
||||
return VADAgent("tcp://localhost:5555", False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -61,25 +42,40 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f
|
||||
"""
|
||||
model_item = MagicMock()
|
||||
model_item.item.side_effect = probabilities
|
||||
streaming.model = MagicMock()
|
||||
streaming.model.return_value = model_item
|
||||
streaming.model = MagicMock(return_value=model_item)
|
||||
|
||||
audio_in_poller = AsyncMock()
|
||||
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
|
||||
streaming.audio_in_poller = audio_in_poller
|
||||
# Prepare deterministic audio chunks and a poller that stops the loop when exhausted
|
||||
chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes()
|
||||
chunks = [chunk_bytes for _ in probabilities]
|
||||
|
||||
for _ in probabilities:
|
||||
await streaming.run()
|
||||
class DummyPoller:
|
||||
def __init__(self, data, agent):
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
|
||||
async def poll(self, timeout_ms=None):
|
||||
if self.data:
|
||||
return self.data.pop(0)
|
||||
# Stop the loop cleanly once we've consumed all chunks
|
||||
self.agent._running = False
|
||||
return None
|
||||
|
||||
streaming.audio_in_poller = DummyPoller(chunks, streaming)
|
||||
streaming._ready = True
|
||||
streaming._running = True
|
||||
|
||||
await streaming._streaming_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
||||
async def test_voice_activity_detected(audio_out_socket, vad_agent):
|
||||
"""
|
||||
Test a scenario where there is voice activity detected between silences.
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
@@ -88,7 +84,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
|
||||
async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
|
||||
"""
|
||||
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
||||
short pause.
|
||||
@@ -97,7 +93,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
|
||||
probabilities = (
|
||||
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
)
|
||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
@@ -107,15 +104,22 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
|
||||
async def test_no_data(audio_out_socket, vad_agent):
|
||||
"""
|
||||
Test a scenario where there is no data received. This should not cause errors.
|
||||
"""
|
||||
audio_in_poller = AsyncMock()
|
||||
audio_in_poller.poll.return_value = None
|
||||
streaming.audio_in_poller = audio_in_poller
|
||||
|
||||
await streaming.run()
|
||||
class DummyPoller:
|
||||
async def poll(self, timeout_ms=None):
|
||||
vad_agent._running = False
|
||||
return None
|
||||
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
vad_agent.audio_in_poller = DummyPoller()
|
||||
vad_agent._ready = True
|
||||
vad_agent._running = True
|
||||
|
||||
await vad_agent._streaming_loop()
|
||||
|
||||
audio_out_socket.send.assert_not_called()
|
||||
assert len(streaming.audio_buffer) == 0
|
||||
assert len(vad_agent.audio_buffer) == 0
|
||||
|
||||
Reference in New Issue
Block a user