import asyncio import json from unittest.mock import AsyncMock, MagicMock import pytest from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings from control_backend.schemas.ri_message import RIEndpoint @pytest.fixture def agent(): agent = UserInterruptAgent(name="user_interrupt_agent") agent.send = AsyncMock() agent.logger = MagicMock() agent.sub_socket = AsyncMock() return agent @pytest.mark.asyncio async def test_send_to_speech_agent(agent): """Verify speech command format.""" await agent._send_to_speech_agent("Hello World") agent.send.assert_awaited_once() sent_msg: InternalMessage = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.robot_speech_name body = json.loads(sent_msg.body) assert body["data"] == "Hello World" assert body["is_priority"] is True @pytest.mark.asyncio async def test_send_to_gesture_agent(agent): """Verify gesture command format.""" await agent._send_to_gesture_agent("wave_hand") agent.send.assert_awaited_once() sent_msg: InternalMessage = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.robot_gesture_name body = json.loads(sent_msg.body) assert body["data"] == "wave_hand" assert body["is_priority"] is True assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value @pytest.mark.asyncio async def test_send_to_program_manager(agent): """Verify belief update format.""" context_str = "2" await agent._send_to_program_manager(context_str) agent.send.assert_awaited_once() sent_msg: InternalMessage = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.bdi_program_manager_name assert sent_msg.thread == "belief_override_id" body = json.loads(sent_msg.body) assert body["belief"] == context_str @pytest.mark.asyncio async def test_receive_loop_routing_success(agent): """ Test that the loop correctly: 1. Receives 'button_pressed' topic from ZMQ 2. Parses the JSON payload to find 'type' and 'context' 3. Calls the correct handler method based on 'type' """ # Prepare JSON payloads as bytes payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode() payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode() payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode() agent.sub_socket.recv_multipart.side_effect = [ (b"button_pressed", payload_speech), (b"button_pressed", payload_gesture), (b"button_pressed", payload_override), asyncio.CancelledError, # Stop the infinite loop ] agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() agent._send_to_program_manager = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass await asyncio.sleep(0) # Speech agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech") # Gesture agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture") # Override agent._send_to_program_manager.assert_awaited_once_with("Hello Override") assert agent._send_to_speech_agent.await_count == 1 assert agent._send_to_gesture_agent.await_count == 1 assert agent._send_to_program_manager.await_count == 1 @pytest.mark.asyncio async def test_receive_loop_unknown_type(agent): """Test that unknown 'type' values in the JSON log a warning and do not crash.""" # Prepare a payload with an unknown type payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode() agent.sub_socket.recv_multipart.side_effect = [ (b"button_pressed", payload_unknown), asyncio.CancelledError, ] agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() agent._send_to_belief_collector = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass await asyncio.sleep(0) # Ensure no handlers were called agent._send_to_speech_agent.assert_not_called() agent._send_to_gesture_agent.assert_not_called() agent._send_to_belief_collector.assert_not_called() agent.logger.warning.assert_called_with( "Received button press with unknown type '%s' (context: '%s').", "unknown_thing", "some_data", )