import json from unittest.mock import AsyncMock, MagicMock, patch import pytest import zmq.asyncio from fastapi import FastAPI from fastapi.testclient import TestClient from control_backend.api.v1.endpoints import robot from control_backend.schemas.ri_message import GestureCommand, SpeechCommand @pytest.fixture def app(): """ Creates a FastAPI test app and attaches the router under test. Also sets up a mock internal_comm_socket. """ app = FastAPI() app.include_router(robot.router) return app @pytest.fixture def client(app): """Create a test client for the app.""" return TestClient(app) @pytest.fixture def mock_zmq_context(): """Mock the ZMQ context.""" with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context: context_instance = MagicMock() mock_context.return_value = context_instance yield context_instance @pytest.fixture def mock_sockets(mock_zmq_context): """Mock ZMQ sockets.""" mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket) mock_pub_socket = AsyncMock(spec=zmq.asyncio.Socket) mock_zmq_context.socket.return_value = mock_sub_socket return {"sub": mock_sub_socket, "pub": mock_pub_socket} def test_receive_speech_command_success(client): """ Test for successful reception of a command. Ensures the status code is 202 and the response body is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data. """ # Arrange mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket command_data = {"endpoint": "actuate/speech", "data": "This is a test"} speech_command = SpeechCommand(**command_data) # Act response = client.post("/command", json=command_data) # Assert assert response.status_code == 202 assert response.json() == {"status": "Command received"} # Verify that the ZMQ socket was used correctly mock_pub_socket.send_multipart.assert_awaited_once_with( [b"command", speech_command.model_dump_json().encode()] ) def test_receive_gesture_command_success(client): """ Test for successful reception of a command. Ensures the status code is 202 and the response body is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data. """ # Arrange mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket command_data = {"endpoint": "actuate/gesture/tag", "data": "happy"} gesture_command = GestureCommand(**command_data) # Act response = client.post("/command", json=command_data) # Assert assert response.status_code == 202 assert response.json() == {"status": "Command received"} # Verify that the ZMQ socket was used correctly mock_pub_socket.send_multipart.assert_awaited_once_with( [b"command", gesture_command.model_dump_json().encode()] ) def test_receive_command_invalid_payload(client): """ Test invalid data handling (schema validation). """ # Missing required field(s) bad_payload = {"invalid": "data"} response = client.post("/command", json=bad_payload) assert response.status_code == 422 # validation error def test_ping_check_returns_none(client): """Ensure /ping_check returns 200 and None (currently unimplemented).""" response = client.get("/ping_check") assert response.status_code == 200 assert response.json() is None # TODO: Convert these mock sockets to the fixture. @pytest.mark.asyncio async def test_ping_stream_yields_ping_event(monkeypatch): """Test that ping_stream yields a proper SSE message when a ping is received.""" mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"ping", b"true"]) mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_request = AsyncMock() mock_request.is_disconnected = AsyncMock(side_effect=[False, True]) response = await robot.ping_stream(mock_request) generator = aiter(response.body_iterator) event = await anext(generator) event_text = event.decode() if isinstance(event, bytes) else str(event) assert event_text.strip() == "data: true" with pytest.raises(StopAsyncIteration): await anext(generator) mock_sub_socket.connect.assert_called_once() mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") mock_sub_socket.recv_multipart.assert_awaited() @pytest.mark.asyncio async def test_ping_stream_handles_timeout(monkeypatch): """Test that ping_stream continues looping on TimeoutError.""" mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() mock_sub_socket.recv_multipart.side_effect = TimeoutError() mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_request = AsyncMock() mock_request.is_disconnected = AsyncMock(return_value=True) response = await robot.ping_stream(mock_request) generator = aiter(response.body_iterator) with pytest.raises(StopAsyncIteration): await anext(generator) mock_sub_socket.connect.assert_called_once() mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") mock_sub_socket.recv_multipart.assert_awaited() @pytest.mark.asyncio async def test_ping_stream_yields_json_values(monkeypatch): """Ensure ping_stream correctly parses and yields JSON body values.""" mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() mock_sub_socket.recv_multipart = AsyncMock( return_value=[b"ping", json.dumps({"connected": True}).encode()] ) mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_request = AsyncMock() mock_request.is_disconnected = AsyncMock(side_effect=[False, True]) response = await robot.ping_stream(mock_request) generator = aiter(response.body_iterator) event = await anext(generator) event_text = event.decode() if isinstance(event, bytes) else str(event) assert "connected" in event_text assert "true" in event_text mock_sub_socket.connect.assert_called_once() mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") mock_sub_socket.recv_multipart.assert_awaited() # New tests for get_available_gesture_tags endpoint @pytest.mark.asyncio async def test_get_available_gesture_tags_success(client, monkeypatch): """ Test successful retrieval of available gesture tags. """ # Arrange mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() # Simulate a response with gesture tags response_data = {"tags": ["wave", "nod", "point", "dance"]} mock_sub_socket.recv_multipart = AsyncMock( return_value=[b"get_gestures", json.dumps(response_data).encode()] ) mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Mock settings mock_settings = MagicMock() mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555" monkeypatch.setattr(robot, "settings", mock_settings) # Mock logger to avoid actual logging mock_logger = MagicMock() monkeypatch.setattr(robot.logger, "debug", mock_logger) # Act response = client.get("/get_available_gesture_tags") # Assert assert response.status_code == 200 assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]} # Verify ZeroMQ interactions mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555") mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures") mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""]) mock_sub_socket.recv_multipart.assert_awaited_once() @pytest.mark.asyncio async def test_get_available_gesture_tags_with_amount(client, monkeypatch): """ Test retrieval of gesture tags with a specific amount parameter. This tests the TODO in the endpoint about getting a certain amount from the UI. """ # Arrange mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() # Simulate a response with gesture tags response_data = {"tags": ["wave", "nod"]} mock_sub_socket.recv_multipart = AsyncMock( return_value=[b"get_gestures", json.dumps(response_data).encode()] ) mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Mock settings mock_settings = MagicMock() mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555" monkeypatch.setattr(robot, "settings", mock_settings) # Mock logger mock_logger = MagicMock() monkeypatch.setattr(robot.logger, "debug", mock_logger) # Act - Note: The endpoint currently doesn't support query parameters for amount, # but we're testing what happens if the UI sends an amount (the TODO in the code) # For now, we test the current behavior response = client.get("/get_available_gesture_tags") # Assert assert response.status_code == 200 assert response.json() == {"available_gesture_tags": ["wave", "nod"]} # The endpoint currently doesn't use the amount parameter, so it should send empty bytes mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""]) @pytest.mark.asyncio async def test_get_available_gesture_tags_timeout(client, monkeypatch): """ Test timeout scenario when fetching gesture tags. """ # Arrange mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() # Simulate a timeout mock_sub_socket.recv_multipart = AsyncMock(side_effect=TimeoutError) mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Mock settings mock_settings = MagicMock() mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555" monkeypatch.setattr(robot, "settings", mock_settings) # Mock logger to verify debug message is logged mock_logger = MagicMock() monkeypatch.setattr(robot.logger, "debug", mock_logger) # Act response = client.get("/get_available_gesture_tags") # Assert assert response.status_code == 200 # On timeout, body becomes b"" and json.loads(b"") raises JSONDecodeError # But looking at the endpoint code, it will try to parse empty bytes which will fail # Let's check what actually happens assert response.json() == {"available_gesture_tags": []} # Verify the timeout was logged mock_logger.assert_called_once_with("got timeout error fetching gestures") # Verify ZeroMQ interactions mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555") mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures") mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""]) mock_sub_socket.recv_multipart.assert_awaited_once() @pytest.mark.asyncio async def test_get_available_gesture_tags_empty_response(client, monkeypatch): """ Test scenario when response contains no tags. """ # Arrange mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() # Simulate a response with empty tags response_data = {"tags": []} mock_sub_socket.recv_multipart = AsyncMock( return_value=[b"get_gestures", json.dumps(response_data).encode()] ) mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Mock settings mock_settings = MagicMock() mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555" monkeypatch.setattr(robot, "settings", mock_settings) # Act response = client.get("/get_available_gesture_tags") # Assert assert response.status_code == 200 assert response.json() == {"available_gesture_tags": []} @pytest.mark.asyncio async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch): """ Test scenario when response JSON doesn't contain 'tags' key. """ # Arrange mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() # Simulate a response without 'tags' key response_data = {"some_other_key": "value"} mock_sub_socket.recv_multipart = AsyncMock( return_value=[b"get_gestures", json.dumps(response_data).encode()] ) mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Mock settings mock_settings = MagicMock() mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555" monkeypatch.setattr(robot, "settings", mock_settings) # Act response = client.get("/get_available_gesture_tags") # Assert assert response.status_code == 200 # .get("tags", []) should return empty list if 'tags' key is missing assert response.json() == {"available_gesture_tags": []} @pytest.mark.asyncio async def test_get_available_gesture_tags_invalid_json(client, monkeypatch): """ Test scenario when response contains invalid JSON. """ # Arrange mock_sub_socket = AsyncMock() mock_sub_socket.connect = MagicMock() mock_sub_socket.setsockopt = MagicMock() # Simulate a response with invalid JSON mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"get_gestures", b"invalid json"]) mock_context = MagicMock() mock_context.socket.return_value = mock_sub_socket monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Mock settings mock_settings = MagicMock() mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555" monkeypatch.setattr(robot, "settings", mock_settings) # Act response = client.get("/get_available_gesture_tags") # Assert - invalid JSON should raise an exception assert response.status_code == 200 assert response.json() == {"available_gesture_tags": []}