feat: ui program to cb connection #26

Merged
2584433 merged 5 commits from feat/recieve-programs-ui into dev 2025-11-19 14:21:01 +00:00
2 changed files with 83 additions and 60 deletions
Showing only changes of commit 2ed2a84f13 - Show all commits

View File

@@ -1,10 +1,10 @@
import json
import logging import logging
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, HTTPException, Request
from pydantic import ValidationError
from control_backend.schemas.message import Message from control_backend.schemas.message import Message
from control_backend.schemas.program import Phase from control_backend.schemas.program import Program
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@@ -16,37 +16,20 @@ async def receive_message(program: Message, request: Request):
Receives a BehaviorProgram as a stringified JSON list inside `message`. Receives a BehaviorProgram as a stringified JSON list inside `message`.
Converts it into real Phase objects. Converts it into real Phase objects.
""" """
logger.info("Received raw program: ") logger.debug("Received raw program: %s", program)
logger.debug("%s", program)
raw_str = program.message # This is the JSON string raw_str = program.message # This is the JSON string
# Convert Json into dict. # Validate program
try: try:
program_list = json.loads(raw_str) program = Program.model_validate_json(raw_str)
except json.JSONDecodeError as e: except ValidationError as e:
logger.error("Failed to decode program JSON: %s", e) logger.error("Failed to validate program JSON: %s", e)
raise HTTPException(status_code=400, detail="Undecodeable Json string") from None raise HTTPException(status_code=400, detail="Not a valid program") from None
# Validate Phases
try:
phases: list[Phase] = [Phase(**phase) for phase in program_list]
except Exception as e:
logger.error("Failed to convert to Phase objects: %s", e)
raise HTTPException(status_code=400, detail="Non-Phase String") from None
logger.info(f"Succesfully recieved {len(phases)} Phase(s).")
for p in phases:
logger.info(
f"Phase {p.id}: "
f"{len(p.phaseData.norms)} norms, "
f"{len(p.phaseData.goals)} goals, "
f"{len(p.phaseData.triggers) if hasattr(p.phaseData, 'triggers') else 0} triggers"
)
# send away # send away
topic = b"program" topic = b"program"
body = json.dumps([p.model_dump() for p in phases]).encode("utf-8") body = program.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body]) await pub_socket.send_multipart([topic, body])
return {"status": "Program parsed", "phase_count": len(phases)} return {"status": "Program parsed"}

View File

@@ -5,8 +5,9 @@ import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import program # <-- import your router from control_backend.api.v1.endpoints import program
from control_backend.schemas.message import Message from control_backend.schemas.message import Message
from control_backend.schemas.program import Program
@pytest.fixture @pytest.fixture
@@ -23,30 +24,40 @@ def client(app):
return TestClient(app) return TestClient(app)
def make_valid_phase_dict(): def make_valid_program_dict():
"""Helper to create a valid Phase JSON structure.""" """Helper to create a valid Program JSON structure."""
return { return {
"phases": [
{
"id": "phase1", "id": "phase1",
"name": "basephase", "name": "basephase",
"nextPhaseId": "phase2", "nextPhaseId": "phase2",
"phaseData": { "phaseData": {
"norms": [{"id": "n1", "name": "norm", "value": "be nice"}], "norms": [{"id": "n1", "name": "norm", "value": "be nice"}],
"goals": [{"id": "g1", "name": "goal", "description": "test goal", "achieved": False}], "goals": [
{"id": "g1", "name": "goal", "description": "test goal", "achieved": False}
],
"triggers": [ "triggers": [
{"id": "t1", "label": "trigger", "type": "keyword", "value": ["stop", "exit"]} {
"id": "t1",
"label": "trigger",
"type": "keyword",
"value": ["stop", "exit"],
}
], ],
}, },
} }
]
}
def test_receive_program_success(client): def test_receive_program_success(client):
"""Valid program JSON string should parse and be sent via the socket.""" """Valid Program JSON should be parsed and sent through the socket."""
# Arrange
mock_pub_socket = AsyncMock() mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket client.app.state.endpoints_pub_socket = mock_pub_socket
phases_list = [make_valid_phase_dict()] program_dict = make_valid_program_dict()
message_body = json.dumps(phases_list) message_body = json.dumps(program_dict)
msg = Message(message=message_body) msg = Message(message=message_body)
# Act # Act
@@ -54,38 +65,67 @@ def test_receive_program_success(client):
# Assert # Assert
assert response.status_code == 202 assert response.status_code == 202
assert response.json() == {"status": "Program parsed", "phase_count": 1} assert response.json() == {"status": "Program parsed"}
# Check the mocked socket # Verify socket call (don't compare raw JSON string)
expected_body = json.dumps(phases_list).encode("utf-8") mock_pub_socket.send_multipart.assert_awaited_once()
mock_pub_socket.send_multipart.assert_awaited_once_with([b"program", expected_body]) args, kwargs = mock_pub_socket.send_multipart.await_args
assert args[0][0] == b"program"
sent_bytes = args[0][1]
# Decode sent bytes and compare actual structures
sent_obj = json.loads(sent_bytes.decode())
expected_obj = Program.model_validate_json(message_body).model_dump()
assert sent_obj == expected_obj
def test_receive_program_invalid_json(client): def test_receive_program_invalid_json(client):
"""Malformed JSON string should return 400 with 'Undecodeable Json string'.""" """Invalid JSON string (not parseable) should trigger HTTP 400."""
mock_pub_socket = AsyncMock() mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket client.app.state.endpoints_pub_socket = mock_pub_socket
# Not valid JSON bad_json_str = "{invalid json}"
bad_message = Message(message="{not valid json}") msg = Message(message=bad_json_str)
response = client.post("/program", json=bad_message.model_dump())
response = client.post("/program", json=msg.model_dump())
assert response.status_code == 400 assert response.status_code == 400
assert response.json()["detail"] == "Undecodeable Json string" assert response.json()["detail"] == "Not a valid program"
mock_pub_socket.send_multipart.assert_not_called() mock_pub_socket.send_multipart.assert_not_called()
def test_receive_program_invalid_phase(client): def test_receive_program_invalid_deep_structure(client):
"""Decodable JSON but invalid Phase structure should return 400 with 'Non-Phase String'.""" """Valid JSON shape but invalid deep nested data should still raise 400."""
mock_pub_socket = AsyncMock() mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket client.app.state.endpoints_pub_socket = mock_pub_socket
# Missing required Phase fields # Structurally correct Program, but with missing elements
invalid_phase = [{"id": "only_id"}] bad_program = {
bad_message = Message(message=json.dumps(invalid_phase)) "phases": [
{
"id": "phase1",
"name": "deepfail",
"nextPhaseId": "phase2",
"phaseData": {
"norms": [
{"id": "n1", "name": "norm"} # Missing "value"
],
"goals": [
{"id": "g1", "name": "goal", "description": "desc", "achieved": False}
],
"triggers": [
{"id": "t1", "label": "trigger", "type": "keyword", "value": ["start"]}
],
},
}
]
}
response = client.post("/program", json=bad_message.model_dump()) msg = Message(message=json.dumps(bad_program))
response = client.post("/program", json=msg.model_dump())
assert response.status_code == 400 assert response.status_code == 400
assert response.json()["detail"] == "Non-Phase String" assert response.json()["detail"] == "Not a valid program"
mock_pub_socket.send_multipart.assert_not_called() mock_pub_socket.send_multipart.assert_not_called()