import asyncio import json from unittest.mock import AsyncMock, MagicMock import pytest from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings from control_backend.schemas.program import ( ConditionalNorm, Goal, KeywordBelief, Phase, Plan, Program, Trigger, ) from control_backend.schemas.ri_message import RIEndpoint @pytest.fixture def agent(): agent = UserInterruptAgent(name="user_interrupt_agent") agent.send = AsyncMock() agent.logger = MagicMock() agent.sub_socket = AsyncMock() agent.pub_socket = AsyncMock() return agent @pytest.mark.asyncio async def test_send_to_speech_agent(agent): """Verify speech command format.""" await agent._send_to_speech_agent("Hello World") agent.send.assert_awaited_once() sent_msg: InternalMessage = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.robot_speech_name body = json.loads(sent_msg.body) assert body["data"] == "Hello World" assert body["is_priority"] is True @pytest.mark.asyncio async def test_send_to_gesture_agent(agent): """Verify gesture command format.""" await agent._send_to_gesture_agent("wave_hand") agent.send.assert_awaited_once() sent_msg: InternalMessage = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.robot_gesture_name body = json.loads(sent_msg.body) assert body["data"] == "wave_hand" assert body["is_priority"] is True assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value @pytest.mark.asyncio async def test_send_to_bdi_belief(agent): """Verify belief update format.""" context_str = "some_goal" await agent._send_to_bdi_belief(context_str, "goal") assert agent.send.await_count == 1 sent_msg = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.bdi_core_name assert sent_msg.thread == "beliefs" assert "achieved_some_goal" in sent_msg.body @pytest.mark.asyncio async def test_receive_loop_routing_success(agent): """ Test that the loop correctly: 1. Receives 'button_pressed' topic from ZMQ 2. Parses the JSON payload to find 'type' and 'context' 3. Calls the correct handler method based on 'type' """ # Prepare JSON payloads as bytes payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode() payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode() # override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal). # To test routing, we need to populate the maps agent._goal_map["Hello Override"] = "some_goal_slug" payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode() agent.sub_socket.recv_multipart.side_effect = [ (b"button_pressed", payload_speech), (b"button_pressed", payload_gesture), (b"button_pressed", payload_override), asyncio.CancelledError, # Stop the infinite loop ] agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() agent._send_to_bdi_belief = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass await asyncio.sleep(0) # Speech agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech") # Gesture agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture") # Override (since we mapped it to a goal) agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug", "goal") assert agent._send_to_speech_agent.await_count == 1 assert agent._send_to_gesture_agent.await_count == 1 assert agent._send_to_bdi_belief.await_count == 1 @pytest.mark.asyncio async def test_receive_loop_unknown_type(agent): """Test that unknown 'type' values in the JSON log a warning and do not crash.""" # Prepare a payload with an unknown type payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode() agent.sub_socket.recv_multipart.side_effect = [ (b"button_pressed", payload_unknown), asyncio.CancelledError, ] agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass await asyncio.sleep(0) # Ensure no handlers were called agent._send_to_speech_agent.assert_not_called() agent._send_to_gesture_agent.assert_not_called() agent.logger.warning.assert_called() @pytest.mark.asyncio async def test_create_mapping(agent): # Create a program with a trigger, goal, and conditional norm import uuid trigger_id = uuid.uuid4() goal_id = uuid.uuid4() norm_id = uuid.uuid4() cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key") plan = Plan(id=uuid.uuid4(), name="p1", steps=[]) trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan) goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan) cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond) phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger]) prog = Program(phases=[phase]) # Call create_mapping via handle_message msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json()) await agent.handle_message(msg) # Check maps assert str(trigger_id) in agent._trigger_map assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger" assert str(goal_id) in agent._goal_map assert agent._goal_map[str(goal_id)] == "my_goal" assert str(norm_id) in agent._cond_norm_map assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite" @pytest.mark.asyncio async def test_create_mapping_invalid_json(agent): # Pass invalid json to handle_message thread "new_program" msg = InternalMessage(to="me", thread="new_program", body="invalid json") await agent.handle_message(msg) # Should log error and maps should remain empty or cleared agent.logger.error.assert_called() @pytest.mark.asyncio async def test_handle_message_trigger_start(agent): # Setup reverse map manually agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() args = agent.pub_socket.send_multipart.call_args[0][0] assert args[0] == b"experiment" payload = json.loads(args[1]) assert payload["type"] == "trigger_update" assert payload["id"] == "ui_id_123" assert payload["achieved"] is True @pytest.mark.asyncio async def test_handle_message_trigger_end(agent): agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) assert payload["type"] == "trigger_update" assert payload["achieved"] is False @pytest.mark.asyncio async def test_handle_message_transition_phase(agent): msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) assert payload["type"] == "phase_update" assert payload["id"] == "phase_id_123" @pytest.mark.asyncio async def test_handle_message_goal_start(agent): agent._goal_reverse_map["goal_slug"] = "goal_id_123" msg = InternalMessage(to="me", thread="goal_start", body="goal_slug") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) assert payload["type"] == "goal_update" assert payload["id"] == "goal_id_123" assert payload["active"] is True @pytest.mark.asyncio async def test_handle_message_active_norms_update(agent): agent._cond_norm_reverse_map["norm_active"] = "id_1" agent._cond_norm_reverse_map["norm_inactive"] = "id_2" # Body is like: "('norm_active', 'other')" # The split logic handles quotes etc. msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) assert payload["type"] == "cond_norms_state_update" norms = {n["id"]: n["active"] for n in payload["norms"]} assert norms["id_1"] is True assert norms["id_2"] is False @pytest.mark.asyncio async def test_send_experiment_control(agent): # Test next_phase await agent._send_experiment_control_to_bdi_core("next_phase") agent.send.assert_awaited() msg = agent.send.call_args[0][0] assert msg.thread == "force_next_phase" # Test reset_phase await agent._send_experiment_control_to_bdi_core("reset_phase") msg = agent.send.call_args[0][0] assert msg.thread == "reset_current_phase" # Test reset_experiment await agent._send_experiment_control_to_bdi_core("reset_experiment") msg = agent.send.call_args[0][0] assert msg.thread == "reset_experiment" @pytest.mark.asyncio async def test_send_pause_command(agent): await agent._send_pause_command("true") # Sends to RI and VAD assert agent.send.await_count == 2 msgs = [call.args[0] for call in agent.send.call_args_list] ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name) assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint assert json.loads(ri_msg.body)["data"] is True vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name) assert vad_msg.body == "PAUSE" agent.send.reset_mock() await agent._send_pause_command("false") assert agent.send.await_count == 2 vad_msg = next( m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name ).args[0] assert vad_msg.body == "RESUME"