from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from control_backend.api.v1.endpoints import user_interact @pytest.fixture def app(): app = FastAPI() app.include_router(user_interact.router) return app @pytest.fixture def client(app): return TestClient(app) @pytest.mark.asyncio async def test_receive_button_event(client): mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket payload = {"type": "speech", "context": "hello"} response = client.post("/button_pressed", json=payload) assert response.status_code == 202 assert response.json() == {"status": "Event received"} mock_pub_socket.send_multipart.assert_awaited_once() args = mock_pub_socket.send_multipart.call_args[0][0] assert args[0] == b"button_pressed" assert "speech" in args[1].decode() @pytest.mark.asyncio async def test_receive_button_event_invalid_payload(client): mock_pub_socket = AsyncMock() client.app.state.endpoints_pub_socket = mock_pub_socket # Missing context payload = {"type": "speech"} response = client.post("/button_pressed", json=payload) assert response.status_code == 422 mock_pub_socket.send_multipart.assert_not_called() @pytest.mark.asyncio async def test_experiment_stream_direct_call(): """ Directly calling the endpoint function to test the streaming logic without dealing with TestClient streaming limitations. """ mock_socket = AsyncMock() # 1. recv data # 2. recv timeout # 3. disconnect (request.is_disconnected returns True) mock_socket.recv_multipart.side_effect = [ (b"topic", b"message1"), TimeoutError(), (b"topic", b"message2"), # Should not be reached if disconnect checks work ] mock_socket.close = MagicMock() mock_socket.connect = MagicMock() mock_socket.subscribe = MagicMock() mock_context = MagicMock() mock_context.socket.return_value = mock_socket with patch( "control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context ): mock_request = AsyncMock() # is_disconnected sequence: # 1. False (before first recv) -> reads message1 # 2. False (before second recv) -> triggers TimeoutError, continues # 3. True (before third recv) -> break loop mock_request.is_disconnected.side_effect = [False, False, True] response = await user_interact.experiment_stream(mock_request) lines = [] # Consume the generator async for line in response.body_iterator: lines.append(line) assert "data: message1\n\n" in lines assert len(lines) == 1 mock_socket.connect.assert_called() mock_socket.subscribe.assert_called_with(b"experiment") mock_socket.close.assert_called() @pytest.mark.asyncio async def test_status_stream_direct_call(): """ Test the status stream, ensuring it handles messages and sends pings on timeout. """ mock_socket = AsyncMock() # Define the sequence of events for the socket: # 1. Successfully receive a message # 2. Timeout (which should trigger the ': ping' yield) # 3. Another message (which won't be reached because we'll simulate disconnect) mock_socket.recv_multipart.side_effect = [ (b"topic", b"status_update"), TimeoutError(), (b"topic", b"ignored_msg"), ] mock_socket.close = MagicMock() mock_socket.connect = MagicMock() mock_socket.subscribe = MagicMock() mock_context = MagicMock() mock_context.socket.return_value = mock_socket # Mock the ZMQ Context to return our mock_socket with patch( "control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context ): mock_request = AsyncMock() # is_disconnected sequence: # 1. False -> Process "status_update" # 2. False -> Process TimeoutError (yield ping) # 3. True -> Break loop (client disconnected) mock_request.is_disconnected.side_effect = [False, False, True] # Call the status_stream function explicitly response = await user_interact.status_stream(mock_request) lines = [] async for line in response.body_iterator: lines.append(line) # Assertions assert "data: status_update\n\n" in lines assert ": ping\n\n" in lines # Verify lines 91-92 (ping logic) mock_socket.connect.assert_called() mock_socket.subscribe.assert_called_with(b"status") mock_socket.close.assert_called()