337 lines
9.9 KiB
Python
337 lines
9.9 KiB
Python
import asyncio
|
|
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
|
|
|
|
|
|
def speech_agent_path():
|
|
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
|
|
|
|
|
|
@pytest.fixture
|
|
def zmq_context(mocker):
|
|
mock_context = mocker.patch(
|
|
"control_backend.agents.communication.ri_communication_agent.Context.instance"
|
|
)
|
|
mock_context.return_value = MagicMock()
|
|
return mock_context
|
|
|
|
|
|
def negotiation_message(
|
|
actuation_port: int = 5556,
|
|
bind_main: bool = False,
|
|
bind_actuation: bool = True,
|
|
main_port: int = 5555,
|
|
):
|
|
return {
|
|
"endpoint": "negotiate/ports",
|
|
"data": [
|
|
{"id": "main", "port": main_port, "bind": bind_main},
|
|
{"id": "actuation", "port": actuation_port, "bind": bind_actuation},
|
|
],
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_success_connects_and_starts_robot(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
|
fake_socket.send_multipart = AsyncMock()
|
|
|
|
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
|
robot_instance = MockRobot.return_value
|
|
robot_instance.start = AsyncMock()
|
|
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
|
|
|
agent.add_behavior = MagicMock()
|
|
|
|
await agent.setup()
|
|
|
|
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
|
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
|
|
robot_instance.start.assert_awaited_once()
|
|
MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True)
|
|
agent.add_behavior.assert_called_once()
|
|
assert agent.connected is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_binds_when_requested(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
fake_socket.recv_json = AsyncMock(return_value=negotiation_message(bind_main=True))
|
|
fake_socket.send_multipart = AsyncMock()
|
|
|
|
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
|
|
|
|
agent.add_behavior = MagicMock()
|
|
|
|
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
|
MockRobot.return_value.start = AsyncMock()
|
|
await agent.setup()
|
|
|
|
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
|
agent.add_behavior.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_negotiate_invalid_endpoint_retries(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
|
|
fake_socket.send_multipart = AsyncMock()
|
|
|
|
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
|
agent._req_socket = fake_socket
|
|
|
|
success = await agent._negotiate_connection(max_retries=1)
|
|
|
|
assert success is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_negotiate_timeout(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
|
fake_socket.send_multipart = AsyncMock()
|
|
|
|
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
|
agent._req_socket = fake_socket
|
|
|
|
success = await agent._negotiate_connection(max_retries=1)
|
|
|
|
assert success is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_negotiation_response_updates_req_socket(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
|
agent._req_socket = fake_socket
|
|
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
|
MockRobot.return_value.start = AsyncMock()
|
|
await agent._handle_negotiation_response(
|
|
negotiation_message(
|
|
main_port=6000,
|
|
actuation_port=6001,
|
|
bind_main=False,
|
|
bind_actuation=False,
|
|
)
|
|
)
|
|
|
|
fake_socket.connect.assert_any_call("tcp://localhost:6000")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_disconnection_publishes_and_reconnects():
|
|
pub_socket = AsyncMock()
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent.pub_socket = pub_socket
|
|
agent.connected = True
|
|
agent._negotiate_connection = AsyncMock(return_value=True)
|
|
|
|
await agent._handle_disconnection()
|
|
|
|
pub_socket.send_multipart.assert_awaited()
|
|
assert agent.connected is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_listen_loop_handles_non_ping(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return {"endpoint": "negotiate/ports", "data": {}}
|
|
|
|
fake_socket.recv_json = recv_once
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._req_socket = fake_socket
|
|
agent.pub_socket = AsyncMock()
|
|
agent.connected = True
|
|
agent._running = True
|
|
|
|
await agent._listen_loop()
|
|
|
|
fake_socket.send_json.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_negotiate_unexpected_error(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom"))
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._req_socket = fake_socket
|
|
|
|
assert await agent._negotiate_connection(max_retries=1) is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_negotiate_handle_response_error(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
|
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._req_socket = fake_socket
|
|
agent._handle_negotiation_response = AsyncMock(side_effect=Exception("bad response"))
|
|
|
|
assert await agent._negotiate_connection(max_retries=1) is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
fake_socket.recv_json = AsyncMock()
|
|
agent = RICommunicationAgent("ri_comm")
|
|
|
|
async def swallow(coro):
|
|
coro.close()
|
|
|
|
agent.add_behavior = swallow
|
|
agent._negotiate_connection = AsyncMock(return_value=False)
|
|
|
|
await agent.setup()
|
|
|
|
assert agent.connected is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_negotiation_response_unhandled_id():
|
|
agent = RICommunicationAgent("ri_comm")
|
|
|
|
await agent._handle_negotiation_response(
|
|
{"data": [{"id": "other", "port": 5000, "bind": False}]}
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_closes_sockets():
|
|
req = MagicMock()
|
|
pub = MagicMock()
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._req_socket = req
|
|
agent.pub_socket = pub
|
|
|
|
await agent.stop()
|
|
|
|
req.close.assert_called_once()
|
|
pub.close.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_listen_loop_not_connected(monkeypatch):
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._running = True
|
|
agent.connected = False
|
|
agent._req_socket = AsyncMock()
|
|
|
|
async def fake_sleep(duration):
|
|
agent._running = False
|
|
|
|
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
|
|
|
await agent._listen_loop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_listen_loop_send_and_recv_timeout():
|
|
req = AsyncMock()
|
|
req.send_json = AsyncMock(side_effect=TimeoutError)
|
|
req.recv_json = AsyncMock(side_effect=TimeoutError)
|
|
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._req_socket = req
|
|
agent.pub_socket = AsyncMock()
|
|
agent.connected = True
|
|
agent._running = True
|
|
|
|
async def stop_run():
|
|
agent._running = False
|
|
|
|
agent._handle_disconnection = AsyncMock(side_effect=stop_run)
|
|
|
|
await agent._listen_loop()
|
|
|
|
agent._handle_disconnection.assert_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_listen_loop_missing_endpoint(monkeypatch):
|
|
req = AsyncMock()
|
|
req.send_json = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return {"data": {}}
|
|
|
|
req.recv_json = recv_once
|
|
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._req_socket = req
|
|
agent.pub_socket = AsyncMock()
|
|
agent.connected = True
|
|
agent._running = True
|
|
|
|
await agent._listen_loop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_listen_loop_generic_exception():
|
|
req = AsyncMock()
|
|
req.send_json = AsyncMock()
|
|
req.recv_json = AsyncMock(side_effect=ValueError("boom"))
|
|
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._req_socket = req
|
|
agent.pub_socket = AsyncMock()
|
|
agent.connected = True
|
|
agent._running = True
|
|
|
|
with pytest.raises(ValueError):
|
|
await agent._listen_loop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_disconnection_timeout(monkeypatch):
|
|
pub = AsyncMock()
|
|
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
|
|
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent.pub_socket = pub
|
|
agent._negotiate_connection = AsyncMock(return_value=False)
|
|
|
|
await agent._handle_disconnection()
|
|
|
|
pub.send_multipart.assert_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_listen_loop_ping_sends_internal(zmq_context):
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
fake_socket.send_json = AsyncMock()
|
|
pub_socket = AsyncMock()
|
|
|
|
agent = RICommunicationAgent("ri_comm")
|
|
agent._req_socket = fake_socket
|
|
agent.pub_socket = pub_socket
|
|
agent.connected = True
|
|
agent._running = True
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return {"endpoint": "ping", "data": {}}
|
|
|
|
fake_socket.recv_json = recv_once
|
|
|
|
await agent._listen_loop()
|
|
|
|
pub_socket.send_multipart.assert_awaited()
|