style: compacted program and reworked tests
ref: N25B-198
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from control_backend.schemas.message import Message
|
from control_backend.schemas.message import Message
|
||||||
from control_backend.schemas.program import Phase
|
from control_backend.schemas.program import Program
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -16,37 +16,20 @@ async def receive_message(program: Message, request: Request):
|
|||||||
Receives a BehaviorProgram as a stringified JSON list inside `message`.
|
Receives a BehaviorProgram as a stringified JSON list inside `message`.
|
||||||
Converts it into real Phase objects.
|
Converts it into real Phase objects.
|
||||||
"""
|
"""
|
||||||
logger.info("Received raw program: ")
|
logger.debug("Received raw program: %s", program)
|
||||||
logger.debug("%s", program)
|
|
||||||
raw_str = program.message # This is the JSON string
|
raw_str = program.message # This is the JSON string
|
||||||
|
|
||||||
# Convert Json into dict.
|
# Validate program
|
||||||
try:
|
try:
|
||||||
program_list = json.loads(raw_str)
|
program = Program.model_validate_json(raw_str)
|
||||||
except json.JSONDecodeError as e:
|
except ValidationError as e:
|
||||||
logger.error("Failed to decode program JSON: %s", e)
|
logger.error("Failed to validate program JSON: %s", e)
|
||||||
raise HTTPException(status_code=400, detail="Undecodeable Json string") from None
|
raise HTTPException(status_code=400, detail="Not a valid program") from None
|
||||||
|
|
||||||
# Validate Phases
|
|
||||||
try:
|
|
||||||
phases: list[Phase] = [Phase(**phase) for phase in program_list]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to convert to Phase objects: %s", e)
|
|
||||||
raise HTTPException(status_code=400, detail="Non-Phase String") from None
|
|
||||||
|
|
||||||
logger.info(f"Succesfully recieved {len(phases)} Phase(s).")
|
|
||||||
for p in phases:
|
|
||||||
logger.info(
|
|
||||||
f"Phase {p.id}: "
|
|
||||||
f"{len(p.phaseData.norms)} norms, "
|
|
||||||
f"{len(p.phaseData.goals)} goals, "
|
|
||||||
f"{len(p.phaseData.triggers) if hasattr(p.phaseData, 'triggers') else 0} triggers"
|
|
||||||
)
|
|
||||||
|
|
||||||
# send away
|
# send away
|
||||||
topic = b"program"
|
topic = b"program"
|
||||||
body = json.dumps([p.model_dump() for p in phases]).encode("utf-8")
|
body = program.model_dump_json().encode()
|
||||||
pub_socket = request.app.state.endpoints_pub_socket
|
pub_socket = request.app.state.endpoints_pub_socket
|
||||||
await pub_socket.send_multipart([topic, body])
|
await pub_socket.send_multipart([topic, body])
|
||||||
|
|
||||||
return {"status": "Program parsed", "phase_count": len(phases)}
|
return {"status": "Program parsed"}
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ import pytest
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from control_backend.api.v1.endpoints import program # <-- import your router
|
from control_backend.api.v1.endpoints import program
|
||||||
from control_backend.schemas.message import Message
|
from control_backend.schemas.message import Message
|
||||||
|
from control_backend.schemas.program import Program
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -23,30 +24,40 @@ def client(app):
|
|||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
def make_valid_phase_dict():
|
def make_valid_program_dict():
|
||||||
"""Helper to create a valid Phase JSON structure."""
|
"""Helper to create a valid Program JSON structure."""
|
||||||
return {
|
return {
|
||||||
"id": "phase1",
|
"phases": [
|
||||||
"name": "basephase",
|
{
|
||||||
"nextPhaseId": "phase2",
|
"id": "phase1",
|
||||||
"phaseData": {
|
"name": "basephase",
|
||||||
"norms": [{"id": "n1", "name": "norm", "value": "be nice"}],
|
"nextPhaseId": "phase2",
|
||||||
"goals": [{"id": "g1", "name": "goal", "description": "test goal", "achieved": False}],
|
"phaseData": {
|
||||||
"triggers": [
|
"norms": [{"id": "n1", "name": "norm", "value": "be nice"}],
|
||||||
{"id": "t1", "label": "trigger", "type": "keyword", "value": ["stop", "exit"]}
|
"goals": [
|
||||||
],
|
{"id": "g1", "name": "goal", "description": "test goal", "achieved": False}
|
||||||
},
|
],
|
||||||
|
"triggers": [
|
||||||
|
{
|
||||||
|
"id": "t1",
|
||||||
|
"label": "trigger",
|
||||||
|
"type": "keyword",
|
||||||
|
"value": ["stop", "exit"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_receive_program_success(client):
|
def test_receive_program_success(client):
|
||||||
"""Valid program JSON string should parse and be sent via the socket."""
|
"""Valid Program JSON should be parsed and sent through the socket."""
|
||||||
# Arrange
|
|
||||||
mock_pub_socket = AsyncMock()
|
mock_pub_socket = AsyncMock()
|
||||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
phases_list = [make_valid_phase_dict()]
|
program_dict = make_valid_program_dict()
|
||||||
message_body = json.dumps(phases_list)
|
message_body = json.dumps(program_dict)
|
||||||
msg = Message(message=message_body)
|
msg = Message(message=message_body)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@@ -54,38 +65,67 @@ def test_receive_program_success(client):
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
assert response.json() == {"status": "Program parsed", "phase_count": 1}
|
assert response.json() == {"status": "Program parsed"}
|
||||||
|
|
||||||
# Check the mocked socket
|
# Verify socket call (don't compare raw JSON string)
|
||||||
expected_body = json.dumps(phases_list).encode("utf-8")
|
mock_pub_socket.send_multipart.assert_awaited_once()
|
||||||
mock_pub_socket.send_multipart.assert_awaited_once_with([b"program", expected_body])
|
args, kwargs = mock_pub_socket.send_multipart.await_args
|
||||||
|
|
||||||
|
assert args[0][0] == b"program"
|
||||||
|
sent_bytes = args[0][1]
|
||||||
|
|
||||||
|
# Decode sent bytes and compare actual structures
|
||||||
|
sent_obj = json.loads(sent_bytes.decode())
|
||||||
|
expected_obj = Program.model_validate_json(message_body).model_dump()
|
||||||
|
|
||||||
|
assert sent_obj == expected_obj
|
||||||
|
|
||||||
|
|
||||||
def test_receive_program_invalid_json(client):
|
def test_receive_program_invalid_json(client):
|
||||||
"""Malformed JSON string should return 400 with 'Undecodeable Json string'."""
|
"""Invalid JSON string (not parseable) should trigger HTTP 400."""
|
||||||
mock_pub_socket = AsyncMock()
|
mock_pub_socket = AsyncMock()
|
||||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
# Not valid JSON
|
bad_json_str = "{invalid json}"
|
||||||
bad_message = Message(message="{not valid json}")
|
msg = Message(message=bad_json_str)
|
||||||
response = client.post("/program", json=bad_message.model_dump())
|
|
||||||
|
response = client.post("/program", json=msg.model_dump())
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert response.json()["detail"] == "Undecodeable Json string"
|
assert response.json()["detail"] == "Not a valid program"
|
||||||
mock_pub_socket.send_multipart.assert_not_called()
|
mock_pub_socket.send_multipart.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_receive_program_invalid_phase(client):
|
def test_receive_program_invalid_deep_structure(client):
|
||||||
"""Decodable JSON but invalid Phase structure should return 400 with 'Non-Phase String'."""
|
"""Valid JSON shape but invalid deep nested data should still raise 400."""
|
||||||
mock_pub_socket = AsyncMock()
|
mock_pub_socket = AsyncMock()
|
||||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
# Missing required Phase fields
|
# Structurally correct Program, but with missing elements
|
||||||
invalid_phase = [{"id": "only_id"}]
|
bad_program = {
|
||||||
bad_message = Message(message=json.dumps(invalid_phase))
|
"phases": [
|
||||||
|
{
|
||||||
|
"id": "phase1",
|
||||||
|
"name": "deepfail",
|
||||||
|
"nextPhaseId": "phase2",
|
||||||
|
"phaseData": {
|
||||||
|
"norms": [
|
||||||
|
{"id": "n1", "name": "norm"} # Missing "value"
|
||||||
|
],
|
||||||
|
"goals": [
|
||||||
|
{"id": "g1", "name": "goal", "description": "desc", "achieved": False}
|
||||||
|
],
|
||||||
|
"triggers": [
|
||||||
|
{"id": "t1", "label": "trigger", "type": "keyword", "value": ["start"]}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
response = client.post("/program", json=bad_message.model_dump())
|
msg = Message(message=json.dumps(bad_program))
|
||||||
|
response = client.post("/program", json=msg.model_dump())
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert response.json()["detail"] == "Non-Phase String"
|
assert response.json()["detail"] == "Not a valid program"
|
||||||
mock_pub_socket.send_multipart.assert_not_called()
|
mock_pub_socket.send_multipart.assert_not_called()
|
||||||
|
|||||||
Reference in New Issue
Block a user