From 041fc4ab6e01183512345d83887e9df5e4d58c17 Mon Sep 17 00:00:00 2001 From: Pim Hutting Date: Thu, 15 Jan 2026 09:02:52 +0100 Subject: [PATCH 1/7] chore: cond_norms unachieve and via belief msg --- .../user_interrupt/user_interrupt_agent.py | 140 ++++++++++-------- 1 file changed, 80 insertions(+), 60 deletions(-) diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index deddbba..0bde563 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -26,7 +26,7 @@ class UserInterruptAgent(BaseAgent): - Send a prioritized message to the `RobotSpeechAgent` - Send a prioritized gesture to the `RobotGestureAgent` - - Send a belief override to the `BDIProgramManager`in order to activate a + - Send a belief override to the `BDI Core` in order to activate a trigger/conditional norm or complete a goal. Prioritized actions clear the current RI queue before inserting the new item, @@ -75,7 +75,9 @@ class UserInterruptAgent(BaseAgent): These are the different types and contexts: - type: "speech", context: string that the robot has to say. - type: "gesture", context: single gesture name that the robot has to perform. - - type: "override", context: belief_id that overrides the goal/trigger/conditional norm. + - type: "override", context: id that belongs to the goal/trigger/conditional norm. + - type: "override_unachieve", context: id that belongs to the conditional norm to unachieve. + - type: "next_phase", context: None, indicates to the BDI Core to - type: "pause", context: boolean indicating whether to pause - type: "reset_phase", context: None, indicates to the BDI Core to - type: "reset_experiment", context: None, indicates to the BDI Core to @@ -93,68 +95,82 @@ class UserInterruptAgent(BaseAgent): self.logger.debug("Received event type %s", event_type) - if event_type == "speech": - await self._send_to_speech_agent(event_context) - self.logger.info( - "Forwarded button press (speech) with context '%s' to RobotSpeechAgent.", - event_context, - ) - elif event_type == "gesture": - await self._send_to_gesture_agent(event_context) - self.logger.info( - "Forwarded button press (gesture) with context '%s' to RobotGestureAgent.", - event_context, - ) - elif event_type == "override": - ui_id = str(event_context) - if asl_trigger := self._trigger_map.get(ui_id): - await self._send_to_bdi("force_trigger", asl_trigger) + match event_type: + case "speech": + await self._send_to_speech_agent(event_context) self.logger.info( - "Forwarded button press (override) with context '%s' to BDI Core.", + "Forwarded button press (speech) with context '%s' to RobotSpeechAgent.", event_context, ) - elif asl_cond_norm := self._cond_norm_map.get(ui_id): - await self._send_to_bdi("force_norm", asl_cond_norm) + case "gesture": + await self._send_to_gesture_agent(event_context) self.logger.info( - "Forwarded button press (override) with context '%s' to BDIProgramManager.", + "Forwarded button press (gesture) with context '%s' to RobotGestureAgent.", event_context, ) - elif asl_goal := self._goal_map.get(ui_id): - await self._send_to_bdi_belief(asl_goal) - self.logger.info( - "Forwarded button press (override) with context '%s' to BDI Core.", + case "override": + ui_id = str(event_context) + if asl_trigger := self._trigger_map.get(ui_id): + await self._send_to_bdi("force_trigger", asl_trigger) + self.logger.info( + "Forwarded button press (override) with context '%s' to BDI Core.", + event_context, + ) + elif asl_cond_norm := self._cond_norm_map.get(ui_id): + await self._send_to_bdi_belief(asl_cond_norm) + self.logger.info( + "Forwarded button press (override) with context '%s' to BDI Core.", + event_context, + ) + elif asl_goal := self._goal_map.get(ui_id): + await self._send_to_bdi_belief(asl_goal) + self.logger.info( + "Forwarded button press (override) with context '%s' to BDI Core.", + event_context, + ) + # Send achieve_goal to program manager to update semantic belief extractor + goal_achieve_msg = InternalMessage( + to=settings.agent_settings.bdi_program_manager_name, + thread="achieve_goal", + body=ui_id, + ) + + await self.send(goal_achieve_msg) + else: + self.logger.warning("Could not determine which element to override.") + case "override_unachieve": + ui_id = str(event_context) + if asl_cond_norm := self._cond_norm_map.get(ui_id): + await self._send_to_bdi_belief(asl_cond_norm, True) + self.logger.info( + "Forwarded button press (override_unachieve)" + "with context '%s' to BDI Core.", + event_context, + ) + else: + self.logger.warning( + "Could not determine which conditional norm to unachieve." + ) + + case "pause": + self.logger.debug( + "Received pause/resume button press with context '%s'.", event_context + ) + await self._send_pause_command(event_context) + if event_context: + self.logger.info("Sent pause command.") + else: + self.logger.info("Sent resume command.") + + case "next_phase" | "reset_phase" | "reset_experiment": + await self._send_experiment_control_to_bdi_core(event_type) + case _: + self.logger.warning( + "Received button press with unknown type '%s' (context: '%s').", + event_type, event_context, ) - goal_achieve_msg = InternalMessage( - to=settings.agent_settings.bdi_program_manager_name, - thread="achieve_goal", - body=ui_id, - ) - - await self.send(goal_achieve_msg) - else: - self.logger.warning("Could not determine which element to override.") - - elif event_type == "pause": - self.logger.debug( - "Received pause/resume button press with context '%s'.", event_context - ) - await self._send_pause_command(event_context) - if event_context: - self.logger.info("Sent pause command.") - else: - self.logger.info("Sent resume command.") - - elif event_type in ["next_phase", "reset_phase", "reset_experiment"]: - await self._send_experiment_control_to_bdi_core(event_type) - else: - self.logger.warning( - "Received button press with unknown type '%s' (context: '%s').", - event_type, - event_context, - ) - async def handle_message(self, msg: InternalMessage): """ Handle commands received from other internal Python agents. @@ -195,9 +211,10 @@ class UserInterruptAgent(BaseAgent): await self._send_experiment_update(payload) self.logger.info(f"UI Update: Goal {goal_name} started (ID: {ui_id})") case "active_norms_update": - norm_list = [s.strip("() '\",") for s in msg.body.split(",") if s.strip("() '\",")] - - await self._broadcast_cond_norms(norm_list) + active_norms_asl = [ + s.strip("() '\",") for s in msg.body.split(",") if s.strip("() '\",") + ] + await self._broadcast_cond_norms(active_norms_asl) case _: self.logger.debug(f"Received internal message on unhandled thread: {msg.thread}") @@ -308,12 +325,15 @@ class UserInterruptAgent(BaseAgent): await self.send(msg) self.logger.info(f"Directly forced {thread} in BDI: {body}") - async def _send_to_bdi_belief(self, asl_goal: str): + async def _send_to_bdi_belief(self, asl_goal: str, unachieve: bool = False): """Send belief to BDI Core""" belief_name = f"achieved_{asl_goal}" belief = Belief(name=belief_name, arguments=None) self.logger.debug(f"Sending belief to BDI Core: {belief_name}") - belief_message = BeliefMessage(create=[belief]) + # Conditional norms are unachieved by removing the belief + belief_message = ( + BeliefMessage(delete=[belief]) if unachieve else BeliefMessage(create=[belief]) + ) msg = InternalMessage( to=settings.agent_settings.bdi_core_name, thread="beliefs", From b1c18abffd2d15cfa3473a56ebb2198c690c79d7 Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Fri, 16 Jan 2026 13:11:41 +0100 Subject: [PATCH 2/7] test: bunch of tests Written with AI, still need to check them ref: N25B-449 --- src/control_backend/agents/bdi/__init__.py | 3 - .../agents/bdi/agentspeak_ast.py | 10 +- .../agents/bdi/belief_collector_agent.py | 152 ----------- .../communication/ri_communication_agent.py | 4 +- src/control_backend/main.py | 2 + .../actuation/test_robot_gesture_agent.py | 71 ++++- .../actuation/test_robot_speech_agent.py | 12 +- test/unit/agents/bdi/test_agentspeak_ast.py | 186 +++++++++++++ .../agents/bdi/test_agentspeak_generator.py | 187 +++++++++++++ test/unit/agents/bdi/test_bdi_core_agent.py | 258 +++++++++++++++++- .../agents/bdi/test_bdi_program_manager.py | 213 +++++++++++++-- test/unit/agents/bdi/test_belief_collector.py | 135 --------- .../agents/bdi/test_text_belief_extractor.py | 156 ++++++++++- .../test_ri_communication_agent.py | 70 ++++- test/unit/agents/llm/test_llm_agent.py | 69 +++-- .../test_transcription_agent.py | 21 +- .../perception/vad_agent/test_vad_agent.py | 153 +++++++++++ .../vad_agent/test_vad_streaming.py | 49 ++++ test/unit/agents/test_base.py | 24 ++ .../user_interrupt/test_user_interrupt.py | 209 ++++++++++++-- .../api/v1/endpoints/test_user_interact.py | 96 +++++++ test/unit/test_main_sockets.py | 40 +++ 22 files changed, 1747 insertions(+), 373 deletions(-) delete mode 100644 src/control_backend/agents/bdi/belief_collector_agent.py create mode 100644 test/unit/agents/bdi/test_agentspeak_ast.py create mode 100644 test/unit/agents/bdi/test_agentspeak_generator.py delete mode 100644 test/unit/agents/bdi/test_belief_collector.py create mode 100644 test/unit/agents/perception/vad_agent/test_vad_agent.py create mode 100644 test/unit/agents/test_base.py create mode 100644 test/unit/api/v1/endpoints/test_user_interact.py create mode 100644 test/unit/test_main_sockets.py diff --git a/src/control_backend/agents/bdi/__init__.py b/src/control_backend/agents/bdi/__init__.py index 8d45440..d6f5124 100644 --- a/src/control_backend/agents/bdi/__init__.py +++ b/src/control_backend/agents/bdi/__init__.py @@ -1,8 +1,5 @@ from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent -from .belief_collector_agent import ( - BDIBeliefCollectorAgent as BDIBeliefCollectorAgent, -) from .text_belief_extractor_agent import ( TextBeliefExtractorAgent as TextBeliefExtractorAgent, ) diff --git a/src/control_backend/agents/bdi/agentspeak_ast.py b/src/control_backend/agents/bdi/agentspeak_ast.py index 188b4f3..68be531 100644 --- a/src/control_backend/agents/bdi/agentspeak_ast.py +++ b/src/control_backend/agents/bdi/agentspeak_ast.py @@ -77,7 +77,7 @@ class AstTerm(AstExpression, ABC): return AstBinaryOp(self, BinaryOperatorType.NOT_EQUALS, _coalesce_expr(other)) -@dataclass +@dataclass(eq=False) class AstAtom(AstTerm): """ Grounded expression in all lowercase. @@ -89,7 +89,7 @@ class AstAtom(AstTerm): return self.value.lower() -@dataclass +@dataclass(eq=False) class AstVar(AstTerm): """ Ungrounded variable expression. First letter capitalized. @@ -101,7 +101,7 @@ class AstVar(AstTerm): return self.name.capitalize() -@dataclass +@dataclass(eq=False) class AstNumber(AstTerm): value: int | float @@ -109,7 +109,7 @@ class AstNumber(AstTerm): return str(self.value) -@dataclass +@dataclass(eq=False) class AstString(AstTerm): value: str @@ -117,7 +117,7 @@ class AstString(AstTerm): return f'"{self.value}"' -@dataclass +@dataclass(eq=False) class AstLiteral(AstTerm): functor: str terms: list[AstTerm] = field(default_factory=list) diff --git a/src/control_backend/agents/bdi/belief_collector_agent.py b/src/control_backend/agents/bdi/belief_collector_agent.py deleted file mode 100644 index ac0e2e5..0000000 --- a/src/control_backend/agents/bdi/belief_collector_agent.py +++ /dev/null @@ -1,152 +0,0 @@ -import json - -from pydantic import ValidationError - -from control_backend.agents.base import BaseAgent -from control_backend.core.agent_system import InternalMessage -from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief, BeliefMessage - - -class BDIBeliefCollectorAgent(BaseAgent): - """ - BDI Belief Collector Agent. - - This agent acts as a central aggregator for beliefs derived from various sources (e.g., text, - emotion, vision). It receives raw extracted data from other agents, - normalizes them into valid :class:`Belief` objects, and forwards them as a unified packet to the - BDI Core Agent. - - It serves as a funnel to ensure the BDI agent receives a consistent stream of beliefs. - """ - - async def setup(self): - """ - Initialize the agent. - """ - self.logger.info("Setting up %s", self.name) - - async def handle_message(self, msg: InternalMessage): - """ - Handle incoming messages from other extractor agents. - - Routes the message to specific handlers based on the 'type' field in the JSON body. - Supported types: - - ``belief_extraction_text``: Handled by :meth:`_handle_belief_text` - - ``emotion_extraction_text``: Handled by :meth:`_handle_emo_text` - - :param msg: The received internal message. - """ - sender_node = msg.sender - - # Parse JSON payload - try: - payload = json.loads(msg.body) - except Exception as e: - self.logger.warning( - "BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s", - sender_node, - msg.body, - e, - ) - return - - msg_type = payload.get("type") - - # Prefer explicit 'type' field - if msg_type == "belief_extraction_text": - self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node) - await self._handle_belief_text(payload, sender_node) - # This is not implemented yet, but we keep the structure for future use - elif msg_type == "emotion_extraction_text": - self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node) - await self._handle_emo_text(payload, sender_node) - else: - self.logger.warning( - "Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type - ) - - async def _handle_belief_text(self, payload: dict, origin: str): - """ - Process text-based belief extraction payloads. - - Expected payload format:: - - { - "type": "belief_extraction_text", - "beliefs": { - "user_said": ["Can you help me?"], - "intention": ["ask_help"] - } - } - - Validates and converts the dictionary items into :class:`Belief` objects. - - :param payload: The dictionary payload containing belief data. - :param origin: The name of the sender agent. - """ - beliefs = payload.get("beliefs", {}) - - if not beliefs: - self.logger.debug("Received empty beliefs set.") - return - - def try_create_belief(name, arguments) -> Belief | None: - """ - Create a belief object from name and arguments, or return None silently if the input is - not correct. - - :param name: The name of the belief. - :param arguments: The arguments of the belief. - :return: A Belief object if the input is valid or None. - """ - try: - return Belief(name=name, arguments=arguments) - except ValidationError: - return None - - beliefs = [ - belief - for name, arguments in beliefs.items() - if (belief := try_create_belief(name, arguments)) is not None - ] - - self.logger.debug("Forwarding %d beliefs.", len(beliefs)) - for belief in beliefs: - for argument in belief.arguments: - self.logger.debug(" - %s %s", belief.name, argument) - - await self._send_beliefs_to_bdi(beliefs, origin=origin) - - async def _handle_emo_text(self, payload: dict, origin: str): - """ - Process emotion extraction payloads. - - **TODO**: Implement this method once emotion recognition is integrated. - - :param payload: The dictionary payload containing emotion data. - :param origin: The name of the sender agent. - """ - pass - - async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None): - """ - Send a list of aggregated beliefs to the BDI Core Agent. - - Wraps the beliefs in a :class:`BeliefMessage` and sends it via the 'beliefs' thread. - - :param beliefs: The list of Belief objects to send. - :param origin: (Optional) The original source of the beliefs (unused currently). - """ - if not beliefs: - return - - msg = InternalMessage( - to=settings.agent_settings.bdi_core_name, - sender=self.name, - body=BeliefMessage(create=beliefs).model_dump_json(), - thread="beliefs", - ) - - await self.send(msg) - self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs)) diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index 719053c..252502d 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -324,7 +324,7 @@ class RICommunicationAgent(BaseAgent): async def handle_message(self, msg: InternalMessage): try: pause_command = PauseCommand.model_validate_json(msg.body) - self._req_socket.send_json(pause_command.model_dump()) - self.logger.debug(self._req_socket.recv_json()) + await self._req_socket.send_json(pause_command.model_dump()) + self.logger.debug(await self._req_socket.recv_json()) except ValidationError: self.logger.warning("Incorrect message format for PauseCommand.") diff --git a/src/control_backend/main.py b/src/control_backend/main.py index d20cc66..ec93b1e 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -172,6 +172,8 @@ async def lifespan(app: FastAPI): await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value]) # Additional shutdown logic goes here + for agent in agents: + await agent.stop() logger.info("Application shutdown complete.") diff --git a/test/unit/agents/actuation/test_robot_gesture_agent.py b/test/unit/agents/actuation/test_robot_gesture_agent.py index fe051a6..225278d 100644 --- a/test/unit/agents/actuation/test_robot_gesture_agent.py +++ b/test/unit/agents/actuation/test_robot_gesture_agent.py @@ -28,7 +28,11 @@ async def test_setup_bind(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -55,7 +59,11 @@ async def test_setup_connect(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -119,6 +127,65 @@ async def test_handle_message_rejects_invalid_gesture_tag(): pubsocket.send_json.assert_not_awaited() +@pytest.mark.asyncio +async def test_handle_message_sends_valid_single_gesture_command(): + """Internal message with valid single gesture is forwarded.""" + pubsocket = AsyncMock() + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.pubsocket = pubsocket + + payload = { + "endpoint": RIEndpoint.GESTURE_SINGLE, + "data": "wave", + } + msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) + + await agent.handle_message(msg) + + pubsocket.send_json.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_message_rejects_invalid_single_gesture(): + """Internal message with invalid single gesture is not forwarded.""" + pubsocket = AsyncMock() + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.pubsocket = pubsocket + + payload = { + "endpoint": RIEndpoint.GESTURE_SINGLE, + "data": "dance", + } + msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) + + await agent.handle_message(msg) + + pubsocket.send_json.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_zmq_command_loop_valid_single_gesture_payload(): + """UI command with valid single gesture is read from SUB and published.""" + command = {"endpoint": RIEndpoint.GESTURE_SINGLE, "data": "wave"} + fake_socket = AsyncMock() + + async def recv_once(): + agent._running = False + return b"command", json.dumps(command).encode("utf-8") + + fake_socket.recv_multipart = recv_once + fake_socket.send_json = AsyncMock() + + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.subsocket = fake_socket + agent.pubsocket = fake_socket + agent._running = True + + await agent._zmq_command_loop() + + fake_socket.send_json.assert_awaited_once() + + @pytest.mark.asyncio async def test_handle_message_invalid_payload(): """Invalid payload is caught and does not send.""" diff --git a/test/unit/agents/actuation/test_robot_speech_agent.py b/test/unit/agents/actuation/test_robot_speech_agent.py index d95f66a..e5a664d 100644 --- a/test/unit/agents/actuation/test_robot_speech_agent.py +++ b/test/unit/agents/actuation/test_robot_speech_agent.py @@ -30,7 +30,11 @@ async def test_setup_bind(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -48,7 +52,11 @@ async def test_setup_connect(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() diff --git a/test/unit/agents/bdi/test_agentspeak_ast.py b/test/unit/agents/bdi/test_agentspeak_ast.py new file mode 100644 index 0000000..8d3bdf0 --- /dev/null +++ b/test/unit/agents/bdi/test_agentspeak_ast.py @@ -0,0 +1,186 @@ +import pytest + +from control_backend.agents.bdi.agentspeak_ast import ( + AstAtom, + AstBinaryOp, + AstLiteral, + AstLogicalExpression, + AstNumber, + AstPlan, + AstProgram, + AstRule, + AstStatement, + AstString, + AstVar, + BinaryOperatorType, + StatementType, + TriggerType, + _coalesce_expr, +) + + +def test_ast_atom(): + atom = AstAtom("test") + assert str(atom) == "test" + assert atom._to_agentspeak() == "test" + + +def test_ast_var(): + var = AstVar("Variable") + assert str(var) == "Variable" + assert var._to_agentspeak() == "Variable" + + +def test_ast_number(): + num = AstNumber(42) + assert str(num) == "42" + num_float = AstNumber(3.14) + assert str(num_float) == "3.14" + + +def test_ast_string(): + s = AstString("hello") + assert str(s) == '"hello"' + + +def test_ast_literal(): + lit = AstLiteral("functor", [AstAtom("atom"), AstNumber(1)]) + assert str(lit) == "functor(atom, 1)" + lit_empty = AstLiteral("functor") + assert str(lit_empty) == "functor" + + +def test_ast_binary_op(): + left = AstNumber(1) + right = AstNumber(2) + op = AstBinaryOp(left, BinaryOperatorType.GREATER_THAN, right) + assert str(op) == "1 > 2" + + # Test logical wrapper + assert isinstance(op.left, AstLogicalExpression) + assert isinstance(op.right, AstLogicalExpression) + + +def test_ast_binary_op_parens(): + # 1 > 2 + inner = AstBinaryOp(AstNumber(1), BinaryOperatorType.GREATER_THAN, AstNumber(2)) + # (1 > 2) & 3 + outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstNumber(3)) + assert str(outer) == "(1 > 2) & 3" + + # 3 & (1 > 2) + outer_right = AstBinaryOp(AstNumber(3), BinaryOperatorType.AND, inner) + assert str(outer_right) == "3 & (1 > 2)" + + +def test_ast_binary_op_parens_negated(): + inner = AstLogicalExpression(AstAtom("foo"), negated=True) + outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstAtom("bar")) + # The current implementation checks `if self.left.negated: l_str = f"({l_str})"` + # str(inner) is "not foo" + # so we expect "(not foo) & bar" + assert str(outer) == "(not foo) & bar" + + outer_right = AstBinaryOp(AstAtom("bar"), BinaryOperatorType.AND, inner) + assert str(outer_right) == "bar & (not foo)" + + +def test_ast_logical_expression_negation(): + expr = AstLogicalExpression(AstAtom("true"), negated=True) + assert str(expr) == "not true" + + expr_neg_neg = ~expr + assert str(expr_neg_neg) == "true" + assert not expr_neg_neg.negated + + # Invert a non-logical expression (wraps it) + term = AstAtom("true") + inverted = ~term + assert isinstance(inverted, AstLogicalExpression) + assert inverted.negated + assert str(inverted) == "not true" + + +def test_ast_logical_expression_no_negation(): + # _as_logical on already logical expression + expr = AstLogicalExpression(AstAtom("x")) + # Doing binary op will call _as_logical + op = AstBinaryOp(expr, BinaryOperatorType.AND, AstAtom("y")) + assert isinstance(op.left, AstLogicalExpression) + assert op.left is expr # Should reuse instance + + +def test_ast_operators(): + t1 = AstAtom("a") + t2 = AstAtom("b") + + assert str(t1 & t2) == "a & b" + assert str(t1 | t2) == "a | b" + assert str(t1 >= t2) == "a >= b" + assert str(t1 > t2) == "a > b" + assert str(t1 <= t2) == "a <= b" + assert str(t1 < t2) == "a < b" + assert str(t1 == t2) == "a == b" + assert str(t1 != t2) == r"a \== b" + + +def test_coalesce_expr(): + t = AstAtom("a") + assert str(t & "b") == 'a & "b"' + assert str(t & 1) == "a & 1" + assert str(t & 1.5) == "a & 1.5" + + with pytest.raises(TypeError): + _coalesce_expr(None) + + +def test_ast_statement(): + stmt = AstStatement(StatementType.DO_ACTION, AstLiteral("action")) + assert str(stmt) == ".action" + + +def test_ast_rule(): + # Rule with condition + rule = AstRule(AstLiteral("head"), AstLiteral("body")) + assert str(rule) == "head :- body." + + # Rule without condition + rule_simple = AstRule(AstLiteral("fact")) + assert str(rule_simple) == "fact." + + +def test_ast_plan(): + plan = AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("goal"), + [AstLiteral("context")], + [AstStatement(StatementType.DO_ACTION, AstLiteral("action"))], + ) + output = str(plan) + # verify parts exist + assert "+!goal" in output + assert ": context" in output + assert "<- .action." in output + + +def test_ast_plan_no_context(): + plan = AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("goal"), + [], + [AstStatement(StatementType.DO_ACTION, AstLiteral("action"))], + ) + output = str(plan) + assert "+!goal" in output + assert ": " not in output + assert "<- .action." in output + + +def test_ast_program(): + prog = AstProgram( + rules=[AstRule(AstLiteral("fact"))], + plans=[AstPlan(TriggerType.ADDED_BELIEF, AstLiteral("b"), [], [])], + ) + output = str(prog) + assert "fact." in output + assert "+b" in output diff --git a/test/unit/agents/bdi/test_agentspeak_generator.py b/test/unit/agents/bdi/test_agentspeak_generator.py new file mode 100644 index 0000000..5a3a849 --- /dev/null +++ b/test/unit/agents/bdi/test_agentspeak_generator.py @@ -0,0 +1,187 @@ +import uuid + +import pytest + +from control_backend.agents.bdi.agentspeak_ast import AstProgram +from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator +from control_backend.schemas.program import ( + BasicNorm, + ConditionalNorm, + Gesture, + GestureAction, + Goal, + InferredBelief, + KeywordBelief, + LLMAction, + LogicalOperator, + Phase, + Plan, + Program, + SemanticBelief, + SpeechAction, + Trigger, +) + + +@pytest.fixture +def generator(): + return AgentSpeakGenerator() + + +def test_generate_empty_program(generator): + prog = Program(phases=[]) + code = generator.generate(prog) + assert 'phase("end").' in code + assert "!notify_cycle" in code + + +def test_generate_basic_norm(generator): + norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="be nice") + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert f'norm("be nice") :- phase("{phase.id}").' in code + + +def test_generate_critical_norm(generator): + norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="safety", critical=True) + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert f'critical_norm("safety") :- phase("{phase.id}").' in code + + +def test_generate_conditional_norm(generator): + cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="please") + norm = ConditionalNorm(id=uuid.uuid4(), name="n1", norm="help", condition=cond) + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert 'norm("help")' in code + assert 'keyword_said("please")' in code + assert f"force_norm_{generator._slugify_str(norm.norm)}" in code + + +def test_generate_goal_and_plan(generator): + action = SpeechAction(id=uuid.uuid4(), name="s1", text="hello") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[action]) + # IMPORTANT: can_fail must be False for +achieved_ belief to be added + goal = Goal(id=uuid.uuid4(), name="g1", description="desc", plan=plan, can_fail=False) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + # Check trigger for goal + goal_slug = generator._slugify_str(goal.name) + assert f"+!{goal_slug}" in code + assert f'phase("{phase.id}")' in code + assert '!say("hello")' in code + + # Check success belief addition + assert f"+achieved_{goal_slug}" in code + + +def test_generate_subgoal(generator): + subplan = Plan(id=uuid.uuid4(), name="p2", steps=[]) + subgoal = Goal(id=uuid.uuid4(), name="sub1", description="sub", plan=subplan) + + plan = Plan(id=uuid.uuid4(), name="p1", steps=[subgoal]) + goal = Goal(id=uuid.uuid4(), name="g1", description="main", plan=plan) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + subgoal_slug = generator._slugify_str(subgoal.name) + # Main goal calls subgoal + assert f"!{subgoal_slug}" in code + # Subgoal plan exists + assert f"+!{subgoal_slug}" in code + + +def test_generate_trigger(generator): + cond = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[]) + trigger = Trigger(id=uuid.uuid4(), name="t1", condition=cond, plan=plan) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[], triggers=[trigger]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + # Trigger logic is added to check_triggers + assert f"{generator.slugify(cond)}" in code + assert f'notify_trigger_start("{generator.slugify(trigger)}")' in code + assert f'notify_trigger_end("{generator.slugify(trigger)}")' in code + + +def test_phase_transition(generator): + phase1 = Phase(id=uuid.uuid4(), name="p1", norms=[], goals=[], triggers=[]) + phase2 = Phase(id=uuid.uuid4(), name="p2", norms=[], goals=[], triggers=[]) + prog = Program(phases=[phase1, phase2]) + + code = generator.generate(prog) + assert "transition_phase" in code + assert f'phase("{phase1.id}")' in code + assert f'phase("{phase2.id}")' in code + assert "force_transition_phase" in code + + +def test_astify_gesture(generator): + gesture = Gesture(type="single", name="wave") + action = GestureAction(id=uuid.uuid4(), name="g1", gesture=gesture) + ast = generator._astify(action) + assert str(ast) == 'gesture("single", "wave")' + + +def test_astify_llm_action(generator): + action = LLMAction(id=uuid.uuid4(), name="l1", goal="be funny") + ast = generator._astify(action) + assert str(ast) == 'reply_with_goal("be funny")' + + +def test_astify_inferred_belief_and(generator): + left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a") + right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b") + inf = InferredBelief( + id=uuid.uuid4(), name="i1", operator=LogicalOperator.AND, left=left, right=right + ) + + ast = generator._astify(inf) + assert 'keyword_said("a") & keyword_said("b")' == str(ast) + + +def test_astify_inferred_belief_or(generator): + left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a") + right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b") + inf = InferredBelief( + id=uuid.uuid4(), name="i1", operator=LogicalOperator.OR, left=left, right=right + ) + + ast = generator._astify(inf) + assert 'keyword_said("a") | keyword_said("b")' == str(ast) + + +def test_astify_semantic_belief(generator): + sb = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc") + ast = generator._astify(sb) + assert str(ast) == f"semantic_{generator._slugify_str(sb.name)}" + + +def test_slugify_not_implemented(generator): + with pytest.raises(NotImplementedError): + generator.slugify("not a program element") + + +def test_astify_not_implemented(generator): + with pytest.raises(NotImplementedError): + generator._astify("not a program element") + + +def test_process_phase_transition_from_none(generator): + # Initialize AstProgram manually as we are bypassing generate() + generator._asp = AstProgram() + # Should safely return doing nothing + generator._add_phase_transition(None, None) + + assert len(generator._asp.plans) == 0 diff --git a/test/unit/agents/bdi/test_bdi_core_agent.py b/test/unit/agents/bdi/test_bdi_core_agent.py index 64f2ca7..152d901 100644 --- a/test/unit/agents/bdi/test_bdi_core_agent.py +++ b/test/unit/agents/bdi/test_bdi_core_agent.py @@ -57,11 +57,22 @@ async def test_handle_belief_collector_message(agent, mock_settings): await agent.handle_message(msg) - # Expect bdi_agent.call to be triggered to add belief - args = agent.bdi_agent.call.call_args.args - assert args[0] == agentspeak.Trigger.addition - assert args[1] == agentspeak.GoalType.belief - assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),)) + # Check for the specific call we expect among all calls + # bdi_agent.call is called multiple times (for transition_phase, check_triggers) + # We want to confirm the belief addition call exists + found_call = False + for call in agent.bdi_agent.call.call_args_list: + args = call.args + if ( + args[0] == agentspeak.Trigger.addition + and args[1] == agentspeak.GoalType.belief + and args[2].functor == "user_said" + and args[2].args[0].functor == "Hello" + ): + found_call = True + break + + assert found_call, "Expected belief addition call not found in bdi_agent.call history" @pytest.mark.asyncio @@ -77,11 +88,19 @@ async def test_handle_delete_belief_message(agent, mock_settings): ) await agent.handle_message(msg) - # Expect bdi_agent.call to be triggered to remove belief - args = agent.bdi_agent.call.call_args.args - assert args[0] == agentspeak.Trigger.removal - assert args[1] == agentspeak.GoalType.belief - assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),)) + found_call = False + for call in agent.bdi_agent.call.call_args_list: + args = call.args + if ( + args[0] == agentspeak.Trigger.removal + and args[1] == agentspeak.GoalType.belief + and args[2].functor == "user_said" + and args[2].args[0].functor == "Hello" + ): + found_call = True + break + + assert found_call @pytest.mark.asyncio @@ -171,7 +190,11 @@ def test_remove_belief_success_wakes_loop(agent): agent._remove_belief("remove_me", ["x"]) assert agent.bdi_agent.call.called - trigger, goaltype, literal, *_ = agent.bdi_agent.call.call_args.args + + call_args = agent.bdi_agent.call.call_args.args + trigger = call_args[0] + goaltype = call_args[1] + literal = call_args[2] assert trigger == agentspeak.Trigger.removal assert goaltype == agentspeak.GoalType.belief @@ -288,3 +311,216 @@ async def test_deadline_sleep_branch(agent): duration = time.time() - start_time assert duration >= 0.004 # loop slept until deadline + + +@pytest.mark.asyncio +async def test_handle_new_program(agent): + agent._load_asl = AsyncMock() + agent.add_behavior = MagicMock() + # Mock existing loop task so it can be cancelled + mock_task = MagicMock() + mock_task.cancel = MagicMock() + agent._bdi_loop_task = mock_task + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) + + msg = InternalMessage(to="bdi_agent", thread="new_program", body="path/to/asl.asl") + + await agent.handle_message(msg) + + mock_task.cancel.assert_called_once() + agent._load_asl.assert_awaited_once_with("path/to/asl.asl") + agent.add_behavior.assert_called() + + +@pytest.mark.asyncio +async def test_handle_user_interrupts(agent, mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + # force_phase_transition + agent._set_goal = MagicMock() + msg = InternalMessage( + to="bdi_agent", + sender=mock_settings.agent_settings.user_interrupt_name, + thread="force_phase_transition", + body="", + ) + await agent.handle_message(msg) + agent._set_goal.assert_called_with("transition_phase") + + # force_trigger + agent._force_trigger = MagicMock() + msg.thread = "force_trigger" + msg.body = "trigger_x" + await agent.handle_message(msg) + agent._force_trigger.assert_called_with("trigger_x") + + # force_norm + agent._force_norm = MagicMock() + msg.thread = "force_norm" + msg.body = "norm_y" + await agent.handle_message(msg) + agent._force_norm.assert_called_with("norm_y") + + # force_next_phase + agent._force_next_phase = MagicMock() + msg.thread = "force_next_phase" + msg.body = "" + await agent.handle_message(msg) + agent._force_next_phase.assert_called_once() + + # unknown interrupt + agent.logger = MagicMock() + msg.thread = "unknown_thing" + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_custom_action_reply_with_goal(agent): + agent._send_to_llm = MagicMock(side_effect=agent.send) + agent._add_custom_actions() + action_fn = agent.actions.actions[(".reply_with_goal", 3)] + + mock_term = MagicMock(args=["msg", "norms", "goal"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + agent._send_to_llm.assert_called_with("msg", "norms", "goal") + + +@pytest.mark.asyncio +async def test_custom_action_notify_norms(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_norms", 1)] + + mock_term = MagicMock(args=["norms_list"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + + agent.send.assert_called() + msg = agent.send.call_args[0][0] + assert msg.thread == "active_norms_update" + assert msg.body == "norms_list" + + +@pytest.mark.asyncio +async def test_custom_action_say(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".say", 1)] + + mock_term = MagicMock(args=["hello"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + + assert agent.send.call_count == 2 + msgs = [c[0][0] for c in agent.send.call_args_list] + assert any(m.to == settings.agent_settings.robot_speech_name for m in msgs) + assert any( + m.to == settings.agent_settings.llm_name and m.thread == "assistant_message" for m in msgs + ) + + +@pytest.mark.asyncio +async def test_custom_action_gesture(agent): + agent._add_custom_actions() + # Test single + action_fn = agent.actions.actions[(".gesture", 2)] + mock_term = MagicMock(args=["single", "wave"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert "actuate/gesture/single" in msg.body + + # Test tag + mock_term.args = ["tag", "happy"] + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert "actuate/gesture/tag" in msg.body + + +@pytest.mark.asyncio +async def test_custom_action_notify_user_said(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_user_said", 1)] + mock_term = MagicMock(args=["hello"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert msg.to == settings.agent_settings.llm_name + assert msg.thread == "user_message" + + +@pytest.mark.asyncio +async def test_custom_action_notify_trigger_start_end(agent): + agent._add_custom_actions() + # Start + action_fn = agent.actions.actions[(".notify_trigger_start", 1)] + gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "trigger_start" + + # End + action_fn = agent.actions.actions[(".notify_trigger_end", 1)] + gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "trigger_end" + + +@pytest.mark.asyncio +async def test_custom_action_notify_goal_start(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_goal_start", 1)] + gen = action_fn(agent, MagicMock(args=["g1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "goal_start" + + +@pytest.mark.asyncio +async def test_custom_action_notify_transition_phase(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_transition_phase", 2)] + gen = action_fn(agent, MagicMock(args=["old", "new"]), MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert msg.thread == "transition_phase" + assert "old" in msg.body and "new" in msg.body + + +def test_remove_belief_no_args(agent): + agent._wake_bdi_loop = MagicMock() + agent.bdi_agent.call.return_value = True + agent._remove_belief("fact", None) + assert agent.bdi_agent.call.called + + +def test_set_goal_with_args(agent): + agent._wake_bdi_loop = MagicMock() + agent._set_goal("goal", ["arg1", "arg2"]) + assert agent.bdi_agent.call.called + + +def test_format_belief_string(): + assert BDICoreAgent.format_belief_string("b") == "b" + assert BDICoreAgent.format_belief_string("b", ["a1", "a2"]) == "b(a1,a2)" + + +def test_force_norm(agent): + agent._add_belief = MagicMock() + agent._force_norm("be_polite") + agent._add_belief.assert_called_with("force_be_polite") + + +def test_force_trigger(agent): + agent._set_goal = MagicMock() + agent._force_trigger("trig") + agent._set_goal.assert_called_with("trig") + + +def test_force_next_phase(agent): + agent._set_goal = MagicMock() + agent._force_next_phase() + agent._set_goal.assert_called_with("force_transition_phase") diff --git a/test/unit/agents/bdi/test_bdi_program_manager.py b/test/unit/agents/bdi/test_bdi_program_manager.py index 2bed2a7..540a172 100644 --- a/test/unit/agents/bdi/test_bdi_program_manager.py +++ b/test/unit/agents/bdi/test_bdi_program_manager.py @@ -1,13 +1,13 @@ import asyncio +import json import sys import uuid -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock, mock_open, patch import pytest from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager from control_backend.core.agent_system import InternalMessage -from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program # Fix Windows Proactor loop for zmq @@ -48,24 +48,26 @@ def make_valid_program_json(norm="N1", goal="G1") -> str: ).model_dump_json() -@pytest.mark.skip(reason="Functionality being rebuilt.") @pytest.mark.asyncio -async def test_send_to_bdi(): +async def test_create_agentspeak_and_send_to_bdi(mock_settings): manager = BDIProgramManager(name="program_manager_test") manager.send = AsyncMock() program = Program.model_validate_json(make_valid_program_json()) - await manager._create_agentspeak_and_send_to_bdi(program) + + with patch("builtins.open", mock_open()) as mock_file: + await manager._create_agentspeak_and_send_to_bdi(program) + + # Check file writing + mock_file.assert_called_with("src/control_backend/agents/bdi/agentspeak.asl", "w") + handle = mock_file() + handle.write.assert_called() assert manager.send.await_count == 1 msg: InternalMessage = manager.send.await_args[0][0] - assert msg.thread == "beliefs" - - beliefs = BeliefMessage.model_validate_json(msg.body) - names = {b.name: b.arguments for b in beliefs.beliefs} - - assert "norms" in names and names["norms"] == ["N1"] - assert "goals" in names and names["goals"] == ["G1"] + assert msg.thread == "new_program" + assert msg.to == mock_settings.agent_settings.bdi_core_name + assert msg.body == "src/control_backend/agents/bdi/agentspeak.asl" @pytest.mark.asyncio @@ -81,6 +83,9 @@ async def test_receive_programs_valid_and_invalid(): manager.sub_socket = sub manager._create_agentspeak_and_send_to_bdi = AsyncMock() manager._send_clear_llm_history = AsyncMock() + manager._send_program_to_user_interrupt = AsyncMock() + manager._send_beliefs_to_semantic_belief_extractor = AsyncMock() + manager._send_goals_to_semantic_belief_extractor = AsyncMock() try: # Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out @@ -94,10 +99,9 @@ async def test_receive_programs_valid_and_invalid(): assert forwarded.phases[0].norms[0].name == "N1" assert forwarded.phases[0].goals[0].name == "G1" - # Verify history clear was triggered - assert ( - manager._send_clear_llm_history.await_count == 2 - ) # first sends program to UserInterrupt, then clears LLM + # Verify history clear was triggered exactly once (for the valid program) + # The invalid program loop `continue`s before calling _send_clear_llm_history + assert manager._send_clear_llm_history.await_count == 1 @pytest.mark.asyncio @@ -115,4 +119,179 @@ async def test_send_clear_llm_history(mock_settings): # Verify the content and recipient assert msg.body == "clear_history" - assert msg.to == "llm_agent" + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase(mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + # Setup state + prog = Program.model_validate_json(make_valid_program_json(norm="N1", goal="G1")) + manager._initialize_internal_state(prog) + + # Test valid transition (to same phase for simplicity, or we need 2 phases) + # Let's create a program with 2 phases + phase2_id = uuid.uuid4() + phase2 = Phase(id=phase2_id, name="Phase 2", norms=[], goals=[], triggers=[]) + prog.phases.append(phase2) + manager._initialize_internal_state(prog) + + current_phase_id = str(prog.phases[0].id) + next_phase_id = str(phase2_id) + + payload = json.dumps({"old": current_phase_id, "new": next_phase_id}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + assert str(manager._phase.id) == next_phase_id + + # Allow background tasks to run (add_behavior) + await asyncio.sleep(0) + + # Check notifications sent + # 1. beliefs to extractor + # 2. goals to extractor + # 3. notification to user interrupt + + assert manager.send.await_count >= 3 + + # Verify user interrupt notification + calls = manager.send.await_args_list + ui_msgs = [ + c[0][0] for c in calls if c[0][0].to == mock_settings.agent_settings.user_interrupt_name + ] + assert len(ui_msgs) > 0 + assert ui_msgs[-1].body == next_phase_id + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase_desync(): + manager = BDIProgramManager(name="program_manager_test") + manager.logger = MagicMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + + current_phase_id = str(prog.phases[0].id) + + # Request transition from WRONG old phase + payload = json.dumps({"old": "wrong_id", "new": "some_new_id"}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + # Should warn and do nothing + manager.logger.warning.assert_called_once() + assert "Phase transition desync detected" in manager.logger.warning.call_args[0][0] + assert str(manager._phase.id) == current_phase_id + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase_end(mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + current_phase_id = str(prog.phases[0].id) + + payload = json.dumps({"old": current_phase_id, "new": "end"}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + assert manager._phase is None + + # Allow background tasks to run (add_behavior) + await asyncio.sleep(0) + + # Verify notification to user interrupt + assert manager.send.await_count == 1 + msg_sent = manager.send.await_args[0][0] + assert msg_sent.to == mock_settings.agent_settings.user_interrupt_name + assert msg_sent.body == "end" + + +@pytest.mark.asyncio +async def test_handle_message_achieve_goal(mock_settings): + mock_settings.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + prog = Program.model_validate_json(make_valid_program_json(goal="TargetGoal")) + manager._initialize_internal_state(prog) + + goal_id = str(prog.phases[0].goals[0].id) + + msg = InternalMessage(to="me", sender="ui", body=goal_id, thread="achieve_goal") + + await manager.handle_message(msg) + + # Should send achieved goals to text extractor + assert manager.send.await_count == 1 + msg_sent = manager.send.await_args[0][0] + assert msg_sent.to == mock_settings.agent_settings.text_belief_extractor_name + assert msg_sent.thread == "achieved_goals" + + # Verify body + from control_backend.schemas.belief_list import GoalList + + gl = GoalList.model_validate_json(msg_sent.body) + assert len(gl.goals) == 1 + assert gl.goals[0].name == "TargetGoal" + + +@pytest.mark.asyncio +async def test_handle_message_achieve_goal_not_found(): + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + manager.logger = MagicMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + + msg = InternalMessage(to="me", sender="ui", body="non_existent_id", thread="achieve_goal") + + await manager.handle_message(msg) + + manager.send.assert_not_called() + manager.logger.debug.assert_called() + + +@pytest.mark.asyncio +async def test_setup(mock_settings): + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + manager.add_behavior = MagicMock(side_effect=close_coro) + + mock_context = MagicMock() + mock_sub = MagicMock() + mock_context.socket.return_value = mock_sub + + with patch( + "control_backend.agents.bdi.bdi_program_manager.Context.instance", return_value=mock_context + ): + # We also need to mock file writing in _create_agentspeak_and_send_to_bdi + with patch("builtins.open", new_callable=MagicMock): + await manager.setup() + + # Check logic + # 1. Sends default empty program to BDI + assert manager.send.await_count == 1 + assert manager.send.await_args[0][0].to == mock_settings.agent_settings.bdi_core_name + + # 2. Connects SUB socket + mock_sub.connect.assert_called_with(mock_settings.zmq_settings.internal_sub_address) + mock_sub.subscribe.assert_called_with("program") + + # 3. Adds behavior + manager.add_behavior.assert_called() diff --git a/test/unit/agents/bdi/test_belief_collector.py b/test/unit/agents/bdi/test_belief_collector.py deleted file mode 100644 index 69db269..0000000 --- a/test/unit/agents/bdi/test_belief_collector.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -from unittest.mock import AsyncMock - -import pytest - -from control_backend.agents.bdi import ( - BDIBeliefCollectorAgent, -) -from control_backend.core.agent_system import InternalMessage -from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief - - -@pytest.fixture -def agent(): - agent = BDIBeliefCollectorAgent("belief_collector_agent") - return agent - - -def make_msg(body: dict, sender: str = "sender"): - return InternalMessage(to="collector", sender=sender, body=json.dumps(body)) - - -@pytest.mark.asyncio -async def test_handle_message_routes_belief_text(agent, mocker): - """ - Test that when a message is received, _handle_belief_text is called with that message. - """ - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}} - spy = mocker.patch.object(agent, "_handle_belief_text", new_callable=AsyncMock) - - await agent.handle_message(make_msg(payload)) - - spy.assert_awaited_once_with(payload, "sender") - - -@pytest.mark.asyncio -async def test_handle_message_routes_emotion(agent, mocker): - payload = {"type": "emotion_extraction_text"} - spy = mocker.patch.object(agent, "_handle_emo_text", new_callable=AsyncMock) - - await agent.handle_message(make_msg(payload)) - - spy.assert_awaited_once_with(payload, "sender") - - -@pytest.mark.asyncio -async def test_handle_message_bad_json(agent, mocker): - agent._handle_belief_text = AsyncMock() - bad_msg = InternalMessage(to="collector", sender="sender", body="not json") - - await agent.handle_message(bad_msg) - - agent._handle_belief_text.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}} - spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock) - expected = [Belief(name="user_said", arguments=["hello"])] - - await agent._handle_belief_text(payload, "origin") - - spy.assert_awaited_once_with(expected, origin="origin") - - -@pytest.mark.asyncio -async def test_handle_belief_text_no_send_when_empty(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {}} - spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock) - - await agent._handle_belief_text(payload, "origin") - - spy.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_send_beliefs_to_bdi(agent): - agent.send = AsyncMock() - beliefs = [Belief(name="user_said", arguments=["hello", "world"])] - - await agent._send_beliefs_to_bdi(beliefs, origin="origin") - - agent.send.assert_awaited_once() - sent: InternalMessage = agent.send.call_args.args[0] - assert sent.to == settings.agent_settings.bdi_core_name - assert sent.thread == "beliefs" - assert json.loads(sent.body)["create"] == [belief.model_dump() for belief in beliefs] - - -@pytest.mark.asyncio -async def test_setup_executes(agent): - """Covers setup and asserts the agent has a name.""" - await agent.setup() - assert agent.name == "belief_collector_agent" # simple property assertion - - -@pytest.mark.asyncio -async def test_handle_message_unrecognized_type_executes(agent): - """Covers the else branch for unrecognized message type.""" - payload = {"type": "unknown_type"} - msg = make_msg(payload, sender="tester") - # Wrap send to ensure nothing is sent - agent.send = AsyncMock() - await agent.handle_message(msg) - # Assert no messages were sent - agent.send.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_emo_text_executes(agent): - """Covers the _handle_emo_text method.""" - # The method does nothing, but we can assert it returns None - result = await agent._handle_emo_text({}, "origin") - assert result is None - - -@pytest.mark.asyncio -async def test_send_beliefs_to_bdi_empty_executes(agent): - """Covers early return when beliefs are empty.""" - agent.send = AsyncMock() - await agent._send_beliefs_to_bdi({}) - # Assert that nothing was sent - agent.send.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_belief_text_invalid_returns_none(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "invalid-argument"}} - - result = await agent._handle_belief_text(payload, "origin") - - # The method itself returns None - assert result is None diff --git a/test/unit/agents/bdi/test_text_belief_extractor.py b/test/unit/agents/bdi/test_text_belief_extractor.py index 6782ba1..0d7dc00 100644 --- a/test/unit/agents/bdi/test_text_belief_extractor.py +++ b/test/unit/agents/bdi/test_text_belief_extractor.py @@ -14,6 +14,7 @@ from control_backend.schemas.belief_message import Belief as InternalBelief from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.chat_history import ChatHistory, ChatMessage from control_backend.schemas.program import ( + BaseGoal, # Changed from Goal ConditionalNorm, KeywordBelief, LLMAction, @@ -28,7 +29,8 @@ from control_backend.schemas.program import ( @pytest.fixture def llm(): llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4) - llm._query_llm = AsyncMock() + # We must ensure _query_llm returns a dictionary so iterating it doesn't fail + llm._query_llm = AsyncMock(return_value={}) return llm @@ -374,3 +376,155 @@ async def test_llm_failure_handling(agent, llm, sample_program): assert len(belief_changes.true) == 0 assert len(belief_changes.false) == 0 + + +def test_belief_state_bool(): + # Empty + bs = BeliefState() + assert not bs + + # True set + bs_true = BeliefState(true={InternalBelief(name="a", arguments=None)}) + assert bs_true + + # False set + bs_false = BeliefState(false={InternalBelief(name="a", arguments=None)}) + assert bs_false + + +@pytest.mark.asyncio +async def test_handle_beliefs_message_validation_error(agent, mock_settings): + # Invalid JSON + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="beliefs", + body="invalid json", + ) + # Should log warning and return + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + # Invalid Model + msg.body = json.dumps({"beliefs": [{"invalid": "obj"}]}) + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_handle_goals_message_validation_error(agent, mock_settings): + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="goals", + body="invalid json", + ) + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_handle_goal_achieved_message_validation_error(agent, mock_settings): + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="achieved_goals", + body="invalid json", + ) + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_goal_inferrer_infer_from_conversation(agent, llm): + # Setup goals + # Use BaseGoal object as typically received by the extractor + g1 = BaseGoal(id=uuid.uuid4(), name="g1", description="desc", can_fail=True) + + # Use real GoalAchievementInferrer + from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer + + inferrer = GoalAchievementInferrer(llm) + inferrer.goals = {g1} + + # Mock LLM response + llm._query_llm.return_value = True + + completions = await inferrer.infer_from_conversation(ChatHistory(messages=[])) + assert completions + # slugify uses slugify library, hard to predict exact string without it, + # but we can check values + assert list(completions.values())[0] is True + + +def test_apply_conversation_message_limit(agent): + with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s: + mock_s.behaviour_settings.conversation_history_length_limit = 2 + agent.conversation.messages = [] + + agent._apply_conversation_message(ChatMessage(role="user", content="1")) + agent._apply_conversation_message(ChatMessage(role="assistant", content="2")) + agent._apply_conversation_message(ChatMessage(role="user", content="3")) + + assert len(agent.conversation.messages) == 2 + assert agent.conversation.messages[0].content == "2" + assert agent.conversation.messages[1].content == "3" + + +@pytest.mark.asyncio +async def test_handle_program_manager_reset(agent): + with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s: + mock_s.agent_settings.bdi_program_manager_name = "pm" + agent.conversation.messages = [ChatMessage(role="user", content="hi")] + agent.belief_inferrer.available_beliefs = [ + SemanticBelief(id=uuid.uuid4(), name="b", description="d") + ] + + msg = InternalMessage(to="me", sender="pm", thread="conversation_history", body="reset") + await agent.handle_message(msg) + + assert len(agent.conversation.messages) == 0 + assert len(agent.belief_inferrer.available_beliefs) == 0 + + +def test_split_into_chunks(): + from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer + + items = [1, 2, 3, 4, 5] + chunks = SemanticBeliefInferrer._split_into_chunks(items, 2) + assert len(chunks) == 2 + assert len(chunks[0]) + len(chunks[1]) == 5 + + +@pytest.mark.asyncio +async def test_infer_beliefs_call(agent, llm): + from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer + + inferrer = SemanticBeliefInferrer(llm) + sb = SemanticBelief(id=uuid.uuid4(), name="is_happy", description="User is happy") + + llm.query = AsyncMock(return_value={"is_happy": True}) + + res = await inferrer._infer_beliefs(ChatHistory(messages=[]), [sb]) + assert res == {"is_happy": True} + llm.query.assert_called_once() + + +@pytest.mark.asyncio +async def test_infer_goal_call(agent, llm): + from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer + + inferrer = GoalAchievementInferrer(llm) + goal = BaseGoal(id=uuid.uuid4(), name="g1", description="d") + + llm.query = AsyncMock(return_value=True) + + res = await inferrer._infer_goal(ChatHistory(messages=[]), goal) + assert res is True + llm.query.assert_called_once() diff --git a/test/unit/agents/communication/test_ri_communication_agent.py b/test/unit/agents/communication/test_ri_communication_agent.py index 06d8766..a678907 100644 --- a/test/unit/agents/communication/test_ri_communication_agent.py +++ b/test/unit/agents/communication/test_ri_communication_agent.py @@ -4,6 +4,8 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent +from control_backend.core.agent_system import InternalMessage +from control_backend.schemas.ri_message import PauseCommand, RIEndpoint def speech_agent_path(): @@ -53,7 +55,11 @@ async def test_setup_success_connects_and_starts_robot(zmq_context): MockGesture.return_value.start = AsyncMock() agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -83,7 +89,11 @@ async def test_setup_binds_when_requested(zmq_context): agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True) - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) with ( patch(speech_agent_path(), autospec=True) as MockSpeech, @@ -151,6 +161,7 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context): @pytest.mark.asyncio async def test_handle_disconnection_publishes_and_reconnects(): pub_socket = AsyncMock() + pub_socket.close = MagicMock() agent = RICommunicationAgent("ri_comm") agent.pub_socket = pub_socket agent.connected = True @@ -233,6 +244,25 @@ async def test_handle_negotiation_response_unhandled_id(): ) +@pytest.mark.asyncio +async def test_handle_negotiation_response_audio(zmq_context): + agent = RICommunicationAgent("ri_comm") + + with patch( + "control_backend.agents.communication.ri_communication_agent.VADAgent", autospec=True + ) as MockVAD: + MockVAD.return_value.start = AsyncMock() + + await agent._handle_negotiation_response( + {"data": [{"id": "audio", "port": 7000, "bind": False}]} + ) + + MockVAD.assert_called_once_with( + audio_in_address="tcp://localhost:7000", audio_in_bind=False + ) + MockVAD.return_value.start.assert_awaited_once() + + @pytest.mark.asyncio async def test_stop_closes_sockets(): req = MagicMock() @@ -323,6 +353,7 @@ async def test_listen_loop_generic_exception(): @pytest.mark.asyncio async def test_handle_disconnection_timeout(monkeypatch): pub = AsyncMock() + pub.close = MagicMock() pub.send_multipart = AsyncMock(side_effect=TimeoutError) agent = RICommunicationAgent("ri_comm") @@ -365,3 +396,38 @@ async def test_negotiate_req_socket_none_causes_retry(zmq_context): result = await agent._negotiate_connection(max_retries=1) assert result is False + + +@pytest.mark.asyncio +async def test_handle_message_pause_command(zmq_context): + """Test handle_message with a valid PauseCommand.""" + agent = RICommunicationAgent("ri_comm") + agent._req_socket = AsyncMock() + agent.logger = MagicMock() + + agent._req_socket.recv_json.return_value = {"status": "ok"} + + pause_cmd = PauseCommand(data=True) + msg = InternalMessage(to="ri_comm", sender="user_int", body=pause_cmd.model_dump_json()) + + await agent.handle_message(msg) + + agent._req_socket.send_json.assert_awaited_once() + args = agent._req_socket.send_json.await_args[0][0] + assert args["endpoint"] == RIEndpoint.PAUSE.value + assert args["data"] is True + + +@pytest.mark.asyncio +async def test_handle_message_invalid_pause_command(zmq_context): + """Test handle_message with invalid JSON.""" + agent = RICommunicationAgent("ri_comm") + agent._req_socket = AsyncMock() + agent.logger = MagicMock() + + msg = InternalMessage(to="ri_comm", sender="user_int", body="invalid json") + + await agent.handle_message(msg) + + agent.logger.warning.assert_called_with("Incorrect message format for PauseCommand.") + agent._req_socket.send_json.assert_not_called() diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index 5fc07f2..a1cc297 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -58,17 +58,20 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings): to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body=prompt.model_dump_json(), + thread="prompt_message", # REQUIRED: thread must match handle_message logic ) await agent.handle_message(msg) # Verification # "Hello world." constitutes one sentence/chunk based on punctuation split - # The agent should call send once with the full sentence + # The agent should call send once with the full sentence, PLUS once more for full reply assert agent.send.called - args = agent.send.call_args_list[0][0][0] - assert args.to == mock_settings.agent_settings.bdi_core_name - assert "Hello world." in args.body + + # Check args. We expect at least one call sending "Hello world." + calls = agent.send.call_args_list + bodies = [c[0][0].body for c in calls] + assert any("Hello world." in b for b in bodies) @pytest.mark.asyncio @@ -80,18 +83,23 @@ async def test_llm_processing_errors(mock_httpx_client, mock_settings): to="llm", sender=mock_settings.agent_settings.bdi_core_name, body=prompt.model_dump_json(), + thread="prompt_message", ) - # HTTP Error + # HTTP Error: stream method RAISES exception immediately mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail")) + await agent.handle_message(msg) - assert "LLM service unavailable." in agent.send.call_args[0][0].body + + # Check that error message was sent + assert agent.send.called + assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body # General Exception agent.send.reset_mock() mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom")) await agent.handle_message(msg) - assert "Error processing the request." in agent.send.call_args[0][0].body + assert "Error processing the request." in agent.send.call_args_list[0][0][0].body @pytest.mark.asyncio @@ -113,16 +121,19 @@ async def test_llm_json_error(mock_httpx_client, mock_settings): agent = LLMAgent("llm_agent") agent.send = AsyncMock() + # Ensure logger is mocked + agent.logger = MagicMock() - with patch.object(agent.logger, "error") as log: - prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) - msg = InternalMessage( - to="llm", - sender=mock_settings.agent_settings.bdi_core_name, - body=prompt.model_dump_json(), - ) - await agent.handle_message(msg) - log.assert_called() # Should log JSONDecodeError + prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) + msg = InternalMessage( + to="llm", + sender=mock_settings.agent_settings.bdi_core_name, + body=prompt.model_dump_json(), + thread="prompt_message", + ) + await agent.handle_message(msg) + + agent.logger.error.assert_called() # Should log JSONDecodeError def test_llm_instructions(): @@ -157,6 +168,7 @@ async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body=invalid_json, + thread="prompt_message", ) await agent.handle_message(msg) @@ -285,3 +297,28 @@ async def test_clear_history_command(mock_settings): ) await agent.handle_message(msg) assert len(agent.history) == 0 + + +@pytest.mark.asyncio +async def test_handle_assistant_and_user_messages(mock_settings): + agent = LLMAgent("llm_agent") + + # Assistant message + msg_ast = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_core_name, + thread="assistant_message", + body="I said this", + ) + await agent.handle_message(msg_ast) + assert agent.history[-1] == {"role": "assistant", "content": "I said this"} + + # User message + msg_usr = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_core_name, + thread="user_message", + body="User said this", + ) + await agent.handle_message(msg_usr) + assert agent.history[-1] == {"role": "user", "content": "User said this"} diff --git a/test/unit/agents/perception/transcription_agent/test_transcription_agent.py b/test/unit/agents/perception/transcription_agent/test_transcription_agent.py index ccdaa7f..57875ca 100644 --- a/test/unit/agents/perception/transcription_agent/test_transcription_agent.py +++ b/test/unit/agents/perception/transcription_agent/test_transcription_agent.py @@ -36,7 +36,12 @@ async def test_transcription_agent_flow(mock_zmq_context): agent.send = AsyncMock() agent._running = True - agent.add_behavior = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -143,7 +148,12 @@ async def test_transcription_loop_continues_after_error(mock_zmq_context): agent = TranscriptionAgent("tcp://in") agent._running = True # ← REQUIRED to enter the loop agent.send = AsyncMock() # should never be called - agent.add_behavior = AsyncMock() # match other tests + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) # match other tests await agent.setup() @@ -180,7 +190,12 @@ async def test_transcription_continue_branch_when_empty(mock_zmq_context): # Make loop runnable agent._running = True agent.send = AsyncMock() - agent.add_behavior = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() diff --git a/test/unit/agents/perception/vad_agent/test_vad_agent.py b/test/unit/agents/perception/vad_agent/test_vad_agent.py new file mode 100644 index 0000000..fe65545 --- /dev/null +++ b/test/unit/agents/perception/vad_agent/test_vad_agent.py @@ -0,0 +1,153 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from control_backend.agents.perception.vad_agent import VADAgent +from control_backend.core.agent_system import InternalMessage +from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus + + +@pytest.fixture(autouse=True) +def mock_zmq(): + with patch("zmq.asyncio.Context") as mock: + mock.instance.return_value = MagicMock() + yield mock + + +@pytest.fixture +def agent(): + return VADAgent("tcp://localhost:5555", False) + + +@pytest.mark.asyncio +async def test_handle_message_pause(agent): + agent._paused = MagicMock() + # It starts set (not paused) + + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="PAUSE") + + # We need to mock settings to match sender name + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.clear.assert_called_once() + assert agent._reset_needed is True + + +@pytest.mark.asyncio +async def test_handle_message_resume(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="RESUME") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.set.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_message_unknown_command(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="UNKNOWN") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + agent.logger = MagicMock() + + await agent.handle_message(msg) + + agent.logger.warning.assert_called() + agent._paused.clear.assert_not_called() + agent._paused.set.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_message_unknown_sender(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="other_agent", body="PAUSE") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.clear.assert_not_called() + + +@pytest.mark.asyncio +async def test_status_loop_waits_for_running(agent): + agent._running = True + agent.program_sub_socket = AsyncMock() + agent.program_sub_socket.close = MagicMock() + agent._reset_stream = AsyncMock() + + # Sequence of messages: + # 1. Wrong topic + # 2. Right topic, wrong status (STARTING) + # 3. Right topic, RUNNING -> Should break loop + + agent.program_sub_socket.recv_multipart.side_effect = [ + (b"wrong_topic", b"whatever"), + (PROGRAM_STATUS, ProgramStatus.STARTING.value), + (PROGRAM_STATUS, ProgramStatus.RUNNING.value), + ] + + await agent._status_loop() + + assert agent._reset_stream.await_count == 1 + agent.program_sub_socket.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_setup_success(agent, mock_zmq): + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) + + mock_context = mock_zmq.instance.return_value + mock_sub = MagicMock() + mock_pub = MagicMock() + + # We expect multiple socket calls: + # 1. audio_in (SUB) + # 2. audio_out (PUB) + # 3. program_sub (SUB) + mock_context.socket.side_effect = [mock_sub, mock_pub, mock_sub] + + with patch("control_backend.agents.perception.vad_agent.torch.hub.load") as mock_load: + mock_load.return_value = (MagicMock(), None) + + with patch("control_backend.agents.perception.vad_agent.TranscriptionAgent") as MockTrans: + mock_trans_instance = MockTrans.return_value + mock_trans_instance.start = AsyncMock() + + await agent.setup() + + mock_trans_instance.start.assert_awaited_once() + + assert agent.add_behavior.call_count == 2 # streaming_loop + status_loop + assert agent.audio_in_socket is not None + assert agent.audio_out_socket is not None + assert agent.program_sub_socket is not None + + +@pytest.mark.asyncio +async def test_reset_stream(agent): + mock_poller = MagicMock() + agent.audio_in_poller = mock_poller + + # poll(1) returns not None twice, then None + mock_poller.poll = AsyncMock(side_effect=[b"data", b"data", None]) + + agent._ready = MagicMock() + + await agent._reset_stream() + + assert mock_poller.poll.await_count == 3 + agent._ready.set.assert_called_once() diff --git a/test/unit/agents/perception/vad_agent/test_vad_streaming.py b/test/unit/agents/perception/vad_agent/test_vad_streaming.py index 166919f..349fab2 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_streaming.py +++ b/test/unit/agents/perception/vad_agent/test_vad_streaming.py @@ -5,6 +5,7 @@ import pytest import zmq from control_backend.agents.perception.vad_agent import VADAgent +from control_backend.core.config import settings # We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets @@ -135,6 +136,54 @@ async def test_no_data(audio_out_socket, vad_agent): assert len(vad_agent.audio_buffer) == 0 +@pytest.mark.asyncio +async def test_streaming_loop_reset_needed(audio_out_socket, vad_agent): + """Test that _reset_needed branch works as expected.""" + vad_agent._reset_needed = True + vad_agent._ready.set() + vad_agent._paused.set() + vad_agent._running = True + vad_agent.audio_buffer = np.array([1.0], dtype=np.float32) + vad_agent.i_since_speech = 0 + + # Mock _reset_stream to stop the loop by setting _running=False + async def mock_reset(): + vad_agent._running = False + + vad_agent._reset_stream = mock_reset + + # Needs a poller to avoid AssertionError + vad_agent.audio_in_poller = AsyncMock() + vad_agent.audio_in_poller.poll.return_value = None + + await vad_agent._streaming_loop() + + assert vad_agent._reset_needed is False + assert len(vad_agent.audio_buffer) == 0 + assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech + + +@pytest.mark.asyncio +async def test_streaming_loop_no_data_clears_buffer(audio_out_socket, vad_agent): + """Test that if poll returns None, buffer is cleared if not empty.""" + vad_agent.audio_buffer = np.array([1.0], dtype=np.float32) + vad_agent._ready.set() + vad_agent._paused.set() + vad_agent._running = True + + class MockPoller: + async def poll(self, timeout_ms=None): + vad_agent._running = False # stop after one poll + return None + + vad_agent.audio_in_poller = MockPoller() + + await vad_agent._streaming_loop() + + assert len(vad_agent.audio_buffer) == 0 + assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech + + @pytest.mark.asyncio async def test_vad_model_load_failure_stops_agent(vad_agent): """ diff --git a/test/unit/agents/test_base.py b/test/unit/agents/test_base.py new file mode 100644 index 0000000..0579ada --- /dev/null +++ b/test/unit/agents/test_base.py @@ -0,0 +1,24 @@ +import logging + +from control_backend.agents.base import BaseAgent + + +class MyAgent(BaseAgent): + async def setup(self): + pass + + async def handle_message(self, msg): + pass + + +def test_base_agent_logger_init(): + # When defining a subclass, __init_subclass__ runs + # The BaseAgent in agents/base.py sets the logger + assert hasattr(MyAgent, "logger") + assert isinstance(MyAgent.logger, logging.Logger) + # The logger name depends on the package. + # Since this test file is running as a module, __package__ might be None or the test package. + # In 'src/control_backend/agents/base.py', it uses __package__ of base.py which is + # 'control_backend.agents'. + # So logger name should be control_backend.agents.MyAgent + assert MyAgent.logger.name == "control_backend.agents.MyAgent" diff --git a/test/unit/agents/user_interrupt/test_user_interrupt.py b/test/unit/agents/user_interrupt/test_user_interrupt.py index 7e3e700..7c38a05 100644 --- a/test/unit/agents/user_interrupt/test_user_interrupt.py +++ b/test/unit/agents/user_interrupt/test_user_interrupt.py @@ -7,6 +7,15 @@ import pytest from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.program import ( + ConditionalNorm, + Goal, + KeywordBelief, + Phase, + Plan, + Program, + Trigger, +) from control_backend.schemas.ri_message import RIEndpoint @@ -16,6 +25,7 @@ def agent(): agent.send = AsyncMock() agent.logger = MagicMock() agent.sub_socket = AsyncMock() + agent.pub_socket = AsyncMock() return agent @@ -49,21 +59,18 @@ async def test_send_to_gesture_agent(agent): @pytest.mark.asyncio -async def test_send_to_program_manager(agent): +async def test_send_to_bdi_belief(agent): """Verify belief update format.""" - context_str = "2" + context_str = "some_goal" - await agent._send_to_program_manager(context_str) + await agent._send_to_bdi_belief(context_str) - agent.send.assert_awaited_once() - sent_msg: InternalMessage = agent.send.call_args.args[0] + assert agent.send.await_count == 1 + sent_msg = agent.send.call_args.args[0] - assert sent_msg.to == settings.agent_settings.bdi_program_manager_name - assert sent_msg.thread == "belief_override_id" - - body = json.loads(sent_msg.body) - - assert body["belief"] == context_str + assert sent_msg.to == settings.agent_settings.bdi_core_name + assert sent_msg.thread == "beliefs" + assert "achieved_some_goal" in sent_msg.body @pytest.mark.asyncio @@ -77,6 +84,10 @@ async def test_receive_loop_routing_success(agent): # Prepare JSON payloads as bytes payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode() payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode() + # override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal). + + # To test routing, we need to populate the maps + agent._goal_map["Hello Override"] = "some_goal_slug" payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode() agent.sub_socket.recv_multipart.side_effect = [ @@ -88,7 +99,7 @@ async def test_receive_loop_routing_success(agent): agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() - agent._send_to_program_manager = AsyncMock() + agent._send_to_bdi_belief = AsyncMock() try: await agent._receive_button_event() @@ -103,12 +114,12 @@ async def test_receive_loop_routing_success(agent): # Gesture agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture") - # Override - agent._send_to_program_manager.assert_awaited_once_with("Hello Override") + # Override (since we mapped it to a goal) + agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug") assert agent._send_to_speech_agent.await_count == 1 assert agent._send_to_gesture_agent.await_count == 1 - assert agent._send_to_program_manager.await_count == 1 + assert agent._send_to_bdi_belief.await_count == 1 @pytest.mark.asyncio @@ -125,7 +136,6 @@ async def test_receive_loop_unknown_type(agent): agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() - agent._send_to_belief_collector = AsyncMock() try: await agent._receive_button_event() @@ -137,10 +147,165 @@ async def test_receive_loop_unknown_type(agent): # Ensure no handlers were called agent._send_to_speech_agent.assert_not_called() agent._send_to_gesture_agent.assert_not_called() - agent._send_to_belief_collector.assert_not_called() - agent.logger.warning.assert_called_with( - "Received button press with unknown type '%s' (context: '%s').", - "unknown_thing", - "some_data", - ) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_create_mapping(agent): + # Create a program with a trigger, goal, and conditional norm + import uuid + + trigger_id = uuid.uuid4() + goal_id = uuid.uuid4() + norm_id = uuid.uuid4() + + cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[]) + + trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan) + goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan) + + cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond) + + phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger]) + prog = Program(phases=[phase]) + + # Call create_mapping via handle_message + msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json()) + await agent.handle_message(msg) + + # Check maps + assert str(trigger_id) in agent._trigger_map + assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger" + + assert str(goal_id) in agent._goal_map + assert agent._goal_map[str(goal_id)] == "my_goal" + + assert str(norm_id) in agent._cond_norm_map + assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite" + + +@pytest.mark.asyncio +async def test_create_mapping_invalid_json(agent): + # Pass invalid json to handle_message thread "new_program" + msg = InternalMessage(to="me", thread="new_program", body="invalid json") + await agent.handle_message(msg) + + # Should log error and maps should remain empty or cleared + agent.logger.error.assert_called() + + +@pytest.mark.asyncio +async def test_handle_message_trigger_start(agent): + # Setup reverse map manually + agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" + + msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + args = agent.pub_socket.send_multipart.call_args[0][0] + assert args[0] == b"experiment" + payload = json.loads(args[1]) + assert payload["type"] == "trigger_update" + assert payload["id"] == "ui_id_123" + assert payload["achieved"] is True + + +@pytest.mark.asyncio +async def test_handle_message_trigger_end(agent): + agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" + + msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "trigger_update" + assert payload["achieved"] is False + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase(agent): + msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "phase_update" + assert payload["id"] == "phase_id_123" + + +@pytest.mark.asyncio +async def test_handle_message_goal_start(agent): + agent._goal_reverse_map["goal_slug"] = "goal_id_123" + + msg = InternalMessage(to="me", thread="goal_start", body="goal_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "goal_update" + assert payload["id"] == "goal_id_123" + assert payload["active"] is True + + +@pytest.mark.asyncio +async def test_handle_message_active_norms_update(agent): + agent._cond_norm_reverse_map["norm_active"] = "id_1" + agent._cond_norm_reverse_map["norm_inactive"] = "id_2" + + # Body is like: "('norm_active', 'other')" + # The split logic handles quotes etc. + msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "cond_norms_state_update" + norms = {n["id"]: n["active"] for n in payload["norms"]} + assert norms["id_1"] is True + assert norms["id_2"] is False + + +@pytest.mark.asyncio +async def test_send_experiment_control(agent): + # Test next_phase + await agent._send_experiment_control_to_bdi_core("next_phase") + agent.send.assert_awaited() + msg = agent.send.call_args[0][0] + assert msg.thread == "force_next_phase" + + # Test reset_phase + await agent._send_experiment_control_to_bdi_core("reset_phase") + msg = agent.send.call_args[0][0] + assert msg.thread == "reset_current_phase" + + # Test reset_experiment + await agent._send_experiment_control_to_bdi_core("reset_experiment") + msg = agent.send.call_args[0][0] + assert msg.thread == "reset_experiment" + + +@pytest.mark.asyncio +async def test_send_pause_command(agent): + await agent._send_pause_command("true") + # Sends to RI and VAD + assert agent.send.await_count == 2 + msgs = [call.args[0] for call in agent.send.call_args_list] + + ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name) + assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint + assert json.loads(ri_msg.body)["data"] is True + + vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name) + assert vad_msg.body == "PAUSE" + + agent.send.reset_mock() + await agent._send_pause_command("false") + assert agent.send.await_count == 2 + vad_msg = next( + m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name + ).args[0] + assert vad_msg.body == "RESUME" diff --git a/test/unit/api/v1/endpoints/test_user_interact.py b/test/unit/api/v1/endpoints/test_user_interact.py new file mode 100644 index 0000000..ddb9932 --- /dev/null +++ b/test/unit/api/v1/endpoints/test_user_interact.py @@ -0,0 +1,96 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from control_backend.api.v1.endpoints import user_interact + + +@pytest.fixture +def app(): + app = FastAPI() + app.include_router(user_interact.router) + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +@pytest.mark.asyncio +async def test_receive_button_event(client): + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket + + payload = {"type": "speech", "context": "hello"} + response = client.post("/button_pressed", json=payload) + + assert response.status_code == 202 + assert response.json() == {"status": "Event received"} + + mock_pub_socket.send_multipart.assert_awaited_once() + args = mock_pub_socket.send_multipart.call_args[0][0] + assert args[0] == b"button_pressed" + assert "speech" in args[1].decode() + + +@pytest.mark.asyncio +async def test_receive_button_event_invalid_payload(client): + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket + + # Missing context + payload = {"type": "speech"} + response = client.post("/button_pressed", json=payload) + + assert response.status_code == 422 + mock_pub_socket.send_multipart.assert_not_called() + + +@pytest.mark.asyncio +async def test_experiment_stream_direct_call(): + """ + Directly calling the endpoint function to test the streaming logic + without dealing with TestClient streaming limitations. + """ + mock_socket = AsyncMock() + # 1. recv data + # 2. recv timeout + # 3. disconnect (request.is_disconnected returns True) + mock_socket.recv_multipart.side_effect = [ + (b"topic", b"message1"), + TimeoutError(), + (b"topic", b"message2"), # Should not be reached if disconnect checks work + ] + mock_socket.close = MagicMock() + mock_socket.connect = MagicMock() + mock_socket.subscribe = MagicMock() + + mock_context = MagicMock() + mock_context.socket.return_value = mock_socket + + with patch( + "control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context + ): + mock_request = AsyncMock() + # is_disconnected sequence: + # 1. False (before first recv) -> reads message1 + # 2. False (before second recv) -> triggers TimeoutError, continues + # 3. True (before third recv) -> break loop + mock_request.is_disconnected.side_effect = [False, False, True] + + response = await user_interact.experiment_stream(mock_request) + + lines = [] + # Consume the generator + async for line in response.body_iterator: + lines.append(line) + + assert "data: message1\n\n" in lines + assert len(lines) == 1 + + mock_socket.connect.assert_called() + mock_socket.subscribe.assert_called_with(b"experiment") + mock_socket.close.assert_called() diff --git a/test/unit/test_main_sockets.py b/test/unit/test_main_sockets.py new file mode 100644 index 0000000..662147a --- /dev/null +++ b/test/unit/test_main_sockets.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +import zmq + +from control_backend.main import setup_sockets + + +def test_setup_sockets_proxy(): + mock_context = MagicMock() + mock_pub = MagicMock() + mock_sub = MagicMock() + + mock_context.socket.side_effect = [mock_pub, mock_sub] + + with patch("zmq.asyncio.Context.instance", return_value=mock_context): + with patch("zmq.proxy") as mock_proxy: + setup_sockets() + + mock_pub.bind.assert_called() + mock_sub.bind.assert_called() + mock_proxy.assert_called_with(mock_sub, mock_pub) + + # Check cleanup + mock_pub.close.assert_called() + mock_sub.close.assert_called() + + +def test_setup_sockets_proxy_error(): + mock_context = MagicMock() + mock_pub = MagicMock() + mock_sub = MagicMock() + mock_context.socket.side_effect = [mock_pub, mock_sub] + + with patch("zmq.asyncio.Context.instance", return_value=mock_context): + with patch("zmq.proxy", side_effect=zmq.ZMQError): + with patch("control_backend.main.logger") as mock_logger: + setup_sockets() + mock_logger.warning.assert_called() + mock_pub.close.assert_called() + mock_sub.close.assert_called() From 6d03ba8a4153015cda8c8ec4ba1372233c7084af Mon Sep 17 00:00:00 2001 From: Pim Hutting Date: Fri, 16 Jan 2026 14:28:27 +0100 Subject: [PATCH 3/7] feat: added extra endpoint for norm pings also made sure that you cannot skip phase on end phase ref: N25B-400 --- .../agents/bdi/agentspeak_generator.py | 10 ++++++ .../user_interrupt/user_interrupt_agent.py | 29 ++++++++++------- .../api/v1/endpoints/user_interact.py | 31 +++++++++++++++++-- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/control_backend/agents/bdi/agentspeak_generator.py b/src/control_backend/agents/bdi/agentspeak_generator.py index ed6f787..524f980 100644 --- a/src/control_backend/agents/bdi/agentspeak_generator.py +++ b/src/control_backend/agents/bdi/agentspeak_generator.py @@ -424,6 +424,16 @@ class AgentSpeakGenerator: ) ) + # Force phase transition fallback + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("force_transition_phase"), + [], + [AstStatement(StatementType.EMPTY, AstLiteral("true"))], + ) + ) + @singledispatchmethod def _astify(self, element: ProgramElement) -> AstExpression: raise NotImplementedError(f"Cannot convert element {element} to an AgentSpeak expression.") diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index 0bde563..9ba8409 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -117,13 +117,13 @@ class UserInterruptAgent(BaseAgent): event_context, ) elif asl_cond_norm := self._cond_norm_map.get(ui_id): - await self._send_to_bdi_belief(asl_cond_norm) + await self._send_to_bdi_belief(asl_cond_norm, "cond_norm") self.logger.info( "Forwarded button press (override) with context '%s' to BDI Core.", event_context, ) elif asl_goal := self._goal_map.get(ui_id): - await self._send_to_bdi_belief(asl_goal) + await self._send_to_bdi_belief(asl_goal, "goal") self.logger.info( "Forwarded button press (override) with context '%s' to BDI Core.", event_context, @@ -141,7 +141,7 @@ class UserInterruptAgent(BaseAgent): case "override_unachieve": ui_id = str(event_context) if asl_cond_norm := self._cond_norm_map.get(ui_id): - await self._send_to_bdi_belief(asl_cond_norm, True) + await self._send_to_bdi_belief(asl_cond_norm, "cond_norm", True) self.logger.info( "Forwarded button press (override_unachieve)" "with context '%s' to BDI Core.", @@ -187,11 +187,9 @@ class UserInterruptAgent(BaseAgent): payload = {"type": "trigger_update", "id": ui_id, "achieved": True} await self._send_experiment_update(payload) self.logger.info(f"UI Update: Trigger {asl_slug} started (ID: {ui_id})") - case "trigger_end": asl_slug = msg.body ui_id = self._trigger_reverse_map.get(asl_slug) - if ui_id: payload = {"type": "trigger_update", "id": ui_id, "achieved": False} await self._send_experiment_update(payload) @@ -207,7 +205,7 @@ class UserInterruptAgent(BaseAgent): goal_name = msg.body ui_id = self._goal_reverse_map.get(goal_name) if ui_id: - payload = {"type": "goal_update", "id": ui_id, "active": True} + payload = {"type": "goal_update", "id": ui_id} await self._send_experiment_update(payload) self.logger.info(f"UI Update: Goal {goal_name} started (ID: {ui_id})") case "active_norms_update": @@ -224,15 +222,17 @@ class UserInterruptAgent(BaseAgent): :param active_slugs: A list of slugs (strings) currently active in the BDI core. """ updates = [] - for asl_slug, ui_id in self._cond_norm_reverse_map.items(): is_active = asl_slug in active_slugs - updates.append({"id": ui_id, "name": asl_slug, "active": is_active}) + updates.append({"id": ui_id, "active": is_active}) payload = {"type": "cond_norms_state_update", "norms": updates} - await self._send_experiment_update(payload, should_log=False) - # self.logger.debug(f"Broadcasted state for {len(updates)} conditional norms.") + if self.pub_socket: + topic = b"status" + body = json.dumps(payload).encode("utf-8") + await self.pub_socket.send_multipart([topic, body]) + # self.logger.info(f"UI Update: Active norms {updates}") def _create_mapping(self, program_json: str): """ @@ -325,9 +325,14 @@ class UserInterruptAgent(BaseAgent): await self.send(msg) self.logger.info(f"Directly forced {thread} in BDI: {body}") - async def _send_to_bdi_belief(self, asl_goal: str, unachieve: bool = False): + async def _send_to_bdi_belief(self, asl: str, asl_type: str, unachieve: bool = False): """Send belief to BDI Core""" - belief_name = f"achieved_{asl_goal}" + if asl_type == "goal": + belief_name = f"achieved_{asl}" + elif asl_type == "cond_norm": + belief_name = f"force_{asl}" + else: + self.logger.warning("Tried to send belief with unknown type") belief = Belief(name=belief_name, arguments=None) self.logger.debug(f"Sending belief to BDI Core: {belief_name}") # Conditional norms are unachieved by removing the belief diff --git a/src/control_backend/api/v1/endpoints/user_interact.py b/src/control_backend/api/v1/endpoints/user_interact.py index 3d3406e..eb70f35 100644 --- a/src/control_backend/api/v1/endpoints/user_interact.py +++ b/src/control_backend/api/v1/endpoints/user_interact.py @@ -52,11 +52,11 @@ async def experiment_stream(request: Request): while True: # Check if client closed the tab if await request.is_disconnected(): - logger.info("Client disconnected from experiment stream.") + logger.error("Client disconnected from experiment stream.") break try: - parts = await asyncio.wait_for(socket.recv_multipart(), timeout=1.0) + parts = await asyncio.wait_for(socket.recv_multipart(), timeout=10.0) _, message = parts yield f"data: {message.decode().strip()}\n\n" except TimeoutError: @@ -65,3 +65,30 @@ async def experiment_stream(request: Request): socket.close() return StreamingResponse(gen(), media_type="text/event-stream") + + +@router.get("/status_stream") +async def status_stream(request: Request): + context = Context.instance() + socket = context.socket(zmq.SUB) + socket.connect(settings.zmq_settings.internal_sub_address) + + socket.subscribe(b"status") + + async def gen(): + try: + while True: + if await request.is_disconnected(): + break + try: + # Shorter timeout since this is frequent + parts = await asyncio.wait_for(socket.recv_multipart(), timeout=0.5) + _, message = parts + yield f"data: {message.decode().strip()}\n\n" + except TimeoutError: + yield ": ping\n\n" # Keep the connection alive + continue + finally: + socket.close() + + return StreamingResponse(gen(), media_type="text/event-stream") From 7c10c50336ebb57739d04782bf907b00c63cd866 Mon Sep 17 00:00:00 2001 From: Pim Hutting Date: Fri, 16 Jan 2026 14:29:46 +0100 Subject: [PATCH 4/7] chore: removed resetExperiment from backened now it happens in UI ref: N25B-400 --- .../agents/user_interrupt/user_interrupt_agent.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index 9ba8409..cf72ce5 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -80,7 +80,6 @@ class UserInterruptAgent(BaseAgent): - type: "next_phase", context: None, indicates to the BDI Core to - type: "pause", context: boolean indicating whether to pause - type: "reset_phase", context: None, indicates to the BDI Core to - - type: "reset_experiment", context: None, indicates to the BDI Core to """ while True: topic, body = await self.sub_socket.recv_multipart() @@ -162,7 +161,7 @@ class UserInterruptAgent(BaseAgent): else: self.logger.info("Sent resume command.") - case "next_phase" | "reset_phase" | "reset_experiment": + case "next_phase" | "reset_phase": await self._send_experiment_control_to_bdi_core(event_type) case _: self.logger.warning( @@ -359,8 +358,6 @@ class UserInterruptAgent(BaseAgent): thread = "force_next_phase" case "reset_phase": thread = "reset_current_phase" - case "reset_experiment": - thread = "reset_experiment" case _: self.logger.warning( "Received unknown experiment control type '%s' to send to BDI Core.", From 8506c0d9effe74889cbe2080786f948e89194caa Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Fri, 16 Jan 2026 15:07:44 +0100 Subject: [PATCH 5/7] chore: remove belief collector and small tweaks --- .../agents/bdi/bdi_core_agent.py | 2 +- src/control_backend/core/config.py | 2 -- src/control_backend/main.py | 7 ------ .../actuation/test_robot_gesture_agent.py | 3 +-- test/unit/agents/bdi/test_bdi_core_agent.py | 10 ++++---- .../agents/bdi/test_text_belief_extractor.py | 24 +++++++++++++++++++ .../test_speech_recognizer.py | 4 +++- .../perception/vad_agent/test_vad_agent.py | 1 - test/unit/conftest.py | 1 - 9 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/control_backend/agents/bdi/bdi_core_agent.py b/src/control_backend/agents/bdi/bdi_core_agent.py index 0c217dc..628bb53 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent.py +++ b/src/control_backend/agents/bdi/bdi_core_agent.py @@ -167,7 +167,7 @@ class BDICoreAgent(BaseAgent): case "force_next_phase": self._force_next_phase() case _: - self.logger.warning("Received unknow user interruption: %s", msg) + self.logger.warning("Received unknown user interruption: %s", msg) def _apply_belief_changes(self, belief_changes: BeliefMessage): """ diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 6deb1b8..329a246 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -35,7 +35,6 @@ class AgentSettings(BaseModel): Names of the various agents in the system. These names are used for routing messages. :ivar bdi_core_name: Name of the BDI Core Agent. - :ivar bdi_belief_collector_name: Name of the Belief Collector Agent. :ivar bdi_program_manager_name: Name of the BDI Program Manager Agent. :ivar text_belief_extractor_name: Name of the Text Belief Extractor Agent. :ivar vad_name: Name of the Voice Activity Detection (VAD) Agent. @@ -50,7 +49,6 @@ class AgentSettings(BaseModel): # agent names bdi_core_name: str = "bdi_core_agent" - bdi_belief_collector_name: str = "belief_collector_agent" bdi_program_manager_name: str = "bdi_program_manager_agent" text_belief_extractor_name: str = "text_belief_extractor_agent" vad_name: str = "vad_agent" diff --git a/src/control_backend/main.py b/src/control_backend/main.py index ec93b1e..a0136bd 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -26,7 +26,6 @@ from zmq.asyncio import Context # BDI agents from control_backend.agents.bdi import ( - BDIBeliefCollectorAgent, BDICoreAgent, TextBeliefExtractorAgent, ) @@ -122,12 +121,6 @@ async def lifespan(app: FastAPI): "name": settings.agent_settings.bdi_core_name, }, ), - "BeliefCollectorAgent": ( - BDIBeliefCollectorAgent, - { - "name": settings.agent_settings.bdi_belief_collector_name, - }, - ), "TextBeliefExtractorAgent": ( TextBeliefExtractorAgent, { diff --git a/test/unit/agents/actuation/test_robot_gesture_agent.py b/test/unit/agents/actuation/test_robot_gesture_agent.py index 225278d..1e6fd8a 100644 --- a/test/unit/agents/actuation/test_robot_gesture_agent.py +++ b/test/unit/agents/actuation/test_robot_gesture_agent.py @@ -478,8 +478,7 @@ async def test_stop_closes_sockets(): pubsocket.close.assert_called_once() subsocket.close.assert_called_once() - # Note: repsocket is not closed in stop() method, but you might want to add it - # repsocket.close.assert_called_once() + repsocket.close.assert_called_once() @pytest.mark.asyncio diff --git a/test/unit/agents/bdi/test_bdi_core_agent.py b/test/unit/agents/bdi/test_bdi_core_agent.py index 152d901..6245d5b 100644 --- a/test/unit/agents/bdi/test_bdi_core_agent.py +++ b/test/unit/agents/bdi/test_bdi_core_agent.py @@ -45,12 +45,12 @@ async def test_setup_no_asl(mock_agentspeak_env, agent): @pytest.mark.asyncio -async def test_handle_belief_collector_message(agent, mock_settings): +async def test_handle_belief_message(agent, mock_settings): """Test that incoming beliefs are added to the BDI agent""" beliefs = [Belief(name="user_said", arguments=["Hello"])] msg = InternalMessage( to="bdi_agent", - sender=mock_settings.agent_settings.bdi_belief_collector_name, + sender=mock_settings.agent_settings.text_belief_extractor_name, body=BeliefMessage(create=beliefs).model_dump_json(), thread="beliefs", ) @@ -82,7 +82,7 @@ async def test_handle_delete_belief_message(agent, mock_settings): msg = InternalMessage( to="bdi_agent", - sender=mock_settings.agent_settings.bdi_belief_collector_name, + sender=mock_settings.agent_settings.text_belief_extractor_name, body=BeliefMessage(delete=beliefs).model_dump_json(), thread="beliefs", ) @@ -104,11 +104,11 @@ async def test_handle_delete_belief_message(agent, mock_settings): @pytest.mark.asyncio -async def test_incorrect_belief_collector_message(agent, mock_settings): +async def test_incorrect_belief_message(agent, mock_settings): """Test that incorrect message format triggers an exception.""" msg = InternalMessage( to="bdi_agent", - sender=mock_settings.agent_settings.bdi_belief_collector_name, + sender=mock_settings.agent_settings.text_belief_extractor_name, body=json.dumps({"bad_format": "bad_format"}), thread="beliefs", ) diff --git a/test/unit/agents/bdi/test_text_belief_extractor.py b/test/unit/agents/bdi/test_text_belief_extractor.py index 0d7dc00..353b718 100644 --- a/test/unit/agents/bdi/test_text_belief_extractor.py +++ b/test/unit/agents/bdi/test_text_belief_extractor.py @@ -359,6 +359,30 @@ async def test_simulated_real_turn_remove_belief(agent, llm, sample_program): assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false) +@pytest.mark.asyncio +async def test_infer_goal_completions_sends_beliefs(agent, llm): + """Test that inferred goal completions are sent to the BDI core.""" + goal = BaseGoal( + id=uuid.uuid4(), name="Say Hello", description="The user said hello", can_fail=True + ) + agent.goal_inferrer.goals = {goal} + + # Mock goal inference: goal is achieved + llm.query = AsyncMock(return_value=True) + + await agent._infer_goal_completions() + + # Should send belief change to BDI core + agent.send.assert_awaited_once() + sent: InternalMessage = agent.send.call_args.args[0] + assert sent.to == settings.agent_settings.bdi_core_name + assert sent.thread == "beliefs" + + parsed = BeliefMessage.model_validate_json(sent.body) + assert len(parsed.create) == 1 + assert parsed.create[0].name == "achieved_say_hello" + + @pytest.mark.asyncio async def test_llm_failure_handling(agent, llm, sample_program): """ diff --git a/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py b/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py index 47443a9..518d189 100644 --- a/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py +++ b/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py @@ -55,4 +55,6 @@ def test_get_decode_options(): assert isinstance(options["sample_len"], int) # When disabled, it should not limit output length based on input size - assert "sample_rate" not in options + recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=False) + options = recognizer._get_decode_options(audio) + assert "sample_len" not in options diff --git a/test/unit/agents/perception/vad_agent/test_vad_agent.py b/test/unit/agents/perception/vad_agent/test_vad_agent.py index fe65545..3e6b0ad 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_agent.py +++ b/test/unit/agents/perception/vad_agent/test_vad_agent.py @@ -60,7 +60,6 @@ async def test_handle_message_unknown_command(agent): await agent.handle_message(msg) - agent.logger.warning.assert_called() agent._paused.clear.assert_not_called() agent._paused.set.assert_not_called() diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 6ab989e..d5f06e5 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -25,7 +25,6 @@ def mock_settings(): mock.zmq_settings.internal_sub_address = "tcp://localhost:5561" mock.zmq_settings.ri_command_address = "tcp://localhost:0000" mock.agent_settings.bdi_core_name = "bdi_core_agent" - mock.agent_settings.bdi_belief_collector_name = "belief_collector_agent" mock.agent_settings.llm_name = "llm_agent" mock.agent_settings.robot_speech_name = "robot_speech_agent" mock.agent_settings.transcription_name = "transcription_agent" From 7f7e0c542ee5bb61bbff2ed94662c669ff75e0ce Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Fri, 16 Jan 2026 15:35:41 +0100 Subject: [PATCH 6/7] docs: add missing docs ref: N25B-115 --- src/control_backend/agents/__init__.py | 4 + .../agents/actuation/__init__.py | 4 + src/control_backend/agents/bdi/__init__.py | 5 ++ .../agents/bdi/agentspeak_ast.py | 36 ++++++++- .../agents/bdi/agentspeak_generator.py | 14 ++++ .../agents/bdi/text_belief_extractor_agent.py | 12 ++- .../agents/communication/__init__.py | 4 + src/control_backend/agents/llm/__init__.py | 4 + .../agents/perception/__init__.py | 5 ++ .../transcription_agent.py | 2 +- .../user_interrupt/user_interrupt_agent.py | 45 ++++++----- src/control_backend/api/v1/endpoints/sse.py | 12 --- src/control_backend/core/agent_system.py | 12 +++ src/control_backend/schemas/belief_list.py | 6 ++ src/control_backend/schemas/chat_history.py | 13 ++++ src/control_backend/schemas/events.py | 8 ++ src/control_backend/schemas/program.py | 78 ++++++++++--------- 17 files changed, 191 insertions(+), 73 deletions(-) delete mode 100644 src/control_backend/api/v1/endpoints/sse.py diff --git a/src/control_backend/agents/__init__.py b/src/control_backend/agents/__init__.py index 1618d55..85f4aad 100644 --- a/src/control_backend/agents/__init__.py +++ b/src/control_backend/agents/__init__.py @@ -1 +1,5 @@ +""" +This package contains all agent implementations for the PepperPlus Control Backend. +""" + from .base import BaseAgent as BaseAgent diff --git a/src/control_backend/agents/actuation/__init__.py b/src/control_backend/agents/actuation/__init__.py index 8ff7e7f..9a8d81b 100644 --- a/src/control_backend/agents/actuation/__init__.py +++ b/src/control_backend/agents/actuation/__init__.py @@ -1,2 +1,6 @@ +""" +Agents responsible for controlling the robot's physical actions, such as speech and gestures. +""" + from .robot_gesture_agent import RobotGestureAgent as RobotGestureAgent from .robot_speech_agent import RobotSpeechAgent as RobotSpeechAgent diff --git a/src/control_backend/agents/bdi/__init__.py b/src/control_backend/agents/bdi/__init__.py index d6f5124..2f7d976 100644 --- a/src/control_backend/agents/bdi/__init__.py +++ b/src/control_backend/agents/bdi/__init__.py @@ -1,3 +1,8 @@ +""" +Agents and utilities for the BDI (Belief-Desire-Intention) reasoning system, +implementing AgentSpeak(L) logic. +""" + from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent from .text_belief_extractor_agent import ( diff --git a/src/control_backend/agents/bdi/agentspeak_ast.py b/src/control_backend/agents/bdi/agentspeak_ast.py index 68be531..19f48e2 100644 --- a/src/control_backend/agents/bdi/agentspeak_ast.py +++ b/src/control_backend/agents/bdi/agentspeak_ast.py @@ -80,7 +80,7 @@ class AstTerm(AstExpression, ABC): @dataclass(eq=False) class AstAtom(AstTerm): """ - Grounded expression in all lowercase. + Represents a grounded atom in AgentSpeak (e.g., lowercase constants). """ value: str @@ -92,7 +92,7 @@ class AstAtom(AstTerm): @dataclass(eq=False) class AstVar(AstTerm): """ - Ungrounded variable expression. First letter capitalized. + Represents an ungrounded variable in AgentSpeak (e.g., capitalized names). """ name: str @@ -103,6 +103,10 @@ class AstVar(AstTerm): @dataclass(eq=False) class AstNumber(AstTerm): + """ + Represents a numeric constant in AgentSpeak. + """ + value: int | float def _to_agentspeak(self) -> str: @@ -111,6 +115,10 @@ class AstNumber(AstTerm): @dataclass(eq=False) class AstString(AstTerm): + """ + Represents a string literal in AgentSpeak. + """ + value: str def _to_agentspeak(self) -> str: @@ -119,6 +127,10 @@ class AstString(AstTerm): @dataclass(eq=False) class AstLiteral(AstTerm): + """ + Represents a literal (functor and terms) in AgentSpeak. + """ + functor: str terms: list[AstTerm] = field(default_factory=list) @@ -142,6 +154,10 @@ class BinaryOperatorType(StrEnum): @dataclass class AstBinaryOp(AstExpression): + """ + Represents a binary logical or relational operation in AgentSpeak. + """ + left: AstExpression operator: BinaryOperatorType right: AstExpression @@ -167,6 +183,10 @@ class AstBinaryOp(AstExpression): @dataclass class AstLogicalExpression(AstExpression): + """ + Represents a logical expression, potentially negated, in AgentSpeak. + """ + expression: AstExpression negated: bool = False @@ -208,6 +228,10 @@ class AstStatement(AstNode): @dataclass class AstRule(AstNode): + """ + Represents an inference rule in AgentSpeak. If there is no condition, it always holds. + """ + result: AstExpression condition: AstExpression | None = None @@ -231,6 +255,10 @@ class TriggerType(StrEnum): @dataclass class AstPlan(AstNode): + """ + Represents a plan in AgentSpeak, consisting of a trigger, context, and body. + """ + type: TriggerType trigger_literal: AstExpression context: list[AstExpression] @@ -260,6 +288,10 @@ class AstPlan(AstNode): @dataclass class AstProgram(AstNode): + """ + Represents a full AgentSpeak program, consisting of rules and plans. + """ + rules: list[AstRule] = field(default_factory=list) plans: list[AstPlan] = field(default_factory=list) diff --git a/src/control_backend/agents/bdi/agentspeak_generator.py b/src/control_backend/agents/bdi/agentspeak_generator.py index 524f980..2fe12e3 100644 --- a/src/control_backend/agents/bdi/agentspeak_generator.py +++ b/src/control_backend/agents/bdi/agentspeak_generator.py @@ -40,9 +40,23 @@ from control_backend.schemas.program import ( class AgentSpeakGenerator: + """ + Generator class that translates a high-level :class:`~control_backend.schemas.program.Program` + into AgentSpeak(L) source code. + + It handles the conversion of phases, norms, goals, and triggers into AgentSpeak rules and plans, + ensuring the robot follows the defined behavioral logic. + """ + _asp: AstProgram def generate(self, program: Program) -> str: + """ + Translates a Program object into an AgentSpeak source string. + + :param program: The behavior program to translate. + :return: The generated AgentSpeak code as a string. + """ self._asp = AstProgram() if program.phases: diff --git a/src/control_backend/agents/bdi/text_belief_extractor_agent.py b/src/control_backend/agents/bdi/text_belief_extractor_agent.py index b5fd266..362dfbf 100644 --- a/src/control_backend/agents/bdi/text_belief_extractor_agent.py +++ b/src/control_backend/agents/bdi/text_belief_extractor_agent.py @@ -18,6 +18,12 @@ type JSONLike = None | bool | int | float | str | list["JSONLike"] | dict[str, " class BeliefState(BaseModel): + """ + Represents the state of inferred semantic beliefs. + + Maintains sets of beliefs that are currently considered true or false. + """ + true: set[InternalBelief] = set() false: set[InternalBelief] = set() @@ -338,7 +344,7 @@ class TextBeliefExtractorAgent(BaseAgent): class SemanticBeliefInferrer: """ - Class that handles only prompting an LLM for semantic beliefs. + Infers semantic beliefs from conversation history using an LLM. """ def __init__( @@ -464,6 +470,10 @@ Respond with a JSON similar to the following, but with the property names as giv class GoalAchievementInferrer(SemanticBeliefInferrer): + """ + Infers whether specific conversational goals have been achieved using an LLM. + """ + def __init__(self, llm: TextBeliefExtractorAgent.LLM): super().__init__(llm) self.goals: set[BaseGoal] = set() diff --git a/src/control_backend/agents/communication/__init__.py b/src/control_backend/agents/communication/__init__.py index 2aa1535..3dde6cf 100644 --- a/src/control_backend/agents/communication/__init__.py +++ b/src/control_backend/agents/communication/__init__.py @@ -1 +1,5 @@ +""" +Agents responsible for external communication and service discovery. +""" + from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent diff --git a/src/control_backend/agents/llm/__init__.py b/src/control_backend/agents/llm/__init__.py index e12ff29..519812f 100644 --- a/src/control_backend/agents/llm/__init__.py +++ b/src/control_backend/agents/llm/__init__.py @@ -1 +1,5 @@ +""" +Agents that interface with Large Language Models for natural language processing and generation. +""" + from .llm_agent import LLMAgent as LLMAgent diff --git a/src/control_backend/agents/perception/__init__.py b/src/control_backend/agents/perception/__init__.py index e18361a..5a46671 100644 --- a/src/control_backend/agents/perception/__init__.py +++ b/src/control_backend/agents/perception/__init__.py @@ -1,3 +1,8 @@ +""" +Agents responsible for processing sensory input, such as audio transcription and voice activity +detection. +""" + from .transcription_agent.transcription_agent import ( TranscriptionAgent as TranscriptionAgent, ) diff --git a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py index 765d7ac..795623d 100644 --- a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py +++ b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py @@ -74,7 +74,7 @@ class TranscriptionAgent(BaseAgent): def _connect_audio_in_socket(self): """ - Helper to connect the ZMQ SUB socket for audio input. + Connects the ZMQ SUB socket for receiving audio data. """ self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index cf72ce5..6a4c9b0 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -50,10 +50,8 @@ class UserInterruptAgent(BaseAgent): async def setup(self): """ - Initialize the agent. - - Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic. - Starts the background behavior to receive the user interrupts. + Initialize the agent by setting up ZMQ sockets for receiving button events and + publishing updates. """ context = Context.instance() @@ -68,18 +66,15 @@ class UserInterruptAgent(BaseAgent): async def _receive_button_event(self): """ - The behaviour of the UserInterruptAgent. - Continuous loop that receives button_pressed events from the button_pressed HTTP endpoint. - These events contain a type and a context. + Main loop to receive and process button press events from the UI. - These are the different types and contexts: - - type: "speech", context: string that the robot has to say. - - type: "gesture", context: single gesture name that the robot has to perform. - - type: "override", context: id that belongs to the goal/trigger/conditional norm. - - type: "override_unachieve", context: id that belongs to the conditional norm to unachieve. - - type: "next_phase", context: None, indicates to the BDI Core to - - type: "pause", context: boolean indicating whether to pause - - type: "reset_phase", context: None, indicates to the BDI Core to + Handles different event types: + - `speech`: Triggers immediate robot speech. + - `gesture`: Triggers an immediate robot gesture. + - `override`: Forces a belief, trigger, or goal completion in the BDI core. + - `override_unachieve`: Removes a belief from the BDI core. + - `pause`: Toggles the system's pause state. + - `next_phase` / `reset_phase`: Controls experiment flow. """ while True: topic, body = await self.sub_socket.recv_multipart() @@ -172,7 +167,10 @@ class UserInterruptAgent(BaseAgent): async def handle_message(self, msg: InternalMessage): """ - Handle commands received from other internal Python agents. + Handles internal messages from other agents, such as program updates or trigger + notifications. + + :param msg: The incoming :class:`~control_backend.core.agent_system.InternalMessage`. """ match msg.thread: case "new_program": @@ -217,8 +215,9 @@ class UserInterruptAgent(BaseAgent): async def _broadcast_cond_norms(self, active_slugs: list[str]): """ - Sends the current state of all conditional norms to the UI. - :param active_slugs: A list of slugs (strings) currently active in the BDI core. + Broadcasts the current activation state of all conditional norms to the UI. + + :param active_slugs: A list of sluggified norm names currently active in the BDI core. """ updates = [] for asl_slug, ui_id in self._cond_norm_reverse_map.items(): @@ -235,7 +234,9 @@ class UserInterruptAgent(BaseAgent): def _create_mapping(self, program_json: str): """ - Create mappings between UI IDs and ASL slugs for triggers, goals, and conditional norms + Creates a bidirectional mapping between UI identifiers and AgentSpeak slugs. + + :param program_json: The JSON representation of the behavioral program. """ try: program = Program.model_validate_json(program_json) @@ -277,8 +278,10 @@ class UserInterruptAgent(BaseAgent): async def _send_experiment_update(self, data, should_log: bool = True): """ - Sends an update to the 'experiment' topic. - The SSE endpoint will pick this up and push it to the UI. + Publishes an experiment state update to the internal ZMQ bus for the UI. + + :param data: The update payload. + :param should_log: Whether to log the update. """ if self.pub_socket: topic = b"experiment" diff --git a/src/control_backend/api/v1/endpoints/sse.py b/src/control_backend/api/v1/endpoints/sse.py deleted file mode 100644 index c660aa5..0000000 --- a/src/control_backend/api/v1/endpoints/sse.py +++ /dev/null @@ -1,12 +0,0 @@ -from fastapi import APIRouter, Request - -router = APIRouter() - - -# TODO: implement -@router.get("/sse") -async def sse(request: Request): - """ - Placeholder for future Server-Sent Events endpoint. - """ - pass diff --git a/src/control_backend/core/agent_system.py b/src/control_backend/core/agent_system.py index e3c8dc4..267f072 100644 --- a/src/control_backend/core/agent_system.py +++ b/src/control_backend/core/agent_system.py @@ -22,10 +22,22 @@ class AgentDirectory: @staticmethod def register(name: str, agent: "BaseAgent"): + """ + Registers an agent instance with a unique name. + + :param name: The name of the agent. + :param agent: The :class:`BaseAgent` instance. + """ _agent_directory[name] = agent @staticmethod def get(name: str) -> "BaseAgent | None": + """ + Retrieves a registered agent instance by name. + + :param name: The name of the agent to retrieve. + :return: The :class:`BaseAgent` instance, or None if not found. + """ return _agent_directory.get(name) diff --git a/src/control_backend/schemas/belief_list.py b/src/control_backend/schemas/belief_list.py index f3d6818..841a4ed 100644 --- a/src/control_backend/schemas/belief_list.py +++ b/src/control_backend/schemas/belief_list.py @@ -16,4 +16,10 @@ class BeliefList(BaseModel): class GoalList(BaseModel): + """ + Represents a list of goals, used for communicating multiple goals between agents. + + :ivar goals: The list of goals. + """ + goals: list[BaseGoal] diff --git a/src/control_backend/schemas/chat_history.py b/src/control_backend/schemas/chat_history.py index 52fc224..8fd1e72 100644 --- a/src/control_backend/schemas/chat_history.py +++ b/src/control_backend/schemas/chat_history.py @@ -2,9 +2,22 @@ from pydantic import BaseModel class ChatMessage(BaseModel): + """ + Represents a single message in a conversation. + + :ivar role: The role of the speaker (e.g., 'user', 'assistant'). + :ivar content: The text content of the message. + """ + role: str content: str class ChatHistory(BaseModel): + """ + Represents a sequence of chat messages, forming a conversation history. + + :ivar messages: An ordered list of :class:`ChatMessage` objects. + """ + messages: list[ChatMessage] diff --git a/src/control_backend/schemas/events.py b/src/control_backend/schemas/events.py index 46967f7..a01b668 100644 --- a/src/control_backend/schemas/events.py +++ b/src/control_backend/schemas/events.py @@ -2,5 +2,13 @@ from pydantic import BaseModel class ButtonPressedEvent(BaseModel): + """ + Represents a button press event from the UI. + + :ivar type: The type of event (e.g., 'speech', 'gesture', 'override'). + :ivar context: Additional data associated with the event (e.g., speech text, gesture name, + or ID). + """ + type: str context: str diff --git a/src/control_backend/schemas/program.py b/src/control_backend/schemas/program.py index d04abbb..283e17d 100644 --- a/src/control_backend/schemas/program.py +++ b/src/control_backend/schemas/program.py @@ -20,6 +20,10 @@ class ProgramElement(BaseModel): class LogicalOperator(Enum): + """ + Logical operators for combining beliefs. + """ + AND = "AND" OR = "OR" @@ -30,9 +34,9 @@ type BasicBelief = KeywordBelief | SemanticBelief class KeywordBelief(ProgramElement): """ - Represents a belief that is set when the user spoken text contains a certain keyword. + Represents a belief that is activated when a specific keyword is detected in the user's speech. - :ivar keyword: The keyword on which this belief gets set. + :ivar keyword: The string to look for in the transcription. """ name: str = "" @@ -41,9 +45,11 @@ class KeywordBelief(ProgramElement): class SemanticBelief(ProgramElement): """ - Represents a belief that is set by semantic LLM validation. + Represents a belief whose truth value is determined by an LLM analyzing the conversation + context. - :ivar description: Description of how to form the belief, used by the LLM. + :ivar description: A natural language description of what this belief represents, + used as a prompt for the LLM. """ description: str @@ -51,13 +57,11 @@ class SemanticBelief(ProgramElement): class InferredBelief(ProgramElement): """ - Represents a belief that gets formed by combining two beliefs with a logical AND or OR. + Represents a belief derived from other beliefs using logical operators. - These beliefs can also be :class:`InferredBelief`, leading to arbitrarily deep nesting. - - :ivar operator: The logical operator to apply. - :ivar left: The left part of the logical expression. - :ivar right: The right part of the logical expression. + :ivar operator: The :class:`LogicalOperator` (AND/OR) to apply. + :ivar left: The left operand (another belief). + :ivar right: The right operand (another belief). """ name: str = "" @@ -67,6 +71,13 @@ class InferredBelief(ProgramElement): class Norm(ProgramElement): + """ + Base class for behavioral norms that guide the robot's interactions. + + :ivar norm: The textual description of the norm. + :ivar critical: Whether this norm is considered critical and should be strictly enforced. + """ + name: str = "" norm: str critical: bool = False @@ -74,10 +85,7 @@ class Norm(ProgramElement): class BasicNorm(Norm): """ - Represents a behavioral norm. - - :ivar norm: The actual norm text describing the behavior. - :ivar critical: When true, this norm should absolutely not be violated (checked separately). + A simple behavioral norm that is always considered for activation when its phase is active. """ pass @@ -85,9 +93,9 @@ class BasicNorm(Norm): class ConditionalNorm(Norm): """ - Represents a norm that is only active when a condition is met (i.e., a certain belief holds). + A behavioral norm that is only active when a specific condition (belief) is met. - :ivar condition: When to activate this norm. + :ivar condition: The :class:`Belief` that must hold for this norm to be active. """ condition: Belief @@ -140,9 +148,9 @@ type Action = SpeechAction | GestureAction | LLMAction class SpeechAction(ProgramElement): """ - Represents the action of the robot speaking a literal text. + An action where the robot speaks a predefined literal text. - :ivar text: The text to speak. + :ivar text: The text content to be spoken. """ name: str = "" @@ -151,11 +159,10 @@ class SpeechAction(ProgramElement): class Gesture(BaseModel): """ - Represents a gesture to be performed. Can be either a single gesture, - or a random gesture from a category (tag). + Defines a physical gesture for the robot to perform. - :ivar type: The type of the gesture, "tag" or "single". - :ivar name: The name of the single gesture or tag. + :ivar type: Whether to use a specific "single" gesture or a random one from a "tag" category. + :ivar name: The identifier for the gesture or tag. """ type: Literal["tag", "single"] @@ -164,9 +171,9 @@ class Gesture(BaseModel): class GestureAction(ProgramElement): """ - Represents the action of the robot performing a gesture. + An action where the robot performs a physical gesture. - :ivar gesture: The gesture to perform. + :ivar gesture: The :class:`Gesture` definition. """ name: str = "" @@ -175,10 +182,9 @@ class GestureAction(ProgramElement): class LLMAction(ProgramElement): """ - Represents the action of letting an LLM generate a reply based on its chat history - and an additional goal added in the prompt. + An action that triggers an LLM-generated conversational response. - :ivar goal: The extra (temporary) goal to add to the LLM. + :ivar goal: A temporary conversational goal to guide the LLM's response generation. """ name: str = "" @@ -187,10 +193,10 @@ class LLMAction(ProgramElement): class Trigger(ProgramElement): """ - Represents a belief-based trigger. When a belief is set, the corresponding plan is executed. + Defines a reactive behavior: when the condition (belief) is met, the plan is executed. - :ivar condition: When to activate the trigger. - :ivar plan: The plan to execute. + :ivar condition: The :class:`Belief` that triggers this behavior. + :ivar plan: The :class:`Plan` to execute upon activation. """ condition: Belief @@ -199,11 +205,11 @@ class Trigger(ProgramElement): class Phase(ProgramElement): """ - A distinct phase within a program, containing norms, goals, and triggers. + A logical stage in the interaction program, grouping norms, goals, and triggers. - :ivar norms: List of norms active in this phase. - :ivar goals: List of goals to pursue in this phase. - :ivar triggers: List of triggers that define transitions out of this phase. + :ivar norms: List of norms active during this phase. + :ivar goals: List of goals the robot pursues in this phase. + :ivar triggers: List of reactive behaviors defined for this phase. """ name: str = "" @@ -214,9 +220,9 @@ class Phase(ProgramElement): class Program(BaseModel): """ - Represents a complete interaction program, consisting of a sequence or set of phases. + The top-level container for a complete robot behavior definition. - :ivar phases: The list of phases that make up the program. + :ivar phases: An ordered list of :class:`Phase` objects defining the interaction flow. """ phases: list[Phase] From db64eaeb0b03e683e23950a13d8b4a271f17136b Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Fri, 16 Jan 2026 16:18:36 +0100 Subject: [PATCH 7/7] fix: failing tests and warnings ref: N25B-449 --- pyproject.toml | 1 + .../user_interrupt/user_interrupt_agent.py | 4 +++- src/control_backend/api/v1/router.py | 4 +--- .../perception/vad_agent/test_vad_agent.py | 7 +++--- ...st_vad_agent.py => test_vad_agent_unit.py} | 0 .../user_interrupt/test_user_interrupt.py | 4 ++-- test/unit/api/v1/endpoints/test_router.py | 1 - .../api/v1/endpoints/test_sse_endpoint.py | 24 ------------------- uv.lock | 4 +++- 9 files changed, 14 insertions(+), 35 deletions(-) rename test/unit/agents/perception/vad_agent/{test_vad_agent.py => test_vad_agent_unit.py} (100%) delete mode 100644 test/unit/api/v1/endpoints/test_sse_endpoint.py diff --git a/pyproject.toml b/pyproject.toml index cdc2ce3..5de7daa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ test = [ "pytest-asyncio>=1.2.0", "pytest-cov>=7.0.0", "pytest-mock>=3.15.1", + "python-slugify>=8.0.4", "pyyaml>=6.0.3", "pyzmq>=27.1.0", "soundfile>=0.13.1", diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index 6a4c9b0..a42861a 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -202,7 +202,7 @@ class UserInterruptAgent(BaseAgent): goal_name = msg.body ui_id = self._goal_reverse_map.get(goal_name) if ui_id: - payload = {"type": "goal_update", "id": ui_id} + payload = {"type": "goal_update", "id": ui_id, "active": True} await self._send_experiment_update(payload) self.logger.info(f"UI Update: Goal {goal_name} started (ID: {ui_id})") case "active_norms_update": @@ -361,6 +361,8 @@ class UserInterruptAgent(BaseAgent): thread = "force_next_phase" case "reset_phase": thread = "reset_current_phase" + case "reset_experiment": + thread = "reset_experiment" case _: self.logger.warning( "Received unknown experiment control type '%s' to send to BDI Core.", diff --git a/src/control_backend/api/v1/router.py b/src/control_backend/api/v1/router.py index c130ad3..b46df5f 100644 --- a/src/control_backend/api/v1/router.py +++ b/src/control_backend/api/v1/router.py @@ -1,13 +1,11 @@ from fastapi.routing import APIRouter -from control_backend.api.v1.endpoints import logs, message, program, robot, sse, user_interact +from control_backend.api.v1.endpoints import logs, message, program, robot, user_interact api_router = APIRouter() api_router.include_router(message.router, tags=["Messages"]) -api_router.include_router(sse.router, tags=["SSE"]) - api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"]) api_router.include_router(logs.router, tags=["Logs"]) diff --git a/test/integration/agents/perception/vad_agent/test_vad_agent.py b/test/integration/agents/perception/vad_agent/test_vad_agent.py index 668d1ce..3cde755 100644 --- a/test/integration/agents/perception/vad_agent/test_vad_agent.py +++ b/test/integration/agents/perception/vad_agent/test_vad_agent.py @@ -40,7 +40,7 @@ async def test_normal_setup(per_transcription_agent): per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent._streaming_loop = AsyncMock() - async def swallow_background_task(coro): + def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task @@ -106,7 +106,7 @@ async def test_out_socket_creation_failure(zmq_context): per_vad_agent._streaming_loop = AsyncMock() per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None) - async def swallow_background_task(coro): + def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task @@ -126,7 +126,7 @@ async def test_stop(zmq_context, per_transcription_agent): per_vad_agent._reset_stream = AsyncMock() per_vad_agent._streaming_loop = AsyncMock() - async def swallow_background_task(coro): + def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task @@ -150,6 +150,7 @@ async def test_application_startup_complete(zmq_context): vad_agent._running = True vad_agent._reset_stream = AsyncMock() vad_agent.program_sub_socket = AsyncMock() + vad_agent.program_sub_socket.close = MagicMock() vad_agent.program_sub_socket.recv_multipart.side_effect = [ (PROGRAM_STATUS, ProgramStatus.RUNNING.value), ] diff --git a/test/unit/agents/perception/vad_agent/test_vad_agent.py b/test/unit/agents/perception/vad_agent/test_vad_agent_unit.py similarity index 100% rename from test/unit/agents/perception/vad_agent/test_vad_agent.py rename to test/unit/agents/perception/vad_agent/test_vad_agent_unit.py diff --git a/test/unit/agents/user_interrupt/test_user_interrupt.py b/test/unit/agents/user_interrupt/test_user_interrupt.py index 7c38a05..7a71891 100644 --- a/test/unit/agents/user_interrupt/test_user_interrupt.py +++ b/test/unit/agents/user_interrupt/test_user_interrupt.py @@ -63,7 +63,7 @@ async def test_send_to_bdi_belief(agent): """Verify belief update format.""" context_str = "some_goal" - await agent._send_to_bdi_belief(context_str) + await agent._send_to_bdi_belief(context_str, "goal") assert agent.send.await_count == 1 sent_msg = agent.send.call_args.args[0] @@ -115,7 +115,7 @@ async def test_receive_loop_routing_success(agent): agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture") # Override (since we mapped it to a goal) - agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug") + agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug", "goal") assert agent._send_to_speech_agent.await_count == 1 assert agent._send_to_gesture_agent.await_count == 1 diff --git a/test/unit/api/v1/endpoints/test_router.py b/test/unit/api/v1/endpoints/test_router.py index 7303d9c..dd93d8d 100644 --- a/test/unit/api/v1/endpoints/test_router.py +++ b/test/unit/api/v1/endpoints/test_router.py @@ -11,6 +11,5 @@ def test_router_includes_expected_paths(): # Ensure at least one route under each prefix exists assert any(p.startswith("/robot") for p in paths) assert any(p.startswith("/message") for p in paths) - assert any(p.startswith("/sse") for p in paths) assert any(p.startswith("/logs") for p in paths) assert any(p.startswith("/program") for p in paths) diff --git a/test/unit/api/v1/endpoints/test_sse_endpoint.py b/test/unit/api/v1/endpoints/test_sse_endpoint.py deleted file mode 100644 index 75a4555..0000000 --- a/test/unit/api/v1/endpoints/test_sse_endpoint.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient - -from control_backend.api.v1.endpoints import sse - - -@pytest.fixture -def app(): - app = FastAPI() - app.include_router(sse.router) - return app - - -@pytest.fixture -def client(app): - return TestClient(app) - - -def test_sse_route_exists(client): - """Minimal smoke test to ensure /sse route exists and responds.""" - response = client.get("/sse") - # Since implementation is not done, we only assert it doesn't crash - assert response.status_code == 200 diff --git a/uv.lock b/uv.lock index ce46ceb..ea39c17 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.13" resolution-markers = [ "python_full_version >= '3.14'", @@ -1030,6 +1030,7 @@ test = [ { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-mock" }, + { name = "python-slugify" }, { name = "pyyaml" }, { name = "pyzmq" }, { name = "soundfile" }, @@ -1080,6 +1081,7 @@ test = [ { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-mock", specifier = ">=3.15.1" }, + { name = "python-slugify", specifier = ">=8.0.4" }, { name = "pyyaml", specifier = ">=6.0.3" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "soundfile", specifier = ">=0.13.1" },