chore: adjust message/command structure and write unit tests

ref: N25B-205
This commit is contained in:
Björn Otgaar
2025-10-23 12:54:53 +02:00
parent 530fc42c50
commit 1f8d769762
7 changed files with 90 additions and 12 deletions

View File

@@ -6,7 +6,7 @@ import zmq
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context from control_backend.core.zmq_context import context
from control_backend.schemas.message import Message from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,8 +33,7 @@ class RICommandAgent(Agent):
# Try to get body # Try to get body
try: try:
message_json = json.loads(body.decode("utf-8")) message_json = json.loads(body.decode("utf-8"))
message = Message.model_validate(message_json) message = SpeechCommand.model_validate(message_json)
logger.info("Received message \"%s\"", message.message)
# Send to the robot. # Send to the robot.
await self.agent.pubsocket.send_json(message) await self.agent.pubsocket.send_json(message)

View File

@@ -44,8 +44,7 @@ class RICommunicationAgent(Agent):
logger.info("No ping retrieved in 3 seconds, killing myself.") logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill() self.kill()
# message = Message.model_validate(message) logger.debug("Received message \"%s\"", message)
logger.info("Received message \"%s\"", message)
if "endpoint" not in message: if "endpoint" not in message:
logger.error("No received endpoint in message, excepted ping endpoint.") logger.error("No received endpoint in message, excepted ping endpoint.")
return return

View File

@@ -0,0 +1,23 @@
from fastapi import APIRouter, Request
import logging
from zmq import Socket
from control_backend.schemas.message import Message
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/message", status_code=202)
async def receive_message(message: Message, request: Request):
logger.info("Received message: %s", message.message)
topic = b"message"
body = message.model_dump_json().encode("utf-8")
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, body])
return {"status": "Message received"}

View File

@@ -0,0 +1,20 @@
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel, Field, ValidationError
class RIEndpoint(str, Enum):
SPEECH = "actuate/speech"
PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports"
class RIMessage(BaseModel):
endpoint: RIEndpoint
data: Any
class SpeechCommand(RIMessage):
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str

View File

@@ -4,7 +4,7 @@ import json
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.schemas.message import Message from control_backend.schemas.ri_message import SpeechCommand
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_bind(monkeypatch): async def test_setup_bind(monkeypatch):
@@ -41,7 +41,7 @@ async def test_setup_connect(monkeypatch):
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_commands_behaviour_valid_message(caplog): async def test_send_commands_behaviour_valid_message():
"""Test behaviour with valid JSON message""" """Test behaviour with valid JSON message"""
fake_socket = AsyncMock() fake_socket = AsyncMock()
message_dict = {"message": "hello"} message_dict = {"message": "hello"}
@@ -55,12 +55,14 @@ async def test_send_commands_behaviour_valid_message(caplog):
behaviour = agent.SendCommandsBehaviour() behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent behaviour.agent = agent
with caplog.at_level("INFO"): with patch('control_backend.agents.ri_command_agent.SpeechCommand') as MockSpeechCommand:
mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message
await behaviour.run() await behaviour.run()
fake_socket.recv_multipart.assert_awaited() fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited_with(mock_message)
assert "Received message" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_commands_behaviour_invalid_message(caplog): async def test_send_commands_behaviour_invalid_message(caplog):

View File

@@ -358,7 +358,7 @@ async def test_listen_behaviour_ping_correct(caplog):
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops) # Run once (CyclicBehaviour normally loops)
with caplog.at_level("INFO"): with caplog.at_level("DEBUG"):
await behaviour.run() await behaviour.run()
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()

View File

@@ -0,0 +1,35 @@
import pytest
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
from pydantic import ValidationError
def valid_command_1():
return SpeechCommand(data="Hallo?")
def invalid_command_1():
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
def test_valid_speech_command_1():
command = valid_command_1()
try:
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
assert True
except ValidationError:
assert False
def test_invalid_speech_command_1():
command = invalid_command_1()
passed_ri_message_validation = False
try:
# Should succeed, still.
RIMessage.model_validate(command)
passed_ri_message_validation = True
# Should fail.
SpeechCommand.model_validate(command)
assert False
except ValidationError:
assert passed_ri_message_validation