416 lines
14 KiB
Python
416 lines
14 KiB
Python
# tests/test_robot_endpoints.py
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import zmq.asyncio
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from control_backend.api.v1.endpoints import robot
|
|
from control_backend.schemas.ri_message import GestureCommand, SpeechCommand
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
"""
|
|
Creates a FastAPI test app and attaches the router under test.
|
|
Also sets up a mock internal_comm_socket.
|
|
"""
|
|
app = FastAPI()
|
|
app.include_router(robot.router)
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create a test client for the app."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_zmq_context():
|
|
"""Mock the ZMQ context used by the endpoint module."""
|
|
with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context:
|
|
context_instance = MagicMock()
|
|
mock_context.return_value = context_instance
|
|
yield context_instance
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_sockets(mock_zmq_context):
|
|
"""Optional helper if you want both a sub and req/push socket available."""
|
|
mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
|
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
|
|
|
mock_zmq_context.socket.return_value = mock_sub_socket
|
|
|
|
return {"sub": mock_sub_socket, "req": mock_req_socket}
|
|
|
|
|
|
def test_receive_speech_command_success(client):
|
|
"""
|
|
Test for successful reception of a command. Ensures the status code is 202 and the response body
|
|
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
|
|
expected data.
|
|
"""
|
|
# Arrange
|
|
mock_pub_socket = AsyncMock()
|
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
|
|
|
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
|
|
speech_command = SpeechCommand(**command_data)
|
|
|
|
# Act
|
|
response = client.post("/command/speech", json=command_data)
|
|
|
|
# Assert
|
|
assert response.status_code == 202
|
|
assert response.json() == {"status": "Speech command received"}
|
|
|
|
# Verify that the ZMQ socket was used correctly
|
|
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
|
[b"command", speech_command.model_dump_json().encode()]
|
|
)
|
|
|
|
|
|
def test_receive_gesture_command_success(client):
|
|
"""
|
|
Test for successful reception of a command that is a gesture command.
|
|
Ensures the status code is 202 and the response body is correct.
|
|
"""
|
|
# Arrange
|
|
mock_pub_socket = AsyncMock()
|
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
|
|
|
command_data = {"endpoint": "actuate/gesture/tag", "data": "happy"}
|
|
gesture_command = GestureCommand(**command_data)
|
|
|
|
# Act
|
|
response = client.post("/command/gesture", json=command_data)
|
|
|
|
# Assert
|
|
assert response.status_code == 202
|
|
assert response.json() == {"status": "Gesture command received"}
|
|
|
|
# Verify that the ZMQ socket was used correctly
|
|
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
|
[b"command", gesture_command.model_dump_json().encode()]
|
|
)
|
|
|
|
|
|
def test_receive_speech_command_invalid_payload(client):
|
|
"""
|
|
Test invalid data handling (schema validation).
|
|
"""
|
|
# Missing required field(s)
|
|
bad_payload = {"invalid": "data"}
|
|
response = client.post("/command/speech", json=bad_payload)
|
|
assert response.status_code == 422 # validation error
|
|
|
|
|
|
def test_receive_gesture_command_invalid_payload(client):
|
|
"""
|
|
Test invalid data handling (schema validation).
|
|
"""
|
|
# Missing required field(s)
|
|
bad_payload = {"invalid": "data"}
|
|
response = client.post("/command/gesture", json=bad_payload)
|
|
assert response.status_code == 422 # validation error
|
|
|
|
|
|
def test_ping_check_returns_none(client):
|
|
"""Ensure /ping_check returns 200 and None (currently unimplemented)."""
|
|
response = client.get("/ping_check")
|
|
assert response.status_code == 200
|
|
assert response.json() is None
|
|
|
|
|
|
# ----------------------------
|
|
# ping_stream tests (unchanged behavior)
|
|
# ----------------------------
|
|
@pytest.mark.asyncio
|
|
async def test_ping_stream_yields_ping_event(monkeypatch):
|
|
"""Test that ping_stream yields a proper SSE message when a ping is received."""
|
|
mock_sub_socket = AsyncMock()
|
|
mock_sub_socket.connect = MagicMock()
|
|
mock_sub_socket.setsockopt = MagicMock()
|
|
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"ping", b"true"])
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_sub_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
# patch settings address used by ping_stream
|
|
mock_settings = MagicMock()
|
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
|
|
|
mock_request = AsyncMock()
|
|
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
|
|
|
|
response = await robot.ping_stream(mock_request)
|
|
generator = aiter(response.body_iterator)
|
|
|
|
event = await anext(generator)
|
|
event_text = event.decode() if isinstance(event, bytes) else str(event)
|
|
assert event_text.strip() == "data: true"
|
|
|
|
with pytest.raises(StopAsyncIteration):
|
|
await anext(generator)
|
|
|
|
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
|
|
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
|
mock_sub_socket.recv_multipart.assert_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ping_stream_handles_timeout(monkeypatch):
|
|
"""Test that ping_stream continues looping on TimeoutError."""
|
|
mock_sub_socket = AsyncMock()
|
|
mock_sub_socket.connect = MagicMock()
|
|
mock_sub_socket.setsockopt = MagicMock()
|
|
mock_sub_socket.recv_multipart.side_effect = TimeoutError()
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_sub_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
mock_settings = MagicMock()
|
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
|
|
|
mock_request = AsyncMock()
|
|
mock_request.is_disconnected = AsyncMock(return_value=True)
|
|
|
|
response = await robot.ping_stream(mock_request)
|
|
generator = aiter(response.body_iterator)
|
|
|
|
with pytest.raises(StopAsyncIteration):
|
|
await anext(generator)
|
|
|
|
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
|
|
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
|
mock_sub_socket.recv_multipart.assert_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ping_stream_yields_json_values(monkeypatch):
|
|
"""Ensure ping_stream correctly parses and yields JSON body values."""
|
|
mock_sub_socket = AsyncMock()
|
|
mock_sub_socket.connect = MagicMock()
|
|
mock_sub_socket.setsockopt = MagicMock()
|
|
mock_sub_socket.recv_multipart = AsyncMock(
|
|
return_value=[b"ping", json.dumps({"connected": True}).encode()]
|
|
)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_sub_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
mock_settings = MagicMock()
|
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
|
|
|
mock_request = AsyncMock()
|
|
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
|
|
|
|
response = await robot.ping_stream(mock_request)
|
|
generator = aiter(response.body_iterator)
|
|
|
|
event = await anext(generator)
|
|
event_text = event.decode() if isinstance(event, bytes) else str(event)
|
|
|
|
assert "connected" in event_text
|
|
assert "true" in event_text
|
|
|
|
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
|
|
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
|
mock_sub_socket.recv_multipart.assert_awaited()
|
|
|
|
|
|
# ----------------------------
|
|
# Updated get_available_gesture_tags tests (REQ socket on tcp://localhost:7788)
|
|
# ----------------------------
|
|
@pytest.mark.asyncio
|
|
async def test_get_available_gesture_tags_success(client, monkeypatch):
|
|
"""
|
|
Test successful retrieval of available gesture tags using a REQ socket.
|
|
"""
|
|
# Arrange
|
|
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
|
mock_req_socket.connect = MagicMock()
|
|
mock_req_socket.send = AsyncMock()
|
|
response_data = {"tags": ["wave", "nod", "point", "dance"]}
|
|
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_req_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
# Replace logger methods to avoid noisy logs in tests
|
|
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
|
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
|
|
|
# Act
|
|
response = client.get("/commands/gesture/tags")
|
|
|
|
# Assert
|
|
assert response.status_code == 200
|
|
assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]}
|
|
|
|
# Verify ZeroMQ REQ interactions
|
|
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
|
|
mock_req_socket.send.assert_awaited_once_with(b"None")
|
|
mock_req_socket.recv.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_available_gesture_tags_with_amount(client, monkeypatch):
|
|
"""
|
|
The endpoint currently ignores the 'amount' TODO, so behavior is the same as 'success'.
|
|
This test asserts that the endpoint still sends b"None" and returns the tags.
|
|
"""
|
|
# Arrange
|
|
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
|
mock_req_socket.connect = MagicMock()
|
|
mock_req_socket.send = AsyncMock()
|
|
response_data = {"tags": ["wave", "nod"]}
|
|
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_req_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
|
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
|
|
|
# Act
|
|
response = client.get("/commands/gesture/tags")
|
|
|
|
# Assert
|
|
assert response.status_code == 200
|
|
assert response.json() == {"available_gesture_tags": ["wave", "nod"]}
|
|
|
|
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
|
|
mock_req_socket.send.assert_awaited_once_with(b"None")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_available_gesture_tags_timeout(client, monkeypatch):
|
|
"""
|
|
Test timeout scenario when fetching gesture tags. Endpoint should handle TimeoutError
|
|
and return an empty list while logging the timeout.
|
|
"""
|
|
# Arrange
|
|
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
|
mock_req_socket.connect = MagicMock()
|
|
mock_req_socket.send = AsyncMock()
|
|
mock_req_socket.recv = AsyncMock(side_effect=TimeoutError)
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_req_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
# Patch logger.debug so we can assert it was called with the expected message
|
|
mock_debug = MagicMock()
|
|
monkeypatch.setattr(robot.logger, "debug", mock_debug)
|
|
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
|
|
|
# Act
|
|
response = client.get("/commands/gesture/tags")
|
|
|
|
# Assert
|
|
assert response.status_code == 200
|
|
assert response.json() == {"available_gesture_tags": []}
|
|
|
|
# Verify the timeout was logged using the exact string from the endpoint code
|
|
mock_debug.assert_called_once_with("Got timeout error fetching gestures.")
|
|
|
|
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
|
|
mock_req_socket.send.assert_awaited_once_with(b"None")
|
|
mock_req_socket.recv.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_available_gesture_tags_empty_response(client, monkeypatch):
|
|
"""
|
|
Test scenario when response contains an empty 'tags' list.
|
|
"""
|
|
# Arrange
|
|
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
|
mock_req_socket.connect = MagicMock()
|
|
mock_req_socket.send = AsyncMock()
|
|
response_data = {"tags": []}
|
|
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_req_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
|
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
|
|
|
# Act
|
|
response = client.get("/commands/gesture/tags")
|
|
|
|
# Assert
|
|
assert response.status_code == 200
|
|
assert response.json() == {"available_gesture_tags": []}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch):
|
|
"""
|
|
Test scenario when response JSON doesn't contain 'tags' key.
|
|
"""
|
|
# Arrange
|
|
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
|
mock_req_socket.connect = MagicMock()
|
|
mock_req_socket.send = AsyncMock()
|
|
response_data = {"some_other_key": "value"}
|
|
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_req_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
|
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
|
|
|
# Act
|
|
response = client.get("/commands/gesture/tags")
|
|
|
|
# Assert
|
|
assert response.status_code == 200
|
|
assert response.json() == {"available_gesture_tags": []}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_available_gesture_tags_invalid_json(client, monkeypatch):
|
|
"""
|
|
Test scenario when response contains invalid JSON. Endpoint should log the error
|
|
and return an empty list.
|
|
"""
|
|
# Arrange
|
|
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
|
mock_req_socket.connect = MagicMock()
|
|
mock_req_socket.send = AsyncMock()
|
|
mock_req_socket.recv = AsyncMock(return_value=b"invalid json")
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_req_socket
|
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
|
|
|
mock_error = MagicMock()
|
|
monkeypatch.setattr(robot.logger, "error", mock_error)
|
|
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
|
|
|
# Act
|
|
response = client.get("/commands/gesture/tags")
|
|
|
|
# Assert - invalid JSON should lead to empty list and error log invocation
|
|
assert response.status_code == 200
|
|
assert response.json() == {"available_gesture_tags": []}
|
|
assert mock_error.call_count == 1
|