import asyncio import json from unittest.mock import AsyncMock, MagicMock, patch import pytest from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.program import ( ConditionalNorm, Goal, KeywordBelief, Phase, Plan, Program, Trigger, ) from control_backend.schemas.ri_message import RIEndpoint @pytest.fixture def agent(): agent = UserInterruptAgent(name="user_interrupt_agent") agent.send = AsyncMock() agent.logger = MagicMock() agent.sub_socket = AsyncMock() agent.pub_socket = AsyncMock() return agent @pytest.fixture(autouse=True) def mock_experiment_logger(): with patch( "control_backend.agents.user_interrupt.user_interrupt_agent.experiment_logger" ) as logger: yield logger @pytest.mark.asyncio async def test_send_to_speech_agent(agent): """Verify speech command format.""" await agent._send_to_speech_agent("Hello World") agent.send.assert_awaited_once() sent_msg: InternalMessage = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.robot_speech_name body = json.loads(sent_msg.body) assert body["data"] == "Hello World" assert body["is_priority"] is True @pytest.mark.asyncio async def test_send_to_gesture_agent(agent): """Verify gesture command format.""" await agent._send_to_gesture_agent("wave_hand") agent.send.assert_awaited_once() sent_msg: InternalMessage = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.robot_gesture_name body = json.loads(sent_msg.body) assert body["data"] == "wave_hand" assert body["is_priority"] is True assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value @pytest.mark.asyncio async def test_send_to_bdi_belief(agent): """Verify belief update format.""" context_str = "some_goal" await agent._send_to_bdi_belief(context_str, "goal") assert agent.send.await_count == 1 sent_msg = agent.send.call_args.args[0] assert sent_msg.to == settings.agent_settings.bdi_core_name assert sent_msg.thread == "beliefs" assert "achieved_some_goal" in sent_msg.body @pytest.mark.asyncio async def test_receive_loop_routing_success(agent): """ Test that the loop correctly: 1. Receives 'button_pressed' topic from ZMQ 2. Parses the JSON payload to find 'type' and 'context' 3. Calls the correct handler method based on 'type' """ # Prepare JSON payloads as bytes payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode() payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode() # override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal). # To test routing, we need to populate the maps agent._goal_map["Hello Override"] = "some_goal_slug" payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode() agent.sub_socket.recv_multipart.side_effect = [ (b"button_pressed", payload_speech), (b"button_pressed", payload_gesture), (b"button_pressed", payload_override), asyncio.CancelledError, # Stop the infinite loop ] agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() agent._send_to_bdi_belief = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass await asyncio.sleep(0) # Speech agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech") # Gesture agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture") # Override (since we mapped it to a goal) agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug", "goal") assert agent._send_to_speech_agent.await_count == 1 assert agent._send_to_gesture_agent.await_count == 1 assert agent._send_to_bdi_belief.await_count == 1 @pytest.mark.asyncio async def test_receive_loop_unknown_type(agent): """Test that unknown 'type' values in the JSON log a warning and do not crash.""" # Prepare a payload with an unknown type payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode() agent.sub_socket.recv_multipart.side_effect = [ (b"button_pressed", payload_unknown), asyncio.CancelledError, ] agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass await asyncio.sleep(0) # Ensure no handlers were called agent._send_to_speech_agent.assert_not_called() agent._send_to_gesture_agent.assert_not_called() agent.logger.warning.assert_called() @pytest.mark.asyncio async def test_create_mapping(agent): # Create a program with a trigger, goal, and conditional norm import uuid trigger_id = uuid.uuid4() goal_id = uuid.uuid4() norm_id = uuid.uuid4() cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key") plan = Plan(id=uuid.uuid4(), name="p1", steps=[]) trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan) goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan) cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond) phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger]) prog = Program(phases=[phase]) # Call create_mapping via handle_message msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json()) await agent.handle_message(msg) # Check maps assert str(trigger_id) in agent._trigger_map assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger" assert str(goal_id) in agent._goal_map assert agent._goal_map[str(goal_id)] == "my_goal" assert str(norm_id) in agent._cond_norm_map assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite" @pytest.mark.asyncio async def test_create_mapping_invalid_json(agent): # Pass invalid json to handle_message thread "new_program" msg = InternalMessage(to="me", thread="new_program", body="invalid json") await agent.handle_message(msg) # Should log error and maps should remain empty or cleared agent.logger.error.assert_called() @pytest.mark.asyncio async def test_handle_message_trigger_start(agent): # Setup reverse map manually agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() args = agent.pub_socket.send_multipart.call_args[0][0] assert args[0] == b"experiment" payload = json.loads(args[1]) assert payload["type"] == "trigger_update" assert payload["id"] == "ui_id_123" assert payload["achieved"] is True @pytest.mark.asyncio async def test_handle_message_trigger_end(agent): agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) assert payload["type"] == "trigger_update" assert payload["achieved"] is False @pytest.mark.asyncio async def test_handle_message_transition_phase(agent): msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) assert payload["type"] == "phase_update" assert payload["id"] == "phase_id_123" @pytest.mark.asyncio async def test_handle_message_goal_start(agent): agent._goal_reverse_map["goal_slug"] = "goal_id_123" msg = InternalMessage(to="me", thread="goal_start", body="goal_slug") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) assert payload["type"] == "goal_update" assert payload["id"] == "goal_id_123" assert payload["active"] is True @pytest.mark.asyncio async def test_handle_message_active_norms_update(agent): agent._cond_norm_reverse_map["norm_active"] = "id_1" agent._cond_norm_reverse_map["norm_inactive"] = "id_2" # Body is like: "('norm_active', 'other')" # The split logic handles quotes etc. msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'") await agent.handle_message(msg) agent.pub_socket.send_multipart.assert_awaited_once() payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) assert payload["type"] == "cond_norms_state_update" norms = {n["id"]: n["active"] for n in payload["norms"]} assert norms["id_1"] is True assert norms["id_2"] is False @pytest.mark.asyncio async def test_send_experiment_control(agent): # Test next_phase await agent._send_experiment_control_to_bdi_core("next_phase") agent.send.assert_awaited() msg = agent.send.call_args[0][0] assert msg.thread == "force_next_phase" # Test reset_phase await agent._send_experiment_control_to_bdi_core("reset_phase") msg = agent.send.call_args[0][0] assert msg.thread == "reset_current_phase" # Test reset_experiment await agent._send_experiment_control_to_bdi_core("reset_experiment") msg = agent.send.call_args[0][0] assert msg.thread == "reset_experiment" @pytest.mark.asyncio async def test_send_pause_command(agent): await agent._send_pause_command("true") # Sends to RI and VAD assert agent.send.await_count == 2 msgs = [call.args[0] for call in agent.send.call_args_list] ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name) assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint assert json.loads(ri_msg.body)["data"] is True vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name) assert vad_msg.body == "PAUSE" agent.send.reset_mock() await agent._send_pause_command("false") assert agent.send.await_count == 2 vad_msg = next( m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name ).args[0] assert vad_msg.body == "RESUME" @pytest.mark.asyncio async def test_setup(agent): """Test the setup method initializes sockets correctly.""" with patch("control_backend.agents.user_interrupt.user_interrupt_agent.Context") as MockContext: mock_ctx_instance = MagicMock() MockContext.instance.return_value = mock_ctx_instance mock_sub = MagicMock() mock_pub = MagicMock() mock_ctx_instance.socket.side_effect = [mock_sub, mock_pub] # MOCK add_behavior so we don't rely on internal attributes agent.add_behavior = MagicMock() await agent.setup() # Check sockets mock_sub.connect.assert_called_with(settings.zmq_settings.internal_sub_address) mock_pub.connect.assert_called_with(settings.zmq_settings.internal_pub_address) # Verify add_behavior was called agent.add_behavior.assert_called_once() @pytest.mark.asyncio async def test_receive_loop_json_error(agent): """Verify that malformed JSON is caught and logged without crashing the loop.""" agent.sub_socket.recv_multipart.side_effect = [ (b"topic", b"INVALID{JSON"), asyncio.CancelledError, ] try: await agent._receive_button_event() except asyncio.CancelledError: pass agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic") @pytest.mark.asyncio async def test_receive_loop_override_trigger(agent): """Verify routing 'override' to a Trigger.""" agent._trigger_map["101"] = "trigger_slug" payload = json.dumps({"type": "override", "context": "101"}).encode() agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] agent._send_to_bdi = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass agent._send_to_bdi.assert_awaited_once_with("force_trigger", "trigger_slug") @pytest.mark.asyncio async def test_receive_loop_override_norm(agent): """Verify routing 'override' to a Conditional Norm.""" agent._cond_norm_map["202"] = "norm_slug" payload = json.dumps({"type": "override", "context": "202"}).encode() agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] agent._send_to_bdi_belief = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass agent._send_to_bdi_belief.assert_awaited_once_with("norm_slug", "cond_norm") @pytest.mark.asyncio async def test_receive_loop_override_missing(agent): """Verify warning log when an override ID is not found in any map.""" payload = json.dumps({"type": "override", "context": "999"}).encode() agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] try: await agent._receive_button_event() except asyncio.CancelledError: pass agent.logger.warning.assert_called_with("Could not determine which element to override.") @pytest.mark.asyncio async def test_receive_loop_unachieve_logic(agent): """Verify success and failure paths for override_unachieve.""" agent._cond_norm_map["202"] = "norm_slug" success_payload = json.dumps({"type": "override_unachieve", "context": "202"}).encode() fail_payload = json.dumps({"type": "override_unachieve", "context": "999"}).encode() agent.sub_socket.recv_multipart.side_effect = [ (b"topic", success_payload), (b"topic", fail_payload), asyncio.CancelledError, ] agent._send_to_bdi_belief = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass # Assert success call (True flag for unachieve) agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True) # Assert failure log agent.logger.warning.assert_called_with( "Could not determine which conditional norm to unachieve." ) @pytest.mark.asyncio async def test_receive_loop_pause_resume(agent): """Verify pause and resume toggle logic and logging.""" pause_payload = json.dumps({"type": "pause", "context": "true"}).encode() resume_payload = json.dumps({"type": "pause", "context": ""}).encode() agent.sub_socket.recv_multipart.side_effect = [ (b"topic", pause_payload), (b"topic", resume_payload), asyncio.CancelledError, ] agent._send_pause_command = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass agent._send_pause_command.assert_any_call("true") agent._send_pause_command.assert_any_call("") agent.logger.info.assert_any_call("Sent pause command.") agent.logger.info.assert_any_call("Sent resume command.") @pytest.mark.asyncio async def test_receive_loop_phase_control(agent): """Verify experiment flow control (next_phase).""" payload = json.dumps({"type": "next_phase", "context": ""}).encode() agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] agent._send_experiment_control_to_bdi_core = AsyncMock() try: await agent._receive_button_event() except asyncio.CancelledError: pass agent._send_experiment_control_to_bdi_core.assert_awaited_once_with("next_phase") @pytest.mark.asyncio async def test_handle_message_unknown_thread(agent): """Test handling of an unknown message thread (lines 213-214).""" msg = InternalMessage(to="me", thread="unknown_thread", body="test") await agent.handle_message(msg) agent.logger.debug.assert_called_with( "Received internal message on unhandled thread: unknown_thread" ) @pytest.mark.asyncio async def test_send_to_bdi_belief_edge_cases(agent): """ Covers: - Unknown asl_type warning (lines 326-328) - unachieve=True logic (lines 334-337) """ # 1. Unknown Type await agent._send_to_bdi_belief("slug", "unknown_type") agent.logger.warning.assert_called_with("Tried to send belief with unknown type") agent.send.assert_not_called() # Reset mock for part 2 agent.send.reset_mock() # 2. Unachieve = True await agent._send_to_bdi_belief("slug", "cond_norm", unachieve=True) agent.send.assert_awaited() sent_msg = agent.send.call_args.args[0] # Verify it is a delete operation body_obj = BeliefMessage.model_validate_json(sent_msg.body) # Verify 'delete' has content assert body_obj.delete is not None assert len(body_obj.delete) == 1 assert body_obj.delete[0].name == "force_slug" # Verify 'create' is empty (handling both None and []) assert not body_obj.create @pytest.mark.asyncio async def test_send_experiment_control_unknown(agent): """Test sending an unknown experiment control type (lines 366-367).""" await agent._send_experiment_control_to_bdi_core("invalid_command") agent.logger.warning.assert_called_with( "Received unknown experiment control type '%s' to send to BDI Core.", "invalid_command" ) # Ensure it still sends an empty message (as per code logic, though thread is empty) agent.send.assert_awaited() msg = agent.send.call_args[0][0] assert msg.thread == "" @pytest.mark.asyncio async def test_create_mapping_recursive_goals(agent): """Verify that nested subgoals are correctly registered in the mapping.""" import uuid # 1. Setup IDs parent_goal_id = uuid.uuid4() child_goal_id = uuid.uuid4() # 2. Create the child goal child_goal = Goal( id=child_goal_id, name="child_goal", description="I am a subgoal", plan=Plan(id=uuid.uuid4(), name="p_child", steps=[]), ) # 3. Create the parent goal and put the child goal inside its plan steps parent_goal = Goal( id=parent_goal_id, name="parent_goal", description="I am a parent", plan=Plan(id=uuid.uuid4(), name="p_parent", steps=[child_goal]), # Nested here ) # 4. Build the program phase = Phase( id=uuid.uuid4(), name="phase1", norms=[], goals=[parent_goal], # Only the parent is top-level triggers=[], ) prog = Program(phases=[phase]) # 5. Execute mapping msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json()) await agent.handle_message(msg) # 6. Assertions # Check parent assert str(parent_goal_id) in agent._goal_map assert agent._goal_map[str(parent_goal_id)] == "parent_goal" # Check child (This confirms the recursion worked) assert str(child_goal_id) in agent._goal_map assert agent._goal_map[str(child_goal_id)] == "child_goal" assert agent._goal_reverse_map["child_goal"] == str(child_goal_id) @pytest.mark.asyncio async def test_receive_loop_advanced_scenarios(agent): """ Covers: - JSONDecodeError (lines 86-88) - Override: Trigger found (lines 108-109) - Override: Norm found (lines 114-115) - Override: Nothing found (line 134) - Override Unachieve: Success & Fail (lines 136-145) - Pause: Context true/false logs (lines 150-157) - Next Phase (line 160) """ # 1. Setup Data Maps agent._trigger_map["101"] = "trigger_slug" agent._cond_norm_map["202"] = "norm_slug" # 2. Define Payloads # A. Invalid JSON bad_json = b"INVALID{JSON" # B. Override -> Trigger override_trigger = json.dumps({"type": "override", "context": "101"}).encode() # C. Override -> Norm override_norm = json.dumps({"type": "override", "context": "202"}).encode() # D. Override -> Unknown override_fail = json.dumps({"type": "override", "context": "999"}).encode() # E. Unachieve -> Success unachieve_success = json.dumps({"type": "override_unachieve", "context": "202"}).encode() # F. Unachieve -> Fail unachieve_fail = json.dumps({"type": "override_unachieve", "context": "999"}).encode() # G. Pause (True) pause_true = json.dumps({"type": "pause", "context": "true"}).encode() # H. Pause (False/Resume) pause_false = json.dumps({"type": "pause", "context": ""}).encode() # I. Next Phase next_phase = json.dumps({"type": "next_phase", "context": ""}).encode() # 3. Setup Socket agent.sub_socket.recv_multipart.side_effect = [ (b"topic", bad_json), (b"topic", override_trigger), (b"topic", override_norm), (b"topic", override_fail), (b"topic", unachieve_success), (b"topic", unachieve_fail), (b"topic", pause_true), (b"topic", pause_false), (b"topic", next_phase), asyncio.CancelledError, # End loop ] # Mock internal helpers to verify calls agent._send_to_bdi = AsyncMock() agent._send_to_bdi_belief = AsyncMock() agent._send_pause_command = AsyncMock() agent._send_experiment_control_to_bdi_core = AsyncMock() # 4. Run Loop try: await agent._receive_button_event() except asyncio.CancelledError: pass # 5. Assertions # JSON Error agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic") # Override Trigger agent._send_to_bdi.assert_awaited_with("force_trigger", "trigger_slug") # Override Norm # We expect _send_to_bdi_belief to be called for the norm # Note: The loop calls _send_to_bdi_belief(asl_cond_norm, "cond_norm") agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm") # Override Fail (Warning log) agent.logger.warning.assert_any_call("Could not determine which element to override.") # Unachieve Success # Loop calls _send_to_bdi_belief(asl_cond_norm, "cond_norm", True) agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True) # Unachieve Fail agent.logger.warning.assert_any_call("Could not determine which conditional norm to unachieve.") # Pause Logic agent._send_pause_command.assert_any_call("true") agent.logger.info.assert_any_call("Sent pause command.") # Resume Logic agent._send_pause_command.assert_any_call("") agent.logger.info.assert_any_call("Sent resume command.") # Next Phase agent._send_experiment_control_to_bdi_core.assert_awaited_with("next_phase")