Files
pepperplus-cb/test/unit/agents/communication/test_ri_communication_agent.py
2026-01-30 16:53:15 +00:00

403 lines
12 KiB
Python

"""
This program has been developed by students from the bachelor Computer Science at Utrecht
University within the Software Project course.
© Copyright Utrecht University (Department of Information and Computing Sciences)
"""
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
def speech_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
def gesture_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotGestureAgent"
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
"control_backend.agents.communication.ri_communication_agent.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
def negotiation_message(
actuation_port: int = 5556,
bind_main: bool = False,
bind_actuation: bool = False,
main_port: int = 5555,
):
return {
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": main_port, "bind": bind_main},
{"id": "actuation", "port": actuation_port, "bind": bind_actuation},
],
}
@pytest.mark.asyncio
async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
fake_socket.send_multipart = AsyncMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
MockSpeech.return_value.start.assert_awaited_once()
MockGesture.return_value.start.assert_awaited_once()
MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False)
MockGesture.assert_called_once_with(
ANY,
address="tcp://localhost:5556",
bind=False,
gesture_data=[],
single_gesture_data=[],
)
agent.add_behavior.assert_called_once()
assert agent.connected is True
@pytest.mark.asyncio
async def test_setup_binds_when_requested(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message(bind_main=True))
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555")
agent.add_behavior.assert_called_once()
@pytest.mark.asyncio
async def test_negotiate_invalid_endpoint_retries(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
success = await agent._negotiate_connection(max_retries=1)
assert success is False
@pytest.mark.asyncio
async def test_negotiate_timeout(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
success = await agent._negotiate_connection(max_retries=1)
assert success is False
@pytest.mark.asyncio
async def test_handle_negotiation_response_updates_req_socket(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
await agent._handle_negotiation_response(
negotiation_message(
main_port=6000,
actuation_port=6001,
bind_main=False,
bind_actuation=False,
)
)
fake_socket.connect.assert_any_call("tcp://localhost:6000")
@pytest.mark.asyncio
async def test_handle_disconnection_publishes_and_reconnects():
pub_socket = AsyncMock()
pub_socket.close = MagicMock()
agent = RICommunicationAgent("ri_comm")
agent.pub_socket = pub_socket
agent.connected = True
agent._negotiate_connection = AsyncMock(return_value=True)
await agent._handle_disconnection()
pub_socket.send_multipart.assert_awaited()
assert agent.connected is True
@pytest.mark.asyncio
async def test_listen_loop_handles_non_ping(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
async def recv_once():
agent._running = False
return {"endpoint": "negotiate/ports", "data": {}}
fake_socket.recv_json = recv_once
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
await agent._listen_loop()
fake_socket.send_json.assert_called()
@pytest.mark.asyncio
async def test_negotiate_unexpected_error(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom"))
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
assert await agent._negotiate_connection(max_retries=1) is False
@pytest.mark.asyncio
async def test_negotiate_handle_response_error(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent._handle_negotiation_response = AsyncMock(side_effect=Exception("bad response"))
assert await agent._negotiate_connection(max_retries=1) is False
@pytest.mark.asyncio
async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock()
agent = RICommunicationAgent("ri_comm")
def swallow(coro):
coro.close()
agent.add_behavior = swallow
agent._negotiate_connection = AsyncMock(return_value=False)
await agent.setup()
assert agent.connected is False
@pytest.mark.asyncio
async def test_handle_negotiation_response_unhandled_id():
agent = RICommunicationAgent("ri_comm")
await agent._handle_negotiation_response(
{"data": [{"id": "other", "port": 5000, "bind": False}]}
)
@pytest.mark.asyncio
async def test_handle_negotiation_response_audio(zmq_context):
agent = RICommunicationAgent("ri_comm")
with patch(
"control_backend.agents.communication.ri_communication_agent.VADAgent", autospec=True
) as MockVAD:
MockVAD.return_value.start = AsyncMock()
await agent._handle_negotiation_response(
{"data": [{"id": "audio", "port": 7000, "bind": False}]}
)
MockVAD.assert_called_once_with(
audio_in_address="tcp://localhost:7000", audio_in_bind=False
)
MockVAD.return_value.start.assert_awaited_once()
@pytest.mark.asyncio
async def test_stop_closes_sockets():
req = MagicMock()
pub = MagicMock()
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = pub
await agent.stop()
req.close.assert_called_once()
pub.close.assert_called_once()
@pytest.mark.asyncio
async def test_listen_loop_not_connected(monkeypatch):
agent = RICommunicationAgent("ri_comm")
agent._running = True
agent.connected = False
agent._req_socket = AsyncMock()
async def fake_sleep(duration):
agent._running = False
monkeypatch.setattr("asyncio.sleep", fake_sleep)
await agent._listen_loop()
@pytest.mark.asyncio
async def test_listen_loop_send_and_recv_timeout():
req = AsyncMock()
req.send_json = AsyncMock(side_effect=TimeoutError)
req.recv_json = AsyncMock(side_effect=TimeoutError)
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
async def stop_run():
agent._running = False
agent._handle_disconnection = AsyncMock(side_effect=stop_run)
await agent._listen_loop()
agent._handle_disconnection.assert_awaited()
@pytest.mark.asyncio
async def test_listen_loop_missing_endpoint(monkeypatch):
req = AsyncMock()
req.send_json = AsyncMock()
async def recv_once():
agent._running = False
return {"data": {}}
req.recv_json = recv_once
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
await agent._listen_loop()
@pytest.mark.asyncio
async def test_listen_loop_generic_exception():
req = AsyncMock()
req.send_json = AsyncMock()
req.recv_json = AsyncMock(side_effect=ValueError("boom"))
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
with pytest.raises(ValueError):
await agent._listen_loop()
@pytest.mark.asyncio
async def test_handle_disconnection_timeout(monkeypatch):
pub = AsyncMock()
pub.close = MagicMock()
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
agent = RICommunicationAgent("ri_comm")
agent.pub_socket = pub
agent._negotiate_connection = AsyncMock(return_value=False)
await agent._handle_disconnection()
pub.send_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_listen_loop_ping_sends_internal(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
pub_socket = AsyncMock()
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent.pub_socket = pub_socket
agent.connected = True
agent._running = True
async def recv_once():
agent._running = False
return {"endpoint": "ping", "data": {}}
fake_socket.recv_json = recv_once
await agent._listen_loop()
pub_socket.send_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_negotiate_req_socket_none_causes_retry(zmq_context):
agent = RICommunicationAgent("ri_comm")
agent._req_socket = None
result = await agent._negotiate_connection(max_retries=1)
assert result is False