Merge remote-tracking branch 'origin/dev' into fix/bdi-correct-belief-management

This commit is contained in:
2025-10-29 13:25:58 +01:00
25 changed files with 1618 additions and 6 deletions

View File

@@ -0,0 +1,46 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.vad_agent import SocketPoller
@pytest.fixture
def socket():
return AsyncMock()
@pytest.mark.asyncio
async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test"
socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused
await poller.poll()
data = await poller.poll()
assert data == socket_data
# Ensure that the poller was reused
mock_poller.assert_called_once_with()
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
assert socket.recv.call_count == 2
@pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket)
data = await poller.poll()
assert data is None
socket.recv.assert_not_called()

View File

@@ -0,0 +1,95 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.vad_agent import Streaming
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket):
import torch
torch.hub.load.return_value = (..., ...) # Mock
return Streaming(audio_in_socket, audio_out_socket)
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
"""
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
:param streaming: The streaming component to be tested.
:param probabilities: A list of probabilities representing the outputs of the VAD model.
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
for _ in probabilities:
await streaming.run()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is voice activity detected between silences.
:return:
"""
speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count + 2)
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
"""
speech_chunk_count = 5
probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
)
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2)
@pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
await streaming.run()
audio_out_socket.send.assert_not_called()
assert len(streaming.audio_buffer) == 0

45
test/unit/conftest.py Normal file
View File

@@ -0,0 +1,45 @@
import sys
from unittest.mock import MagicMock
def pytest_configure(config):
"""
This hook runs at the start of the pytest session, before any tests are
collected. It mocks heavy or unavailable modules to prevent ImportErrors.
"""
# --- Mock spade and spade-bdi ---
mock_spade = MagicMock()
mock_spade.agent = MagicMock()
mock_spade.behaviour = MagicMock()
mock_spade_bdi = MagicMock()
mock_spade_bdi.bdi = MagicMock()
mock_spade.agent.Message = MagicMock()
mock_spade.behaviour.CyclicBehaviour = type("CyclicBehaviour", (object,), {})
mock_spade_bdi.bdi.BDIAgent = type("BDIAgent", (object,), {})
sys.modules["spade"] = mock_spade
sys.modules["spade.agent"] = mock_spade.agent
sys.modules["spade.behaviour"] = mock_spade.behaviour
sys.modules["spade_bdi"] = mock_spade_bdi
sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi
# --- Mock the config module to prevent Pydantic ImportError ---
mock_config_module = MagicMock()
# The code under test does `from ... import settings`, so our mock module
# must have a `settings` attribute. We'll make it a MagicMock so we can
# configure it later in our tests using mocker.patch.
mock_config_module.settings = MagicMock()
sys.modules["control_backend.core.config"] = mock_config_module
# --- Mock torch and zmq for VAD ---
mock_torch = MagicMock()
mock_zmq = MagicMock()
mock_zmq.asyncio = mock_zmq
# In individual tests, these can be imported and the return values changed
sys.modules["torch"] = mock_torch
sys.modules["zmq"] = mock_zmq
sys.modules["zmq.asyncio"] = mock_zmq.asyncio