diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index ac561ed..030e5f9 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -72,7 +72,7 @@ class RICommandAgent(BaseAgent): # To the robot self.pubsocket = context.socket(zmq.PUB) - if self.bind: + if self.bind: # TODO: Should this ever be the case? self.pubsocket.bind(self.address) else: self.pubsocket.connect(self.address) diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 76d6431..93fbf6c 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -1,6 +1,7 @@ import asyncio +import json -import zmq +import zmq.asyncio from spade.behaviour import CyclicBehaviour from zmq.asyncio import Context @@ -14,6 +15,7 @@ class RICommunicationAgent(BaseAgent): req_socket: zmq.Socket _address = "" _bind = True + connected = False def __init__( self, @@ -27,6 +29,8 @@ class RICommunicationAgent(BaseAgent): super().__init__(jid, password, port, verify_security) self._address = address self._bind = bind + self._req_socket: zmq.asyncio.Socket | None = None + self.pub_socket: zmq.asyncio.Socket | None = None class ListenBehaviour(CyclicBehaviour): async def run(self): @@ -35,59 +39,128 @@ class RICommunicationAgent(BaseAgent): """ assert self.agent is not None + if not self.agent.connected: + await asyncio.sleep(1) + return + # We need to listen and sent pings. message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}} - await self.agent.req_socket.send_json(message) - - # Wait up to three seconds for a reply:) + seconds_to_wait_total = 1.0 try: - message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0) - - # We didnt get a reply :( + await asyncio.wait_for( + self.agent._req_socket.send_json(message), timeout=seconds_to_wait_total / 2 + ) except TimeoutError: - self.agent.logger.info("No ping retrieved in 3 seconds, killing myself.") - self.kill() + self.agent.logger.debug( + "Waited too long to send message - " + "we probably dont have any receivers... but let's check!" + ) - self.agent.logger.debug('Received message "%s"', message) + # Wait up to {seconds_to_wait_total/2} seconds for a reply + try: + message = await asyncio.wait_for( + self.agent._req_socket.recv_json(), timeout=seconds_to_wait_total / 2 + ) + + # We didnt get a reply + except TimeoutError: + self.agent.logger.info( + f"No ping retrieved in {seconds_to_wait_total} seconds, " + "sending UI disconnection event and attempting to restart." + ) + + # Make sure we dont retry receiving messages untill we're setup. + self.agent.connected = False + self.agent.remove_behaviour(self) + + # Tell UI we're disconnected. + topic = b"ping" + data = json.dumps(False).encode() + if self.agent.pub_socket is None: + self.agent.logger.warning( + "Communication agent pub socket not correctly initialized." + ) + else: + try: + await asyncio.wait_for( + self.agent.pub_socket.send_multipart([topic, data]), 5 + ) + except TimeoutError: + self.agent.logger.warning( + "Initial connection ping for router timed" + " out in ri_communication_agent." + ) + + # Try to reboot. + self.agent.logger.debug("Restarting communication agent.") + await self.agent.setup() + + self.agent.logger.debug(f'Received message "{message}" from RI.') if "endpoint" not in message: - self.agent.logger.error("No received endpoint in message, excepted ping endpoint.") + self.agent.logger.warning( + "No received endpoint in message, expected ping endpoint." + ) return # See what endpoint we received match message["endpoint"]: case "ping": + topic = b"ping" + data = json.dumps(True).encode() + if self.agent.pub_socket is not None: + await self.agent.pub_socket.send_multipart([topic, data]) await asyncio.sleep(1) case _: - self.agent.logger.info( + self.agent.logger.debug( "Received message with topic different than ping, while ping expected." ) - async def setup(self, max_retries: int = 5): + async def setup_sockets(self, force=False): + """ + Sets up request socket for communication agent. + """ + # Bind request socket + if self._req_socket is None or force: + self._req_socket = Context.instance().socket(zmq.REQ) + if self._bind: + self._req_socket.bind(self._address) + else: + self._req_socket.connect(self._address) + + if self.pub_socket is None or force: + self.pub_socket = Context.instance().socket(zmq.PUB) + self.pub_socket.connect(settings.zmq_settings.internal_pub_address) + + async def setup(self, max_retries: int = 100): """ Try to setup the communication agent, we have 5 retries in case we dont have a response yet. """ self.logger.info("Setting up %s", self.jid) - retries = 0 + # Bind request socket + await self.setup_sockets() + + retries = 0 # Let's try a certain amount of times before failing connection while retries < max_retries: - # Bind request socket - self.req_socket = Context.instance().socket(zmq.REQ) - if self._bind: - self.req_socket.bind(self._address) - else: - self.req_socket.connect(self._address) + # Make sure the socket is properly setup. + if self._req_socket is None: + continue - # Send our message and receive one back:) - message = {"endpoint": "negotiate/ports", "data": None} - await self.req_socket.send_json(message) + # Send our message and receive one back + message = {"endpoint": "negotiate/ports", "data": {}} + await self._req_socket.send_json(message) + retry_frequency = 1.0 try: - received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) + received_message = await asyncio.wait_for( + self._req_socket.recv_json(), timeout=retry_frequency + ) except TimeoutError: self.logger.warning( - "No connection established in 20 seconds (attempt %d/%d)", + "No connection established in %d seconds (attempt %d/%d)", + retries * retry_frequency, retries + 1, max_retries, ) @@ -95,21 +168,21 @@ class RICommunicationAgent(BaseAgent): continue except Exception as e: - self.logger.error("Unexpected error during negotiation: %s", e) + self.logger.warning("Unexpected error during negotiation: %s", e) retries += 1 continue # Validate endpoint endpoint = received_message.get("endpoint") if endpoint != "negotiate/ports": - # TODO: Should this send a message back? - self.logger.error( + self.logger.warning( "Invalid endpoint '%s' received (attempt %d/%d)", endpoint, retries + 1, max_retries, ) retries += 1 + await asyncio.sleep(1) continue # At this point, we have a valid response @@ -128,9 +201,9 @@ class RICommunicationAgent(BaseAgent): case "main": if addr != self._address: if not bind: - self.req_socket.connect(addr) + self._req_socket.connect(addr) else: - self.req_socket.bind(addr) + self._req_socket.bind(addr) case "actuation": ri_commands_agent = RICommandAgent( settings.agent_settings.ri_command_agent_name @@ -145,18 +218,35 @@ class RICommunicationAgent(BaseAgent): self.logger.warning("Unhandled negotiation id: %s", id) except Exception as e: - self.logger.error("Error unpacking negotiation data: %s", e) + self.logger.warning("Error unpacking negotiation data: %s", e) retries += 1 + await asyncio.sleep(1) continue # setup succeeded break else: - self.logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries) + self.logger.warning("Failed to set up %s after %d retries", self.name, max_retries) return # Set up ping behaviour listen_behaviour = self.ListenBehaviour() self.add_behaviour(listen_behaviour) + + # Let UI know that we're connected + topic = b"ping" + data = json.dumps(True).encode() + if self.pub_socket is None: + self.logger.warning("Communication agent pub socket not correctly initialized.") + else: + try: + await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5) + except TimeoutError: + self.logger.warning( + "Initial connection ping for router timed out in ri_communication_agent." + ) + + # Make sure to start listening now that we're connected. + self.connected = True self.logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py deleted file mode 100644 index d7f963b..0000000 --- a/src/control_backend/api/v1/endpoints/command.py +++ /dev/null @@ -1,20 +0,0 @@ -import logging - -from fastapi import APIRouter, Request - -from control_backend.schemas.ri_message import SpeechCommand - -logger = logging.getLogger(__name__) - -router = APIRouter() - - -@router.post("/command", status_code=202) -async def receive_command(command: SpeechCommand, request: Request): - # Validate and retrieve data. - SpeechCommand.model_validate(command) - topic = b"command" - pub_socket = request.app.state.endpoints_pub_socket - await pub_socket.send_multipart([topic, command.model_dump_json().encode()]) - - return {"status": "Command received"} diff --git a/src/control_backend/api/v1/endpoints/robot.py b/src/control_backend/api/v1/endpoints/robot.py new file mode 100644 index 0000000..eb67b0e --- /dev/null +++ b/src/control_backend/api/v1/endpoints/robot.py @@ -0,0 +1,71 @@ +import asyncio +import json +import logging + +import zmq.asyncio +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from zmq.asyncio import Context, Socket + +from control_backend.core.config import settings +from control_backend.schemas.ri_message import SpeechCommand + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post("/command", status_code=202) +async def receive_command(command: SpeechCommand, request: Request): + # Validate and retrieve data. + SpeechCommand.model_validate(command) + topic = b"command" + + pub_socket: Socket = request.app.state.endpoints_pub_socket + await pub_socket.send_multipart([topic, command.model_dump_json().encode()]) + + return {"status": "Command received"} + + +@router.get("/ping_check") +async def ping(request: Request): + pass + + +@router.get("/ping_stream") +async def ping_stream(request: Request): + """Stream live updates whenever the device state changes.""" + + async def event_stream(): + # Set up internal socket to receive ping updates + + sub_socket = Context.instance().socket(zmq.SUB) + sub_socket.connect(settings.zmq_settings.internal_sub_address) + sub_socket.setsockopt(zmq.SUBSCRIBE, b"ping") + connected = False + + ping_frequency = 2 + + # Even though its most likely the updates should alternate + # (So, True - False - True - False for connectivity), + # let's still check. + while True: + try: + topic, body = await asyncio.wait_for( + sub_socket.recv_multipart(), timeout=ping_frequency + ) + connected = json.loads(body) + except TimeoutError: + logger.debug("got timeout error in ping loop in ping router") + connected = False + + # Stop if client disconnected + if await request.is_disconnected(): + logger.info("Client disconnected from SSE") + break + + logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}") + connectedJson = json.dumps(connected) + yield (f"data: {connectedJson}\n\n") + + return StreamingResponse(event_stream(), media_type="text/event-stream") diff --git a/src/control_backend/api/v1/router.py b/src/control_backend/api/v1/router.py index f11dc9c..115cd26 100644 --- a/src/control_backend/api/v1/router.py +++ b/src/control_backend/api/v1/router.py @@ -1,6 +1,6 @@ from fastapi.routing import APIRouter -from control_backend.api.v1.endpoints import command, logs, message, sse +from control_backend.api.v1.endpoints import logs, message, robot, sse api_router = APIRouter() @@ -8,6 +8,6 @@ api_router.include_router(message.router, tags=["Messages"]) api_router.include_router(sse.router, tags=["SSE"]) -api_router.include_router(command.router, tags=["Commands"]) +api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"]) api_router.include_router(logs.router, tags=["Logs"]) diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index 00edcb1..477ab78 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -25,13 +25,9 @@ async def test_setup_bind(zmq_context, mocker): await agent.setup() - # Ensure PUB socket bound fake_socket.bind.assert_any_call("tcp://localhost:5555") - # Ensure SUB socket connected to internal address and subscribed fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") - - # Ensure behaviour attached assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours) @@ -46,7 +42,6 @@ async def test_setup_connect(zmq_context, mocker): await agent.setup() - # Ensure PUB socket connected fake_socket.connect.assert_any_call("tcp://localhost:5555") @@ -78,7 +73,7 @@ async def test_send_commands_behaviour_valid_message(): @pytest.mark.asyncio -async def test_send_commands_behaviour_invalid_message(caplog): +async def test_send_commands_behaviour_invalid_message(): """Test behaviour with invalid JSON message triggers error logging""" fake_socket = AsyncMock() fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}")) @@ -91,9 +86,7 @@ async def test_send_commands_behaviour_invalid_message(caplog): behaviour = agent.SendCommandsBehaviour() behaviour.agent = agent - with caplog.at_level("ERROR"): - await behaviour.run() + await behaviour.run() fake_socket.recv_multipart.assert_awaited() fake_socket.send_json.assert_not_awaited() - assert "Error processing message" in caplog.text diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index 443d609..6e29340 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -100,6 +100,7 @@ async def test_setup_creates_socket_and_negotiate_1(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_1() + fake_socket.send_multipart = AsyncMock() # Mock RICommandAgent agent startup with patch( @@ -110,13 +111,16 @@ async def test_setup_creates_socket_and_negotiate_1(zmq_context): # --- Act --- agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, ) await agent.setup() # --- Assert --- fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( @@ -138,6 +142,7 @@ async def test_setup_creates_socket_and_negotiate_2(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_2() + fake_socket.send_multipart = AsyncMock() # Mock RICommandAgent agent startup with patch( @@ -148,13 +153,16 @@ async def test_setup_creates_socket_and_negotiate_2(zmq_context): # --- Act --- agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, ) await agent.setup() # --- Assert --- fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( @@ -168,7 +176,7 @@ async def test_setup_creates_socket_and_negotiate_2(zmq_context): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): +async def test_setup_creates_socket_and_negotiate_3(zmq_context): """ Test the functionality of setup with incorrect negotiation message """ @@ -176,6 +184,7 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_wrong_negototiate_1() + fake_socket.send_multipart = AsyncMock() # Mock RICommandAgent agent startup @@ -186,13 +195,15 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - # --- Act --- - with caplog.at_level("ERROR"): - agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False - ) - await agent.setup(max_retries=1) + + agent = RICommunicationAgent( + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, + ) + await agent.setup(max_retries=1) # --- Assert --- fake_socket.connect.assert_any_call("tcp://localhost:5555") @@ -200,7 +211,6 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): # Since it failed, there should not be any command agent. fake_agent_instance.start.assert_not_awaited() - assert "Failed to set up RICommunicationAgent" in caplog.text # Ensure the agent did not attach a ListenBehaviour assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) @@ -215,6 +225,7 @@ async def test_setup_creates_socket_and_negotiate_4(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_3() + fake_socket.send_multipart = AsyncMock() # Mock RICommandAgent agent startup with patch( @@ -222,16 +233,18 @@ async def test_setup_creates_socket_and_negotiate_4(zmq_context): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - # --- Act --- agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=True + "test@server", + "password", + address="tcp://localhost:5555", + bind=True, ) await agent.setup() # --- Assert --- fake_socket.bind.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( @@ -253,6 +266,7 @@ async def test_setup_creates_socket_and_negotiate_5(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_4() + fake_socket.send_multipart = AsyncMock() # Mock RICommandAgent agent startup with patch( @@ -260,16 +274,18 @@ async def test_setup_creates_socket_and_negotiate_5(zmq_context): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - # --- Act --- agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, ) await agent.setup() # --- Assert --- fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( @@ -291,6 +307,7 @@ async def test_setup_creates_socket_and_negotiate_6(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_5() + fake_socket.send_multipart = AsyncMock() # Mock RICommandAgent agent startup with patch( @@ -298,16 +315,18 @@ async def test_setup_creates_socket_and_negotiate_6(zmq_context): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - # --- Act --- agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, ) await agent.setup() # --- Assert --- fake_socket.connect.assert_any_call("tcp://localhost:5555") - fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( @@ -321,7 +340,7 @@ async def test_setup_creates_socket_and_negotiate_6(zmq_context): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): +async def test_setup_creates_socket_and_negotiate_7(zmq_context): """ Test the functionality of setup with incorrect id """ @@ -329,6 +348,7 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_invalid_id_negototiate() + fake_socket.send_multipart = AsyncMock() # Mock RICommandAgent agent startup @@ -341,11 +361,14 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): fake_agent_instance.start = AsyncMock() # --- Act --- - with caplog.at_level("WARNING"): - agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False - ) - await agent.setup(max_retries=1) + + agent = RICommunicationAgent( + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, + ) + await agent.setup(max_retries=1) # --- Assert --- fake_socket.connect.assert_any_call("tcp://localhost:5555") @@ -353,11 +376,10 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): # Since it failed, there should not be any command agent. fake_agent_instance.start.assert_not_awaited() - assert "Unhandled negotiation id:" in caplog.text @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog): +async def test_setup_creates_socket_and_negotiate_timeout(zmq_context): """ Test the functionality of setup with incorrect negotiation message """ @@ -365,6 +387,7 @@ async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) + fake_socket.send_multipart = AsyncMock() with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -373,47 +396,47 @@ async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog): fake_agent_instance.start = AsyncMock() # --- Act --- - with caplog.at_level("WARNING"): - agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False - ) - await agent.setup(max_retries=1) + + agent = RICommunicationAgent( + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, + ) + await agent.setup(max_retries=1) # --- Assert --- fake_socket.connect.assert_any_call("tcp://localhost:5555") # Since it failed, there should not be any command agent. fake_agent_instance.start.assert_not_awaited() - assert "No connection established in 20 seconds" in caplog.text # Ensure the agent did not attach a ListenBehaviour assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) @pytest.mark.asyncio -async def test_listen_behaviour_ping_correct(caplog): +async def test_listen_behaviour_ping_correct(): fake_socket = AsyncMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}}) + fake_socket.send_multipart = AsyncMock() - # TODO: Integration test between actual server and password needed for spade agents agent = RICommunicationAgent("test@server", "password") - agent.req_socket = fake_socket + agent._req_socket = fake_socket + agent.connected = True behaviour = agent.ListenBehaviour() agent.add_behaviour(behaviour) - # Run once (CyclicBehaviour normally loops) - with caplog.at_level("DEBUG"): - await behaviour.run() + await behaviour.run() fake_socket.send_json.assert_awaited() fake_socket.recv_json.assert_awaited() - assert "Received message" in caplog.text @pytest.mark.asyncio -async def test_listen_behaviour_ping_wrong_endpoint(caplog): +async def test_listen_behaviour_ping_wrong_endpoint(): """ Test if our listen behaviour can work with wrong messages (wrong endpoint) """ @@ -430,48 +453,51 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog): ], } ) + fake_pub_socket = AsyncMock() - agent = RICommunicationAgent("test@server", "password") - agent.req_socket = fake_socket + agent = RICommunicationAgent("test@server", "password", fake_pub_socket) + agent._req_socket = fake_socket + agent.connected = True behaviour = agent.ListenBehaviour() agent.add_behaviour(behaviour) # Run once (CyclicBehaviour normally loops) - with caplog.at_level("INFO"): - await behaviour.run() - assert "Received message with topic different than ping, while ping expected." in caplog.text + await behaviour.run() + fake_socket.send_json.assert_awaited() fake_socket.recv_json.assert_awaited() @pytest.mark.asyncio -async def test_listen_behaviour_timeout(zmq_context, caplog): +async def test_listen_behaviour_timeout(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # recv_json will never resolve, simulate timeout fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) + fake_socket.send_multipart = AsyncMock() agent = RICommunicationAgent("test@server", "password") - agent.req_socket = fake_socket + agent._req_socket = fake_socket + agent.connected = True behaviour = agent.ListenBehaviour() agent.add_behaviour(behaviour) - with caplog.at_level("INFO"): - await behaviour.run() - - assert "No ping retrieved in 3 seconds" in caplog.text + await behaviour.run() + assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + assert not agent.connected @pytest.mark.asyncio -async def test_listen_behaviour_ping_no_endpoint(caplog): +async def test_listen_behaviour_ping_no_endpoint(): """ Test if our listen behaviour can work with wrong messages (wrong endpoint) """ fake_socket = AsyncMock() fake_socket.send_json = AsyncMock() + fake_socket.send_multipart = AsyncMock() # This is a message without endpoint >:( fake_socket.recv_json = AsyncMock( @@ -481,43 +507,45 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): ) agent = RICommunicationAgent("test@server", "password") - agent.req_socket = fake_socket + agent._req_socket = fake_socket + agent.connected = True behaviour = agent.ListenBehaviour() agent.add_behaviour(behaviour) - # Run once (CyclicBehaviour normally loops) - with caplog.at_level("ERROR"): - await behaviour.run() + await behaviour.run() - assert "No received endpoint in message, excepted ping endpoint." in caplog.text fake_socket.send_json.assert_awaited() fake_socket.recv_json.assert_awaited() @pytest.mark.asyncio -async def test_setup_unexpected_exception(zmq_context, caplog): +async def test_setup_unexpected_exception(zmq_context): fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # Simulate unexpected exception during recv_json() fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) + fake_socket.send_multipart = AsyncMock() agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, ) - with caplog.at_level("ERROR"): - await agent.setup(max_retries=1) + await agent.setup(max_retries=1) - # Ensure that the error was logged - assert "Unexpected error during negotiation: boom!" in caplog.text + assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + assert not agent.connected @pytest.mark.asyncio -async def test_setup_unpacking_exception(zmq_context, caplog): +async def test_setup_unpacking_exception(zmq_context): # --- Arrange --- fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() + fake_socket.send_multipart = AsyncMock() # Make recv_json return malformed negotiation data to trigger unpacking exception malformed_data = { @@ -534,15 +562,15 @@ async def test_setup_unpacking_exception(zmq_context, caplog): fake_agent_instance.start = AsyncMock() agent = RICommunicationAgent( - "test@server", "password", address="tcp://localhost:5555", bind=False + "test@server", + "password", + address="tcp://localhost:5555", + bind=False, ) # --- Act & Assert --- - with caplog.at_level("ERROR"): - await agent.setup(max_retries=1) - # Ensure the unpacking exception was logged - assert "Error unpacking negotiation data" in caplog.text + await agent.setup(max_retries=1) # Ensure no command agent was started fake_agent_instance.start.assert_not_awaited() diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py deleted file mode 100644 index 1c9213a..0000000 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ /dev/null @@ -1,61 +0,0 @@ -from unittest.mock import AsyncMock - -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient - -from control_backend.api.v1.endpoints import command -from control_backend.schemas.ri_message import SpeechCommand - - -@pytest.fixture -def app(): - """ - Creates a FastAPI test app and attaches the router under test. - Also sets up a mock internal_comm_socket. - """ - app = FastAPI() - app.include_router(command.router) - return app - - -@pytest.fixture -def client(app): - """Create a test client for the app.""" - return TestClient(app) - - -def test_receive_command_success(client): - """ - Test for successful reception of a command. Ensures the status code is 202 and the response body - is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the - expected data. - """ - # Arrange - mock_pub_socket = AsyncMock() - client.app.state.endpoints_pub_socket = mock_pub_socket - - command_data = {"endpoint": "actuate/speech", "data": "This is a test"} - speech_command = SpeechCommand(**command_data) - - # Act - response = client.post("/command", json=command_data) - - # Assert - assert response.status_code == 202 - assert response.json() == {"status": "Command received"} - - # Verify that the ZMQ socket was used correctly - mock_pub_socket.send_multipart.assert_awaited_once_with( - [b"command", speech_command.model_dump_json().encode()] - ) - - -def test_receive_command_invalid_payload(client): - """ - Test invalid data handling (schema validation). - """ - # Missing required field(s) - bad_payload = {"invalid": "data"} - response = client.post("/command", json=bad_payload) - assert response.status_code == 422 # validation error diff --git a/test/integration/api/endpoints/test_robot_endpoint.py b/test/integration/api/endpoints/test_robot_endpoint.py new file mode 100644 index 0000000..0f71951 --- /dev/null +++ b/test/integration/api/endpoints/test_robot_endpoint.py @@ -0,0 +1,156 @@ +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from control_backend.api.v1.endpoints import robot +from control_backend.schemas.ri_message import SpeechCommand + + +@pytest.fixture +def app(): + """ + Creates a FastAPI test app and attaches the router under test. + Also sets up a mock internal_comm_socket. + """ + app = FastAPI() + app.include_router(robot.router) + return app + + +@pytest.fixture +def client(app): + """Create a test client for the app.""" + return TestClient(app) + + +def test_receive_command_success(client): + """ + Test for successful reception of a command. Ensures the status code is 202 and the response body + is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the + expected data. + """ + # Arrange + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket + + command_data = {"endpoint": "actuate/speech", "data": "This is a test"} + speech_command = SpeechCommand(**command_data) + + # Act + response = client.post("/command", json=command_data) + + # Assert + assert response.status_code == 202 + assert response.json() == {"status": "Command received"} + + # Verify that the ZMQ socket was used correctly + mock_pub_socket.send_multipart.assert_awaited_once_with( + [b"command", speech_command.model_dump_json().encode()] + ) + + +def test_receive_command_invalid_payload(client): + """ + Test invalid data handling (schema validation). + """ + # Missing required field(s) + bad_payload = {"invalid": "data"} + response = client.post("/command", json=bad_payload) + assert response.status_code == 422 # validation error + + +def test_ping_check_returns_none(client): + """Ensure /ping_check returns 200 and None (currently unimplemented).""" + response = client.get("/ping_check") + assert response.status_code == 200 + assert response.json() is None + + +@pytest.mark.asyncio +async def test_ping_stream_yields_ping_event(monkeypatch): + """Test that ping_stream yields a proper SSE message when a ping is received.""" + mock_sub_socket = AsyncMock() + mock_sub_socket.connect = MagicMock() + mock_sub_socket.setsockopt = MagicMock() + mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"ping", b"true"]) + + mock_context = MagicMock() + mock_context.socket.return_value = mock_sub_socket + monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) + + mock_request = AsyncMock() + mock_request.is_disconnected = AsyncMock(side_effect=[False, True]) + + response = await robot.ping_stream(mock_request) + generator = aiter(response.body_iterator) + + event = await anext(generator) + event_text = event.decode() if isinstance(event, bytes) else str(event) + assert event_text.strip() == "data: true" + + with pytest.raises(StopAsyncIteration): + await anext(generator) + + mock_sub_socket.connect.assert_called_once() + mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") + mock_sub_socket.recv_multipart.assert_awaited() + + +@pytest.mark.asyncio +async def test_ping_stream_handles_timeout(monkeypatch): + """Test that ping_stream continues looping on TimeoutError.""" + mock_sub_socket = AsyncMock() + mock_sub_socket.connect = MagicMock() + mock_sub_socket.setsockopt = MagicMock() + mock_sub_socket.recv_multipart.side_effect = TimeoutError() + + mock_context = MagicMock() + mock_context.socket.return_value = mock_sub_socket + monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) + + mock_request = AsyncMock() + mock_request.is_disconnected = AsyncMock(return_value=True) + + response = await robot.ping_stream(mock_request) + generator = aiter(response.body_iterator) + + with pytest.raises(StopAsyncIteration): + await anext(generator) + + mock_sub_socket.connect.assert_called_once() + mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") + mock_sub_socket.recv_multipart.assert_awaited() + + +@pytest.mark.asyncio +async def test_ping_stream_yields_json_values(monkeypatch): + """Ensure ping_stream correctly parses and yields JSON body values.""" + mock_sub_socket = AsyncMock() + mock_sub_socket.connect = MagicMock() + mock_sub_socket.setsockopt = MagicMock() + mock_sub_socket.recv_multipart = AsyncMock( + return_value=[b"ping", json.dumps({"connected": True}).encode()] + ) + + mock_context = MagicMock() + mock_context.socket.return_value = mock_sub_socket + monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) + + mock_request = AsyncMock() + mock_request.is_disconnected = AsyncMock(side_effect=[False, True]) + + response = await robot.ping_stream(mock_request) + generator = aiter(response.body_iterator) + + event = await anext(generator) + event_text = event.decode() if isinstance(event, bytes) else str(event) + + assert "connected" in event_text + assert "true" in event_text + + mock_sub_socket.connect.assert_called_once() + mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") + mock_sub_socket.recv_multipart.assert_awaited()