From a2a04740e51c978a4baf863abf05a7bfc315b056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Otgaar?= Date: Thu, 23 Oct 2025 16:45:41 +0200 Subject: [PATCH] chore: add unit test for router and implement command router ref: N25B-205 --- .../agents/ri_command_agent.py | 6 +- src/control_backend/agents/test_agent.py | 37 ----------- .../api/v1/endpoints/command.py | 19 +++--- src/control_backend/api/v1/router.py | 7 ++- .../api/endpoints/test_command_endpoint.py | 62 +++++++++++++++++++ 5 files changed, 79 insertions(+), 52 deletions(-) delete mode 100644 src/control_backend/agents/test_agent.py create mode 100644 test/unit/api/endpoints/test_command_endpoint.py diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index b11ba01..7ca1bf9 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -32,9 +32,9 @@ class RICommandAgent(Agent): # Try to get body try: - message_json = json.loads(body.decode("utf-8")) - message = SpeechCommand.model_validate(message_json) - + message = SpeechCommand.model_validate(body) + + # Send to the robot. await self.agent.pubsocket.send_json(message) except Exception as e: diff --git a/src/control_backend/agents/test_agent.py b/src/control_backend/agents/test_agent.py deleted file mode 100644 index 749c96b..0000000 --- a/src/control_backend/agents/test_agent.py +++ /dev/null @@ -1,37 +0,0 @@ -import json -import logging -from spade.agent import Agent -from spade.behaviour import CyclicBehaviour -import zmq - -from control_backend.core.config import settings -from control_backend.core.zmq_context import context -from control_backend.schemas.message import Message - -logger = logging.getLogger(__name__) - -class TestAgent(Agent): - socket: zmq.Socket - - class ListenBehaviour(CyclicBehaviour): - async def run(self): - assert self.agent is not None - topic, body = await self.agent.socket.recv_multipart() - - try: - message_json = json.loads(body.decode("utf-8")) - message = Message.model_validate(message_json) - logger.info("Received message \"%s\"", message.message) - except Exception as e: - logger.error("Error processing message: %s", e) - - async def setup(self): - logger.info("Setting up %s", self.jid) - self.socket = context.socket(zmq.SUB) - self.socket.connect(settings.zmq_settings.internal_comm_address) - self.socket.setsockopt(zmq.SUBSCRIBE, b"message") - - b = self.ListenBehaviour() - self.add_behaviour(b) - - logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py index fef07b8..60cdf46 100644 --- a/src/control_backend/api/v1/endpoints/command.py +++ b/src/control_backend/api/v1/endpoints/command.py @@ -3,21 +3,18 @@ import logging from zmq import Socket -from control_backend.schemas.message import Message +from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint logger = logging.getLogger(__name__) router = APIRouter() -@router.post("/message", status_code=202) -async def receive_message(message: Message, request: Request): - logger.info("Received message: %s", message.message) - - topic = b"message" - body = message.model_dump_json().encode("utf-8") - +@router.post("/command", status_code=202) +async def receive_command(command: SpeechCommand, request: Request): + # Validate and retrieve data. + SpeechCommand.model_validate(command) + topic = b"command" pub_socket: Socket = request.app.state.internal_comm_socket + pub_socket.send_multipart([topic, command]) - pub_socket.send_multipart([topic, body]) - - return {"status": "Message received"} + return {"status": "Command received"} diff --git a/src/control_backend/api/v1/router.py b/src/control_backend/api/v1/router.py index 2a17ab5..b7a6d5f 100644 --- a/src/control_backend/api/v1/router.py +++ b/src/control_backend/api/v1/router.py @@ -1,6 +1,6 @@ from fastapi.routing import APIRouter -from control_backend.api.v1.endpoints import message, sse +from control_backend.api.v1.endpoints import message, sse, command api_router = APIRouter() @@ -12,4 +12,9 @@ api_router.include_router( api_router.include_router( sse.router, tags=["SSE"] +) + +api_router.include_router( + command.router, + tags=["Commands"] ) \ No newline at end of file diff --git a/test/unit/api/endpoints/test_command_endpoint.py b/test/unit/api/endpoints/test_command_endpoint.py new file mode 100644 index 0000000..3ab1be3 --- /dev/null +++ b/test/unit/api/endpoints/test_command_endpoint.py @@ -0,0 +1,62 @@ +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from unittest.mock import MagicMock + +from control_backend.api.v1.endpoints import command +from control_backend.schemas.ri_message import SpeechCommand + +@pytest.fixture +def app(): + """ + Creates a FastAPI test app and attaches the router under test. + Also sets up a mock internal_comm_socket. + """ + app = FastAPI() + app.include_router(command.router) + app.state.internal_comm_socket = MagicMock() # mock ZMQ socket + return app + + +@pytest.fixture +def client(app): + """Create a test client for the app.""" + return TestClient(app) + + +def test_receive_command_endpoint(client, app): + """ + Test that a POST to /command sends the right multipart message + and returns a 202 with the expected JSON body. + """ + mock_socket = app.state.internal_comm_socket + + # Prepare test payload that matches SpeechCommand + payload = {"endpoint": "actuate/speech", "data": "yooo"} + + # Send POST request + response = client.post("/command", json=payload) + + # Check response + assert response.status_code == 202 + assert response.json() == {"status": "Command received"} + + # Verify that the socket was called with the correct data + assert mock_socket.send_multipart.called, "Socket should be used to send data" + + args, kwargs = mock_socket.send_multipart.call_args + sent_data = args[0] + + assert sent_data[0] == b"command" + # Check JSON encoding roughly matches + assert isinstance(sent_data[1], SpeechCommand) + + +def test_receive_command_invalid_payload(client): + """ + Test invalid data handling (schema validation). + """ + # Missing required field(s) + bad_payload = {"invalid": "data"} + response = client.post("/command", json=bad_payload) + assert response.status_code == 422 # validation error \ No newline at end of file