Merge remote-tracking branch 'origin/dev' into refactor/config-file

# Conflicts:
#	src/control_backend/agents/ri_communication_agent.py
#	src/control_backend/core/config.py
#	src/control_backend/main.py
This commit is contained in:
Twirre Meulenbelt
2025-11-19 17:30:48 +01:00
46 changed files with 1207 additions and 651 deletions

View File

@@ -0,0 +1,58 @@
import numpy as np
import pytest
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
OpenAIWhisperSpeechRecognizer,
SpeechRecognizer,
)
@pytest.fixture(autouse=True)
def patch_sr_settings(monkeypatch):
# Patch the *module-local* settings that SpeechRecognizer imported
from control_backend.agents.perception.transcription_agent import speech_recognizer as sr
# Provide real numbers for everything _estimate_max_tokens() reads
monkeypatch.setattr(sr.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
monkeypatch.setattr(
sr.settings.behaviour_settings, "transcription_words_per_minute", 450, raising=False
)
monkeypatch.setattr(
sr.settings.behaviour_settings, "transcription_words_per_token", 0.75, raising=False
)
monkeypatch.setattr(
sr.settings.behaviour_settings, "transcription_token_buffer", 10, raising=False
)
def test_estimate_max_tokens():
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
expecting 610 tokens."""
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
actual = SpeechRecognizer._estimate_max_tokens(audio)
assert actual == 610
assert isinstance(actual, int)
def test_get_decode_options():
"""Check whether the right decode options are given under different scenarios."""
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
# With the defaults, it should limit output length based on input size
recognizer = OpenAIWhisperSpeechRecognizer()
options = recognizer._get_decode_options(audio)
assert "sample_len" in options
assert isinstance(options["sample_len"], int)
# When explicitly enabled, it should limit output length based on input size
recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=True)
options = recognizer._get_decode_options(audio)
assert "sample_len" in options
assert isinstance(options["sample_len"], int)
# When disabled, it should not limit output length based on input size
assert "sample_rate" not in options

View File

@@ -0,0 +1,46 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.perception.vad_agent import SocketPoller
@pytest.fixture
def socket():
return AsyncMock()
@pytest.mark.asyncio
async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test"
socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused
await poller.poll()
data = await poller.poll()
assert data == socket_data
# Ensure that the poller was reused
mock_poller.assert_called_once_with()
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
assert socket.recv.call_count == 2
@pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket)
data = await poller.poll()
assert data is None
socket.recv.assert_not_called()

View File

@@ -0,0 +1,121 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.perception.vad_agent import StreamingBehaviour
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock BDIAgent."""
agent = MagicMock()
agent.jid = "vad_agent@test"
return agent
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket, mock_agent):
import torch
torch.hub.load.return_value = (..., ...) # Mock
streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
streaming._ready = True
streaming.agent = mock_agent
return streaming
@pytest.fixture(autouse=True)
def patch_settings(monkeypatch):
# Patch the settings that vad_agent.run() reads
from control_backend.agents.perception import vad_agent
monkeypatch.setattr(
vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False
)
monkeypatch.setattr(
vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False
)
monkeypatch.setattr(
vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False
)
monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
"""
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
:param streaming: The streaming component to be tested.
:param probabilities: A list of probabilities representing the outputs of the VAD model.
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
for _ in probabilities:
await streaming.run()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is voice activity detected between silences.
"""
speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
assert len(data) == 512 * 4 * (speech_chunk_count + 1)
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
"""
speech_chunk_count = 5
probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
)
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# Expecting 13 chunks (2*5 with speech, 1 pause between, 1 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 1)
@pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
await streaming.run()
audio_out_socket.send.assert_not_called()
assert len(streaming.audio_buffer) == 0