diff --git a/.env.example b/.env.example index d498054..41a382a 100644 --- a/.env.example +++ b/.env.example @@ -10,7 +10,7 @@ LLM_SETTINGS__LOCAL_LLM_URL="http://localhost:1234/v1/chat/completions" LLM_SETTINGS__LOCAL_LLM_MODEL="gpt-oss" # Number of non-speech chunks to wait before speech ended. A chunk is approximately 31 ms. Increasing this number allows longer pauses in speech, but also increases response time. -BEHAVIOUR_SETTINGS__VAD_NON_SPEECH_PATIENCE_CHUNKS=3 +BEHAVIOUR_SETTINGS__VAD_NON_SPEECH_PATIENCE_CHUNKS=15 # Timeout in milliseconds for socket polling. Increase this number if network latency/jitter is high, often the case when using Wi-Fi. Perhaps 500 ms. A symptom of this issue is transcriptions getting cut off. BEHAVIOUR_SETTINGS__SOCKET_POLLER_TIMEOUT_MS=100 diff --git a/.gitignore b/.gitignore index 47ef46d..41b7458 100644 --- a/.gitignore +++ b/.gitignore @@ -223,6 +223,7 @@ docs/* !docs/conf.py # Generated files +*.asl experiment-*.log @@ -274,6 +275,5 @@ experiment-*.log - diff --git a/pyproject.toml b/pyproject.toml index e57a03c..5de7daa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "pydantic>=2.12.0", "pydantic-settings>=2.11.0", "python-json-logger>=4.0.0", + "python-slugify>=8.0.4", "pyyaml>=6.0.3", "pyzmq>=27.1.0", "silero-vad>=6.0.0", @@ -47,6 +48,7 @@ test = [ "pytest-asyncio>=1.2.0", "pytest-cov>=7.0.0", "pytest-mock>=3.15.1", + "python-slugify>=8.0.4", "pyyaml>=6.0.3", "pyzmq>=27.1.0", "soundfile>=0.13.1", diff --git a/src/control_backend/agents/__init__.py b/src/control_backend/agents/__init__.py index 1618d55..85f4aad 100644 --- a/src/control_backend/agents/__init__.py +++ b/src/control_backend/agents/__init__.py @@ -1 +1,5 @@ +""" +This package contains all agent implementations for the PepperPlus Control Backend. +""" + from .base import BaseAgent as BaseAgent diff --git a/src/control_backend/agents/actuation/__init__.py b/src/control_backend/agents/actuation/__init__.py index 8ff7e7f..9a8d81b 100644 --- a/src/control_backend/agents/actuation/__init__.py +++ b/src/control_backend/agents/actuation/__init__.py @@ -1,2 +1,6 @@ +""" +Agents responsible for controlling the robot's physical actions, such as speech and gestures. +""" + from .robot_gesture_agent import RobotGestureAgent as RobotGestureAgent from .robot_speech_agent import RobotSpeechAgent as RobotSpeechAgent diff --git a/src/control_backend/agents/actuation/robot_gesture_agent.py b/src/control_backend/agents/actuation/robot_gesture_agent.py index 3b264d2..997b684 100644 --- a/src/control_backend/agents/actuation/robot_gesture_agent.py +++ b/src/control_backend/agents/actuation/robot_gesture_agent.py @@ -83,6 +83,8 @@ class RobotGestureAgent(BaseAgent): self.subsocket.close() if self.pubsocket: self.pubsocket.close() + if self.repsocket: + self.repsocket.close() await super().stop() async def handle_message(self, msg: InternalMessage): diff --git a/src/control_backend/agents/bdi/__init__.py b/src/control_backend/agents/bdi/__init__.py index 8d45440..2f7d976 100644 --- a/src/control_backend/agents/bdi/__init__.py +++ b/src/control_backend/agents/bdi/__init__.py @@ -1,8 +1,10 @@ +""" +Agents and utilities for the BDI (Belief-Desire-Intention) reasoning system, +implementing AgentSpeak(L) logic. +""" + from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent -from .belief_collector_agent import ( - BDIBeliefCollectorAgent as BDIBeliefCollectorAgent, -) from .text_belief_extractor_agent import ( TextBeliefExtractorAgent as TextBeliefExtractorAgent, ) diff --git a/src/control_backend/agents/bdi/agentspeak_ast.py b/src/control_backend/agents/bdi/agentspeak_ast.py new file mode 100644 index 0000000..12c7947 --- /dev/null +++ b/src/control_backend/agents/bdi/agentspeak_ast.py @@ -0,0 +1,570 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import StrEnum + + +class AstNode(ABC): + """ + Abstract base class for all elements of an AgentSpeak program. + + This class serves as the foundation for all AgentSpeak abstract syntax tree (AST) nodes. + It defines the core interface that all AST nodes must implement to generate AgentSpeak code. + """ + + @abstractmethod + def _to_agentspeak(self) -> str: + """ + Generates the AgentSpeak code string. + + This method converts the AST node into its corresponding + AgentSpeak source code representation. + + :return: The AgentSpeak code string representation of this node. + """ + pass + + def __str__(self) -> str: + """ + Returns the string representation of this AST node. + + This method provides a convenient way to get the AgentSpeak code representation + by delegating to the _to_agentspeak method. + + :return: The AgentSpeak code string representation of this node. + """ + return self._to_agentspeak() + + +class AstExpression(AstNode, ABC): + """ + Intermediate class for anything that can be used in a logical expression. + + This class extends AstNode to provide common functionality for all expressions + that can be used in logical operations within AgentSpeak programs. + """ + + def __and__(self, other: ExprCoalescible) -> AstBinaryOp: + """ + Creates a logical AND operation between this expression and another. + + This method allows for operator overloading of the & operator to create + binary logical operations in a more intuitive syntax. + + :param other: The right-hand side expression to combine with this one. + :return: A new AstBinaryOp representing the logical AND operation. + """ + return AstBinaryOp(self, BinaryOperatorType.AND, _coalesce_expr(other)) + + def __or__(self, other: ExprCoalescible) -> AstBinaryOp: + """ + Creates a logical OR operation between this expression and another. + + This method allows for operator overloading of the | operator to create + binary logical operations in a more intuitive syntax. + + :param other: The right-hand side expression to combine with this one. + :return: A new AstBinaryOp representing the logical OR operation. + """ + return AstBinaryOp(self, BinaryOperatorType.OR, _coalesce_expr(other)) + + def __invert__(self) -> AstLogicalExpression: + """ + Creates a logical negation of this expression. + + This method allows for operator overloading of the ~ operator to create + negated expressions. If the expression is already a logical expression, + it toggles the negation flag. Otherwise, it wraps the expression in a + new AstLogicalExpression with negation set to True. + + :return: An AstLogicalExpression representing the negated form of this expression. + """ + if isinstance(self, AstLogicalExpression): + self.negated = not self.negated + return self + return AstLogicalExpression(self, negated=True) + + +type ExprCoalescible = AstExpression | str | int | float + + +def _coalesce_expr(value: ExprCoalescible) -> AstExpression: + if isinstance(value, AstExpression): + return value + if isinstance(value, str): + return AstString(value) + if isinstance(value, (int, float)): + return AstNumber(value) + raise TypeError(f"Cannot coalesce type {type(value)} into an AstTerm.") + + +@dataclass +class AstTerm(AstExpression, ABC): + """ + Base class for terms appearing inside literals. + """ + + def __ge__(self, other: ExprCoalescible) -> AstBinaryOp: + return AstBinaryOp(self, BinaryOperatorType.GREATER_EQUALS, _coalesce_expr(other)) + + def __gt__(self, other: ExprCoalescible) -> AstBinaryOp: + return AstBinaryOp(self, BinaryOperatorType.GREATER_THAN, _coalesce_expr(other)) + + def __le__(self, other: ExprCoalescible) -> AstBinaryOp: + return AstBinaryOp(self, BinaryOperatorType.LESS_EQUALS, _coalesce_expr(other)) + + def __lt__(self, other: ExprCoalescible) -> AstBinaryOp: + return AstBinaryOp(self, BinaryOperatorType.LESS_THAN, _coalesce_expr(other)) + + def __eq__(self, other: ExprCoalescible) -> AstBinaryOp: + return AstBinaryOp(self, BinaryOperatorType.EQUALS, _coalesce_expr(other)) + + def __ne__(self, other: ExprCoalescible) -> AstBinaryOp: + return AstBinaryOp(self, BinaryOperatorType.NOT_EQUALS, _coalesce_expr(other)) + + +@dataclass(eq=False) +class AstAtom(AstTerm): + """ + Represents a grounded atom in AgentSpeak (e.g., lowercase constants). + + Atoms are the simplest form of terms in AgentSpeak, representing concrete, + unchanging values. They are typically used as constants in logical expressions. + + :ivar value: The string value of this atom, which will be converted to lowercase + in the AgentSpeak representation. + """ + + value: str + + def _to_agentspeak(self) -> str: + """ + Converts this atom to its AgentSpeak string representation. + + Atoms are represented in lowercase in AgentSpeak to distinguish them + from variables (which are capitalized). + + :return: The lowercase string representation of this atom. + """ + return self.value.lower() + + +@dataclass(eq=False) +class AstVar(AstTerm): + """ + Represents an ungrounded variable in AgentSpeak (e.g., capitalized names). + + Variables in AgentSpeak are placeholders that can be bound to specific values + during execution. They are distinguished from atoms by their capitalization. + + :ivar name: The name of this variable, which will be capitalized in the + AgentSpeak representation. + """ + + name: str + + def _to_agentspeak(self) -> str: + """ + Converts this variable to its AgentSpeak string representation. + + Variables are represented with capitalized names in AgentSpeak to distinguish + them from atoms (which are lowercase). + + :return: The capitalized string representation of this variable. + """ + return self.name.capitalize() + + +@dataclass(eq=False) +class AstNumber(AstTerm): + """ + Represents a numeric constant in AgentSpeak. + + Numeric constants can be either integers or floating-point numbers and are + used in logical expressions and comparisons. + + :ivar value: The numeric value of this constant (can be int or float). + """ + + value: int | float + + def _to_agentspeak(self) -> str: + """ + Converts this numeric constant to its AgentSpeak string representation. + + :return: The string representation of the numeric value. + """ + return str(self.value) + + +@dataclass(eq=False) +class AstString(AstTerm): + """ + Represents a string literal in AgentSpeak. + + String literals are used to represent textual data and are enclosed in + double quotes in the AgentSpeak representation. + + :ivar value: The string content of this literal. + """ + + value: str + + def _to_agentspeak(self) -> str: + """ + Converts this string literal to its AgentSpeak string representation. + + String literals are enclosed in double quotes in AgentSpeak. + + :return: The string literal enclosed in double quotes. + """ + return f'"{self.value}"' + + +@dataclass(eq=False) +class AstLiteral(AstTerm): + """ + Represents a literal (functor and terms) in AgentSpeak. + + Literals are the fundamental building blocks of AgentSpeak programs, consisting + of a functor (predicate name) and an optional list of terms (arguments). + + :ivar functor: The name of the predicate or function. + :ivar terms: A list of terms (arguments) for this literal. Defaults to an empty list. + """ + + functor: str + terms: list[AstTerm] = field(default_factory=list) + + def _to_agentspeak(self) -> str: + """ + Converts this literal to its AgentSpeak string representation. + + If the literal has no terms, it returns just the functor name. + Otherwise, it returns the functor followed by the terms in parentheses. + + :return: The AgentSpeak string representation of this literal. + """ + if not self.terms: + return self.functor + args = ", ".join(map(str, self.terms)) + return f"{self.functor}({args})" + + +class BinaryOperatorType(StrEnum): + """ + Enumeration of binary operator types used in AgentSpeak expressions. + + These operators are used to create binary operations between expressions, + including logical operations (AND, OR) and comparison operations. + """ + + AND = "&" + OR = "|" + GREATER_THAN = ">" + LESS_THAN = "<" + EQUALS = "==" + NOT_EQUALS = "\\==" + GREATER_EQUALS = ">=" + LESS_EQUALS = "<=" + + +@dataclass +class AstBinaryOp(AstExpression): + """ + Represents a binary logical or relational operation in AgentSpeak. + + Binary operations combine two expressions using a logical or comparison operator. + They are used to create complex logical conditions in AgentSpeak programs. + + :ivar left: The left-hand side expression of the operation. + :ivar operator: The binary operator type (AND, OR, comparison operators, etc.). + :ivar right: The right-hand side expression of the operation. + """ + + left: AstExpression + operator: BinaryOperatorType + right: AstExpression + + def __post_init__(self): + """ + Post-initialization processing to ensure proper expression types. + + This method wraps the left and right expressions in AstLogicalExpression + instances if they aren't already, ensuring consistent handling throughout + the AST. + """ + self.left = _as_logical(self.left) + self.right = _as_logical(self.right) + + def _to_agentspeak(self) -> str: + """ + Converts this binary operation to its AgentSpeak string representation. + + The method handles proper parenthesization of sub-expressions to maintain + correct operator precedence and readability. + + :return: The AgentSpeak string representation of this binary operation. + """ + l_str = str(self.left) + r_str = str(self.right) + + assert isinstance(self.left, AstLogicalExpression) + assert isinstance(self.right, AstLogicalExpression) + + if isinstance(self.left.expression, AstBinaryOp) or self.left.negated: + l_str = f"({l_str})" + if isinstance(self.right.expression, AstBinaryOp) or self.right.negated: + r_str = f"({r_str})" + + return f"{l_str} {self.operator.value} {r_str}" + + +@dataclass +class AstLogicalExpression(AstExpression): + """ + Represents a logical expression, potentially negated, in AgentSpeak. + + Logical expressions can be either positive or negated and form the basis + of conditions and beliefs in AgentSpeak programs. + + :ivar expression: The underlying expression being evaluated. + :ivar negated: Boolean flag indicating whether this expression is negated. + """ + + expression: AstExpression + negated: bool = False + + def _to_agentspeak(self) -> str: + """ + Converts this logical expression to its AgentSpeak string representation. + + If the expression is negated, it prepends 'not ' to the expression string. + For complex expressions (binary operations), it adds parentheses when negated + to maintain correct logical interpretation. + + :return: The AgentSpeak string representation of this logical expression. + """ + expr_str = str(self.expression) + if isinstance(self.expression, AstBinaryOp) and self.negated: + expr_str = f"({expr_str})" + return f"{'not ' if self.negated else ''}{expr_str}" + + +def _as_logical(expr: AstExpression) -> AstLogicalExpression: + """ + Converts an expression to a logical expression if it isn't already. + + This helper function ensures that expressions are properly wrapped in + AstLogicalExpression instances, which is necessary for consistent handling + of logical operations in the AST. + + :param expr: The expression to convert. + :return: The expression wrapped in an AstLogicalExpression if it wasn't already. + """ + if isinstance(expr, AstLogicalExpression): + return expr + return AstLogicalExpression(expr) + + +class StatementType(StrEnum): + """ + Enumeration of statement types that can appear in AgentSpeak plans. + + These statement types define the different kinds of actions and operations + that can be performed within the body of an AgentSpeak plan. + """ + + EMPTY = "" + """Empty statement (no operation, used when evaluating a plan to true).""" + + DO_ACTION = "." + """Execute an action defined in Python.""" + + ACHIEVE_GOAL = "!" + """Achieve a goal (add a goal to be accomplished).""" + + TEST_GOAL = "?" + """Test a goal (check if a goal can be achieved).""" + + ADD_BELIEF = "+" + """Add a belief to the belief base.""" + + REMOVE_BELIEF = "-" + """Remove a belief from the belief base.""" + + REPLACE_BELIEF = "-+" + """Replace a belief in the belief base.""" + + +@dataclass +class AstStatement(AstNode): + """ + A statement that can appear inside a plan. + + Statements are the executable units within AgentSpeak plans. They consist + of a statement type (defining the operation) and an expression (defining + what to operate on). + + :ivar type: The type of statement (action, goal, belief operation, etc.). + :ivar expression: The expression that this statement operates on. + """ + + type: StatementType + expression: AstExpression + + def _to_agentspeak(self) -> str: + """ + Converts this statement to its AgentSpeak string representation. + + The representation consists of the statement type prefix followed by + the expression. + + :return: The AgentSpeak string representation of this statement. + """ + return f"{self.type.value}{self.expression}" + + +@dataclass +class AstRule(AstNode): + """ + Represents an inference rule in AgentSpeak. If there is no condition, it always holds. + + Rules define logical implications in AgentSpeak programs. They consist of a + result (conclusion) and an optional condition (premise). When the condition + holds, the result is inferred to be true. + + :ivar result: The conclusion or result of this rule. + :ivar condition: The premise or condition for this rule (optional). + """ + + result: AstExpression + condition: AstExpression | None = None + + def __post_init__(self): + """ + Post-initialization processing to ensure proper expression types. + + If a condition is provided, this method wraps it in an AstLogicalExpression + to ensure consistent handling throughout the AST. + """ + if self.condition is not None: + self.condition = _as_logical(self.condition) + + def _to_agentspeak(self) -> str: + """ + Converts this rule to its AgentSpeak string representation. + + If no condition is specified, the rule is represented as a simple fact. + If a condition is specified, it's represented as an implication (result :- condition). + + :return: The AgentSpeak string representation of this rule. + """ + if not self.condition: + return f"{self.result}." + return f"{self.result} :- {self.condition}." + + +class TriggerType(StrEnum): + """ + Enumeration of trigger types for AgentSpeak plans. + + Trigger types define what kind of events can activate an AgentSpeak plan. + Currently, the system supports triggers for added beliefs and added goals. + """ + + ADDED_BELIEF = "+" + """Trigger when a belief is added to the belief base.""" + + # REMOVED_BELIEF = "-" # TODO + # MODIFIED_BELIEF = "^" # TODO + + ADDED_GOAL = "+!" + """Trigger when a goal is added to be achieved.""" + + # REMOVED_GOAL = "-!" # TODO + + +@dataclass +class AstPlan(AstNode): + """ + Represents a plan in AgentSpeak, consisting of a trigger, context, and body. + + Plans define the reactive behavior of agents in AgentSpeak. They specify what + actions to take when certain conditions are met (trigger and context). + + :ivar type: The type of trigger that activates this plan. + :ivar trigger_literal: The specific event or condition that triggers this plan. + :ivar context: A list of conditions that must hold for this plan to be applicable. + :ivar body: A list of statements to execute when this plan is triggered. + """ + + type: TriggerType + trigger_literal: AstExpression + context: list[AstExpression] + body: list[AstStatement] + + def _to_agentspeak(self) -> str: + """ + Converts this plan to its AgentSpeak string representation. + + The representation follows the standard AgentSpeak plan format: + trigger_type + trigger_literal + : context_conditions + <- body_statements. + + :return: The AgentSpeak string representation of this plan. + """ + assert isinstance(self.trigger_literal, AstLiteral) + + indent = " " * 6 + colon = " : " + arrow = " <- " + + lines = [] + + lines.append(f"{self.type.value}{self.trigger_literal}") + + if self.context: + lines.append(colon + f" &\n{indent}".join(str(c) for c in self.context)) + + if self.body: + lines.append(arrow + f";\n{indent}".join(str(s) for s in self.body) + ".") + + lines.append("") + + return "\n".join(lines) + + +@dataclass +class AstProgram(AstNode): + """ + Represents a full AgentSpeak program, consisting of rules and plans. + + This is the root node of the AgentSpeak AST, containing all the rules + and plans that define the agent's behavior. + + :ivar rules: A list of inference rules in this program. + :ivar plans: A list of reactive plans in this program. + """ + + rules: list[AstRule] = field(default_factory=list) + plans: list[AstPlan] = field(default_factory=list) + + def _to_agentspeak(self) -> str: + """ + Converts this program to its AgentSpeak string representation. + + The representation consists of all rules followed by all plans, + separated by blank lines for readability. + + :return: The complete AgentSpeak source code for this program. + """ + lines = [] + lines.extend(map(str, self.rules)) + + lines.extend(["", ""]) + lines.extend(map(str, self.plans)) + + return "\n".join(lines) diff --git a/src/control_backend/agents/bdi/agentspeak_generator.py b/src/control_backend/agents/bdi/agentspeak_generator.py new file mode 100644 index 0000000..7c9d8f0 --- /dev/null +++ b/src/control_backend/agents/bdi/agentspeak_generator.py @@ -0,0 +1,881 @@ +from functools import singledispatchmethod + +from slugify import slugify + +from control_backend.agents.bdi.agentspeak_ast import ( + AstAtom, + AstBinaryOp, + AstExpression, + AstLiteral, + AstNumber, + AstPlan, + AstProgram, + AstRule, + AstStatement, + AstString, + AstVar, + BinaryOperatorType, + StatementType, + TriggerType, +) +from control_backend.core.config import settings +from control_backend.schemas.program import ( + BaseGoal, + BasicNorm, + ConditionalNorm, + GestureAction, + Goal, + InferredBelief, + KeywordBelief, + LLMAction, + LogicalOperator, + Norm, + Phase, + PlanElement, + Program, + ProgramElement, + SemanticBelief, + SpeechAction, + Trigger, +) + + +class AgentSpeakGenerator: + """ + Generator class that translates a high-level :class:`~control_backend.schemas.program.Program` + into AgentSpeak(L) source code. + + It handles the conversion of phases, norms, goals, and triggers into AgentSpeak rules and plans, + ensuring the robot follows the defined behavioral logic. + + The generator follows a systematic approach: + 1. Sets up initial phase and cycle notification rules + 2. Adds keyword inference capabilities for natural language processing + 3. Creates default plans for common operations + 4. Processes each phase with its norms, goals, and triggers + 5. Adds fallback plans for robust execution + + :ivar _asp: The internal AgentSpeak program representation being built. + """ + + _asp: AstProgram + + def generate(self, program: Program) -> str: + """ + Translates a Program object into an AgentSpeak source string. + + This is the main entry point for the code generation process. It initializes + the AgentSpeak program structure and orchestrates the conversion of all + program elements into their AgentSpeak representations. + + :param program: The behavior program to translate. + :return: The generated AgentSpeak code as a string. + """ + self._asp = AstProgram() + + if program.phases: + self._asp.rules.append(AstRule(self._astify(program.phases[0]))) + else: + self._asp.rules.append(AstRule(AstLiteral("phase", [AstString("end")]))) + + self._asp.rules.append(AstRule(AstLiteral("!notify_cycle"))) + + self._add_keyword_inference() + self._add_default_plans() + + self._process_phases(program.phases) + + self._add_fallbacks() + + return str(self._asp) + + def _add_keyword_inference(self) -> None: + """ + Adds inference rules for keyword detection in user messages. + + This method creates rules that allow the system to detect when specific + keywords are mentioned in user messages. It uses string operations to + check if a keyword is a substring of the user's message. + + The generated rule has the form: + keyword_said(Keyword) :- user_said(Message) & .substring(Keyword, Message, Pos) & Pos >= 0 + + This enables the system to trigger behaviors based on keyword detection. + """ + keyword = AstVar("Keyword") + message = AstVar("Message") + position = AstVar("Pos") + + self._asp.rules.append( + AstRule( + AstLiteral("keyword_said", [keyword]), + AstLiteral("user_said", [message]) + & AstLiteral(".substring", [keyword, message, position]) + & (position >= 0), + ) + ) + + def _add_default_plans(self): + """ + Adds default plans for common operations. + + This method sets up the standard plans that handle fundamental operations + like replying with goals, simple speech actions, general replies, and + cycle notifications. These plans provide the basic infrastructure for + the agent's reactive behavior. + """ + self._add_reply_with_goal_plan() + self._add_say_plan() + self._add_reply_plan() + self._add_notify_cycle_plan() + + def _add_reply_with_goal_plan(self): + """ + Adds a plan for replying with a specific conversational goal. + + This plan handles the case where the agent needs to respond to user input + while pursuing a specific conversational goal. It: + 1. Marks that the agent has responded this turn + 2. Gathers all active norms + 3. Generates a reply that considers both the user message and the goal + + Trigger: +!reply_with_goal(Goal) + Context: user_said(Message) + """ + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("reply_with_goal", [AstVar("Goal")]), + [AstLiteral("user_said", [AstVar("Message")])], + [ + AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")), + AstStatement( + StatementType.DO_ACTION, + AstLiteral( + "findall", + [AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")], + ), + ), + AstStatement( + StatementType.DO_ACTION, + AstLiteral( + "reply_with_goal", [AstVar("Message"), AstVar("Norms"), AstVar("Goal")] + ), + ), + ], + ) + ) + + def _add_say_plan(self): + """ + Adds a plan for simple speech actions. + + This plan handles direct speech actions where the agent needs to say + a specific text. It: + 1. Marks that the agent has responded this turn + 2. Executes the speech action + + Trigger: +!say(Text) + Context: None (can be executed anytime) + """ + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("say", [AstVar("Text")]), + [], + [ + AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")), + AstStatement(StatementType.DO_ACTION, AstLiteral("say", [AstVar("Text")])), + ], + ) + ) + + def _add_reply_plan(self): + """ + Adds a plan for general reply actions. + + This plan handles general reply actions where the agent needs to respond + to user input without a specific conversational goal. It: + 1. Marks that the agent has responded this turn + 2. Gathers all active norms + 3. Generates a reply based on the user message and norms + + Trigger: +!reply + Context: user_said(Message) + """ + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("reply"), + [AstLiteral("user_said", [AstVar("Message")])], + [ + AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")), + AstStatement( + StatementType.DO_ACTION, + AstLiteral( + "findall", + [AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")], + ), + ), + AstStatement( + StatementType.DO_ACTION, + AstLiteral("reply", [AstVar("Message"), AstVar("Norms")]), + ), + ], + ) + ) + + def _add_notify_cycle_plan(self): + """ + Adds a plan for cycle notification. + + This plan handles the periodic notification cycle that allows the system + to monitor and report on the current state. It: + 1. Gathers all active norms + 2. Notifies the system about the current norms + 3. Waits briefly to allow processing + 4. Recursively triggers the next cycle + + Trigger: +!notify_cycle + Context: None (can be executed anytime) + """ + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("notify_cycle"), + [], + [ + AstStatement( + StatementType.DO_ACTION, + AstLiteral( + "findall", + [AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")], + ), + ), + AstStatement( + StatementType.DO_ACTION, AstLiteral("notify_norms", [AstVar("Norms")]) + ), + AstStatement(StatementType.DO_ACTION, AstLiteral("wait", [AstNumber(100)])), + AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("notify_cycle")), + ], + ) + ) + + def _process_phases(self, phases: list[Phase]) -> None: + """ + Processes all phases in the program and their transitions. + + This method iterates through each phase and: + 1. Processes the current phase (norms, goals, triggers) + 2. Sets up transitions between phases + 3. Adds special handling for the end phase + + :param phases: The list of phases to process. + """ + for curr_phase, next_phase in zip([None] + phases, phases + [None], strict=True): + if curr_phase: + self._process_phase(curr_phase) + self._add_phase_transition(curr_phase, next_phase) + + # End phase behavior + # When deleting this, the entire `reply` plan and action can be deleted + self._asp.plans.append( + AstPlan( + type=TriggerType.ADDED_BELIEF, + trigger_literal=AstLiteral("user_said", [AstVar("Message")]), + context=[AstLiteral("phase", [AstString("end")])], + body=[ + AstStatement( + StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")]) + ), + AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply")), + ], + ) + ) + + def _process_phase(self, phase: Phase) -> None: + """ + Processes a single phase, including its norms, goals, and triggers. + + This method handles the complete processing of a phase by: + 1. Processing all norms in the phase + 2. Setting up the default execution loop for the phase + 3. Processing all goals in sequence + 4. Processing all triggers for reactive behavior + + :param phase: The phase to process. + """ + for norm in phase.norms: + self._process_norm(norm, phase) + + self._add_default_loop(phase) + + previous_goal = None + for goal in phase.goals: + self._process_goal(goal, phase, previous_goal, main_goal=True) + previous_goal = goal + + for trigger in phase.triggers: + self._process_trigger(trigger, phase) + + def _add_phase_transition(self, from_phase: Phase | None, to_phase: Phase | None) -> None: + """ + Adds plans for transitioning between phases. + + This method creates two plans for each phase transition: + 1. A check plan that verifies if transition conditions are met + 2. A force plan that actually performs the transition (can be forced externally) + + The transition involves: + - Notifying the system about the phase change + - Removing the current phase belief + - Adding the next phase belief + + :param from_phase: The phase being transitioned from (or None for initial setup). + :param to_phase: The phase being transitioned to (or None for end phase). + """ + if from_phase is None: + return + from_phase_ast = self._astify(from_phase) + to_phase_ast = ( + self._astify(to_phase) if to_phase else AstLiteral("phase", [AstString("end")]) + ) + + check_context = [from_phase_ast] + if from_phase: + for goal in from_phase.goals: + check_context.append(self._astify(goal, achieved=True)) + + force_context = [from_phase_ast] + + body = [ + AstStatement( + StatementType.DO_ACTION, + AstLiteral( + "notify_transition_phase", + [ + AstString(str(from_phase.id)), + AstString(str(to_phase.id) if to_phase else "end"), + ], + ), + ), + AstStatement(StatementType.REMOVE_BELIEF, from_phase_ast), + AstStatement(StatementType.ADD_BELIEF, to_phase_ast), + ] + + # Check + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("transition_phase"), + check_context, + [ + AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("force_transition_phase")), + ], + ) + ) + + # Force + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, AstLiteral("force_transition_phase"), force_context, body + ) + ) + + def _process_norm(self, norm: Norm, phase: Phase) -> None: + """ + Processes a norm and adds it as an inference rule. + + This method converts norms into AgentSpeak rules that define when + the norm should be active. It handles both basic norms (always active + in their phase) and conditional norms (active only when their condition + is met). + + :param norm: The norm to process. + :param phase: The phase this norm belongs to. + """ + rule: AstRule | None = None + + match norm: + case ConditionalNorm(condition=cond): + rule = AstRule( + self._astify(norm), + self._astify(phase) & self._astify(cond) + | AstAtom(f"force_{self.slugify(norm)}"), + ) + case BasicNorm(): + rule = AstRule(self._astify(norm), self._astify(phase)) + + if not rule: + return + + self._asp.rules.append(rule) + + def _add_default_loop(self, phase: Phase) -> None: + """ + Adds the default execution loop for a phase. + + This method creates the main reactive loop that runs when the agent + receives user input during a phase. The loop: + 1. Notifies the system about the user input + 2. Resets the response tracking + 3. Executes all phase goals + 4. Attempts phase transition + + :param phase: The phase to create the loop for. + """ + actions = [] + + actions.append( + AstStatement( + StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")]) + ) + ) + actions.append(AstStatement(StatementType.REMOVE_BELIEF, AstLiteral("responded_this_turn"))) + + for goal in phase.goals: + actions.append(AstStatement(StatementType.ACHIEVE_GOAL, self._astify(goal))) + + actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("transition_phase"))) + + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_BELIEF, + AstLiteral("user_said", [AstVar("Message")]), + [self._astify(phase)], + actions, + ) + ) + + def _process_goal( + self, + goal: Goal, + phase: Phase, + previous_goal: Goal | None = None, + continues_response: bool = False, + main_goal: bool = False, + ) -> None: + """ + Processes a goal and creates plans for achieving it. + + This method creates two plans for each goal: + 1. A main plan that executes the goal's steps when conditions are met + 2. A fallback plan that provides a default empty implementation (prevents crashes) + + The method also recursively processes any subgoals contained within + the goal's plan. + + :param goal: The goal to process. + :param phase: The phase this goal belongs to. + :param previous_goal: The previous goal in sequence (for dependency tracking). + :param continues_response: Whether this goal continues an existing response. + :param main_goal: Whether this is a main goal (for UI notification purposes). + """ + context: list[AstExpression] = [self._astify(phase)] + context.append(~self._astify(goal, achieved=True)) + if previous_goal and previous_goal.can_fail: + context.append(self._astify(previous_goal, achieved=True)) + if not continues_response: + context.append(~AstLiteral("responded_this_turn")) + + body = [] + if main_goal: # UI only needs to know about the main goals + body.append( + AstStatement( + StatementType.DO_ACTION, + AstLiteral("notify_goal_start", [AstString(self.slugify(goal))]), + ) + ) + + subgoals = [] + for step in goal.plan.steps: + body.append(self._step_to_statement(step)) + if isinstance(step, Goal): + subgoals.append(step) + + if not goal.can_fail and not continues_response: + body.append(AstStatement(StatementType.ADD_BELIEF, self._astify(goal, achieved=True))) + + self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(goal), context, body)) + + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + self._astify(goal), + context=[], + body=[AstStatement(StatementType.EMPTY, AstLiteral("true"))], + ) + ) + + prev_goal = None + for subgoal in subgoals: + self._process_goal(subgoal, phase, prev_goal) + prev_goal = subgoal + + def _step_to_statement(self, step: PlanElement) -> AstStatement: + """ + Converts a plan step to an AgentSpeak statement. + + This method transforms different types of plan elements into their + corresponding AgentSpeak statements. Goals and speech-related actions + become achieve-goal statements, while gesture actions become do-action + statements. + + :param step: The plan element to convert. + :return: The corresponding AgentSpeak statement. + """ + match step: + # Note that SpeechAction gets included in the ACHIEVE_GOAL, since it's a goal internally + case Goal() | SpeechAction() | LLMAction() as a: + return AstStatement(StatementType.ACHIEVE_GOAL, self._astify(a)) + case GestureAction() as a: + return AstStatement(StatementType.DO_ACTION, self._astify(a)) + + def _process_trigger(self, trigger: Trigger, phase: Phase) -> None: + """ + Processes a trigger and creates plans for its execution. + + This method creates plans that execute when trigger conditions are met. + It handles both automatic triggering (when conditions are detected) and + manual forcing (from UI). The trigger execution includes: + 1. Notifying the system about trigger start + 2. Executing all trigger steps + 3. Waiting briefly for UI display + 4. Notifying the system about trigger end + + :param trigger: The trigger to process. + :param phase: The phase this trigger belongs to. + """ + body = [] + subgoals = [] + + body.append( + AstStatement( + StatementType.DO_ACTION, + AstLiteral("notify_trigger_start", [AstString(self.slugify(trigger))]), + ) + ) + for step in trigger.plan.steps: + body.append(self._step_to_statement(step)) + if isinstance(step, Goal): + step.can_fail = False # triggers are continuous sequence + subgoals.append(step) + + # Arbitrary wait for UI to display nicely + body.append( + AstStatement( + StatementType.DO_ACTION, + AstLiteral("wait", [AstNumber(settings.behaviour_settings.trigger_time_to_wait)]), + ) + ) + + body.append( + AstStatement( + StatementType.DO_ACTION, + AstLiteral("notify_trigger_end", [AstString(self.slugify(trigger))]), + ) + ) + + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("check_triggers"), + [self._astify(phase), self._astify(trigger.condition)], + body, + ) + ) + + # Force trigger (from UI) + self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(trigger), [], body)) + + for subgoal in subgoals: + self._process_goal(subgoal, phase, continues_response=True) + + def _add_fallbacks(self): + """ + Adds fallback plans for robust execution, preventing crashes. + + This method creates fallback plans that provide default empty implementations + for key goals. These fallbacks ensure that the system can continue execution + even when no specific plans are applicable, preventing crashes. + + The fallbacks are created for: + - check_triggers: When no triggers are applicable + - transition_phase: When phase transition conditions aren't met + - force_transition_phase: When forced transitions aren't possible + """ + # Trigger fallback + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("check_triggers"), + [], + [AstStatement(StatementType.EMPTY, AstLiteral("true"))], + ) + ) + + # Phase transition fallback + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("transition_phase"), + [], + [AstStatement(StatementType.EMPTY, AstLiteral("true"))], + ) + ) + + # Force phase transition fallback + self._asp.plans.append( + AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("force_transition_phase"), + [], + [AstStatement(StatementType.EMPTY, AstLiteral("true"))], + ) + ) + + @singledispatchmethod + def _astify(self, element: ProgramElement) -> AstExpression: + """ + Converts program elements to AgentSpeak expressions (base method). + + This is the base method for the singledispatch mechanism that handles + conversion of different program element types to their AgentSpeak + representations. Specific implementations are provided for each + element type through registered methods. + + :param element: The program element to convert. + :return: The corresponding AgentSpeak expression. + :raises NotImplementedError: If no specific implementation exists for the element type. + """ + raise NotImplementedError(f"Cannot convert element {element} to an AgentSpeak expression.") + + @_astify.register + def _(self, kwb: KeywordBelief) -> AstExpression: + """ + Converts a KeywordBelief to an AgentSpeak expression. + + Keyword beliefs are converted to keyword_said literals that check + if the keyword was mentioned in user input. + + :param kwb: The KeywordBelief to convert. + :return: An AstLiteral representing the keyword detection. + """ + return AstLiteral("keyword_said", [AstString(kwb.keyword)]) + + @_astify.register + def _(self, sb: SemanticBelief) -> AstExpression: + """ + Converts a SemanticBelief to an AgentSpeak expression. + + Semantic beliefs are converted to literals using their slugified names, + which are used for LLM-based belief evaluation. + + :param sb: The SemanticBelief to convert. + :return: An AstLiteral representing the semantic belief. + """ + return AstLiteral(self.slugify(sb)) + + @_astify.register + def _(self, ib: InferredBelief) -> AstExpression: + """ + Converts an InferredBelief to an AgentSpeak expression. + + Inferred beliefs are converted to binary operations that combine + their left and right operands using the appropriate logical operator. + + :param ib: The InferredBelief to convert. + :return: An AstBinaryOp representing the logical combination. + """ + return AstBinaryOp( + self._astify(ib.left), + BinaryOperatorType.AND if ib.operator == LogicalOperator.AND else BinaryOperatorType.OR, + self._astify(ib.right), + ) + + @_astify.register + def _(self, norm: Norm) -> AstExpression: + """ + Converts a Norm to an AgentSpeak expression. + + Norms are converted to literals with either 'norm' or 'critical_norm' + functors depending on their critical flag, with the norm text as an argument. + + Note that currently, critical norms are not yet functionally supported. They are possible + to astify for future use. + + :param norm: The Norm to convert. + :return: An AstLiteral representing the norm. + """ + functor = "critical_norm" if norm.critical else "norm" + return AstLiteral(functor, [AstString(norm.norm)]) + + @_astify.register + def _(self, phase: Phase) -> AstExpression: + """ + Converts a Phase to an AgentSpeak expression. + + Phases are converted to phase literals with their unique identifier + as an argument, which is used for phase tracking and transitions. + + :param phase: The Phase to convert. + :return: An AstLiteral representing the phase. + """ + return AstLiteral("phase", [AstString(str(phase.id))]) + + @_astify.register + def _(self, goal: Goal, achieved: bool = False) -> AstExpression: + """ + Converts a Goal to an AgentSpeak expression. + + Goals are converted to literals using their slugified names. If the + achieved parameter is True, the literal is prefixed with 'achieved_'. + + :param goal: The Goal to convert. + :param achieved: Whether to represent this as an achieved goal. + :return: An AstLiteral representing the goal. + """ + return AstLiteral(f"{'achieved_' if achieved else ''}{self._slugify_str(goal.name)}") + + @_astify.register + def _(self, trigger: Trigger) -> AstExpression: + """ + Converts a Trigger to an AgentSpeak expression. + + Triggers are converted to literals using their slugified names, + which are used to identify and execute trigger plans. + + :param trigger: The Trigger to convert. + :return: An AstLiteral representing the trigger. + """ + return AstLiteral(self.slugify(trigger)) + + @_astify.register + def _(self, sa: SpeechAction) -> AstExpression: + """ + Converts a SpeechAction to an AgentSpeak expression. + + Speech actions are converted to say literals with the text content + as an argument, which are used for direct speech output. + + :param sa: The SpeechAction to convert. + :return: An AstLiteral representing the speech action. + """ + return AstLiteral("say", [AstString(sa.text)]) + + @_astify.register + def _(self, ga: GestureAction) -> AstExpression: + """ + Converts a GestureAction to an AgentSpeak expression. + + Gesture actions are converted to gesture literals with the gesture + type and name as arguments, which are used for physical robot gestures. + + :param ga: The GestureAction to convert. + :return: An AstLiteral representing the gesture action. + """ + gesture = ga.gesture + return AstLiteral("gesture", [AstString(gesture.type), AstString(gesture.name)]) + + @_astify.register + def _(self, la: LLMAction) -> AstExpression: + """ + Converts an LLMAction to an AgentSpeak expression. + + LLM actions are converted to reply_with_goal literals with the + conversational goal as an argument, which are used for LLM-generated + responses guided by specific goals. + + :param la: The LLMAction to convert. + :return: An AstLiteral representing the LLM action. + """ + return AstLiteral("reply_with_goal", [AstString(la.goal)]) + + @singledispatchmethod + @staticmethod + def slugify(element: ProgramElement) -> str: + """ + Converts program elements to slugs (base method). + + This is the base method for the singledispatch mechanism that handles + conversion of different program element types to their slug representations. + Specific implementations are provided for each element type through + registered methods. + + Slugs are used outside of AgentSpeak, mostly for identifying what to send to the AgentSpeak + program as beliefs. + + :param element: The program element to convert to a slug. + :return: The slug string representation. + :raises NotImplementedError: If no specific implementation exists for the element type. + """ + raise NotImplementedError(f"Cannot convert element {element} to a slug.") + + @slugify.register + @staticmethod + def _(n: Norm) -> str: + """ + Converts a Norm to a slug. + + Norms are converted to slugs with the 'norm_' prefix followed by + the slugified norm text. + + :param n: The Norm to convert. + :return: The slug string representation. + """ + return f"norm_{AgentSpeakGenerator._slugify_str(n.norm)}" + + @slugify.register + @staticmethod + def _(sb: SemanticBelief) -> str: + """ + Converts a SemanticBelief to a slug. + + Semantic beliefs are converted to slugs with the 'semantic_' prefix + followed by the slugified belief name. + + :param sb: The SemanticBelief to convert. + :return: The slug string representation. + """ + return f"semantic_{AgentSpeakGenerator._slugify_str(sb.name)}" + + @slugify.register + @staticmethod + def _(g: BaseGoal) -> str: + """ + Converts a BaseGoal to a slug. + + Goals are converted to slugs using their slugified names directly. + + :param g: The BaseGoal to convert. + :return: The slug string representation. + """ + return AgentSpeakGenerator._slugify_str(g.name) + + @slugify.register + @staticmethod + def _(t: Trigger) -> str: + """ + Converts a Trigger to a slug. + + Triggers are converted to slugs with the 'trigger_' prefix followed by + the slugified trigger name. + + :param t: The Trigger to convert. + :return: The slug string representation. + """ + return f"trigger_{AgentSpeakGenerator._slugify_str(t.name)}" + + @staticmethod + def _slugify_str(text: str) -> str: + """ + Converts a text string to a slug. + + This helper method converts arbitrary text to a URL-friendly slug format + by converting to lowercase, removing special characters, and replacing + spaces with underscores. It also removes common stopwords. + + :param text: The text string to convert. + :return: The slugified string. + """ + return slugify(text, separator="_", stopwords=["a", "an", "the", "we", "you", "I"]) diff --git a/src/control_backend/agents/bdi/bdi_core_agent.py b/src/control_backend/agents/bdi/bdi_core_agent.py index f056e09..628bb53 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent.py +++ b/src/control_backend/agents/bdi/bdi_core_agent.py @@ -1,5 +1,6 @@ import asyncio import copy +import json import time from collections.abc import Iterable @@ -11,9 +12,9 @@ from pydantic import ValidationError from control_backend.agents.base import BaseAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief, BeliefMessage +from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.llm_prompt_message import LLMPromptMessage -from control_backend.schemas.ri_message import SpeechCommand +from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, SpeechCommand DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak @@ -42,13 +43,13 @@ class BDICoreAgent(BaseAgent): bdi_agent: agentspeak.runtime.Agent - def __init__(self, name: str, asl: str): + def __init__(self, name: str): super().__init__(name) - self.asl_file = asl self.env = agentspeak.runtime.Environment() # Deep copy because we don't actually want to modify the standard actions globally self.actions = copy.deepcopy(agentspeak.stdlib.actions) self._wake_bdi_loop = asyncio.Event() + self._bdi_loop_task = None async def setup(self) -> None: """ @@ -65,19 +66,22 @@ class BDICoreAgent(BaseAgent): await self._load_asl() # Start the BDI cycle loop - self.add_behavior(self._bdi_loop()) + self._bdi_loop_task = self.add_behavior(self._bdi_loop()) self._wake_bdi_loop.set() self.logger.debug("Setup complete.") - async def _load_asl(self): + async def _load_asl(self, file_name: str | None = None) -> None: """ Load and parse the AgentSpeak source file. """ + file_name = file_name or "src/control_backend/agents/bdi/default_behavior.asl" + try: - with open(self.asl_file) as source: + with open(file_name) as source: self.bdi_agent = self.env.build_agent(source, self.actions) + self.logger.info(f"Loaded new ASL from {file_name}.") except FileNotFoundError: - self.logger.warning(f"Could not find the specified ASL file at {self.asl_file}.") + self.logger.warning(f"Could not find the specified ASL file at {file_name}.") self.bdi_agent = agentspeak.runtime.Agent(self.env, self.name) async def _bdi_loop(self): @@ -97,14 +101,12 @@ class BDICoreAgent(BaseAgent): maybe_more_work = True while maybe_more_work: maybe_more_work = False - self.logger.debug("Stepping BDI.") if self.bdi_agent.step(): maybe_more_work = True if not maybe_more_work: deadline = self.bdi_agent.shortest_deadline() if deadline: - self.logger.debug("Sleeping until %s", deadline) await asyncio.sleep(deadline - time.time()) maybe_more_work = True else: @@ -116,6 +118,7 @@ class BDICoreAgent(BaseAgent): Handle incoming messages. - **Beliefs**: Updates the internal belief base. + - **Program**: Updates the internal agentspeak file to match the current program. - **LLM Responses**: Forwards the generated text to the Robot Speech Agent (actuation). :param msg: The received internal message. @@ -124,12 +127,19 @@ class BDICoreAgent(BaseAgent): if msg.thread == "beliefs": try: - beliefs = BeliefMessage.model_validate_json(msg.body).beliefs - self._apply_beliefs(beliefs) + belief_changes = BeliefMessage.model_validate_json(msg.body) + self._apply_belief_changes(belief_changes) except ValidationError: self.logger.exception("Error processing belief.") return + # New agentspeak file + if msg.thread == "new_program": + if self._bdi_loop_task: + self._bdi_loop_task.cancel() + await self._load_asl(msg.body) + self.add_behavior(self._bdi_loop()) + # The message was not a belief, handle special cases based on sender match msg.sender: case settings.agent_settings.llm_name: @@ -144,23 +154,44 @@ class BDICoreAgent(BaseAgent): body=cmd.model_dump_json(), ) await self.send(out_msg) + case settings.agent_settings.user_interrupt_name: + self.logger.debug("Received user interruption: %s", msg) - def _apply_beliefs(self, beliefs: list[Belief]): + match msg.thread: + case "force_phase_transition": + self._set_goal("transition_phase") + case "force_trigger": + self._force_trigger(msg.body) + case "force_norm": + self._force_norm(msg.body) + case "force_next_phase": + self._force_next_phase() + case _: + self.logger.warning("Received unknown user interruption: %s", msg) + + def _apply_belief_changes(self, belief_changes: BeliefMessage): """ Update the belief base with a list of new beliefs. - If ``replace=True`` is set on a belief, it removes all existing beliefs with that name - before adding the new one. + For beliefs in ``belief_changes.replace``, it removes all existing beliefs with that name + before adding one new one. + + :param belief_changes: The changes in beliefs to apply. """ - if not beliefs: + if not belief_changes.create and not belief_changes.replace and not belief_changes.delete: return - for belief in beliefs: - if belief.replace: - self._remove_all_with_name(belief.name) + for belief in belief_changes.create: self._add_belief(belief.name, belief.arguments) - def _add_belief(self, name: str, args: Iterable[str] = []): + for belief in belief_changes.replace: + self._remove_all_with_name(belief.name) + self._add_belief(belief.name, belief.arguments) + + for belief in belief_changes.delete: + self._remove_belief(belief.name, belief.arguments) + + def _add_belief(self, name: str, args: list[str] = None): """ Add a single belief to the BDI agent. @@ -168,9 +199,13 @@ class BDICoreAgent(BaseAgent): :param args: Arguments for the belief. """ # new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple - merged_args = DELIMITER.join(arg for arg in args) - new_args = (agentspeak.Literal(merged_args),) - term = agentspeak.Literal(name, new_args) + args = args or [] + if args: + merged_args = DELIMITER.join(arg for arg in args) + new_args = (agentspeak.Literal(merged_args),) + term = agentspeak.Literal(name, new_args) + else: + term = agentspeak.Literal(name) self.bdi_agent.call( agentspeak.Trigger.addition, @@ -179,16 +214,35 @@ class BDICoreAgent(BaseAgent): agentspeak.runtime.Intention(), ) + # Check for transitions + self.bdi_agent.call( + agentspeak.Trigger.addition, + agentspeak.GoalType.achievement, + agentspeak.Literal("transition_phase"), + agentspeak.runtime.Intention(), + ) + + # Check triggers + self.bdi_agent.call( + agentspeak.Trigger.addition, + agentspeak.GoalType.achievement, + agentspeak.Literal("check_triggers"), + agentspeak.runtime.Intention(), + ) + self._wake_bdi_loop.set() self.logger.debug(f"Added belief {self.format_belief_string(name, args)}") - def _remove_belief(self, name: str, args: Iterable[str]): + def _remove_belief(self, name: str, args: Iterable[str] | None): """ Removes a specific belief (with arguments), if it exists. """ - new_args = (agentspeak.Literal(arg) for arg in args) - term = agentspeak.Literal(name, new_args) + if args is None: + term = agentspeak.Literal(name) + else: + new_args = (agentspeak.Literal(arg) for arg in args) + term = agentspeak.Literal(name, new_args) result = self.bdi_agent.call( agentspeak.Trigger.removal, @@ -228,6 +282,43 @@ class BDICoreAgent(BaseAgent): self.logger.debug(f"Removed {removed_count} beliefs.") + def _set_goal(self, name: str, args: Iterable[str] | None = None): + args = args or [] + + if args: + merged_args = DELIMITER.join(arg for arg in args) + new_args = (agentspeak.Literal(merged_args),) + term = agentspeak.Literal(name, new_args) + else: + term = agentspeak.Literal(name) + + self.bdi_agent.call( + agentspeak.Trigger.addition, + agentspeak.GoalType.achievement, + term, + agentspeak.runtime.Intention(), + ) + + self._wake_bdi_loop.set() + + self.logger.debug(f"Set goal !{self.format_belief_string(name, args)}.") + + def _force_trigger(self, name: str): + self._set_goal(name) + + self.logger.info("Manually forced trigger %s.", name) + + # TODO: make this compatible for critical norms + def _force_norm(self, name: str): + self._add_belief(f"force_{name}") + + self.logger.info("Manually forced norm %s.", name) + + def _force_next_phase(self): + self._set_goal("force_transition_phase") + + self.logger.info("Manually forced phase transition.") + def _add_custom_actions(self) -> None: """ Add any custom actions here. Inside `@self.actions.add()`, the first argument is @@ -235,43 +326,213 @@ class BDICoreAgent(BaseAgent): the function expects (which will be located in `term.args`). """ - @self.actions.add(".reply", 3) - def _reply(agent: "BDICoreAgent", term, intention): + @self.actions.add(".reply", 2) + def _reply(agent, term, intention): """ - Sends text to the LLM (AgentSpeak action). - Example: .reply("Hello LLM!", "Some norm", "Some goal") + Let the LLM generate a response to a user's utterance with the current norms and goals. """ message_text = agentspeak.grounded(term.args[0], intention.scope) norms = agentspeak.grounded(term.args[1], intention.scope) - goals = agentspeak.grounded(term.args[2], intention.scope) - self.logger.debug("Norms: %s", norms) - self.logger.debug("Goals: %s", goals) - self.logger.debug("User text: %s", message_text) - - asyncio.create_task(self._send_to_llm(str(message_text), str(norms), str(goals))) + self.add_behavior(self._send_to_llm(str(message_text), str(norms), "")) yield - async def _send_to_llm(self, text: str, norms: str = None, goals: str = None): + @self.actions.add(".reply_with_goal", 3) + def _reply_with_goal(agent: "BDICoreAgent", term, intention): + """ + Let the LLM generate a response to a user's utterance with the current norms and a + specific goal. + """ + message_text = agentspeak.grounded(term.args[0], intention.scope) + norms = agentspeak.grounded(term.args[1], intention.scope) + goal = agentspeak.grounded(term.args[2], intention.scope) + self.add_behavior(self._send_to_llm(str(message_text), str(norms), str(goal))) + yield + + @self.actions.add(".notify_norms", 1) + def _notify_norms(agent, term, intention): + norms = agentspeak.grounded(term.args[0], intention.scope) + + norm_update_message = InternalMessage( + to=settings.agent_settings.user_interrupt_name, + thread="active_norms_update", + body=str(norms), + ) + + self.add_behavior(self.send(norm_update_message, should_log=False)) + yield + + @self.actions.add(".say", 1) + def _say(agent, term, intention): + """ + Make the robot say the given text instantly. + """ + message_text = agentspeak.grounded(term.args[0], intention.scope) + + self.logger.debug('"say" action called with text=%s', message_text) + + speech_command = SpeechCommand(data=message_text) + speech_message = InternalMessage( + to=settings.agent_settings.robot_speech_name, + sender=settings.agent_settings.bdi_core_name, + body=speech_command.model_dump_json(), + ) + + self.add_behavior(self.send(speech_message)) + + chat_history_message = InternalMessage( + to=settings.agent_settings.llm_name, + thread="assistant_message", + body=str(message_text), + ) + + self.add_behavior(self.send(chat_history_message)) + + yield + + @self.actions.add(".gesture", 2) + def _gesture(agent, term, intention): + """ + Make the robot perform the given gesture instantly. + """ + gesture_type = agentspeak.grounded(term.args[0], intention.scope) + gesture_name = agentspeak.grounded(term.args[1], intention.scope) + + self.logger.debug( + '"gesture" action called with type=%s, name=%s', + gesture_type, + gesture_name, + ) + + if str(gesture_type) == "single": + endpoint = RIEndpoint.GESTURE_SINGLE + elif str(gesture_type) == "tag": + endpoint = RIEndpoint.GESTURE_TAG + else: + self.logger.warning("Gesture type %s could not be resolved.", gesture_type) + endpoint = RIEndpoint.GESTURE_SINGLE + + gesture_command = GestureCommand(endpoint=endpoint, data=gesture_name) + gesture_message = InternalMessage( + to=settings.agent_settings.robot_gesture_name, + sender=settings.agent_settings.bdi_core_name, + body=gesture_command.model_dump_json(), + ) + self.add_behavior(self.send(gesture_message)) + yield + + @self.actions.add(".notify_user_said", 1) + def _notify_user_said(agent, term, intention): + user_said = agentspeak.grounded(term.args[0], intention.scope) + + msg = InternalMessage( + to=settings.agent_settings.llm_name, thread="user_message", body=str(user_said) + ) + + self.add_behavior(self.send(msg)) + + yield + + @self.actions.add(".notify_trigger_start", 1) + def _notify_trigger_start(agent, term, intention): + """ + Notify the UI about the trigger we just started doing. + """ + trigger_name = agentspeak.grounded(term.args[0], intention.scope) + + self.logger.debug("Started trigger %s", trigger_name) + + msg = InternalMessage( + to=settings.agent_settings.user_interrupt_name, + sender=self.name, + thread="trigger_start", + body=str(trigger_name), + ) + + # TODO: check with Pim + self.add_behavior(self.send(msg)) + + yield + + @self.actions.add(".notify_trigger_end", 1) + def _notify_trigger_end(agent, term, intention): + """ + Notify the UI about the trigger we just started doing. + """ + trigger_name = agentspeak.grounded(term.args[0], intention.scope) + + self.logger.debug("Finished trigger %s", trigger_name) + + msg = InternalMessage( + to=settings.agent_settings.user_interrupt_name, + sender=self.name, + thread="trigger_end", + body=str(trigger_name), + ) + + self.add_behavior(self.send(msg)) + + yield + + @self.actions.add(".notify_goal_start", 1) + def _notify_goal_start(agent, term, intention): + """ + Notify the UI about the goal we just started chasing. + """ + goal_name = agentspeak.grounded(term.args[0], intention.scope) + + self.logger.debug("Started chasing goal %s", goal_name) + + msg = InternalMessage( + to=settings.agent_settings.user_interrupt_name, + sender=self.name, + thread="goal_start", + body=str(goal_name), + ) + + self.add_behavior(self.send(msg)) + + yield + + @self.actions.add(".notify_transition_phase", 2) + def _notify_transition_phase(agent, term, intention): + """ + Notify the BDI program manager about a phase transition. + """ + old = agentspeak.grounded(term.args[0], intention.scope) + new = agentspeak.grounded(term.args[1], intention.scope) + + msg = InternalMessage( + to=settings.agent_settings.bdi_program_manager_name, + thread="transition_phase", + body=json.dumps({"old": str(old), "new": str(new)}), + ) + + self.add_behavior(self.send(msg)) + + yield + + @self.actions.add(".notify_ui", 0) + def _notify_ui(agent, term, intention): + pass + + async def _send_to_llm(self, text: str, norms: str, goals: str): """ Sends a text query to the LLM agent asynchronously. """ - prompt = LLMPromptMessage( - text=text, - norms=norms.split("\n") if norms else [], - goals=goals.split("\n") if norms else [], - ) + prompt = LLMPromptMessage(text=text, norms=norms.split("\n"), goals=goals.split("\n")) msg = InternalMessage( to=settings.agent_settings.llm_name, sender=self.name, body=prompt.model_dump_json(), + thread="prompt_message", ) await self.send(msg) self.logger.info("Message sent to LLM agent: %s", text) @staticmethod - def format_belief_string(name: str, args: Iterable[str] = []): + def format_belief_string(name: str, args: Iterable[str] | None = []): """ Given a belief's name and its args, return a string of the form "name(*args)" """ - return f"{name}{'(' if args else ''}{','.join(args)}{')' if args else ''}" + return f"{name}{'(' if args else ''}{','.join(args or [])}{')' if args else ''}" diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index 2f4f850..54c9983 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -1,12 +1,23 @@ +import asyncio +import json + import zmq from pydantic import ValidationError from zmq.asyncio import Context from control_backend.agents import BaseAgent -from control_backend.core.agent_system import InternalMessage +from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief, BeliefMessage -from control_backend.schemas.program import Program +from control_backend.schemas.belief_list import BeliefList, GoalList +from control_backend.schemas.internal_message import InternalMessage +from control_backend.schemas.program import ( + Belief, + ConditionalNorm, + Goal, + InferredBelief, + Phase, + Program, +) class BDIProgramManager(BaseAgent): @@ -14,51 +25,236 @@ class BDIProgramManager(BaseAgent): BDI Program Manager Agent. This agent is responsible for receiving high-level programs (sequences of instructions/goals) - from the external HTTP API (via ZMQ) and translating them into core beliefs (norms and goals) - for the BDI Core Agent. In the future, it will be responsible for determining when goals are - met, and passing on new norms and goals accordingly. + from the external HTTP API (via ZMQ), transforming it into an AgentSpeak program, sharing the + program and its components to other agents, and keeping agents informed of the current state. :ivar sub_socket: The ZMQ SUB socket used to receive program updates. + :ivar _program: The current Program. + :ivar _phase: The current Phase. + :ivar _goal_mapping: A mapping of goal IDs to goals. """ + _program: Program + _phase: Phase | None + def __init__(self, **kwargs): super().__init__(**kwargs) self.sub_socket = None + self._goal_mapping: dict[str, Goal] = {} - async def _send_to_bdi(self, program: Program): + def _initialize_internal_state(self, program: Program): """ - Convert a received program into BDI beliefs and send them to the BDI Core Agent. + Initialize the state of the program manager given a new Program. Reset the tracking of the + current phase to the first phase, make a mapping of goal IDs to goals, used during the life + of the program. + :param program: The new program. + """ + self._program = program + self._phase = program.phases[0] # start in first phase + self._goal_mapping = {} + for phase in program.phases: + for goal in phase.goals: + self._populate_goal_mapping_with_goal(goal) - Currently, it takes the **first phase** of the program and extracts: - - **Norms**: Constraints or rules the agent must follow. - - **Goals**: Objectives the agent must achieve. + def _populate_goal_mapping_with_goal(self, goal: Goal): + """ + Recurse through the given goal and its subgoals and add all goals found to the + ``self._goal_mapping``. + :param goal: The goal to add to the ``self._goal_mapping``, including subgoals. + """ + self._goal_mapping[str(goal.id)] = goal + for step in goal.plan.steps: + if isinstance(step, Goal): + self._populate_goal_mapping_with_goal(step) - These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will - overwrite any existing norms/goals of the same name in the BDI agent. + async def _create_agentspeak_and_send_to_bdi(self, program: Program): + """ + Convert a received program into an AgentSpeak file and send it to the BDI Core Agent. :param program: The program object received from the API. """ - first_phase = program.phases[0] - norms_belief = Belief( - name="norms", - arguments=[norm.norm for norm in first_phase.norms], - replace=True, + asg = AgentSpeakGenerator() + + asl_str = asg.generate(program) + + file_name = settings.behaviour_settings.agentspeak_file + + with open(file_name, "w") as f: + f.write(asl_str) + + msg = InternalMessage( + sender=self.name, + to=settings.agent_settings.bdi_core_name, + body=file_name, + thread="new_program", ) - goals_belief = Belief( - name="goals", - arguments=[goal.description for goal in first_phase.goals], - replace=True, + + await self.send(msg) + + async def handle_message(self, msg: InternalMessage): + match msg.thread: + case "transition_phase": + phases = json.loads(msg.body) + + await self._transition_phase(phases["old"], phases["new"]) + case "achieve_goal": + goal_id = msg.body + await self._send_achieved_goal_to_semantic_belief_extractor(goal_id) + + async def _transition_phase(self, old: str, new: str): + """ + When receiving a signal from the BDI core that the phase has changed, apply this change to + the current state and inform other agents about the change. + + :param old: The ID of the old phase. + :param new: The ID of the new phase. + """ + if self._phase is None: + return + + if old != str(self._phase.id): + self.logger.warning( + f"Phase transition desync detected! ASL requested move from '{old}', " + f"but Python is currently in '{self._phase.id}'. Request ignored." + ) + return + + if new == "end": + self._phase = None + # Notify user interaction agent + msg = InternalMessage( + to=settings.agent_settings.user_interrupt_name, + thread="transition_phase", + body="end", + ) + self.logger.info("Transitioned to end phase, notifying UserInterruptAgent.") + + self.add_behavior(self.send(msg)) + return + + for phase in self._program.phases: + if str(phase.id) == new: + self._phase = phase + + await self._send_beliefs_to_semantic_belief_extractor() + await self._send_goals_to_semantic_belief_extractor() + + # Notify user interaction agent + msg = InternalMessage( + to=settings.agent_settings.user_interrupt_name, + thread="transition_phase", + body=str(self._phase.id), ) - program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief]) + self.logger.info(f"Transitioned to phase {new}, notifying UserInterruptAgent.") + + self.add_behavior(self.send(msg)) + + def _extract_current_beliefs(self) -> list[Belief]: + """Extract beliefs from the current phase.""" + assert self._phase is not None, ( + "Invalid state, no phase set. Call this method only when " + "a program has been received and the end-phase has not " + "been reached." + ) + + beliefs: list[Belief] = [] + + for norm in self._phase.norms: + if isinstance(norm, ConditionalNorm): + beliefs += self._extract_beliefs_from_belief(norm.condition) + + for trigger in self._phase.triggers: + beliefs += self._extract_beliefs_from_belief(trigger.condition) + + return beliefs + + @staticmethod + def _extract_beliefs_from_belief(belief: Belief) -> list[Belief]: + """Recursively extract beliefs from the given belief.""" + if isinstance(belief, InferredBelief): + return BDIProgramManager._extract_beliefs_from_belief( + belief.left + ) + BDIProgramManager._extract_beliefs_from_belief(belief.right) + return [belief] + + async def _send_beliefs_to_semantic_belief_extractor(self): + """Extract beliefs from the program and send them to the Semantic Belief Extractor Agent.""" + beliefs = BeliefList(beliefs=self._extract_current_beliefs()) message = InternalMessage( - to=settings.agent_settings.bdi_core_name, + to=settings.agent_settings.text_belief_extractor_name, sender=self.name, - body=program_beliefs.model_dump_json(), + body=beliefs.model_dump_json(), thread="beliefs", ) + + await self.send(message) + + @staticmethod + def _extract_goals_from_goal(goal: Goal) -> list[Goal]: + """ + Extract all goals from a given goal, that is: the goal itself and any subgoals. + + :return: All goals within and including the given goal. + """ + goals: list[Goal] = [goal] + for plan in goal.plan: + if isinstance(plan, Goal): + goals.extend(BDIProgramManager._extract_goals_from_goal(plan)) + return goals + + def _extract_current_goals(self) -> list[Goal]: + """ + Extract all goals from the program, including subgoals. + + :return: A list of Goal objects. + """ + assert self._phase is not None, ( + "Invalid state, no phase set. Call this method only when " + "a program has been received and the end-phase has not " + "been reached." + ) + + goals: list[Goal] = [] + + for goal in self._phase.goals: + goals.extend(self._extract_goals_from_goal(goal)) + + return goals + + async def _send_goals_to_semantic_belief_extractor(self): + """ + Extract goals for the current phase and send them to the Semantic Belief Extractor Agent. + """ + goals = GoalList(goals=self._extract_current_goals()) + + message = InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=self.name, + body=goals.model_dump_json(), + thread="goals", + ) + + await self.send(message) + + async def _send_achieved_goal_to_semantic_belief_extractor(self, achieved_goal_id: str): + """ + Inform the semantic belief extractor when a goal is marked achieved. + + :param achieved_goal_id: The id of the achieved goal. + """ + goal = self._goal_mapping.get(achieved_goal_id) + if goal is None: + self.logger.debug(f"Goal with ID {achieved_goal_id} marked achieved but was not found.") + return + + goals = self._extract_goals_from_goal(goal) + message = InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + body=GoalList(goals=goals).model_dump_json(), + thread="achieved_goals", + ) await self.send(message) - self.logger.debug("Sent new norms and goals to the BDI agent.") async def _send_clear_llm_history(self): """ @@ -68,13 +264,19 @@ class BDIProgramManager(BaseAgent): """ message = InternalMessage( to=settings.agent_settings.llm_name, - sender=self.name, body="clear_history", - threads="clear history message", ) await self.send(message) self.logger.debug("Sent message to LLM agent to clear history.") + extractor_msg = InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + thread="conversation_history", + body="reset", + ) + await self.send(extractor_msg) + self.logger.debug("Sent message to extractor agent to clear history.") + async def _receive_programs(self): """ Continuous loop that receives program updates from the HTTP endpoint. @@ -88,20 +290,44 @@ class BDIProgramManager(BaseAgent): try: program = Program.model_validate_json(body) - await self._send_to_bdi(program) - await self._send_clear_llm_history() - except ValidationError: - self.logger.exception("Received an invalid program.") + self.logger.warning("Received an invalid program.") continue + self._initialize_internal_state(program) + await self._send_program_to_user_interrupt(program) + await self._send_clear_llm_history() + + await asyncio.gather( + self._create_agentspeak_and_send_to_bdi(program), + self._send_beliefs_to_semantic_belief_extractor(), + self._send_goals_to_semantic_belief_extractor(), + ) + + async def _send_program_to_user_interrupt(self, program: Program): + """ + Send the received program to the User Interrupt Agent. + + :param program: The program object received from the API. + """ + msg = InternalMessage( + sender=self.name, + to=settings.agent_settings.user_interrupt_name, + body=program.model_dump_json(), + thread="new_program", + ) + + await self.send(msg) + async def setup(self): """ Initialize the agent. Connects the internal ZMQ SUB socket and subscribes to the 'program' topic. - Starts the background behavior to receive programs. + Starts the background behavior to receive programs. Initializes a default program. """ + await self._create_agentspeak_and_send_to_bdi(Program(phases=[])) + context = Context.instance() self.sub_socket = context.socket(zmq.SUB) diff --git a/src/control_backend/agents/bdi/belief_collector_agent.py b/src/control_backend/agents/bdi/belief_collector_agent.py deleted file mode 100644 index 788cff1..0000000 --- a/src/control_backend/agents/bdi/belief_collector_agent.py +++ /dev/null @@ -1,152 +0,0 @@ -import json - -from pydantic import ValidationError - -from control_backend.agents.base import BaseAgent -from control_backend.core.agent_system import InternalMessage -from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief, BeliefMessage - - -class BDIBeliefCollectorAgent(BaseAgent): - """ - BDI Belief Collector Agent. - - This agent acts as a central aggregator for beliefs derived from various sources (e.g., text, - emotion, vision). It receives raw extracted data from other agents, - normalizes them into valid :class:`Belief` objects, and forwards them as a unified packet to the - BDI Core Agent. - - It serves as a funnel to ensure the BDI agent receives a consistent stream of beliefs. - """ - - async def setup(self): - """ - Initialize the agent. - """ - self.logger.info("Setting up %s", self.name) - - async def handle_message(self, msg: InternalMessage): - """ - Handle incoming messages from other extractor agents. - - Routes the message to specific handlers based on the 'type' field in the JSON body. - Supported types: - - ``belief_extraction_text``: Handled by :meth:`_handle_belief_text` - - ``emotion_extraction_text``: Handled by :meth:`_handle_emo_text` - - :param msg: The received internal message. - """ - sender_node = msg.sender - - # Parse JSON payload - try: - payload = json.loads(msg.body) - except Exception as e: - self.logger.warning( - "BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s", - sender_node, - msg.body, - e, - ) - return - - msg_type = payload.get("type") - - # Prefer explicit 'type' field - if msg_type == "belief_extraction_text": - self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node) - await self._handle_belief_text(payload, sender_node) - # This is not implemented yet, but we keep the structure for future use - elif msg_type == "emotion_extraction_text": - self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node) - await self._handle_emo_text(payload, sender_node) - else: - self.logger.warning( - "Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type - ) - - async def _handle_belief_text(self, payload: dict, origin: str): - """ - Process text-based belief extraction payloads. - - Expected payload format:: - - { - "type": "belief_extraction_text", - "beliefs": { - "user_said": ["Can you help me?"], - "intention": ["ask_help"] - } - } - - Validates and converts the dictionary items into :class:`Belief` objects. - - :param payload: The dictionary payload containing belief data. - :param origin: The name of the sender agent. - """ - beliefs = payload.get("beliefs", {}) - - if not beliefs: - self.logger.debug("Received empty beliefs set.") - return - - def try_create_belief(name, arguments) -> Belief | None: - """ - Create a belief object from name and arguments, or return None silently if the input is - not correct. - - :param name: The name of the belief. - :param arguments: The arguments of the belief. - :return: A Belief object if the input is valid or None. - """ - try: - return Belief(name=name, arguments=arguments) - except ValidationError: - return None - - beliefs = [ - belief - for name, arguments in beliefs.items() - if (belief := try_create_belief(name, arguments)) is not None - ] - - self.logger.debug("Forwarding %d beliefs.", len(beliefs)) - for belief in beliefs: - for argument in belief.arguments: - self.logger.debug(" - %s %s", belief.name, argument) - - await self._send_beliefs_to_bdi(beliefs, origin=origin) - - async def _handle_emo_text(self, payload: dict, origin: str): - """ - Process emotion extraction payloads. - - **TODO**: Implement this method once emotion recognition is integrated. - - :param payload: The dictionary payload containing emotion data. - :param origin: The name of the sender agent. - """ - pass - - async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None): - """ - Send a list of aggregated beliefs to the BDI Core Agent. - - Wraps the beliefs in a :class:`BeliefMessage` and sends it via the 'beliefs' thread. - - :param beliefs: The list of Belief objects to send. - :param origin: (Optional) The original source of the beliefs (unused currently). - """ - if not beliefs: - return - - msg = InternalMessage( - to=settings.agent_settings.bdi_core_name, - sender=self.name, - body=BeliefMessage(beliefs=beliefs).model_dump_json(), - thread="beliefs", - ) - - await self.send(msg) - self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs)) diff --git a/src/control_backend/agents/bdi/default_behavior.asl b/src/control_backend/agents/bdi/default_behavior.asl new file mode 100644 index 0000000..b4d6682 --- /dev/null +++ b/src/control_backend/agents/bdi/default_behavior.asl @@ -0,0 +1,34 @@ +phase("end"). +keyword_said(Keyword) :- (user_said(Message) & .substring(Keyword, Message, Pos)) & (Pos >= 0). + + ++!reply_with_goal(Goal) + : user_said(Message) + <- +responded_this_turn; + .findall(Norm, norm(Norm), Norms); + .reply_with_goal(Message, Norms, Goal). + ++!say(Text) + <- +responded_this_turn; + .say(Text). + ++!reply + : user_said(Message) + <- +responded_this_turn; + .findall(Norm, norm(Norm), Norms); + .reply(Message, Norms). + ++!notify_cycle + <- .notify_ui; + .wait(1). + ++user_said(Message) + : phase("end") + <- .notify_user_said(Message); + !reply. + ++!check_triggers + <- true. + ++!transition_phase + <- true. diff --git a/src/control_backend/agents/bdi/rules.asl b/src/control_backend/agents/bdi/rules.asl deleted file mode 100644 index cc9b4ef..0000000 --- a/src/control_backend/agents/bdi/rules.asl +++ /dev/null @@ -1,6 +0,0 @@ -norms(""). -goals(""). - -+user_said(Message) : norms(Norms) & goals(Goals) <- - -user_said(Message); - .reply(Message, Norms, Goals). diff --git a/src/control_backend/agents/bdi/text_belief_extractor_agent.py b/src/control_backend/agents/bdi/text_belief_extractor_agent.py index 0f2db01..bdbc2a7 100644 --- a/src/control_backend/agents/bdi/text_belief_extractor_agent.py +++ b/src/control_backend/agents/bdi/text_belief_extractor_agent.py @@ -1,8 +1,52 @@ +import asyncio import json +import httpx +from pydantic import BaseModel, ValidationError + from control_backend.agents.base import BaseAgent +from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.belief_list import BeliefList, GoalList +from control_backend.schemas.belief_message import Belief as InternalBelief +from control_backend.schemas.belief_message import BeliefMessage +from control_backend.schemas.chat_history import ChatHistory, ChatMessage +from control_backend.schemas.program import BaseGoal, SemanticBelief + +type JSONLike = None | bool | int | float | str | list["JSONLike"] | dict[str, "JSONLike"] + + +class BeliefState(BaseModel): + """ + Represents the state of inferred semantic beliefs. + + Maintains sets of beliefs that are currently considered true or false. + """ + + true: set[InternalBelief] = set() + false: set[InternalBelief] = set() + + def difference(self, other: "BeliefState") -> "BeliefState": + return BeliefState( + true=self.true - other.true, + false=self.false - other.false, + ) + + def union(self, other: "BeliefState") -> "BeliefState": + return BeliefState( + true=self.true | other.true, + false=self.false | other.false, + ) + + def __sub__(self, other): + return self.difference(other) + + def __or__(self, other): + return self.union(other) + + def __bool__(self): + return bool(self.true) or bool(self.false) class TextBeliefExtractorAgent(BaseAgent): @@ -12,54 +56,488 @@ class TextBeliefExtractorAgent(BaseAgent): This agent is responsible for processing raw text (e.g., from speech transcription) and extracting semantic beliefs from it. - In the current demonstration version, it performs a simple wrapping of the user's input - into a ``user_said`` belief. In a full implementation, this agent would likely interact - with an LLM or NLU engine to extract intent, entities, and other structured information. + It uses the available beliefs received from the program manager to try to extract beliefs from a + user's message, sends and updated beliefs to the BDI core, and forms a ``user_said`` belief from + the message itself. """ + def __init__(self, name: str): + super().__init__(name) + self._llm = self.LLM(self, settings.llm_settings.n_parallel) + self.belief_inferrer = SemanticBeliefInferrer(self._llm) + self.goal_inferrer = GoalAchievementInferrer(self._llm) + self._current_beliefs = BeliefState() + self._current_goal_completions: dict[str, bool] = {} + self._force_completed_goals: set[BaseGoal] = set() + self.conversation = ChatHistory(messages=[]) + async def setup(self): """ Initialize the agent and its resources. """ - self.logger.info("Settting up %s.", self.name) - # Setup LLM belief context if needed (currently demo is just passthrough) - self.beliefs = {"mood": ["X"], "car": ["Y"]} + self.logger.info("Setting up %s.", self.name) async def handle_message(self, msg: InternalMessage): """ - Handle incoming messages, primarily from the Transcription Agent. + Handle incoming messages. Expect messages from the Transcriber agent, LLM agent, and the + Program manager agent. - :param msg: The received message containing transcribed text. + :param msg: The received message. """ sender = msg.sender - if sender == settings.agent_settings.transcription_name: - self.logger.debug("Received text from transcriber: %s", msg.body) - await self._process_transcription_demo(msg.body) - else: - self.logger.info("Discarding message from %s", sender) - async def _process_transcription_demo(self, txt: str): + match sender: + case settings.agent_settings.transcription_name: + self.logger.debug("Received text from transcriber: %s", msg.body) + self._apply_conversation_message(ChatMessage(role="user", content=msg.body)) + await self._user_said(msg.body) + await self._infer_new_beliefs() + await self._infer_goal_completions() + case settings.agent_settings.llm_name: + self.logger.debug("Received text from LLM: %s", msg.body) + self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body)) + case settings.agent_settings.bdi_program_manager_name: + await self._handle_program_manager_message(msg) + case _: + self.logger.info("Discarding message from %s", sender) + return + + def _apply_conversation_message(self, message: ChatMessage): """ - Process the transcribed text and generate beliefs. + Save the chat message to our conversation history, taking into account the conversation + length limit. - **Demo Implementation:** - Currently, this method takes the raw text ``txt`` and wraps it into a belief structure: - ``user_said("txt")``. - - This belief is then sent to the :class:`BDIBeliefCollectorAgent`. - - :param txt: The raw transcribed text string. + :param message: The chat message to add to the conversation history. """ - # For demo, just wrapping user text as user_said belief - belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"} - payload = json.dumps(belief) + length_limit = settings.behaviour_settings.conversation_history_length_limit + self.conversation.messages = (self.conversation.messages + [message])[-length_limit:] - belief_msg = InternalMessage( - to=settings.agent_settings.bdi_belief_collector_name, - sender=self.name, - body=payload, - thread="beliefs", + async def _handle_program_manager_message(self, msg: InternalMessage): + """ + Handle a message from the program manager: extract available beliefs and goals from it. + + :param msg: The received message from the program manager. + """ + match msg.thread: + case "beliefs": + self._handle_beliefs_message(msg) + await self._infer_new_beliefs() + case "goals": + self._handle_goals_message(msg) + await self._infer_goal_completions() + case "achieved_goals": + self._handle_goal_achieved_message(msg) + case "conversation_history": + if msg.body == "reset": + self._reset_phase() + case _: + self.logger.warning("Received unexpected message from %s", msg.sender) + + def _reset_phase(self): + """ + Delete all state about the current phase, such as what beliefs exist and which ones are + true. + """ + self.conversation = ChatHistory(messages=[]) + self.belief_inferrer.available_beliefs.clear() + self._current_beliefs = BeliefState() + self.goal_inferrer.goals.clear() + self._current_goal_completions = {} + + def _handle_beliefs_message(self, msg: InternalMessage): + """ + Handle the message from the Program Manager agent containing the beliefs that exist for this + phase. + :param msg: A list of beliefs. + """ + try: + belief_list = BeliefList.model_validate_json(msg.body) + except ValidationError: + self.logger.warning( + "Received message from program manager but it is not a valid list of beliefs." + ) + return + + available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)] + self.belief_inferrer.available_beliefs = available_beliefs + self.logger.debug( + "Received %d semantic beliefs from the program manager: %s", + len(available_beliefs), + ", ".join(b.name for b in available_beliefs), ) + def _handle_goals_message(self, msg: InternalMessage): + """ + Handle the message from the Program Manager agent containing the goals that exist for this + phase. + :param msg: A list of goals. + """ + try: + goals_list = GoalList.model_validate_json(msg.body) + except ValidationError: + self.logger.warning( + "Received message from program manager but it is not a valid list of goals." + ) + return + + # Use only goals that can fail, as the others are always assumed to be completed + available_goals = {g for g in goals_list.goals if g.can_fail} + available_goals -= self._force_completed_goals + self.goal_inferrer.goals = available_goals + self.logger.debug( + "Received %d failable goals from the program manager: %s", + len(available_goals), + ", ".join(g.name for g in available_goals), + ) + + def _handle_goal_achieved_message(self, msg: InternalMessage): + """ + Handle message that gets sent when goals are marked achieved from a user interrupt. This + goal should then not be changed by this agent anymore. + :param msg: List of goals that are marked achieved. + """ + # NOTE: When goals can be marked unachieved, remember to re-add them to the goal_inferrer + try: + goals_list = GoalList.model_validate_json(msg.body) + except ValidationError: + self.logger.warning( + "Received goal achieved message from the program manager, " + "but it is not a valid list of goals." + ) + return + + for goal in goals_list.goals: + self._force_completed_goals.add(goal) + self._current_goal_completions[f"achieved_{AgentSpeakGenerator.slugify(goal)}"] = True + + self.goal_inferrer.goals -= self._force_completed_goals + + async def _user_said(self, text: str): + """ + Create a belief for the user's full speech. + + :param text: User's transcribed text. + """ + belief_msg = InternalMessage( + to=settings.agent_settings.bdi_core_name, + sender=self.name, + body=BeliefMessage( + replace=[InternalBelief(name="user_said", arguments=[text])], + ).model_dump_json(), + thread="beliefs", + ) await self.send(belief_msg) - self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"])) + + async def _infer_new_beliefs(self): + """ + Determine which beliefs hold and do not hold for the current conversation state. When + beliefs change, a message is sent to the BDI core. + """ + conversation_beliefs = await self.belief_inferrer.infer_from_conversation(self.conversation) + + new_beliefs = conversation_beliefs - self._current_beliefs + if not new_beliefs: + self.logger.debug("No new beliefs detected.") + return + + self._current_beliefs |= new_beliefs + + belief_changes = BeliefMessage( + create=list(new_beliefs.true), + delete=list(new_beliefs.false), + ) + + message = InternalMessage( + to=settings.agent_settings.bdi_core_name, + sender=self.name, + body=belief_changes.model_dump_json(), + thread="beliefs", + ) + await self.send(message) + + async def _infer_goal_completions(self): + """ + Determine which goals have been achieved given the current conversation state. When + a goal's achieved state changes, a message is sent to the BDI core. + """ + goal_completions = await self.goal_inferrer.infer_from_conversation(self.conversation) + + new_achieved = [ + InternalBelief(name=goal, arguments=None) + for goal, achieved in goal_completions.items() + if achieved and self._current_goal_completions.get(goal) != achieved + ] + new_not_achieved = [ + InternalBelief(name=goal, arguments=None) + for goal, achieved in goal_completions.items() + if not achieved and self._current_goal_completions.get(goal) != achieved + ] + for goal, achieved in goal_completions.items(): + self._current_goal_completions[goal] = achieved + + if not new_achieved and not new_not_achieved: + self.logger.debug("No goal achievement changes detected.") + return + + belief_changes = BeliefMessage( + create=new_achieved, + delete=new_not_achieved, + ) + message = InternalMessage( + to=settings.agent_settings.bdi_core_name, + sender=self.name, + body=belief_changes.model_dump_json(), + thread="beliefs", + ) + await self.send(message) + + class LLM: + """ + Class that handles sending structured generation requests to an LLM. + """ + + def __init__(self, agent: "TextBeliefExtractorAgent", n_parallel: int): + self._agent = agent + self._semaphore = asyncio.Semaphore(n_parallel) + + async def query(self, prompt: str, schema: dict, tries: int = 3) -> JSONLike | None: + """ + Query the LLM with the given prompt and schema, return an instance of a dict conforming + to this schema. Try ``tries`` times, or return None. + + :param prompt: Prompt to be queried. + :param schema: Schema to be queried. + :param tries: Number of times to try to query the LLM. + :return: An instance of a dict conforming to this schema, or None if failed. + """ + try_count = 0 + while try_count < tries: + try_count += 1 + + try: + return await self._query_llm(prompt, schema) + except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e: + if try_count < tries: + continue + self._agent.logger.exception( + "Failed to get LLM response after %d tries.", + try_count, + exc_info=e, + ) + + return None + + async def _query_llm(self, prompt: str, schema: dict) -> JSONLike: + """ + Query an LLM with the given prompt and schema, return an instance of a dict conforming + to that schema. + + :param prompt: The prompt to be queried. + :param schema: Schema to use during response. + :return: A dict conforming to this schema. + :raises httpx.HTTPStatusError: If the LLM server responded with an error. + :raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the + response was cut off early due to length limitations. + :raises KeyError: If the LLM server responded with no error, but the response was + invalid. + """ + async with self._semaphore: + async with httpx.AsyncClient() as client: + response = await client.post( + settings.llm_settings.local_llm_url, + json={ + "model": settings.llm_settings.local_llm_model, + "messages": [{"role": "user", "content": prompt}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Beliefs", + "strict": True, + "schema": schema, + }, + }, + "reasoning_effort": "low", + "temperature": settings.llm_settings.code_temperature, + "stream": False, + }, + timeout=30.0, + ) + response.raise_for_status() + + response_json = response.json() + json_message = response_json["choices"][0]["message"]["content"] + return json.loads(json_message) + + +class SemanticBeliefInferrer: + """ + Infers semantic beliefs from conversation history using an LLM. + """ + + def __init__( + self, + llm: "TextBeliefExtractorAgent.LLM", + available_beliefs: list[SemanticBelief] | None = None, + ): + self._llm = llm + self.available_beliefs: list[SemanticBelief] = available_beliefs or [] + + async def infer_from_conversation(self, conversation: ChatHistory) -> BeliefState: + """ + Process conversation history to extract beliefs, semantically. The result is an object that + describes all beliefs that hold or don't hold based on the full conversation. + + :param conversation: The conversation history to be processed. + :return: An object that describes beliefs. + """ + # Return instantly if there are no beliefs to infer + if not self.available_beliefs: + return BeliefState() + + n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs))) + all_beliefs: list[dict[str, bool | None] | None] = await asyncio.gather( + *[ + self._infer_beliefs(conversation, beliefs) + for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel) + ] + ) + new_beliefs = BeliefState() + # Collect beliefs from all parallel calls + for beliefs in all_beliefs: + if beliefs is None: + continue + # For each, convert them to InternalBeliefs + for belief_name, belief_holds in beliefs.items(): + # Skip beliefs that were marked not possible to determine + if belief_holds is None: + continue + belief = InternalBelief(name=belief_name, arguments=None) + if belief_holds: + new_beliefs.true.add(belief) + else: + new_beliefs.false.add(belief) + return new_beliefs + + @staticmethod + def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]: + """ + Split a list into ``n`` chunks, making each chunk approximately ``len(items) / n`` long. + + :param items: The list of items to split. + :param n: The number of desired chunks. + :return: A list of chunks each approximately ``len(items) / n`` long. + """ + k, m = divmod(len(items), n) + return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)] + + async def _infer_beliefs( + self, + conversation: ChatHistory, + beliefs: list[SemanticBelief], + ) -> dict[str, bool | None] | None: + """ + Infer given beliefs based on the given conversation. + :param conversation: The conversation to infer beliefs from. + :param beliefs: The beliefs to infer. + :return: A dict containing belief names and a boolean whether they hold, or None if the + belief cannot be inferred based on the given conversation. + """ + example = { + "example_belief": True, + } + + prompt = f"""{self._format_conversation(conversation)} + +Given the above conversation, what beliefs can be inferred? +If there is no relevant information about a belief belief, give null. +In case messages conflict, prefer using the most recent messages for inference. + +Choose from the following list of beliefs, formatted as `- : `: +{self._format_beliefs(beliefs)} + +Respond with a JSON similar to the following, but with the property names as given above: +{json.dumps(example, indent=2)} +""" + + schema = self._create_beliefs_schema(beliefs) + + return await self._llm.query(prompt, schema) + + @staticmethod + def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]: + return AgentSpeakGenerator.slugify(belief), { + "type": ["boolean", "null"], + "description": belief.description, + } + + @staticmethod + def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict: + belief_schemas = [ + SemanticBeliefInferrer._create_belief_schema(belief) for belief in beliefs + ] + + return { + "type": "object", + "properties": dict(belief_schemas), + "required": [name for name, _ in belief_schemas], + } + + @staticmethod + def _format_message(message: ChatMessage): + return f"{message.role.upper()}:\n{message.content}" + + @staticmethod + def _format_conversation(conversation: ChatHistory): + return "\n\n".join( + [SemanticBeliefInferrer._format_message(message) for message in conversation.messages] + ) + + @staticmethod + def _format_beliefs(beliefs: list[SemanticBelief]): + return "\n".join( + [f"- {AgentSpeakGenerator.slugify(belief)}: {belief.description}" for belief in beliefs] + ) + + +class GoalAchievementInferrer(SemanticBeliefInferrer): + """ + Infers whether specific conversational goals have been achieved using an LLM. + """ + + def __init__(self, llm: TextBeliefExtractorAgent.LLM): + super().__init__(llm) + self.goals: set[BaseGoal] = set() + + async def infer_from_conversation(self, conversation: ChatHistory) -> dict[str, bool]: + """ + Determine which goals have been achieved based on the given conversation. + + :param conversation: The conversation to infer goal completion from. + :return: A mapping of goals and a boolean whether they have been achieved. + """ + if not self.goals: + return {} + + goals_achieved = await asyncio.gather( + *[self._infer_goal(conversation, g) for g in self.goals] + ) + return { + f"achieved_{AgentSpeakGenerator.slugify(goal)}": achieved + for goal, achieved in zip(self.goals, goals_achieved, strict=True) + } + + async def _infer_goal(self, conversation: ChatHistory, goal: BaseGoal) -> bool: + prompt = f"""{self._format_conversation(conversation)} + +Given the above conversation, what has the following goal been achieved? + +The name of the goal: {goal.name} +Description of the goal: {goal.description} + +Answer with literally only `true` or `false` (without backticks).""" + + schema = { + "type": "boolean", + } + + return await self._llm.query(prompt, schema) diff --git a/src/control_backend/agents/communication/__init__.py b/src/control_backend/agents/communication/__init__.py index 2aa1535..3dde6cf 100644 --- a/src/control_backend/agents/communication/__init__.py +++ b/src/control_backend/agents/communication/__init__.py @@ -1 +1,5 @@ +""" +Agents responsible for external communication and service discovery. +""" + from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index 5c6ca77..5df5a13 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -3,11 +3,14 @@ import json import zmq import zmq.asyncio as azmq +from pydantic import ValidationError from zmq.asyncio import Context from control_backend.agents import BaseAgent from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent from control_backend.core.config import settings +from control_backend.schemas.internal_message import InternalMessage +from control_backend.schemas.ri_message import PauseCommand from ..actuation.robot_speech_agent import RobotSpeechAgent from ..perception import VADAgent @@ -47,6 +50,8 @@ class RICommunicationAgent(BaseAgent): self._req_socket: azmq.Socket | None = None self.pub_socket: azmq.Socket | None = None self.connected = False + self.gesture_agent: RobotGestureAgent | None = None + self.speech_agent: RobotSpeechAgent | None = None async def setup(self): """ @@ -140,6 +145,7 @@ class RICommunicationAgent(BaseAgent): # At this point, we have a valid response try: + self.logger.debug("Negotiation successful.") await self._handle_negotiation_response(received_message) # Let UI know that we're connected topic = b"ping" @@ -188,6 +194,7 @@ class RICommunicationAgent(BaseAgent): address=addr, bind=bind, ) + self.speech_agent = robot_speech_agent robot_gesture_agent = RobotGestureAgent( settings.agent_settings.robot_gesture_name, address=addr, @@ -195,6 +202,7 @@ class RICommunicationAgent(BaseAgent): gesture_data=gesture_data, single_gesture_data=single_gesture_data, ) + self.gesture_agent = robot_gesture_agent await robot_speech_agent.start() await asyncio.sleep(0.1) # Small delay await robot_gesture_agent.start() @@ -225,6 +233,7 @@ class RICommunicationAgent(BaseAgent): while self._running: if not self.connected: await asyncio.sleep(settings.behaviour_settings.sleep_s) + self.logger.debug("Not connected, skipping ping loop iteration.") continue # We need to listen and send pings. @@ -248,7 +257,8 @@ class RICommunicationAgent(BaseAgent): self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2 ) - self.logger.debug(f'Received message "{message}" from RI.') + if "endpoint" in message and message["endpoint"] != "ping": + self.logger.debug(f'Received message "{message}" from RI.') if "endpoint" not in message: self.logger.warning("No received endpoint in message, expected ping endpoint.") continue @@ -288,13 +298,33 @@ class RICommunicationAgent(BaseAgent): # Tell UI we're disconnected. topic = b"ping" data = json.dumps(False).encode() + self.logger.debug("1") if self.pub_socket: try: + self.logger.debug("2") await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5) except TimeoutError: + self.logger.debug("3") self.logger.warning("Connection ping for router timed out.") # Try to reboot/renegotiate + if self.gesture_agent is not None: + await self.gesture_agent.stop() + + if self.speech_agent is not None: + await self.speech_agent.stop() + + if self.pub_socket is not None: + self.pub_socket.close() + self.logger.debug("Restarting communication negotiation.") - if await self._negotiate_connection(max_retries=1): + if await self._negotiate_connection(max_retries=2): self.connected = True + + async def handle_message(self, msg: InternalMessage): + try: + pause_command = PauseCommand.model_validate_json(msg.body) + await self._req_socket.send_json(pause_command.model_dump()) + self.logger.debug(await self._req_socket.recv_json()) + except ValidationError: + self.logger.warning("Incorrect message format for PauseCommand.") diff --git a/src/control_backend/agents/llm/__init__.py b/src/control_backend/agents/llm/__init__.py index e12ff29..519812f 100644 --- a/src/control_backend/agents/llm/__init__.py +++ b/src/control_backend/agents/llm/__init__.py @@ -1 +1,5 @@ +""" +Agents that interface with Large Language Models for natural language processing and generation. +""" + from .llm_agent import LLMAgent as LLMAgent diff --git a/src/control_backend/agents/llm/llm_agent.py b/src/control_backend/agents/llm/llm_agent.py index 60c585f..1c72dfc 100644 --- a/src/control_backend/agents/llm/llm_agent.py +++ b/src/control_backend/agents/llm/llm_agent.py @@ -46,18 +46,23 @@ class LLMAgent(BaseAgent): :param msg: The received internal message. """ if msg.sender == settings.agent_settings.bdi_core_name: - self.logger.debug("Processing message from BDI core.") - try: - prompt_message = LLMPromptMessage.model_validate_json(msg.body) - await self._process_bdi_message(prompt_message) - except ValidationError: - self.logger.debug("Prompt message from BDI core is invalid.") + match msg.thread: + case "prompt_message": + try: + prompt_message = LLMPromptMessage.model_validate_json(msg.body) + await self._process_bdi_message(prompt_message) + except ValidationError: + self.logger.debug("Prompt message from BDI core is invalid.") + case "assistant_message": + self.history.append({"role": "assistant", "content": msg.body}) + case "user_message": + self.history.append({"role": "user", "content": msg.body}) elif msg.sender == settings.agent_settings.bdi_program_manager_name: if msg.body == "clear_history": self.logger.debug("Clearing conversation history.") self.history.clear() else: - self.logger.debug("Message ignored (not from BDI core.") + self.logger.debug("Message ignored.") async def _process_bdi_message(self, message: LLMPromptMessage): """ @@ -68,11 +73,12 @@ class LLMAgent(BaseAgent): :param message: The parsed prompt message containing text, norms, and goals. """ + full_message = "" async for chunk in self._query_llm(message.text, message.norms, message.goals): await self._send_reply(chunk) - self.logger.debug( - "Finished processing BDI message. Response sent in chunks to BDI core." - ) + full_message += chunk + self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.") + await self._send_full_reply(full_message) async def _send_reply(self, msg: str): """ @@ -87,6 +93,19 @@ class LLMAgent(BaseAgent): ) await self.send(reply) + async def _send_full_reply(self, msg: str): + """ + Sends a response message (full) to agents that need it. + + :param msg: The text content of the message. + """ + message = InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=self.name, + body=msg, + ) + await self.send(message) + async def _query_llm( self, prompt: str, norms: list[str], goals: list[str] ) -> AsyncGenerator[str]: @@ -104,13 +123,6 @@ class LLMAgent(BaseAgent): :param goals: Goals the LLM should achieve. :yield: Fragments of the LLM-generated content (e.g., sentences/phrases). """ - self.history.append( - { - "role": "user", - "content": prompt, - } - ) - instructions = LLMInstructions(norms if norms else None, goals if goals else None) messages = [ { @@ -176,7 +188,7 @@ class LLMAgent(BaseAgent): json={ "model": settings.llm_settings.local_llm_model, "messages": messages, - "temperature": 0.3, + "temperature": settings.llm_settings.chat_temperature, "stream": True, }, ) as response: diff --git a/src/control_backend/agents/perception/__init__.py b/src/control_backend/agents/perception/__init__.py index e18361a..5a46671 100644 --- a/src/control_backend/agents/perception/__init__.py +++ b/src/control_backend/agents/perception/__init__.py @@ -1,3 +1,8 @@ +""" +Agents responsible for processing sensory input, such as audio transcription and voice activity +detection. +""" + from .transcription_agent.transcription_agent import ( TranscriptionAgent as TranscriptionAgent, ) diff --git a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py index 765d7ac..795623d 100644 --- a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py +++ b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py @@ -74,7 +74,7 @@ class TranscriptionAgent(BaseAgent): def _connect_audio_in_socket(self): """ - Helper to connect the ZMQ SUB socket for audio input. + Connects the ZMQ SUB socket for receiving audio data. """ self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") diff --git a/src/control_backend/agents/perception/vad_agent.py b/src/control_backend/agents/perception/vad_agent.py index 70fa9e1..2b333f5 100644 --- a/src/control_backend/agents/perception/vad_agent.py +++ b/src/control_backend/agents/perception/vad_agent.py @@ -7,6 +7,7 @@ import zmq.asyncio as azmq from control_backend.agents import BaseAgent from control_backend.core.config import settings +from control_backend.schemas.internal_message import InternalMessage from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus from .transcription_agent.transcription_agent import TranscriptionAgent @@ -86,6 +87,12 @@ class VADAgent(BaseAgent): self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech self._ready = asyncio.Event() + + # Pause control + self._reset_needed = False + self._paused = asyncio.Event() + self._paused.set() # Not paused at start + self.model = None async def setup(self): @@ -213,6 +220,16 @@ class VADAgent(BaseAgent): """ await self._ready.wait() while self._running: + await self._paused.wait() + + # After being unpaused, reset stream and buffers + if self._reset_needed: + self.logger.debug("Resuming: resetting stream and buffers.") + await self._reset_stream() + self.audio_buffer = np.array([], dtype=np.float32) + self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech + self._reset_needed = False + assert self.audio_in_poller is not None data = await self.audio_in_poller.poll() if data is None: @@ -229,10 +246,11 @@ class VADAgent(BaseAgent): assert self.model is not None prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item() non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks + begin_silence_length = settings.behaviour_settings.vad_begin_silence_chunks prob_threshold = settings.behaviour_settings.vad_prob_threshold if prob > prob_threshold: - if self.i_since_speech > non_speech_patience: + if self.i_since_speech > non_speech_patience + begin_silence_length: self.logger.debug("Speech started.") self.audio_buffer = np.append(self.audio_buffer, chunk) self.i_since_speech = 0 @@ -246,7 +264,7 @@ class VADAgent(BaseAgent): continue # Speech probably ended. Make sure we have a usable amount of data. - if len(self.audio_buffer) >= 3 * len(chunk): + if len(self.audio_buffer) > begin_silence_length * len(chunk): self.logger.debug("Speech ended.") assert self.audio_out_socket is not None await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes()) @@ -254,3 +272,27 @@ class VADAgent(BaseAgent): # At this point, we know that the speech has ended. # Prepend the last chunk that had no speech, for a more fluent boundary self.audio_buffer = chunk + + async def handle_message(self, msg: InternalMessage): + """ + Handle incoming messages. + + Expects messages to pause or resume the VAD processing from User Interrupt Agent. + + :param msg: The received internal message. + """ + sender = msg.sender + + if sender == settings.agent_settings.user_interrupt_name: + if msg.body == "PAUSE": + self.logger.info("Pausing VAD processing.") + self._paused.clear() + # If the robot needs to pick up speaking where it left off, do not set _reset_needed + self._reset_needed = True + elif msg.body == "RESUME": + self.logger.info("Resuming VAD processing.") + self._paused.set() + else: + self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}") + else: + self.logger.debug(f"Ignoring message from unknown sender: {sender}") diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index b2efc41..e2b2d87 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -4,9 +4,17 @@ import zmq from zmq.asyncio import Context from control_backend.agents import BaseAgent +from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings -from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, SpeechCommand +from control_backend.schemas.belief_message import Belief, BeliefMessage +from control_backend.schemas.program import ConditionalNorm, Program +from control_backend.schemas.ri_message import ( + GestureCommand, + PauseCommand, + RIEndpoint, + SpeechCommand, +) class UserInterruptAgent(BaseAgent): @@ -18,29 +26,55 @@ class UserInterruptAgent(BaseAgent): - Send a prioritized message to the `RobotSpeechAgent` - Send a prioritized gesture to the `RobotGestureAgent` - - Send a belief override to the `BDIProgramManager`in order to activate a + - Send a belief override to the `BDI Core` in order to activate a trigger/conditional norm or complete a goal. Prioritized actions clear the current RI queue before inserting the new item, ensuring they are executed immediately after Pepper's current action has been fulfilled. - :ivar sub_socket: The ZMQ SUB socket used to receive user intterupts. + :ivar sub_socket: The ZMQ SUB socket used to receive user interrupts. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.sub_socket = None + self.pub_socket = None + self._trigger_map = {} + self._trigger_reverse_map = {} + + self._goal_map = {} # id -> sluggified goal + self._goal_reverse_map = {} # sluggified goal -> id + + self._cond_norm_map = {} # id -> sluggified cond norm + self._cond_norm_reverse_map = {} # sluggified cond norm -> id + + async def setup(self): + """ + Initialize the agent by setting up ZMQ sockets for receiving button events and + publishing updates. + """ + context = Context.instance() + + self.sub_socket = context.socket(zmq.SUB) + self.sub_socket.connect(settings.zmq_settings.internal_sub_address) + self.sub_socket.subscribe("button_pressed") + + self.pub_socket = context.socket(zmq.PUB) + self.pub_socket.connect(settings.zmq_settings.internal_pub_address) + + self.add_behavior(self._receive_button_event()) async def _receive_button_event(self): """ - The behaviour of the UserInterruptAgent. - Continuous loop that receives button_pressed events from the button_pressed HTTP endpoint. - These events contain a type and a context. + Main loop to receive and process button press events from the UI. - These are the different types and contexts: - - type: "speech", context: string that the robot has to say. - - type: "gesture", context: single gesture name that the robot has to perform. - - type: "override", context: belief_id that overrides the goal/trigger/conditional norm. + Handles different event types: + - `speech`: Triggers immediate robot speech. + - `gesture`: Triggers an immediate robot gesture. + - `override`: Forces a belief, trigger, or goal completion in the BDI core. + - `override_unachieve`: Removes a belief from the BDI core. + - `pause`: Toggles the system's pause state. + - `next_phase` / `reset_phase`: Controls experiment flow. """ while True: topic, body = await self.sub_socket.recv_multipart() @@ -53,30 +87,208 @@ class UserInterruptAgent(BaseAgent): self.logger.error("Received invalid JSON payload on topic %s", topic) continue - if event_type == "speech": - await self._send_to_speech_agent(event_context) - self.logger.info( - "Forwarded button press (speech) with context '%s' to RobotSpeechAgent.", - event_context, - ) - elif event_type == "gesture": - await self._send_to_gesture_agent(event_context) - self.logger.info( - "Forwarded button press (gesture) with context '%s' to RobotGestureAgent.", - event_context, - ) - elif event_type == "override": - await self._send_to_program_manager(event_context) - self.logger.info( - "Forwarded button press (override) with context '%s' to BDIProgramManager.", - event_context, - ) - else: - self.logger.warning( - "Received button press with unknown type '%s' (context: '%s').", - event_type, - event_context, - ) + self.logger.debug("Received event type %s", event_type) + + match event_type: + case "speech": + await self._send_to_speech_agent(event_context) + self.logger.info( + "Forwarded button press (speech) with context '%s' to RobotSpeechAgent.", + event_context, + ) + case "gesture": + await self._send_to_gesture_agent(event_context) + self.logger.info( + "Forwarded button press (gesture) with context '%s' to RobotGestureAgent.", + event_context, + ) + case "override": + ui_id = str(event_context) + if asl_trigger := self._trigger_map.get(ui_id): + await self._send_to_bdi("force_trigger", asl_trigger) + self.logger.info( + "Forwarded button press (override) with context '%s' to BDI Core.", + event_context, + ) + elif asl_cond_norm := self._cond_norm_map.get(ui_id): + await self._send_to_bdi_belief(asl_cond_norm, "cond_norm") + self.logger.info( + "Forwarded button press (override) with context '%s' to BDI Core.", + event_context, + ) + elif asl_goal := self._goal_map.get(ui_id): + await self._send_to_bdi_belief(asl_goal, "goal") + self.logger.info( + "Forwarded button press (override) with context '%s' to BDI Core.", + event_context, + ) + # Send achieve_goal to program manager to update semantic belief extractor + goal_achieve_msg = InternalMessage( + to=settings.agent_settings.bdi_program_manager_name, + thread="achieve_goal", + body=ui_id, + ) + + await self.send(goal_achieve_msg) + else: + self.logger.warning("Could not determine which element to override.") + case "override_unachieve": + ui_id = str(event_context) + if asl_cond_norm := self._cond_norm_map.get(ui_id): + await self._send_to_bdi_belief(asl_cond_norm, "cond_norm", True) + self.logger.info( + "Forwarded button press (override_unachieve)" + "with context '%s' to BDI Core.", + event_context, + ) + else: + self.logger.warning( + "Could not determine which conditional norm to unachieve." + ) + + case "pause": + self.logger.debug( + "Received pause/resume button press with context '%s'.", event_context + ) + await self._send_pause_command(event_context) + if event_context: + self.logger.info("Sent pause command.") + else: + self.logger.info("Sent resume command.") + + case "next_phase" | "reset_phase": + await self._send_experiment_control_to_bdi_core(event_type) + case _: + self.logger.warning( + "Received button press with unknown type '%s' (context: '%s').", + event_type, + event_context, + ) + + async def handle_message(self, msg: InternalMessage): + """ + Handles internal messages from other agents, such as program updates or trigger + notifications. + + :param msg: The incoming :class:`~control_backend.core.agent_system.InternalMessage`. + """ + match msg.thread: + case "new_program": + self._create_mapping(msg.body) + case "trigger_start": + # msg.body is the sluggified trigger + asl_slug = msg.body + ui_id = self._trigger_reverse_map.get(asl_slug) + + if ui_id: + payload = {"type": "trigger_update", "id": ui_id, "achieved": True} + await self._send_experiment_update(payload) + self.logger.info(f"UI Update: Trigger {asl_slug} started (ID: {ui_id})") + case "trigger_end": + asl_slug = msg.body + ui_id = self._trigger_reverse_map.get(asl_slug) + if ui_id: + payload = {"type": "trigger_update", "id": ui_id, "achieved": False} + await self._send_experiment_update(payload) + self.logger.info(f"UI Update: Trigger {asl_slug} ended (ID: {ui_id})") + case "transition_phase": + new_phase_id = msg.body + self.logger.info(f"Phase transition detected: {new_phase_id}") + + payload = {"type": "phase_update", "id": new_phase_id} + + await self._send_experiment_update(payload) + case "goal_start": + goal_name = msg.body + ui_id = self._goal_reverse_map.get(goal_name) + if ui_id: + payload = {"type": "goal_update", "id": ui_id, "active": True} + await self._send_experiment_update(payload) + self.logger.info(f"UI Update: Goal {goal_name} started (ID: {ui_id})") + case "active_norms_update": + active_norms_asl = [ + s.strip("() '\",") for s in msg.body.split(",") if s.strip("() '\",") + ] + await self._broadcast_cond_norms(active_norms_asl) + case _: + self.logger.debug(f"Received internal message on unhandled thread: {msg.thread}") + + async def _broadcast_cond_norms(self, active_slugs: list[str]): + """ + Broadcasts the current activation state of all conditional norms to the UI. + + :param active_slugs: A list of sluggified norm names currently active in the BDI core. + """ + updates = [] + for asl_slug, ui_id in self._cond_norm_reverse_map.items(): + is_active = asl_slug in active_slugs + updates.append({"id": ui_id, "active": is_active}) + + payload = {"type": "cond_norms_state_update", "norms": updates} + + if self.pub_socket: + topic = b"status" + body = json.dumps(payload).encode("utf-8") + await self.pub_socket.send_multipart([topic, body]) + # self.logger.info(f"UI Update: Active norms {updates}") + + def _create_mapping(self, program_json: str): + """ + Creates a bidirectional mapping between UI identifiers and AgentSpeak slugs. + + :param program_json: The JSON representation of the behavioral program. + """ + try: + program = Program.model_validate_json(program_json) + self._trigger_map = {} + self._trigger_reverse_map = {} + self._goal_map = {} + self._cond_norm_map = {} + self._cond_norm_reverse_map = {} + + for phase in program.phases: + for trigger in phase.triggers: + slug = AgentSpeakGenerator.slugify(trigger) + self._trigger_map[str(trigger.id)] = slug + self._trigger_reverse_map[slug] = str(trigger.id) + + for goal in phase.goals: + self._goal_map[str(goal.id)] = AgentSpeakGenerator.slugify(goal) + self._goal_reverse_map[AgentSpeakGenerator.slugify(goal)] = str(goal.id) + + for goal, id in self._goal_reverse_map.items(): + self.logger.debug(f"Goal mapping: UI ID {goal} -> {id}") + + for norm in phase.norms: + if isinstance(norm, ConditionalNorm): + asl_slug = AgentSpeakGenerator.slugify(norm) + + norm_id = str(norm.id) + + self._cond_norm_map[norm_id] = asl_slug + self._cond_norm_reverse_map[norm.norm] = norm_id + self.logger.debug("Added conditional norm %s", asl_slug) + + self.logger.info( + f"Mapped {len(self._trigger_map)} triggers and {len(self._goal_map)} goals " + f"and {len(self._cond_norm_map)} conditional norms for UserInterruptAgent." + ) + except Exception as e: + self.logger.error(f"Mapping failed: {e}") + + async def _send_experiment_update(self, data, should_log: bool = True): + """ + Publishes an experiment state update to the internal ZMQ bus for the UI. + + :param data: The update payload. + :param should_log: Whether to log the update. + """ + if self.pub_socket: + topic = b"experiment" + body = json.dumps(data).encode("utf-8") + await self.pub_socket.send_multipart([topic, body]) + if should_log: + self.logger.debug(f"Sent experiment update: {data}") async def _send_to_speech_agent(self, text_to_say: str): """ @@ -109,38 +321,93 @@ class UserInterruptAgent(BaseAgent): ) await self.send(out_msg) - async def _send_to_program_manager(self, belief_id: str): - """ - Send a button_override belief to the BDIProgramManager. + async def _send_to_bdi(self, thread: str, body: str): + """Send slug of trigger to BDI""" + msg = InternalMessage(to=settings.agent_settings.bdi_core_name, thread=thread, body=body) + await self.send(msg) + self.logger.info(f"Directly forced {thread} in BDI: {body}") - :param belief_id: The belief_id that overrides the goal/trigger/conditional norm. - this id can belong to a basic belief or an inferred belief. - See also: https://utrechtuniversity.youtrack.cloud/articles/N25B-A-27/UI-components + async def _send_to_bdi_belief(self, asl: str, asl_type: str, unachieve: bool = False): + """Send belief to BDI Core""" + if asl_type == "goal": + belief_name = f"achieved_{asl}" + elif asl_type == "cond_norm": + belief_name = f"force_{asl}" + else: + self.logger.warning("Tried to send belief with unknown type") + return + + belief = Belief(name=belief_name, arguments=None) + self.logger.debug(f"Sending belief to BDI Core: {belief_name}") + # Conditional norms are unachieved by removing the belief + belief_message = ( + BeliefMessage(delete=[belief]) if unachieve else BeliefMessage(create=[belief]) + ) + msg = InternalMessage( + to=settings.agent_settings.bdi_core_name, + thread="beliefs", + body=belief_message.model_dump_json(), + ) + await self.send(msg) + + async def _send_experiment_control_to_bdi_core(self, type): """ - data = {"belief": belief_id} - message = InternalMessage( - to=settings.agent_settings.bdi_program_manager_name, + method to send experiment control buttons to bdi core. + + :param type: the type of control button we should send to the bdi core. + """ + # Switch which thread we should send to bdi core + thread = "" + match type: + case "next_phase": + thread = "force_next_phase" + case "reset_phase": + thread = "reset_current_phase" + case "reset_experiment": + thread = "reset_experiment" + case _: + self.logger.warning( + "Received unknown experiment control type '%s' to send to BDI Core.", + type, + ) + + out_msg = InternalMessage( + to=settings.agent_settings.bdi_core_name, sender=self.name, - body=json.dumps(data), - thread="belief_override_id", + thread=thread, + body="", + ) + self.logger.debug("Sending experiment control '%s' to BDI Core.", thread) + await self.send(out_msg) + + async def _send_pause_command(self, pause): + """ + Send a pause command to the Robot Interface via the RI Communication Agent. + Send a pause command to the other internal agents; for now just VAD agent. + """ + cmd = PauseCommand(data=pause) + message = InternalMessage( + to=settings.agent_settings.ri_communication_name, + sender=self.name, + body=cmd.model_dump_json(), ) await self.send(message) - self.logger.info( - "Sent button_override belief with id '%s' to Program manager.", - belief_id, - ) - async def setup(self): - """ - Initialize the agent. - - Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic. - Starts the background behavior to receive the user interrupts. - """ - context = Context.instance() - - self.sub_socket = context.socket(zmq.SUB) - self.sub_socket.connect(settings.zmq_settings.internal_sub_address) - self.sub_socket.subscribe("button_pressed") - - self.add_behavior(self._receive_button_event()) + if pause == "true": + # Send pause to VAD agent + vad_message = InternalMessage( + to=settings.agent_settings.vad_name, + sender=self.name, + body="PAUSE", + ) + await self.send(vad_message) + self.logger.info("Sent pause command to VAD Agent and RI Communication Agent.") + else: + # Send resume to VAD agent + vad_message = InternalMessage( + to=settings.agent_settings.vad_name, + sender=self.name, + body="RESUME", + ) + await self.send(vad_message) + self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.") diff --git a/src/control_backend/api/v1/endpoints/button_pressed.py b/src/control_backend/api/v1/endpoints/button_pressed.py deleted file mode 100644 index 5a94a53..0000000 --- a/src/control_backend/api/v1/endpoints/button_pressed.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging - -from fastapi import APIRouter, Request - -from control_backend.schemas.events import ButtonPressedEvent - -logger = logging.getLogger(__name__) -router = APIRouter() - - -@router.post("/button_pressed", status_code=202) -async def receive_button_event(event: ButtonPressedEvent, request: Request): - """ - Endpoint to handle external button press events. - - Validates the event payload and publishes it to the internal 'button_pressed' topic. - Subscribers (in this case user_interrupt_agent) will pick this up to trigger - specific behaviors or state changes. - - :param event: The parsed ButtonPressedEvent object. - :param request: The FastAPI request object. - """ - logger.debug("Received button event: %s | %s", event.type, event.context) - - topic = b"button_pressed" - body = event.model_dump_json().encode() - - pub_socket = request.app.state.endpoints_pub_socket - await pub_socket.send_multipart([topic, body]) - - return {"status": "Event received"} diff --git a/src/control_backend/api/v1/endpoints/robot.py b/src/control_backend/api/v1/endpoints/robot.py index afbf1ac..95a9c40 100644 --- a/src/control_backend/api/v1/endpoints/robot.py +++ b/src/control_backend/api/v1/endpoints/robot.py @@ -137,7 +137,6 @@ async def ping_stream(request: Request): logger.info("Client disconnected from SSE") break - logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}") connectedJson = json.dumps(connected) yield (f"data: {connectedJson}\n\n") diff --git a/src/control_backend/api/v1/endpoints/sse.py b/src/control_backend/api/v1/endpoints/sse.py deleted file mode 100644 index c660aa5..0000000 --- a/src/control_backend/api/v1/endpoints/sse.py +++ /dev/null @@ -1,12 +0,0 @@ -from fastapi import APIRouter, Request - -router = APIRouter() - - -# TODO: implement -@router.get("/sse") -async def sse(request: Request): - """ - Placeholder for future Server-Sent Events endpoint. - """ - pass diff --git a/src/control_backend/api/v1/endpoints/user_interact.py b/src/control_backend/api/v1/endpoints/user_interact.py new file mode 100644 index 0000000..eb70f35 --- /dev/null +++ b/src/control_backend/api/v1/endpoints/user_interact.py @@ -0,0 +1,94 @@ +import asyncio +import logging + +import zmq +import zmq.asyncio +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from zmq.asyncio import Context + +from control_backend.core.config import settings +from control_backend.schemas.events import ButtonPressedEvent + +logger = logging.getLogger(__name__) +router = APIRouter() + + +@router.post("/button_pressed", status_code=202) +async def receive_button_event(event: ButtonPressedEvent, request: Request): + """ + Endpoint to handle external button press events. + + Validates the event payload and publishes it to the internal 'button_pressed' topic. + Subscribers (in this case user_interrupt_agent) will pick this up to trigger + specific behaviors or state changes. + + :param event: The parsed ButtonPressedEvent object. + :param request: The FastAPI request object. + """ + logger.debug("Received button event: %s | %s", event.type, event.context) + + topic = b"button_pressed" + body = event.model_dump_json().encode() + + pub_socket = request.app.state.endpoints_pub_socket + await pub_socket.send_multipart([topic, body]) + + return {"status": "Event received"} + + +@router.get("/experiment_stream") +async def experiment_stream(request: Request): + # Use the asyncio-compatible context + context = Context.instance() + socket = context.socket(zmq.SUB) + + # Connect and subscribe + socket.connect(settings.zmq_settings.internal_sub_address) + socket.subscribe(b"experiment") + + async def gen(): + try: + while True: + # Check if client closed the tab + if await request.is_disconnected(): + logger.error("Client disconnected from experiment stream.") + break + + try: + parts = await asyncio.wait_for(socket.recv_multipart(), timeout=10.0) + _, message = parts + yield f"data: {message.decode().strip()}\n\n" + except TimeoutError: + continue + finally: + socket.close() + + return StreamingResponse(gen(), media_type="text/event-stream") + + +@router.get("/status_stream") +async def status_stream(request: Request): + context = Context.instance() + socket = context.socket(zmq.SUB) + socket.connect(settings.zmq_settings.internal_sub_address) + + socket.subscribe(b"status") + + async def gen(): + try: + while True: + if await request.is_disconnected(): + break + try: + # Shorter timeout since this is frequent + parts = await asyncio.wait_for(socket.recv_multipart(), timeout=0.5) + _, message = parts + yield f"data: {message.decode().strip()}\n\n" + except TimeoutError: + yield ": ping\n\n" # Keep the connection alive + continue + finally: + socket.close() + + return StreamingResponse(gen(), media_type="text/event-stream") diff --git a/src/control_backend/api/v1/router.py b/src/control_backend/api/v1/router.py index ebba0db..b46df5f 100644 --- a/src/control_backend/api/v1/router.py +++ b/src/control_backend/api/v1/router.py @@ -1,17 +1,15 @@ from fastapi.routing import APIRouter -from control_backend.api.v1.endpoints import button_pressed, logs, message, program, robot, sse +from control_backend.api.v1.endpoints import logs, message, program, robot, user_interact api_router = APIRouter() api_router.include_router(message.router, tags=["Messages"]) -api_router.include_router(sse.router, tags=["SSE"]) - api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"]) api_router.include_router(logs.router, tags=["Logs"]) api_router.include_router(program.router, tags=["Program"]) -api_router.include_router(button_pressed.router, tags=["Button Pressed Events"]) +api_router.include_router(user_interact.router, tags=["Button Pressed Events"]) diff --git a/src/control_backend/core/agent_system.py b/src/control_backend/core/agent_system.py index 5b2ea7e..267f072 100644 --- a/src/control_backend/core/agent_system.py +++ b/src/control_backend/core/agent_system.py @@ -22,10 +22,22 @@ class AgentDirectory: @staticmethod def register(name: str, agent: "BaseAgent"): + """ + Registers an agent instance with a unique name. + + :param name: The name of the agent. + :param agent: The :class:`BaseAgent` instance. + """ _agent_directory[name] = agent @staticmethod def get(name: str) -> "BaseAgent | None": + """ + Retrieves a registered agent instance by name. + + :param name: The name of the agent to retrieve. + :return: The :class:`BaseAgent` instance, or None if not found. + """ return _agent_directory.get(name) @@ -120,7 +132,7 @@ class BaseAgent(ABC): task.cancel() self.logger.info(f"Agent {self.name} stopped") - async def send(self, message: InternalMessage): + async def send(self, message: InternalMessage, should_log: bool = True): """ Send a message to another agent. @@ -142,13 +154,17 @@ class BaseAgent(ABC): if target: await target.inbox.put(message) - self.logger.debug(f"Sent message {message.body} to {message.to} via regular inbox.") + if should_log: + self.logger.debug( + f"Sent message {message.body} to {message.to} via regular inbox." + ) else: # Apparently target agent is on a different process, send via ZMQ topic = f"internal/{receiver}".encode() body = message.model_dump_json().encode() await self._internal_pub_socket.send_multipart([topic, body]) - self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.") + if should_log: + self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.") async def _process_inbox(self): """ @@ -158,7 +174,6 @@ class BaseAgent(ABC): """ while self._running: msg = await self.inbox.get() - self.logger.debug(f"Received message from {msg.sender}.") await self.handle_message(msg) async def _receive_internal_zmq_loop(self): @@ -201,7 +216,16 @@ class BaseAgent(ABC): :param coro: The coroutine to execute as a task. """ - task = asyncio.create_task(coro) + + async def try_coro(coro_: Coroutine): + try: + await coro_ + except asyncio.CancelledError: + self.logger.debug("A behavior was canceled successfully: %s", coro_) + except Exception: + self.logger.warning("An exception occurred in a behavior.", exc_info=True) + + task = asyncio.create_task(try_coro(coro)) self._tasks.add(task) task.add_done_callback(self._tasks.discard) return task diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 2dbde02..8ef30cb 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -35,7 +35,6 @@ class AgentSettings(BaseModel): Names of the various agents in the system. These names are used for routing messages. :ivar bdi_core_name: Name of the BDI Core Agent. - :ivar bdi_belief_collector_name: Name of the Belief Collector Agent. :ivar bdi_program_manager_name: Name of the BDI Program Manager Agent. :ivar text_belief_extractor_name: Name of the Text Belief Extractor Agent. :ivar vad_name: Name of the Voice Activity Detection (VAD) Agent. @@ -50,7 +49,6 @@ class AgentSettings(BaseModel): # agent names bdi_core_name: str = "bdi_core_agent" - bdi_belief_collector_name: str = "belief_collector_agent" bdi_program_manager_name: str = "bdi_program_manager_agent" text_belief_extractor_name: str = "text_belief_extractor_agent" vad_name: str = "vad_agent" @@ -73,10 +71,14 @@ class BehaviourSettings(BaseModel): :ivar vad_prob_threshold: Probability threshold for Voice Activity Detection. :ivar vad_initial_since_speech: Initial value for 'since speech' counter in VAD. :ivar vad_non_speech_patience_chunks: Number of non-speech chunks to wait before speech ended. + :ivar vad_begin_silence_chunks: The number of chunks of silence to prepend to speech chunks. :ivar transcription_max_concurrent_tasks: Maximum number of concurrent transcription tasks. :ivar transcription_words_per_minute: Estimated words per minute for transcription timing. :ivar transcription_words_per_token: Estimated words per token for transcription timing. :ivar transcription_token_buffer: Buffer for transcription tokens. + :ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from. + :ivar trigger_time_to_wait: Amount of milliseconds to wait before informing the UI about trigger + completion. """ # ATTENTION: When adding/removing settings, make sure to update the .env.example file @@ -88,7 +90,8 @@ class BehaviourSettings(BaseModel): # VAD settings vad_prob_threshold: float = 0.5 vad_initial_since_speech: int = 100 - vad_non_speech_patience_chunks: int = 3 + vad_non_speech_patience_chunks: int = 15 + vad_begin_silence_chunks: int = 6 # transcription behaviour transcription_max_concurrent_tasks: int = 3 @@ -96,6 +99,13 @@ class BehaviourSettings(BaseModel): transcription_words_per_token: float = 0.75 # (3 words = 4 tokens) transcription_token_buffer: int = 10 + # Text belief extractor settings + conversation_history_length_limit: int = 10 + + # AgentSpeak related settings + trigger_time_to_wait: int = 2000 + agentspeak_file: str = "src/control_backend/agents/bdi/agentspeak.asl" + class LLMSettings(BaseModel): """ @@ -103,12 +113,19 @@ class LLMSettings(BaseModel): :ivar local_llm_url: URL for the local LLM API. :ivar local_llm_model: Name of the local LLM model to use. + :ivar chat_temperature: The temperature to use while generating chat responses. + :ivar code_temperature: The temperature to use while generating code-like responses like during + belief inference. + :ivar n_parallel: The number of parallel calls allowed to be made to the LLM. """ # ATTENTION: When adding/removing settings, make sure to update the .env.example file local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "gpt-oss" + chat_temperature: float = 1.0 + code_temperature: float = 0.3 + n_parallel: int = 4 class VADSettings(BaseModel): diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 3509cbc..a0136bd 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -26,7 +26,6 @@ from zmq.asyncio import Context # BDI agents from control_backend.agents.bdi import ( - BDIBeliefCollectorAgent, BDICoreAgent, TextBeliefExtractorAgent, ) @@ -120,13 +119,6 @@ async def lifespan(app: FastAPI): BDICoreAgent, { "name": settings.agent_settings.bdi_core_name, - "asl": "src/control_backend/agents/bdi/rules.asl", - }, - ), - "BeliefCollectorAgent": ( - BDIBeliefCollectorAgent, - { - "name": settings.agent_settings.bdi_belief_collector_name, }, ), "TextBeliefExtractorAgent": ( @@ -173,6 +165,8 @@ async def lifespan(app: FastAPI): await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value]) # Additional shutdown logic goes here + for agent in agents: + await agent.stop() logger.info("Application shutdown complete.") diff --git a/src/control_backend/schemas/belief_list.py b/src/control_backend/schemas/belief_list.py new file mode 100644 index 0000000..841a4ed --- /dev/null +++ b/src/control_backend/schemas/belief_list.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel + +from control_backend.schemas.program import BaseGoal +from control_backend.schemas.program import Belief as ProgramBelief + + +class BeliefList(BaseModel): + """ + Represents a list of beliefs, separated from a program. Useful in agents which need to + communicate beliefs. + + :ivar: beliefs: The list of beliefs. + """ + + beliefs: list[ProgramBelief] + + +class GoalList(BaseModel): + """ + Represents a list of goals, used for communicating multiple goals between agents. + + :ivar goals: The list of goals. + """ + + goals: list[BaseGoal] diff --git a/src/control_backend/schemas/belief_message.py b/src/control_backend/schemas/belief_message.py index deb1152..226833e 100644 --- a/src/control_backend/schemas/belief_message.py +++ b/src/control_backend/schemas/belief_message.py @@ -6,18 +6,30 @@ class Belief(BaseModel): Represents a single belief in the BDI system. :ivar name: The functor or name of the belief (e.g., 'user_said'). - :ivar arguments: A list of string arguments for the belief. - :ivar replace: If True, existing beliefs with this name should be replaced by this one. + :ivar arguments: A list of string arguments for the belief, or None if the belief has no + arguments. """ name: str - arguments: list[str] - replace: bool = False + arguments: list[str] | None = None + + # To make it hashable + model_config = {"frozen": True} class BeliefMessage(BaseModel): """ - A container for transporting a list of beliefs between agents. + A container for communicating beliefs between agents. + + :ivar create: Beliefs to create. + :ivar delete: Beliefs to delete. + :ivar replace: Beliefs to replace. Deletes all beliefs with the same name, replacing them with + one new belief. """ - beliefs: list[Belief] + create: list[Belief] = [] + delete: list[Belief] = [] + replace: list[Belief] = [] + + def has_values(self) -> bool: + return len(self.create) > 0 or len(self.delete) > 0 or len(self.replace) > 0 diff --git a/src/control_backend/schemas/chat_history.py b/src/control_backend/schemas/chat_history.py new file mode 100644 index 0000000..8fd1e72 --- /dev/null +++ b/src/control_backend/schemas/chat_history.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel + + +class ChatMessage(BaseModel): + """ + Represents a single message in a conversation. + + :ivar role: The role of the speaker (e.g., 'user', 'assistant'). + :ivar content: The text content of the message. + """ + + role: str + content: str + + +class ChatHistory(BaseModel): + """ + Represents a sequence of chat messages, forming a conversation history. + + :ivar messages: An ordered list of :class:`ChatMessage` objects. + """ + + messages: list[ChatMessage] diff --git a/src/control_backend/schemas/events.py b/src/control_backend/schemas/events.py index 46967f7..a01b668 100644 --- a/src/control_backend/schemas/events.py +++ b/src/control_backend/schemas/events.py @@ -2,5 +2,13 @@ from pydantic import BaseModel class ButtonPressedEvent(BaseModel): + """ + Represents a button press event from the UI. + + :ivar type: The type of event (e.g., 'speech', 'gesture', 'override'). + :ivar context: Additional data associated with the event (e.g., speech text, gesture name, + or ID). + """ + type: str context: str diff --git a/src/control_backend/schemas/internal_message.py b/src/control_backend/schemas/internal_message.py index 758c085..afe2908 100644 --- a/src/control_backend/schemas/internal_message.py +++ b/src/control_backend/schemas/internal_message.py @@ -14,6 +14,6 @@ class InternalMessage(BaseModel): """ to: str | Iterable[str] - sender: str + sender: str | None = None body: str thread: str | None = None diff --git a/src/control_backend/schemas/program.py b/src/control_backend/schemas/program.py index 28969b9..3fb0a19 100644 --- a/src/control_backend/schemas/program.py +++ b/src/control_backend/schemas/program.py @@ -1,71 +1,311 @@ -from pydantic import BaseModel +from enum import Enum +from typing import Literal + +from pydantic import UUID4, BaseModel -class Norm(BaseModel): +class ProgramElement(BaseModel): """ - Represents a behavioral norm. + Represents a basic element of our behavior program. + :ivar name: The researcher-assigned name of the element. :ivar id: Unique identifier. - :ivar label: Human-readable label. - :ivar norm: The actual norm text describing the behavior. """ - id: str - label: str - norm: str + name: str + id: UUID4 + + # To make program elements hashable + model_config = {"frozen": True} -class Goal(BaseModel): +class LogicalOperator(Enum): """ - Represents an objective to be achieved. + Logical operators for combining beliefs. - :ivar id: Unique identifier. - :ivar label: Human-readable label. - :ivar description: Detailed description of the goal. - :ivar achieved: Status flag indicating if the goal has been met. + These operators define how beliefs can be combined to form more complex + logical conditions. They are used in inferred beliefs to create compound + beliefs from simpler ones. + + AND: Both operands must be true for the result to be true. + OR: At least one operand must be true for the result to be true. """ - id: str - label: str - description: str - achieved: bool + AND = "AND" + OR = "OR" -class TriggerKeyword(BaseModel): - id: str +type Belief = KeywordBelief | SemanticBelief | InferredBelief +type BasicBelief = KeywordBelief | SemanticBelief + + +class KeywordBelief(ProgramElement): + """ + Represents a belief that is activated when a specific keyword is detected in the user's speech. + + Keyword beliefs provide a simple but effective way to detect specific topics + or intentions in user speech. They are triggered when the exact keyword + string appears in the transcribed user input. + + :ivar keyword: The string to look for in the transcription. + + Example: + A keyword belief with keyword="robot" would be activated when the user + says "I like the robot" or "Tell me about robots". + """ + + name: str = "" keyword: str -class KeywordTrigger(BaseModel): - id: str - label: str - type: str - keywords: list[TriggerKeyword] - - -class Phase(BaseModel): +class SemanticBelief(ProgramElement): """ - A distinct phase within a program, containing norms, goals, and triggers. + Represents a belief whose truth value is determined by an LLM analyzing the conversation + context. - :ivar id: Unique identifier. - :ivar label: Human-readable label. - :ivar norms: List of norms active in this phase. - :ivar goals: List of goals to pursue in this phase. - :ivar triggers: List of triggers that define transitions out of this phase. + Semantic beliefs provide more sophisticated belief detection by using + an LLM to analyze the conversation context and determine + if the belief should be considered true. This allows for more nuanced + and context-aware belief evaluation. + + :ivar description: A natural language description of what this belief represents, + used as a prompt for the LLM. + + Example: + A semantic belief with description="The user is expressing frustration" + would be activated when the LLM determines that the user's statements + indicate frustration, even if no specific keywords are used. """ - id: str - label: str - norms: list[Norm] + description: str + + +class InferredBelief(ProgramElement): + """ + Represents a belief derived from other beliefs using logical operators. + + Inferred beliefs allow for the creation of complex belief structures by + combining simpler beliefs using logical operators. This enables the + representation of sophisticated conditions and relationships between + different aspects of the conversation or context. + + :ivar operator: The :class:`LogicalOperator` (AND/OR) to apply. + :ivar left: The left operand (another belief). + :ivar right: The right operand (another belief). + """ + + name: str = "" + operator: LogicalOperator + left: Belief + right: Belief + + +class Norm(ProgramElement): + """ + Base class for behavioral norms that guide the robot's interactions. + + Norms represent guidelines, principles, or rules that should govern the + robot's behavior during interactions. They can be either basic (always + active in their phase) or conditional (active only when specific beliefs + are true). + + :ivar norm: The textual description of the norm. + :ivar critical: Whether this norm is considered critical and should be strictly enforced. + + Critical norms are currently not supported yet, but are intended for norms that should + ABSOLUTELY NOT be violated, possible cheched by additional validator agents. + """ + + name: str = "" + norm: str + critical: bool = False + + +class BasicNorm(Norm): + """ + A simple behavioral norm that is always considered for activation when its phase is active. + + Basic norms are the most straightforward type of norms. They are active + throughout their assigned phase and provide consistent behavioral guidance + without any additional conditions. + + These norms are suitable for general principles that should always apply + during a particular interaction phase. + """ + + pass + + +class ConditionalNorm(Norm): + """ + A behavioral norm that is only active when a specific condition (belief) is met. + + Conditional norms provide context-sensitive behavioral guidance. They are + only active and considered for activation when their associated condition + (belief) is true. This allows for more nuanced and adaptive behavior that + responds to the specific context of the interaction. + + An important note, is that the current implementation of these norms for keyword-based beliefs + is that they only hold for 1 turn, as keyword-based conditions often express temporary + conditions. + + :ivar condition: The :class:`Belief` that must hold for this norm to be active. + + Example: + A conditional norm with the condition "user is frustrated" might specify + that the robot should use more empathetic language and avoid complex topics. + """ + + condition: Belief + + +type PlanElement = Goal | Action + + +class Plan(ProgramElement): + """ + Represents a list of steps to execute. Each of these steps can be a goal (with its own plan) + or a simple action. + + Plans define sequences of actions and subgoals that the robot should execute + to achieve a particular objective. They form the procedural knowledge of + the behavior program, specifying what the robot should do in different + situations. + + :ivar steps: The actions or subgoals to execute, in order. + """ + + name: str = "" + steps: list[PlanElement] + + +class BaseGoal(ProgramElement): + """ + Represents an objective to be achieved. This base version does not include a plan to achieve + this goal, and is used in semantic belief extraction. + + :ivar description: A description of the goal, used to determine if it has been achieved. + :ivar can_fail: Whether we can fail to achieve the goal after executing the plan. + + The can_fail attribute determines whether goal achievement is binary + (success/failure) or whether it can be determined through conversation + analysis. + """ + + description: str = "" + can_fail: bool = True + + +class Goal(BaseGoal): + """ + Represents an objective to be achieved. To reach the goal, we should execute the corresponding + plan. It inherits from the BaseGoal a variable `can_fail`, which, if true, will cause the + completion to be determined based on the conversation. + + Goals extend base goals by including a specific plan to achieve the objective. + They form the core of the robot's proactive behavior, defining both what + should be accomplished and how to accomplish it. + + Instances of this goal are not hashable because a plan is not hashable. + + :ivar plan: The plan to execute. + """ + + plan: Plan + + +type Action = SpeechAction | GestureAction | LLMAction + + +class SpeechAction(ProgramElement): + """ + An action where the robot speaks a predefined literal text. + + :ivar text: The text content to be spoken. + """ + + name: str = "" + text: str + + +class Gesture(BaseModel): + """ + Defines a physical gesture for the robot to perform. + + :ivar type: Whether to use a specific "single" gesture or a random one from a "tag" category. + :ivar name: The identifier for the gesture or tag. + + The type field determines how the gesture is selected: + - "single": Use the specific gesture identified by name + - "tag": Select a random gesture from the category identified by name + """ + + type: Literal["tag", "single"] + name: str + + +class GestureAction(ProgramElement): + """ + An action where the robot performs a physical gesture. + + :ivar gesture: The :class:`Gesture` definition. + """ + + name: str = "" + gesture: Gesture + + +class LLMAction(ProgramElement): + """ + An action that triggers an LLM-generated conversational response. + + :ivar goal: A temporary conversational goal to guide the LLM's response generation. + + The goal parameter provides high-level guidance to the LLM about what + the response should aim to achieve, while allowing the LLM flexibility + in how to express it. + """ + + name: str = "" + goal: str + + +class Trigger(ProgramElement): + """ + Defines a reactive behavior: when the condition (belief) is met, the plan is executed. + + :ivar condition: The :class:`Belief` that triggers this behavior. + :ivar plan: The :class:`Plan` to execute upon activation. + """ + + condition: Belief + plan: Plan + + +class Phase(ProgramElement): + """ + A logical stage in the interaction program, grouping norms, goals, and triggers. + + :ivar norms: List of norms active during this phase. + :ivar goals: List of goals the robot pursues in this phase. + :ivar triggers: List of reactive behaviors defined for this phase. + """ + + name: str = "" + norms: list[BasicNorm | ConditionalNorm] goals: list[Goal] - triggers: list[KeywordTrigger] + triggers: list[Trigger] class Program(BaseModel): """ - Represents a complete interaction program, consisting of a sequence or set of phases. + The top-level container for a complete robot behavior definition. - :ivar phases: The list of phases that make up the program. + The Program class represents the complete specification of a robot's + behavioral logic. It contains all the phases, norms, goals, triggers, + and actions that define how the robot should behave during interactions. + + :ivar phases: An ordered list of :class:`Phase` objects defining the interaction flow. """ phases: list[Phase] diff --git a/src/control_backend/schemas/ri_message.py b/src/control_backend/schemas/ri_message.py index a48dec6..e6eafa3 100644 --- a/src/control_backend/schemas/ri_message.py +++ b/src/control_backend/schemas/ri_message.py @@ -14,6 +14,7 @@ class RIEndpoint(str, Enum): GESTURE_TAG = "actuate/gesture/tag" PING = "ping" NEGOTIATE_PORTS = "negotiate/ports" + PAUSE = "" class RIMessage(BaseModel): @@ -64,3 +65,15 @@ class GestureCommand(RIMessage): if self.endpoint not in allowed: raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG") return self + + +class PauseCommand(RIMessage): + """ + A specific command to pause or unpause the robot's actions. + + :ivar endpoint: Fixed to ``RIEndpoint.PAUSE``. + :ivar data: A boolean indicating whether to pause (True) or unpause (False). + """ + + endpoint: RIEndpoint = RIEndpoint(RIEndpoint.PAUSE) + data: bool diff --git a/test/integration/agents/perception/vad_agent/test_vad_agent.py b/test/integration/agents/perception/vad_agent/test_vad_agent.py index 668d1ce..3cde755 100644 --- a/test/integration/agents/perception/vad_agent/test_vad_agent.py +++ b/test/integration/agents/perception/vad_agent/test_vad_agent.py @@ -40,7 +40,7 @@ async def test_normal_setup(per_transcription_agent): per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent._streaming_loop = AsyncMock() - async def swallow_background_task(coro): + def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task @@ -106,7 +106,7 @@ async def test_out_socket_creation_failure(zmq_context): per_vad_agent._streaming_loop = AsyncMock() per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None) - async def swallow_background_task(coro): + def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task @@ -126,7 +126,7 @@ async def test_stop(zmq_context, per_transcription_agent): per_vad_agent._reset_stream = AsyncMock() per_vad_agent._streaming_loop = AsyncMock() - async def swallow_background_task(coro): + def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task @@ -150,6 +150,7 @@ async def test_application_startup_complete(zmq_context): vad_agent._running = True vad_agent._reset_stream = AsyncMock() vad_agent.program_sub_socket = AsyncMock() + vad_agent.program_sub_socket.close = MagicMock() vad_agent.program_sub_socket.recv_multipart.side_effect = [ (PROGRAM_STATUS, ProgramStatus.RUNNING.value), ] diff --git a/test/unit/agents/actuation/test_robot_gesture_agent.py b/test/unit/agents/actuation/test_robot_gesture_agent.py index fe051a6..1e6fd8a 100644 --- a/test/unit/agents/actuation/test_robot_gesture_agent.py +++ b/test/unit/agents/actuation/test_robot_gesture_agent.py @@ -28,7 +28,11 @@ async def test_setup_bind(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -55,7 +59,11 @@ async def test_setup_connect(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -119,6 +127,65 @@ async def test_handle_message_rejects_invalid_gesture_tag(): pubsocket.send_json.assert_not_awaited() +@pytest.mark.asyncio +async def test_handle_message_sends_valid_single_gesture_command(): + """Internal message with valid single gesture is forwarded.""" + pubsocket = AsyncMock() + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.pubsocket = pubsocket + + payload = { + "endpoint": RIEndpoint.GESTURE_SINGLE, + "data": "wave", + } + msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) + + await agent.handle_message(msg) + + pubsocket.send_json.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_message_rejects_invalid_single_gesture(): + """Internal message with invalid single gesture is not forwarded.""" + pubsocket = AsyncMock() + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.pubsocket = pubsocket + + payload = { + "endpoint": RIEndpoint.GESTURE_SINGLE, + "data": "dance", + } + msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) + + await agent.handle_message(msg) + + pubsocket.send_json.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_zmq_command_loop_valid_single_gesture_payload(): + """UI command with valid single gesture is read from SUB and published.""" + command = {"endpoint": RIEndpoint.GESTURE_SINGLE, "data": "wave"} + fake_socket = AsyncMock() + + async def recv_once(): + agent._running = False + return b"command", json.dumps(command).encode("utf-8") + + fake_socket.recv_multipart = recv_once + fake_socket.send_json = AsyncMock() + + agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="") + agent.subsocket = fake_socket + agent.pubsocket = fake_socket + agent._running = True + + await agent._zmq_command_loop() + + fake_socket.send_json.assert_awaited_once() + + @pytest.mark.asyncio async def test_handle_message_invalid_payload(): """Invalid payload is caught and does not send.""" @@ -411,8 +478,7 @@ async def test_stop_closes_sockets(): pubsocket.close.assert_called_once() subsocket.close.assert_called_once() - # Note: repsocket is not closed in stop() method, but you might want to add it - # repsocket.close.assert_called_once() + repsocket.close.assert_called_once() @pytest.mark.asyncio diff --git a/test/unit/agents/actuation/test_robot_speech_agent.py b/test/unit/agents/actuation/test_robot_speech_agent.py index d95f66a..e5a664d 100644 --- a/test/unit/agents/actuation/test_robot_speech_agent.py +++ b/test/unit/agents/actuation/test_robot_speech_agent.py @@ -30,7 +30,11 @@ async def test_setup_bind(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -48,7 +52,11 @@ async def test_setup_connect(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() diff --git a/test/unit/agents/bdi/test_agentspeak_ast.py b/test/unit/agents/bdi/test_agentspeak_ast.py new file mode 100644 index 0000000..8d3bdf0 --- /dev/null +++ b/test/unit/agents/bdi/test_agentspeak_ast.py @@ -0,0 +1,186 @@ +import pytest + +from control_backend.agents.bdi.agentspeak_ast import ( + AstAtom, + AstBinaryOp, + AstLiteral, + AstLogicalExpression, + AstNumber, + AstPlan, + AstProgram, + AstRule, + AstStatement, + AstString, + AstVar, + BinaryOperatorType, + StatementType, + TriggerType, + _coalesce_expr, +) + + +def test_ast_atom(): + atom = AstAtom("test") + assert str(atom) == "test" + assert atom._to_agentspeak() == "test" + + +def test_ast_var(): + var = AstVar("Variable") + assert str(var) == "Variable" + assert var._to_agentspeak() == "Variable" + + +def test_ast_number(): + num = AstNumber(42) + assert str(num) == "42" + num_float = AstNumber(3.14) + assert str(num_float) == "3.14" + + +def test_ast_string(): + s = AstString("hello") + assert str(s) == '"hello"' + + +def test_ast_literal(): + lit = AstLiteral("functor", [AstAtom("atom"), AstNumber(1)]) + assert str(lit) == "functor(atom, 1)" + lit_empty = AstLiteral("functor") + assert str(lit_empty) == "functor" + + +def test_ast_binary_op(): + left = AstNumber(1) + right = AstNumber(2) + op = AstBinaryOp(left, BinaryOperatorType.GREATER_THAN, right) + assert str(op) == "1 > 2" + + # Test logical wrapper + assert isinstance(op.left, AstLogicalExpression) + assert isinstance(op.right, AstLogicalExpression) + + +def test_ast_binary_op_parens(): + # 1 > 2 + inner = AstBinaryOp(AstNumber(1), BinaryOperatorType.GREATER_THAN, AstNumber(2)) + # (1 > 2) & 3 + outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstNumber(3)) + assert str(outer) == "(1 > 2) & 3" + + # 3 & (1 > 2) + outer_right = AstBinaryOp(AstNumber(3), BinaryOperatorType.AND, inner) + assert str(outer_right) == "3 & (1 > 2)" + + +def test_ast_binary_op_parens_negated(): + inner = AstLogicalExpression(AstAtom("foo"), negated=True) + outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstAtom("bar")) + # The current implementation checks `if self.left.negated: l_str = f"({l_str})"` + # str(inner) is "not foo" + # so we expect "(not foo) & bar" + assert str(outer) == "(not foo) & bar" + + outer_right = AstBinaryOp(AstAtom("bar"), BinaryOperatorType.AND, inner) + assert str(outer_right) == "bar & (not foo)" + + +def test_ast_logical_expression_negation(): + expr = AstLogicalExpression(AstAtom("true"), negated=True) + assert str(expr) == "not true" + + expr_neg_neg = ~expr + assert str(expr_neg_neg) == "true" + assert not expr_neg_neg.negated + + # Invert a non-logical expression (wraps it) + term = AstAtom("true") + inverted = ~term + assert isinstance(inverted, AstLogicalExpression) + assert inverted.negated + assert str(inverted) == "not true" + + +def test_ast_logical_expression_no_negation(): + # _as_logical on already logical expression + expr = AstLogicalExpression(AstAtom("x")) + # Doing binary op will call _as_logical + op = AstBinaryOp(expr, BinaryOperatorType.AND, AstAtom("y")) + assert isinstance(op.left, AstLogicalExpression) + assert op.left is expr # Should reuse instance + + +def test_ast_operators(): + t1 = AstAtom("a") + t2 = AstAtom("b") + + assert str(t1 & t2) == "a & b" + assert str(t1 | t2) == "a | b" + assert str(t1 >= t2) == "a >= b" + assert str(t1 > t2) == "a > b" + assert str(t1 <= t2) == "a <= b" + assert str(t1 < t2) == "a < b" + assert str(t1 == t2) == "a == b" + assert str(t1 != t2) == r"a \== b" + + +def test_coalesce_expr(): + t = AstAtom("a") + assert str(t & "b") == 'a & "b"' + assert str(t & 1) == "a & 1" + assert str(t & 1.5) == "a & 1.5" + + with pytest.raises(TypeError): + _coalesce_expr(None) + + +def test_ast_statement(): + stmt = AstStatement(StatementType.DO_ACTION, AstLiteral("action")) + assert str(stmt) == ".action" + + +def test_ast_rule(): + # Rule with condition + rule = AstRule(AstLiteral("head"), AstLiteral("body")) + assert str(rule) == "head :- body." + + # Rule without condition + rule_simple = AstRule(AstLiteral("fact")) + assert str(rule_simple) == "fact." + + +def test_ast_plan(): + plan = AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("goal"), + [AstLiteral("context")], + [AstStatement(StatementType.DO_ACTION, AstLiteral("action"))], + ) + output = str(plan) + # verify parts exist + assert "+!goal" in output + assert ": context" in output + assert "<- .action." in output + + +def test_ast_plan_no_context(): + plan = AstPlan( + TriggerType.ADDED_GOAL, + AstLiteral("goal"), + [], + [AstStatement(StatementType.DO_ACTION, AstLiteral("action"))], + ) + output = str(plan) + assert "+!goal" in output + assert ": " not in output + assert "<- .action." in output + + +def test_ast_program(): + prog = AstProgram( + rules=[AstRule(AstLiteral("fact"))], + plans=[AstPlan(TriggerType.ADDED_BELIEF, AstLiteral("b"), [], [])], + ) + output = str(prog) + assert "fact." in output + assert "+b" in output diff --git a/test/unit/agents/bdi/test_agentspeak_generator.py b/test/unit/agents/bdi/test_agentspeak_generator.py new file mode 100644 index 0000000..5a3a849 --- /dev/null +++ b/test/unit/agents/bdi/test_agentspeak_generator.py @@ -0,0 +1,187 @@ +import uuid + +import pytest + +from control_backend.agents.bdi.agentspeak_ast import AstProgram +from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator +from control_backend.schemas.program import ( + BasicNorm, + ConditionalNorm, + Gesture, + GestureAction, + Goal, + InferredBelief, + KeywordBelief, + LLMAction, + LogicalOperator, + Phase, + Plan, + Program, + SemanticBelief, + SpeechAction, + Trigger, +) + + +@pytest.fixture +def generator(): + return AgentSpeakGenerator() + + +def test_generate_empty_program(generator): + prog = Program(phases=[]) + code = generator.generate(prog) + assert 'phase("end").' in code + assert "!notify_cycle" in code + + +def test_generate_basic_norm(generator): + norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="be nice") + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert f'norm("be nice") :- phase("{phase.id}").' in code + + +def test_generate_critical_norm(generator): + norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="safety", critical=True) + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert f'critical_norm("safety") :- phase("{phase.id}").' in code + + +def test_generate_conditional_norm(generator): + cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="please") + norm = ConditionalNorm(id=uuid.uuid4(), name="n1", norm="help", condition=cond) + phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + assert 'norm("help")' in code + assert 'keyword_said("please")' in code + assert f"force_norm_{generator._slugify_str(norm.norm)}" in code + + +def test_generate_goal_and_plan(generator): + action = SpeechAction(id=uuid.uuid4(), name="s1", text="hello") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[action]) + # IMPORTANT: can_fail must be False for +achieved_ belief to be added + goal = Goal(id=uuid.uuid4(), name="g1", description="desc", plan=plan, can_fail=False) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + # Check trigger for goal + goal_slug = generator._slugify_str(goal.name) + assert f"+!{goal_slug}" in code + assert f'phase("{phase.id}")' in code + assert '!say("hello")' in code + + # Check success belief addition + assert f"+achieved_{goal_slug}" in code + + +def test_generate_subgoal(generator): + subplan = Plan(id=uuid.uuid4(), name="p2", steps=[]) + subgoal = Goal(id=uuid.uuid4(), name="sub1", description="sub", plan=subplan) + + plan = Plan(id=uuid.uuid4(), name="p1", steps=[subgoal]) + goal = Goal(id=uuid.uuid4(), name="g1", description="main", plan=plan) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + subgoal_slug = generator._slugify_str(subgoal.name) + # Main goal calls subgoal + assert f"!{subgoal_slug}" in code + # Subgoal plan exists + assert f"+!{subgoal_slug}" in code + + +def test_generate_trigger(generator): + cond = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[]) + trigger = Trigger(id=uuid.uuid4(), name="t1", condition=cond, plan=plan) + phase = Phase(id=uuid.uuid4(), norms=[], goals=[], triggers=[trigger]) + prog = Program(phases=[phase]) + + code = generator.generate(prog) + # Trigger logic is added to check_triggers + assert f"{generator.slugify(cond)}" in code + assert f'notify_trigger_start("{generator.slugify(trigger)}")' in code + assert f'notify_trigger_end("{generator.slugify(trigger)}")' in code + + +def test_phase_transition(generator): + phase1 = Phase(id=uuid.uuid4(), name="p1", norms=[], goals=[], triggers=[]) + phase2 = Phase(id=uuid.uuid4(), name="p2", norms=[], goals=[], triggers=[]) + prog = Program(phases=[phase1, phase2]) + + code = generator.generate(prog) + assert "transition_phase" in code + assert f'phase("{phase1.id}")' in code + assert f'phase("{phase2.id}")' in code + assert "force_transition_phase" in code + + +def test_astify_gesture(generator): + gesture = Gesture(type="single", name="wave") + action = GestureAction(id=uuid.uuid4(), name="g1", gesture=gesture) + ast = generator._astify(action) + assert str(ast) == 'gesture("single", "wave")' + + +def test_astify_llm_action(generator): + action = LLMAction(id=uuid.uuid4(), name="l1", goal="be funny") + ast = generator._astify(action) + assert str(ast) == 'reply_with_goal("be funny")' + + +def test_astify_inferred_belief_and(generator): + left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a") + right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b") + inf = InferredBelief( + id=uuid.uuid4(), name="i1", operator=LogicalOperator.AND, left=left, right=right + ) + + ast = generator._astify(inf) + assert 'keyword_said("a") & keyword_said("b")' == str(ast) + + +def test_astify_inferred_belief_or(generator): + left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a") + right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b") + inf = InferredBelief( + id=uuid.uuid4(), name="i1", operator=LogicalOperator.OR, left=left, right=right + ) + + ast = generator._astify(inf) + assert 'keyword_said("a") | keyword_said("b")' == str(ast) + + +def test_astify_semantic_belief(generator): + sb = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc") + ast = generator._astify(sb) + assert str(ast) == f"semantic_{generator._slugify_str(sb.name)}" + + +def test_slugify_not_implemented(generator): + with pytest.raises(NotImplementedError): + generator.slugify("not a program element") + + +def test_astify_not_implemented(generator): + with pytest.raises(NotImplementedError): + generator._astify("not a program element") + + +def test_process_phase_transition_from_none(generator): + # Initialize AstProgram manually as we are bypassing generate() + generator._asp = AstProgram() + # Should safely return doing nothing + generator._add_phase_transition(None, None) + + assert len(generator._asp.plans) == 0 diff --git a/test/unit/agents/bdi/test_bdi_core_agent.py b/test/unit/agents/bdi/test_bdi_core_agent.py index 8d004fc..6245d5b 100644 --- a/test/unit/agents/bdi/test_bdi_core_agent.py +++ b/test/unit/agents/bdi/test_bdi_core_agent.py @@ -20,7 +20,7 @@ def mock_agentspeak_env(): @pytest.fixture def agent(): - agent = BDICoreAgent("bdi_agent", "dummy.asl") + agent = BDICoreAgent("bdi_agent") agent.send = AsyncMock() agent.bdi_agent = MagicMock() return agent @@ -45,31 +45,70 @@ async def test_setup_no_asl(mock_agentspeak_env, agent): @pytest.mark.asyncio -async def test_handle_belief_collector_message(agent, mock_settings): +async def test_handle_belief_message(agent, mock_settings): """Test that incoming beliefs are added to the BDI agent""" beliefs = [Belief(name="user_said", arguments=["Hello"])] msg = InternalMessage( to="bdi_agent", - sender=mock_settings.agent_settings.bdi_belief_collector_name, - body=BeliefMessage(beliefs=beliefs).model_dump_json(), + sender=mock_settings.agent_settings.text_belief_extractor_name, + body=BeliefMessage(create=beliefs).model_dump_json(), thread="beliefs", ) await agent.handle_message(msg) - # Expect bdi_agent.call to be triggered to add belief - args = agent.bdi_agent.call.call_args.args - assert args[0] == agentspeak.Trigger.addition - assert args[1] == agentspeak.GoalType.belief - assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),)) + # Check for the specific call we expect among all calls + # bdi_agent.call is called multiple times (for transition_phase, check_triggers) + # We want to confirm the belief addition call exists + found_call = False + for call in agent.bdi_agent.call.call_args_list: + args = call.args + if ( + args[0] == agentspeak.Trigger.addition + and args[1] == agentspeak.GoalType.belief + and args[2].functor == "user_said" + and args[2].args[0].functor == "Hello" + ): + found_call = True + break + + assert found_call, "Expected belief addition call not found in bdi_agent.call history" @pytest.mark.asyncio -async def test_incorrect_belief_collector_message(agent, mock_settings): +async def test_handle_delete_belief_message(agent, mock_settings): + """Test that incoming beliefs to be deleted are removed from the BDI agent""" + beliefs = [Belief(name="user_said", arguments=["Hello"])] + + msg = InternalMessage( + to="bdi_agent", + sender=mock_settings.agent_settings.text_belief_extractor_name, + body=BeliefMessage(delete=beliefs).model_dump_json(), + thread="beliefs", + ) + await agent.handle_message(msg) + + found_call = False + for call in agent.bdi_agent.call.call_args_list: + args = call.args + if ( + args[0] == agentspeak.Trigger.removal + and args[1] == agentspeak.GoalType.belief + and args[2].functor == "user_said" + and args[2].args[0].functor == "Hello" + ): + found_call = True + break + + assert found_call + + +@pytest.mark.asyncio +async def test_incorrect_belief_message(agent, mock_settings): """Test that incorrect message format triggers an exception.""" msg = InternalMessage( to="bdi_agent", - sender=mock_settings.agent_settings.bdi_belief_collector_name, + sender=mock_settings.agent_settings.text_belief_extractor_name, body=json.dumps({"bad_format": "bad_format"}), thread="beliefs", ) @@ -113,14 +152,14 @@ async def test_custom_actions(agent): # Invoke action mock_term = MagicMock() - mock_term.args = ["Hello", "Norm", "Goal"] + mock_term.args = ["Hello", "Norm"] mock_intention = MagicMock() # Run generator gen = action_fn(agent, mock_term, mock_intention) next(gen) # Execute - agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal") + agent._send_to_llm.assert_called_with("Hello", "Norm", "") def test_add_belief_sets_event(agent): @@ -128,7 +167,8 @@ def test_add_belief_sets_event(agent): agent._wake_bdi_loop = MagicMock() belief = Belief(name="test_belief", arguments=["a", "b"]) - agent._apply_beliefs([belief]) + belief_changes = BeliefMessage(replace=[belief]) + agent._apply_belief_changes(belief_changes) assert agent.bdi_agent.call.called agent._wake_bdi_loop.set.assert_called() @@ -137,7 +177,7 @@ def test_add_belief_sets_event(agent): def test_apply_beliefs_empty_returns(agent): """Line: if not beliefs: return""" agent._wake_bdi_loop = MagicMock() - agent._apply_beliefs([]) + agent._apply_belief_changes(BeliefMessage()) agent.bdi_agent.call.assert_not_called() agent._wake_bdi_loop.set.assert_not_called() @@ -150,7 +190,11 @@ def test_remove_belief_success_wakes_loop(agent): agent._remove_belief("remove_me", ["x"]) assert agent.bdi_agent.call.called - trigger, goaltype, literal, *_ = agent.bdi_agent.call.call_args.args + + call_args = agent.bdi_agent.call.call_args.args + trigger = call_args[0] + goaltype = call_args[1] + literal = call_args[2] assert trigger == agentspeak.Trigger.removal assert goaltype == agentspeak.GoalType.belief @@ -220,8 +264,9 @@ def test_replace_belief_calls_remove_all(agent): agent._remove_all_with_name = MagicMock() agent._wake_bdi_loop = MagicMock() - belief = Belief(name="user_said", arguments=["Hello"], replace=True) - agent._apply_beliefs([belief]) + belief = Belief(name="user_said", arguments=["Hello"]) + belief_changes = BeliefMessage(replace=[belief]) + agent._apply_belief_changes(belief_changes) agent._remove_all_with_name.assert_called_with("user_said") @@ -266,3 +311,216 @@ async def test_deadline_sleep_branch(agent): duration = time.time() - start_time assert duration >= 0.004 # loop slept until deadline + + +@pytest.mark.asyncio +async def test_handle_new_program(agent): + agent._load_asl = AsyncMock() + agent.add_behavior = MagicMock() + # Mock existing loop task so it can be cancelled + mock_task = MagicMock() + mock_task.cancel = MagicMock() + agent._bdi_loop_task = mock_task + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) + + msg = InternalMessage(to="bdi_agent", thread="new_program", body="path/to/asl.asl") + + await agent.handle_message(msg) + + mock_task.cancel.assert_called_once() + agent._load_asl.assert_awaited_once_with("path/to/asl.asl") + agent.add_behavior.assert_called() + + +@pytest.mark.asyncio +async def test_handle_user_interrupts(agent, mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + # force_phase_transition + agent._set_goal = MagicMock() + msg = InternalMessage( + to="bdi_agent", + sender=mock_settings.agent_settings.user_interrupt_name, + thread="force_phase_transition", + body="", + ) + await agent.handle_message(msg) + agent._set_goal.assert_called_with("transition_phase") + + # force_trigger + agent._force_trigger = MagicMock() + msg.thread = "force_trigger" + msg.body = "trigger_x" + await agent.handle_message(msg) + agent._force_trigger.assert_called_with("trigger_x") + + # force_norm + agent._force_norm = MagicMock() + msg.thread = "force_norm" + msg.body = "norm_y" + await agent.handle_message(msg) + agent._force_norm.assert_called_with("norm_y") + + # force_next_phase + agent._force_next_phase = MagicMock() + msg.thread = "force_next_phase" + msg.body = "" + await agent.handle_message(msg) + agent._force_next_phase.assert_called_once() + + # unknown interrupt + agent.logger = MagicMock() + msg.thread = "unknown_thing" + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_custom_action_reply_with_goal(agent): + agent._send_to_llm = MagicMock(side_effect=agent.send) + agent._add_custom_actions() + action_fn = agent.actions.actions[(".reply_with_goal", 3)] + + mock_term = MagicMock(args=["msg", "norms", "goal"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + agent._send_to_llm.assert_called_with("msg", "norms", "goal") + + +@pytest.mark.asyncio +async def test_custom_action_notify_norms(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_norms", 1)] + + mock_term = MagicMock(args=["norms_list"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + + agent.send.assert_called() + msg = agent.send.call_args[0][0] + assert msg.thread == "active_norms_update" + assert msg.body == "norms_list" + + +@pytest.mark.asyncio +async def test_custom_action_say(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".say", 1)] + + mock_term = MagicMock(args=["hello"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + + assert agent.send.call_count == 2 + msgs = [c[0][0] for c in agent.send.call_args_list] + assert any(m.to == settings.agent_settings.robot_speech_name for m in msgs) + assert any( + m.to == settings.agent_settings.llm_name and m.thread == "assistant_message" for m in msgs + ) + + +@pytest.mark.asyncio +async def test_custom_action_gesture(agent): + agent._add_custom_actions() + # Test single + action_fn = agent.actions.actions[(".gesture", 2)] + mock_term = MagicMock(args=["single", "wave"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert "actuate/gesture/single" in msg.body + + # Test tag + mock_term.args = ["tag", "happy"] + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert "actuate/gesture/tag" in msg.body + + +@pytest.mark.asyncio +async def test_custom_action_notify_user_said(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_user_said", 1)] + mock_term = MagicMock(args=["hello"]) + gen = action_fn(agent, mock_term, MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert msg.to == settings.agent_settings.llm_name + assert msg.thread == "user_message" + + +@pytest.mark.asyncio +async def test_custom_action_notify_trigger_start_end(agent): + agent._add_custom_actions() + # Start + action_fn = agent.actions.actions[(".notify_trigger_start", 1)] + gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "trigger_start" + + # End + action_fn = agent.actions.actions[(".notify_trigger_end", 1)] + gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "trigger_end" + + +@pytest.mark.asyncio +async def test_custom_action_notify_goal_start(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_goal_start", 1)] + gen = action_fn(agent, MagicMock(args=["g1"]), MagicMock()) + next(gen) + assert agent.send.call_args[0][0].thread == "goal_start" + + +@pytest.mark.asyncio +async def test_custom_action_notify_transition_phase(agent): + agent._add_custom_actions() + action_fn = agent.actions.actions[(".notify_transition_phase", 2)] + gen = action_fn(agent, MagicMock(args=["old", "new"]), MagicMock()) + next(gen) + msg = agent.send.call_args[0][0] + assert msg.thread == "transition_phase" + assert "old" in msg.body and "new" in msg.body + + +def test_remove_belief_no_args(agent): + agent._wake_bdi_loop = MagicMock() + agent.bdi_agent.call.return_value = True + agent._remove_belief("fact", None) + assert agent.bdi_agent.call.called + + +def test_set_goal_with_args(agent): + agent._wake_bdi_loop = MagicMock() + agent._set_goal("goal", ["arg1", "arg2"]) + assert agent.bdi_agent.call.called + + +def test_format_belief_string(): + assert BDICoreAgent.format_belief_string("b") == "b" + assert BDICoreAgent.format_belief_string("b", ["a1", "a2"]) == "b(a1,a2)" + + +def test_force_norm(agent): + agent._add_belief = MagicMock() + agent._force_norm("be_polite") + agent._add_belief.assert_called_with("force_be_polite") + + +def test_force_trigger(agent): + agent._set_goal = MagicMock() + agent._force_trigger("trig") + agent._set_goal.assert_called_with("trig") + + +def test_force_next_phase(agent): + agent._set_goal = MagicMock() + agent._force_next_phase() + agent._set_goal.assert_called_with("force_transition_phase") diff --git a/test/unit/agents/bdi/test_bdi_program_manager.py b/test/unit/agents/bdi/test_bdi_program_manager.py index 573524e..5771451 100644 --- a/test/unit/agents/bdi/test_bdi_program_manager.py +++ b/test/unit/agents/bdi/test_bdi_program_manager.py @@ -1,55 +1,73 @@ import asyncio import json import sys -from unittest.mock import AsyncMock +import uuid +from unittest.mock import AsyncMock, MagicMock, mock_open, patch import pytest from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager from control_backend.core.agent_system import InternalMessage -from control_backend.schemas.belief_message import BeliefMessage -from control_backend.schemas.program import Program +from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program # Fix Windows Proactor loop for zmq if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -def make_valid_program_json(norm="N1", goal="G1"): - return json.dumps( - { - "phases": [ - { - "id": "phase1", - "label": "Phase 1", - "triggers": [], - "norms": [{"id": "n1", "label": "Norm 1", "norm": norm}], - "goals": [ - {"id": "g1", "label": "Goal 1", "description": goal, "achieved": False} - ], - } - ] - } - ) +def make_valid_program_json(norm="N1", goal="G1") -> str: + return Program( + phases=[ + Phase( + id=uuid.uuid4(), + name="Basic Phase", + norms=[ + BasicNorm( + id=uuid.uuid4(), + name=norm, + norm=norm, + ), + ], + goals=[ + Goal( + id=uuid.uuid4(), + name=goal, + description="This description can be used to determine whether the goal " + "has been achieved.", + plan=Plan( + id=uuid.uuid4(), + name="Goal Plan", + steps=[], + ), + can_fail=False, + ), + ], + triggers=[], + ), + ], + ).model_dump_json() @pytest.mark.asyncio -async def test_send_to_bdi(): +async def test_create_agentspeak_and_send_to_bdi(mock_settings): manager = BDIProgramManager(name="program_manager_test") manager.send = AsyncMock() program = Program.model_validate_json(make_valid_program_json()) - await manager._send_to_bdi(program) + + with patch("builtins.open", mock_open()) as mock_file: + await manager._create_agentspeak_and_send_to_bdi(program) + + # Check file writing + mock_file.assert_called_with(mock_settings.behaviour_settings.agentspeak_file, "w") + handle = mock_file() + handle.write.assert_called() assert manager.send.await_count == 1 msg: InternalMessage = manager.send.await_args[0][0] - assert msg.thread == "beliefs" - - beliefs = BeliefMessage.model_validate_json(msg.body) - names = {b.name: b.arguments for b in beliefs.beliefs} - - assert "norms" in names and names["norms"] == ["N1"] - assert "goals" in names and names["goals"] == ["G1"] + assert msg.thread == "new_program" + assert msg.to == mock_settings.agent_settings.bdi_core_name + assert msg.body == mock_settings.behaviour_settings.agentspeak_file @pytest.mark.asyncio @@ -61,9 +79,13 @@ async def test_receive_programs_valid_and_invalid(): ] manager = BDIProgramManager(name="program_manager_test") + manager._internal_pub_socket = AsyncMock() manager.sub_socket = sub - manager._send_to_bdi = AsyncMock() + manager._create_agentspeak_and_send_to_bdi = AsyncMock() manager._send_clear_llm_history = AsyncMock() + manager._send_program_to_user_interrupt = AsyncMock() + manager._send_beliefs_to_semantic_belief_extractor = AsyncMock() + manager._send_goals_to_semantic_belief_extractor = AsyncMock() try: # Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out @@ -72,12 +94,13 @@ async def test_receive_programs_valid_and_invalid(): pass # Only valid Program should have triggered _send_to_bdi - assert manager._send_to_bdi.await_count == 1 - forwarded: Program = manager._send_to_bdi.await_args[0][0] - assert forwarded.phases[0].norms[0].norm == "N1" - assert forwarded.phases[0].goals[0].description == "G1" + assert manager._create_agentspeak_and_send_to_bdi.await_count == 1 + forwarded: Program = manager._create_agentspeak_and_send_to_bdi.await_args[0][0] + assert forwarded.phases[0].norms[0].name == "N1" + assert forwarded.phases[0].goals[0].name == "G1" - # Verify history clear was triggered + # Verify history clear was triggered exactly once (for the valid program) + # The invalid program loop `continue`s before calling _send_clear_llm_history assert manager._send_clear_llm_history.await_count == 1 @@ -91,9 +114,184 @@ async def test_send_clear_llm_history(mock_settings): await manager._send_clear_llm_history() - assert manager.send.await_count == 1 - msg: InternalMessage = manager.send.await_args[0][0] + assert manager.send.await_count == 2 + msg: InternalMessage = manager.send.await_args_list[0][0][0] # Verify the content and recipient assert msg.body == "clear_history" - assert msg.to == "llm_agent" + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase(mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + # Setup state + prog = Program.model_validate_json(make_valid_program_json(norm="N1", goal="G1")) + manager._initialize_internal_state(prog) + + # Test valid transition (to same phase for simplicity, or we need 2 phases) + # Let's create a program with 2 phases + phase2_id = uuid.uuid4() + phase2 = Phase(id=phase2_id, name="Phase 2", norms=[], goals=[], triggers=[]) + prog.phases.append(phase2) + manager._initialize_internal_state(prog) + + current_phase_id = str(prog.phases[0].id) + next_phase_id = str(phase2_id) + + payload = json.dumps({"old": current_phase_id, "new": next_phase_id}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + assert str(manager._phase.id) == next_phase_id + + # Allow background tasks to run (add_behavior) + await asyncio.sleep(0) + + # Check notifications sent + # 1. beliefs to extractor + # 2. goals to extractor + # 3. notification to user interrupt + + assert manager.send.await_count >= 3 + + # Verify user interrupt notification + calls = manager.send.await_args_list + ui_msgs = [ + c[0][0] for c in calls if c[0][0].to == mock_settings.agent_settings.user_interrupt_name + ] + assert len(ui_msgs) > 0 + assert ui_msgs[-1].body == next_phase_id + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase_desync(): + manager = BDIProgramManager(name="program_manager_test") + manager.logger = MagicMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + + current_phase_id = str(prog.phases[0].id) + + # Request transition from WRONG old phase + payload = json.dumps({"old": "wrong_id", "new": "some_new_id"}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + # Should warn and do nothing + manager.logger.warning.assert_called_once() + assert "Phase transition desync detected" in manager.logger.warning.call_args[0][0] + assert str(manager._phase.id) == current_phase_id + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase_end(mock_settings): + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + current_phase_id = str(prog.phases[0].id) + + payload = json.dumps({"old": current_phase_id, "new": "end"}) + msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase") + + await manager.handle_message(msg) + + assert manager._phase is None + + # Allow background tasks to run (add_behavior) + await asyncio.sleep(0) + + # Verify notification to user interrupt + assert manager.send.await_count == 1 + msg_sent = manager.send.await_args[0][0] + assert msg_sent.to == mock_settings.agent_settings.user_interrupt_name + assert msg_sent.body == "end" + + +@pytest.mark.asyncio +async def test_handle_message_achieve_goal(mock_settings): + mock_settings.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent" + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + prog = Program.model_validate_json(make_valid_program_json(goal="TargetGoal")) + manager._initialize_internal_state(prog) + + goal_id = str(prog.phases[0].goals[0].id) + + msg = InternalMessage(to="me", sender="ui", body=goal_id, thread="achieve_goal") + + await manager.handle_message(msg) + + # Should send achieved goals to text extractor + assert manager.send.await_count == 1 + msg_sent = manager.send.await_args[0][0] + assert msg_sent.to == mock_settings.agent_settings.text_belief_extractor_name + assert msg_sent.thread == "achieved_goals" + + # Verify body + from control_backend.schemas.belief_list import GoalList + + gl = GoalList.model_validate_json(msg_sent.body) + assert len(gl.goals) == 1 + assert gl.goals[0].name == "TargetGoal" + + +@pytest.mark.asyncio +async def test_handle_message_achieve_goal_not_found(): + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + manager.logger = MagicMock() + + prog = Program.model_validate_json(make_valid_program_json()) + manager._initialize_internal_state(prog) + + msg = InternalMessage(to="me", sender="ui", body="non_existent_id", thread="achieve_goal") + + await manager.handle_message(msg) + + manager.send.assert_not_called() + manager.logger.debug.assert_called() + + +@pytest.mark.asyncio +async def test_setup(mock_settings): + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + manager.add_behavior = MagicMock(side_effect=close_coro) + + mock_context = MagicMock() + mock_sub = MagicMock() + mock_context.socket.return_value = mock_sub + + with patch( + "control_backend.agents.bdi.bdi_program_manager.Context.instance", return_value=mock_context + ): + # We also need to mock file writing in _create_agentspeak_and_send_to_bdi + with patch("builtins.open", new_callable=MagicMock): + await manager.setup() + + # Check logic + # 1. Sends default empty program to BDI + assert manager.send.await_count == 1 + assert manager.send.await_args[0][0].to == mock_settings.agent_settings.bdi_core_name + + # 2. Connects SUB socket + mock_sub.connect.assert_called_with(mock_settings.zmq_settings.internal_sub_address) + mock_sub.subscribe.assert_called_with("program") + + # 3. Adds behavior + manager.add_behavior.assert_called() diff --git a/test/unit/agents/bdi/test_belief_collector.py b/test/unit/agents/bdi/test_belief_collector.py deleted file mode 100644 index 67b2ed5..0000000 --- a/test/unit/agents/bdi/test_belief_collector.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -from unittest.mock import AsyncMock - -import pytest - -from control_backend.agents.bdi import ( - BDIBeliefCollectorAgent, -) -from control_backend.core.agent_system import InternalMessage -from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief - - -@pytest.fixture -def agent(): - agent = BDIBeliefCollectorAgent("belief_collector_agent") - return agent - - -def make_msg(body: dict, sender: str = "sender"): - return InternalMessage(to="collector", sender=sender, body=json.dumps(body)) - - -@pytest.mark.asyncio -async def test_handle_message_routes_belief_text(agent, mocker): - """ - Test that when a message is received, _handle_belief_text is called with that message. - """ - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}} - spy = mocker.patch.object(agent, "_handle_belief_text", new_callable=AsyncMock) - - await agent.handle_message(make_msg(payload)) - - spy.assert_awaited_once_with(payload, "sender") - - -@pytest.mark.asyncio -async def test_handle_message_routes_emotion(agent, mocker): - payload = {"type": "emotion_extraction_text"} - spy = mocker.patch.object(agent, "_handle_emo_text", new_callable=AsyncMock) - - await agent.handle_message(make_msg(payload)) - - spy.assert_awaited_once_with(payload, "sender") - - -@pytest.mark.asyncio -async def test_handle_message_bad_json(agent, mocker): - agent._handle_belief_text = AsyncMock() - bad_msg = InternalMessage(to="collector", sender="sender", body="not json") - - await agent.handle_message(bad_msg) - - agent._handle_belief_text.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}} - spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock) - expected = [Belief(name="user_said", arguments=["hello"])] - - await agent._handle_belief_text(payload, "origin") - - spy.assert_awaited_once_with(expected, origin="origin") - - -@pytest.mark.asyncio -async def test_handle_belief_text_no_send_when_empty(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {}} - spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock) - - await agent._handle_belief_text(payload, "origin") - - spy.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_send_beliefs_to_bdi(agent): - agent.send = AsyncMock() - beliefs = [Belief(name="user_said", arguments=["hello", "world"])] - - await agent._send_beliefs_to_bdi(beliefs, origin="origin") - - agent.send.assert_awaited_once() - sent: InternalMessage = agent.send.call_args.args[0] - assert sent.to == settings.agent_settings.bdi_core_name - assert sent.thread == "beliefs" - assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs] - - -@pytest.mark.asyncio -async def test_setup_executes(agent): - """Covers setup and asserts the agent has a name.""" - await agent.setup() - assert agent.name == "belief_collector_agent" # simple property assertion - - -@pytest.mark.asyncio -async def test_handle_message_unrecognized_type_executes(agent): - """Covers the else branch for unrecognized message type.""" - payload = {"type": "unknown_type"} - msg = make_msg(payload, sender="tester") - # Wrap send to ensure nothing is sent - agent.send = AsyncMock() - await agent.handle_message(msg) - # Assert no messages were sent - agent.send.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_emo_text_executes(agent): - """Covers the _handle_emo_text method.""" - # The method does nothing, but we can assert it returns None - result = await agent._handle_emo_text({}, "origin") - assert result is None - - -@pytest.mark.asyncio -async def test_send_beliefs_to_bdi_empty_executes(agent): - """Covers early return when beliefs are empty.""" - agent.send = AsyncMock() - await agent._send_beliefs_to_bdi({}) - # Assert that nothing was sent - agent.send.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_belief_text_invalid_returns_none(agent, mocker): - payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "invalid-argument"}} - - result = await agent._handle_belief_text(payload, "origin") - - # The method itself returns None - assert result is None diff --git a/test/unit/agents/bdi/test_text_belief_extractor.py b/test/unit/agents/bdi/test_text_belief_extractor.py new file mode 100644 index 0000000..353b718 --- /dev/null +++ b/test/unit/agents/bdi/test_text_belief_extractor.py @@ -0,0 +1,554 @@ +import json +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from control_backend.agents.bdi import TextBeliefExtractorAgent +from control_backend.agents.bdi.text_belief_extractor_agent import BeliefState +from control_backend.core.agent_system import InternalMessage +from control_backend.core.config import settings +from control_backend.schemas.belief_list import BeliefList +from control_backend.schemas.belief_message import Belief as InternalBelief +from control_backend.schemas.belief_message import BeliefMessage +from control_backend.schemas.chat_history import ChatHistory, ChatMessage +from control_backend.schemas.program import ( + BaseGoal, # Changed from Goal + ConditionalNorm, + KeywordBelief, + LLMAction, + Phase, + Plan, + Program, + SemanticBelief, + Trigger, +) + + +@pytest.fixture +def llm(): + llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4) + # We must ensure _query_llm returns a dictionary so iterating it doesn't fail + llm._query_llm = AsyncMock(return_value={}) + return llm + + +@pytest.fixture +def agent(llm): + with patch( + "control_backend.agents.bdi.text_belief_extractor_agent.TextBeliefExtractorAgent.LLM", + return_value=llm, + ): + agent = TextBeliefExtractorAgent("text_belief_agent") + agent.send = AsyncMock() + return agent + + +@pytest.fixture +def sample_program(): + return Program( + phases=[ + Phase( + name="Some phase", + id=uuid.uuid4(), + norms=[ + ConditionalNorm( + name="Some norm", + id=uuid.uuid4(), + norm="Use nautical terms.", + critical=False, + condition=SemanticBelief( + name="is_pirate", + id=uuid.uuid4(), + description="The user is a pirate. Perhaps because they say " + "they are, or because they speak like a pirate " + 'with terms like "arr".', + ), + ), + ], + goals=[], + triggers=[ + Trigger( + name="Some trigger", + id=uuid.uuid4(), + condition=SemanticBelief( + name="no_more_booze", + id=uuid.uuid4(), + description="There is no more alcohol.", + ), + plan=Plan( + name="Some plan", + id=uuid.uuid4(), + steps=[ + LLMAction( + name="Some action", + id=uuid.uuid4(), + goal="Suggest eating chocolate instead.", + ), + ], + ), + ), + ], + ), + ], + ) + + +def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage: + return InternalMessage(to="unused", sender=sender, body=body, thread=thread) + + +@pytest.mark.asyncio +async def test_handle_message_ignores_other_agents(agent): + msg = make_msg("unknown", "some data", None) + + await agent.handle_message(msg) + + agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it. + + +@pytest.mark.asyncio +async def test_handle_message_from_transcriber(agent, mock_settings): + transcription = "hello world" + msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None) + + await agent.handle_message(msg) + + agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it. + sent: InternalMessage = agent.send.call_args.args[0] # noqa + assert sent.to == mock_settings.agent_settings.bdi_core_name + assert sent.thread == "beliefs" + parsed = BeliefMessage.model_validate_json(sent.body) + replaced_last = parsed.replace.pop() + assert replaced_last.name == "user_said" + assert replaced_last.arguments == [transcription] + + +@pytest.mark.asyncio +async def test_query_llm(): + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": "null", + } + } + ] + } + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_async_client = MagicMock() + mock_async_client.__aenter__.return_value = mock_client + mock_async_client.__aexit__.return_value = None + + with patch( + "control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient", + return_value=mock_async_client, + ): + llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4) + + res = await llm._query_llm("hello world", {"type": "null"}) + # Response content was set as "null", so should be deserialized as None + assert res is None + + +@pytest.mark.asyncio +async def test_retry_query_llm_success(llm): + llm._query_llm.return_value = None + res = await llm.query("hello world", {"type": "null"}) + + llm._query_llm.assert_called_once() + assert res is None + + +@pytest.mark.asyncio +async def test_retry_query_llm_success_after_failure(llm): + llm._query_llm.side_effect = [KeyError(), "real value"] + res = await llm.query("hello world", {"type": "string"}) + + assert llm._query_llm.call_count == 2 + assert res == "real value" + + +@pytest.mark.asyncio +async def test_retry_query_llm_failures(llm): + llm._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"] + res = await llm.query("hello world", {"type": "string"}) + + assert llm._query_llm.call_count == 3 + assert res is None + + +@pytest.mark.asyncio +async def test_retry_query_llm_fail_immediately(llm): + llm._query_llm.side_effect = [KeyError(), "real value"] + res = await llm.query("hello world", {"type": "string"}, tries=1) + + assert llm._query_llm.call_count == 1 + assert res is None + + +@pytest.mark.asyncio +async def test_extracting_semantic_beliefs(agent): + """ + The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly. + """ + assert len(agent.belief_inferrer.available_beliefs) == 0 + beliefs = BeliefList( + beliefs=[ + KeywordBelief( + id=uuid.uuid4(), + name="keyword_hello", + keyword="hello", + ), + SemanticBelief( + id=uuid.uuid4(), name="semantic_hello_1", description="Some semantic belief 1" + ), + SemanticBelief( + id=uuid.uuid4(), name="semantic_hello_2", description="Some semantic belief 2" + ), + ] + ) + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.bdi_program_manager_name, + body=beliefs.model_dump_json(), + thread="beliefs", + ), + ) + assert len(agent.belief_inferrer.available_beliefs) == 2 + + +@pytest.mark.asyncio +async def test_handle_invalid_beliefs(agent, sample_program): + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + assert len(agent.belief_inferrer.available_beliefs) == 2 + + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.bdi_program_manager_name, + body=json.dumps({"phases": "Invalid"}), + thread="beliefs", + ), + ) + + assert len(agent.belief_inferrer.available_beliefs) == 2 + + +@pytest.mark.asyncio +async def test_handle_robot_response(agent): + initial_length = len(agent.conversation.messages) + response = "Hi, I'm Pepper. What's your name?" + + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.llm_name, + body=response, + ), + ) + + assert len(agent.conversation.messages) == initial_length + 1 + assert agent.conversation.messages[-1].role == "assistant" + assert agent.conversation.messages[-1].content == response + + +@pytest.mark.asyncio +async def test_simulated_real_turn_with_beliefs(agent, llm, sample_program): + """Test sending user message to extract beliefs from.""" + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + + # Send a user message with the belief that there's no more booze + llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": True} + assert len(agent.conversation.messages) == 0 + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.transcription_name, + body="We're all out of schnaps.", + ), + ) + assert len(agent.conversation.messages) == 1 + + # There should be a belief set and sent to the BDI core, as well as the user_said belief + assert agent.send.call_count == 2 + + # First should be the beliefs message + message: InternalMessage = agent.send.call_args_list[1].args[0] + beliefs = BeliefMessage.model_validate_json(message.body) + assert len(beliefs.create) == 1 + assert beliefs.create[0].name == "no_more_booze" + + +@pytest.mark.asyncio +async def test_simulated_real_turn_no_beliefs(agent, llm, sample_program): + """Test a user message to extract beliefs from, but no beliefs are formed.""" + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + + # Send a user message with no new beliefs + llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": None} + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.transcription_name, + body="Hello there!", + ), + ) + + # Only the user_said belief should've been sent + agent.send.assert_called_once() + + +@pytest.mark.asyncio +async def test_simulated_real_turn_no_new_beliefs(agent, llm, sample_program): + """ + Test a user message to extract beliefs from, but no new beliefs are formed because they already + existed. + """ + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + agent._current_beliefs = BeliefState(true={InternalBelief(name="is_pirate", arguments=None)}) + + # Send a user message with the belief the user is a pirate, still + llm._query_llm.return_value = {"is_pirate": True, "no_more_booze": None} + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.transcription_name, + body="Arr, nice to meet you, matey.", + ), + ) + + # Only the user_said belief should've been sent, as no beliefs have changed + agent.send.assert_called_once() + + +@pytest.mark.asyncio +async def test_simulated_real_turn_remove_belief(agent, llm, sample_program): + """ + Test a user message to extract beliefs from, but an existing belief is determined no longer to + hold. + """ + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + agent._current_beliefs = BeliefState( + true={InternalBelief(name="no_more_booze", arguments=None)}, + ) + + # Send a user message with the belief the user is a pirate, still + llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": False} + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.transcription_name, + body="I found an untouched barrel of wine!", + ), + ) + + # Both user_said and belief change should've been sent + assert agent.send.call_count == 2 + + # Agent's current beliefs should've changed + assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false) + + +@pytest.mark.asyncio +async def test_infer_goal_completions_sends_beliefs(agent, llm): + """Test that inferred goal completions are sent to the BDI core.""" + goal = BaseGoal( + id=uuid.uuid4(), name="Say Hello", description="The user said hello", can_fail=True + ) + agent.goal_inferrer.goals = {goal} + + # Mock goal inference: goal is achieved + llm.query = AsyncMock(return_value=True) + + await agent._infer_goal_completions() + + # Should send belief change to BDI core + agent.send.assert_awaited_once() + sent: InternalMessage = agent.send.call_args.args[0] + assert sent.to == settings.agent_settings.bdi_core_name + assert sent.thread == "beliefs" + + parsed = BeliefMessage.model_validate_json(sent.body) + assert len(parsed.create) == 1 + assert parsed.create[0].name == "achieved_say_hello" + + +@pytest.mark.asyncio +async def test_llm_failure_handling(agent, llm, sample_program): + """ + Check that the agent handles failures gracefully without crashing. + """ + llm._query_llm.side_effect = httpx.HTTPError("") + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + + belief_changes = await agent.belief_inferrer.infer_from_conversation( + ChatHistory( + messages=[ChatMessage(role="user", content="Good day!")], + ), + ) + + assert len(belief_changes.true) == 0 + assert len(belief_changes.false) == 0 + + +def test_belief_state_bool(): + # Empty + bs = BeliefState() + assert not bs + + # True set + bs_true = BeliefState(true={InternalBelief(name="a", arguments=None)}) + assert bs_true + + # False set + bs_false = BeliefState(false={InternalBelief(name="a", arguments=None)}) + assert bs_false + + +@pytest.mark.asyncio +async def test_handle_beliefs_message_validation_error(agent, mock_settings): + # Invalid JSON + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="beliefs", + body="invalid json", + ) + # Should log warning and return + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + # Invalid Model + msg.body = json.dumps({"beliefs": [{"invalid": "obj"}]}) + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_handle_goals_message_validation_error(agent, mock_settings): + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="goals", + body="invalid json", + ) + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_handle_goal_achieved_message_validation_error(agent, mock_settings): + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + msg = InternalMessage( + to="me", + sender=mock_settings.agent_settings.bdi_program_manager_name, + thread="achieved_goals", + body="invalid json", + ) + agent.logger = MagicMock() + await agent.handle_message(msg) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_goal_inferrer_infer_from_conversation(agent, llm): + # Setup goals + # Use BaseGoal object as typically received by the extractor + g1 = BaseGoal(id=uuid.uuid4(), name="g1", description="desc", can_fail=True) + + # Use real GoalAchievementInferrer + from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer + + inferrer = GoalAchievementInferrer(llm) + inferrer.goals = {g1} + + # Mock LLM response + llm._query_llm.return_value = True + + completions = await inferrer.infer_from_conversation(ChatHistory(messages=[])) + assert completions + # slugify uses slugify library, hard to predict exact string without it, + # but we can check values + assert list(completions.values())[0] is True + + +def test_apply_conversation_message_limit(agent): + with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s: + mock_s.behaviour_settings.conversation_history_length_limit = 2 + agent.conversation.messages = [] + + agent._apply_conversation_message(ChatMessage(role="user", content="1")) + agent._apply_conversation_message(ChatMessage(role="assistant", content="2")) + agent._apply_conversation_message(ChatMessage(role="user", content="3")) + + assert len(agent.conversation.messages) == 2 + assert agent.conversation.messages[0].content == "2" + assert agent.conversation.messages[1].content == "3" + + +@pytest.mark.asyncio +async def test_handle_program_manager_reset(agent): + with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s: + mock_s.agent_settings.bdi_program_manager_name = "pm" + agent.conversation.messages = [ChatMessage(role="user", content="hi")] + agent.belief_inferrer.available_beliefs = [ + SemanticBelief(id=uuid.uuid4(), name="b", description="d") + ] + + msg = InternalMessage(to="me", sender="pm", thread="conversation_history", body="reset") + await agent.handle_message(msg) + + assert len(agent.conversation.messages) == 0 + assert len(agent.belief_inferrer.available_beliefs) == 0 + + +def test_split_into_chunks(): + from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer + + items = [1, 2, 3, 4, 5] + chunks = SemanticBeliefInferrer._split_into_chunks(items, 2) + assert len(chunks) == 2 + assert len(chunks[0]) + len(chunks[1]) == 5 + + +@pytest.mark.asyncio +async def test_infer_beliefs_call(agent, llm): + from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer + + inferrer = SemanticBeliefInferrer(llm) + sb = SemanticBelief(id=uuid.uuid4(), name="is_happy", description="User is happy") + + llm.query = AsyncMock(return_value={"is_happy": True}) + + res = await inferrer._infer_beliefs(ChatHistory(messages=[]), [sb]) + assert res == {"is_happy": True} + llm.query.assert_called_once() + + +@pytest.mark.asyncio +async def test_infer_goal_call(agent, llm): + from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer + + inferrer = GoalAchievementInferrer(llm) + goal = BaseGoal(id=uuid.uuid4(), name="g1", description="d") + + llm.query = AsyncMock(return_value=True) + + res = await inferrer._infer_goal(ChatHistory(messages=[]), goal) + assert res is True + llm.query.assert_called_once() diff --git a/test/unit/agents/bdi/test_text_extractor.py b/test/unit/agents/bdi/test_text_extractor.py deleted file mode 100644 index 895fef0..0000000 --- a/test/unit/agents/bdi/test_text_extractor.py +++ /dev/null @@ -1,65 +0,0 @@ -import json -from unittest.mock import AsyncMock - -import pytest - -from control_backend.agents.bdi import ( - TextBeliefExtractorAgent, -) -from control_backend.core.agent_system import InternalMessage - - -@pytest.fixture -def agent(): - agent = TextBeliefExtractorAgent("text_belief_agent") - agent.send = AsyncMock() - return agent - - -def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage: - return InternalMessage(to="unused", sender=sender, body=body, thread=thread) - - -@pytest.mark.asyncio -async def test_handle_message_ignores_other_agents(agent): - msg = make_msg("unknown", "some data", None) - - await agent.handle_message(msg) - - agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it. - - -@pytest.mark.asyncio -async def test_handle_message_from_transcriber(agent, mock_settings): - transcription = "hello world" - msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None) - - await agent.handle_message(msg) - - agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it. - sent: InternalMessage = agent.send.call_args.args[0] # noqa - assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name - assert sent.thread == "beliefs" - parsed = json.loads(sent.body) - assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"} - - -@pytest.mark.asyncio -async def test_process_transcription_demo(agent, mock_settings): - transcription = "this is a test" - - await agent._process_transcription_demo(transcription) - - agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it. - sent: InternalMessage = agent.send.call_args.args[0] # noqa - assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name - assert sent.thread == "beliefs" - parsed = json.loads(sent.body) - assert parsed["beliefs"]["user_said"] == [transcription] - - -@pytest.mark.asyncio -async def test_setup_initializes_beliefs(agent): - """Covers the setup method and ensures beliefs are initialized.""" - await agent.setup() - assert agent.beliefs == {"mood": ["X"], "car": ["Y"]} diff --git a/test/unit/agents/communication/test_ri_communication_agent.py b/test/unit/agents/communication/test_ri_communication_agent.py index 06d8766..a678907 100644 --- a/test/unit/agents/communication/test_ri_communication_agent.py +++ b/test/unit/agents/communication/test_ri_communication_agent.py @@ -4,6 +4,8 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent +from control_backend.core.agent_system import InternalMessage +from control_backend.schemas.ri_message import PauseCommand, RIEndpoint def speech_agent_path(): @@ -53,7 +55,11 @@ async def test_setup_success_connects_and_starts_robot(zmq_context): MockGesture.return_value.start = AsyncMock() agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -83,7 +89,11 @@ async def test_setup_binds_when_requested(zmq_context): agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True) - agent.add_behavior = MagicMock() + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) with ( patch(speech_agent_path(), autospec=True) as MockSpeech, @@ -151,6 +161,7 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context): @pytest.mark.asyncio async def test_handle_disconnection_publishes_and_reconnects(): pub_socket = AsyncMock() + pub_socket.close = MagicMock() agent = RICommunicationAgent("ri_comm") agent.pub_socket = pub_socket agent.connected = True @@ -233,6 +244,25 @@ async def test_handle_negotiation_response_unhandled_id(): ) +@pytest.mark.asyncio +async def test_handle_negotiation_response_audio(zmq_context): + agent = RICommunicationAgent("ri_comm") + + with patch( + "control_backend.agents.communication.ri_communication_agent.VADAgent", autospec=True + ) as MockVAD: + MockVAD.return_value.start = AsyncMock() + + await agent._handle_negotiation_response( + {"data": [{"id": "audio", "port": 7000, "bind": False}]} + ) + + MockVAD.assert_called_once_with( + audio_in_address="tcp://localhost:7000", audio_in_bind=False + ) + MockVAD.return_value.start.assert_awaited_once() + + @pytest.mark.asyncio async def test_stop_closes_sockets(): req = MagicMock() @@ -323,6 +353,7 @@ async def test_listen_loop_generic_exception(): @pytest.mark.asyncio async def test_handle_disconnection_timeout(monkeypatch): pub = AsyncMock() + pub.close = MagicMock() pub.send_multipart = AsyncMock(side_effect=TimeoutError) agent = RICommunicationAgent("ri_comm") @@ -365,3 +396,38 @@ async def test_negotiate_req_socket_none_causes_retry(zmq_context): result = await agent._negotiate_connection(max_retries=1) assert result is False + + +@pytest.mark.asyncio +async def test_handle_message_pause_command(zmq_context): + """Test handle_message with a valid PauseCommand.""" + agent = RICommunicationAgent("ri_comm") + agent._req_socket = AsyncMock() + agent.logger = MagicMock() + + agent._req_socket.recv_json.return_value = {"status": "ok"} + + pause_cmd = PauseCommand(data=True) + msg = InternalMessage(to="ri_comm", sender="user_int", body=pause_cmd.model_dump_json()) + + await agent.handle_message(msg) + + agent._req_socket.send_json.assert_awaited_once() + args = agent._req_socket.send_json.await_args[0][0] + assert args["endpoint"] == RIEndpoint.PAUSE.value + assert args["data"] is True + + +@pytest.mark.asyncio +async def test_handle_message_invalid_pause_command(zmq_context): + """Test handle_message with invalid JSON.""" + agent = RICommunicationAgent("ri_comm") + agent._req_socket = AsyncMock() + agent.logger = MagicMock() + + msg = InternalMessage(to="ri_comm", sender="user_int", body="invalid json") + + await agent.handle_message(msg) + + agent.logger.warning.assert_called_with("Incorrect message format for PauseCommand.") + agent._req_socket.send_json.assert_not_called() diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index ef8a3bf..a1cc297 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -58,17 +58,20 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings): to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body=prompt.model_dump_json(), + thread="prompt_message", # REQUIRED: thread must match handle_message logic ) await agent.handle_message(msg) # Verification # "Hello world." constitutes one sentence/chunk based on punctuation split - # The agent should call send once with the full sentence + # The agent should call send once with the full sentence, PLUS once more for full reply assert agent.send.called - args = agent.send.call_args[0][0] - assert args.to == mock_settings.agent_settings.bdi_core_name - assert "Hello world." in args.body + + # Check args. We expect at least one call sending "Hello world." + calls = agent.send.call_args_list + bodies = [c[0][0].body for c in calls] + assert any("Hello world." in b for b in bodies) @pytest.mark.asyncio @@ -80,18 +83,23 @@ async def test_llm_processing_errors(mock_httpx_client, mock_settings): to="llm", sender=mock_settings.agent_settings.bdi_core_name, body=prompt.model_dump_json(), + thread="prompt_message", ) - # HTTP Error + # HTTP Error: stream method RAISES exception immediately mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail")) + await agent.handle_message(msg) - assert "LLM service unavailable." in agent.send.call_args[0][0].body + + # Check that error message was sent + assert agent.send.called + assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body # General Exception agent.send.reset_mock() mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom")) await agent.handle_message(msg) - assert "Error processing the request." in agent.send.call_args[0][0].body + assert "Error processing the request." in agent.send.call_args_list[0][0][0].body @pytest.mark.asyncio @@ -113,16 +121,19 @@ async def test_llm_json_error(mock_httpx_client, mock_settings): agent = LLMAgent("llm_agent") agent.send = AsyncMock() + # Ensure logger is mocked + agent.logger = MagicMock() - with patch.object(agent.logger, "error") as log: - prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) - msg = InternalMessage( - to="llm", - sender=mock_settings.agent_settings.bdi_core_name, - body=prompt.model_dump_json(), - ) - await agent.handle_message(msg) - log.assert_called() # Should log JSONDecodeError + prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) + msg = InternalMessage( + to="llm", + sender=mock_settings.agent_settings.bdi_core_name, + body=prompt.model_dump_json(), + thread="prompt_message", + ) + await agent.handle_message(msg) + + agent.logger.error.assert_called() # Should log JSONDecodeError def test_llm_instructions(): @@ -157,6 +168,7 @@ async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body=invalid_json, + thread="prompt_message", ) await agent.handle_message(msg) @@ -285,3 +297,28 @@ async def test_clear_history_command(mock_settings): ) await agent.handle_message(msg) assert len(agent.history) == 0 + + +@pytest.mark.asyncio +async def test_handle_assistant_and_user_messages(mock_settings): + agent = LLMAgent("llm_agent") + + # Assistant message + msg_ast = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_core_name, + thread="assistant_message", + body="I said this", + ) + await agent.handle_message(msg_ast) + assert agent.history[-1] == {"role": "assistant", "content": "I said this"} + + # User message + msg_usr = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_core_name, + thread="user_message", + body="User said this", + ) + await agent.handle_message(msg_usr) + assert agent.history[-1] == {"role": "user", "content": "User said this"} diff --git a/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py b/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py index 47443a9..518d189 100644 --- a/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py +++ b/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py @@ -55,4 +55,6 @@ def test_get_decode_options(): assert isinstance(options["sample_len"], int) # When disabled, it should not limit output length based on input size - assert "sample_rate" not in options + recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=False) + options = recognizer._get_decode_options(audio) + assert "sample_len" not in options diff --git a/test/unit/agents/perception/transcription_agent/test_transcription_agent.py b/test/unit/agents/perception/transcription_agent/test_transcription_agent.py index ccdaa7f..57875ca 100644 --- a/test/unit/agents/perception/transcription_agent/test_transcription_agent.py +++ b/test/unit/agents/perception/transcription_agent/test_transcription_agent.py @@ -36,7 +36,12 @@ async def test_transcription_agent_flow(mock_zmq_context): agent.send = AsyncMock() agent._running = True - agent.add_behavior = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() @@ -143,7 +148,12 @@ async def test_transcription_loop_continues_after_error(mock_zmq_context): agent = TranscriptionAgent("tcp://in") agent._running = True # ← REQUIRED to enter the loop agent.send = AsyncMock() # should never be called - agent.add_behavior = AsyncMock() # match other tests + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) # match other tests await agent.setup() @@ -180,7 +190,12 @@ async def test_transcription_continue_branch_when_empty(mock_zmq_context): # Make loop runnable agent._running = True agent.send = AsyncMock() - agent.add_behavior = AsyncMock() + + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) await agent.setup() diff --git a/test/unit/agents/perception/vad_agent/test_vad_agent_unit.py b/test/unit/agents/perception/vad_agent/test_vad_agent_unit.py new file mode 100644 index 0000000..3e6b0ad --- /dev/null +++ b/test/unit/agents/perception/vad_agent/test_vad_agent_unit.py @@ -0,0 +1,152 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from control_backend.agents.perception.vad_agent import VADAgent +from control_backend.core.agent_system import InternalMessage +from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus + + +@pytest.fixture(autouse=True) +def mock_zmq(): + with patch("zmq.asyncio.Context") as mock: + mock.instance.return_value = MagicMock() + yield mock + + +@pytest.fixture +def agent(): + return VADAgent("tcp://localhost:5555", False) + + +@pytest.mark.asyncio +async def test_handle_message_pause(agent): + agent._paused = MagicMock() + # It starts set (not paused) + + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="PAUSE") + + # We need to mock settings to match sender name + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.clear.assert_called_once() + assert agent._reset_needed is True + + +@pytest.mark.asyncio +async def test_handle_message_resume(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="RESUME") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.set.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_message_unknown_command(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="UNKNOWN") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + agent.logger = MagicMock() + + await agent.handle_message(msg) + + agent._paused.clear.assert_not_called() + agent._paused.set.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_message_unknown_sender(agent): + agent._paused = MagicMock() + msg = InternalMessage(to="vad", sender="other_agent", body="PAUSE") + + with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + await agent.handle_message(msg) + + agent._paused.clear.assert_not_called() + + +@pytest.mark.asyncio +async def test_status_loop_waits_for_running(agent): + agent._running = True + agent.program_sub_socket = AsyncMock() + agent.program_sub_socket.close = MagicMock() + agent._reset_stream = AsyncMock() + + # Sequence of messages: + # 1. Wrong topic + # 2. Right topic, wrong status (STARTING) + # 3. Right topic, RUNNING -> Should break loop + + agent.program_sub_socket.recv_multipart.side_effect = [ + (b"wrong_topic", b"whatever"), + (PROGRAM_STATUS, ProgramStatus.STARTING.value), + (PROGRAM_STATUS, ProgramStatus.RUNNING.value), + ] + + await agent._status_loop() + + assert agent._reset_stream.await_count == 1 + agent.program_sub_socket.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_setup_success(agent, mock_zmq): + def close_coro(coro): + coro.close() + return MagicMock() + + agent.add_behavior = MagicMock(side_effect=close_coro) + + mock_context = mock_zmq.instance.return_value + mock_sub = MagicMock() + mock_pub = MagicMock() + + # We expect multiple socket calls: + # 1. audio_in (SUB) + # 2. audio_out (PUB) + # 3. program_sub (SUB) + mock_context.socket.side_effect = [mock_sub, mock_pub, mock_sub] + + with patch("control_backend.agents.perception.vad_agent.torch.hub.load") as mock_load: + mock_load.return_value = (MagicMock(), None) + + with patch("control_backend.agents.perception.vad_agent.TranscriptionAgent") as MockTrans: + mock_trans_instance = MockTrans.return_value + mock_trans_instance.start = AsyncMock() + + await agent.setup() + + mock_trans_instance.start.assert_awaited_once() + + assert agent.add_behavior.call_count == 2 # streaming_loop + status_loop + assert agent.audio_in_socket is not None + assert agent.audio_out_socket is not None + assert agent.program_sub_socket is not None + + +@pytest.mark.asyncio +async def test_reset_stream(agent): + mock_poller = MagicMock() + agent.audio_in_poller = mock_poller + + # poll(1) returns not None twice, then None + mock_poller.poll = AsyncMock(side_effect=[b"data", b"data", None]) + + agent._ready = MagicMock() + + await agent._reset_stream() + + assert mock_poller.poll.await_count == 3 + agent._ready.set.assert_called_once() diff --git a/test/unit/agents/perception/vad_agent/test_vad_streaming.py b/test/unit/agents/perception/vad_agent/test_vad_streaming.py index 166919f..349fab2 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_streaming.py +++ b/test/unit/agents/perception/vad_agent/test_vad_streaming.py @@ -5,6 +5,7 @@ import pytest import zmq from control_backend.agents.perception.vad_agent import VADAgent +from control_backend.core.config import settings # We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets @@ -135,6 +136,54 @@ async def test_no_data(audio_out_socket, vad_agent): assert len(vad_agent.audio_buffer) == 0 +@pytest.mark.asyncio +async def test_streaming_loop_reset_needed(audio_out_socket, vad_agent): + """Test that _reset_needed branch works as expected.""" + vad_agent._reset_needed = True + vad_agent._ready.set() + vad_agent._paused.set() + vad_agent._running = True + vad_agent.audio_buffer = np.array([1.0], dtype=np.float32) + vad_agent.i_since_speech = 0 + + # Mock _reset_stream to stop the loop by setting _running=False + async def mock_reset(): + vad_agent._running = False + + vad_agent._reset_stream = mock_reset + + # Needs a poller to avoid AssertionError + vad_agent.audio_in_poller = AsyncMock() + vad_agent.audio_in_poller.poll.return_value = None + + await vad_agent._streaming_loop() + + assert vad_agent._reset_needed is False + assert len(vad_agent.audio_buffer) == 0 + assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech + + +@pytest.mark.asyncio +async def test_streaming_loop_no_data_clears_buffer(audio_out_socket, vad_agent): + """Test that if poll returns None, buffer is cleared if not empty.""" + vad_agent.audio_buffer = np.array([1.0], dtype=np.float32) + vad_agent._ready.set() + vad_agent._paused.set() + vad_agent._running = True + + class MockPoller: + async def poll(self, timeout_ms=None): + vad_agent._running = False # stop after one poll + return None + + vad_agent.audio_in_poller = MockPoller() + + await vad_agent._streaming_loop() + + assert len(vad_agent.audio_buffer) == 0 + assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech + + @pytest.mark.asyncio async def test_vad_model_load_failure_stops_agent(vad_agent): """ diff --git a/test/unit/agents/test_base.py b/test/unit/agents/test_base.py new file mode 100644 index 0000000..0579ada --- /dev/null +++ b/test/unit/agents/test_base.py @@ -0,0 +1,24 @@ +import logging + +from control_backend.agents.base import BaseAgent + + +class MyAgent(BaseAgent): + async def setup(self): + pass + + async def handle_message(self, msg): + pass + + +def test_base_agent_logger_init(): + # When defining a subclass, __init_subclass__ runs + # The BaseAgent in agents/base.py sets the logger + assert hasattr(MyAgent, "logger") + assert isinstance(MyAgent.logger, logging.Logger) + # The logger name depends on the package. + # Since this test file is running as a module, __package__ might be None or the test package. + # In 'src/control_backend/agents/base.py', it uses __package__ of base.py which is + # 'control_backend.agents'. + # So logger name should be control_backend.agents.MyAgent + assert MyAgent.logger.name == "control_backend.agents.MyAgent" diff --git a/test/unit/agents/user_interrupt/test_user_interrupt.py b/test/unit/agents/user_interrupt/test_user_interrupt.py index 7e3e700..7a71891 100644 --- a/test/unit/agents/user_interrupt/test_user_interrupt.py +++ b/test/unit/agents/user_interrupt/test_user_interrupt.py @@ -7,6 +7,15 @@ import pytest from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.program import ( + ConditionalNorm, + Goal, + KeywordBelief, + Phase, + Plan, + Program, + Trigger, +) from control_backend.schemas.ri_message import RIEndpoint @@ -16,6 +25,7 @@ def agent(): agent.send = AsyncMock() agent.logger = MagicMock() agent.sub_socket = AsyncMock() + agent.pub_socket = AsyncMock() return agent @@ -49,21 +59,18 @@ async def test_send_to_gesture_agent(agent): @pytest.mark.asyncio -async def test_send_to_program_manager(agent): +async def test_send_to_bdi_belief(agent): """Verify belief update format.""" - context_str = "2" + context_str = "some_goal" - await agent._send_to_program_manager(context_str) + await agent._send_to_bdi_belief(context_str, "goal") - agent.send.assert_awaited_once() - sent_msg: InternalMessage = agent.send.call_args.args[0] + assert agent.send.await_count == 1 + sent_msg = agent.send.call_args.args[0] - assert sent_msg.to == settings.agent_settings.bdi_program_manager_name - assert sent_msg.thread == "belief_override_id" - - body = json.loads(sent_msg.body) - - assert body["belief"] == context_str + assert sent_msg.to == settings.agent_settings.bdi_core_name + assert sent_msg.thread == "beliefs" + assert "achieved_some_goal" in sent_msg.body @pytest.mark.asyncio @@ -77,6 +84,10 @@ async def test_receive_loop_routing_success(agent): # Prepare JSON payloads as bytes payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode() payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode() + # override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal). + + # To test routing, we need to populate the maps + agent._goal_map["Hello Override"] = "some_goal_slug" payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode() agent.sub_socket.recv_multipart.side_effect = [ @@ -88,7 +99,7 @@ async def test_receive_loop_routing_success(agent): agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() - agent._send_to_program_manager = AsyncMock() + agent._send_to_bdi_belief = AsyncMock() try: await agent._receive_button_event() @@ -103,12 +114,12 @@ async def test_receive_loop_routing_success(agent): # Gesture agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture") - # Override - agent._send_to_program_manager.assert_awaited_once_with("Hello Override") + # Override (since we mapped it to a goal) + agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug", "goal") assert agent._send_to_speech_agent.await_count == 1 assert agent._send_to_gesture_agent.await_count == 1 - assert agent._send_to_program_manager.await_count == 1 + assert agent._send_to_bdi_belief.await_count == 1 @pytest.mark.asyncio @@ -125,7 +136,6 @@ async def test_receive_loop_unknown_type(agent): agent._send_to_speech_agent = AsyncMock() agent._send_to_gesture_agent = AsyncMock() - agent._send_to_belief_collector = AsyncMock() try: await agent._receive_button_event() @@ -137,10 +147,165 @@ async def test_receive_loop_unknown_type(agent): # Ensure no handlers were called agent._send_to_speech_agent.assert_not_called() agent._send_to_gesture_agent.assert_not_called() - agent._send_to_belief_collector.assert_not_called() - agent.logger.warning.assert_called_with( - "Received button press with unknown type '%s' (context: '%s').", - "unknown_thing", - "some_data", - ) + agent.logger.warning.assert_called() + + +@pytest.mark.asyncio +async def test_create_mapping(agent): + # Create a program with a trigger, goal, and conditional norm + import uuid + + trigger_id = uuid.uuid4() + goal_id = uuid.uuid4() + norm_id = uuid.uuid4() + + cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key") + plan = Plan(id=uuid.uuid4(), name="p1", steps=[]) + + trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan) + goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan) + + cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond) + + phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger]) + prog = Program(phases=[phase]) + + # Call create_mapping via handle_message + msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json()) + await agent.handle_message(msg) + + # Check maps + assert str(trigger_id) in agent._trigger_map + assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger" + + assert str(goal_id) in agent._goal_map + assert agent._goal_map[str(goal_id)] == "my_goal" + + assert str(norm_id) in agent._cond_norm_map + assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite" + + +@pytest.mark.asyncio +async def test_create_mapping_invalid_json(agent): + # Pass invalid json to handle_message thread "new_program" + msg = InternalMessage(to="me", thread="new_program", body="invalid json") + await agent.handle_message(msg) + + # Should log error and maps should remain empty or cleared + agent.logger.error.assert_called() + + +@pytest.mark.asyncio +async def test_handle_message_trigger_start(agent): + # Setup reverse map manually + agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" + + msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + args = agent.pub_socket.send_multipart.call_args[0][0] + assert args[0] == b"experiment" + payload = json.loads(args[1]) + assert payload["type"] == "trigger_update" + assert payload["id"] == "ui_id_123" + assert payload["achieved"] is True + + +@pytest.mark.asyncio +async def test_handle_message_trigger_end(agent): + agent._trigger_reverse_map["trigger_slug"] = "ui_id_123" + + msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "trigger_update" + assert payload["achieved"] is False + + +@pytest.mark.asyncio +async def test_handle_message_transition_phase(agent): + msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "phase_update" + assert payload["id"] == "phase_id_123" + + +@pytest.mark.asyncio +async def test_handle_message_goal_start(agent): + agent._goal_reverse_map["goal_slug"] = "goal_id_123" + + msg = InternalMessage(to="me", thread="goal_start", body="goal_slug") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "goal_update" + assert payload["id"] == "goal_id_123" + assert payload["active"] is True + + +@pytest.mark.asyncio +async def test_handle_message_active_norms_update(agent): + agent._cond_norm_reverse_map["norm_active"] = "id_1" + agent._cond_norm_reverse_map["norm_inactive"] = "id_2" + + # Body is like: "('norm_active', 'other')" + # The split logic handles quotes etc. + msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'") + await agent.handle_message(msg) + + agent.pub_socket.send_multipart.assert_awaited_once() + payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1]) + assert payload["type"] == "cond_norms_state_update" + norms = {n["id"]: n["active"] for n in payload["norms"]} + assert norms["id_1"] is True + assert norms["id_2"] is False + + +@pytest.mark.asyncio +async def test_send_experiment_control(agent): + # Test next_phase + await agent._send_experiment_control_to_bdi_core("next_phase") + agent.send.assert_awaited() + msg = agent.send.call_args[0][0] + assert msg.thread == "force_next_phase" + + # Test reset_phase + await agent._send_experiment_control_to_bdi_core("reset_phase") + msg = agent.send.call_args[0][0] + assert msg.thread == "reset_current_phase" + + # Test reset_experiment + await agent._send_experiment_control_to_bdi_core("reset_experiment") + msg = agent.send.call_args[0][0] + assert msg.thread == "reset_experiment" + + +@pytest.mark.asyncio +async def test_send_pause_command(agent): + await agent._send_pause_command("true") + # Sends to RI and VAD + assert agent.send.await_count == 2 + msgs = [call.args[0] for call in agent.send.call_args_list] + + ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name) + assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint + assert json.loads(ri_msg.body)["data"] is True + + vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name) + assert vad_msg.body == "PAUSE" + + agent.send.reset_mock() + await agent._send_pause_command("false") + assert agent.send.await_count == 2 + vad_msg = next( + m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name + ).args[0] + assert vad_msg.body == "RESUME" diff --git a/test/unit/api/v1/endpoints/test_program_endpoint.py b/test/unit/api/v1/endpoints/test_program_endpoint.py index 178159c..c1a3fd9 100644 --- a/test/unit/api/v1/endpoints/test_program_endpoint.py +++ b/test/unit/api/v1/endpoints/test_program_endpoint.py @@ -1,4 +1,5 @@ import json +import uuid from unittest.mock import AsyncMock import pytest @@ -6,7 +7,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from control_backend.api.v1.endpoints import program -from control_backend.schemas.program import Program +from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program @pytest.fixture @@ -25,29 +26,39 @@ def client(app): def make_valid_program_dict(): """Helper to create a valid Program JSON structure.""" - return { - "phases": [ - { - "id": "phase1", - "label": "basephase", - "norms": [{"id": "n1", "label": "norm", "norm": "be nice"}], - "goals": [ - {"id": "g1", "label": "goal", "description": "test goal", "achieved": False} + # Converting to JSON using Pydantic because it knows how to convert a UUID object + program_json_str = Program( + phases=[ + Phase( + id=uuid.uuid4(), + name="Basic Phase", + norms=[ + BasicNorm( + id=uuid.uuid4(), + name="Some norm", + norm="Do normal.", + ), ], - "triggers": [ - { - "id": "t1", - "label": "trigger", - "type": "keywords", - "keywords": [ - {"id": "kw1", "keyword": "keyword1"}, - {"id": "kw2", "keyword": "keyword2"}, - ], - }, + goals=[ + Goal( + id=uuid.uuid4(), + name="Some goal", + description="This description can be used to determine whether the goal " + "has been achieved.", + plan=Plan( + id=uuid.uuid4(), + name="Goal Plan", + steps=[], + ), + can_fail=False, + ), ], - } - ] - } + triggers=[], + ), + ], + ).model_dump_json() + # Converting back to a dict because that's what's expected + return json.loads(program_json_str) def test_receive_program_success(client): @@ -71,7 +82,8 @@ def test_receive_program_success(client): sent_bytes = args[0][1] sent_obj = json.loads(sent_bytes.decode()) - expected_obj = Program.model_validate(program_dict).model_dump() + # Converting to JSON using Pydantic because it knows how to handle UUIDs + expected_obj = json.loads(Program.model_validate(program_dict).model_dump_json()) assert sent_obj == expected_obj diff --git a/test/unit/api/v1/endpoints/test_router.py b/test/unit/api/v1/endpoints/test_router.py index 7303d9c..dd93d8d 100644 --- a/test/unit/api/v1/endpoints/test_router.py +++ b/test/unit/api/v1/endpoints/test_router.py @@ -11,6 +11,5 @@ def test_router_includes_expected_paths(): # Ensure at least one route under each prefix exists assert any(p.startswith("/robot") for p in paths) assert any(p.startswith("/message") for p in paths) - assert any(p.startswith("/sse") for p in paths) assert any(p.startswith("/logs") for p in paths) assert any(p.startswith("/program") for p in paths) diff --git a/test/unit/api/v1/endpoints/test_sse_endpoint.py b/test/unit/api/v1/endpoints/test_sse_endpoint.py deleted file mode 100644 index 75a4555..0000000 --- a/test/unit/api/v1/endpoints/test_sse_endpoint.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest -from fastapi import FastAPI -from fastapi.testclient import TestClient - -from control_backend.api.v1.endpoints import sse - - -@pytest.fixture -def app(): - app = FastAPI() - app.include_router(sse.router) - return app - - -@pytest.fixture -def client(app): - return TestClient(app) - - -def test_sse_route_exists(client): - """Minimal smoke test to ensure /sse route exists and responds.""" - response = client.get("/sse") - # Since implementation is not done, we only assert it doesn't crash - assert response.status_code == 200 diff --git a/test/unit/api/v1/endpoints/test_user_interact.py b/test/unit/api/v1/endpoints/test_user_interact.py new file mode 100644 index 0000000..ddb9932 --- /dev/null +++ b/test/unit/api/v1/endpoints/test_user_interact.py @@ -0,0 +1,96 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from control_backend.api.v1.endpoints import user_interact + + +@pytest.fixture +def app(): + app = FastAPI() + app.include_router(user_interact.router) + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +@pytest.mark.asyncio +async def test_receive_button_event(client): + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket + + payload = {"type": "speech", "context": "hello"} + response = client.post("/button_pressed", json=payload) + + assert response.status_code == 202 + assert response.json() == {"status": "Event received"} + + mock_pub_socket.send_multipart.assert_awaited_once() + args = mock_pub_socket.send_multipart.call_args[0][0] + assert args[0] == b"button_pressed" + assert "speech" in args[1].decode() + + +@pytest.mark.asyncio +async def test_receive_button_event_invalid_payload(client): + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket + + # Missing context + payload = {"type": "speech"} + response = client.post("/button_pressed", json=payload) + + assert response.status_code == 422 + mock_pub_socket.send_multipart.assert_not_called() + + +@pytest.mark.asyncio +async def test_experiment_stream_direct_call(): + """ + Directly calling the endpoint function to test the streaming logic + without dealing with TestClient streaming limitations. + """ + mock_socket = AsyncMock() + # 1. recv data + # 2. recv timeout + # 3. disconnect (request.is_disconnected returns True) + mock_socket.recv_multipart.side_effect = [ + (b"topic", b"message1"), + TimeoutError(), + (b"topic", b"message2"), # Should not be reached if disconnect checks work + ] + mock_socket.close = MagicMock() + mock_socket.connect = MagicMock() + mock_socket.subscribe = MagicMock() + + mock_context = MagicMock() + mock_context.socket.return_value = mock_socket + + with patch( + "control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context + ): + mock_request = AsyncMock() + # is_disconnected sequence: + # 1. False (before first recv) -> reads message1 + # 2. False (before second recv) -> triggers TimeoutError, continues + # 3. True (before third recv) -> break loop + mock_request.is_disconnected.side_effect = [False, False, True] + + response = await user_interact.experiment_stream(mock_request) + + lines = [] + # Consume the generator + async for line in response.body_iterator: + lines.append(line) + + assert "data: message1\n\n" in lines + assert len(lines) == 1 + + mock_socket.connect.assert_called() + mock_socket.subscribe.assert_called_with(b"experiment") + mock_socket.close.assert_called() diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 6ab989e..5e925d0 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -25,7 +25,6 @@ def mock_settings(): mock.zmq_settings.internal_sub_address = "tcp://localhost:5561" mock.zmq_settings.ri_command_address = "tcp://localhost:0000" mock.agent_settings.bdi_core_name = "bdi_core_agent" - mock.agent_settings.bdi_belief_collector_name = "belief_collector_agent" mock.agent_settings.llm_name = "llm_agent" mock.agent_settings.robot_speech_name = "robot_speech_agent" mock.agent_settings.transcription_name = "transcription_agent" @@ -33,6 +32,7 @@ def mock_settings(): mock.agent_settings.vad_name = "vad_agent" mock.behaviour_settings.sleep_s = 0.01 # Speed up tests mock.behaviour_settings.comm_setup_max_retries = 1 + mock.behaviour_settings.agentspeak_file = "src/control_backend/agents/bdi/agentspeak.asl" yield mock diff --git a/test/unit/schemas/test_ui_program_message.py b/test/unit/schemas/test_ui_program_message.py index 7ed544e..6f6d5fd 100644 --- a/test/unit/schemas/test_ui_program_message.py +++ b/test/unit/schemas/test_ui_program_message.py @@ -1,49 +1,66 @@ +import uuid + import pytest from pydantic import ValidationError from control_backend.schemas.program import ( + BasicNorm, + ConditionalNorm, Goal, - KeywordTrigger, - Norm, + InferredBelief, + KeywordBelief, + LogicalOperator, Phase, + Plan, Program, - TriggerKeyword, + SemanticBelief, + Trigger, ) -def base_norm() -> Norm: - return Norm( - id="norm1", - label="testNorm", +def base_norm() -> BasicNorm: + return BasicNorm( + id=uuid.uuid4(), + name="testNormName", norm="testNormNorm", + critical=False, ) def base_goal() -> Goal: return Goal( - id="goal1", - label="testGoal", - description="testGoalDescription", - achieved=False, + id=uuid.uuid4(), + name="testGoalName", + description="This description can be used to determine whether the goal has been achieved.", + plan=Plan( + id=uuid.uuid4(), + name="testGoalPlanName", + steps=[], + ), + can_fail=False, ) -def base_trigger() -> KeywordTrigger: - return KeywordTrigger( - id="trigger1", - label="testTrigger", - type="keywords", - keywords=[ - TriggerKeyword(id="keyword1", keyword="testKeyword1"), - TriggerKeyword(id="keyword1", keyword="testKeyword2"), - ], +def base_trigger() -> Trigger: + return Trigger( + id=uuid.uuid4(), + name="testTriggerName", + condition=KeywordBelief( + id=uuid.uuid4(), + name="testTriggerKeywordBeliefTriggerName", + keyword="Keyword", + ), + plan=Plan( + id=uuid.uuid4(), + name="testTriggerPlanName", + steps=[], + ), ) def base_phase() -> Phase: return Phase( - id="phase1", - label="basephase", + id=uuid.uuid4(), norms=[base_norm()], goals=[base_goal()], triggers=[base_trigger()], @@ -58,7 +75,7 @@ def invalid_program() -> dict: # wrong types inside phases list (not Phase objects) return { "phases": [ - {"id": "phase1"}, # incomplete + {"id": uuid.uuid4()}, # incomplete {"not_a_phase": True}, ] } @@ -77,11 +94,112 @@ def test_valid_deepprogram(): # validate nested components directly phase = validated.phases[0] assert isinstance(phase.goals[0], Goal) - assert isinstance(phase.triggers[0], KeywordTrigger) - assert isinstance(phase.norms[0], Norm) + assert isinstance(phase.triggers[0], Trigger) + assert isinstance(phase.norms[0], BasicNorm) def test_invalid_program(): bad = invalid_program() with pytest.raises(ValidationError): Program.model_validate(bad) + + +def test_conditional_norm_parsing(): + """ + Check that pydantic is able to preserve the type of the norm, that it doesn't lose its + "condition" field when serializing and deserializing. + """ + norm = ConditionalNorm( + name="testNormName", + id=uuid.uuid4(), + norm="testNormNorm", + critical=False, + condition=KeywordBelief( + name="testKeywordBelief", + id=uuid.uuid4(), + keyword="testKeywordBelief", + ), + ) + program = Program( + phases=[ + Phase( + name="Some phase", + id=uuid.uuid4(), + norms=[norm], + goals=[], + triggers=[], + ), + ], + ) + + parsed_program = Program.model_validate_json(program.model_dump_json()) + parsed_norm = parsed_program.phases[0].norms[0] + + assert hasattr(parsed_norm, "condition") + assert isinstance(parsed_norm, ConditionalNorm) + + +def test_belief_type_parsing(): + """ + Check that pydantic is able to discern between the different types of beliefs when serializing + and deserializing. + """ + keyword_belief = KeywordBelief( + name="testKeywordBelief", + id=uuid.uuid4(), + keyword="something", + ) + semantic_belief = SemanticBelief( + name="testSemanticBelief", + id=uuid.uuid4(), + description="something", + ) + inferred_belief = InferredBelief( + name="testInferredBelief", + id=uuid.uuid4(), + operator=LogicalOperator.OR, + left=keyword_belief, + right=semantic_belief, + ) + + program = Program( + phases=[ + Phase( + name="Some phase", + id=uuid.uuid4(), + norms=[], + goals=[], + triggers=[ + Trigger( + name="testTriggerKeywordTrigger", + id=uuid.uuid4(), + condition=keyword_belief, + plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]), + ), + Trigger( + name="testTriggerSemanticTrigger", + id=uuid.uuid4(), + condition=semantic_belief, + plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]), + ), + Trigger( + name="testTriggerInferredTrigger", + id=uuid.uuid4(), + condition=inferred_belief, + plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]), + ), + ], + ), + ], + ) + + parsed_program = Program.model_validate_json(program.model_dump_json()) + + parsed_keyword_belief = parsed_program.phases[0].triggers[0].condition + assert isinstance(parsed_keyword_belief, KeywordBelief) + + parsed_semantic_belief = parsed_program.phases[0].triggers[1].condition + assert isinstance(parsed_semantic_belief, SemanticBelief) + + parsed_inferred_belief = parsed_program.phases[0].triggers[2].condition + assert isinstance(parsed_inferred_belief, InferredBelief) diff --git a/test/unit/test_main_sockets.py b/test/unit/test_main_sockets.py new file mode 100644 index 0000000..662147a --- /dev/null +++ b/test/unit/test_main_sockets.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +import zmq + +from control_backend.main import setup_sockets + + +def test_setup_sockets_proxy(): + mock_context = MagicMock() + mock_pub = MagicMock() + mock_sub = MagicMock() + + mock_context.socket.side_effect = [mock_pub, mock_sub] + + with patch("zmq.asyncio.Context.instance", return_value=mock_context): + with patch("zmq.proxy") as mock_proxy: + setup_sockets() + + mock_pub.bind.assert_called() + mock_sub.bind.assert_called() + mock_proxy.assert_called_with(mock_sub, mock_pub) + + # Check cleanup + mock_pub.close.assert_called() + mock_sub.close.assert_called() + + +def test_setup_sockets_proxy_error(): + mock_context = MagicMock() + mock_pub = MagicMock() + mock_sub = MagicMock() + mock_context.socket.side_effect = [mock_pub, mock_sub] + + with patch("zmq.asyncio.Context.instance", return_value=mock_context): + with patch("zmq.proxy", side_effect=zmq.ZMQError): + with patch("control_backend.main.logger") as mock_logger: + setup_sockets() + mock_logger.warning.assert_called() + mock_pub.close.assert_called() + mock_sub.close.assert_called() diff --git a/uv.lock b/uv.lock index ff4b8a7..ea39c17 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.13" resolution-markers = [ "python_full_version >= '3.14'", @@ -997,6 +997,7 @@ dependencies = [ { name = "pydantic" }, { name = "pydantic-settings" }, { name = "python-json-logger" }, + { name = "python-slugify" }, { name = "pyyaml" }, { name = "pyzmq" }, { name = "silero-vad" }, @@ -1029,6 +1030,7 @@ test = [ { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-mock" }, + { name = "python-slugify" }, { name = "pyyaml" }, { name = "pyzmq" }, { name = "soundfile" }, @@ -1046,6 +1048,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.12.0" }, { name = "pydantic-settings", specifier = ">=2.11.0" }, { name = "python-json-logger", specifier = ">=4.0.0" }, + { name = "python-slugify", specifier = ">=8.0.4" }, { name = "pyyaml", specifier = ">=6.0.3" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "silero-vad", specifier = ">=6.0.0" }, @@ -1078,6 +1081,7 @@ test = [ { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-mock", specifier = ">=3.15.1" }, + { name = "python-slugify", specifier = ">=8.0.4" }, { name = "pyyaml", specifier = ">=6.0.3" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "soundfile", specifier = ">=0.13.1" }, @@ -1341,6 +1345,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" }, ] +[[package]] +name = "python-slugify" +version = "8.0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "text-unidecode" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/c7/5e1547c44e31da50a460df93af11a535ace568ef89d7a811069ead340c4a/python-slugify-8.0.4.tar.gz", hash = "sha256:59202371d1d05b54a9e7720c5e038f928f45daaffe41dd10822f3907b937c856", size = 10921, upload-time = "2024-02-08T18:32:45.488Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/62/02da182e544a51a5c3ccf4b03ab79df279f9c60c5e82d5e8bec7ca26ac11/python_slugify-8.0.4-py2.py3-none-any.whl", hash = "sha256:276540b79961052b66b7d116620b36518847f52d5fd9e3a70164fc8c50faa6b8", size = 10051, upload-time = "2024-02-08T18:32:43.911Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3" @@ -1864,6 +1880,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] +[[package]] +name = "text-unidecode" +version = "1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ab/e2/e9a00f0ccb71718418230718b3d900e71a5d16e701a3dae079a21e9cd8f8/text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93", size = 76885, upload-time = "2019-08-30T21:36:45.405Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/a5/c0b6468d3824fe3fde30dbb5e1f687b291608f9473681bbf7dabbf5a87d7/text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8", size = 78154, upload-time = "2019-08-30T21:37:03.543Z" }, +] + [[package]] name = "tiktoken" version = "0.12.0"