# Conflicts: # src/control_backend/agents/ri_communication_agent.py # src/control_backend/core/config.py # src/control_backend/main.py
122 lines
3.8 KiB
Python
122 lines
3.8 KiB
Python
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from control_backend.agents.perception.vad_agent import StreamingBehaviour
|
|
|
|
|
|
@pytest.fixture
|
|
def audio_in_socket():
|
|
return AsyncMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def audio_out_socket():
|
|
return AsyncMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent(mocker):
|
|
"""Fixture to create a mock BDIAgent."""
|
|
agent = MagicMock()
|
|
agent.jid = "vad_agent@test"
|
|
return agent
|
|
|
|
|
|
@pytest.fixture
|
|
def streaming(audio_in_socket, audio_out_socket, mock_agent):
|
|
import torch
|
|
|
|
torch.hub.load.return_value = (..., ...) # Mock
|
|
streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
|
|
streaming._ready = True
|
|
streaming.agent = mock_agent
|
|
return streaming
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def patch_settings(monkeypatch):
|
|
# Patch the settings that vad_agent.run() reads
|
|
from control_backend.agents.perception import vad_agent
|
|
|
|
monkeypatch.setattr(
|
|
vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False
|
|
)
|
|
monkeypatch.setattr(
|
|
vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False
|
|
)
|
|
monkeypatch.setattr(
|
|
vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False
|
|
)
|
|
monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
|
|
|
|
|
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
|
"""
|
|
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
|
|
|
:param streaming: The streaming component to be tested.
|
|
:param probabilities: A list of probabilities representing the outputs of the VAD model.
|
|
"""
|
|
model_item = MagicMock()
|
|
model_item.item.side_effect = probabilities
|
|
streaming.model = MagicMock()
|
|
streaming.model.return_value = model_item
|
|
|
|
audio_in_poller = AsyncMock()
|
|
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
|
|
streaming.audio_in_poller = audio_in_poller
|
|
|
|
for _ in probabilities:
|
|
await streaming.run()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
|
"""
|
|
Test a scenario where there is voice activity detected between silences.
|
|
"""
|
|
speech_chunk_count = 5
|
|
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
|
await simulate_streaming_with_probabilities(streaming, probabilities)
|
|
|
|
audio_out_socket.send.assert_called_once()
|
|
data = audio_out_socket.send.call_args[0][0]
|
|
assert isinstance(data, bytes)
|
|
assert len(data) == 512 * 4 * (speech_chunk_count + 1)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
|
|
"""
|
|
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
|
short pause.
|
|
"""
|
|
speech_chunk_count = 5
|
|
probabilities = (
|
|
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
|
)
|
|
await simulate_streaming_with_probabilities(streaming, probabilities)
|
|
|
|
audio_out_socket.send.assert_called_once()
|
|
data = audio_out_socket.send.call_args[0][0]
|
|
assert isinstance(data, bytes)
|
|
# Expecting 13 chunks (2*5 with speech, 1 pause between, 1 as padding)
|
|
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 1)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
|
|
"""
|
|
Test a scenario where there is no data received. This should not cause errors.
|
|
"""
|
|
audio_in_poller = AsyncMock()
|
|
audio_in_poller.poll.return_value = None
|
|
streaming.audio_in_poller = audio_in_poller
|
|
|
|
await streaming.run()
|
|
|
|
audio_out_socket.send.assert_not_called()
|
|
assert len(streaming.audio_buffer) == 0
|