diff --git a/.gitignore b/.gitignore index b6490a9..41b7458 100644 --- a/.gitignore +++ b/.gitignore @@ -223,7 +223,8 @@ docs/* !docs/conf.py # Generated files -agentspeak.asl +*.asl +experiment-*.log diff --git a/.logging_config.yaml b/.logging_config.yaml index 4af5d56..f7cccf9 100644 --- a/.logging_config.yaml +++ b/.logging_config.yaml @@ -1,36 +1,57 @@ version: 1 custom_levels: - OBSERVATION: 25 - ACTION: 26 + OBSERVATION: 24 + ACTION: 25 + CHAT: 26 LLM: 9 formatters: # Console output colored: - (): "colorlog.ColoredFormatter" + class: colorlog.ColoredFormatter format: "{log_color}{asctime}.{msecs:03.0f} | {levelname:11} | {name:70} | {message}" style: "{" datefmt: "%H:%M:%S" # User-facing UI (structured JSON) - json_experiment: - (): "pythonjsonlogger.jsonlogger.JsonFormatter" + json: + class: pythonjsonlogger.jsonlogger.JsonFormatter format: "{name} {levelname} {levelno} {message} {created} {relativeCreated}" style: "{" + # Experiment stream for console and file output, with optional `role` field + experiment: + class: control_backend.logging.OptionalFieldFormatter + format: "%(asctime)s %(levelname)s %(role?)s %(message)s" + defaults: + role: "-" + +filters: + # Filter out any log records that have the extra field "partial" set to True, indicating that they + # will be replaced later. + partial: + (): control_backend.logging.PartialFilter + handlers: console: class: logging.StreamHandler level: DEBUG formatter: colored + filters: [partial] stream: ext://sys.stdout ui: class: zmq.log.handlers.PUBHandler level: LLM - formatter: json_experiment + formatter: json + file: + class: control_backend.logging.DatedFileHandler + formatter: experiment + filters: [partial] + # Directory must match config.logging_settings.experiment_log_directory + file_prefix: experiment_logs/experiment -# Level of external libraries +# Level for external libraries root: level: WARN handlers: [console] @@ -39,3 +60,6 @@ loggers: control_backend: level: LLM handlers: [ui] + experiment: # This name must match config.logging_settings.experiment_logger_name + level: DEBUG + handlers: [ui, file] diff --git a/src/control_backend/agents/actuation/robot_gesture_agent.py b/src/control_backend/agents/actuation/robot_gesture_agent.py index 997b684..9567940 100644 --- a/src/control_backend/agents/actuation/robot_gesture_agent.py +++ b/src/control_backend/agents/actuation/robot_gesture_agent.py @@ -1,4 +1,5 @@ import json +import logging import zmq import zmq.asyncio as azmq @@ -8,6 +9,8 @@ from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings from control_backend.schemas.ri_message import GestureCommand, RIEndpoint +experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name) + class RobotGestureAgent(BaseAgent): """ @@ -111,6 +114,7 @@ class RobotGestureAgent(BaseAgent): gesture_command.data, ) return + experiment_logger.action("Gesture: %s", gesture_command.data) await self.pubsocket.send_json(gesture_command.model_dump()) except Exception: self.logger.exception("Error processing internal message.") diff --git a/src/control_backend/agents/base.py b/src/control_backend/agents/base.py index ec50af5..beccdaa 100644 --- a/src/control_backend/agents/base.py +++ b/src/control_backend/agents/base.py @@ -1,9 +1,10 @@ import logging +from abc import ABC from control_backend.core.agent_system import BaseAgent as CoreBaseAgent -class BaseAgent(CoreBaseAgent): +class BaseAgent(CoreBaseAgent, ABC): """ The primary base class for all implementation agents. diff --git a/src/control_backend/agents/bdi/agentspeak_ast.py b/src/control_backend/agents/bdi/agentspeak_ast.py index 19f48e2..12c7947 100644 --- a/src/control_backend/agents/bdi/agentspeak_ast.py +++ b/src/control_backend/agents/bdi/agentspeak_ast.py @@ -8,31 +8,78 @@ from enum import StrEnum class AstNode(ABC): """ Abstract base class for all elements of an AgentSpeak program. + + This class serves as the foundation for all AgentSpeak abstract syntax tree (AST) nodes. + It defines the core interface that all AST nodes must implement to generate AgentSpeak code. """ @abstractmethod def _to_agentspeak(self) -> str: """ Generates the AgentSpeak code string. + + This method converts the AST node into its corresponding + AgentSpeak source code representation. + + :return: The AgentSpeak code string representation of this node. """ pass def __str__(self) -> str: + """ + Returns the string representation of this AST node. + + This method provides a convenient way to get the AgentSpeak code representation + by delegating to the _to_agentspeak method. + + :return: The AgentSpeak code string representation of this node. + """ return self._to_agentspeak() class AstExpression(AstNode, ABC): """ Intermediate class for anything that can be used in a logical expression. + + This class extends AstNode to provide common functionality for all expressions + that can be used in logical operations within AgentSpeak programs. """ def __and__(self, other: ExprCoalescible) -> AstBinaryOp: + """ + Creates a logical AND operation between this expression and another. + + This method allows for operator overloading of the & operator to create + binary logical operations in a more intuitive syntax. + + :param other: The right-hand side expression to combine with this one. + :return: A new AstBinaryOp representing the logical AND operation. + """ return AstBinaryOp(self, BinaryOperatorType.AND, _coalesce_expr(other)) def __or__(self, other: ExprCoalescible) -> AstBinaryOp: + """ + Creates a logical OR operation between this expression and another. + + This method allows for operator overloading of the | operator to create + binary logical operations in a more intuitive syntax. + + :param other: The right-hand side expression to combine with this one. + :return: A new AstBinaryOp representing the logical OR operation. + """ return AstBinaryOp(self, BinaryOperatorType.OR, _coalesce_expr(other)) def __invert__(self) -> AstLogicalExpression: + """ + Creates a logical negation of this expression. + + This method allows for operator overloading of the ~ operator to create + negated expressions. If the expression is already a logical expression, + it toggles the negation flag. Otherwise, it wraps the expression in a + new AstLogicalExpression with negation set to True. + + :return: An AstLogicalExpression representing the negated form of this expression. + """ if isinstance(self, AstLogicalExpression): self.negated = not self.negated return self @@ -81,11 +128,25 @@ class AstTerm(AstExpression, ABC): class AstAtom(AstTerm): """ Represents a grounded atom in AgentSpeak (e.g., lowercase constants). + + Atoms are the simplest form of terms in AgentSpeak, representing concrete, + unchanging values. They are typically used as constants in logical expressions. + + :ivar value: The string value of this atom, which will be converted to lowercase + in the AgentSpeak representation. """ value: str def _to_agentspeak(self) -> str: + """ + Converts this atom to its AgentSpeak string representation. + + Atoms are represented in lowercase in AgentSpeak to distinguish them + from variables (which are capitalized). + + :return: The lowercase string representation of this atom. + """ return self.value.lower() @@ -93,11 +154,25 @@ class AstAtom(AstTerm): class AstVar(AstTerm): """ Represents an ungrounded variable in AgentSpeak (e.g., capitalized names). + + Variables in AgentSpeak are placeholders that can be bound to specific values + during execution. They are distinguished from atoms by their capitalization. + + :ivar name: The name of this variable, which will be capitalized in the + AgentSpeak representation. """ name: str def _to_agentspeak(self) -> str: + """ + Converts this variable to its AgentSpeak string representation. + + Variables are represented with capitalized names in AgentSpeak to distinguish + them from atoms (which are lowercase). + + :return: The capitalized string representation of this variable. + """ return self.name.capitalize() @@ -105,11 +180,21 @@ class AstVar(AstTerm): class AstNumber(AstTerm): """ Represents a numeric constant in AgentSpeak. + + Numeric constants can be either integers or floating-point numbers and are + used in logical expressions and comparisons. + + :ivar value: The numeric value of this constant (can be int or float). """ value: int | float def _to_agentspeak(self) -> str: + """ + Converts this numeric constant to its AgentSpeak string representation. + + :return: The string representation of the numeric value. + """ return str(self.value) @@ -117,11 +202,23 @@ class AstNumber(AstTerm): class AstString(AstTerm): """ Represents a string literal in AgentSpeak. + + String literals are used to represent textual data and are enclosed in + double quotes in the AgentSpeak representation. + + :ivar value: The string content of this literal. """ value: str def _to_agentspeak(self) -> str: + """ + Converts this string literal to its AgentSpeak string representation. + + String literals are enclosed in double quotes in AgentSpeak. + + :return: The string literal enclosed in double quotes. + """ return f'"{self.value}"' @@ -129,12 +226,26 @@ class AstString(AstTerm): class AstLiteral(AstTerm): """ Represents a literal (functor and terms) in AgentSpeak. + + Literals are the fundamental building blocks of AgentSpeak programs, consisting + of a functor (predicate name) and an optional list of terms (arguments). + + :ivar functor: The name of the predicate or function. + :ivar terms: A list of terms (arguments) for this literal. Defaults to an empty list. """ functor: str terms: list[AstTerm] = field(default_factory=list) def _to_agentspeak(self) -> str: + """ + Converts this literal to its AgentSpeak string representation. + + If the literal has no terms, it returns just the functor name. + Otherwise, it returns the functor followed by the terms in parentheses. + + :return: The AgentSpeak string representation of this literal. + """ if not self.terms: return self.functor args = ", ".join(map(str, self.terms)) @@ -142,6 +253,13 @@ class AstLiteral(AstTerm): class BinaryOperatorType(StrEnum): + """ + Enumeration of binary operator types used in AgentSpeak expressions. + + These operators are used to create binary operations between expressions, + including logical operations (AND, OR) and comparison operations. + """ + AND = "&" OR = "|" GREATER_THAN = ">" @@ -156,6 +274,13 @@ class BinaryOperatorType(StrEnum): class AstBinaryOp(AstExpression): """ Represents a binary logical or relational operation in AgentSpeak. + + Binary operations combine two expressions using a logical or comparison operator. + They are used to create complex logical conditions in AgentSpeak programs. + + :ivar left: The left-hand side expression of the operation. + :ivar operator: The binary operator type (AND, OR, comparison operators, etc.). + :ivar right: The right-hand side expression of the operation. """ left: AstExpression @@ -163,10 +288,25 @@ class AstBinaryOp(AstExpression): right: AstExpression def __post_init__(self): + """ + Post-initialization processing to ensure proper expression types. + + This method wraps the left and right expressions in AstLogicalExpression + instances if they aren't already, ensuring consistent handling throughout + the AST. + """ self.left = _as_logical(self.left) self.right = _as_logical(self.right) def _to_agentspeak(self) -> str: + """ + Converts this binary operation to its AgentSpeak string representation. + + The method handles proper parenthesization of sub-expressions to maintain + correct operator precedence and readability. + + :return: The AgentSpeak string representation of this binary operation. + """ l_str = str(self.left) r_str = str(self.right) @@ -185,12 +325,27 @@ class AstBinaryOp(AstExpression): class AstLogicalExpression(AstExpression): """ Represents a logical expression, potentially negated, in AgentSpeak. + + Logical expressions can be either positive or negated and form the basis + of conditions and beliefs in AgentSpeak programs. + + :ivar expression: The underlying expression being evaluated. + :ivar negated: Boolean flag indicating whether this expression is negated. """ expression: AstExpression negated: bool = False def _to_agentspeak(self) -> str: + """ + Converts this logical expression to its AgentSpeak string representation. + + If the expression is negated, it prepends 'not ' to the expression string. + For complex expressions (binary operations), it adds parentheses when negated + to maintain correct logical interpretation. + + :return: The AgentSpeak string representation of this logical expression. + """ expr_str = str(self.expression) if isinstance(self.expression, AstBinaryOp) and self.negated: expr_str = f"({expr_str})" @@ -198,31 +353,76 @@ class AstLogicalExpression(AstExpression): def _as_logical(expr: AstExpression) -> AstLogicalExpression: + """ + Converts an expression to a logical expression if it isn't already. + + This helper function ensures that expressions are properly wrapped in + AstLogicalExpression instances, which is necessary for consistent handling + of logical operations in the AST. + + :param expr: The expression to convert. + :return: The expression wrapped in an AstLogicalExpression if it wasn't already. + """ if isinstance(expr, AstLogicalExpression): return expr return AstLogicalExpression(expr) class StatementType(StrEnum): + """ + Enumeration of statement types that can appear in AgentSpeak plans. + + These statement types define the different kinds of actions and operations + that can be performed within the body of an AgentSpeak plan. + """ + EMPTY = "" + """Empty statement (no operation, used when evaluating a plan to true).""" + DO_ACTION = "." + """Execute an action defined in Python.""" + ACHIEVE_GOAL = "!" + """Achieve a goal (add a goal to be accomplished).""" + TEST_GOAL = "?" + """Test a goal (check if a goal can be achieved).""" + ADD_BELIEF = "+" + """Add a belief to the belief base.""" + REMOVE_BELIEF = "-" + """Remove a belief from the belief base.""" + REPLACE_BELIEF = "-+" + """Replace a belief in the belief base.""" @dataclass class AstStatement(AstNode): """ A statement that can appear inside a plan. + + Statements are the executable units within AgentSpeak plans. They consist + of a statement type (defining the operation) and an expression (defining + what to operate on). + + :ivar type: The type of statement (action, goal, belief operation, etc.). + :ivar expression: The expression that this statement operates on. """ type: StatementType expression: AstExpression def _to_agentspeak(self) -> str: + """ + Converts this statement to its AgentSpeak string representation. + + The representation consists of the statement type prefix followed by + the expression. + + :return: The AgentSpeak string representation of this statement. + """ return f"{self.type.value}{self.expression}" @@ -230,26 +430,59 @@ class AstStatement(AstNode): class AstRule(AstNode): """ Represents an inference rule in AgentSpeak. If there is no condition, it always holds. + + Rules define logical implications in AgentSpeak programs. They consist of a + result (conclusion) and an optional condition (premise). When the condition + holds, the result is inferred to be true. + + :ivar result: The conclusion or result of this rule. + :ivar condition: The premise or condition for this rule (optional). """ result: AstExpression condition: AstExpression | None = None def __post_init__(self): + """ + Post-initialization processing to ensure proper expression types. + + If a condition is provided, this method wraps it in an AstLogicalExpression + to ensure consistent handling throughout the AST. + """ if self.condition is not None: self.condition = _as_logical(self.condition) def _to_agentspeak(self) -> str: + """ + Converts this rule to its AgentSpeak string representation. + + If no condition is specified, the rule is represented as a simple fact. + If a condition is specified, it's represented as an implication (result :- condition). + + :return: The AgentSpeak string representation of this rule. + """ if not self.condition: return f"{self.result}." return f"{self.result} :- {self.condition}." class TriggerType(StrEnum): + """ + Enumeration of trigger types for AgentSpeak plans. + + Trigger types define what kind of events can activate an AgentSpeak plan. + Currently, the system supports triggers for added beliefs and added goals. + """ + ADDED_BELIEF = "+" + """Trigger when a belief is added to the belief base.""" + # REMOVED_BELIEF = "-" # TODO # MODIFIED_BELIEF = "^" # TODO + ADDED_GOAL = "+!" + """Trigger when a goal is added to be achieved.""" + # REMOVED_GOAL = "-!" # TODO @@ -257,6 +490,14 @@ class TriggerType(StrEnum): class AstPlan(AstNode): """ Represents a plan in AgentSpeak, consisting of a trigger, context, and body. + + Plans define the reactive behavior of agents in AgentSpeak. They specify what + actions to take when certain conditions are met (trigger and context). + + :ivar type: The type of trigger that activates this plan. + :ivar trigger_literal: The specific event or condition that triggers this plan. + :ivar context: A list of conditions that must hold for this plan to be applicable. + :ivar body: A list of statements to execute when this plan is triggered. """ type: TriggerType @@ -265,6 +506,16 @@ class AstPlan(AstNode): body: list[AstStatement] def _to_agentspeak(self) -> str: + """ + Converts this plan to its AgentSpeak string representation. + + The representation follows the standard AgentSpeak plan format: + trigger_type + trigger_literal + : context_conditions + <- body_statements. + + :return: The AgentSpeak string representation of this plan. + """ assert isinstance(self.trigger_literal, AstLiteral) indent = " " * 6 @@ -290,12 +541,26 @@ class AstPlan(AstNode): class AstProgram(AstNode): """ Represents a full AgentSpeak program, consisting of rules and plans. + + This is the root node of the AgentSpeak AST, containing all the rules + and plans that define the agent's behavior. + + :ivar rules: A list of inference rules in this program. + :ivar plans: A list of reactive plans in this program. """ rules: list[AstRule] = field(default_factory=list) plans: list[AstPlan] = field(default_factory=list) def _to_agentspeak(self) -> str: + """ + Converts this program to its AgentSpeak string representation. + + The representation consists of all rules followed by all plans, + separated by blank lines for readability. + + :return: The complete AgentSpeak source code for this program. + """ lines = [] lines.extend(map(str, self.rules)) diff --git a/src/control_backend/agents/bdi/agentspeak_generator.py b/src/control_backend/agents/bdi/agentspeak_generator.py index 93c41af..5ee9f15 100644 --- a/src/control_backend/agents/bdi/agentspeak_generator.py +++ b/src/control_backend/agents/bdi/agentspeak_generator.py @@ -18,6 +18,7 @@ from control_backend.agents.bdi.agentspeak_ast import ( StatementType, TriggerType, ) +from control_backend.core.config import settings from control_backend.schemas.program import ( BaseGoal, BasicNorm, @@ -47,6 +48,15 @@ class AgentSpeakGenerator: It handles the conversion of phases, norms, goals, and triggers into AgentSpeak rules and plans, ensuring the robot follows the defined behavioral logic. + + The generator follows a systematic approach: + 1. Sets up initial phase and cycle notification rules + 2. Adds keyword inference capabilities for natural language processing + 3. Creates default plans for common operations + 4. Processes each phase with its norms, goals, and triggers + 5. Adds fallback plans for robust execution + + :ivar _asp: The internal AgentSpeak program representation being built. """ _asp: AstProgram @@ -55,6 +65,10 @@ class AgentSpeakGenerator: """ Translates a Program object into an AgentSpeak source string. + This is the main entry point for the code generation process. It initializes + the AgentSpeak program structure and orchestrates the conversion of all + program elements into their AgentSpeak representations. + :param program: The behavior program to translate. :return: The generated AgentSpeak code as a string. """ @@ -77,6 +91,18 @@ class AgentSpeakGenerator: return str(self._asp) def _add_keyword_inference(self) -> None: + """ + Adds inference rules for keyword detection in user messages. + + This method creates rules that allow the system to detect when specific + keywords are mentioned in user messages. It uses string operations to + check if a keyword is a substring of the user's message. + + The generated rule has the form: + keyword_said(Keyword) :- user_said(Message) & .substring(Keyword, Message, Pos) & Pos >= 0 + + This enables the system to trigger behaviors based on keyword detection. + """ keyword = AstVar("Keyword") message = AstVar("Message") position = AstVar("Pos") @@ -91,12 +117,32 @@ class AgentSpeakGenerator: ) def _add_default_plans(self): + """ + Adds default plans for common operations. + + This method sets up the standard plans that handle fundamental operations + like replying with goals, simple speech actions, general replies, and + cycle notifications. These plans provide the basic infrastructure for + the agent's reactive behavior. + """ self._add_reply_with_goal_plan() self._add_say_plan() self._add_reply_plan() self._add_notify_cycle_plan() def _add_reply_with_goal_plan(self): + """ + Adds a plan for replying with a specific conversational goal. + + This plan handles the case where the agent needs to respond to user input + while pursuing a specific conversational goal. It: + 1. Marks that the agent has responded this turn + 2. Gathers all active norms + 3. Generates a reply that considers both the user message and the goal + + Trigger: +!reply_with_goal(Goal) + Context: user_said(Message) + """ self._asp.plans.append( AstPlan( TriggerType.ADDED_GOAL, @@ -122,6 +168,17 @@ class AgentSpeakGenerator: ) def _add_say_plan(self): + """ + Adds a plan for simple speech actions. + + This plan handles direct speech actions where the agent needs to say + a specific text. It: + 1. Marks that the agent has responded this turn + 2. Executes the speech action + + Trigger: +!say(Text) + Context: None (can be executed anytime) + """ self._asp.plans.append( AstPlan( TriggerType.ADDED_GOAL, @@ -135,6 +192,18 @@ class AgentSpeakGenerator: ) def _add_reply_plan(self): + """ + Adds a plan for general reply actions. + + This plan handles general reply actions where the agent needs to respond + to user input without a specific conversational goal. It: + 1. Marks that the agent has responded this turn + 2. Gathers all active norms + 3. Generates a reply based on the user message and norms + + Trigger: +!reply + Context: user_said(Message) + """ self._asp.plans.append( AstPlan( TriggerType.ADDED_GOAL, @@ -158,6 +227,19 @@ class AgentSpeakGenerator: ) def _add_notify_cycle_plan(self): + """ + Adds a plan for cycle notification. + + This plan handles the periodic notification cycle that allows the system + to monitor and report on the current state. It: + 1. Gathers all active norms + 2. Notifies the system about the current norms + 3. Waits briefly to allow processing + 4. Recursively triggers the next cycle + + Trigger: +!notify_cycle + Context: None (can be executed anytime) + """ self._asp.plans.append( AstPlan( TriggerType.ADDED_GOAL, @@ -181,6 +263,16 @@ class AgentSpeakGenerator: ) def _process_phases(self, phases: list[Phase]) -> None: + """ + Processes all phases in the program and their transitions. + + This method iterates through each phase and: + 1. Processes the current phase (norms, goals, triggers) + 2. Sets up transitions between phases + 3. Adds special handling for the end phase + + :param phases: The list of phases to process. + """ for curr_phase, next_phase in zip([None] + phases, phases + [None], strict=True): if curr_phase: self._process_phase(curr_phase) @@ -203,6 +295,17 @@ class AgentSpeakGenerator: ) def _process_phase(self, phase: Phase) -> None: + """ + Processes a single phase, including its norms, goals, and triggers. + + This method handles the complete processing of a phase by: + 1. Processing all norms in the phase + 2. Setting up the default execution loop for the phase + 3. Processing all goals in sequence + 4. Processing all triggers for reactive behavior + + :param phase: The phase to process. + """ for norm in phase.norms: self._process_norm(norm, phase) @@ -217,6 +320,21 @@ class AgentSpeakGenerator: self._process_trigger(trigger, phase) def _add_phase_transition(self, from_phase: Phase | None, to_phase: Phase | None) -> None: + """ + Adds plans for transitioning between phases. + + This method creates two plans for each phase transition: + 1. A check plan that verifies if transition conditions are met + 2. A force plan that actually performs the transition (can be forced externally) + + The transition involves: + - Notifying the system about the phase change + - Removing the current phase belief + - Adding the next phase belief + + :param from_phase: The phase being transitioned from (or None for initial setup). + :param to_phase: The phase being transitioned to (or None for end phase). + """ if from_phase is None: return from_phase_ast = self._astify(from_phase) @@ -246,18 +364,6 @@ class AgentSpeakGenerator: AstStatement(StatementType.ADD_BELIEF, to_phase_ast), ] - # if from_phase: - # body.extend( - # [ - # AstStatement( - # StatementType.TEST_GOAL, AstLiteral("user_said", [AstVar("Message")]) - # ), - # AstStatement( - # StatementType.REPLACE_BELIEF, AstLiteral("user_said", [AstVar("Message")]) - # ), - # ] - # ) - # Check self._asp.plans.append( AstPlan( @@ -278,6 +384,17 @@ class AgentSpeakGenerator: ) def _process_norm(self, norm: Norm, phase: Phase) -> None: + """ + Processes a norm and adds it as an inference rule. + + This method converts norms into AgentSpeak rules that define when + the norm should be active. It handles both basic norms (always active + in their phase) and conditional norms (active only when their condition + is met). + + :param norm: The norm to process. + :param phase: The phase this norm belongs to. + """ rule: AstRule | None = None match norm: @@ -296,6 +413,18 @@ class AgentSpeakGenerator: self._asp.rules.append(rule) def _add_default_loop(self, phase: Phase) -> None: + """ + Adds the default execution loop for a phase. + + This method creates the main reactive loop that runs when the agent + receives user input during a phase. The loop: + 1. Notifies the system about the user input + 2. Resets the response tracking + 3. Executes all phase goals + 4. Attempts phase transition + + :param phase: The phase to create the loop for. + """ actions = [] actions.append( @@ -304,7 +433,6 @@ class AgentSpeakGenerator: ) ) actions.append(AstStatement(StatementType.REMOVE_BELIEF, AstLiteral("responded_this_turn"))) - actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("check_triggers"))) for goal in phase.goals: actions.append(AstStatement(StatementType.ACHIEVE_GOAL, self._astify(goal))) @@ -328,6 +456,22 @@ class AgentSpeakGenerator: continues_response: bool = False, main_goal: bool = False, ) -> None: + """ + Processes a goal and creates plans for achieving it. + + This method creates two plans for each goal: + 1. A main plan that executes the goal's steps when conditions are met + 2. A fallback plan that provides a default empty implementation (prevents crashes) + + The method also recursively processes any subgoals contained within + the goal's plan. + + :param goal: The goal to process. + :param phase: The phase this goal belongs to. + :param previous_goal: The previous goal in sequence (for dependency tracking). + :param continues_response: Whether this goal continues an existing response. + :param main_goal: Whether this is a main goal (for UI notification purposes). + """ context: list[AstExpression] = [self._astify(phase)] context.append(~self._astify(goal, achieved=True)) if previous_goal and previous_goal.can_fail: @@ -370,14 +514,39 @@ class AgentSpeakGenerator: prev_goal = subgoal def _step_to_statement(self, step: PlanElement) -> AstStatement: + """ + Converts a plan step to an AgentSpeak statement. + + This method transforms different types of plan elements into their + corresponding AgentSpeak statements. Goals and speech-related actions + become achieve-goal statements, while gesture actions become do-action + statements. + + :param step: The plan element to convert. + :return: The corresponding AgentSpeak statement. + """ match step: + # Note that SpeechAction gets included in the ACHIEVE_GOAL, since it's a goal internally case Goal() | SpeechAction() | LLMAction() as a: return AstStatement(StatementType.ACHIEVE_GOAL, self._astify(a)) case GestureAction() as a: return AstStatement(StatementType.DO_ACTION, self._astify(a)) - # TODO: separate handling of keyword and others def _process_trigger(self, trigger: Trigger, phase: Phase) -> None: + """ + Processes a trigger and creates plans for its execution. + + This method creates plans that execute when trigger conditions are met. + It handles both automatic triggering (when conditions are detected) and + manual forcing (from UI). The trigger execution includes: + 1. Notifying the system about trigger start + 2. Executing all trigger steps + 3. Waiting briefly for UI display + 4. Notifying the system about trigger end + + :param trigger: The trigger to process. + :param phase: The phase this trigger belongs to. + """ body = [] subgoals = [] @@ -394,7 +563,12 @@ class AgentSpeakGenerator: subgoals.append(step) # Arbitrary wait for UI to display nicely - body.append(AstStatement(StatementType.DO_ACTION, AstLiteral("wait", [AstNumber(2000)]))) + body.append( + AstStatement( + StatementType.DO_ACTION, + AstLiteral("wait", [AstNumber(settings.behaviour_settings.trigger_time_to_wait)]), + ) + ) body.append( AstStatement( @@ -419,6 +593,18 @@ class AgentSpeakGenerator: self._process_goal(subgoal, phase, continues_response=True) def _add_fallbacks(self): + """ + Adds fallback plans for robust execution, preventing crashes. + + This method creates fallback plans that provide default empty implementations + for key goals. These fallbacks ensure that the system can continue execution + even when no specific plans are applicable, preventing crashes. + + The fallbacks are created for: + - check_triggers: When no triggers are applicable + - transition_phase: When phase transition conditions aren't met + - force_transition_phase: When forced transitions aren't possible + """ # Trigger fallback self._asp.plans.append( AstPlan( @@ -451,14 +637,44 @@ class AgentSpeakGenerator: @singledispatchmethod def _astify(self, element: ProgramElement) -> AstExpression: + """ + Converts program elements to AgentSpeak expressions (base method). + + This is the base method for the singledispatch mechanism that handles + conversion of different program element types to their AgentSpeak + representations. Specific implementations are provided for each + element type through registered methods. + + :param element: The program element to convert. + :return: The corresponding AgentSpeak expression. + :raises NotImplementedError: If no specific implementation exists for the element type. + """ raise NotImplementedError(f"Cannot convert element {element} to an AgentSpeak expression.") @_astify.register def _(self, kwb: KeywordBelief) -> AstExpression: + """ + Converts a KeywordBelief to an AgentSpeak expression. + + Keyword beliefs are converted to keyword_said literals that check + if the keyword was mentioned in user input. + + :param kwb: The KeywordBelief to convert. + :return: An AstLiteral representing the keyword detection. + """ return AstLiteral("keyword_said", [AstString(kwb.keyword)]) @_astify.register def _(self, sb: SemanticBelief) -> AstExpression: + """ + Converts a SemanticBelief to an AgentSpeak expression. + + Semantic beliefs are converted to literals using their slugified names, + which are used for LLM-based belief evaluation. + + :param sb: The SemanticBelief to convert. + :return: An AstLiteral representing the semantic belief. + """ return AstLiteral(self.slugify(sb)) @_astify.register @@ -467,6 +683,15 @@ class AgentSpeakGenerator: @_astify.register def _(self, ib: InferredBelief) -> AstExpression: + """ + Converts an InferredBelief to an AgentSpeak expression. + + Inferred beliefs are converted to binary operations that combine + their left and right operands using the appropriate logical operator. + + :param ib: The InferredBelief to convert. + :return: An AstBinaryOp representing the logical combination. + """ return AstBinaryOp( self._astify(ib.left), BinaryOperatorType.AND if ib.operator == LogicalOperator.AND else BinaryOperatorType.OR, @@ -475,59 +700,187 @@ class AgentSpeakGenerator: @_astify.register def _(self, norm: Norm) -> AstExpression: + """ + Converts a Norm to an AgentSpeak expression. + + Norms are converted to literals with either 'norm' or 'critical_norm' + functors depending on their critical flag, with the norm text as an argument. + + Note that currently, critical norms are not yet functionally supported. They are possible + to astify for future use. + + :param norm: The Norm to convert. + :return: An AstLiteral representing the norm. + """ functor = "critical_norm" if norm.critical else "norm" return AstLiteral(functor, [AstString(norm.norm)]) @_astify.register def _(self, phase: Phase) -> AstExpression: + """ + Converts a Phase to an AgentSpeak expression. + + Phases are converted to phase literals with their unique identifier + as an argument, which is used for phase tracking and transitions. + + :param phase: The Phase to convert. + :return: An AstLiteral representing the phase. + """ return AstLiteral("phase", [AstString(str(phase.id))]) @_astify.register def _(self, goal: Goal, achieved: bool = False) -> AstExpression: + """ + Converts a Goal to an AgentSpeak expression. + + Goals are converted to literals using their slugified names. If the + achieved parameter is True, the literal is prefixed with 'achieved_'. + + :param goal: The Goal to convert. + :param achieved: Whether to represent this as an achieved goal. + :return: An AstLiteral representing the goal. + """ return AstLiteral(f"{'achieved_' if achieved else ''}{self._slugify_str(goal.name)}") @_astify.register def _(self, trigger: Trigger) -> AstExpression: + """ + Converts a Trigger to an AgentSpeak expression. + + Triggers are converted to literals using their slugified names, + which are used to identify and execute trigger plans. + + :param trigger: The Trigger to convert. + :return: An AstLiteral representing the trigger. + """ return AstLiteral(self.slugify(trigger)) @_astify.register def _(self, sa: SpeechAction) -> AstExpression: + """ + Converts a SpeechAction to an AgentSpeak expression. + + Speech actions are converted to say literals with the text content + as an argument, which are used for direct speech output. + + :param sa: The SpeechAction to convert. + :return: An AstLiteral representing the speech action. + """ return AstLiteral("say", [AstString(sa.text)]) @_astify.register def _(self, ga: GestureAction) -> AstExpression: + """ + Converts a GestureAction to an AgentSpeak expression. + + Gesture actions are converted to gesture literals with the gesture + type and name as arguments, which are used for physical robot gestures. + + :param ga: The GestureAction to convert. + :return: An AstLiteral representing the gesture action. + """ gesture = ga.gesture return AstLiteral("gesture", [AstString(gesture.type), AstString(gesture.name)]) @_astify.register def _(self, la: LLMAction) -> AstExpression: + """ + Converts an LLMAction to an AgentSpeak expression. + + LLM actions are converted to reply_with_goal literals with the + conversational goal as an argument, which are used for LLM-generated + responses guided by specific goals. + + :param la: The LLMAction to convert. + :return: An AstLiteral representing the LLM action. + """ return AstLiteral("reply_with_goal", [AstString(la.goal)]) @singledispatchmethod @staticmethod def slugify(element: ProgramElement) -> str: + """ + Converts program elements to slugs (base method). + + This is the base method for the singledispatch mechanism that handles + conversion of different program element types to their slug representations. + Specific implementations are provided for each element type through + registered methods. + + Slugs are used outside of AgentSpeak, mostly for identifying what to send to the AgentSpeak + program as beliefs. + + :param element: The program element to convert to a slug. + :return: The slug string representation. + :raises NotImplementedError: If no specific implementation exists for the element type. + """ raise NotImplementedError(f"Cannot convert element {element} to a slug.") @slugify.register @staticmethod def _(n: Norm) -> str: + """ + Converts a Norm to a slug. + + Norms are converted to slugs with the 'norm_' prefix followed by + the slugified norm text. + + :param n: The Norm to convert. + :return: The slug string representation. + """ return f"norm_{AgentSpeakGenerator._slugify_str(n.norm)}" @slugify.register @staticmethod def _(sb: SemanticBelief) -> str: + """ + Converts a SemanticBelief to a slug. + + Semantic beliefs are converted to slugs with the 'semantic_' prefix + followed by the slugified belief name. + + :param sb: The SemanticBelief to convert. + :return: The slug string representation. + """ return f"semantic_{AgentSpeakGenerator._slugify_str(sb.name)}" @slugify.register @staticmethod def _(g: BaseGoal) -> str: + """ + Converts a BaseGoal to a slug. + + Goals are converted to slugs using their slugified names directly. + + :param g: The BaseGoal to convert. + :return: The slug string representation. + """ return AgentSpeakGenerator._slugify_str(g.name) @slugify.register @staticmethod - def _(t: Trigger): + def _(t: Trigger) -> str: + """ + Converts a Trigger to a slug. + + Triggers are converted to slugs with the 'trigger_' prefix followed by + the slugified trigger name. + + :param t: The Trigger to convert. + :return: The slug string representation. + """ return f"trigger_{AgentSpeakGenerator._slugify_str(t.name)}" @staticmethod def _slugify_str(text: str) -> str: + """ + Converts a text string to a slug. + + This helper method converts arbitrary text to a URL-friendly slug format + by converting to lowercase, removing special characters, and replacing + spaces with underscores. It also removes common stopwords. + + :param text: The text string to convert. + :return: The slugified string. + """ return slugify(text, separator="_", stopwords=["a", "an", "the", "we", "you", "I"]) diff --git a/src/control_backend/agents/bdi/bdi_core_agent.py b/src/control_backend/agents/bdi/bdi_core_agent.py index 54b5149..698bbc4 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent.py +++ b/src/control_backend/agents/bdi/bdi_core_agent.py @@ -1,6 +1,7 @@ import asyncio import copy import json +import logging import time from collections.abc import Iterable @@ -19,6 +20,9 @@ from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, Speec DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak +experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name) + + class BDICoreAgent(BaseAgent): """ BDI Core Agent. @@ -207,6 +211,9 @@ class BDICoreAgent(BaseAgent): else: term = agentspeak.Literal(name) + if name != "user_said": + experiment_logger.observation(f"Formed new belief: {name}{f'={args}' if args else ''}") + self.bdi_agent.call( agentspeak.Trigger.addition, agentspeak.GoalType.belief, @@ -244,6 +251,9 @@ class BDICoreAgent(BaseAgent): new_args = (agentspeak.Literal(arg) for arg in args) term = agentspeak.Literal(name, new_args) + if name != "user_said": + experiment_logger.observation(f"Removed belief: {name}{f'={args}' if args else ''}") + result = self.bdi_agent.call( agentspeak.Trigger.removal, agentspeak.GoalType.belief, @@ -386,6 +396,8 @@ class BDICoreAgent(BaseAgent): body=str(message_text), ) + experiment_logger.chat(str(message_text), extra={"role": "assistant"}) + self.add_behavior(self.send(chat_history_message)) yield diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index 730c8e5..3ea6a62 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -1,10 +1,12 @@ import asyncio import json +import logging import zmq from pydantic import ValidationError from zmq.asyncio import Context +import control_backend from control_backend.agents import BaseAgent from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator from control_backend.core.config import settings @@ -19,17 +21,21 @@ from control_backend.schemas.program import ( Program, ) +experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name) + class BDIProgramManager(BaseAgent): """ BDI Program Manager Agent. This agent is responsible for receiving high-level programs (sequences of instructions/goals) - from the external HTTP API (via ZMQ) and translating them into core beliefs (norms and goals) - for the BDI Core Agent. In the future, it will be responsible for determining when goals are - met, and passing on new norms and goals accordingly. + from the external HTTP API (via ZMQ), transforming it into an AgentSpeak program, sharing the + program and its components to other agents, and keeping agents informed of the current state. :ivar sub_socket: The ZMQ SUB socket used to receive program updates. + :ivar _program: The current Program. + :ivar _phase: The current Phase. + :ivar _goal_mapping: A mapping of goal IDs to goals. """ _program: Program @@ -38,16 +44,28 @@ class BDIProgramManager(BaseAgent): def __init__(self, **kwargs): super().__init__(**kwargs) self.sub_socket = None + self._goal_mapping: dict[str, Goal] = {} def _initialize_internal_state(self, program: Program): + """ + Initialize the state of the program manager given a new Program. Reset the tracking of the + current phase to the first phase, make a mapping of goal IDs to goals, used during the life + of the program. + :param program: The new program. + """ self._program = program self._phase = program.phases[0] # start in first phase - self._goal_mapping: dict[str, Goal] = {} + self._goal_mapping = {} for phase in program.phases: for goal in phase.goals: self._populate_goal_mapping_with_goal(goal) def _populate_goal_mapping_with_goal(self, goal: Goal): + """ + Recurse through the given goal and its subgoals and add all goals found to the + ``self._goal_mapping``. + :param goal: The goal to add to the ``self._goal_mapping``, including subgoals. + """ self._goal_mapping[str(goal.id)] = goal for step in goal.plan.steps: if isinstance(step, Goal): @@ -63,7 +81,7 @@ class BDIProgramManager(BaseAgent): asl_str = asg.generate(program) - file_name = "src/control_backend/agents/bdi/agentspeak.asl" + file_name = settings.behaviour_settings.agentspeak_file with open(file_name, "w") as f: f.write(asl_str) @@ -88,6 +106,16 @@ class BDIProgramManager(BaseAgent): await self._send_achieved_goal_to_semantic_belief_extractor(goal_id) async def _transition_phase(self, old: str, new: str): + """ + When receiving a signal from the BDI core that the phase has changed, apply this change to + the current state and inform other agents about the change. + + :param old: The ID of the old phase. + :param new: The ID of the new phase. + """ + if self._phase is None: + return + if old != str(self._phase.id): self.logger.warning( f"Phase transition desync detected! ASL requested move from '{old}', " @@ -126,6 +154,13 @@ class BDIProgramManager(BaseAgent): self.add_behavior(self.send(msg)) def _extract_current_beliefs(self) -> list[Belief]: + """Extract beliefs from the current phase.""" + assert self._phase is not None, ( + "Invalid state, no phase set. Call this method only when " + "a program has been received and the end-phase has not " + "been reached." + ) + beliefs: list[Belief] = [] for norm in self._phase.norms: @@ -139,6 +174,7 @@ class BDIProgramManager(BaseAgent): @staticmethod def _extract_beliefs_from_belief(belief: Belief) -> list[Belief]: + """Recursively extract beliefs from the given belief.""" if isinstance(belief, InferredBelief): return BDIProgramManager._extract_beliefs_from_belief( belief.left @@ -146,9 +182,7 @@ class BDIProgramManager(BaseAgent): return [belief] async def _send_beliefs_to_semantic_belief_extractor(self): - """ - Extract beliefs from the program and send them to the Semantic Belief Extractor Agent. - """ + """Extract beliefs from the program and send them to the Semantic Belief Extractor Agent.""" beliefs = BeliefList(beliefs=self._extract_current_beliefs()) message = InternalMessage( @@ -168,9 +202,9 @@ class BDIProgramManager(BaseAgent): :return: All goals within and including the given goal. """ goals: list[Goal] = [goal] - for plan in goal.plan: - if isinstance(plan, Goal): - goals.extend(BDIProgramManager._extract_goals_from_goal(plan)) + for step in goal.plan.steps: + if isinstance(step, Goal): + goals.extend(BDIProgramManager._extract_goals_from_goal(step)) return goals def _extract_current_goals(self) -> list[Goal]: @@ -179,6 +213,12 @@ class BDIProgramManager(BaseAgent): :return: A list of Goal objects. """ + assert self._phase is not None, ( + "Invalid state, no phase set. Call this method only when " + "a program has been received and the end-phase has not " + "been reached." + ) + goals: list[Goal] = [] for goal in self._phase.goals: @@ -241,6 +281,18 @@ class BDIProgramManager(BaseAgent): await self.send(extractor_msg) self.logger.debug("Sent message to extractor agent to clear history.") + @staticmethod + def _rollover_experiment_logs(): + """ + A new experiment program started; make a new experiment log file. + """ + handlers = logging.getLogger(settings.logging_settings.experiment_logger_name).handlers + for handler in handlers: + if isinstance(handler, control_backend.logging.DatedFileHandler): + experiment_logger.action("Doing rollover...") + handler.do_rollover() + experiment_logger.debug("Finished rollover.") + async def _receive_programs(self): """ Continuous loop that receives program updates from the HTTP endpoint. @@ -261,6 +313,7 @@ class BDIProgramManager(BaseAgent): self._initialize_internal_state(program) await self._send_program_to_user_interrupt(program) await self._send_clear_llm_history() + self._rollover_experiment_logs() await asyncio.gather( self._create_agentspeak_and_send_to_bdi(program), diff --git a/src/control_backend/agents/bdi/text_belief_extractor_agent.py b/src/control_backend/agents/bdi/text_belief_extractor_agent.py index 9ea6b9a..4aaaec8 100644 --- a/src/control_backend/agents/bdi/text_belief_extractor_agent.py +++ b/src/control_backend/agents/bdi/text_belief_extractor_agent.py @@ -134,6 +134,10 @@ class TextBeliefExtractorAgent(BaseAgent): self.logger.warning("Received unexpected message from %s", msg.sender) def _reset_phase(self): + """ + Delete all state about the current phase, such as what beliefs exist and which ones are + true. + """ self.conversation = ChatHistory(messages=[]) self.belief_inferrer.available_beliefs.clear() self._current_beliefs = BeliefState() @@ -141,6 +145,11 @@ class TextBeliefExtractorAgent(BaseAgent): self._current_goal_completions = {} def _handle_beliefs_message(self, msg: InternalMessage): + """ + Handle the message from the Program Manager agent containing the beliefs that exist for this + phase. + :param msg: A list of beliefs. + """ try: belief_list = BeliefList.model_validate_json(msg.body) except ValidationError: @@ -158,6 +167,11 @@ class TextBeliefExtractorAgent(BaseAgent): ) def _handle_goals_message(self, msg: InternalMessage): + """ + Handle the message from the Program Manager agent containing the goals that exist for this + phase. + :param msg: A list of goals. + """ try: goals_list = GoalList.model_validate_json(msg.body) except ValidationError: @@ -177,6 +191,11 @@ class TextBeliefExtractorAgent(BaseAgent): ) def _handle_goal_achieved_message(self, msg: InternalMessage): + """ + Handle message that gets sent when goals are marked achieved from a user interrupt. This + goal should then not be changed by this agent anymore. + :param msg: List of goals that are marked achieved. + """ # NOTE: When goals can be marked unachieved, remember to re-add them to the goal_inferrer try: goals_list = GoalList.model_validate_json(msg.body) @@ -210,6 +229,10 @@ class TextBeliefExtractorAgent(BaseAgent): await self.send(belief_msg) async def _infer_new_beliefs(self): + """ + Determine which beliefs hold and do not hold for the current conversation state. When + beliefs change, a message is sent to the BDI core. + """ conversation_beliefs = await self.belief_inferrer.infer_from_conversation(self.conversation) new_beliefs = conversation_beliefs - self._current_beliefs @@ -233,6 +256,10 @@ class TextBeliefExtractorAgent(BaseAgent): await self.send(message) async def _infer_goal_completions(self): + """ + Determine which goals have been achieved given the current conversation state. When + a goal's achieved state changes, a message is sent to the BDI core. + """ goal_completions = await self.goal_inferrer.infer_from_conversation(self.conversation) new_achieved = [ @@ -377,19 +404,22 @@ class SemanticBeliefInferrer: for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel) ] ) - retval = BeliefState() + new_beliefs = BeliefState() + # Collect beliefs from all parallel calls for beliefs in all_beliefs: if beliefs is None: continue + # For each, convert them to InternalBeliefs for belief_name, belief_holds in beliefs.items(): + # Skip beliefs that were marked not possible to determine if belief_holds is None: continue belief = InternalBelief(name=belief_name, arguments=None) if belief_holds: - retval.true.add(belief) + new_beliefs.true.add(belief) else: - retval.false.add(belief) - return retval + new_beliefs.false.add(belief) + return new_beliefs @staticmethod def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]: diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index 746705c..1072a96 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -146,7 +146,7 @@ class RICommunicationAgent(BaseAgent): # At this point, we have a valid response try: - self.logger.debug("Negotiation successful. Handling rn") + self.logger.debug("Negotiation successful.") await self._handle_negotiation_response(received_message) # Let UI know that we're connected topic = b"ping" diff --git a/src/control_backend/agents/llm/llm_agent.py b/src/control_backend/agents/llm/llm_agent.py index 8d81249..08a77e3 100644 --- a/src/control_backend/agents/llm/llm_agent.py +++ b/src/control_backend/agents/llm/llm_agent.py @@ -1,5 +1,6 @@ import asyncio import json +import logging import re import uuid from collections.abc import AsyncGenerator @@ -14,6 +15,8 @@ from control_backend.core.config import settings from ...schemas.llm_prompt_message import LLMPromptMessage from .llm_instructions import LLMInstructions +experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name) + class LLMAgent(BaseAgent): """ @@ -170,7 +173,7 @@ class LLMAgent(BaseAgent): *self.history, ] - message_id = str(uuid.uuid4()) # noqa + message_id = str(uuid.uuid4()) try: full_message = "" @@ -179,10 +182,9 @@ class LLMAgent(BaseAgent): full_message += token current_chunk += token - self.logger.llm( - "Received token: %s", + experiment_logger.chat( full_message, - extra={"reference": message_id}, # Used in the UI to update old logs + extra={"role": "assistant", "reference": message_id, "partial": True}, ) # Stream the message in chunks separated by punctuation. @@ -197,6 +199,11 @@ class LLMAgent(BaseAgent): # Yield any remaining tail if current_chunk: yield current_chunk + + experiment_logger.chat( + full_message, + extra={"role": "assistant", "reference": message_id, "partial": False}, + ) except httpx.HTTPError as err: self.logger.error("HTTP error.", exc_info=err) yield "LLM service unavailable." diff --git a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py index 795623d..e69fea6 100644 --- a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py +++ b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py @@ -1,4 +1,5 @@ import asyncio +import logging import numpy as np import zmq @@ -10,6 +11,8 @@ from control_backend.core.config import settings from .speech_recognizer import SpeechRecognizer +experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name) + class TranscriptionAgent(BaseAgent): """ @@ -25,6 +28,8 @@ class TranscriptionAgent(BaseAgent): :ivar audio_in_socket: The ZMQ SUB socket instance. :ivar speech_recognizer: The speech recognition engine instance. :ivar _concurrency: Semaphore to limit concurrent transcriptions. + :ivar _current_speech_reference: The reference of the current user utterance, for synchronising + experiment logs. """ def __init__(self, audio_in_address: str): @@ -39,6 +44,7 @@ class TranscriptionAgent(BaseAgent): self.audio_in_socket: azmq.Socket | None = None self.speech_recognizer = None self._concurrency = None + self._current_speech_reference: str | None = None async def setup(self): """ @@ -63,6 +69,10 @@ class TranscriptionAgent(BaseAgent): self.logger.info("Finished setting up %s", self.name) + async def handle_message(self, msg: InternalMessage): + if msg.thread == "voice_activity": + self._current_speech_reference = msg.body + async def stop(self): """ Stop the agent and close sockets. @@ -96,24 +106,25 @@ class TranscriptionAgent(BaseAgent): async def _share_transcription(self, transcription: str): """ - Share a transcription to the other agents that depend on it. + Share a transcription to the other agents that depend on it, and to experiment logs. Currently sends to: - :attr:`settings.agent_settings.text_belief_extractor_name` + - The UI via the experiment logger :param transcription: The transcribed text. """ - receiver_names = [ - settings.agent_settings.text_belief_extractor_name, - ] + experiment_logger.chat( + transcription, + extra={"role": "user", "reference": self._current_speech_reference, "partial": False}, + ) - for receiver_name in receiver_names: - message = InternalMessage( - to=receiver_name, - sender=self.name, - body=transcription, - ) - await self.send(message) + message = InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=self.name, + body=transcription, + ) + await self.send(message) async def _transcribing_loop(self) -> None: """ @@ -129,10 +140,9 @@ class TranscriptionAgent(BaseAgent): audio = np.frombuffer(audio_data, dtype=np.float32) speech = await self._transcribe(audio) if not speech: - self.logger.info("Nothing transcribed.") + self.logger.debug("Nothing transcribed.") continue - self.logger.info("Transcribed speech: %s", speech) await self._share_transcription(speech) except Exception as e: self.logger.error(f"Error in transcription loop: {e}") diff --git a/src/control_backend/agents/perception/vad_agent.py b/src/control_backend/agents/perception/vad_agent.py index 2b333f5..f397563 100644 --- a/src/control_backend/agents/perception/vad_agent.py +++ b/src/control_backend/agents/perception/vad_agent.py @@ -1,4 +1,6 @@ import asyncio +import logging +import uuid import numpy as np import torch @@ -12,6 +14,8 @@ from control_backend.schemas.internal_message import InternalMessage from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus from .transcription_agent.transcription_agent import TranscriptionAgent +experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name) + class SocketPoller[T]: """ @@ -252,6 +256,18 @@ class VADAgent(BaseAgent): if prob > prob_threshold: if self.i_since_speech > non_speech_patience + begin_silence_length: self.logger.debug("Speech started.") + reference = str(uuid.uuid4()) + experiment_logger.chat( + "...", + extra={"role": "user", "reference": reference, "partial": True}, + ) + await self.send( + InternalMessage( + to=settings.agent_settings.transcription_name, + body=reference, + thread="voice_activity", + ) + ) self.audio_buffer = np.append(self.audio_buffer, chunk) self.i_since_speech = 0 continue @@ -269,9 +285,10 @@ class VADAgent(BaseAgent): assert self.audio_out_socket is not None await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes()) - # At this point, we know that the speech has ended. - # Prepend the last chunk that had no speech, for a more fluent boundary - self.audio_buffer = chunk + # At this point, we know that there is no speech. + # Prepend the last few chunks that had no speech, for a more fluent boundary. + self.audio_buffer = np.append(self.audio_buffer, chunk) + self.audio_buffer = self.audio_buffer[-begin_silence_length * len(chunk) :] async def handle_message(self, msg: InternalMessage): """ diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index 6138495..7cf67c0 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -1,4 +1,5 @@ import json +import logging import zmq from zmq.asyncio import Context @@ -8,13 +9,15 @@ from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings from control_backend.schemas.belief_message import Belief, BeliefMessage -from control_backend.schemas.program import ConditionalNorm, Program +from control_backend.schemas.program import ConditionalNorm, Goal, Program from control_backend.schemas.ri_message import ( GestureCommand, RIEndpoint, SpeechCommand, ) +experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name) + class UserInterruptAgent(BaseAgent): """ @@ -245,6 +248,16 @@ class UserInterruptAgent(BaseAgent): self._cond_norm_map = {} self._cond_norm_reverse_map = {} + def _register_goal(goal: Goal): + """Recursively register goals and their subgoals.""" + slug = AgentSpeakGenerator.slugify(goal) + self._goal_map[str(goal.id)] = slug + self._goal_reverse_map[slug] = str(goal.id) + + for step in goal.plan.steps: + if isinstance(step, Goal): + _register_goal(step) + for phase in program.phases: for trigger in phase.triggers: slug = AgentSpeakGenerator.slugify(trigger) @@ -252,8 +265,7 @@ class UserInterruptAgent(BaseAgent): self._trigger_reverse_map[slug] = str(trigger.id) for goal in phase.goals: - self._goal_map[str(goal.id)] = AgentSpeakGenerator.slugify(goal) - self._goal_reverse_map[AgentSpeakGenerator.slugify(goal)] = str(goal.id) + _register_goal(goal) for goal, id in self._goal_reverse_map.items(): self.logger.debug(f"Goal mapping: UI ID {goal} -> {id}") @@ -295,6 +307,7 @@ class UserInterruptAgent(BaseAgent): :param text_to_say: The string that the robot has to say. """ + experiment_logger.chat(text_to_say, extra={"role": "assistant"}) cmd = SpeechCommand(data=text_to_say, is_priority=True) out_msg = InternalMessage( to=settings.agent_settings.robot_speech_name, @@ -334,6 +347,7 @@ class UserInterruptAgent(BaseAgent): belief_name = f"force_{asl}" else: self.logger.warning("Tried to send belief with unknown type") + return belief = Belief(name=belief_name, arguments=None) self.logger.debug(f"Sending belief to BDI Core: {belief_name}") # Conditional norms are unachieved by removing the belief diff --git a/src/control_backend/api/v1/endpoints/logs.py b/src/control_backend/api/v1/endpoints/logs.py index ccccf44..0e2dff9 100644 --- a/src/control_backend/api/v1/endpoints/logs.py +++ b/src/control_backend/api/v1/endpoints/logs.py @@ -1,8 +1,9 @@ import logging +from pathlib import Path import zmq -from fastapi import APIRouter -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, HTTPException +from fastapi.responses import FileResponse, StreamingResponse from zmq.asyncio import Context from control_backend.core.config import settings @@ -38,3 +39,29 @@ async def log_stream(): yield f"data: {message}\n\n" return StreamingResponse(gen(), media_type="text/event-stream") + + +LOGGING_DIR = Path(settings.logging_settings.experiment_log_directory).resolve() + + +@router.get("/logs/files") +@router.get("/api/logs/files") +async def log_directory(): + """ + Get a list of all log files stored in the experiment log file directory. + """ + return [f.name for f in LOGGING_DIR.glob("*.log")] + + +@router.get("/logs/files/{filename}") +@router.get("/api/logs/files/{filename}") +async def log_file(filename: str): + # Prevent path-traversal + file_path = (LOGGING_DIR / filename).resolve() # This .resolve() is important + if not file_path.is_relative_to(LOGGING_DIR): + raise HTTPException(status_code=400, detail="Invalid filename.") + + if not file_path.is_file(): + raise HTTPException(status_code=404, detail="File not found.") + + return FileResponse(file_path, filename=file_path.name) diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index a9d22a9..57ab5d7 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -82,6 +82,8 @@ class BehaviourSettings(BaseModel): emotions and update emotion beliefs. :ivar visual_emotion_recognition_min_frames_per_face: Minimum number of frames per face required to consider a face valid. + :ivar trigger_time_to_wait: Amount of milliseconds to wait before informing the UI about trigger + completion. """ # ATTENTION: When adding/removing settings, make sure to update the .env.example file @@ -108,6 +110,10 @@ class BehaviourSettings(BaseModel): # Visual Emotion Recognition settings visual_emotion_recognition_window_duration_s: int = 5 visual_emotion_recognition_min_frames_per_face: int = 3 + # AgentSpeak related settings + trigger_time_to_wait: int = 2000 + agentspeak_file: str = "src/control_backend/agents/bdi/agentspeak.asl" + class LLMSettings(BaseModel): """ @@ -162,6 +168,20 @@ class SpeechModelSettings(BaseModel): openai_model_name: str = "small.en" +class LoggingSettings(BaseModel): + """ + Configuration for logging. + + :ivar logging_config_file: Path to the logging configuration file. + :ivar experiment_log_directory: Location of the experiment logs. Must match the logging config. + :ivar experiment_logger_name: Name of the experiment logger. Must match the logging config. + """ + + logging_config_file: str = ".logging_config.yaml" + experiment_log_directory: str = "experiment_logs" + experiment_logger_name: str = "experiment" + + class Settings(BaseSettings): """ Global application settings. @@ -183,6 +203,8 @@ class Settings(BaseSettings): ri_host: str = "localhost" + logging_settings: LoggingSettings = LoggingSettings() + zmq_settings: ZMQSettings = ZMQSettings() agent_settings: AgentSettings = AgentSettings() diff --git a/src/control_backend/logging/__init__.py b/src/control_backend/logging/__init__.py index c97af40..a08e9b8 100644 --- a/src/control_backend/logging/__init__.py +++ b/src/control_backend/logging/__init__.py @@ -1 +1,4 @@ +from .dated_file_handler import DatedFileHandler as DatedFileHandler +from .optional_field_formatter import OptionalFieldFormatter as OptionalFieldFormatter +from .partial_filter import PartialFilter as PartialFilter from .setup_logging import setup_logging as setup_logging diff --git a/src/control_backend/logging/dated_file_handler.py b/src/control_backend/logging/dated_file_handler.py new file mode 100644 index 0000000..3a405bb --- /dev/null +++ b/src/control_backend/logging/dated_file_handler.py @@ -0,0 +1,38 @@ +from datetime import datetime +from logging import FileHandler +from pathlib import Path + + +class DatedFileHandler(FileHandler): + def __init__(self, file_prefix: str, **kwargs): + if not file_prefix: + raise ValueError("`file_prefix` argument cannot be empty.") + self._file_prefix = file_prefix + kwargs["filename"] = self._make_filename() + super().__init__(**kwargs) + + def _make_filename(self) -> str: + """ + Create the filename for the current logfile, using the configured file prefix and the + current date and time. If the directory does not exist, it gets created. + + :return: A filepath. + """ + filepath = Path(f"{self._file_prefix}-{datetime.now():%Y%m%d-%H%M%S}.log") + if not filepath.parent.is_dir(): + filepath.parent.mkdir(parents=True, exist_ok=True) + return str(filepath) + + def do_rollover(self): + """ + Close the current logfile and create a new one with the current date and time. + """ + self.acquire() + try: + if self.stream: + self.stream.close() + + self.baseFilename = self._make_filename() + self.stream = self._open() + finally: + self.release() diff --git a/src/control_backend/logging/optional_field_formatter.py b/src/control_backend/logging/optional_field_formatter.py new file mode 100644 index 0000000..886e9a4 --- /dev/null +++ b/src/control_backend/logging/optional_field_formatter.py @@ -0,0 +1,67 @@ +import logging +import re + + +class OptionalFieldFormatter(logging.Formatter): + """ + Logging formatter that supports optional fields marked by `?`. + + Optional fields are denoted by placing a `?` after the field name inside + the parentheses, e.g., `%(role?)s`. If the field is not provided in the + log call's `extra` dict, it will use the default value from `defaults` + or `None` if no default is specified. + + :param fmt: Format string with optional `%(name?)s` style fields. + :type fmt: str or None + :param datefmt: Date format string, passed to parent :class:`logging.Formatter`. + :type datefmt: str or None + :param style: Formatting style, must be '%'. Passed to parent. + :type style: str + :param defaults: Default values for optional fields when not provided. + :type defaults: dict or None + + :example: + + >>> formatter = OptionalFieldFormatter( + ... fmt="%(asctime)s %(levelname)s %(role?)s %(message)s", + ... defaults={"role": ""-""} + ... ) + >>> handler = logging.StreamHandler() + >>> handler.setFormatter(formatter) + >>> logger = logging.getLogger(__name__) + >>> logger.addHandler(handler) + >>> + >>> logger.chat("Hello there!", extra={"role": "USER"}) + 2025-01-15 10:30:00 CHAT USER Hello there! + >>> + >>> logger.info("A logging message") + 2025-01-15 10:30:01 INFO - A logging message + + .. note:: + Only `%`-style formatting is supported. The `{` and `$` styles are not + compatible with this formatter. + + .. seealso:: + :class:`logging.Formatter` for base formatter documentation. + """ + + # Match %(name?)s or %(name?)d etc. + OPTIONAL_PATTERN = re.compile(r"%\((\w+)\?\)([sdifFeEgGxXocrba%])") + + def __init__(self, fmt=None, datefmt=None, style="%", defaults=None): + self.defaults = defaults or {} + + self.optional_fields = set(self.OPTIONAL_PATTERN.findall(fmt or "")) + + # Convert %(name?)s to %(name)s for standard formatting + normalized_fmt = self.OPTIONAL_PATTERN.sub(r"%(\1)\2", fmt or "") + + super().__init__(normalized_fmt, datefmt, style) + + def format(self, record): + for field, _ in self.optional_fields: + if not hasattr(record, field): + default = self.defaults.get(field, None) + setattr(record, field, default) + + return super().format(record) diff --git a/src/control_backend/logging/partial_filter.py b/src/control_backend/logging/partial_filter.py new file mode 100644 index 0000000..1b121cb --- /dev/null +++ b/src/control_backend/logging/partial_filter.py @@ -0,0 +1,10 @@ +import logging + + +class PartialFilter(logging.Filter): + """ + Class to filter any log records that have the "partial" attribute set to ``True``. + """ + + def filter(self, record): + return getattr(record, "partial", False) is not True diff --git a/src/control_backend/logging/setup_logging.py b/src/control_backend/logging/setup_logging.py index 05ae85a..7147dcc 100644 --- a/src/control_backend/logging/setup_logging.py +++ b/src/control_backend/logging/setup_logging.py @@ -37,7 +37,7 @@ def add_logging_level(level_name: str, level_num: int, method_name: str | None = setattr(logging, method_name, log_to_root) -def setup_logging(path: str = ".logging_config.yaml") -> None: +def setup_logging(path: str = settings.logging_settings.logging_config_file) -> None: """ Setup logging configuration of the CB. Tries to load the logging configuration from a file, in which we specify custom loggers, formatters, handlers, etc. @@ -65,7 +65,7 @@ def setup_logging(path: str = ".logging_config.yaml") -> None: # Patch ZMQ PUBHandler to know about custom levels if custom_levels: - for logger_name in ("control_backend",): + for logger_name in config.get("loggers", {}): logger = logging.getLogger(logger_name) for handler in logger.handlers: if isinstance(handler, PUBHandler): diff --git a/src/control_backend/schemas/program.py b/src/control_backend/schemas/program.py index 9bc6e0d..15e0bc3 100644 --- a/src/control_backend/schemas/program.py +++ b/src/control_backend/schemas/program.py @@ -22,6 +22,13 @@ class ProgramElement(BaseModel): class LogicalOperator(Enum): """ Logical operators for combining beliefs. + + These operators define how beliefs can be combined to form more complex + logical conditions. They are used in inferred beliefs to create compound + beliefs from simpler ones. + + AND: Both operands must be true for the result to be true. + OR: At least one operand must be true for the result to be true. """ AND = "AND" @@ -36,7 +43,15 @@ class KeywordBelief(ProgramElement): """ Represents a belief that is activated when a specific keyword is detected in the user's speech. + Keyword beliefs provide a simple but effective way to detect specific topics + or intentions in user speech. They are triggered when the exact keyword + string appears in the transcribed user input. + :ivar keyword: The string to look for in the transcription. + + Example: + A keyword belief with keyword="robot" would be activated when the user + says "I like the robot" or "Tell me about robots". """ name: str = "" @@ -48,8 +63,18 @@ class SemanticBelief(ProgramElement): Represents a belief whose truth value is determined by an LLM analyzing the conversation context. + Semantic beliefs provide more sophisticated belief detection by using + an LLM to analyze the conversation context and determine + if the belief should be considered true. This allows for more nuanced + and context-aware belief evaluation. + :ivar description: A natural language description of what this belief represents, used as a prompt for the LLM. + + Example: + A semantic belief with description="The user is expressing frustration" + would be activated when the LLM determines that the user's statements + indicate frustration, even if no specific keywords are used. """ description: str @@ -59,6 +84,11 @@ class InferredBelief(ProgramElement): """ Represents a belief derived from other beliefs using logical operators. + Inferred beliefs allow for the creation of complex belief structures by + combining simpler beliefs using logical operators. This enables the + representation of sophisticated conditions and relationships between + different aspects of the conversation or context. + :ivar operator: The :class:`LogicalOperator` (AND/OR) to apply. :ivar left: The left operand (another belief). :ivar right: The right operand (another belief). @@ -83,8 +113,16 @@ class Norm(ProgramElement): """ Base class for behavioral norms that guide the robot's interactions. + Norms represent guidelines, principles, or rules that should govern the + robot's behavior during interactions. They can be either basic (always + active in their phase) or conditional (active only when specific beliefs + are true). + :ivar norm: The textual description of the norm. :ivar critical: Whether this norm is considered critical and should be strictly enforced. + + Critical norms are currently not supported yet, but are intended for norms that should + ABSOLUTELY NOT be violated, possible cheched by additional validator agents. """ name: str = "" @@ -95,6 +133,13 @@ class Norm(ProgramElement): class BasicNorm(Norm): """ A simple behavioral norm that is always considered for activation when its phase is active. + + Basic norms are the most straightforward type of norms. They are active + throughout their assigned phase and provide consistent behavioral guidance + without any additional conditions. + + These norms are suitable for general principles that should always apply + during a particular interaction phase. """ pass @@ -104,7 +149,20 @@ class ConditionalNorm(Norm): """ A behavioral norm that is only active when a specific condition (belief) is met. + Conditional norms provide context-sensitive behavioral guidance. They are + only active and considered for activation when their associated condition + (belief) is true. This allows for more nuanced and adaptive behavior that + responds to the specific context of the interaction. + + An important note, is that the current implementation of these norms for keyword-based beliefs + is that they only hold for 1 turn, as keyword-based conditions often express temporary + conditions. + :ivar condition: The :class:`Belief` that must hold for this norm to be active. + + Example: + A conditional norm with the condition "user is frustrated" might specify + that the robot should use more empathetic language and avoid complex topics. """ condition: Belief @@ -116,7 +174,12 @@ type PlanElement = Goal | Action class Plan(ProgramElement): """ Represents a list of steps to execute. Each of these steps can be a goal (with its own plan) - or a simple action. + or a simple action. + + Plans define sequences of actions and subgoals that the robot should execute + to achieve a particular objective. They form the procedural knowledge of + the behavior program, specifying what the robot should do in different + situations. :ivar steps: The actions or subgoals to execute, in order. """ @@ -132,6 +195,10 @@ class BaseGoal(ProgramElement): :ivar description: A description of the goal, used to determine if it has been achieved. :ivar can_fail: Whether we can fail to achieve the goal after executing the plan. + + The can_fail attribute determines whether goal achievement is binary + (success/failure) or whether it can be determined through conversation + analysis. """ description: str = "" @@ -141,9 +208,13 @@ class BaseGoal(ProgramElement): class Goal(BaseGoal): """ Represents an objective to be achieved. To reach the goal, we should execute the corresponding - plan. It inherits from the BaseGoal a variable `can_fail`, which if true will cause the + plan. It inherits from the BaseGoal a variable `can_fail`, which, if true, will cause the completion to be determined based on the conversation. + Goals extend base goals by including a specific plan to achieve the objective. + They form the core of the robot's proactive behavior, defining both what + should be accomplished and how to accomplish it. + Instances of this goal are not hashable because a plan is not hashable. :ivar plan: The plan to execute. @@ -172,6 +243,10 @@ class Gesture(BaseModel): :ivar type: Whether to use a specific "single" gesture or a random one from a "tag" category. :ivar name: The identifier for the gesture or tag. + + The type field determines how the gesture is selected: + - "single": Use the specific gesture identified by name + - "tag": Select a random gesture from the category identified by name """ type: Literal["tag", "single"] @@ -194,6 +269,10 @@ class LLMAction(ProgramElement): An action that triggers an LLM-generated conversational response. :ivar goal: A temporary conversational goal to guide the LLM's response generation. + + The goal parameter provides high-level guidance to the LLM about what + the response should aim to achieve, while allowing the LLM flexibility + in how to express it. """ name: str = "" @@ -231,6 +310,10 @@ class Program(BaseModel): """ The top-level container for a complete robot behavior definition. + The Program class represents the complete specification of a robot's + behavioral logic. It contains all the phases, norms, goals, triggers, + and actions that define how the robot should behave during interactions. + :ivar phases: An ordered list of :class:`Phase` objects defining the interaction flow. """ diff --git a/test/unit/agents/actuation/test_robot_gesture_agent.py b/test/unit/agents/actuation/test_robot_gesture_agent.py index 1e6fd8a..20d7d51 100644 --- a/test/unit/agents/actuation/test_robot_gesture_agent.py +++ b/test/unit/agents/actuation/test_robot_gesture_agent.py @@ -1,5 +1,5 @@ import json -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest import zmq @@ -19,6 +19,12 @@ def zmq_context(mocker): return mock_context +@pytest.fixture(autouse=True) +def mock_experiment_logger(): + with patch("control_backend.agents.actuation.robot_gesture_agent.experiment_logger") as logger: + yield logger + + @pytest.mark.asyncio async def test_setup_bind(zmq_context, mocker): """Setup binds and subscribes to internal commands.""" diff --git a/test/unit/agents/bdi/test_bdi_core_agent.py b/test/unit/agents/bdi/test_bdi_core_agent.py index 6245d5b..1bf0107 100644 --- a/test/unit/agents/bdi/test_bdi_core_agent.py +++ b/test/unit/agents/bdi/test_bdi_core_agent.py @@ -26,6 +26,12 @@ def agent(): return agent +@pytest.fixture(autouse=True) +def mock_experiment_logger(): + with patch("control_backend.agents.bdi.bdi_core_agent.experiment_logger") as logger: + yield logger + + @pytest.mark.asyncio async def test_setup_loads_asl(mock_agentspeak_env, agent): # Mock file opening diff --git a/test/unit/agents/bdi/test_bdi_program_manager.py b/test/unit/agents/bdi/test_bdi_program_manager.py index 540a172..646075b 100644 --- a/test/unit/agents/bdi/test_bdi_program_manager.py +++ b/test/unit/agents/bdi/test_bdi_program_manager.py @@ -8,7 +8,17 @@ import pytest from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager from control_backend.core.agent_system import InternalMessage -from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program +from control_backend.schemas.program import ( + BasicNorm, + ConditionalNorm, + Goal, + InferredBelief, + KeywordBelief, + Phase, + Plan, + Program, + Trigger, +) # Fix Windows Proactor loop for zmq if sys.platform.startswith("win"): @@ -59,7 +69,7 @@ async def test_create_agentspeak_and_send_to_bdi(mock_settings): await manager._create_agentspeak_and_send_to_bdi(program) # Check file writing - mock_file.assert_called_with("src/control_backend/agents/bdi/agentspeak.asl", "w") + mock_file.assert_called_with(mock_settings.behaviour_settings.agentspeak_file, "w") handle = mock_file() handle.write.assert_called() @@ -67,7 +77,7 @@ async def test_create_agentspeak_and_send_to_bdi(mock_settings): msg: InternalMessage = manager.send.await_args[0][0] assert msg.thread == "new_program" assert msg.to == mock_settings.agent_settings.bdi_core_name - assert msg.body == "src/control_backend/agents/bdi/agentspeak.asl" + assert msg.body == mock_settings.behaviour_settings.agentspeak_file @pytest.mark.asyncio @@ -295,3 +305,98 @@ async def test_setup(mock_settings): # 3. Adds behavior manager.add_behavior.assert_called() + + +@pytest.mark.asyncio +async def test_send_program_to_user_interrupt(mock_settings): + """Test directly sending the program to the user interrupt agent.""" + mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" + + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + program = Program.model_validate_json(make_valid_program_json()) + + await manager._send_program_to_user_interrupt(program) + + assert manager.send.await_count == 1 + msg = manager.send.await_args[0][0] + assert msg.to == "user_interrupt_agent" + assert msg.thread == "new_program" + assert "Basic Phase" in msg.body + + +@pytest.mark.asyncio +async def test_complex_program_extraction(): + manager = BDIProgramManager(name="program_manager_test") + + # 1. Create Complex Components + + # Inferred Belief (A & B) + belief_left = KeywordBelief(id=uuid.uuid4(), name="b1", keyword="hot") + belief_right = KeywordBelief(id=uuid.uuid4(), name="b2", keyword="sunny") + inferred_belief = InferredBelief( + id=uuid.uuid4(), name="b_inf", operator="AND", left=belief_left, right=belief_right + ) + + # Conditional Norm + cond_norm = ConditionalNorm( + id=uuid.uuid4(), name="norm_cond", norm="wear_hat", condition=inferred_belief + ) + + # Trigger with Inferred Belief condition + dummy_plan = Plan(id=uuid.uuid4(), name="dummy_plan", steps=[]) + trigger = Trigger(id=uuid.uuid4(), name="trigger_1", condition=inferred_belief, plan=dummy_plan) + + # Nested Goal + sub_goal = Goal( + id=uuid.uuid4(), + name="sub_goal", + description="desc", + plan=Plan(id=uuid.uuid4(), name="empty", steps=[]), + can_fail=True, + ) + + parent_goal = Goal( + id=uuid.uuid4(), + name="parent_goal", + description="desc", + # The plan contains the sub_goal as a step + plan=Plan(id=uuid.uuid4(), name="parent_plan", steps=[sub_goal]), + can_fail=False, + ) + + # 2. Assemble Program + phase = Phase( + id=uuid.uuid4(), + name="Complex Phase", + norms=[cond_norm], + goals=[parent_goal], + triggers=[trigger], + ) + program = Program(phases=[phase]) + + # 3. Initialize Internal State (Triggers _populate_goal_mapping -> Nested Goal logic) + manager._initialize_internal_state(program) + + # Assertion for Line 53-54 (Mapping population) + # Both parent and sub-goal should be mapped + assert str(parent_goal.id) in manager._goal_mapping + assert str(sub_goal.id) in manager._goal_mapping + + # 4. Test Belief Extraction (Triggers lines 132-133, 142-146) + beliefs = manager._extract_current_beliefs() + + # Should extract recursive beliefs from cond_norm and trigger + # Inferred belief splits into Left + Right. Since we use it twice, we get duplicates + # checking existence is enough. + belief_names = [b.name for b in beliefs] + assert "b1" in belief_names + assert "b2" in belief_names + + # 5. Test Goal Extraction (Triggers lines 173, 185) + goals = manager._extract_current_goals() + + goal_names = [g.name for g in goals] + assert "parent_goal" in goal_names + assert "sub_goal" in goal_names diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index bd407cc..bbd6e93 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -18,6 +18,12 @@ def mock_httpx_client(): yield mock_client +@pytest.fixture(autouse=True) +def mock_experiment_logger(): + with patch("control_backend.agents.llm.llm_agent.experiment_logger") as logger: + yield logger + + @pytest.mark.asyncio async def test_llm_processing_success(mock_httpx_client, mock_settings): # Setup the mock response for the stream diff --git a/test/unit/agents/perception/transcription_agent/test_transcription_agent.py b/test/unit/agents/perception/transcription_agent/test_transcription_agent.py index 57875ca..f5a4d1c 100644 --- a/test/unit/agents/perception/transcription_agent/test_transcription_agent.py +++ b/test/unit/agents/perception/transcription_agent/test_transcription_agent.py @@ -14,6 +14,15 @@ from control_backend.agents.perception.transcription_agent.transcription_agent i ) +@pytest.fixture(autouse=True) +def mock_experiment_logger(): + with patch( + "control_backend.agents.perception" + ".transcription_agent.transcription_agent.experiment_logger" + ) as logger: + yield logger + + @pytest.mark.asyncio async def test_transcription_agent_flow(mock_zmq_context): mock_sub = MagicMock() diff --git a/test/unit/agents/perception/vad_agent/test_vad_streaming.py b/test/unit/agents/perception/vad_agent/test_vad_streaming.py index 349fab2..b53f63d 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_streaming.py +++ b/test/unit/agents/perception/vad_agent/test_vad_streaming.py @@ -24,7 +24,9 @@ def audio_out_socket(): @pytest.fixture def vad_agent(audio_out_socket): - return VADAgent("tcp://localhost:5555", False) + agent = VADAgent("tcp://localhost:5555", False) + agent._internal_pub_socket = AsyncMock() + return agent @pytest.fixture(autouse=True) @@ -44,6 +46,12 @@ def patch_settings(monkeypatch): monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False) +@pytest.fixture(autouse=True) +def mock_experiment_logger(): + with patch("control_backend.agents.perception.vad_agent.experiment_logger") as logger: + yield logger + + async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): """ Simulates a streaming scenario with given VAD model probabilities for testing purposes. @@ -84,14 +92,15 @@ async def test_voice_activity_detected(audio_out_socket, vad_agent): Test a scenario where there is voice activity detected between silences. """ speech_chunk_count = 5 - probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5 + begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks + probabilities = [0.0] * 15 + [1.0] * speech_chunk_count + [0.0] * 5 vad_agent.audio_out_socket = audio_out_socket await simulate_streaming_with_probabilities(vad_agent, probabilities) audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] assert isinstance(data, bytes) - assert len(data) == 512 * 4 * (speech_chunk_count + 1) + assert len(data) == 512 * 4 * (begin_silence_chunks + speech_chunk_count) @pytest.mark.asyncio @@ -101,8 +110,9 @@ async def test_voice_activity_short_pause(audio_out_socket, vad_agent): short pause. """ speech_chunk_count = 5 + begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks probabilities = ( - [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5 + [0.0] * 15 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5 ) vad_agent.audio_out_socket = audio_out_socket await simulate_streaming_with_probabilities(vad_agent, probabilities) @@ -110,8 +120,8 @@ async def test_voice_activity_short_pause(audio_out_socket, vad_agent): audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] assert isinstance(data, bytes) - # Expecting 13 chunks (2*5 with speech, 1 pause between, 1 as padding) - assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 1) + # Expecting 13 chunks (2*5 with speech, 1 pause between, begin_silence_chunks as padding) + assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + begin_silence_chunks) @pytest.mark.asyncio diff --git a/test/unit/agents/user_interrupt/test_user_interrupt.py b/test/unit/agents/user_interrupt/test_user_interrupt.py index 7a71891..c41d79e 100644 --- a/test/unit/agents/user_interrupt/test_user_interrupt.py +++ b/test/unit/agents/user_interrupt/test_user_interrupt.py @@ -1,12 +1,13 @@ import asyncio import json -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.program import ( ConditionalNorm, Goal, @@ -29,6 +30,14 @@ def agent(): return agent +@pytest.fixture(autouse=True) +def mock_experiment_logger(): + with patch( + "control_backend.agents.user_interrupt.user_interrupt_agent.experiment_logger" + ) as logger: + yield logger + + @pytest.mark.asyncio async def test_send_to_speech_agent(agent): """Verify speech command format.""" @@ -309,3 +318,375 @@ async def test_send_pause_command(agent): m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name ).args[0] assert vad_msg.body == "RESUME" + + +@pytest.mark.asyncio +async def test_setup(agent): + """Test the setup method initializes sockets correctly.""" + with patch("control_backend.agents.user_interrupt.user_interrupt_agent.Context") as MockContext: + mock_ctx_instance = MagicMock() + MockContext.instance.return_value = mock_ctx_instance + + mock_sub = MagicMock() + mock_pub = MagicMock() + mock_ctx_instance.socket.side_effect = [mock_sub, mock_pub] + + # MOCK add_behavior so we don't rely on internal attributes + agent.add_behavior = MagicMock() + + await agent.setup() + + # Check sockets + mock_sub.connect.assert_called_with(settings.zmq_settings.internal_sub_address) + mock_pub.connect.assert_called_with(settings.zmq_settings.internal_pub_address) + + # Verify add_behavior was called + agent.add_behavior.assert_called_once() + + +@pytest.mark.asyncio +async def test_receive_loop_json_error(agent): + """Verify that malformed JSON is caught and logged without crashing the loop.""" + agent.sub_socket.recv_multipart.side_effect = [ + (b"topic", b"INVALID{JSON"), + asyncio.CancelledError, + ] + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic") + + +@pytest.mark.asyncio +async def test_receive_loop_override_trigger(agent): + """Verify routing 'override' to a Trigger.""" + agent._trigger_map["101"] = "trigger_slug" + payload = json.dumps({"type": "override", "context": "101"}).encode() + + agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] + agent._send_to_bdi = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent._send_to_bdi.assert_awaited_once_with("force_trigger", "trigger_slug") + + +@pytest.mark.asyncio +async def test_receive_loop_override_norm(agent): + """Verify routing 'override' to a Conditional Norm.""" + agent._cond_norm_map["202"] = "norm_slug" + payload = json.dumps({"type": "override", "context": "202"}).encode() + + agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] + agent._send_to_bdi_belief = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent._send_to_bdi_belief.assert_awaited_once_with("norm_slug", "cond_norm") + + +@pytest.mark.asyncio +async def test_receive_loop_override_missing(agent): + """Verify warning log when an override ID is not found in any map.""" + payload = json.dumps({"type": "override", "context": "999"}).encode() + + agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent.logger.warning.assert_called_with("Could not determine which element to override.") + + +@pytest.mark.asyncio +async def test_receive_loop_unachieve_logic(agent): + """Verify success and failure paths for override_unachieve.""" + agent._cond_norm_map["202"] = "norm_slug" + + success_payload = json.dumps({"type": "override_unachieve", "context": "202"}).encode() + fail_payload = json.dumps({"type": "override_unachieve", "context": "999"}).encode() + + agent.sub_socket.recv_multipart.side_effect = [ + (b"topic", success_payload), + (b"topic", fail_payload), + asyncio.CancelledError, + ] + agent._send_to_bdi_belief = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + # Assert success call (True flag for unachieve) + agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True) + # Assert failure log + agent.logger.warning.assert_called_with( + "Could not determine which conditional norm to unachieve." + ) + + +@pytest.mark.asyncio +async def test_receive_loop_pause_resume(agent): + """Verify pause and resume toggle logic and logging.""" + pause_payload = json.dumps({"type": "pause", "context": "true"}).encode() + resume_payload = json.dumps({"type": "pause", "context": ""}).encode() + + agent.sub_socket.recv_multipart.side_effect = [ + (b"topic", pause_payload), + (b"topic", resume_payload), + asyncio.CancelledError, + ] + agent._send_pause_command = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent._send_pause_command.assert_any_call("true") + agent._send_pause_command.assert_any_call("") + agent.logger.info.assert_any_call("Sent pause command.") + agent.logger.info.assert_any_call("Sent resume command.") + + +@pytest.mark.asyncio +async def test_receive_loop_phase_control(agent): + """Verify experiment flow control (next_phase).""" + payload = json.dumps({"type": "next_phase", "context": ""}).encode() + + agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError] + agent._send_experiment_control_to_bdi_core = AsyncMock() + + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + agent._send_experiment_control_to_bdi_core.assert_awaited_once_with("next_phase") + + +@pytest.mark.asyncio +async def test_handle_message_unknown_thread(agent): + """Test handling of an unknown message thread (lines 213-214).""" + msg = InternalMessage(to="me", thread="unknown_thread", body="test") + await agent.handle_message(msg) + + agent.logger.debug.assert_called_with( + "Received internal message on unhandled thread: unknown_thread" + ) + + +@pytest.mark.asyncio +async def test_send_to_bdi_belief_edge_cases(agent): + """ + Covers: + - Unknown asl_type warning (lines 326-328) + - unachieve=True logic (lines 334-337) + """ + # 1. Unknown Type + await agent._send_to_bdi_belief("slug", "unknown_type") + + agent.logger.warning.assert_called_with("Tried to send belief with unknown type") + agent.send.assert_not_called() + + # Reset mock for part 2 + agent.send.reset_mock() + + # 2. Unachieve = True + await agent._send_to_bdi_belief("slug", "cond_norm", unachieve=True) + + agent.send.assert_awaited() + sent_msg = agent.send.call_args.args[0] + + # Verify it is a delete operation + body_obj = BeliefMessage.model_validate_json(sent_msg.body) + + # Verify 'delete' has content + assert body_obj.delete is not None + assert len(body_obj.delete) == 1 + assert body_obj.delete[0].name == "force_slug" + + # Verify 'create' is empty (handling both None and []) + assert not body_obj.create + + +@pytest.mark.asyncio +async def test_send_experiment_control_unknown(agent): + """Test sending an unknown experiment control type (lines 366-367).""" + await agent._send_experiment_control_to_bdi_core("invalid_command") + + agent.logger.warning.assert_called_with( + "Received unknown experiment control type '%s' to send to BDI Core.", "invalid_command" + ) + + # Ensure it still sends an empty message (as per code logic, though thread is empty) + agent.send.assert_awaited() + msg = agent.send.call_args[0][0] + assert msg.thread == "" + + +@pytest.mark.asyncio +async def test_create_mapping_recursive_goals(agent): + """Verify that nested subgoals are correctly registered in the mapping.""" + import uuid + + # 1. Setup IDs + parent_goal_id = uuid.uuid4() + child_goal_id = uuid.uuid4() + + # 2. Create the child goal + child_goal = Goal( + id=child_goal_id, + name="child_goal", + description="I am a subgoal", + plan=Plan(id=uuid.uuid4(), name="p_child", steps=[]), + ) + + # 3. Create the parent goal and put the child goal inside its plan steps + parent_goal = Goal( + id=parent_goal_id, + name="parent_goal", + description="I am a parent", + plan=Plan(id=uuid.uuid4(), name="p_parent", steps=[child_goal]), # Nested here + ) + + # 4. Build the program + phase = Phase( + id=uuid.uuid4(), + name="phase1", + norms=[], + goals=[parent_goal], # Only the parent is top-level + triggers=[], + ) + prog = Program(phases=[phase]) + + # 5. Execute mapping + msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json()) + await agent.handle_message(msg) + + # 6. Assertions + # Check parent + assert str(parent_goal_id) in agent._goal_map + assert agent._goal_map[str(parent_goal_id)] == "parent_goal" + + # Check child (This confirms the recursion worked) + assert str(child_goal_id) in agent._goal_map + assert agent._goal_map[str(child_goal_id)] == "child_goal" + assert agent._goal_reverse_map["child_goal"] == str(child_goal_id) + + +@pytest.mark.asyncio +async def test_receive_loop_advanced_scenarios(agent): + """ + Covers: + - JSONDecodeError (lines 86-88) + - Override: Trigger found (lines 108-109) + - Override: Norm found (lines 114-115) + - Override: Nothing found (line 134) + - Override Unachieve: Success & Fail (lines 136-145) + - Pause: Context true/false logs (lines 150-157) + - Next Phase (line 160) + """ + # 1. Setup Data Maps + agent._trigger_map["101"] = "trigger_slug" + agent._cond_norm_map["202"] = "norm_slug" + + # 2. Define Payloads + # A. Invalid JSON + bad_json = b"INVALID{JSON" + + # B. Override -> Trigger + override_trigger = json.dumps({"type": "override", "context": "101"}).encode() + + # C. Override -> Norm + override_norm = json.dumps({"type": "override", "context": "202"}).encode() + + # D. Override -> Unknown + override_fail = json.dumps({"type": "override", "context": "999"}).encode() + + # E. Unachieve -> Success + unachieve_success = json.dumps({"type": "override_unachieve", "context": "202"}).encode() + + # F. Unachieve -> Fail + unachieve_fail = json.dumps({"type": "override_unachieve", "context": "999"}).encode() + + # G. Pause (True) + pause_true = json.dumps({"type": "pause", "context": "true"}).encode() + + # H. Pause (False/Resume) + pause_false = json.dumps({"type": "pause", "context": ""}).encode() + + # I. Next Phase + next_phase = json.dumps({"type": "next_phase", "context": ""}).encode() + + # 3. Setup Socket + agent.sub_socket.recv_multipart.side_effect = [ + (b"topic", bad_json), + (b"topic", override_trigger), + (b"topic", override_norm), + (b"topic", override_fail), + (b"topic", unachieve_success), + (b"topic", unachieve_fail), + (b"topic", pause_true), + (b"topic", pause_false), + (b"topic", next_phase), + asyncio.CancelledError, # End loop + ] + + # Mock internal helpers to verify calls + agent._send_to_bdi = AsyncMock() + agent._send_to_bdi_belief = AsyncMock() + agent._send_pause_command = AsyncMock() + agent._send_experiment_control_to_bdi_core = AsyncMock() + + # 4. Run Loop + try: + await agent._receive_button_event() + except asyncio.CancelledError: + pass + + # 5. Assertions + + # JSON Error + agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic") + + # Override Trigger + agent._send_to_bdi.assert_awaited_with("force_trigger", "trigger_slug") + + # Override Norm + # We expect _send_to_bdi_belief to be called for the norm + # Note: The loop calls _send_to_bdi_belief(asl_cond_norm, "cond_norm") + agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm") + + # Override Fail (Warning log) + agent.logger.warning.assert_any_call("Could not determine which element to override.") + + # Unachieve Success + # Loop calls _send_to_bdi_belief(asl_cond_norm, "cond_norm", True) + agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True) + + # Unachieve Fail + agent.logger.warning.assert_any_call("Could not determine which conditional norm to unachieve.") + + # Pause Logic + agent._send_pause_command.assert_any_call("true") + agent.logger.info.assert_any_call("Sent pause command.") + + # Resume Logic + agent._send_pause_command.assert_any_call("") + agent.logger.info.assert_any_call("Sent resume command.") + + # Next Phase + agent._send_experiment_control_to_bdi_core.assert_awaited_with("next_phase") diff --git a/test/unit/api/v1/endpoints/test_logs_endpoint.py b/test/unit/api/v1/endpoints/test_logs_endpoint.py index 50ee740..4aaa90e 100644 --- a/test/unit/api/v1/endpoints/test_logs_endpoint.py +++ b/test/unit/api/v1/endpoints/test_logs_endpoint.py @@ -1,7 +1,7 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient from starlette.responses import StreamingResponse @@ -61,3 +61,67 @@ async def test_log_stream_endpoint_lines(client): # Optional: assert subscribe/connect were called assert dummy_socket.subscribed # at least some log levels subscribed assert dummy_socket.connected # connect was called + + +@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR") +def test_files_endpoint(LOGGING_DIR, client): + file_1, file_2 = MagicMock(), MagicMock() + file_1.name = "file_1" + file_2.name = "file_2" + LOGGING_DIR.glob.return_value = [file_1, file_2] + result = client.get("/api/logs/files") + + assert result.status_code == 200 + assert result.json() == ["file_1", "file_2"] + + +@patch("control_backend.api.v1.endpoints.logs.FileResponse") +@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR") +def test_log_file_endpoint_success(LOGGING_DIR, MockFileResponse, client): + mock_file_path = MagicMock() + mock_file_path.is_relative_to.return_value = True + mock_file_path.is_file.return_value = True + mock_file_path.name = "test.log" + + LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path) + mock_file_path.resolve.return_value = mock_file_path + + MockFileResponse.return_value = MagicMock() + + result = client.get("/api/logs/files/test.log") + + assert result.status_code == 200 + MockFileResponse.assert_called_once_with(mock_file_path, filename="test.log") + + +@pytest.mark.asyncio +@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR") +async def test_log_file_endpoint_path_traversal(LOGGING_DIR): + from control_backend.api.v1.endpoints.logs import log_file + + mock_file_path = MagicMock() + mock_file_path.is_relative_to.return_value = False + + LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path) + mock_file_path.resolve.return_value = mock_file_path + + with pytest.raises(HTTPException) as exc_info: + await log_file("../secret.txt") + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid filename." + + +@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR") +def test_log_file_endpoint_file_not_found(LOGGING_DIR, client): + mock_file_path = MagicMock() + mock_file_path.is_relative_to.return_value = True + mock_file_path.is_file.return_value = False + + LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path) + mock_file_path.resolve.return_value = mock_file_path + + result = client.get("/api/logs/files/nonexistent.log") + + assert result.status_code == 404 + assert result.json()["detail"] == "File not found." diff --git a/test/unit/api/v1/endpoints/test_user_interact.py b/test/unit/api/v1/endpoints/test_user_interact.py index ddb9932..9785eec 100644 --- a/test/unit/api/v1/endpoints/test_user_interact.py +++ b/test/unit/api/v1/endpoints/test_user_interact.py @@ -94,3 +94,55 @@ async def test_experiment_stream_direct_call(): mock_socket.connect.assert_called() mock_socket.subscribe.assert_called_with(b"experiment") mock_socket.close.assert_called() + + +@pytest.mark.asyncio +async def test_status_stream_direct_call(): + """ + Test the status stream, ensuring it handles messages and sends pings on timeout. + """ + mock_socket = AsyncMock() + + # Define the sequence of events for the socket: + # 1. Successfully receive a message + # 2. Timeout (which should trigger the ': ping' yield) + # 3. Another message (which won't be reached because we'll simulate disconnect) + mock_socket.recv_multipart.side_effect = [ + (b"topic", b"status_update"), + TimeoutError(), + (b"topic", b"ignored_msg"), + ] + + mock_socket.close = MagicMock() + mock_socket.connect = MagicMock() + mock_socket.subscribe = MagicMock() + + mock_context = MagicMock() + mock_context.socket.return_value = mock_socket + + # Mock the ZMQ Context to return our mock_socket + with patch( + "control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context + ): + mock_request = AsyncMock() + + # is_disconnected sequence: + # 1. False -> Process "status_update" + # 2. False -> Process TimeoutError (yield ping) + # 3. True -> Break loop (client disconnected) + mock_request.is_disconnected.side_effect = [False, False, True] + + # Call the status_stream function explicitly + response = await user_interact.status_stream(mock_request) + + lines = [] + async for line in response.body_iterator: + lines.append(line) + + # Assertions + assert "data: status_update\n\n" in lines + assert ": ping\n\n" in lines # Verify lines 91-92 (ping logic) + + mock_socket.connect.assert_called() + mock_socket.subscribe.assert_called_with(b"status") + mock_socket.close.assert_called() diff --git a/test/unit/conftest.py b/test/unit/conftest.py index d5f06e5..5e925d0 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -32,6 +32,7 @@ def mock_settings(): mock.agent_settings.vad_name = "vad_agent" mock.behaviour_settings.sleep_s = 0.01 # Speed up tests mock.behaviour_settings.comm_setup_max_retries = 1 + mock.behaviour_settings.agentspeak_file = "src/control_backend/agents/bdi/agentspeak.asl" yield mock diff --git a/test/unit/logging/test_dated_file_handler.py b/test/unit/logging/test_dated_file_handler.py new file mode 100644 index 0000000..14809fb --- /dev/null +++ b/test/unit/logging/test_dated_file_handler.py @@ -0,0 +1,45 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from control_backend.logging.dated_file_handler import DatedFileHandler + + +@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open") +def test_reset(open_): + stream = MagicMock() + open_.return_value = stream + + # A file should be opened when the logger is created + handler = DatedFileHandler(file_prefix="anything") + assert open_.call_count == 1 + + # Upon reset, the current file should be closed, and a new one should be opened + handler.do_rollover() + assert stream.close.call_count == 1 + assert open_.call_count == 2 + + +@patch("control_backend.logging.dated_file_handler.Path") +@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open") +def test_creates_dir(open_, Path_): + stream = MagicMock() + open_.return_value = stream + + test_path = MagicMock() + test_path.parent.is_dir.return_value = False + Path_.return_value = test_path + + DatedFileHandler(file_prefix="anything") + + # The directory should've been created + test_path.parent.mkdir.assert_called_once() + + +@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open") +def test_invalid_constructor(_): + with pytest.raises(ValueError): + DatedFileHandler(file_prefix=None) + + with pytest.raises(ValueError): + DatedFileHandler(file_prefix="") diff --git a/test/unit/logging/test_optional_field_formatter.py b/test/unit/logging/test_optional_field_formatter.py new file mode 100644 index 0000000..ae75bd9 --- /dev/null +++ b/test/unit/logging/test_optional_field_formatter.py @@ -0,0 +1,218 @@ +import logging + +import pytest + +from control_backend.logging.optional_field_formatter import OptionalFieldFormatter + + +@pytest.fixture +def logger(): + """Create a fresh logger for each test.""" + logger = logging.getLogger(f"test_{id(object())}") + logger.setLevel(logging.DEBUG) + logger.handlers = [] + return logger + + +@pytest.fixture +def log_output(logger): + """Capture log output and return a function to get it.""" + + class ListHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records = [] + + def emit(self, record): + self.records.append(self.format(record)) + + handler = ListHandler() + logger.addHandler(handler) + + def get_output(): + return handler.records + + return get_output + + +def test_optional_field_present(logger, log_output): + """Optional field should appear when provided in extra.""" + formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s") + logger.handlers[0].setFormatter(formatter) + + logger.info("test message", extra={"role": "user"}) + + assert log_output() == ["INFO - user - test message"] + + +def test_optional_field_missing_no_default(logger, log_output): + """Missing optional field with no default should be None.""" + formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s") + logger.handlers[0].setFormatter(formatter) + + logger.info("test message") + + assert log_output() == ["INFO - None - test message"] + + +def test_optional_field_missing_with_default(logger, log_output): + """Missing optional field should use provided default.""" + formatter = OptionalFieldFormatter( + "%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"} + ) + logger.handlers[0].setFormatter(formatter) + + logger.info("test message") + + assert log_output() == ["INFO - assistant - test message"] + + +def test_optional_field_overrides_default(logger, log_output): + """Provided extra value should override default.""" + formatter = OptionalFieldFormatter( + "%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"} + ) + logger.handlers[0].setFormatter(formatter) + + logger.info("test message", extra={"role": "user"}) + + assert log_output() == ["INFO - user - test message"] + + +def test_multiple_optional_fields(logger, log_output): + """Multiple optional fields should work independently.""" + formatter = OptionalFieldFormatter( + "%(levelname)s - %(role?)s - %(request_id?)s - %(message)s", defaults={"role": "assistant"} + ) + logger.handlers[0].setFormatter(formatter) + + logger.info("test", extra={"request_id": "123"}) + + assert log_output() == ["INFO - assistant - 123 - test"] + + +def test_mixed_optional_and_required_fields(logger, log_output): + """Standard fields should work alongside optional fields.""" + formatter = OptionalFieldFormatter("%(levelname)s %(name)s %(role?)s %(message)s") + logger.handlers[0].setFormatter(formatter) + + logger.info("test", extra={"role": "user"}) + + output = log_output()[0] + assert "INFO" in output + assert "user" in output + assert "test" in output + + +def test_no_optional_fields(logger, log_output): + """Formatter should work normally with no optional fields.""" + formatter = OptionalFieldFormatter("%(levelname)s %(message)s") + logger.handlers[0].setFormatter(formatter) + + logger.info("test message") + + assert log_output() == ["INFO test message"] + + +def test_integer_format_specifier(logger, log_output): + """Optional fields with %d specifier should work.""" + formatter = OptionalFieldFormatter( + "%(levelname)s %(count?)d %(message)s", defaults={"count": 0} + ) + logger.handlers[0].setFormatter(formatter) + + logger.info("test", extra={"count": 42}) + + assert log_output() == ["INFO 42 test"] + + +def test_float_format_specifier(logger, log_output): + """Optional fields with %f specifier should work.""" + formatter = OptionalFieldFormatter( + "%(levelname)s %(duration?)f %(message)s", defaults={"duration": 0.0} + ) + logger.handlers[0].setFormatter(formatter) + + logger.info("test", extra={"duration": 1.5}) + + assert "1.5" in log_output()[0] + + +def test_empty_string_default(logger, log_output): + """Empty string default should work.""" + formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults={"role": ""}) + logger.handlers[0].setFormatter(formatter) + + logger.info("test") + + assert log_output() == ["INFO test"] + + +def test_none_format_string(): + """None format string should not raise.""" + formatter = OptionalFieldFormatter(fmt=None) + assert formatter.optional_fields == set() + + +def test_optional_fields_parsed_correctly(): + """Check that optional fields are correctly identified.""" + formatter = OptionalFieldFormatter("%(asctime)s %(role?)s %(level?)d %(name)s") + + assert formatter.optional_fields == {("role", "s"), ("level", "d")} + + +def test_format_string_normalized(): + """Check that ? is removed from format string.""" + formatter = OptionalFieldFormatter("%(role?)s %(message)s") + + assert "?" not in formatter._style._fmt + assert "%(role)s" in formatter._style._fmt + + +def test_field_with_underscore(logger, log_output): + """Field names with underscores should work.""" + formatter = OptionalFieldFormatter("%(levelname)s %(user_id?)s %(message)s") + logger.handlers[0].setFormatter(formatter) + + logger.info("test", extra={"user_id": "abc123"}) + + assert log_output() == ["INFO abc123 test"] + + +def test_field_with_numbers(logger, log_output): + """Field names with numbers should work.""" + formatter = OptionalFieldFormatter("%(levelname)s %(field2?)s %(message)s") + logger.handlers[0].setFormatter(formatter) + + logger.info("test", extra={"field2": "value"}) + + assert log_output() == ["INFO value test"] + + +def test_multiple_log_calls(logger, log_output): + """Formatter should work correctly across multiple log calls.""" + formatter = OptionalFieldFormatter( + "%(levelname)s %(role?)s %(message)s", defaults={"role": "other"} + ) + logger.handlers[0].setFormatter(formatter) + + logger.info("first", extra={"role": "assistant"}) + logger.info("second") + logger.info("third", extra={"role": "user"}) + + assert log_output() == [ + "INFO assistant first", + "INFO other second", + "INFO user third", + ] + + +def test_default_not_mutated(logger, log_output): + """Original defaults dict should not be mutated.""" + defaults = {"role": "other"} + formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults=defaults) + logger.handlers[0].setFormatter(formatter) + + logger.info("test") + + assert defaults == {"role": "other"} diff --git a/test/unit/logging/test_partial_filter.py b/test/unit/logging/test_partial_filter.py new file mode 100644 index 0000000..bd5ef10 --- /dev/null +++ b/test/unit/logging/test_partial_filter.py @@ -0,0 +1,83 @@ +import logging + +import pytest + +from control_backend.logging import PartialFilter + + +@pytest.fixture +def logger(): + """Create a fresh logger for each test.""" + logger = logging.getLogger(f"test_{id(object())}") + logger.setLevel(logging.DEBUG) + logger.handlers = [] + return logger + + +@pytest.fixture +def log_output(logger): + """Capture log output and return a function to get it.""" + + class ListHandler(logging.Handler): + def __init__(self): + super().__init__() + self.records = [] + + def emit(self, record): + self.records.append(self.format(record)) + + handler = ListHandler() + handler.addFilter(PartialFilter()) + handler.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(handler) + + return lambda: handler.records + + +def test_no_partial_attribute(logger, log_output): + """Records without partial attribute should pass through.""" + logger.info("normal message") + + assert log_output() == ["normal message"] + + +def test_partial_true_filtered(logger, log_output): + """Records with partial=True should be filtered out.""" + logger.info("partial message", extra={"partial": True}) + + assert log_output() == [] + + +def test_partial_false_passes(logger, log_output): + """Records with partial=False should pass through.""" + logger.info("complete message", extra={"partial": False}) + + assert log_output() == ["complete message"] + + +def test_partial_none_passes(logger, log_output): + """Records with partial=None should pass through.""" + logger.info("message", extra={"partial": None}) + + assert log_output() == ["message"] + + +def test_partial_truthy_value_passes(logger, log_output): + """ + Records with truthy but non-True partial should pass through, that is, only when it's exactly + ``True`` should it pass. + """ + logger.info("message", extra={"partial": "yes"}) + + assert log_output() == ["message"] + + +def test_multiple_records_mixed(logger, log_output): + """Filter should handle mixed records correctly.""" + logger.info("first") + logger.info("second", extra={"partial": True}) + logger.info("third", extra={"partial": False}) + logger.info("fourth", extra={"partial": True}) + logger.info("fifth") + + assert log_output() == ["first", "third", "fifth"]