From 1f8d7697626adea4db8d9ae17ecac7bfece645eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Otgaar?= Date: Thu, 23 Oct 2025 12:54:53 +0200 Subject: [PATCH] chore: adjust message/command structure and write unit tests ref: N25B-205 --- .../agents/ri_command_agent.py | 7 ++-- .../agents/ri_communication_agent.py | 3 +- .../api/v1/endpoints/command.py | 23 ++++++++++++ src/control_backend/schemas/ri_message.py | 20 +++++++++++ test/unit/agents/test_ri_commands_agent.py | 12 ++++--- .../agents/test_ri_communication_agent.py | 2 +- test/unit/schemas/test_ri_message.py | 35 +++++++++++++++++++ 7 files changed, 90 insertions(+), 12 deletions(-) create mode 100644 src/control_backend/api/v1/endpoints/command.py create mode 100644 src/control_backend/schemas/ri_message.py create mode 100644 test/unit/schemas/test_ri_message.py diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index a5aeda3..b11ba01 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -6,7 +6,7 @@ import zmq from control_backend.core.config import settings from control_backend.core.zmq_context import context -from control_backend.schemas.message import Message +from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) @@ -33,9 +33,8 @@ class RICommandAgent(Agent): # Try to get body try: message_json = json.loads(body.decode("utf-8")) - message = Message.model_validate(message_json) - logger.info("Received message \"%s\"", message.message) - + message = SpeechCommand.model_validate(message_json) + # Send to the robot. await self.agent.pubsocket.send_json(message) except Exception as e: diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index e9374a6..2033857 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -44,8 +44,7 @@ class RICommunicationAgent(Agent): logger.info("No ping retrieved in 3 seconds, killing myself.") self.kill() - # message = Message.model_validate(message) - logger.info("Received message \"%s\"", message) + logger.debug("Received message \"%s\"", message) if "endpoint" not in message: logger.error("No received endpoint in message, excepted ping endpoint.") return diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py new file mode 100644 index 0000000..fef07b8 --- /dev/null +++ b/src/control_backend/api/v1/endpoints/command.py @@ -0,0 +1,23 @@ +from fastapi import APIRouter, Request +import logging + +from zmq import Socket + +from control_backend.schemas.message import Message + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.post("/message", status_code=202) +async def receive_message(message: Message, request: Request): + logger.info("Received message: %s", message.message) + + topic = b"message" + body = message.model_dump_json().encode("utf-8") + + pub_socket: Socket = request.app.state.internal_comm_socket + + pub_socket.send_multipart([topic, body]) + + return {"status": "Message received"} diff --git a/src/control_backend/schemas/ri_message.py b/src/control_backend/schemas/ri_message.py new file mode 100644 index 0000000..b369703 --- /dev/null +++ b/src/control_backend/schemas/ri_message.py @@ -0,0 +1,20 @@ +from enum import Enum +from typing import Any, Literal + +from pydantic import BaseModel, Field, ValidationError + + +class RIEndpoint(str, Enum): + SPEECH = "actuate/speech" + PING = "ping" + NEGOTIATE_PORTS = "negotiate/ports" + + +class RIMessage(BaseModel): + endpoint: RIEndpoint + data: Any + + +class SpeechCommand(RIMessage): + endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH) + data: str diff --git a/test/unit/agents/test_ri_commands_agent.py b/test/unit/agents/test_ri_commands_agent.py index fc5f4aa..4ed8dc1 100644 --- a/test/unit/agents/test_ri_commands_agent.py +++ b/test/unit/agents/test_ri_commands_agent.py @@ -4,7 +4,7 @@ import json import pytest from unittest.mock import AsyncMock, MagicMock, patch from control_backend.agents.ri_command_agent import RICommandAgent -from control_backend.schemas.message import Message +from control_backend.schemas.ri_message import SpeechCommand @pytest.mark.asyncio async def test_setup_bind(monkeypatch): @@ -41,7 +41,7 @@ async def test_setup_connect(monkeypatch): fake_socket.connect.assert_any_call("tcp://localhost:5555") @pytest.mark.asyncio -async def test_send_commands_behaviour_valid_message(caplog): +async def test_send_commands_behaviour_valid_message(): """Test behaviour with valid JSON message""" fake_socket = AsyncMock() message_dict = {"message": "hello"} @@ -55,12 +55,14 @@ async def test_send_commands_behaviour_valid_message(caplog): behaviour = agent.SendCommandsBehaviour() behaviour.agent = agent - with caplog.at_level("INFO"): + with patch('control_backend.agents.ri_command_agent.SpeechCommand') as MockSpeechCommand: + mock_message = MagicMock() + MockSpeechCommand.model_validate.return_value = mock_message + await behaviour.run() fake_socket.recv_multipart.assert_awaited() - fake_socket.send_json.assert_awaited() - assert "Received message" in caplog.text + fake_socket.send_json.assert_awaited_with(mock_message) @pytest.mark.asyncio async def test_send_commands_behaviour_invalid_message(caplog): diff --git a/test/unit/agents/test_ri_communication_agent.py b/test/unit/agents/test_ri_communication_agent.py index c14a6d8..8228608 100644 --- a/test/unit/agents/test_ri_communication_agent.py +++ b/test/unit/agents/test_ri_communication_agent.py @@ -358,7 +358,7 @@ async def test_listen_behaviour_ping_correct(caplog): agent.add_behaviour(behaviour) # Run once (CyclicBehaviour normally loops) - with caplog.at_level("INFO"): + with caplog.at_level("DEBUG"): await behaviour.run() fake_socket.send_json.assert_awaited() diff --git a/test/unit/schemas/test_ri_message.py b/test/unit/schemas/test_ri_message.py new file mode 100644 index 0000000..b840f97 --- /dev/null +++ b/test/unit/schemas/test_ri_message.py @@ -0,0 +1,35 @@ +import pytest +from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand +from pydantic import ValidationError + +def valid_command_1(): + return SpeechCommand(data="Hallo?") + +def invalid_command_1(): + return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.") + +def test_valid_speech_command_1(): + command = valid_command_1() + try: + RIMessage.model_validate(command) + SpeechCommand.model_validate(command) + assert True + except ValidationError: + assert False + + +def test_invalid_speech_command_1(): + command = invalid_command_1() + passed_ri_message_validation = False + try: + # Should succeed, still. + RIMessage.model_validate(command) + passed_ri_message_validation = True + + # Should fail. + SpeechCommand.model_validate(command) + assert False + except ValidationError: + assert passed_ri_message_validation + + \ No newline at end of file