chore: adjust message/command structure and write unit tests
ref: N25B-205
This commit is contained in:
@@ -6,7 +6,7 @@ import zmq
|
|||||||
|
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.core.zmq_context import context
|
from control_backend.core.zmq_context import context
|
||||||
from control_backend.schemas.message import Message
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,9 +33,8 @@ class RICommandAgent(Agent):
|
|||||||
# Try to get body
|
# Try to get body
|
||||||
try:
|
try:
|
||||||
message_json = json.loads(body.decode("utf-8"))
|
message_json = json.loads(body.decode("utf-8"))
|
||||||
message = Message.model_validate(message_json)
|
message = SpeechCommand.model_validate(message_json)
|
||||||
logger.info("Received message \"%s\"", message.message)
|
|
||||||
|
|
||||||
# Send to the robot.
|
# Send to the robot.
|
||||||
await self.agent.pubsocket.send_json(message)
|
await self.agent.pubsocket.send_json(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -44,8 +44,7 @@ class RICommunicationAgent(Agent):
|
|||||||
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
||||||
self.kill()
|
self.kill()
|
||||||
|
|
||||||
# message = Message.model_validate(message)
|
logger.debug("Received message \"%s\"", message)
|
||||||
logger.info("Received message \"%s\"", message)
|
|
||||||
if "endpoint" not in message:
|
if "endpoint" not in message:
|
||||||
logger.error("No received endpoint in message, excepted ping endpoint.")
|
logger.error("No received endpoint in message, excepted ping endpoint.")
|
||||||
return
|
return
|
||||||
|
|||||||
23
src/control_backend/api/v1/endpoints/command.py
Normal file
23
src/control_backend/api/v1/endpoints/command.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from fastapi import APIRouter, Request
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from zmq import Socket
|
||||||
|
|
||||||
|
from control_backend.schemas.message import Message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/message", status_code=202)
|
||||||
|
async def receive_message(message: Message, request: Request):
|
||||||
|
logger.info("Received message: %s", message.message)
|
||||||
|
|
||||||
|
topic = b"message"
|
||||||
|
body = message.model_dump_json().encode("utf-8")
|
||||||
|
|
||||||
|
pub_socket: Socket = request.app.state.internal_comm_socket
|
||||||
|
|
||||||
|
pub_socket.send_multipart([topic, body])
|
||||||
|
|
||||||
|
return {"status": "Message received"}
|
||||||
20
src/control_backend/schemas/ri_message.py
Normal file
20
src/control_backend/schemas/ri_message.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class RIEndpoint(str, Enum):
|
||||||
|
SPEECH = "actuate/speech"
|
||||||
|
PING = "ping"
|
||||||
|
NEGOTIATE_PORTS = "negotiate/ports"
|
||||||
|
|
||||||
|
|
||||||
|
class RIMessage(BaseModel):
|
||||||
|
endpoint: RIEndpoint
|
||||||
|
data: Any
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechCommand(RIMessage):
|
||||||
|
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
|
||||||
|
data: str
|
||||||
@@ -4,7 +4,7 @@ import json
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||||
from control_backend.schemas.message import Message
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_bind(monkeypatch):
|
async def test_setup_bind(monkeypatch):
|
||||||
@@ -41,7 +41,7 @@ async def test_setup_connect(monkeypatch):
|
|||||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_commands_behaviour_valid_message(caplog):
|
async def test_send_commands_behaviour_valid_message():
|
||||||
"""Test behaviour with valid JSON message"""
|
"""Test behaviour with valid JSON message"""
|
||||||
fake_socket = AsyncMock()
|
fake_socket = AsyncMock()
|
||||||
message_dict = {"message": "hello"}
|
message_dict = {"message": "hello"}
|
||||||
@@ -55,12 +55,14 @@ async def test_send_commands_behaviour_valid_message(caplog):
|
|||||||
behaviour = agent.SendCommandsBehaviour()
|
behaviour = agent.SendCommandsBehaviour()
|
||||||
behaviour.agent = agent
|
behaviour.agent = agent
|
||||||
|
|
||||||
with caplog.at_level("INFO"):
|
with patch('control_backend.agents.ri_command_agent.SpeechCommand') as MockSpeechCommand:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
MockSpeechCommand.model_validate.return_value = mock_message
|
||||||
|
|
||||||
await behaviour.run()
|
await behaviour.run()
|
||||||
|
|
||||||
fake_socket.recv_multipart.assert_awaited()
|
fake_socket.recv_multipart.assert_awaited()
|
||||||
fake_socket.send_json.assert_awaited()
|
fake_socket.send_json.assert_awaited_with(mock_message)
|
||||||
assert "Received message" in caplog.text
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_commands_behaviour_invalid_message(caplog):
|
async def test_send_commands_behaviour_invalid_message(caplog):
|
||||||
|
|||||||
@@ -358,7 +358,7 @@ async def test_listen_behaviour_ping_correct(caplog):
|
|||||||
agent.add_behaviour(behaviour)
|
agent.add_behaviour(behaviour)
|
||||||
|
|
||||||
# Run once (CyclicBehaviour normally loops)
|
# Run once (CyclicBehaviour normally loops)
|
||||||
with caplog.at_level("INFO"):
|
with caplog.at_level("DEBUG"):
|
||||||
await behaviour.run()
|
await behaviour.run()
|
||||||
|
|
||||||
fake_socket.send_json.assert_awaited()
|
fake_socket.send_json.assert_awaited()
|
||||||
|
|||||||
35
test/unit/schemas/test_ri_message.py
Normal file
35
test/unit/schemas/test_ri_message.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import pytest
|
||||||
|
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
def valid_command_1():
|
||||||
|
return SpeechCommand(data="Hallo?")
|
||||||
|
|
||||||
|
def invalid_command_1():
|
||||||
|
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
|
||||||
|
|
||||||
|
def test_valid_speech_command_1():
|
||||||
|
command = valid_command_1()
|
||||||
|
try:
|
||||||
|
RIMessage.model_validate(command)
|
||||||
|
SpeechCommand.model_validate(command)
|
||||||
|
assert True
|
||||||
|
except ValidationError:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_speech_command_1():
|
||||||
|
command = invalid_command_1()
|
||||||
|
passed_ri_message_validation = False
|
||||||
|
try:
|
||||||
|
# Should succeed, still.
|
||||||
|
RIMessage.model_validate(command)
|
||||||
|
passed_ri_message_validation = True
|
||||||
|
|
||||||
|
# Should fail.
|
||||||
|
SpeechCommand.model_validate(command)
|
||||||
|
assert False
|
||||||
|
except ValidationError:
|
||||||
|
assert passed_ri_message_validation
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user