diff --git a/src/control_backend/agents/bdi/__init__.py b/src/control_backend/agents/bdi/__init__.py index 8d45440..d6f5124 100644 --- a/src/control_backend/agents/bdi/__init__.py +++ b/src/control_backend/agents/bdi/__init__.py @@ -1,8 +1,5 @@ from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent -from .belief_collector_agent import ( - BDIBeliefCollectorAgent as BDIBeliefCollectorAgent, -) from .text_belief_extractor_agent import ( TextBeliefExtractorAgent as TextBeliefExtractorAgent, ) diff --git a/src/control_backend/agents/bdi/agentspeak_ast.py b/src/control_backend/agents/bdi/agentspeak_ast.py index 188b4f3..68be531 100644 --- a/src/control_backend/agents/bdi/agentspeak_ast.py +++ b/src/control_backend/agents/bdi/agentspeak_ast.py @@ -77,7 +77,7 @@ class AstTerm(AstExpression, ABC): return AstBinaryOp(self, BinaryOperatorType.NOT_EQUALS, _coalesce_expr(other)) -@dataclass +@dataclass(eq=False) class AstAtom(AstTerm): """ Grounded expression in all lowercase. @@ -89,7 +89,7 @@ class AstAtom(AstTerm): return self.value.lower() -@dataclass +@dataclass(eq=False) class AstVar(AstTerm): """ Ungrounded variable expression. First letter capitalized. @@ -101,7 +101,7 @@ class AstVar(AstTerm): return self.name.capitalize() -@dataclass +@dataclass(eq=False) class AstNumber(AstTerm): value: int | float @@ -109,7 +109,7 @@ class AstNumber(AstTerm): return str(self.value) -@dataclass +@dataclass(eq=False) class AstString(AstTerm): value: str @@ -117,7 +117,7 @@ class AstString(AstTerm): return f'"{self.value}"' -@dataclass +@dataclass(eq=False) class AstLiteral(AstTerm): functor: str terms: list[AstTerm] = field(default_factory=list) diff --git a/src/control_backend/agents/bdi/belief_collector_agent.py b/src/control_backend/agents/bdi/belief_collector_agent.py deleted file mode 100644 index ac0e2e5..0000000 --- a/src/control_backend/agents/bdi/belief_collector_agent.py +++ /dev/null @@ -1,152 +0,0 @@ -import json - -from pydantic import ValidationError - -from control_backend.agents.base import BaseAgent -from control_backend.core.agent_system import InternalMessage -from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief, BeliefMessage - - -class BDIBeliefCollectorAgent(BaseAgent): - """ - BDI Belief Collector Agent. - - This agent acts as a central aggregator for beliefs derived from various sources (e.g., text, - emotion, vision). It receives raw extracted data from other agents, - normalizes them into valid :class:`Belief` objects, and forwards them as a unified packet to the - BDI Core Agent. - - It serves as a funnel to ensure the BDI agent receives a consistent stream of beliefs. - """ - - async def setup(self): - """ - Initialize the agent. - """ - self.logger.info("Setting up %s", self.name) - - async def handle_message(self, msg: InternalMessage): - """ - Handle incoming messages from other extractor agents. - - Routes the message to specific handlers based on the 'type' field in the JSON body. - Supported types: - - ``belief_extraction_text``: Handled by :meth:`_handle_belief_text` - - ``emotion_extraction_text``: Handled by :meth:`_handle_emo_text` - - :param msg: The received internal message. - """ - sender_node = msg.sender - - # Parse JSON payload - try: - payload = json.loads(msg.body) - except Exception as e: - self.logger.warning( - "BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s", - sender_node, - msg.body, - e, - ) - return - - msg_type = payload.get("type") - - # Prefer explicit 'type' field - if msg_type == "belief_extraction_text": - self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node) - await self._handle_belief_text(payload, sender_node) - # This is not implemented yet, but we keep the structure for future use - elif msg_type == "emotion_extraction_text": - self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node) - await self._handle_emo_text(payload, sender_node) - else: - self.logger.warning( - "Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type - ) - - async def _handle_belief_text(self, payload: dict, origin: str): - """ - Process text-based belief extraction payloads. - - Expected payload format:: - - { - "type": "belief_extraction_text", - "beliefs": { - "user_said": ["Can you help me?"], - "intention": ["ask_help"] - } - } - - Validates and converts the dictionary items into :class:`Belief` objects. - - :param payload: The dictionary payload containing belief data. - :param origin: The name of the sender agent. - """ - beliefs = payload.get("beliefs", {}) - - if not beliefs: - self.logger.debug("Received empty beliefs set.") - return - - def try_create_belief(name, arguments) -> Belief | None: - """ - Create a belief object from name and arguments, or return None silently if the input is - not correct. - - :param name: The name of the belief. - :param arguments: The arguments of the belief. - :return: A Belief object if the input is valid or None. - """ - try: - return Belief(name=name, arguments=arguments) - except ValidationError: - return None - - beliefs = [ - belief - for name, arguments in beliefs.items() - if (belief := try_create_belief(name, arguments)) is not None - ] - - self.logger.debug("Forwarding %d beliefs.", len(beliefs)) - for belief in beliefs: - for argument in belief.arguments: - self.logger.debug(" - %s %s", belief.name, argument) - - await self._send_beliefs_to_bdi(beliefs, origin=origin) - - async def _handle_emo_text(self, payload: dict, origin: str): - """ - Process emotion extraction payloads. - - **TODO**: Implement this method once emotion recognition is integrated. - - :param payload: The dictionary payload containing emotion data. - :param origin: The name of the sender agent. - """ - pass - - async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None): - """ - Send a list of aggregated beliefs to the BDI Core Agent. - - Wraps the beliefs in a :class:`BeliefMessage` and sends it via the 'beliefs' thread. - - :param beliefs: The list of Belief objects to send. - :param origin: (Optional) The original source of the beliefs (unused currently). - """ - if not beliefs: - return - - msg = InternalMessage( - to=settings.agent_settings.bdi_core_name, - sender=self.name, - body=BeliefMessage(create=beliefs).model_dump_json(), - thread="beliefs", - ) - - await self.send(msg) - self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs)) diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index 719053c..252502d 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -324,7 +324,7 @@ class RICommunicationAgent(BaseAgent): async def handle_message(self, msg: InternalMessage): try: pause_command = PauseCommand.model_validate_json(msg.body) - self._req_socket.send_json(pause_command.model_dump()) - self.logger.debug(self._req_socket.recv_json()) + await self._req_socket.send_json(pause_command.model_dump()) + self.logger.debug(await self._req_socket.recv_json()) except ValidationError: self.logger.warning("Incorrect message format for PauseCommand.") diff --git a/src/control_backend/main.py b/src/control_backend/main.py index d20cc66..ec93b1e 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -172,6 +172,8 @@ async def lifespan(app: FastAPI): await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value]) # Additional shutdown logic goes here + for agent in agents: + await agent.stop() logger.info("Application shutdown complete.") diff --git a/test/unit/agents/actuation/test_robot_gesture_agent.py b/test/unit/agents/actuation/test_robot_gesture_agent.py index fe051a6..225278d 100644 --- a/test/unit/agents/actuation/test_robot_gesture_agent.py +++ b/test/unit/agents/actuation/test_robot_gesture_agent.py @@ -28,7 +28,11 @@ async def test_setup_bind(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -55,7 +59,11 @@ async def test_setup_connect(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -119,6 +127,65 @@ async def test_handle_message_rejects_invalid_gesture_tag(): pubsocket.send_json.assert_not_awaited() +@pytest.mark.asyncio +async def test_handle_message_sends_valid_single_gesture_command(): + """Internal message with valid single gesture is forwarded.""" + pubsocket = AsyncMock() + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.pubsocket = pubsocket + + payload = { + "endpoint": RIEndpoint.GESTURE_SINGLE, + "data": "wave", + } + msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) + + await agent.handle_message(msg) + + pubsocket.send_json.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_message_rejects_invalid_single_gesture(): + """Internal message with invalid single gesture is not forwarded.""" + pubsocket = AsyncMock() + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.pubsocket = pubsocket + + payload = { + "endpoint": RIEndpoint.GESTURE_SINGLE, + "data": "dance", + } + msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) + + await agent.handle_message(msg) + + pubsocket.send_json.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_zmq_command_loop_valid_single_gesture_payload(): + """UI command with valid single gesture is read from SUB and published.""" + command = {"endpoint": RIEndpoint.GESTURE_SINGLE, "data": "wave"} + fake_socket = AsyncMock() + + async def recv_once(): + agent._running = False + return b"command", json.dumps(command).encode("utf-8") + + fake_socket.recv_multipart = recv_once + fake_socket.send_json = AsyncMock() + + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.subsocket = fake_socket + agent.pubsocket = fake_socket + agent._running = True + + await agent._zmq_command_loop() + + fake_socket.send_json.assert_awaited_once() + + @pytest.mark.asyncio async def test_handle_message_invalid_payload(): """Invalid payload is caught and does not send.""" diff --git a/test/unit/agents/actuation/test_robot_speech_agent.py b/test/unit/agents/actuation/test_robot_speech_agent.py index d95f66a..e5a664d 100644 --- a/test/unit/agents/actuation/test_robot_speech_agent.py +++ b/test/unit/agents/actuation/test_robot_speech_agent.py @@ -30,7 +30,11 @@ async def test_setup_bind(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -48,7 +52,11 @@ async def test_setup_connect(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() diff --git a/test/unit/agents/bdi/test_agentspeak_ast.py b/test/unit/agents/bdi/test_agentspeak_ast.py new file mode 100644 index 0000000..8d3bdf0 --- /dev/null +++ b/test/unit/agents/bdi/test_agentspeak_ast.py @@ -0,0 +1,186 @@ +import pytest + +from control_backend.agents.bdi.agentspeak_ast import ( + AstAtom, + AstBinaryOp, + AstLiteral, + AstLogicalExpression, + AstNumber, + AstPlan, + AstProgram, + AstRule, + AstStatement, + AstString, + AstVar, + BinaryOperatorType, + StatementType, + TriggerType, + _coalesce_expr, +) + + +def test_ast_atom(): + atom = AstAtom("test") + assert str(atom) == "test" + assert atom._to_agentspeak() == "test" + + +def test_ast_var(): + var = AstVar("Variable") + assert str(var) == "Variable" + assert var._to_agentspeak() == "Variable" + + +def test_ast_number(): + num = AstNumber(42) + assert str(num) == "42" + num_float = AstNumber(3.14) + assert str(num_float) == "3.14" + + +def test_ast_string(): + s = AstString("hello") + assert str(s) == '"hello"' + + +def test_ast_literal(): + lit = AstLiteral("functor", [AstAtom("atom"), AstNumber(1)]) + assert str(lit) == "functor(atom, 1)" + lit_empty = AstLiteral("functor") + assert str(lit_empty) == "functor" + + +def test_ast_binary_op(): + left = AstNumber(1) + right = AstNumber(2) + op = AstBinaryOp(left, BinaryOperatorType.GREATER_THAN, right) + assert str(op) == "1 > 2" + + # Test logical wrapper + assert isinstance(op.left, AstLogicalExpression) + assert isinstance(op.right, AstLogicalExpression) + + +def test_ast_binary_op_parens(): + # 1 > 2 + inner = AstBinaryOp(AstNumber(1), BinaryOperatorType.GREATER_THAN, AstNumber(2)) + # (1 > 2) & 3 + outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstNumber(3)) + assert str(outer) == "(1 > 2) & 3" + + # 3 & (1 > 2) + outer_right = AstBinaryOp(AstNumber(3), BinaryOperatorType.AND, inner) + assert str(outer_right) == "3 & (1 > 2)" + + +def test_ast_binary_op_parens_negated(): + inner = AstLogicalExpression(AstAtom("foo"), negated=True) + outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstAtom("bar")) + # The current implementation checks `if self.left.negated: l_str = f"({l_str})"` + # str(inner) is "not foo" + # so we expect "(not foo) & bar" + assert str(outer) == "(not foo) & bar" + + outer_right = AstBinaryOp(AstAtom("bar"), BinaryOperatorType.AND, inner) + assert str(outer_right) == "bar & (not foo)" + + +def test_ast_logical_expression_negation(): + expr = AstLogicalExpression(AstAtom("true"), negated=True) + assert str(expr) == "not true" + + expr_neg_neg = ~expr + assert str(expr_neg_neg) == "true" + assert not expr_neg_neg.negated + + # Invert a non-logical expression (wraps it) + term = AstAtom("true") + inverted = ~term + assert isinstance(inverted, AstLogicalExpression) + assert inverted.negated + assert str(inverted) == "not true" + + +def test_ast_logical_expression_no_negation(): + # _as_logical on already logical expression + expr = AstLogicalExpression(AstAtom("x")) + # Doing binary op will call _as_logical + op = AstBinaryOp(expr, BinaryOperatorType.AND, AstAtom("y")) + assert isinstance(op.left, AstLogicalExpression) + assert op.left is expr # Should reuse instance + + +def test_ast_operators(): + t1 = AstAtom("a") + t2 = AstAtom("b") + + assert str(t1 & t2) == "a & b" + assert str(t1 | t2) == "a | b" + assert str(t1 >= t2) == "a >= b" + assert str(t1 > t2) == "a > b" + assert str(t1 <= t2) == "a <= b" + assert str(t1 < t2) == "a < b" + assert str(t1 == t2) == "a == b" + assert str(t1 != t2) == r"a \== b" + + +def test_coalesce_expr(): + t = AstAtom("a") + assert str(t & "b") == 'a & "b"' + assert str(t & 1) == "a & 1" + assert str(t & 1.5) == "a & 1.5" + + with pytest.raises(TypeError): + _coalesce_expr(None) + + +def test_ast_statement(): + stmt = AstStatement(StatementType.DO_ACTION, AstLiteral("action")) + assert str(stmt) == ".action" + + +def test_ast_rule(): + # Rule with condition + rule = AstRule(AstLiteral("head"), AstLiteral("body")) + assert str(rule) == "head :- body." + + # Rule without condition + rule_simple = AstRule(AstLiteral("fact")) + assert str(rule_simple) == "fact." + + +def test_ast_plan(): + plan = AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("goal"), + [AstLiteral("context")], + [AstStatement(StatementType.DO_ACTION, AstLiteral("action"))], + ) + output = str(plan) + # verify parts exist + assert "+!goal" in output + assert ": context" in output + assert "<- .action." in output + + +def test_ast_plan_no_context(): + plan = AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("goal"), + [], + [AstStatement(StatementType.DO_ACTION, AstLiteral("action"))], + ) + output = str(plan) + assert "+!goal" in output + assert ": " not in output + assert "<- .action." in output + + +def test_ast_program(): + prog = AstProgram( + rules=[AstRule(AstLiteral("fact"))], + plans=[AstPlan(TriggerType.ADDED_BELIEF, AstLiteral("b"), [], [])], + ) + output = str(prog) + assert "fact." in output + assert "+b" in output diff --git a/test/unit/agents/bdi/test_agentspeak_generator.py b/test/unit/agents/bdi/test_agentspeak_generator.py new file mode 100644 index 0000000..5a3a849 --- /dev/null +++ b/test/unit/agents/bdi/test_agentspeak_generator.py @@ -0,0 +1,187 @@ +import uuid + +import pytest + +from control_backend.agents.bdi.agentspeak_ast import AstProgram +from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator +from control_backend.schemas.program import ( + BasicNorm, + ConditionalNorm, + Gesture, + GestureAction, + Goal, + InferredBelief, + KeywordBelief, + LLMAction, + LogicalOperator, + Phase, + Plan, + Program, + SemanticBelief, + SpeechAction, + Trigger, +) + + +@pytest.fixture +def generator(): + return AgentSpeakGenerator() + + +def test_generate_empty_program(generator): + prog = Program(phases=[]) + code = generator.generate(prog) + assert 'phase("end").' in code + assert "!notify_cycle" in code + + +def test_generate_basic_norm(generator): + norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="be nice") + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert f'norm("be nice") :- phase("{phase.id}").' in code + + +def test_generate_critical_norm(generator): + norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="safety", critical=True) + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert f'critical_norm("safety") :- phase("{phase.id}").' in code + + +def test_generate_conditional_norm(generator): + cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="please") + norm = ConditionalNorm(id=uuid.uuid4(), name="n1", norm="help", condition=cond) + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert 'norm("help")' in code + assert 'keyword_said("please")' in code + assert f"force_norm_{generator._slugify_str(norm.norm)}" in code + + +def test_generate_goal_and_plan(generator): + action = SpeechAction(id=uuid.uuid4(), name="s1", text="hello") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[action]) + # IMPORTANT: can_fail must be False for +achieved_ belief to be added + goal = Goal(id=uuid.uuid4(), name="g1", description="desc", plan=plan, can_fail=False) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + # Check trigger for goal + goal_slug = generator._slugify_str(goal.name) + assert f"+!{goal_slug}" in code + assert f'phase("{phase.id}")' in code + assert '!say("hello")' in code + + # Check success belief addition + assert f"+achieved_{goal_slug}" in code + + +def test_generate_subgoal(generator): + subplan = Plan(id=uuid.uuid4(), name="p2", steps=[]) + subgoal = Goal(id=uuid.uuid4(), name="sub1", description="sub", plan=subplan) + + plan = Plan(id=uuid.uuid4(), name="p1", steps=[subgoal]) + goal = Goal(id=uuid.uuid4(), name="g1", description="main", plan=plan) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + subgoal_slug = generator._slugify_str(subgoal.name) + # Main goal calls subgoal + assert f"!{subgoal_slug}" in code + # Subgoal plan exists + assert f"+!{subgoal_slug}" in code + + +def test_generate_trigger(generator): + cond = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[]) + trigger = Trigger(id=uuid.uuid4(), name="t1", condition=cond, plan=plan) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[], triggers=[trigger]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + # Trigger logic is added to check_triggers + assert f"{generator.slugify(cond)}" in code + assert f'notify_trigger_start("{generator.slugify(trigger)}")' in code + assert f'notify_trigger_end("{generator.slugify(trigger)}")' in code + + +def test_phase_transition(generator): + phase1 = Phase(id=uuid.uuid4(), name="p1", norms=[], goals=[], triggers=[]) + phase2 = Phase(id=uuid.uuid4(), name="p2", norms=[], goals=[], triggers=[]) + prog = Program(phases=[phase1, phase2]) + + code = generator.generate(prog) + assert "transition_phase" in code + assert f'phase("{phase1.id}")' in code + assert f'phase("{phase2.id}")' in code + assert "force_transition_phase" in code + + +def test_astify_gesture(generator): + gesture = Gesture(type="single", name="wave") + action = GestureAction(id=uuid.uuid4(), name="g1", gesture=gesture) + ast = generator._astify(action) + assert str(ast) == 'gesture("single", "wave")' + + +def test_astify_llm_action(generator): + action = LLMAction(id=uuid.uuid4(), name="l1", goal="be funny") + ast = generator._astify(action) + assert str(ast) == 'reply_with_goal("be funny")' + + +def test_astify_inferred_belief_and(generator): + left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a") + right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b") + inf = InferredBelief( + id=uuid.uuid4(), name="i1", operator=LogicalOperator.AND, left=left, right=right + ) + + ast = generator._astify(inf) + assert 'keyword_said("a") & keyword_said("b")' == str(ast) + + +def test_astify_inferred_belief_or(generator): + left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a") + right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b") + inf = InferredBelief( + id=uuid.uuid4(), name="i1", operator=LogicalOperator.OR, left=left, right=right + ) + + ast = generator._astify(inf) + assert 'keyword_said("a") | keyword_said("b")' == str(ast) + + +def test_astify_semantic_belief(generator): + sb = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc") + ast = generator._astify(sb) + assert str(ast) == f"semantic_{generator._slugify_str(sb.name)}" + + +def test_slugify_not_implemented(generator): + with pytest.raises(NotImplementedError): + generator.slugify("not a program element") + + +def test_astify_not_implemented(generator): + with pytest.raises(NotImplementedError): + generator._astify("not a program element") + + +def test_process_phase_transition_from_none(generator): + # Initialize AstProgram manually as we are bypassing generate() + generator._asp = AstProgram() + # Should safely return doing nothing + generator._add_phase_transition(None, None) + + assert len(generator._asp.plans) == 0 diff --git a/test/unit/agents/bdi/test_bdi_core_agent.py b/test/unit/agents/bdi/test_bdi_core_agent.py index 64f2ca7..152d901 100644 --- a/test/unit/agents/bdi/test_bdi_core_agent.py +++ b/test/unit/agents/bdi/test_bdi_core_agent.py @@ -57,11 +57,22 @@ async def test_handle_belief_collector_message(agent, mock_settings): await agent.handle_message(msg) - # Expect bdi_agent.call to be triggered to add belief - args = agent.bdi_agent.call.call_args.args - assert args[0] == agentspeak.Trigger.addition - assert args[1] == agentspeak.GoalType.belief - assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),)) + # Check for the specific call we expect among all calls + # bdi_agent.call is called multiple times (for transition_phase, check_triggers) + # We want to confirm the belief addition call exists + found_call = False + for call in agent.bdi_agent.call.call_args_list: + args = call.args + if ( + args[0] == agentspeak.Trigger.addition + and args[1] == agentspeak.GoalType.belief + and args[2].functor == "user_said" + and args[2].args[0].functor == "Hello" + ): + found_call = True + break + + assert found_call, "Expected belief addition call not found in bdi_agent.call history" @pytest.mark.asyncio @@ -77,11 +88,19 @@ async def test_handle_delete_belief_message(agent, mock_settings): ) await agent.handle_message(msg) - # Expect bdi_agent.call to be triggered to remove belief - args = agent.bdi_agent.call.call_args.args - assert args[0] == agentspeak.Trigger.removal - assert args[1] == agentspeak.GoalType.belief - assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),)) + found_call = False + for call in agent.bdi_agent.call.call_args_list: + args = call.args + if ( + args[0] == agentspeak.Trigger.removal + and args[1] == agentspeak.GoalType.belief + and args[2].functor == "user_said" + and args[2].args[0].functor == "Hello" + ): + found_call = True + break + + assert found_call @pytest.mark.asyncio @@ -171,7 +190,11 @@ def test_remove_belief_success_wakes_loop(agent): agent._remove_belief("remove_me", ["x"]) assert agent.bdi_agent.call.called - trigger, goaltype, literal, *_ = agent.bdi_agent.call.call_args.args + + call_args = agent.bdi_agent.call.call_args.args + trigger = call_args[0] + goaltype = call_args[1] + literal = call_args[2] assert trigger == agentspeak.Trigger.removal assert goaltype == agentspeak.GoalType.belief @@ -288,3 +311,216 @@ async def test_deadline_sleep_branch(agent): duration = time.time() - start_time assert duration >= 0.004 # loop slept until deadline + + +@pytest.mark.asyncio +async def test_handle_new_program(agent): + agent._load_asl = AsyncMock() + agent.add_behavior = MagicMock() + # Mock existing loop task so it can be cancelled + mock_task = MagicMock() + mock_task.cancel = MagicMock() + agent._bdi_loop_task = mock_task + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) + + msg = InternalMessage(to="bdi_agent", thread="new_program", body="path/to/asl.asl") + + await agent.handle_message(msg) + + mock_task.cancel.assert_called_once() + agent._load_asl.assert_awaited_once_with("path/to/asl.asl") + agent.add_behavior.assert_called() + + +@pytest.mark.asyncio +async def test_handle_user_interrupts(agent, mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + # force_phase_transition + agent._set_goal = MagicMock() + msg = InternalMessage( + to="bdi_agent", + sender=mock_settings.agent_settings.user_interrupt_name, + thread="force_phase_transition", + body="", + ) + await agent.handle_message(msg) + agent._set_goal.assert_called_with("transition_phase") + + # force_trigger + agent._force_trigger = MagicMock() + msg.thread = "force_trigger" + msg.body = "trigger_x" + await agent.handle_message(msg) + agent._force_trigger.assert_called_with("trigger_x") + + # force_norm + agent._force_norm = MagicMock() + msg.thread = "force_norm" + msg.body = "norm_y" + await agent.handle_message(msg) + agent._force_norm.assert_called_with("norm_y") + + # force_next_phase + agent._force_next_phase = MagicMock() + msg.thread = "force_next_phase" + msg.body = "" + await agent.handle_message(msg) + agent._force_next_phase.assert_called_once() + + # unknown interrupt + agent.logger = MagicMock() + msg.thread = "unknown_thing" + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_custom_action_reply_with_goal(agent): + agent._send_to_llm = MagicMock(side_effect=agent.send) + agent._add_custom_actions() + action_fn = agent.actions.actions[(".reply_with_goal", 3)] + + mock_term = MagicMock(args=["msg", "norms", "goal"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + agent._send_to_llm.assert_called_with("msg", "norms", "goal") + + +@pytest.mark.asyncio +async def test_custom_action_notify_norms(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_norms", 1)] + + mock_term = MagicMock(args=["norms_list"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + + agent.send.assert_called() + msg = agent.send.call_args[0][0] + assert msg.thread == "active_norms_update" + assert msg.body == "norms_list" + + +@pytest.mark.asyncio +async def test_custom_action_say(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".say", 1)] + + mock_term = MagicMock(args=["hello"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + + assert agent.send.call_count == 2 + msgs = [c[0][0] for c in agent.send.call_args_list] + assert any(m.to == settings.agent_settings.robot_speech_name for m in msgs) + assert any( + m.to == settings.agent_settings.llm_name and m.thread == "assistant_message" for m in msgs + ) + + +@pytest.mark.asyncio +async def test_custom_action_gesture(agent): + agent._add_custom_actions() + # Test single + action_fn = agent.actions.actions[(".gesture", 2)] + mock_term = MagicMock(args=["single", "wave"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert "actuate/gesture/single" in msg.body + + # Test tag + mock_term.args = ["tag", "happy"] + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert "actuate/gesture/tag" in msg.body + + +@pytest.mark.asyncio +async def test_custom_action_notify_user_said(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_user_said", 1)] + mock_term = MagicMock(args=["hello"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert msg.to == settings.agent_settings.llm_name + assert msg.thread == "user_message" + + +@pytest.mark.asyncio +async def test_custom_action_notify_trigger_start_end(agent): + agent._add_custom_actions() + # Start + action_fn = agent.actions.actions[(".notify_trigger_start", 1)] + gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "trigger_start" + + # End + action_fn = agent.actions.actions[(".notify_trigger_end", 1)] + gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "trigger_end" + + +@pytest.mark.asyncio +async def test_custom_action_notify_goal_start(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_goal_start", 1)] + gen = action_fn(agent, MagicMock(args=["g1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "goal_start" + + +@pytest.mark.asyncio +async def test_custom_action_notify_transition_phase(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_transition_phase", 2)] + gen = action_fn(agent, MagicMock(args=["old", "new"]), MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert msg.thread == "transition_phase" + assert "old" in msg.body and "new" in msg.body + + +def test_remove_belief_no_args(agent): + agent._wake_bdi_loop = MagicMock() + agent.bdi_agent.call.return_value = True + agent._remove_belief("fact", None) + assert agent.bdi_agent.call.called + + +def test_set_goal_with_args(agent): + agent._wake_bdi_loop = MagicMock() + agent._set_goal("goal", ["arg1", "arg2"]) + assert agent.bdi_agent.call.called + + +def test_format_belief_string(): + assert BDICoreAgent.format_belief_string("b") == "b" + assert BDICoreAgent.format_belief_string("b", ["a1", "a2"]) == "b(a1,a2)" + + +def test_force_norm(agent): + agent._add_belief = MagicMock() + agent._force_norm("be_polite") + agent._add_belief.assert_called_with("force_be_polite") + + +def test_force_trigger(agent): + agent._set_goal = MagicMock() + agent._force_trigger("trig") + agent._set_goal.assert_called_with("trig") + + +def test_force_next_phase(agent): + agent._set_goal = MagicMock() + agent._force_next_phase() + agent._set_goal.assert_called_with("force_transition_phase") diff --git a/test/unit/agents/bdi/test_bdi_program_manager.py b/test/unit/agents/bdi/test_bdi_program_manager.py index 2bed2a7..540a172 100644 --- a/test/unit/agents/bdi/test_bdi_program_manager.py +++ b/test/unit/agents/bdi/test_bdi_program_manager.py @@ -1,13 +1,13 @@ import asyncio +import json import sys import uuid -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock, mock_open, patch import pytest from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager from control_backend.core.agent_system import InternalMessage -from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program # Fix Windows Proactor loop for zmq @@ -48,24 +48,26 @@ def make_valid_program_json(norm="N1", goal="G1") -> str: ).model_dump_json() -@pytest.mark.skip(reason="Functionality being rebuilt.") @pytest.mark.asyncio -async def test_send_to_bdi(): +async def test_create_agentspeak_and_send_to_bdi(mock_settings): manager = BDIProgramManager(name="program_manager_test") manager.send = AsyncMock() program = Program.model_validate_json(make_valid_program_json()) - await manager._create_agentspeak_and_send_to_bdi(program) + + with patch("builtins.open", mock_open()) as mock_file: + await manager._create_agentspeak_and_send_to_bdi(program) + + # Check file writing + mock_file.assert_called_with("src/control_backend/agents/bdi/agentspeak.asl", "w") + handle = mock_file() + handle.write.assert_called() assert manager.send.await_count == 1 msg: InternalMessage = manager.send.await_args[0][0] - assert msg.thread == "beliefs" - - beliefs = BeliefMessage.model_validate_json(msg.body) - names = {b.name: b.arguments for b in beliefs.beliefs} - - assert "norms" in names and names["norms"] == ["N1"] - assert "goals" in names and names["goals"] == ["G1"] + assert msg.thread == "new_program" + assert msg.to == mock_settings.agent_settings.bdi_core_name + assert msg.body == "src/control_backend/agents/bdi/agentspeak.asl" @pytest.mark.asyncio @@ -81,6 +83,9 @@ async def test_receive_programs_valid_and_invalid(): manager.sub_socket = sub manager._create_agentspeak_and_send_to_bdi = AsyncMock() manager._send_clear_llm_history = AsyncMock() + manager._send_program_to_user_interrupt = AsyncMock() + manager._send_beliefs_to_semantic_belief_extractor = AsyncMock() + manager._send_goals_to_semantic_belief_extractor = AsyncMock() try: # Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out @@ -94,10 +99,9 @@ async def test_receive_programs_valid_and_invalid(): assert forwarded.phases[0].norms[0].name == "N1" assert forwarded.phases[0].goals[0].name == "G1" - # Verify history clear was triggered - assert ( - manager._send_clear_llm_history.await_count == 2 - ) # first sends program to UserInterrupt, then clears LLM + # Verify history clear was triggered exactly once (for the valid program) + # The invalid program loop `continue`s before calling _send_clear_llm_history + assert manager._send_clear_llm_history.await_count == 1 @pytest.mark.asyncio @@ -115,4 +119,179 @@ async def test_send_clear_llm_history(mock_settings): # Verify the content and recipient assert msg.body == "clear_history" - assert msg.to == "llm_agent" + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase(mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + # Setup state + prog = Program.model_validate_json(make_valid_program_json(norm="N1", goal="G1")) + manager._initialize_internal_state(prog) + + # Test valid transition (to same phase for simplicity, or we need 2 phases) + # Let's create a program with 2 phases + phase2_id = uuid.uuid4() + phase2 = Phase(id=phase2_id, name="Phase 2", norms=[], goals=[], triggers=[]) + prog.phases.append(phase2) + manager._initialize_internal_state(prog) + + current_phase_id = str(prog.phases[0].id) + next_phase_id = str(phase2_id) + + payload = json.dumps({"old": current_phase_id, "new": next_phase_id}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + assert str(manager._phase.id) == next_phase_id + + # Allow background tasks to run (add_behavior) + await asyncio.sleep(0) + + # Check notifications sent + # 1. beliefs to extractor + # 2. goals to extractor + # 3. notification to user interrupt + + assert manager.send.await_count >= 3 + + # Verify user interrupt notification + calls = manager.send.await_args_list + ui_msgs = [ + c[0][0] for c in calls if c[0][0].to == mock_settings.agent_settings.user_interrupt_name + ] + assert len(ui_msgs) > 0 + assert ui_msgs[-1].body == next_phase_id + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase_desync(): + manager = BDIProgramManager(name="program_manager_test") + manager.logger = MagicMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + + current_phase_id = str(prog.phases[0].id) + + # Request transition from WRONG old phase + payload = json.dumps({"old": "wrong_id", "new": "some_new_id"}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + # Should warn and do nothing + manager.logger.warning.assert_called_once() + assert "Phase transition desync detected" in manager.logger.warning.call_args[0][0] + assert str(manager._phase.id) == current_phase_id + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase_end(mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + current_phase_id = str(prog.phases[0].id) + + payload = json.dumps({"old": current_phase_id, "new": "end"}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + assert manager._phase is None + + # Allow background tasks to run (add_behavior) + await asyncio.sleep(0) + + # Verify notification to user interrupt + assert manager.send.await_count == 1 + msg_sent = manager.send.await_args[0][0] + assert msg_sent.to == mock_settings.agent_settings.user_interrupt_name + assert msg_sent.body == "end" + + +@pytest.mark.asyncio +async def test_handle_message_achieve_goal(mock_settings): + mock_settings.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + prog = Program.model_validate_json(make_valid_program_json(goal="TargetGoal")) + manager._initialize_internal_state(prog) + + goal_id = str(prog.phases[0].goals[0].id) + + msg = InternalMessage(to="me", sender="ui", body=goal_id, thread="achieve_goal") + + await manager.handle_message(msg) + + # Should send achieved goals to text extractor + assert manager.send.await_count == 1 + msg_sent = manager.send.await_args[0][0] + assert msg_sent.to == mock_settings.agent_settings.text_belief_extractor_name + assert msg_sent.thread == "achieved_goals" + + # Verify body + from control_backend.schemas.belief_list import GoalList + + gl = GoalList.model_validate_json(msg_sent.body) + assert len(gl.goals) == 1 + assert gl.goals[0].name == "TargetGoal" + + +@pytest.mark.asyncio +async def test_handle_message_achieve_goal_not_found(): + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + manager.logger = MagicMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + + msg = InternalMessage(to="me", sender="ui", body="non_existent_id", thread="achieve_goal") + + await manager.handle_message(msg) + + manager.send.assert_not_called() + manager.logger.debug.assert_called() + + +@pytest.mark.asyncio +async def test_setup(mock_settings): + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + manager.add_behavior = MagicMock(side_effect=close_coro) + + mock_context = MagicMock() + mock_sub = MagicMock() + mock_context.socket.return_value = mock_sub + + with patch( + "control_backend.agents.bdi.bdi_program_manager.Context.instance", return_value=mock_context + ): + # We also need to mock file writing in _create_agentspeak_and_send_to_bdi + with patch("builtins.open", new_callable=MagicMock): + await manager.setup() + + # Check logic + # 1. Sends default empty program to BDI + assert manager.send.await_count == 1 + assert manager.send.await_args[0][0].to == mock_settings.agent_settings.bdi_core_name + + # 2. Connects SUB socket + mock_sub.connect.assert_called_with(mock_settings.zmq_settings.internal_sub_address) + mock_sub.subscribe.assert_called_with("program") + + # 3. Adds behavior + manager.add_behavior.assert_called() diff --git a/test/unit/agents/bdi/test_belief_collector.py b/test/unit/agents/bdi/test_belief_collector.py deleted file mode 100644 index 69db269..0000000 --- a/test/unit/agents/bdi/test_belief_collector.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -from unittest.mock import AsyncMock - -import pytest - -from control_backend.agents.bdi import ( - BDIBeliefCollectorAgent, -) -from control_backend.core.agent_system import InternalMessage -from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief - - -@pytest.fixture -def agent(): - agent = BDIBeliefCollectorAgent("belief_collector_agent") - return agent - - -def make_msg(body: dict, sender: str = "sender"): - return InternalMessage(to="collector", sender=sender, body=json.dumps(body)) - - -@pytest.mark.asyncio -async def test_handle_message_routes_belief_text(agent, mocker): - """ - Test that when a message is received, _handle_belief_text is called with that message. - """ - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}} - spy = mocker.patch.object(agent, "_handle_belief_text", new_callable=AsyncMock) - - await agent.handle_message(make_msg(payload)) - - spy.assert_awaited_once_with(payload, "sender") - - -@pytest.mark.asyncio -async def test_handle_message_routes_emotion(agent, mocker): - payload = {"type": "emotion_extraction_text"} - spy = mocker.patch.object(agent, "_handle_emo_text", new_callable=AsyncMock) - - await agent.handle_message(make_msg(payload)) - - spy.assert_awaited_once_with(payload, "sender") - - -@pytest.mark.asyncio -async def test_handle_message_bad_json(agent, mocker): - agent._handle_belief_text = AsyncMock() - bad_msg = InternalMessage(to="collector", sender="sender", body="not json") - - await agent.handle_message(bad_msg) - - agent._handle_belief_text.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}} - spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock) - expected = [Belief(name="user_said", arguments=["hello"])] - - await agent._handle_belief_text(payload, "origin") - - spy.assert_awaited_once_with(expected, origin="origin") - - -@pytest.mark.asyncio -async def test_handle_belief_text_no_send_when_empty(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {}} - spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock) - - await agent._handle_belief_text(payload, "origin") - - spy.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_send_beliefs_to_bdi(agent): - agent.send = AsyncMock() - beliefs = [Belief(name="user_said", arguments=["hello", "world"])] - - await agent._send_beliefs_to_bdi(beliefs, origin="origin") - - agent.send.assert_awaited_once() - sent: InternalMessage = agent.send.call_args.args[0] - assert sent.to == settings.agent_settings.bdi_core_name - assert sent.thread == "beliefs" - assert json.loads(sent.body)["create"] == [belief.model_dump() for belief in beliefs] - - -@pytest.mark.asyncio -async def test_setup_executes(agent): - """Covers setup and asserts the agent has a name.""" - await agent.setup() - assert agent.name == "belief_collector_agent" # simple property assertion - - -@pytest.mark.asyncio -async def test_handle_message_unrecognized_type_executes(agent): - """Covers the else branch for unrecognized message type.""" - payload = {"type": "unknown_type"} - msg = make_msg(payload, sender="tester") - # Wrap send to ensure nothing is sent - agent.send = AsyncMock() - await agent.handle_message(msg) - # Assert no messages were sent - agent.send.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_emo_text_executes(agent): - """Covers the _handle_emo_text method.""" - # The method does nothing, but we can assert it returns None - result = await agent._handle_emo_text({}, "origin") - assert result is None - - -@pytest.mark.asyncio -async def test_send_beliefs_to_bdi_empty_executes(agent): - """Covers early return when beliefs are empty.""" - agent.send = AsyncMock() - await agent._send_beliefs_to_bdi({}) - # Assert that nothing was sent - agent.send.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_belief_text_invalid_returns_none(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "invalid-argument"}} - - result = await agent._handle_belief_text(payload, "origin") - - # The method itself returns None - assert result is None diff --git a/test/unit/agents/bdi/test_text_belief_extractor.py b/test/unit/agents/bdi/test_text_belief_extractor.py index 6782ba1..0d7dc00 100644 --- a/test/unit/agents/bdi/test_text_belief_extractor.py +++ b/test/unit/agents/bdi/test_text_belief_extractor.py @@ -14,6 +14,7 @@ from control_backend.schemas.belief_message import Belief as InternalBelief from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.chat_history import ChatHistory, ChatMessage from control_backend.schemas.program import ( + BaseGoal, # Changed from Goal ConditionalNorm, KeywordBelief, LLMAction, @@ -28,7 +29,8 @@ from control_backend.schemas.program import ( @pytest.fixture def llm(): llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4) - llm._query_llm = AsyncMock() + # We must ensure _query_llm returns a dictionary so iterating it doesn't fail + llm._query_llm = AsyncMock(return_value={}) return llm @@ -374,3 +376,155 @@ async def test_llm_failure_handling(agent, llm, sample_program): assert len(belief_changes.true) == 0 assert len(belief_changes.false) == 0 + + +def test_belief_state_bool(): + # Empty + bs = BeliefState() + assert not bs + + # True set + bs_true = BeliefState(true={InternalBelief(name="a", arguments=None)}) + assert bs_true + + # False set + bs_false = BeliefState(false={InternalBelief(name="a", arguments=None)}) + assert bs_false + + +@pytest.mark.asyncio +async def test_handle_beliefs_message_validation_error(agent, mock_settings): + # Invalid JSON + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="beliefs", + body="invalid json", + ) + # Should log warning and return + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + # Invalid Model + msg.body = json.dumps({"beliefs": [{"invalid": "obj"}]}) + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_handle_goals_message_validation_error(agent, mock_settings): + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="goals", + body="invalid json", + ) + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_handle_goal_achieved_message_validation_error(agent, mock_settings): + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="achieved_goals", + body="invalid json", + ) + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_goal_inferrer_infer_from_conversation(agent, llm): + # Setup goals + # Use BaseGoal object as typically received by the extractor + g1 = BaseGoal(id=uuid.uuid4(), name="g1", description="desc", can_fail=True) + + # Use real GoalAchievementInferrer + from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer + + inferrer = GoalAchievementInferrer(llm) + inferrer.goals = {g1} + + # Mock LLM response + llm._query_llm.return_value = True + + completions = await inferrer.infer_from_conversation(ChatHistory(messages=[])) + assert completions + # slugify uses slugify library, hard to predict exact string without it, + # but we can check values + assert list(completions.values())[0] is True + + +def test_apply_conversation_message_limit(agent): + with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s: + mock_s.behaviour_settings.conversation_history_length_limit = 2 + agent.conversation.messages = [] + + agent._apply_conversation_message(ChatMessage(role="user", content="1")) + agent._apply_conversation_message(ChatMessage(role="assistant", content="2")) + agent._apply_conversation_message(ChatMessage(role="user", content="3")) + + assert len(agent.conversation.messages) == 2 + assert agent.conversation.messages[0].content == "2" + assert agent.conversation.messages[1].content == "3" + + +@pytest.mark.asyncio +async def test_handle_program_manager_reset(agent): + with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s: + mock_s.agent_settings.bdi_program_manager_name = "pm" + agent.conversation.messages = [ChatMessage(role="user", content="hi")] + agent.belief_inferrer.available_beliefs = [ + SemanticBelief(id=uuid.uuid4(), name="b", description="d") + ] + + msg = InternalMessage(to="me", sender="pm", thread="conversation_history", body="reset") + await agent.handle_message(msg) + + assert len(agent.conversation.messages) == 0 + assert len(agent.belief_inferrer.available_beliefs) == 0 + + +def test_split_into_chunks(): + from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer + + items = [1, 2, 3, 4, 5] + chunks = SemanticBeliefInferrer._split_into_chunks(items, 2) + assert len(chunks) == 2 + assert len(chunks[0]) + len(chunks[1]) == 5 + + +@pytest.mark.asyncio +async def test_infer_beliefs_call(agent, llm): + from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer + + inferrer = SemanticBeliefInferrer(llm) + sb = SemanticBelief(id=uuid.uuid4(), name="is_happy", description="User is happy") + + llm.query = AsyncMock(return_value={"is_happy": True}) + + res = await inferrer._infer_beliefs(ChatHistory(messages=[]), [sb]) + assert res == {"is_happy": True} + llm.query.assert_called_once() + + +@pytest.mark.asyncio +async def test_infer_goal_call(agent, llm): + from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer + + inferrer = GoalAchievementInferrer(llm) + goal = BaseGoal(id=uuid.uuid4(), name="g1", description="d") + + llm.query = AsyncMock(return_value=True) + + res = await inferrer._infer_goal(ChatHistory(messages=[]), goal) + assert res is True + llm.query.assert_called_once() diff --git a/test/unit/agents/communication/test_ri_communication_agent.py b/test/unit/agents/communication/test_ri_communication_agent.py index 06d8766..a678907 100644 --- a/test/unit/agents/communication/test_ri_communication_agent.py +++ b/test/unit/agents/communication/test_ri_communication_agent.py @@ -4,6 +4,8 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent +from control_backend.core.agent_system import InternalMessage +from control_backend.schemas.ri_message import PauseCommand, RIEndpoint def speech_agent_path(): @@ -53,7 +55,11 @@ async def test_setup_success_connects_and_starts_robot(zmq_context): MockGesture.return_value.start = AsyncMock() agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -83,7 +89,11 @@ async def test_setup_binds_when_requested(zmq_context): agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True) - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) with ( patch(speech_agent_path(), autospec=True) as MockSpeech, @@ -151,6 +161,7 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context): @pytest.mark.asyncio async def test_handle_disconnection_publishes_and_reconnects(): pub_socket = AsyncMock() + pub_socket.close = MagicMock() agent = RICommunicationAgent("ri_comm") agent.pub_socket = pub_socket agent.connected = True @@ -233,6 +244,25 @@ async def test_handle_negotiation_response_unhandled_id(): ) +@pytest.mark.asyncio +async def test_handle_negotiation_response_audio(zmq_context): + agent = RICommunicationAgent("ri_comm") + + with patch( + "control_backend.agents.communication.ri_communication_agent.VADAgent", autospec=True + ) as MockVAD: + MockVAD.return_value.start = AsyncMock() + + await agent._handle_negotiation_response( + {"data": [{"id": "audio", "port": 7000, "bind": False}]} + ) + + MockVAD.assert_called_once_with( + audio_in_address="tcp://localhost:7000", audio_in_bind=False + ) + MockVAD.return_value.start.assert_awaited_once() + + @pytest.mark.asyncio async def test_stop_closes_sockets(): req = MagicMock() @@ -323,6 +353,7 @@ async def test_listen_loop_generic_exception(): @pytest.mark.asyncio async def test_handle_disconnection_timeout(monkeypatch): pub = AsyncMock() + pub.close = MagicMock() pub.send_multipart = AsyncMock(side_effect=TimeoutError) agent = RICommunicationAgent("ri_comm") @@ -365,3 +396,38 @@ async def test_negotiate_req_socket_none_causes_retry(zmq_context): result = await agent._negotiate_connection(max_retries=1) assert result is False + + +@pytest.mark.asyncio +async def test_handle_message_pause_command(zmq_context): + """Test handle_message with a valid PauseCommand.""" + agent = RICommunicationAgent("ri_comm") + agent._req_socket = AsyncMock() + agent.logger = MagicMock() + + agent._req_socket.recv_json.return_value = {"status": "ok"} + + pause_cmd = PauseCommand(data=True) + msg = InternalMessage(to="ri_comm", sender="user_int", body=pause_cmd.model_dump_json()) + + await agent.handle_message(msg) + + agent._req_socket.send_json.assert_awaited_once() + args = agent._req_socket.send_json.await_args[0][0] + assert args["endpoint"] == RIEndpoint.PAUSE.value + assert args["data"] is True + + +@pytest.mark.asyncio +async def test_handle_message_invalid_pause_command(zmq_context): + """Test handle_message with invalid JSON.""" + agent = RICommunicationAgent("ri_comm") + agent._req_socket = AsyncMock() + agent.logger = MagicMock() + + msg = InternalMessage(to="ri_comm", sender="user_int", body="invalid json") + + await agent.handle_message(msg) + + agent.logger.warning.assert_called_with("Incorrect message format for PauseCommand.") + agent._req_socket.send_json.assert_not_called() diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index 5fc07f2..a1cc297 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -58,17 +58,20 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings): to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body=prompt.model_dump_json(), + thread="prompt_message", # REQUIRED: thread must match handle_message logic ) await agent.handle_message(msg) # Verification # "Hello world." constitutes one sentence/chunk based on punctuation split - # The agent should call send once with the full sentence + # The agent should call send once with the full sentence, PLUS once more for full reply assert agent.send.called - args = agent.send.call_args_list[0][0][0] - assert args.to == mock_settings.agent_settings.bdi_core_name - assert "Hello world." in args.body + + # Check args. We expect at least one call sending "Hello world." + calls = agent.send.call_args_list + bodies = [c[0][0].body for c in calls] + assert any("Hello world." in b for b in bodies) @pytest.mark.asyncio @@ -80,18 +83,23 @@ async def test_llm_processing_errors(mock_httpx_client, mock_settings): to="llm", sender=mock_settings.agent_settings.bdi_core_name, body=prompt.model_dump_json(), + thread="prompt_message", ) - # HTTP Error + # HTTP Error: stream method RAISES exception immediately mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail")) + await agent.handle_message(msg) - assert "LLM service unavailable." in agent.send.call_args[0][0].body + + # Check that error message was sent + assert agent.send.called + assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body # General Exception agent.send.reset_mock() mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom")) await agent.handle_message(msg) - assert "Error processing the request." in agent.send.call_args[0][0].body + assert "Error processing the request." in agent.send.call_args_list[0][0][0].body @pytest.mark.asyncio @@ -113,16 +121,19 @@ async def test_llm_json_error(mock_httpx_client, mock_settings): agent = LLMAgent("llm_agent") agent.send = AsyncMock() + # Ensure logger is mocked + agent.logger = MagicMock() - with patch.object(agent.logger, "error") as log: - prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) - msg = InternalMessage( - to="llm", - sender=mock_settings.agent_settings.bdi_core_name, - body=prompt.model_dump_json(), - ) - await agent.handle_message(msg) - log.assert_called() # Should log JSONDecodeError + prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) + msg = InternalMessage( + to="llm", + sender=mock_settings.agent_settings.bdi_core_name, + body=prompt.model_dump_json(), + thread="prompt_message", + ) + await agent.handle_message(msg) + + agent.logger.error.assert_called() # Should log JSONDecodeError def test_llm_instructions(): @@ -157,6 +168,7 @@ async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body=invalid_json, + thread="prompt_message", ) await agent.handle_message(msg) @@ -285,3 +297,28 @@ async def test_clear_history_command(mock_settings): ) await agent.handle_message(msg) assert len(agent.history) == 0 + + +@pytest.mark.asyncio +async def test_handle_assistant_and_user_messages(mock_settings): + agent = LLMAgent("llm_agent") + + # Assistant message + msg_ast = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_core_name, + thread="assistant_message", + body="I said this", + ) + await agent.handle_message(msg_ast) + assert agent.history[-1] == {"role": "assistant", "content": "I said this"} + + # User message + msg_usr = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_core_name, + thread="user_message", + body="User said this", + ) + await agent.handle_message(msg_usr) + assert agent.history[-1] == {"role": "user", "content": "User said this"} diff --git a/test/unit/agents/perception/transcription_agent/test_transcription_agent.py b/test/unit/agents/perception/transcription_agent/test_transcription_agent.py index ccdaa7f..57875ca 100644 --- a/test/unit/agents/perception/transcription_agent/test_transcription_agent.py +++ b/test/unit/agents/perception/transcription_agent/test_transcription_agent.py @@ -36,7 +36,12 @@ async def test_transcription_agent_flow(mock_zmq_context): agent.send = AsyncMock() agent._running = True - agent.add_behavior = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -143,7 +148,12 @@ async def test_transcription_loop_continues_after_error(mock_zmq_context): agent = TranscriptionAgent("tcp://in") agent._running = True # ← REQUIRED to enter the loop agent.send = AsyncMock() # should never be called - agent.add_behavior = AsyncMock() # match other tests + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) # match other tests await agent.setup() @@ -180,7 +190,12 @@ async def test_transcription_continue_branch_when_empty(mock_zmq_context): # Make loop runnable agent._running = True agent.send = AsyncMock() - agent.add_behavior = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() diff --git a/test/unit/agents/perception/vad_agent/test_vad_agent.py b/test/unit/agents/perception/vad_agent/test_vad_agent.py new file mode 100644 index 0000000..fe65545 --- /dev/null +++ b/test/unit/agents/perception/vad_agent/test_vad_agent.py @@ -0,0 +1,153 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from control_backend.agents.perception.vad_agent import VADAgent +from control_backend.core.agent_system import InternalMessage +from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus + + +@pytest.fixture(autouse=True) +def mock_zmq(): + with patch("zmq.asyncio.Context") as mock: + mock.instance.return_value = MagicMock() + yield mock + + +@pytest.fixture +def agent(): + return VADAgent("tcp://localhost:5555", False) + + +@pytest.mark.asyncio +async def test_handle_message_pause(agent): + agent._paused = MagicMock() + # It starts set (not paused) + + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="PAUSE") + + # We need to mock settings to match sender name + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.clear.assert_called_once() + assert agent._reset_needed is True + + +@pytest.mark.asyncio +async def test_handle_message_resume(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="RESUME") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.set.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_message_unknown_command(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="UNKNOWN") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + agent.logger = MagicMock() + + await agent.handle_message(msg) + + agent.logger.warning.assert_called() + agent._paused.clear.assert_not_called() + agent._paused.set.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_message_unknown_sender(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="other_agent", body="PAUSE") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.clear.assert_not_called() + + +@pytest.mark.asyncio +async def test_status_loop_waits_for_running(agent): + agent._running = True + agent.program_sub_socket = AsyncMock() + agent.program_sub_socket.close = MagicMock() + agent._reset_stream = AsyncMock() + + # Sequence of messages: + # 1. Wrong topic + # 2. Right topic, wrong status (STARTING) + # 3. Right topic, RUNNING -> Should break loop + + agent.program_sub_socket.recv_multipart.side_effect = [ + (b"wrong_topic", b"whatever"), + (PROGRAM_STATUS, ProgramStatus.STARTING.value), + (PROGRAM_STATUS, ProgramStatus.RUNNING.value), + ] + + await agent._status_loop() + + assert agent._reset_stream.await_count == 1 + agent.program_sub_socket.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_setup_success(agent, mock_zmq): + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) + + mock_context = mock_zmq.instance.return_value + mock_sub = MagicMock() + mock_pub = MagicMock() + + # We expect multiple socket calls: + # 1. audio_in (SUB) + # 2. audio_out (PUB) + # 3. program_sub (SUB) + mock_context.socket.side_effect = [mock_sub, mock_pub, mock_sub] + + with patch("control_backend.agents.perception.vad_agent.torch.hub.load") as mock_load: + mock_load.return_value = (MagicMock(), None) + + with patch("control_backend.agents.perception.vad_agent.TranscriptionAgent") as MockTrans: + mock_trans_instance = MockTrans.return_value + mock_trans_instance.start = AsyncMock() + + await agent.setup() + + mock_trans_instance.start.assert_awaited_once() + + assert agent.add_behavior.call_count == 2 # streaming_loop + status_loop + assert agent.audio_in_socket is not None + assert agent.audio_out_socket is not None + assert agent.program_sub_socket is not None + + +@pytest.mark.asyncio +async def test_reset_stream(agent): + mock_poller = MagicMock() + agent.audio_in_poller = mock_poller + + # poll(1) returns not None twice, then None + mock_poller.poll = AsyncMock(side_effect=[b"data", b"data", None]) + + agent._ready = MagicMock() + + await agent._reset_stream() + + assert mock_poller.poll.await_count == 3 + agent._ready.set.assert_called_once() diff --git a/test/unit/agents/perception/vad_agent/test_vad_streaming.py b/test/unit/agents/perception/vad_agent/test_vad_streaming.py index 166919f..349fab2 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_streaming.py +++ b/test/unit/agents/perception/vad_agent/test_vad_streaming.py @@ -5,6 +5,7 @@ import pytest import zmq from control_backend.agents.perception.vad_agent import VADAgent +from control_backend.core.config import settings # We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets @@ -135,6 +136,54 @@ async def test_no_data(audio_out_socket, vad_agent): assert len(vad_agent.audio_buffer) == 0 +@pytest.mark.asyncio +async def test_streaming_loop_reset_needed(audio_out_socket, vad_agent): + """Test that _reset_needed branch works as expected.""" + vad_agent._reset_needed = True + vad_agent._ready.set() + vad_agent._paused.set() + vad_agent._running = True + vad_agent.audio_buffer = np.array([1.0], dtype=np.float32) + vad_agent.i_since_speech = 0 + + # Mock _reset_stream to stop the loop by setting _running=False + async def mock_reset(): + vad_agent._running = False + + vad_agent._reset_stream = mock_reset + + # Needs a poller to avoid AssertionError + vad_agent.audio_in_poller = AsyncMock() + vad_agent.audio_in_poller.poll.return_value = None + + await vad_agent._streaming_loop() + + assert vad_agent._reset_needed is False + assert len(vad_agent.audio_buffer) == 0 + assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech + + +@pytest.mark.asyncio +async def test_streaming_loop_no_data_clears_buffer(audio_out_socket, vad_agent): + """Test that if poll returns None, buffer is cleared if not empty.""" + vad_agent.audio_buffer = np.array([1.0], dtype=np.float32) + vad_agent._ready.set() + vad_agent._paused.set() + vad_agent._running = True + + class MockPoller: + async def poll(self, timeout_ms=None): + vad_agent._running = False # stop after one poll + return None + + vad_agent.audio_in_poller = MockPoller() + + await vad_agent._streaming_loop() + + assert len(vad_agent.audio_buffer) == 0 + assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech + + @pytest.mark.asyncio async def test_vad_model_load_failure_stops_agent(vad_agent): """ diff --git a/test/unit/agents/test_base.py b/test/unit/agents/test_base.py new file mode 100644 index 0000000..0579ada --- /dev/null +++ b/test/unit/agents/test_base.py @@ -0,0 +1,24 @@ +import logging + +from control_backend.agents.base import BaseAgent + + +class MyAgent(BaseAgent): + async def setup(self): + pass + + async def handle_message(self, msg): + pass + + +def test_base_agent_logger_init(): + # When defining a subclass, __init_subclass__ runs + # The BaseAgent in agents/base.py sets the logger + assert hasattr(MyAgent, "logger") + assert isinstance(MyAgent.logger, logging.Logger) + # The logger name depends on the package. + # Since this test file is running as a module, __package__ might be None or the test package. + # In 'src/control_backend/agents/base.py', it uses __package__ of base.py which is + # 'control_backend.agents'. + # So logger name should be control_backend.agents.MyAgent + assert MyAgent.logger.name == "control_backend.agents.MyAgent" diff --git a/test/unit/agents/user_interrupt/test_user_interrupt.py b/test/unit/agents/user_interrupt/test_user_interrupt.py index 7e3e700..7c38a05 100644 --- a/test/unit/agents/user_interrupt/test_user_interrupt.py +++ b/test/unit/agents/user_interrupt/test_user_interrupt.py @@ -7,6 +7,15 @@ import pytest from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.program import ( + ConditionalNorm, + Goal, + KeywordBelief, + Phase, + Plan, + Program, + Trigger, +) from control_backend.schemas.ri_message import RIEndpoint @@ -16,6 +25,7 @@ def agent(): agent.send = AsyncMock() agent.logger = MagicMock() agent.sub_socket = AsyncMock() + agent.pub_socket = AsyncMock() return agent @@ -49,21 +59,18 @@ async def test_send_to_gesture_agent(agent): @pytest.mark.asyncio -async def test_send_to_program_manager(agent): +async def test_send_to_bdi_belief(agent): """Verify belief update format.""" - context_str = "2" + context_str = "some_goal" - await agent._send_to_program_manager(context_str) + await agent._send_to_bdi_belief(context_str) - agent.send.assert_awaited_once() - sent_msg: InternalMessage = agent.send.call_args.args[0] + assert agent.send.await_count == 1 + sent_msg = agent.send.call_args.args[0] - assert sent_msg.to == settings.agent_settings.bdi_program_manager_name - assert sent_msg.thread == "belief_override_id" - - body = json.loads(sent_msg.body) - - assert body["belief"] == context_str + assert sent_msg.to == settings.agent_settings.bdi_core_name + assert sent_msg.thread == "beliefs" + assert "achieved_some_goal" in sent_msg.body @pytest.mark.asyncio @@ -77,6 +84,10 @@ async def test_receive_loop_routing_success(agent): # Prepare JSON payloads as bytes payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode() payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode() + # override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal). + + # To test routing, we need to populate the maps + agent._goal_map["Hello Override"] = "some_goal_slug" payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode() agent.sub_socket.recv_multipart.side_effect = [ @@ -88,7 +99,7 @@ async def test_receive_loop_routing_success(agent): agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() - agent._send_to_program_manager = AsyncMock() + agent._send_to_bdi_belief = AsyncMock() try: await agent._receive_button_event() @@ -103,12 +114,12 @@ async def test_receive_loop_routing_success(agent): # Gesture agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture") - # Override - agent._send_to_program_manager.assert_awaited_once_with("Hello Override") + # Override (since we mapped it to a goal) + agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug") assert agent._send_to_speech_agent.await_count == 1 assert agent._send_to_gesture_agent.await_count == 1 - assert agent._send_to_program_manager.await_count == 1 + assert agent._send_to_bdi_belief.await_count == 1 @pytest.mark.asyncio @@ -125,7 +136,6 @@ async def test_receive_loop_unknown_type(agent): agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() - agent._send_to_belief_collector = AsyncMock() try: await agent._receive_button_event() @@ -137,10 +147,165 @@ async def test_receive_loop_unknown_type(agent): # Ensure no handlers were called agent._send_to_speech_agent.assert_not_called() agent._send_to_gesture_agent.assert_not_called() - agent._send_to_belief_collector.assert_not_called() - agent.logger.warning.assert_called_with( - "Received button press with unknown type '%s' (context: '%s').", - "unknown_thing", - "some_data", - ) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_create_mapping(agent): + # Create a program with a trigger, goal, and conditional norm + import uuid + + trigger_id = uuid.uuid4() + goal_id = uuid.uuid4() + norm_id = uuid.uuid4() + + cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[]) + + trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan) + goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan) + + cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond) + + phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger]) + prog = Program(phases=[phase]) + + # Call create_mapping via handle_message + msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json()) + await agent.handle_message(msg) + + # Check maps + assert str(trigger_id) in agent._trigger_map + assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger" + + assert str(goal_id) in agent._goal_map + assert agent._goal_map[str(goal_id)] == "my_goal" + + assert str(norm_id) in agent._cond_norm_map + assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite" + + +@pytest.mark.asyncio +async def test_create_mapping_invalid_json(agent): + # Pass invalid json to handle_message thread "new_program" + msg = InternalMessage(to="me", thread="new_program", body="invalid json") + await agent.handle_message(msg) + + # Should log error and maps should remain empty or cleared + agent.logger.error.assert_called() + + +@pytest.mark.asyncio +async def test_handle_message_trigger_start(agent): + # Setup reverse map manually + agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" + + msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + args = agent.pub_socket.send_multipart.call_args[0][0] + assert args[0] == b"experiment" + payload = json.loads(args[1]) + assert payload["type"] == "trigger_update" + assert payload["id"] == "ui_id_123" + assert payload["achieved"] is True + + +@pytest.mark.asyncio +async def test_handle_message_trigger_end(agent): + agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" + + msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "trigger_update" + assert payload["achieved"] is False + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase(agent): + msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "phase_update" + assert payload["id"] == "phase_id_123" + + +@pytest.mark.asyncio +async def test_handle_message_goal_start(agent): + agent._goal_reverse_map["goal_slug"] = "goal_id_123" + + msg = InternalMessage(to="me", thread="goal_start", body="goal_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "goal_update" + assert payload["id"] == "goal_id_123" + assert payload["active"] is True + + +@pytest.mark.asyncio +async def test_handle_message_active_norms_update(agent): + agent._cond_norm_reverse_map["norm_active"] = "id_1" + agent._cond_norm_reverse_map["norm_inactive"] = "id_2" + + # Body is like: "('norm_active', 'other')" + # The split logic handles quotes etc. + msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "cond_norms_state_update" + norms = {n["id"]: n["active"] for n in payload["norms"]} + assert norms["id_1"] is True + assert norms["id_2"] is False + + +@pytest.mark.asyncio +async def test_send_experiment_control(agent): + # Test next_phase + await agent._send_experiment_control_to_bdi_core("next_phase") + agent.send.assert_awaited() + msg = agent.send.call_args[0][0] + assert msg.thread == "force_next_phase" + + # Test reset_phase + await agent._send_experiment_control_to_bdi_core("reset_phase") + msg = agent.send.call_args[0][0] + assert msg.thread == "reset_current_phase" + + # Test reset_experiment + await agent._send_experiment_control_to_bdi_core("reset_experiment") + msg = agent.send.call_args[0][0] + assert msg.thread == "reset_experiment" + + +@pytest.mark.asyncio +async def test_send_pause_command(agent): + await agent._send_pause_command("true") + # Sends to RI and VAD + assert agent.send.await_count == 2 + msgs = [call.args[0] for call in agent.send.call_args_list] + + ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name) + assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint + assert json.loads(ri_msg.body)["data"] is True + + vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name) + assert vad_msg.body == "PAUSE" + + agent.send.reset_mock() + await agent._send_pause_command("false") + assert agent.send.await_count == 2 + vad_msg = next( + m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name + ).args[0] + assert vad_msg.body == "RESUME" diff --git a/test/unit/api/v1/endpoints/test_user_interact.py b/test/unit/api/v1/endpoints/test_user_interact.py new file mode 100644 index 0000000..ddb9932 --- /dev/null +++ b/test/unit/api/v1/endpoints/test_user_interact.py @@ -0,0 +1,96 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from control_backend.api.v1.endpoints import user_interact + + +@pytest.fixture +def app(): + app = FastAPI() + app.include_router(user_interact.router) + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +@pytest.mark.asyncio +async def test_receive_button_event(client): + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket + + payload = {"type": "speech", "context": "hello"} + response = client.post("/button_pressed", json=payload) + + assert response.status_code == 202 + assert response.json() == {"status": "Event received"} + + mock_pub_socket.send_multipart.assert_awaited_once() + args = mock_pub_socket.send_multipart.call_args[0][0] + assert args[0] == b"button_pressed" + assert "speech" in args[1].decode() + + +@pytest.mark.asyncio +async def test_receive_button_event_invalid_payload(client): + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket + + # Missing context + payload = {"type": "speech"} + response = client.post("/button_pressed", json=payload) + + assert response.status_code == 422 + mock_pub_socket.send_multipart.assert_not_called() + + +@pytest.mark.asyncio +async def test_experiment_stream_direct_call(): + """ + Directly calling the endpoint function to test the streaming logic + without dealing with TestClient streaming limitations. + """ + mock_socket = AsyncMock() + # 1. recv data + # 2. recv timeout + # 3. disconnect (request.is_disconnected returns True) + mock_socket.recv_multipart.side_effect = [ + (b"topic", b"message1"), + TimeoutError(), + (b"topic", b"message2"), # Should not be reached if disconnect checks work + ] + mock_socket.close = MagicMock() + mock_socket.connect = MagicMock() + mock_socket.subscribe = MagicMock() + + mock_context = MagicMock() + mock_context.socket.return_value = mock_socket + + with patch( + "control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context + ): + mock_request = AsyncMock() + # is_disconnected sequence: + # 1. False (before first recv) -> reads message1 + # 2. False (before second recv) -> triggers TimeoutError, continues + # 3. True (before third recv) -> break loop + mock_request.is_disconnected.side_effect = [False, False, True] + + response = await user_interact.experiment_stream(mock_request) + + lines = [] + # Consume the generator + async for line in response.body_iterator: + lines.append(line) + + assert "data: message1\n\n" in lines + assert len(lines) == 1 + + mock_socket.connect.assert_called() + mock_socket.subscribe.assert_called_with(b"experiment") + mock_socket.close.assert_called() diff --git a/test/unit/test_main_sockets.py b/test/unit/test_main_sockets.py new file mode 100644 index 0000000..662147a --- /dev/null +++ b/test/unit/test_main_sockets.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +import zmq + +from control_backend.main import setup_sockets + + +def test_setup_sockets_proxy(): + mock_context = MagicMock() + mock_pub = MagicMock() + mock_sub = MagicMock() + + mock_context.socket.side_effect = [mock_pub, mock_sub] + + with patch("zmq.asyncio.Context.instance", return_value=mock_context): + with patch("zmq.proxy") as mock_proxy: + setup_sockets() + + mock_pub.bind.assert_called() + mock_sub.bind.assert_called() + mock_proxy.assert_called_with(mock_sub, mock_pub) + + # Check cleanup + mock_pub.close.assert_called() + mock_sub.close.assert_called() + + +def test_setup_sockets_proxy_error(): + mock_context = MagicMock() + mock_pub = MagicMock() + mock_sub = MagicMock() + mock_context.socket.side_effect = [mock_pub, mock_sub] + + with patch("zmq.asyncio.Context.instance", return_value=mock_context): + with patch("zmq.proxy", side_effect=zmq.ZMQError): + with patch("control_backend.main.logger") as mock_logger: + setup_sockets() + mock_logger.warning.assert_called() + mock_pub.close.assert_called() + mock_sub.close.assert_called()