diff --git a/test/integration/agents/actuation/test_robot_speech_agent.py b/test/integration/agents/actuation/test_robot_speech_agent.py index 327415c..b5dd166 100644 --- a/test/integration/agents/actuation/test_robot_speech_agent.py +++ b/test/integration/agents/actuation/test_robot_speech_agent.py @@ -1,16 +1,17 @@ import json -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest import zmq from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent +from control_backend.core.agent_system import InternalMessage @pytest.fixture def zmq_context(mocker): mock_context = mocker.patch( - "control_backend.agents.actuation.robot_speech_agent.zmq.Context.instance" + "control_backend.agents.actuation.robot_speech_agent.azmq.Context.instance" ) mock_context.return_value = MagicMock() return mock_context @@ -18,81 +19,140 @@ def zmq_context(mocker): @pytest.mark.asyncio async def test_setup_bind(zmq_context, mocker): - """Test setup with bind=True""" + """Setup binds and subscribes to internal commands.""" fake_socket = zmq_context.return_value.socket.return_value - - agent = RobotSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=True) + agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=True) settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" + # Swallow background task coroutines to avoid un-awaited warnings + class Swallow: + def __init__(self): + self.calls = 0 + + async def __call__(self, coro): + self.calls += 1 + coro.close() + + swallow = Swallow() + agent.add_background_task = swallow + await agent.setup() fake_socket.bind.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") - - # Ensure behaviour attached - assert any(isinstance(b, agent.SendZMQCommandsBehaviour) for b in agent.behaviours) + assert swallow.calls == 1 @pytest.mark.asyncio async def test_setup_connect(zmq_context, mocker): - """Test setup with bind=False""" + """Setup connects when bind=False.""" fake_socket = zmq_context.return_value.socket.return_value - - agent = RobotSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=False) + agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False) settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" + class Swallow: + def __init__(self): + self.calls = 0 + + async def __call__(self, coro): + self.calls += 1 + coro.close() + + swallow = Swallow() + agent.add_background_task = swallow + await agent.setup() fake_socket.connect.assert_any_call("tcp://localhost:5555") + fake_socket.connect.assert_any_call("tcp://internal:1234") + assert swallow.calls == 1 @pytest.mark.asyncio -async def test_send_commands_behaviour_valid_message(): - """Test behaviour with valid JSON message""" - fake_socket = AsyncMock() - message_dict = {"message": "hello"} - fake_socket.recv_multipart = AsyncMock( - return_value=(b"command", json.dumps(message_dict).encode("utf-8")) - ) - fake_socket.send_json = AsyncMock() +async def test_handle_message_sends_command(): + """Internal message is forwarded to robot pub socket as JSON.""" + pubsocket = AsyncMock() + agent = RobotSpeechAgent("robot_speech") + agent.pubsocket = pubsocket - agent = RobotSpeechAgent("test@server", "password") - agent.subsocket = fake_socket - agent.pubsocket = fake_socket + payload = {"endpoint": "actuate/speech", "data": "hello"} + msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) - behaviour = agent.SendZMQCommandsBehaviour() - behaviour.agent = agent + await agent.handle_message(msg) - with patch( - "control_backend.agents.actuation.robot_speech_agent.SpeechCommand" - ) as MockSpeechCommand: - mock_message = MagicMock() - MockSpeechCommand.model_validate.return_value = mock_message - - await behaviour.run() - - fake_socket.recv_multipart.assert_awaited() - fake_socket.send_json.assert_awaited_with(mock_message.model_dump()) + pubsocket.send_json.assert_awaited_once_with(payload) @pytest.mark.asyncio -async def test_send_commands_behaviour_invalid_message(): - """Test behaviour with invalid JSON message triggers error logging""" +async def test_zmq_command_loop_valid_payload(zmq_context): + """UI command is read from SUB and published.""" + command = {"endpoint": "actuate/speech", "data": "hello"} fake_socket = AsyncMock() - fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}")) - fake_socket.send_json = AsyncMock() - agent = RobotSpeechAgent("test@server", "password") + async def recv_once(): + # stop after first iteration + agent._running = False + return (b"command", json.dumps(command).encode("utf-8")) + + fake_socket.recv_multipart = recv_once + fake_socket.send_json = AsyncMock() + agent = RobotSpeechAgent("robot_speech") agent.subsocket = fake_socket agent.pubsocket = fake_socket + agent._running = True - behaviour = agent.SendZMQCommandsBehaviour() - behaviour.agent = agent + await agent._zmq_command_loop() - await behaviour.run() + fake_socket.send_json.assert_awaited_once_with(command) + + +@pytest.mark.asyncio +async def test_zmq_command_loop_invalid_json(): + """Invalid JSON is ignored without sending.""" + fake_socket = AsyncMock() + + async def recv_once(): + agent._running = False + return (b"command", b"{not_json}") + + fake_socket.recv_multipart = recv_once + fake_socket.send_json = AsyncMock() + agent = RobotSpeechAgent("robot_speech") + agent.subsocket = fake_socket + agent.pubsocket = fake_socket + agent._running = True + + await agent._zmq_command_loop() - fake_socket.recv_multipart.assert_awaited() fake_socket.send_json.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_handle_message_invalid_payload(): + """Invalid payload is caught and does not send.""" + pubsocket = AsyncMock() + agent = RobotSpeechAgent("robot_speech") + agent.pubsocket = pubsocket + + msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"})) + + await agent.handle_message(msg) + + pubsocket.send_json.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_stop_closes_sockets(): + pubsocket = MagicMock() + subsocket = MagicMock() + agent = RobotSpeechAgent("robot_speech") + agent.pubsocket = pubsocket + agent.subsocket = subsocket + + await agent.stop() + + pubsocket.close.assert_called_once() + subsocket.close.assert_called_once() diff --git a/test/integration/agents/communication/test_ri_communication_agent.py b/test/integration/agents/communication/test_ri_communication_agent.py index b82234b..6f0492b 100644 --- a/test/integration/agents/communication/test_ri_communication_agent.py +++ b/test/integration/agents/communication/test_ri_communication_agent.py @@ -10,558 +10,345 @@ def speech_agent_path(): return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent" -def fake_json_correct_negototiate_1(): - return AsyncMock( - return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 5555, "bind": False}, - {"id": "actuation", "port": 5556, "bind": True}, - ], - } - ) - - -def fake_json_correct_negototiate_2(): - return AsyncMock( - return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 5555, "bind": False}, - {"id": "actuation", "port": 5557, "bind": True}, - ], - } - ) - - -def fake_json_correct_negototiate_3(): - return AsyncMock( - return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 5555, "bind": True}, - {"id": "actuation", "port": 5557, "bind": True}, - ], - } - ) - - -def fake_json_correct_negototiate_4(): - # Different port, do bind - return AsyncMock( - return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 4555, "bind": True}, - {"id": "actuation", "port": 5557, "bind": True}, - ], - } - ) - - -def fake_json_correct_negototiate_5(): - # Different port, dont bind. - return AsyncMock( - return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 4555, "bind": False}, - {"id": "actuation", "port": 5557, "bind": True}, - ], - } - ) - - -def fake_json_wrong_negototiate_1(): - return AsyncMock(return_value={"endpoint": "ping", "data": ""}) - - -def fake_json_invalid_id_negototiate(): - return AsyncMock( - return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "banana", "port": 4555, "bind": False}, - {"id": "tomato", "port": 5557, "bind": True}, - ], - } - ) - - @pytest.fixture def zmq_context(mocker): mock_context = mocker.patch( - "control_backend.agents.communication.ri_communication_agent.zmq.Context.instance" + "control_backend.agents.communication.ri_communication_agent.Context.instance" ) mock_context.return_value = MagicMock() return mock_context +def negotiation_message( + actuation_port: int = 5556, + bind_main: bool = False, + bind_actuation: bool = True, + main_port: int = 5555, +): + return { + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": main_port, "bind": bind_main}, + {"id": "actuation", "port": actuation_port, "bind": bind_actuation}, + ], + } + + @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_1(zmq_context): - """ - Test the setup of the communication agent - """ - # --- Arrange --- +async def test_setup_success_connects_and_starts_robot(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() - fake_socket.recv_json = fake_json_correct_negototiate_1() + fake_socket.recv_json = AsyncMock(return_value=negotiation_message()) fake_socket.send_multipart = AsyncMock() - # Mock ActSpeechAgent agent startup - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() + with patch(speech_agent_path(), autospec=True) as MockRobot: + robot_instance = MockRobot.return_value + robot_instance.start = AsyncMock() + agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) + + class Swallow: + def __init__(self): + self.calls = 0 + + async def __call__(self, coro): + self.calls += 1 + coro.close() + + swallow = Swallow() + agent.add_background_task = swallow - # --- Act --- - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, - ) await agent.setup() - # --- Assert --- - fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) - fake_socket.recv_json.assert_awaited() - fake_agent_instance.start.assert_awaited() - MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password - address="tcp://*:5556", # derived from the 'port' value in negotiation - bind=True, - ) - # Ensure the agent attached a ListenBehaviour - assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + fake_socket.connect.assert_any_call("tcp://localhost:5555") + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) + robot_instance.start.assert_awaited_once() + MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True) + assert swallow.calls == 1 + assert agent.connected is True @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_2(zmq_context): - """ - Test the setup of the communication agent - """ - # --- Arrange --- +async def test_setup_binds_when_requested(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() - fake_socket.recv_json = fake_json_correct_negototiate_2() + fake_socket.recv_json = AsyncMock(return_value=negotiation_message(bind_main=True)) fake_socket.send_multipart = AsyncMock() - # Mock ActSpeechAgent agent startup - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() + agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True) - # --- Act --- - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, - ) + class Swallow: + def __init__(self): + self.calls = 0 + + async def __call__(self, coro): + self.calls += 1 + coro.close() + + swallow = Swallow() + agent.add_background_task = swallow + + with patch(speech_agent_path(), autospec=True) as MockRobot: + MockRobot.return_value.start = AsyncMock() await agent.setup() - # --- Assert --- - fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) - fake_socket.recv_json.assert_awaited() - fake_agent_instance.start.assert_awaited() - MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password - address="tcp://*:5557", # derived from the 'port' value in negotiation - bind=True, - ) - # Ensure the agent attached a ListenBehaviour - assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + fake_socket.bind.assert_any_call("tcp://localhost:5555") + assert swallow.calls == 1 @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_3(zmq_context): - """ - Test the functionality of setup with incorrect negotiation message - """ - # --- Arrange --- +async def test_negotiate_invalid_endpoint_retries(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() - fake_socket.recv_json = fake_json_wrong_negototiate_1() - fake_socket.send_multipart = AsyncMock() - - # Mock ActSpeechAgent agent startup - - # We are sending wrong negotiation info to the communication agent, - # so we should retry and expect a better response, within a limited time. - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() - # --- Act --- - - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, - ) - await agent.setup(max_retries=1) - - # --- Assert --- - fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.recv_json.assert_awaited() - - # Since it failed, there should not be any command agent. - fake_agent_instance.start.assert_not_awaited() - - # Ensure the agent did not attach a ListenBehaviour - assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) - - -@pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_4(zmq_context): - """ - Test the setup of the communication agent with different bind value - """ - # --- Arrange --- - fake_socket = zmq_context.return_value.socket.return_value - fake_socket.send_json = AsyncMock() - fake_socket.recv_json = fake_json_correct_negototiate_3() - fake_socket.send_multipart = AsyncMock() - - # Mock ActSpeechAgent agent startup - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() - # --- Act --- - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=True, - ) - await agent.setup() - - # --- Assert --- - fake_socket.bind.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) - fake_socket.recv_json.assert_awaited() - fake_agent_instance.start.assert_awaited() - MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password - address="tcp://*:5557", # derived from the 'port' value in negotiation - bind=True, - ) - # Ensure the agent attached a ListenBehaviour - assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) - - -@pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_5(zmq_context): - """ - Test the setup of the communication agent - """ - # --- Arrange --- - fake_socket = zmq_context.return_value.socket.return_value - fake_socket.send_json = AsyncMock() - fake_socket.recv_json = fake_json_correct_negototiate_4() - fake_socket.send_multipart = AsyncMock() - - # Mock ActSpeechAgent agent startup - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() - # --- Act --- - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, - ) - await agent.setup() - - # --- Assert --- - fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) - fake_socket.recv_json.assert_awaited() - fake_agent_instance.start.assert_awaited() - MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password - address="tcp://*:5557", # derived from the 'port' value in negotiation - bind=True, - ) - # Ensure the agent attached a ListenBehaviour - assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) - - -@pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_6(zmq_context): - """ - Test the setup of the communication agent - """ - # --- Arrange --- - fake_socket = zmq_context.return_value.socket.return_value - fake_socket.send_json = AsyncMock() - fake_socket.recv_json = fake_json_correct_negototiate_5() - fake_socket.send_multipart = AsyncMock() - - # Mock ActSpeechAgent agent startup - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() - # --- Act --- - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, - ) - await agent.setup() - - # --- Assert --- - fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) - fake_socket.recv_json.assert_awaited() - fake_agent_instance.start.assert_awaited() - MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password - address="tcp://*:5557", # derived from the 'port' value in negotiation - bind=True, - ) - # Ensure the agent attached a ListenBehaviour - assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) - - -@pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_7(zmq_context): - """ - Test the functionality of setup with incorrect id - """ - # --- Arrange --- - fake_socket = zmq_context.return_value.socket.return_value - fake_socket.send_json = AsyncMock() - fake_socket.recv_json = fake_json_invalid_id_negototiate() - fake_socket.send_multipart = AsyncMock() - - # Mock ActSpeechAgent agent startup - - # We are sending wrong negotiation info to the communication agent, - # so we should retry and expect a better response, within a limited time. - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() - - # --- Act --- - - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, - ) - await agent.setup(max_retries=1) - - # --- Assert --- - fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.recv_json.assert_awaited() - - # Since it failed, there should not be any command agent. - fake_agent_instance.start.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_timeout(zmq_context): - """ - Test the functionality of setup with incorrect negotiation message - """ - # --- Arrange --- - fake_socket = zmq_context.return_value.socket.return_value - fake_socket.send_json = AsyncMock() - fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) - fake_socket.send_multipart = AsyncMock() - - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() - - # --- Act --- - - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, - ) - await agent.setup(max_retries=1) - - # --- Assert --- - fake_socket.connect.assert_any_call("tcp://localhost:5555") - - # Since it failed, there should not be any command agent. - fake_agent_instance.start.assert_not_awaited() - - # Ensure the agent did not attach a ListenBehaviour - assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) - - -@pytest.mark.asyncio -async def test_listen_behaviour_ping_correct(): - fake_socket = AsyncMock() - fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}}) fake_socket.send_multipart = AsyncMock() - agent = RICommunicationAgent("test@server", "password") + agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) agent._req_socket = fake_socket - agent.connected = True - behaviour = agent.ListenBehaviour() - agent.add_behaviour(behaviour) + success = await agent._negotiate_connection(max_retries=1) - await behaviour.run() - - fake_socket.send_json.assert_awaited() - fake_socket.recv_json.assert_awaited() + assert success is False @pytest.mark.asyncio -async def test_listen_behaviour_ping_wrong_endpoint(): - """ - Test if our listen behaviour can work with wrong messages (wrong endpoint) - """ - fake_socket = AsyncMock() - fake_socket.send_json = AsyncMock() - - # This is a message for the wrong endpoint >:( - fake_socket.recv_json = AsyncMock( - return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 5555, "bind": False}, - {"id": "actuation", "port": 5556, "bind": True}, - ], - } - ) - fake_pub_socket = AsyncMock() - - agent = RICommunicationAgent("test@server", "password", fake_pub_socket) - agent._req_socket = fake_socket - agent.connected = True - - behaviour = agent.ListenBehaviour() - agent.add_behaviour(behaviour) - - # Run once (CyclicBehaviour normally loops) - - await behaviour.run() - - fake_socket.send_json.assert_awaited() - fake_socket.recv_json.assert_awaited() - - -@pytest.mark.asyncio -async def test_listen_behaviour_timeout(zmq_context): +async def test_negotiate_timeout(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() - # recv_json will never resolve, simulate timeout fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.send_multipart = AsyncMock() - agent = RICommunicationAgent("test@server", "password") + agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) agent._req_socket = fake_socket - agent.connected = True - behaviour = agent.ListenBehaviour() - agent.add_behaviour(behaviour) + success = await agent._negotiate_connection(max_retries=1) - await behaviour.run() - assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) - assert not agent.connected + assert success is False @pytest.mark.asyncio -async def test_listen_behaviour_ping_no_endpoint(): - """ - Test if our listen behaviour can work with wrong messages (wrong endpoint) - """ - fake_socket = AsyncMock() - fake_socket.send_json = AsyncMock() - fake_socket.send_multipart = AsyncMock() - - # This is a message without endpoint >:( - fake_socket.recv_json = AsyncMock( - return_value={ - "data": "I dont have an endpoint >:)", - } - ) - - agent = RICommunicationAgent("test@server", "password") +async def test_handle_negotiation_response_updates_req_socket(zmq_context): + fake_socket = zmq_context.return_value.socket.return_value + agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) agent._req_socket = fake_socket - agent.connected = True - - behaviour = agent.ListenBehaviour() - agent.add_behaviour(behaviour) - - await behaviour.run() - - fake_socket.send_json.assert_awaited() - fake_socket.recv_json.assert_awaited() - - -@pytest.mark.asyncio -async def test_setup_unexpected_exception(zmq_context): - fake_socket = zmq_context.return_value.socket.return_value - fake_socket.send_json = AsyncMock() - # Simulate unexpected exception during recv_json() - fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) - fake_socket.send_multipart = AsyncMock() - - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, - ) - - await agent.setup(max_retries=1) - - assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) - assert not agent.connected - - -@pytest.mark.asyncio -async def test_setup_unpacking_exception(zmq_context): - # --- Arrange --- - fake_socket = zmq_context.return_value.socket.return_value - fake_socket.send_json = AsyncMock() - fake_socket.send_multipart = AsyncMock() - - # Make recv_json return malformed negotiation data to trigger unpacking exception - malformed_data = { - "endpoint": "negotiate/ports", - "data": [{"id": "main"}], - } # missing 'port' and 'bind' - fake_socket.recv_json = AsyncMock(return_value=malformed_data) - - # Patch ActSpeechAgent so it won't actually start - with patch(speech_agent_path(), autospec=True) as MockCommandAgent: - fake_agent_instance = MockCommandAgent.return_value - fake_agent_instance.start = AsyncMock() - - agent = RICommunicationAgent( - "test@server", - "password", - address="tcp://localhost:5555", - bind=False, + with patch(speech_agent_path(), autospec=True) as MockRobot: + MockRobot.return_value.start = AsyncMock() + await agent._handle_negotiation_response( + negotiation_message( + main_port=6000, + actuation_port=6001, + bind_main=False, + bind_actuation=False, + ) ) - # --- Act & Assert --- + fake_socket.connect.assert_any_call("tcp://localhost:6000") - await agent.setup(max_retries=1) - # Ensure no command agent was started - fake_agent_instance.start.assert_not_awaited() +@pytest.mark.asyncio +async def test_handle_disconnection_publishes_and_reconnects(): + pub_socket = AsyncMock() + agent = RICommunicationAgent("ri_comm") + agent.pub_socket = pub_socket + agent.connected = True + agent._negotiate_connection = AsyncMock(return_value=True) - # Ensure no behaviour was attached - assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + await agent._handle_disconnection() + + pub_socket.send_multipart.assert_awaited() + assert agent.connected is True + + +@pytest.mark.asyncio +async def test_listen_loop_handles_non_ping(zmq_context): + fake_socket = zmq_context.return_value.socket.return_value + fake_socket.send_json = AsyncMock() + + async def recv_once(): + agent._running = False + return {"endpoint": "negotiate/ports", "data": {}} + + fake_socket.recv_json = recv_once + agent = RICommunicationAgent("ri_comm") + agent._req_socket = fake_socket + agent.pub_socket = AsyncMock() + agent.connected = True + agent._running = True + + await agent._listen_loop() + + fake_socket.send_json.assert_called() + + +@pytest.mark.asyncio +async def test_negotiate_unexpected_error(zmq_context): + fake_socket = zmq_context.return_value.socket.return_value + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = AsyncMock(side_effect=Exception("boom")) + agent = RICommunicationAgent("ri_comm") + agent._req_socket = fake_socket + + assert await agent._negotiate_connection(max_retries=1) is False + + +@pytest.mark.asyncio +async def test_negotiate_handle_response_error(zmq_context): + fake_socket = zmq_context.return_value.socket.return_value + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = AsyncMock(return_value=negotiation_message()) + + agent = RICommunicationAgent("ri_comm") + agent._req_socket = fake_socket + agent._handle_negotiation_response = AsyncMock(side_effect=Exception("bad response")) + + assert await agent._negotiate_connection(max_retries=1) is False + + +@pytest.mark.asyncio +async def test_setup_warns_on_failed_negotiate(zmq_context, mocker): + fake_socket = zmq_context.return_value.socket.return_value + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = AsyncMock() + agent = RICommunicationAgent("ri_comm") + + async def swallow(coro): + coro.close() + + agent.add_background_task = swallow + agent._negotiate_connection = AsyncMock(return_value=False) + + await agent.setup() + + assert agent.connected is False + + +@pytest.mark.asyncio +async def test_handle_negotiation_response_unhandled_id(): + agent = RICommunicationAgent("ri_comm") + + await agent._handle_negotiation_response( + {"data": [{"id": "other", "port": 5000, "bind": False}]} + ) + + +@pytest.mark.asyncio +async def test_stop_closes_sockets(): + req = MagicMock() + pub = MagicMock() + agent = RICommunicationAgent("ri_comm") + agent._req_socket = req + agent.pub_socket = pub + + await agent.stop() + + req.close.assert_called_once() + pub.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_listen_loop_not_connected(monkeypatch): + agent = RICommunicationAgent("ri_comm") + agent._running = True + agent.connected = False + agent._req_socket = AsyncMock() + + async def fake_sleep(duration): + agent._running = False + + monkeypatch.setattr("asyncio.sleep", fake_sleep) + + await agent._listen_loop() + + +@pytest.mark.asyncio +async def test_listen_loop_send_and_recv_timeout(): + req = AsyncMock() + req.send_json = AsyncMock(side_effect=TimeoutError) + req.recv_json = AsyncMock(side_effect=TimeoutError) + + agent = RICommunicationAgent("ri_comm") + agent._req_socket = req + agent.pub_socket = AsyncMock() + agent.connected = True + agent._running = True + + async def stop_run(): + agent._running = False + + agent._handle_disconnection = AsyncMock(side_effect=stop_run) + + await agent._listen_loop() + + agent._handle_disconnection.assert_awaited() + + +@pytest.mark.asyncio +async def test_listen_loop_missing_endpoint(monkeypatch): + req = AsyncMock() + req.send_json = AsyncMock() + + async def recv_once(): + agent._running = False + return {"data": {}} + + req.recv_json = recv_once + + agent = RICommunicationAgent("ri_comm") + agent._req_socket = req + agent.pub_socket = AsyncMock() + agent.connected = True + agent._running = True + + await agent._listen_loop() + + +@pytest.mark.asyncio +async def test_listen_loop_generic_exception(): + req = AsyncMock() + req.send_json = AsyncMock() + req.recv_json = AsyncMock(side_effect=ValueError("boom")) + + agent = RICommunicationAgent("ri_comm") + agent._req_socket = req + agent.pub_socket = AsyncMock() + agent.connected = True + agent._running = True + + with pytest.raises(ValueError): + await agent._listen_loop() + + +@pytest.mark.asyncio +async def test_handle_disconnection_timeout(monkeypatch): + pub = AsyncMock() + pub.send_multipart = AsyncMock(side_effect=TimeoutError) + + agent = RICommunicationAgent("ri_comm") + agent.pub_socket = pub + agent._negotiate_connection = AsyncMock(return_value=False) + + await agent._handle_disconnection() + + pub.send_multipart.assert_awaited() + + +@pytest.mark.asyncio +async def test_listen_loop_ping_sends_internal(zmq_context): + fake_socket = zmq_context.return_value.socket.return_value + fake_socket.send_json = AsyncMock() + pub_socket = AsyncMock() + + agent = RICommunicationAgent("ri_comm") + agent._req_socket = fake_socket + agent.pub_socket = pub_socket + agent.connected = True + agent._running = True + + async def recv_once(): + agent._running = False + return {"endpoint": "ping", "data": {}} + + fake_socket.recv_json = recv_once + + await agent._listen_loop() + + pub_socket.send_multipart.assert_awaited()