diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index 54c9983..6e8a594 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -198,9 +198,9 @@ class BDIProgramManager(BaseAgent): :return: All goals within and including the given goal. """ goals: list[Goal] = [goal] - for plan in goal.plan: - if isinstance(plan, Goal): - goals.extend(BDIProgramManager._extract_goals_from_goal(plan)) + for step in goal.plan.steps: + if isinstance(step, Goal): + goals.extend(BDIProgramManager._extract_goals_from_goal(step)) return goals def _extract_current_goals(self) -> list[Goal]: diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index e2b2d87..117f83c 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -336,7 +336,6 @@ class UserInterruptAgent(BaseAgent): else: self.logger.warning("Tried to send belief with unknown type") return - belief = Belief(name=belief_name, arguments=None) self.logger.debug(f"Sending belief to BDI Core: {belief_name}") # Conditional norms are unachieved by removing the belief diff --git a/test/unit/agents/bdi/test_bdi_program_manager.py b/test/unit/agents/bdi/test_bdi_program_manager.py index 5771451..646075b 100644 --- a/test/unit/agents/bdi/test_bdi_program_manager.py +++ b/test/unit/agents/bdi/test_bdi_program_manager.py @@ -8,7 +8,17 @@ import pytest from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager from control_backend.core.agent_system import InternalMessage -from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program +from control_backend.schemas.program import ( + BasicNorm, + ConditionalNorm, + Goal, + InferredBelief, + KeywordBelief, + Phase, + Plan, + Program, + Trigger, +) # Fix Windows Proactor loop for zmq if sys.platform.startswith("win"): @@ -295,3 +305,98 @@ async def test_setup(mock_settings): # 3. Adds behavior manager.add_behavior.assert_called() + + +@pytest.mark.asyncio +async def test_send_program_to_user_interrupt(mock_settings): + """Test directly sending the program to the user interrupt agent.""" + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + program = Program.model_validate_json(make_valid_program_json()) + + await manager._send_program_to_user_interrupt(program) + + assert manager.send.await_count == 1 + msg = manager.send.await_args[0][0] + assert msg.to == "user_interrupt_agent" + assert msg.thread == "new_program" + assert "Basic Phase" in msg.body + + +@pytest.mark.asyncio +async def test_complex_program_extraction(): + manager = BDIProgramManager(name="program_manager_test") + + # 1. Create Complex Components + + # Inferred Belief (A & B) + belief_left = KeywordBelief(id=uuid.uuid4(), name="b1", keyword="hot") + belief_right = KeywordBelief(id=uuid.uuid4(), name="b2", keyword="sunny") + inferred_belief = InferredBelief( + id=uuid.uuid4(), name="b_inf", operator="AND", left=belief_left, right=belief_right + ) + + # Conditional Norm + cond_norm = ConditionalNorm( + id=uuid.uuid4(), name="norm_cond", norm="wear_hat", condition=inferred_belief + ) + + # Trigger with Inferred Belief condition + dummy_plan = Plan(id=uuid.uuid4(), name="dummy_plan", steps=[]) + trigger = Trigger(id=uuid.uuid4(), name="trigger_1", condition=inferred_belief, plan=dummy_plan) + + # Nested Goal + sub_goal = Goal( + id=uuid.uuid4(), + name="sub_goal", + description="desc", + plan=Plan(id=uuid.uuid4(), name="empty", steps=[]), + can_fail=True, + ) + + parent_goal = Goal( + id=uuid.uuid4(), + name="parent_goal", + description="desc", + # The plan contains the sub_goal as a step + plan=Plan(id=uuid.uuid4(), name="parent_plan", steps=[sub_goal]), + can_fail=False, + ) + + # 2. Assemble Program + phase = Phase( + id=uuid.uuid4(), + name="Complex Phase", + norms=[cond_norm], + goals=[parent_goal], + triggers=[trigger], + ) + program = Program(phases=[phase]) + + # 3. Initialize Internal State (Triggers _populate_goal_mapping -> Nested Goal logic) + manager._initialize_internal_state(program) + + # Assertion for Line 53-54 (Mapping population) + # Both parent and sub-goal should be mapped + assert str(parent_goal.id) in manager._goal_mapping + assert str(sub_goal.id) in manager._goal_mapping + + # 4. Test Belief Extraction (Triggers lines 132-133, 142-146) + beliefs = manager._extract_current_beliefs() + + # Should extract recursive beliefs from cond_norm and trigger + # Inferred belief splits into Left + Right. Since we use it twice, we get duplicates + # checking existence is enough. + belief_names = [b.name for b in beliefs] + assert "b1" in belief_names + assert "b2" in belief_names + + # 5. Test Goal Extraction (Triggers lines 173, 185) + goals = manager._extract_current_goals() + + goal_names = [g.name for g in goals] + assert "parent_goal" in goal_names + assert "sub_goal" in goal_names diff --git a/test/unit/agents/user_interrupt/test_user_interrupt.py b/test/unit/agents/user_interrupt/test_user_interrupt.py index 7a71891..a69a830 100644 --- a/test/unit/agents/user_interrupt/test_user_interrupt.py +++ b/test/unit/agents/user_interrupt/test_user_interrupt.py @@ -1,12 +1,13 @@ import asyncio import json -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.program import ( ConditionalNorm, Goal, @@ -309,3 +310,220 @@ async def test_send_pause_command(agent): m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name ).args[0] assert vad_msg.body == "RESUME" + + +@pytest.mark.asyncio +async def test_setup(agent): + """Test the setup method initializes sockets correctly.""" + with patch("control_backend.agents.user_interrupt.user_interrupt_agent.Context") as MockContext: + mock_ctx_instance = MagicMock() + MockContext.instance.return_value = mock_ctx_instance + + mock_sub = MagicMock() + mock_pub = MagicMock() + mock_ctx_instance.socket.side_effect = [mock_sub, mock_pub] + + # MOCK add_behavior so we don't rely on internal attributes + agent.add_behavior = MagicMock() + + await agent.setup() + + # Check sockets + mock_sub.connect.assert_called_with(settings.zmq_settings.internal_sub_address) + mock_pub.connect.assert_called_with(settings.zmq_settings.internal_pub_address) + + # Verify add_behavior was called + agent.add_behavior.assert_called_once() + + +@pytest.mark.asyncio +async def test_receive_loop_json_error(agent): + """Verify that malformed JSON is caught and logged without crashing the loop.""" + agent.sub_socket.recv_multipart.side_effect = [ + (b"topic", b"INVALID{JSON"), + asyncio.CancelledError, + ] + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic") + + +@pytest.mark.asyncio +async def test_receive_loop_override_trigger(agent): + """Verify routing 'override' to a Trigger.""" + agent._trigger_map["101"] = "trigger_slug" + payload = json.dumps({"type": "override", "context": "101"}).encode() + + agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] + agent._send_to_bdi = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent._send_to_bdi.assert_awaited_once_with("force_trigger", "trigger_slug") + + +@pytest.mark.asyncio +async def test_receive_loop_override_norm(agent): + """Verify routing 'override' to a Conditional Norm.""" + agent._cond_norm_map["202"] = "norm_slug" + payload = json.dumps({"type": "override", "context": "202"}).encode() + + agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] + agent._send_to_bdi_belief = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent._send_to_bdi_belief.assert_awaited_once_with("norm_slug", "cond_norm") + + +@pytest.mark.asyncio +async def test_receive_loop_override_missing(agent): + """Verify warning log when an override ID is not found in any map.""" + payload = json.dumps({"type": "override", "context": "999"}).encode() + + agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent.logger.warning.assert_called_with("Could not determine which element to override.") + + +@pytest.mark.asyncio +async def test_receive_loop_unachieve_logic(agent): + """Verify success and failure paths for override_unachieve.""" + agent._cond_norm_map["202"] = "norm_slug" + + success_payload = json.dumps({"type": "override_unachieve", "context": "202"}).encode() + fail_payload = json.dumps({"type": "override_unachieve", "context": "999"}).encode() + + agent.sub_socket.recv_multipart.side_effect = [ + (b"topic", success_payload), + (b"topic", fail_payload), + asyncio.CancelledError, + ] + agent._send_to_bdi_belief = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + # Assert success call (True flag for unachieve) + agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True) + # Assert failure log + agent.logger.warning.assert_called_with( + "Could not determine which conditional norm to unachieve." + ) + + +@pytest.mark.asyncio +async def test_receive_loop_pause_resume(agent): + """Verify pause and resume toggle logic and logging.""" + pause_payload = json.dumps({"type": "pause", "context": "true"}).encode() + resume_payload = json.dumps({"type": "pause", "context": ""}).encode() + + agent.sub_socket.recv_multipart.side_effect = [ + (b"topic", pause_payload), + (b"topic", resume_payload), + asyncio.CancelledError, + ] + agent._send_pause_command = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent._send_pause_command.assert_any_call("true") + agent._send_pause_command.assert_any_call("") + agent.logger.info.assert_any_call("Sent pause command.") + agent.logger.info.assert_any_call("Sent resume command.") + + +@pytest.mark.asyncio +async def test_receive_loop_phase_control(agent): + """Verify experiment flow control (next_phase).""" + payload = json.dumps({"type": "next_phase", "context": ""}).encode() + + agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] + agent._send_experiment_control_to_bdi_core = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent._send_experiment_control_to_bdi_core.assert_awaited_once_with("next_phase") + + +@pytest.mark.asyncio +async def test_handle_message_unknown_thread(agent): + """Test handling of an unknown message thread (lines 213-214).""" + msg = InternalMessage(to="me", thread="unknown_thread", body="test") + await agent.handle_message(msg) + + agent.logger.debug.assert_called_with( + "Received internal message on unhandled thread: unknown_thread" + ) + + +@pytest.mark.asyncio +async def test_send_to_bdi_belief_edge_cases(agent): + """ + Covers: + - Unknown asl_type warning (lines 326-328) + - unachieve=True logic (lines 334-337) + """ + # 1. Unknown Type + await agent._send_to_bdi_belief("slug", "unknown_type") + + agent.logger.warning.assert_called_with("Tried to send belief with unknown type") + agent.send.assert_not_called() + + # Reset mock for part 2 + agent.send.reset_mock() + + # 2. Unachieve = True + await agent._send_to_bdi_belief("slug", "cond_norm", unachieve=True) + + agent.send.assert_awaited() + sent_msg = agent.send.call_args.args[0] + + # Verify it is a delete operation + body_obj = BeliefMessage.model_validate_json(sent_msg.body) + + # Verify 'delete' has content + assert body_obj.delete is not None + assert len(body_obj.delete) == 1 + assert body_obj.delete[0].name == "force_slug" + + # Verify 'create' is empty (handling both None and []) + assert not body_obj.create + + +@pytest.mark.asyncio +async def test_send_experiment_control_unknown(agent): + """Test sending an unknown experiment control type (lines 366-367).""" + await agent._send_experiment_control_to_bdi_core("invalid_command") + + agent.logger.warning.assert_called_with( + "Received unknown experiment control type '%s' to send to BDI Core.", "invalid_command" + ) + + # Ensure it still sends an empty message (as per code logic, though thread is empty) + agent.send.assert_awaited() + msg = agent.send.call_args[0][0] + assert msg.thread == "" diff --git a/test/unit/api/v1/endpoints/test_user_interact.py b/test/unit/api/v1/endpoints/test_user_interact.py index ddb9932..9785eec 100644 --- a/test/unit/api/v1/endpoints/test_user_interact.py +++ b/test/unit/api/v1/endpoints/test_user_interact.py @@ -94,3 +94,55 @@ async def test_experiment_stream_direct_call(): mock_socket.connect.assert_called() mock_socket.subscribe.assert_called_with(b"experiment") mock_socket.close.assert_called() + + +@pytest.mark.asyncio +async def test_status_stream_direct_call(): + """ + Test the status stream, ensuring it handles messages and sends pings on timeout. + """ + mock_socket = AsyncMock() + + # Define the sequence of events for the socket: + # 1. Successfully receive a message + # 2. Timeout (which should trigger the ': ping' yield) + # 3. Another message (which won't be reached because we'll simulate disconnect) + mock_socket.recv_multipart.side_effect = [ + (b"topic", b"status_update"), + TimeoutError(), + (b"topic", b"ignored_msg"), + ] + + mock_socket.close = MagicMock() + mock_socket.connect = MagicMock() + mock_socket.subscribe = MagicMock() + + mock_context = MagicMock() + mock_context.socket.return_value = mock_socket + + # Mock the ZMQ Context to return our mock_socket + with patch( + "control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context + ): + mock_request = AsyncMock() + + # is_disconnected sequence: + # 1. False -> Process "status_update" + # 2. False -> Process TimeoutError (yield ping) + # 3. True -> Break loop (client disconnected) + mock_request.is_disconnected.side_effect = [False, False, True] + + # Call the status_stream function explicitly + response = await user_interact.status_stream(mock_request) + + lines = [] + async for line in response.body_iterator: + lines.append(line) + + # Assertions + assert "data: status_update\n\n" in lines + assert ": ping\n\n" in lines # Verify lines 91-92 (ping logic) + + mock_socket.connect.assert_called() + mock_socket.subscribe.assert_called_with(b"status") + mock_socket.close.assert_called()