chore: adjust message/command structure and write unit tests
ref: N25B-205
This commit is contained in:
@@ -6,7 +6,7 @@ import zmq
|
||||
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context
|
||||
from control_backend.schemas.message import Message
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,8 +33,7 @@ class RICommandAgent(Agent):
|
||||
# Try to get body
|
||||
try:
|
||||
message_json = json.loads(body.decode("utf-8"))
|
||||
message = Message.model_validate(message_json)
|
||||
logger.info("Received message \"%s\"", message.message)
|
||||
message = SpeechCommand.model_validate(message_json)
|
||||
|
||||
# Send to the robot.
|
||||
await self.agent.pubsocket.send_json(message)
|
||||
|
||||
@@ -44,8 +44,7 @@ class RICommunicationAgent(Agent):
|
||||
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
||||
self.kill()
|
||||
|
||||
# message = Message.model_validate(message)
|
||||
logger.info("Received message \"%s\"", message)
|
||||
logger.debug("Received message \"%s\"", message)
|
||||
if "endpoint" not in message:
|
||||
logger.error("No received endpoint in message, excepted ping endpoint.")
|
||||
return
|
||||
|
||||
23
src/control_backend/api/v1/endpoints/command.py
Normal file
23
src/control_backend/api/v1/endpoints/command.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from fastapi import APIRouter, Request
|
||||
import logging
|
||||
|
||||
from zmq import Socket
|
||||
|
||||
from control_backend.schemas.message import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/message", status_code=202)
|
||||
async def receive_message(message: Message, request: Request):
|
||||
logger.info("Received message: %s", message.message)
|
||||
|
||||
topic = b"message"
|
||||
body = message.model_dump_json().encode("utf-8")
|
||||
|
||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
||||
|
||||
pub_socket.send_multipart([topic, body])
|
||||
|
||||
return {"status": "Message received"}
|
||||
20
src/control_backend/schemas/ri_message.py
Normal file
20
src/control_backend/schemas/ri_message.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
|
||||
class RIEndpoint(str, Enum):
|
||||
SPEECH = "actuate/speech"
|
||||
PING = "ping"
|
||||
NEGOTIATE_PORTS = "negotiate/ports"
|
||||
|
||||
|
||||
class RIMessage(BaseModel):
|
||||
endpoint: RIEndpoint
|
||||
data: Any
|
||||
|
||||
|
||||
class SpeechCommand(RIMessage):
|
||||
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
|
||||
data: str
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||
from control_backend.schemas.message import Message
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(monkeypatch):
|
||||
@@ -41,7 +41,7 @@ async def test_setup_connect(monkeypatch):
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_commands_behaviour_valid_message(caplog):
|
||||
async def test_send_commands_behaviour_valid_message():
|
||||
"""Test behaviour with valid JSON message"""
|
||||
fake_socket = AsyncMock()
|
||||
message_dict = {"message": "hello"}
|
||||
@@ -55,12 +55,14 @@ async def test_send_commands_behaviour_valid_message(caplog):
|
||||
behaviour = agent.SendCommandsBehaviour()
|
||||
behaviour.agent = agent
|
||||
|
||||
with caplog.at_level("INFO"):
|
||||
with patch('control_backend.agents.ri_command_agent.SpeechCommand') as MockSpeechCommand:
|
||||
mock_message = MagicMock()
|
||||
MockSpeechCommand.model_validate.return_value = mock_message
|
||||
|
||||
await behaviour.run()
|
||||
|
||||
fake_socket.recv_multipart.assert_awaited()
|
||||
fake_socket.send_json.assert_awaited()
|
||||
assert "Received message" in caplog.text
|
||||
fake_socket.send_json.assert_awaited_with(mock_message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_commands_behaviour_invalid_message(caplog):
|
||||
|
||||
@@ -358,7 +358,7 @@ async def test_listen_behaviour_ping_correct(caplog):
|
||||
agent.add_behaviour(behaviour)
|
||||
|
||||
# Run once (CyclicBehaviour normally loops)
|
||||
with caplog.at_level("INFO"):
|
||||
with caplog.at_level("DEBUG"):
|
||||
await behaviour.run()
|
||||
|
||||
fake_socket.send_json.assert_awaited()
|
||||
|
||||
35
test/unit/schemas/test_ri_message.py
Normal file
35
test/unit/schemas/test_ri_message.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
|
||||
from pydantic import ValidationError
|
||||
|
||||
def valid_command_1():
|
||||
return SpeechCommand(data="Hallo?")
|
||||
|
||||
def invalid_command_1():
|
||||
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
|
||||
|
||||
def test_valid_speech_command_1():
|
||||
command = valid_command_1()
|
||||
try:
|
||||
RIMessage.model_validate(command)
|
||||
SpeechCommand.model_validate(command)
|
||||
assert True
|
||||
except ValidationError:
|
||||
assert False
|
||||
|
||||
|
||||
def test_invalid_speech_command_1():
|
||||
command = invalid_command_1()
|
||||
passed_ri_message_validation = False
|
||||
try:
|
||||
# Should succeed, still.
|
||||
RIMessage.model_validate(command)
|
||||
passed_ri_message_validation = True
|
||||
|
||||
# Should fail.
|
||||
SpeechCommand.model_validate(command)
|
||||
assert False
|
||||
except ValidationError:
|
||||
assert passed_ri_message_validation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user