149 lines
4.6 KiB
Python
149 lines
4.6 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from control_backend.api.v1.endpoints import user_interact
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
app = FastAPI()
|
|
app.include_router(user_interact.router)
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_button_event(client):
|
|
mock_pub_socket = AsyncMock()
|
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
|
|
|
payload = {"type": "speech", "context": "hello"}
|
|
response = client.post("/button_pressed", json=payload)
|
|
|
|
assert response.status_code == 202
|
|
assert response.json() == {"status": "Event received"}
|
|
|
|
mock_pub_socket.send_multipart.assert_awaited_once()
|
|
args = mock_pub_socket.send_multipart.call_args[0][0]
|
|
assert args[0] == b"button_pressed"
|
|
assert "speech" in args[1].decode()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_button_event_invalid_payload(client):
|
|
mock_pub_socket = AsyncMock()
|
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
|
|
|
# Missing context
|
|
payload = {"type": "speech"}
|
|
response = client.post("/button_pressed", json=payload)
|
|
|
|
assert response.status_code == 422
|
|
mock_pub_socket.send_multipart.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_experiment_stream_direct_call():
|
|
"""
|
|
Directly calling the endpoint function to test the streaming logic
|
|
without dealing with TestClient streaming limitations.
|
|
"""
|
|
mock_socket = AsyncMock()
|
|
# 1. recv data
|
|
# 2. recv timeout
|
|
# 3. disconnect (request.is_disconnected returns True)
|
|
mock_socket.recv_multipart.side_effect = [
|
|
(b"topic", b"message1"),
|
|
TimeoutError(),
|
|
(b"topic", b"message2"), # Should not be reached if disconnect checks work
|
|
]
|
|
mock_socket.close = MagicMock()
|
|
mock_socket.connect = MagicMock()
|
|
mock_socket.subscribe = MagicMock()
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_socket
|
|
|
|
with patch(
|
|
"control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context
|
|
):
|
|
mock_request = AsyncMock()
|
|
# is_disconnected sequence:
|
|
# 1. False (before first recv) -> reads message1
|
|
# 2. False (before second recv) -> triggers TimeoutError, continues
|
|
# 3. True (before third recv) -> break loop
|
|
mock_request.is_disconnected.side_effect = [False, False, True]
|
|
|
|
response = await user_interact.experiment_stream(mock_request)
|
|
|
|
lines = []
|
|
# Consume the generator
|
|
async for line in response.body_iterator:
|
|
lines.append(line)
|
|
|
|
assert "data: message1\n\n" in lines
|
|
assert len(lines) == 1
|
|
|
|
mock_socket.connect.assert_called()
|
|
mock_socket.subscribe.assert_called_with(b"experiment")
|
|
mock_socket.close.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_status_stream_direct_call():
|
|
"""
|
|
Test the status stream, ensuring it handles messages and sends pings on timeout.
|
|
"""
|
|
mock_socket = AsyncMock()
|
|
|
|
# Define the sequence of events for the socket:
|
|
# 1. Successfully receive a message
|
|
# 2. Timeout (which should trigger the ': ping' yield)
|
|
# 3. Another message (which won't be reached because we'll simulate disconnect)
|
|
mock_socket.recv_multipart.side_effect = [
|
|
(b"topic", b"status_update"),
|
|
TimeoutError(),
|
|
(b"topic", b"ignored_msg"),
|
|
]
|
|
|
|
mock_socket.close = MagicMock()
|
|
mock_socket.connect = MagicMock()
|
|
mock_socket.subscribe = MagicMock()
|
|
|
|
mock_context = MagicMock()
|
|
mock_context.socket.return_value = mock_socket
|
|
|
|
# Mock the ZMQ Context to return our mock_socket
|
|
with patch(
|
|
"control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context
|
|
):
|
|
mock_request = AsyncMock()
|
|
|
|
# is_disconnected sequence:
|
|
# 1. False -> Process "status_update"
|
|
# 2. False -> Process TimeoutError (yield ping)
|
|
# 3. True -> Break loop (client disconnected)
|
|
mock_request.is_disconnected.side_effect = [False, False, True]
|
|
|
|
# Call the status_stream function explicitly
|
|
response = await user_interact.status_stream(mock_request)
|
|
|
|
lines = []
|
|
async for line in response.body_iterator:
|
|
lines.append(line)
|
|
|
|
# Assertions
|
|
assert "data: status_update\n\n" in lines
|
|
assert ": ping\n\n" in lines # Verify lines 91-92 (ping logic)
|
|
|
|
mock_socket.connect.assert_called()
|
|
mock_socket.subscribe.assert_called_with(b"status")
|
|
mock_socket.close.assert_called()
|