import json from unittest.mock import AsyncMock import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from control_backend.api.v1.endpoints import program # <-- import your router from control_backend.schemas.message import Message @pytest.fixture def app(): """Create a FastAPI app with the /program route and mock socket.""" app = FastAPI() app.include_router(program.router) return app @pytest.fixture def client(app): """Create a TestClient.""" return TestClient(app) def make_valid_phase_dict(): """Helper to create a valid Phase JSON structure.""" return { "id": "phase1", "name": "basephase", "nextPhaseId": "phase2", "phaseData": { "norms": [{"id": "n1", "name": "norm", "value": "be nice"}], "goals": [{"id": "g1", "name": "goal", "description": "test goal", "achieved": False}], "triggers": [ {"id": "t1", "label": "trigger", "type": "keyword", "value": ["stop", "exit"]} ], }, } def test_receive_program_success(client): """Valid program JSON string should parse and be sent via the socket.""" # Arrange mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket phases_list = [make_valid_phase_dict()] message_body = json.dumps(phases_list) msg = Message(message=message_body) # Act response = client.post("/program", json=msg.model_dump()) # Assert assert response.status_code == 202 assert response.json() == {"status": "Program parsed", "phase_count": 1} # Check the mocked socket expected_body = json.dumps(phases_list).encode("utf-8") mock_pub_socket.send_multipart.assert_awaited_once_with([b"program", expected_body]) def test_receive_program_invalid_json(client): """Malformed JSON string should return 400 with 'Undecodeable Json string'.""" mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Not valid JSON bad_message = Message(message="{not valid json}") response = client.post("/program", json=bad_message.model_dump()) assert response.status_code == 400 assert response.json()["detail"] == "Undecodeable Json string" mock_pub_socket.send_multipart.assert_not_called() def test_receive_program_invalid_phase(client): """Decodable JSON but invalid Phase structure should return 400 with 'Non-Phase String'.""" mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Missing required Phase fields invalid_phase = [{"id": "only_id"}] bad_message = Message(message=json.dumps(invalid_phase)) response = client.post("/program", json=bad_message.model_dump()) assert response.status_code == 400 assert response.json()["detail"] == "Non-Phase String" mock_pub_socket.send_multipart.assert_not_called()