Merge remote-tracking branch 'origin/dev' into fix/bdi-correct-belief-management

This commit is contained in:
2025-10-29 13:25:58 +01:00
25 changed files with 1618 additions and 6 deletions

View File

View File

@@ -0,0 +1,102 @@
import asyncio
import zmq
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.schemas.ri_message import SpeechCommand
@pytest.mark.asyncio
async def test_setup_bind(monkeypatch):
"""Test setup with bind=True"""
fake_socket = MagicMock()
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup()
# Ensure PUB socket bound
fake_socket.bind.assert_any_call("tcp://localhost:5555")
# Ensure SUB socket connected to internal address and subscribed
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
# Ensure behaviour attached
assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_connect(monkeypatch):
"""Test setup with bind=False"""
fake_socket = MagicMock()
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup()
# Ensure PUB socket connected
fake_socket.connect.assert_any_call("tcp://localhost:5555")
@pytest.mark.asyncio
async def test_send_commands_behaviour_valid_message():
"""Test behaviour with valid JSON message"""
fake_socket = AsyncMock()
message_dict = {"message": "hello"}
fake_socket.recv_multipart = AsyncMock(
return_value=(b"command", json.dumps(message_dict).encode("utf-8"))
)
fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent
with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand:
mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message
await behaviour.run()
fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_awaited_with(mock_message.model_dump())
@pytest.mark.asyncio
async def test_send_commands_behaviour_invalid_message(caplog):
"""Test behaviour with invalid JSON message triggers error logging"""
fake_socket = AsyncMock()
fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}"))
fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent
with caplog.at_level("ERROR"):
await behaviour.run()
fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_not_awaited()
assert "Error processing message" in caplog.text

View File

@@ -0,0 +1,591 @@
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY
from control_backend.agents.ri_communication_agent import RICommunicationAgent
def fake_json_correct_negototiate_1():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True},
],
}
)
def fake_json_correct_negototiate_2():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_3():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_4():
# Different port, do bind
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 4555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_5():
# Different port, dont bind.
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 4555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_wrong_negototiate_1():
return AsyncMock(return_value={"endpoint": "ping", "data": ""})
def fake_json_invalid_id_negototiate():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "banana", "port": 4555, "bind": False},
{"id": "tomato", "port": 5557, "bind": True},
],
}
)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5556", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("ERROR"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.recv_json.assert_awaited()
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "Failed to set up RICommunicationAgent" in caplog.text
# Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
"""
Test the setup of the communication agent with different bind value
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=True
)
await agent.setup()
# --- Assert ---
fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
"""
Test the functionality of setup with incorrect id
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.recv_json.assert_awaited()
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "Unhandled negotiation id:" in caplog.text
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "No connection established in 20 seconds" in caplog.text
# Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_listen_behaviour_ping_correct(caplog):
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
# TODO: Integration test between actual server and password needed for spade agents
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("DEBUG"):
await behaviour.run()
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
assert "Received message" in caplog.text
@pytest.mark.asyncio
async def test_listen_behaviour_ping_wrong_endpoint(caplog):
"""
Test if our listen behaviour can work with wrong messages (wrong endpoint)
"""
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# This is a message for the wrong endpoint >:(
fake_socket.recv_json = AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True},
],
}
)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("INFO"):
await behaviour.run()
assert "Received message with topic different than ping, while ping expected." in caplog.text
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_listen_behaviour_timeout(caplog):
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
with caplog.at_level("INFO"):
await behaviour.run()
assert "No ping retrieved in 3 seconds" in caplog.text
@pytest.mark.asyncio
async def test_listen_behaviour_ping_no_endpoint(caplog):
"""
Test if our listen behaviour can work with wrong messages (wrong endpoint)
"""
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# This is a message without endpoint >:(
fake_socket.recv_json = AsyncMock(
return_value={
"data": "I dont have an endpoint >:)",
}
)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("ERROR"):
await behaviour.run()
assert "No received endpoint in message, excepted ping endpoint." in caplog.text
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_setup_unexpected_exception(monkeypatch, caplog):
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
with caplog.at_level("ERROR"):
await agent.setup(max_retries=1)
# Ensure that the error was logged
assert "Unexpected error during negotiation: boom!" in caplog.text
@pytest.mark.asyncio
async def test_setup_unpacking_exception(monkeypatch, caplog):
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
# Make recv_json return malformed negotiation data to trigger unpacking exception
malformed_data = {
"endpoint": "negotiate/ports",
"data": [{"id": "main"}],
} # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch context.socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Patch RICommandAgent so it won't actually start
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
# --- Act & Assert ---
with caplog.at_level("ERROR"):
await agent.setup(max_retries=1)
# Ensure the unpacking exception was logged
assert "Error unpacking negotiation data" in caplog.text
# Ensure no command agent was started
fake_agent_instance.start.assert_not_awaited()
# Ensure no behaviour was attached
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)

View File

@@ -0,0 +1,99 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq
from spade.agent import Agent
from control_backend.agents.vad_agent import VADAgent
@pytest.fixture
def zmq_context(mocker):
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
@pytest.fixture
def streaming(mocker):
return mocker.patch("control_backend.agents.vad_agent.Streaming")
@pytest.mark.asyncio
async def test_normal_setup(streaming):
"""
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent.add_behaviour = MagicMock()
await vad_agent.setup()
streaming.assert_called_once()
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
assert vad_agent.audio_in_socket is not None
assert vad_agent.audio_out_socket is not None
@pytest.mark.parametrize("do_bind", [True, False])
def test_in_socket_creation(zmq_context, do_bind: bool):
"""
Test that the VAD agent creates an audio input socket, differentiating between binding and
connecting.
"""
vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
vad_agent._connect_audio_in_socket()
assert vad_agent.audio_in_socket is not None
zmq_context.socket.assert_called_once_with(zmq.SUB)
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
if do_bind:
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else:
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
def test_out_socket_creation(zmq_context):
"""
Test that the VAD agent creates an audio output socket correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._connect_audio_out_socket()
assert vad_agent.audio_out_socket is not None
zmq_context.socket.assert_called_once_with(zmq.PUB)
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio
async def test_out_socket_creation_failure(zmq_context):
"""
Test setup failure when the audio output socket cannot be created.
"""
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
assert vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once()
@pytest.mark.asyncio
async def test_stop(zmq_context):
"""
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
await vad_agent.stop()
assert zmq_context.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None

View File

@@ -0,0 +1,57 @@
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
import soundfile as sf
import zmq
from control_backend.agents.vad_agent import Streaming
def get_audio_chunks() -> list[bytes]:
curr_file = os.path.realpath(__file__)
curr_dir = os.path.dirname(curr_file)
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
chunk_size = 512
chunks = []
with sf.SoundFile(file, "r") as f:
assert f.samplerate == 16000
assert f.channels == 1
assert f.subtype == "FLOAT"
while True:
data = f.read(chunk_size, dtype="float32")
if len(data) != chunk_size:
break
chunks.append(data.tobytes())
return chunks
@pytest.mark.asyncio
async def test_real_audio(mocker):
"""
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
input. Ensure that it outputs some fragments with audio.
"""
audio_chunks = get_audio_chunks()
audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
audio_out_socket = AsyncMock()
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
for _ in audio_chunks:
await vad_streamer.run()
audio_out_socket.send.assert_called()
for args in audio_out_socket.send.call_args_list:
assert isinstance(args[0][0], bytes)
assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio

View File

@@ -0,0 +1,63 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from control_backend.api.v1.endpoints import command
from control_backend.schemas.ri_message import SpeechCommand
@pytest.fixture
def app():
"""
Creates a FastAPI test app and attaches the router under test.
Also sets up a mock internal_comm_socket.
"""
app = FastAPI()
app.include_router(command.router)
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
def test_receive_command_endpoint(client, app):
"""
Test that a POST to /command sends the right multipart message
and returns a 202 with the expected JSON body.
"""
mock_socket = app.state.internal_comm_socket
# Prepare test payload that matches SpeechCommand
payload = {"endpoint": "actuate/speech", "data": "yooo"}
# Send POST request
response = client.post("/command", json=payload)
# Check response
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the socket was called with the correct data
assert mock_socket.send_multipart.called, "Socket should be used to send data"
args, kwargs = mock_socket.send_multipart.call_args
sent_data = args[0]
assert sent_data[0] == b"command"
# Check JSON encoding roughly matches
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
def test_receive_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
assert response.status_code == 422 # validation error

View File

@@ -0,0 +1,36 @@
import pytest
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
from pydantic import ValidationError
def valid_command_1():
return SpeechCommand(data="Hallo?")
def invalid_command_1():
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
def test_valid_speech_command_1():
command = valid_command_1()
try:
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
assert True
except ValidationError:
assert False
def test_invalid_speech_command_1():
command = invalid_command_1()
passed_ri_message_validation = False
try:
# Should succeed, still.
RIMessage.model_validate(command)
passed_ri_message_validation = True
# Should fail.
SpeechCommand.model_validate(command)
assert False
except ValidationError:
assert passed_ri_message_validation

View File

@@ -0,0 +1,46 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.vad_agent import SocketPoller
@pytest.fixture
def socket():
return AsyncMock()
@pytest.mark.asyncio
async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test"
socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused
await poller.poll()
data = await poller.poll()
assert data == socket_data
# Ensure that the poller was reused
mock_poller.assert_called_once_with()
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
assert socket.recv.call_count == 2
@pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket)
data = await poller.poll()
assert data is None
socket.recv.assert_not_called()

View File

@@ -0,0 +1,95 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.vad_agent import Streaming
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket):
import torch
torch.hub.load.return_value = (..., ...) # Mock
return Streaming(audio_in_socket, audio_out_socket)
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
"""
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
:param streaming: The streaming component to be tested.
:param probabilities: A list of probabilities representing the outputs of the VAD model.
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
for _ in probabilities:
await streaming.run()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is voice activity detected between silences.
:return:
"""
speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count + 2)
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
"""
speech_chunk_count = 5
probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
)
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2)
@pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
await streaming.run()
audio_out_socket.send.assert_not_called()
assert len(streaming.audio_buffer) == 0

View File

@@ -33,3 +33,13 @@ def pytest_configure(config):
mock_config_module.settings = MagicMock()
sys.modules["control_backend.core.config"] = mock_config_module
# --- Mock torch and zmq for VAD ---
mock_torch = MagicMock()
mock_zmq = MagicMock()
mock_zmq.asyncio = mock_zmq
# In individual tests, these can be imported and the return values changed
sys.modules["torch"] = mock_torch
sys.modules["zmq"] = mock_zmq
sys.modules["zmq.asyncio"] = mock_zmq.asyncio