From 2ed2a84f130017af23882423f279d6b6486cf93d Mon Sep 17 00:00:00 2001 From: JobvAlewijk Date: Wed, 12 Nov 2025 18:04:39 +0100 Subject: [PATCH] style: compacted program and reworked tests ref: N25B-198 --- .../api/v1/endpoints/program.py | 37 ++---- .../api/endpoints/test_program_endpoint.py | 106 ++++++++++++------ 2 files changed, 83 insertions(+), 60 deletions(-) diff --git a/src/control_backend/api/v1/endpoints/program.py b/src/control_backend/api/v1/endpoints/program.py index b711a58..e9812ea 100644 --- a/src/control_backend/api/v1/endpoints/program.py +++ b/src/control_backend/api/v1/endpoints/program.py @@ -1,10 +1,10 @@ -import json import logging from fastapi import APIRouter, HTTPException, Request +from pydantic import ValidationError from control_backend.schemas.message import Message -from control_backend.schemas.program import Phase +from control_backend.schemas.program import Program logger = logging.getLogger(__name__) router = APIRouter() @@ -16,37 +16,20 @@ async def receive_message(program: Message, request: Request): Receives a BehaviorProgram as a stringified JSON list inside `message`. Converts it into real Phase objects. """ - logger.info("Received raw program: ") - logger.debug("%s", program) + logger.debug("Received raw program: %s", program) raw_str = program.message # This is the JSON string - # Convert Json into dict. + # Validate program try: - program_list = json.loads(raw_str) - except json.JSONDecodeError as e: - logger.error("Failed to decode program JSON: %s", e) - raise HTTPException(status_code=400, detail="Undecodeable Json string") from None - - # Validate Phases - try: - phases: list[Phase] = [Phase(**phase) for phase in program_list] - except Exception as e: - logger.error("Failed to convert to Phase objects: %s", e) - raise HTTPException(status_code=400, detail="Non-Phase String") from None - - logger.info(f"Succesfully recieved {len(phases)} Phase(s).") - for p in phases: - logger.info( - f"Phase {p.id}: " - f"{len(p.phaseData.norms)} norms, " - f"{len(p.phaseData.goals)} goals, " - f"{len(p.phaseData.triggers) if hasattr(p.phaseData, 'triggers') else 0} triggers" - ) + program = Program.model_validate_json(raw_str) + except ValidationError as e: + logger.error("Failed to validate program JSON: %s", e) + raise HTTPException(status_code=400, detail="Not a valid program") from None # send away topic = b"program" - body = json.dumps([p.model_dump() for p in phases]).encode("utf-8") + body = program.model_dump_json().encode() pub_socket = request.app.state.endpoints_pub_socket await pub_socket.send_multipart([topic, body]) - return {"status": "Program parsed", "phase_count": len(phases)} + return {"status": "Program parsed"} diff --git a/test/integration/api/endpoints/test_program_endpoint.py b/test/integration/api/endpoints/test_program_endpoint.py index 689961f..05ce63c 100644 --- a/test/integration/api/endpoints/test_program_endpoint.py +++ b/test/integration/api/endpoints/test_program_endpoint.py @@ -5,8 +5,9 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from control_backend.api.v1.endpoints import program # <-- import your router +from control_backend.api.v1.endpoints import program from control_backend.schemas.message import Message +from control_backend.schemas.program import Program @pytest.fixture @@ -23,30 +24,40 @@ def client(app): return TestClient(app) -def make_valid_phase_dict(): - """Helper to create a valid Phase JSON structure.""" +def make_valid_program_dict(): + """Helper to create a valid Program JSON structure.""" return { - "id": "phase1", - "name": "basephase", - "nextPhaseId": "phase2", - "phaseData": { - "norms": [{"id": "n1", "name": "norm", "value": "be nice"}], - "goals": [{"id": "g1", "name": "goal", "description": "test goal", "achieved": False}], - "triggers": [ - {"id": "t1", "label": "trigger", "type": "keyword", "value": ["stop", "exit"]} - ], - }, + "phases": [ + { + "id": "phase1", + "name": "basephase", + "nextPhaseId": "phase2", + "phaseData": { + "norms": [{"id": "n1", "name": "norm", "value": "be nice"}], + "goals": [ + {"id": "g1", "name": "goal", "description": "test goal", "achieved": False} + ], + "triggers": [ + { + "id": "t1", + "label": "trigger", + "type": "keyword", + "value": ["stop", "exit"], + } + ], + }, + } + ] } def test_receive_program_success(client): - """Valid program JSON string should parse and be sent via the socket.""" - # Arrange + """Valid Program JSON should be parsed and sent through the socket.""" mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket - phases_list = [make_valid_phase_dict()] - message_body = json.dumps(phases_list) + program_dict = make_valid_program_dict() + message_body = json.dumps(program_dict) msg = Message(message=message_body) # Act @@ -54,38 +65,67 @@ def test_receive_program_success(client): # Assert assert response.status_code == 202 - assert response.json() == {"status": "Program parsed", "phase_count": 1} + assert response.json() == {"status": "Program parsed"} - # Check the mocked socket - expected_body = json.dumps(phases_list).encode("utf-8") - mock_pub_socket.send_multipart.assert_awaited_once_with([b"program", expected_body]) + # Verify socket call (don't compare raw JSON string) + mock_pub_socket.send_multipart.assert_awaited_once() + args, kwargs = mock_pub_socket.send_multipart.await_args + + assert args[0][0] == b"program" + sent_bytes = args[0][1] + + # Decode sent bytes and compare actual structures + sent_obj = json.loads(sent_bytes.decode()) + expected_obj = Program.model_validate_json(message_body).model_dump() + + assert sent_obj == expected_obj def test_receive_program_invalid_json(client): - """Malformed JSON string should return 400 with 'Undecodeable Json string'.""" + """Invalid JSON string (not parseable) should trigger HTTP 400.""" mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket - # Not valid JSON - bad_message = Message(message="{not valid json}") - response = client.post("/program", json=bad_message.model_dump()) + bad_json_str = "{invalid json}" + msg = Message(message=bad_json_str) + + response = client.post("/program", json=msg.model_dump()) assert response.status_code == 400 - assert response.json()["detail"] == "Undecodeable Json string" + assert response.json()["detail"] == "Not a valid program" mock_pub_socket.send_multipart.assert_not_called() -def test_receive_program_invalid_phase(client): - """Decodable JSON but invalid Phase structure should return 400 with 'Non-Phase String'.""" +def test_receive_program_invalid_deep_structure(client): + """Valid JSON shape but invalid deep nested data should still raise 400.""" mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket - # Missing required Phase fields - invalid_phase = [{"id": "only_id"}] - bad_message = Message(message=json.dumps(invalid_phase)) + # Structurally correct Program, but with missing elements + bad_program = { + "phases": [ + { + "id": "phase1", + "name": "deepfail", + "nextPhaseId": "phase2", + "phaseData": { + "norms": [ + {"id": "n1", "name": "norm"} # Missing "value" + ], + "goals": [ + {"id": "g1", "name": "goal", "description": "desc", "achieved": False} + ], + "triggers": [ + {"id": "t1", "label": "trigger", "type": "keyword", "value": ["start"]} + ], + }, + } + ] + } - response = client.post("/program", json=bad_message.model_dump()) + msg = Message(message=json.dumps(bad_program)) + response = client.post("/program", json=msg.model_dump()) assert response.status_code == 400 - assert response.json()["detail"] == "Non-Phase String" + assert response.json()["detail"] == "Not a valid program" mock_pub_socket.send_multipart.assert_not_called()