Compare commits

..

3 Commits

Author SHA1 Message Date
Twirre Meulenbelt
3406e9ac2f feat: make the pipeline work with Program and AgentSpeak
ref: N25B-429
2026-01-06 15:26:44 +01:00
a357b6990b feat: send program to bdi core
ref: N25B-376
2026-01-06 12:11:37 +01:00
9eea4ee345 feat: new ASL generation
ref: N25B-376
2026-01-02 12:08:20 +01:00
37 changed files with 638 additions and 2174 deletions

View File

@@ -1,9 +0,0 @@
%{first_multiline_commit_description}
To verify:
- [ ] Style checks pass
- [ ] Pipeline (tests) pass
- [ ] Documentation is up to date
- [ ] Tests are up to date (new code is covered)
- [ ] ...

View File

@@ -28,7 +28,6 @@ class RobotGestureAgent(BaseAgent):
address = "" address = ""
bind = False bind = False
gesture_data = [] gesture_data = []
single_gesture_data = []
def __init__( def __init__(
self, self,
@@ -36,10 +35,8 @@ class RobotGestureAgent(BaseAgent):
address=settings.zmq_settings.ri_command_address, address=settings.zmq_settings.ri_command_address,
bind=False, bind=False,
gesture_data=None, gesture_data=None,
single_gesture_data=None,
): ):
self.gesture_data = gesture_data or [] self.gesture_data = gesture_data or []
self.single_gesture_data = single_gesture_data or []
super().__init__(name) super().__init__(name)
self.address = address self.address = address
self.bind = bind self.bind = bind
@@ -102,13 +99,7 @@ class RobotGestureAgent(BaseAgent):
gesture_command.data, gesture_command.data,
) )
return return
elif gesture_command.endpoint == RIEndpoint.GESTURE_SINGLE:
if gesture_command.data not in self.single_gesture_data:
self.logger.warning(
"Received gesture '%s' which is not in available gestures. Early returning",
gesture_command.data,
)
return
await self.pubsocket.send_json(gesture_command.model_dump()) await self.pubsocket.send_json(gesture_command.model_dump())
except Exception: except Exception:
self.logger.exception("Error processing internal message.") self.logger.exception("Error processing internal message.")

View File

@@ -187,9 +187,10 @@ class StatementType(StrEnum):
EMPTY = "" EMPTY = ""
DO_ACTION = "." DO_ACTION = "."
ACHIEVE_GOAL = "!" ACHIEVE_GOAL = "!"
# TEST_GOAL = "?" # TODO TEST_GOAL = "?"
ADD_BELIEF = "+" ADD_BELIEF = "+"
REMOVE_BELIEF = "-" REMOVE_BELIEF = "-"
REPLACE_BELIEF = "-+"
@dataclass @dataclass

View File

@@ -0,0 +1,373 @@
from functools import singledispatchmethod
from slugify import slugify
from control_backend.agents.bdi.agentspeak_ast import (
AstBinaryOp,
AstExpression,
AstLiteral,
AstPlan,
AstProgram,
AstRule,
AstStatement,
AstString,
AstVar,
BinaryOperatorType,
StatementType,
TriggerType,
)
from control_backend.schemas.program import (
BasicNorm,
ConditionalNorm,
GestureAction,
Goal,
InferredBelief,
KeywordBelief,
LLMAction,
LogicalOperator,
Norm,
Phase,
PlanElement,
Program,
ProgramElement,
SemanticBelief,
SpeechAction,
Trigger,
)
class AgentSpeakGenerator:
_asp: AstProgram
def generate(self, program: Program) -> str:
self._asp = AstProgram()
self._asp.rules.append(AstRule(self._astify(program.phases[0])))
self._add_keyword_inference()
self._add_default_plans()
self._process_phases(program.phases)
self._add_fallbacks()
return str(self._asp)
def _add_keyword_inference(self) -> None:
keyword = AstVar("Keyword")
message = AstVar("Message")
position = AstVar("Pos")
self._asp.rules.append(
AstRule(
AstLiteral("keyword_said", [keyword]),
AstLiteral("user_said", [message])
& AstLiteral(".substring", [keyword, message, position])
& (position >= 0),
)
)
def _add_default_plans(self):
self._add_reply_with_goal_plan()
self._add_say_plan()
self._add_reply_plan()
def _add_reply_with_goal_plan(self):
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("reply_with_goal", [AstVar("Goal")]),
[AstLiteral("user_said", [AstVar("Message")])],
[
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"findall",
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
),
),
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"reply_with_goal", [AstVar("Message"), AstVar("Norms"), AstVar("Goal")]
),
),
],
)
)
def _add_say_plan(self):
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("say", [AstVar("Text")]),
[],
[
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
AstStatement(StatementType.DO_ACTION, AstLiteral("say", [AstVar("Text")])),
],
)
)
def _add_reply_plan(self):
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("reply"),
[AstLiteral("user_said", [AstVar("Message")])],
[
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"findall",
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
),
),
AstStatement(
StatementType.DO_ACTION,
AstLiteral("reply", [AstVar("Message"), AstVar("Norms")]),
),
],
)
)
def _process_phases(self, phases: list[Phase]) -> None:
for curr_phase, next_phase in zip([None] + phases, phases + [None], strict=True):
if curr_phase:
self._process_phase(curr_phase)
self._add_phase_transition(curr_phase, next_phase)
# End phase behavior
# When deleting this, the entire `reply` plan and action can be deleted
self._asp.plans.append(
AstPlan(
type=TriggerType.ADDED_BELIEF,
trigger_literal=AstLiteral("user_said", [AstVar("Message")]),
context=[AstLiteral("phase", [AstString("end")])],
body=[AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply"))],
)
)
def _process_phase(self, phase: Phase) -> None:
for norm in phase.norms:
self._process_norm(norm, phase)
self._add_default_loop(phase)
previous_goal = None
for goal in phase.goals:
self._process_goal(goal, phase, previous_goal)
previous_goal = goal
for trigger in phase.triggers:
self._process_trigger(trigger, phase)
def _add_phase_transition(self, from_phase: Phase | None, to_phase: Phase | None) -> None:
if from_phase is None:
return
from_phase_ast = self._astify(from_phase)
to_phase_ast = (
self._astify(to_phase) if to_phase else AstLiteral("phase", [AstString("end")])
)
context = [from_phase_ast, ~AstLiteral("responded_this_turn")]
if from_phase and from_phase.goals:
context.append(self._astify(from_phase.goals[-1], achieved=True))
body = [
AstStatement(StatementType.REMOVE_BELIEF, from_phase_ast),
AstStatement(StatementType.ADD_BELIEF, to_phase_ast),
]
if from_phase:
body.extend(
[
AstStatement(
StatementType.TEST_GOAL, AstLiteral("user_said", [AstVar("Message")])
),
AstStatement(
StatementType.REPLACE_BELIEF, AstLiteral("user_said", [AstVar("Message")])
),
]
)
self._asp.plans.append(
AstPlan(TriggerType.ADDED_GOAL, AstLiteral("transition_phase"), context, body)
)
def _process_norm(self, norm: Norm, phase: Phase) -> None:
rule: AstRule | None = None
match norm:
case ConditionalNorm(condition=cond):
rule = AstRule(self._astify(norm), self._astify(phase) & self._astify(cond))
case BasicNorm():
rule = AstRule(self._astify(norm), self._astify(phase))
if not rule:
return
self._asp.rules.append(rule)
def _add_default_loop(self, phase: Phase) -> None:
actions = []
actions.append(AstStatement(StatementType.REMOVE_BELIEF, AstLiteral("responded_this_turn")))
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("check_triggers")))
for goal in phase.goals:
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, self._astify(goal)))
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("transition_phase")))
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_BELIEF,
AstLiteral("user_said", [AstVar("Message")]),
[self._astify(phase)],
actions,
)
)
def _process_goal(
self,
goal: Goal,
phase: Phase,
previous_goal: Goal | None = None,
continues_response: bool = False,
) -> None:
context: list[AstExpression] = [self._astify(phase)]
context.append(~self._astify(goal, achieved=True))
if previous_goal and previous_goal.can_fail:
context.append(self._astify(previous_goal, achieved=True))
if not continues_response:
context.append(~AstLiteral("responded_this_turn"))
body = []
subgoals = []
for step in goal.plan.steps:
body.append(self._step_to_statement(step))
if isinstance(step, Goal):
subgoals.append(step)
if not goal.can_fail and not continues_response:
body.append(AstStatement(StatementType.ADD_BELIEF, self._astify(goal, achieved=True)))
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(goal), context, body))
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
self._astify(goal),
context=[],
body=[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
)
)
prev_goal = None
for subgoal in subgoals:
self._process_goal(subgoal, phase, prev_goal)
prev_goal = subgoal
def _step_to_statement(self, step: PlanElement) -> AstStatement:
match step:
case Goal() | SpeechAction() | LLMAction() as a:
return AstStatement(StatementType.ACHIEVE_GOAL, self._astify(a))
case GestureAction() as a:
return AstStatement(StatementType.DO_ACTION, self._astify(a))
# TODO: separate handling of keyword and others
def _process_trigger(self, trigger: Trigger, phase: Phase) -> None:
body = []
subgoals = []
for step in trigger.plan.steps:
body.append(self._step_to_statement(step))
if isinstance(step, Goal):
step.can_fail = False # triggers are continuous sequence
subgoals.append(step)
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("check_triggers"),
[self._astify(phase), self._astify(trigger.condition)],
body,
)
)
for subgoal in subgoals:
self._process_goal(subgoal, phase, continues_response=True)
def _add_fallbacks(self):
# Trigger fallback
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("check_triggers"),
[],
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
)
)
# Phase transition fallback
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("transition_phase"),
[],
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
)
)
@singledispatchmethod
def _astify(self, element: ProgramElement) -> AstExpression:
raise NotImplementedError(f"Cannot convert element {element} to an AgentSpeak expression.")
@_astify.register
def _(self, kwb: KeywordBelief) -> AstExpression:
return AstLiteral("keyword_said", [AstString(kwb.keyword)])
@_astify.register
def _(self, sb: SemanticBelief) -> AstExpression:
return AstLiteral(f"semantic_{self._slugify_str(sb.description)}")
@_astify.register
def _(self, ib: InferredBelief) -> AstExpression:
return AstBinaryOp(
self._astify(ib.left),
BinaryOperatorType.AND if ib.operator == LogicalOperator.AND else BinaryOperatorType.OR,
self._astify(ib.right),
)
@_astify.register
def _(self, norm: Norm) -> AstExpression:
functor = "critical_norm" if norm.critical else "norm"
return AstLiteral(functor, [AstString(norm.norm)])
@_astify.register
def _(self, phase: Phase) -> AstExpression:
return AstLiteral("phase", [AstString(str(phase.id))])
@_astify.register
def _(self, goal: Goal, achieved: bool = False) -> AstExpression:
return AstLiteral(f"{'achieved_' if achieved else ''}{self._slugify_str(goal.name)}")
@_astify.register
def _(self, sa: SpeechAction) -> AstExpression:
return AstLiteral("say", [AstString(sa.text)])
@_astify.register
def _(self, ga: GestureAction) -> AstExpression:
gesture = ga.gesture
return AstLiteral("gesture", [AstString(gesture.type), AstString(gesture.name)])
@_astify.register
def _(self, la: LLMAction) -> AstExpression:
return AstLiteral("reply_with_goal", [AstString(la.goal)])
@staticmethod
def _slugify_str(text: str) -> str:
return slugify(text, separator="_", stopwords=["a", "an", "the", "we", "you", "I"])

View File

@@ -11,7 +11,7 @@ from pydantic import ValidationError
from control_backend.agents.base import BaseAgent from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage from control_backend.schemas.llm_prompt_message import LLMPromptMessage
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
@@ -42,13 +42,13 @@ class BDICoreAgent(BaseAgent):
bdi_agent: agentspeak.runtime.Agent bdi_agent: agentspeak.runtime.Agent
def __init__(self, name: str, asl: str): def __init__(self, name: str):
super().__init__(name) super().__init__(name)
self.asl_file = asl
self.env = agentspeak.runtime.Environment() self.env = agentspeak.runtime.Environment()
# Deep copy because we don't actually want to modify the standard actions globally # Deep copy because we don't actually want to modify the standard actions globally
self.actions = copy.deepcopy(agentspeak.stdlib.actions) self.actions = copy.deepcopy(agentspeak.stdlib.actions)
self._wake_bdi_loop = asyncio.Event() self._wake_bdi_loop = asyncio.Event()
self._bdi_loop_task = None
async def setup(self) -> None: async def setup(self) -> None:
""" """
@@ -65,19 +65,22 @@ class BDICoreAgent(BaseAgent):
await self._load_asl() await self._load_asl()
# Start the BDI cycle loop # Start the BDI cycle loop
self.add_behavior(self._bdi_loop()) self._bdi_loop_task = self.add_behavior(self._bdi_loop())
self._wake_bdi_loop.set() self._wake_bdi_loop.set()
self.logger.debug("Setup complete.") self.logger.debug("Setup complete.")
async def _load_asl(self): async def _load_asl(self, file_name: str | None = None) -> None:
""" """
Load and parse the AgentSpeak source file. Load and parse the AgentSpeak source file.
""" """
file_name = file_name or "src/control_backend/agents/bdi/default_behavior.asl"
try: try:
with open(self.asl_file) as source: with open(file_name) as source:
self.bdi_agent = self.env.build_agent(source, self.actions) self.bdi_agent = self.env.build_agent(source, self.actions)
self.logger.info(f"Loaded new ASL from {file_name}.")
except FileNotFoundError: except FileNotFoundError:
self.logger.warning(f"Could not find the specified ASL file at {self.asl_file}.") self.logger.warning(f"Could not find the specified ASL file at {file_name}.")
self.bdi_agent = agentspeak.runtime.Agent(self.env, self.name) self.bdi_agent = agentspeak.runtime.Agent(self.env, self.name)
async def _bdi_loop(self): async def _bdi_loop(self):
@@ -116,6 +119,7 @@ class BDICoreAgent(BaseAgent):
Handle incoming messages. Handle incoming messages.
- **Beliefs**: Updates the internal belief base. - **Beliefs**: Updates the internal belief base.
- **Program**: Updates the internal agentspeak file to match the current program.
- **LLM Responses**: Forwards the generated text to the Robot Speech Agent (actuation). - **LLM Responses**: Forwards the generated text to the Robot Speech Agent (actuation).
:param msg: The received internal message. :param msg: The received internal message.
@@ -124,12 +128,19 @@ class BDICoreAgent(BaseAgent):
if msg.thread == "beliefs": if msg.thread == "beliefs":
try: try:
belief_changes = BeliefMessage.model_validate_json(msg.body) beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
self._apply_belief_changes(belief_changes) self._apply_beliefs(beliefs)
except ValidationError: except ValidationError:
self.logger.exception("Error processing belief.") self.logger.exception("Error processing belief.")
return return
# New agentspeak file
if msg.thread == "new_program":
if self._bdi_loop_task:
self._bdi_loop_task.cancel()
await self._load_asl(msg.body)
self.add_behavior(self._bdi_loop())
# The message was not a belief, handle special cases based on sender # The message was not a belief, handle special cases based on sender
match msg.sender: match msg.sender:
case settings.agent_settings.llm_name: case settings.agent_settings.llm_name:
@@ -145,28 +156,21 @@ class BDICoreAgent(BaseAgent):
) )
await self.send(out_msg) await self.send(out_msg)
def _apply_belief_changes(self, belief_changes: BeliefMessage): def _apply_beliefs(self, beliefs: list[Belief]):
""" """
Update the belief base with a list of new beliefs. Update the belief base with a list of new beliefs.
For beliefs in ``belief_changes.replace``, it removes all existing beliefs with that name If ``replace=True`` is set on a belief, it removes all existing beliefs with that name
before adding one new one. before adding the new one.
:param belief_changes: The changes in beliefs to apply.
""" """
if not belief_changes.create and not belief_changes.replace and not belief_changes.delete: if not beliefs:
return return
for belief in belief_changes.create: for belief in beliefs:
if belief.replace:
self._remove_all_with_name(belief.name)
self._add_belief(belief.name, belief.arguments) self._add_belief(belief.name, belief.arguments)
for belief in belief_changes.replace:
self._remove_all_with_name(belief.name)
self._add_belief(belief.name, belief.arguments)
for belief in belief_changes.delete:
self._remove_belief(belief.name, belief.arguments)
def _add_belief(self, name: str, args: list[str] = None): def _add_belief(self, name: str, args: list[str] = None):
""" """
Add a single belief to the BDI agent. Add a single belief to the BDI agent.
@@ -246,20 +250,18 @@ class BDICoreAgent(BaseAgent):
the function expects (which will be located in `term.args`). the function expects (which will be located in `term.args`).
""" """
@self.actions.add(".reply", 3) @self.actions.add(".reply", 2)
def _reply(agent: "BDICoreAgent", term, intention): def _reply(agent: "BDICoreAgent", term, intention):
""" """
Let the LLM generate a response to a user's utterance with the current norms and goals. Let the LLM generate a response to a user's utterance with the current norms and goals.
""" """
message_text = agentspeak.grounded(term.args[0], intention.scope) message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope) norms = agentspeak.grounded(term.args[1], intention.scope)
goals = agentspeak.grounded(term.args[2], intention.scope)
self.logger.debug("Norms: %s", norms) self.logger.debug("Norms: %s", norms)
self.logger.debug("Goals: %s", goals)
self.logger.debug("User text: %s", message_text) self.logger.debug("User text: %s", message_text)
asyncio.create_task(self._send_to_llm(str(message_text), str(norms), str(goals))) self.add_behavior(self._send_to_llm(str(message_text), str(norms), ""))
yield yield
@self.actions.add(".reply_with_goal", 3) @self.actions.add(".reply_with_goal", 3)
@@ -278,7 +280,7 @@ class BDICoreAgent(BaseAgent):
norms, norms,
goal, goal,
) )
# asyncio.create_task(self._send_to_llm(str(message_text), norms, str(goal))) self.add_behavior(self._send_to_llm(str(message_text), str(norms), str(goal)))
yield yield
@self.actions.add(".say", 1) @self.actions.add(".say", 1)
@@ -290,13 +292,14 @@ class BDICoreAgent(BaseAgent):
self.logger.debug('"say" action called with text=%s', message_text) self.logger.debug('"say" action called with text=%s', message_text)
# speech_command = SpeechCommand(data=message_text) speech_command = SpeechCommand(data=message_text)
# speech_message = InternalMessage( speech_message = InternalMessage(
# to=settings.agent_settings.robot_speech_name, to=settings.agent_settings.robot_speech_name,
# sender=settings.agent_settings.bdi_core_name, sender=settings.agent_settings.bdi_core_name,
# body=speech_command.model_dump_json(), body=speech_command.model_dump_json(),
# ) )
# asyncio.create_task(agent.send(speech_message)) # TODO: add to conversation history
self.add_behavior(self.send(speech_message))
yield yield
@self.actions.add(".gesture", 2) @self.actions.add(".gesture", 2)

View File

@@ -1,598 +1,12 @@
import uuid
from collections.abc import Iterable
import zmq import zmq
from pydantic import ValidationError from pydantic import ValidationError
from slugify import slugify
from zmq.asyncio import Context from zmq.asyncio import Context
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.program import ( from control_backend.schemas.internal_message import InternalMessage
Action, from control_backend.schemas.program import Program
BasicBelief,
BasicNorm,
Belief,
ConditionalNorm,
GestureAction,
Goal,
InferredBelief,
KeywordBelief,
LLMAction,
LogicalOperator,
Phase,
Plan,
Program,
ProgramElement,
SemanticBelief,
SpeechAction,
Trigger,
)
test_program = Program(
phases=[
Phase(
norms=[
BasicNorm(norm="Talk like a pirate", id=uuid.uuid4()),
ConditionalNorm(
condition=InferredBelief(
left=KeywordBelief(keyword="Arr", id=uuid.uuid4()),
right=SemanticBelief(
description="testing", name="semantic belief", id=uuid.uuid4()
),
operator=LogicalOperator.OR,
name="Talking to a pirate",
id=uuid.uuid4(),
),
norm="Use nautical terms",
id=uuid.uuid4(),
),
ConditionalNorm(
condition=SemanticBelief(
description="We are talking to a child",
name="talking to child",
id=uuid.uuid4(),
),
norm="Do not use cuss words",
id=uuid.uuid4(),
),
],
triggers=[
Trigger(
condition=InferredBelief(
left=KeywordBelief(keyword="key", id=uuid.uuid4()),
right=InferredBelief(
left=KeywordBelief(keyword="key2", id=uuid.uuid4()),
right=SemanticBelief(
description="Decode this", name="semantic belief 2", id=uuid.uuid4()
),
operator=LogicalOperator.OR,
name="test trigger inferred inner",
id=uuid.uuid4(),
),
operator=LogicalOperator.OR,
name="test trigger inferred outer",
id=uuid.uuid4(),
),
plan=Plan(
steps=[
SpeechAction(text="Testing trigger", id=uuid.uuid4()),
Goal(
name="Testing trigger",
plan=Plan(
steps=[LLMAction(goal="Do something", id=uuid.uuid4())],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
)
],
goals=[
Goal(
name="Determine user age",
plan=Plan(
steps=[LLMAction(goal="Determine the age of the user.", id=uuid.uuid4())],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
Goal(
name="Find the user's name",
plan=Plan(
steps=[
Goal(
name="Greet the user",
plan=Plan(
steps=[LLMAction(goal="Greet the user.", id=uuid.uuid4())],
id=uuid.uuid4(),
),
can_fail=False,
id=uuid.uuid4(),
),
Goal(
name="Ask for name",
plan=Plan(
steps=[
LLMAction(goal="Obtain the user's name.", id=uuid.uuid4())
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
Goal(
name="Tell a joke",
plan=Plan(
steps=[LLMAction(goal="Tell a joke.", id=uuid.uuid4())], id=uuid.uuid4()
),
id=uuid.uuid4(),
),
],
id=uuid.uuid4(),
),
Phase(
id=uuid.uuid4(),
norms=[
BasicNorm(norm="Use very gentle speech.", id=uuid.uuid4()),
ConditionalNorm(
condition=SemanticBelief(
description="We are talking to a child",
name="talking to child",
id=uuid.uuid4(),
),
norm="Do not use cuss words",
id=uuid.uuid4(),
),
],
triggers=[
Trigger(
condition=InferredBelief(
left=KeywordBelief(keyword="help", id=uuid.uuid4()),
right=SemanticBelief(
description="User is stuck", name="stuck", id=uuid.uuid4()
),
operator=LogicalOperator.OR,
name="help_or_stuck",
id=uuid.uuid4(),
),
plan=Plan(
steps=[
Goal(
name="Unblock user",
plan=Plan(
steps=[
LLMAction(
goal="Provide a step-by-step path to "
"resolve the user's issue.",
id=uuid.uuid4(),
)
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
goals=[
Goal(
name="Clarify intent",
plan=Plan(
steps=[
LLMAction(
goal="Ask 1-2 targeted questions to clarify the "
"user's intent, then proceed.",
id=uuid.uuid4(),
)
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
Goal(
name="Provide solution",
plan=Plan(
steps=[
LLMAction(
goal="Deliver a solution to complete the user's goal.",
id=uuid.uuid4(),
)
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
Goal(
name="Summarize next steps",
plan=Plan(
steps=[
LLMAction(
goal="Summarize what the user should do next.", id=uuid.uuid4()
)
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
),
]
)
def do_things():
print(AgentSpeakGenerator().generate(test_program))
class AgentSpeakGenerator:
"""
Converts Pydantic representation of behavior programs into AgentSpeak(L) code string.
"""
arrow_prefix = f"{' ' * 2}<-{' ' * 2}"
colon_prefix = f"{' ' * 2}:{' ' * 3}"
indent_prefix = " " * 6
def generate(self, program: Program) -> str:
lines = []
lines.append("")
lines += self._generate_initial_beliefs(program)
lines += self._generate_basic_flow(program)
lines += self._generate_phase_transitions(program)
lines += self._generate_norms(program)
lines += self._generate_belief_inference(program)
lines += self._generate_goals(program)
lines += self._generate_triggers(program)
return "\n".join(lines)
def _generate_initial_beliefs(self, program: Program) -> Iterable[str]:
yield "// --- Initial beliefs and agent startup ---"
yield "phase(start)."
yield ""
yield "+started"
yield f"{self.colon_prefix}phase(start)"
yield f"{self.arrow_prefix}phase({program.phases[0].id if program.phases else 'end'})."
yield from ["", ""]
def _generate_basic_flow(self, program: Program) -> Iterable[str]:
yield "// --- Basic flow ---"
for phase in program.phases:
yield from self._generate_basic_flow_per_phase(phase)
yield from ["", ""]
def _generate_basic_flow_per_phase(self, phase: Phase) -> Iterable[str]:
yield "+user_said(Message)"
yield f"{self.colon_prefix}phase({phase.id})"
goals = phase.goals
if goals:
yield f"{self.arrow_prefix}{self._slugify(goals[0], include_prefix=True)}"
for goal in goals[1:]:
yield f"{self.indent_prefix}{self._slugify(goal, include_prefix=True)}"
yield f"{self.indent_prefix if goals else self.arrow_prefix}!transition_phase."
def _generate_phase_transitions(self, program: Program) -> Iterable[str]:
yield "// --- Phase transitions ---"
if len(program.phases) == 0:
yield from ["", ""]
return
# TODO: remove outdated things
for i in range(-1, len(program.phases)):
predecessor = program.phases[i] if i >= 0 else None
successor = program.phases[i + 1] if i < len(program.phases) - 1 else None
yield from self._generate_phase_transition(predecessor, successor)
yield from self._generate_phase_transition(None, None) # to avoid failing plan
yield from ["", ""]
def _generate_phase_transition(
self, phase: Phase | None = None, next_phase: Phase | None = None
) -> Iterable[str]:
yield "+!transition_phase"
if phase is None and next_phase is None: # base case true to avoid failing plan
yield f"{self.arrow_prefix}true."
return
yield f"{self.colon_prefix}phase({phase.id if phase else 'start'})"
yield f"{self.arrow_prefix}-+phase({next_phase.id if next_phase else 'end'})."
def _generate_norms(self, program: Program) -> Iterable[str]:
yield "// --- Norms ---"
for phase in program.phases:
for norm in phase.norms:
if type(norm) is BasicNorm:
yield f"{self._slugify(norm)} :- phase({phase.id})."
if type(norm) is ConditionalNorm:
yield (
f"{self._slugify(norm)} :- phase({phase.id}) & "
f"{self._slugify(norm.condition)}."
)
yield from ["", ""]
def _generate_belief_inference(self, program: Program) -> Iterable[str]:
yield "// --- Belief inference rules ---"
for phase in program.phases:
for norm in phase.norms:
if not isinstance(norm, ConditionalNorm):
continue
yield from self._belief_inference_recursive(norm.condition)
for trigger in phase.triggers:
yield from self._belief_inference_recursive(trigger.condition)
yield from ["", ""]
def _belief_inference_recursive(self, belief: Belief) -> Iterable[str]:
if type(belief) is KeywordBelief:
yield (
f"{self._slugify(belief)} :- user_said(Message) & "
f'.substring(Message, "{belief.keyword}", Pos) & Pos >= 0.'
)
if type(belief) is InferredBelief:
yield (
f"{self._slugify(belief)} :- {self._slugify(belief.left)} "
f"{'&' if belief.operator == LogicalOperator.AND else '|'} "
f"{self._slugify(belief.right)}."
)
yield from self._belief_inference_recursive(belief.left)
yield from self._belief_inference_recursive(belief.right)
def _generate_goals(self, program: Program) -> Iterable[str]:
yield "// --- Goals ---"
for phase in program.phases:
previous_goal: Goal | None = None
for goal in phase.goals:
yield from self._generate_goal_plan_recursive(goal, phase, previous_goal)
previous_goal = goal
yield from ["", ""]
def _generate_goal_plan_recursive(
self, goal: Goal, phase: Phase, previous_goal: Goal | None = None
) -> Iterable[str]:
yield f"+{self._slugify(goal, include_prefix=True)}"
# Context
yield f"{self.colon_prefix}phase({phase.id}) &"
yield f"{self.indent_prefix}not responded_this_turn &"
yield f"{self.indent_prefix}not achieved_{self._slugify(goal)} &"
if previous_goal:
yield f"{self.indent_prefix}achieved_{self._slugify(previous_goal)}"
else:
yield f"{self.indent_prefix}true"
extra_goals_to_generate = []
steps = goal.plan.steps
if len(steps) == 0:
yield f"{self.arrow_prefix}true."
return
first_step = steps[0]
yield (
f"{self.arrow_prefix}{self._slugify(first_step, include_prefix=True)}"
f"{'.' if len(steps) == 1 and goal.can_fail else ';'}"
)
if isinstance(first_step, Goal):
extra_goals_to_generate.append(first_step)
for step in steps[1:-1]:
yield f"{self.indent_prefix}{self._slugify(step, include_prefix=True)};"
if isinstance(step, Goal):
extra_goals_to_generate.append(step)
if len(steps) > 1:
last_step = steps[-1]
yield (
f"{self.indent_prefix}{self._slugify(last_step, include_prefix=True)}"
f"{'.' if goal.can_fail else ';'}"
)
if isinstance(last_step, Goal):
extra_goals_to_generate.append(last_step)
if not goal.can_fail:
yield f"{self.indent_prefix}+achieved_{self._slugify(goal)}."
yield f"+{self._slugify(goal, include_prefix=True)}"
yield f"{self.arrow_prefix}true."
yield ""
extra_previous_goal: Goal | None = None
for extra_goal in extra_goals_to_generate:
yield from self._generate_goal_plan_recursive(extra_goal, phase, extra_previous_goal)
extra_previous_goal = extra_goal
def _generate_triggers(self, program: Program) -> Iterable[str]:
yield "// --- Triggers ---"
for phase in program.phases:
for trigger in phase.triggers:
yield from self._generate_trigger_plan(trigger, phase)
yield from ["", ""]
def _generate_trigger_plan(self, trigger: Trigger, phase: Phase) -> Iterable[str]:
belief_name = self._slugify(trigger.condition)
yield f"+{belief_name}"
yield f"{self.colon_prefix}phase({phase.id})"
extra_goals_to_generate = []
steps = trigger.plan.steps
if len(steps) == 0:
yield f"{self.arrow_prefix}true."
return
first_step = steps[0]
yield (
f"{self.arrow_prefix}{self._slugify(first_step, include_prefix=True)}"
f"{'.' if len(steps) == 1 else ';'}"
)
if isinstance(first_step, Goal):
extra_goals_to_generate.append(first_step)
for step in steps[1:-1]:
yield f"{self.indent_prefix}{self._slugify(step, include_prefix=True)};"
if isinstance(step, Goal):
extra_goals_to_generate.append(step)
if len(steps) > 1:
last_step = steps[-1]
yield f"{self.indent_prefix}{self._slugify(last_step, include_prefix=True)}."
if isinstance(last_step, Goal):
extra_goals_to_generate.append(last_step)
yield ""
extra_previous_goal: Goal | None = None
for extra_goal in extra_goals_to_generate:
yield from self._generate_trigger_plan_recursive(extra_goal, phase, extra_previous_goal)
extra_previous_goal = extra_goal
def _generate_trigger_plan_recursive(
self, goal: Goal, phase: Phase, previous_goal: Goal | None = None
) -> Iterable[str]:
yield f"+{self._slugify(goal, include_prefix=True)}"
extra_goals_to_generate = []
steps = goal.plan.steps
if len(steps) == 0:
yield f"{self.arrow_prefix}true."
return
first_step = steps[0]
yield (
f"{self.arrow_prefix}{self._slugify(first_step, include_prefix=True)}"
f"{'.' if len(steps) == 1 and goal.can_fail else ';'}"
)
if isinstance(first_step, Goal):
extra_goals_to_generate.append(first_step)
for step in steps[1:-1]:
yield f"{self.indent_prefix}{self._slugify(step, include_prefix=True)};"
if isinstance(step, Goal):
extra_goals_to_generate.append(step)
if len(steps) > 1:
last_step = steps[-1]
yield (
f"{self.indent_prefix}{self._slugify(last_step, include_prefix=True)}"
f"{'.' if goal.can_fail else ';'}"
)
if isinstance(last_step, Goal):
extra_goals_to_generate.append(last_step)
if not goal.can_fail:
yield f"{self.indent_prefix}+achieved_{self._slugify(goal)}."
yield f"+{self._slugify(goal, include_prefix=True)}"
yield f"{self.arrow_prefix}true."
yield ""
extra_previous_goal: Goal | None = None
for extra_goal in extra_goals_to_generate:
yield from self._generate_goal_plan_recursive(extra_goal, phase, extra_previous_goal)
extra_previous_goal = extra_goal
def _slugify(self, element: ProgramElement, include_prefix: bool = False) -> str:
def base_slugify_call(text: str):
return slugify(text, separator="_", stopwords=["a", "the"])
if type(element) is KeywordBelief:
return f'keyword_said("{element.keyword}")'
if type(element) is SemanticBelief:
name = element.name
return f"semantic_{base_slugify_call(name if name else element.description)}"
if isinstance(element, BasicNorm):
return f'norm("{element.norm}")'
if isinstance(element, Goal):
return f"{'!' if include_prefix else ''}{base_slugify_call(element.name)}"
if isinstance(element, SpeechAction):
return f'.say("{element.text}")'
if isinstance(element, GestureAction):
return f'.gesture("{element.gesture}")'
if isinstance(element, LLMAction):
return f'!generate_response_with_goal("{element.goal}")'
if isinstance(element, Action.__value__):
raise NotImplementedError(
"Have not implemented an ASL string representation for this action."
)
if element.name == "":
raise ValueError("Name must be initialized for this type of ProgramElement.")
return base_slugify_call(element.name)
def _extract_basic_beliefs_from_program(self, program: Program) -> list[BasicBelief]:
beliefs = []
for phase in program.phases:
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += self._extract_basic_beliefs_from_belief(norm.condition)
for trigger in phase.triggers:
beliefs += self._extract_basic_beliefs_from_belief(trigger.condition)
return beliefs
def _extract_basic_beliefs_from_belief(self, belief: Belief) -> list[BasicBelief]:
if isinstance(belief, InferredBelief):
return self._extract_basic_beliefs_from_belief(
belief.left
) + self._extract_basic_beliefs_from_belief(belief.right)
return [belief]
class BDIProgramManager(BaseAgent): class BDIProgramManager(BaseAgent):
@@ -611,40 +25,36 @@ class BDIProgramManager(BaseAgent):
super().__init__(**kwargs) super().__init__(**kwargs)
self.sub_socket = None self.sub_socket = None
# async def _send_to_bdi(self, program: Program): async def _create_agentspeak_and_send_to_bdi(self, program: Program):
# """ """
# Convert a received program into BDI beliefs and send them to the BDI Core Agent. Convert a received program into BDI beliefs and send them to the BDI Core Agent.
#
# Currently, it takes the **first phase** of the program and extracts: Currently, it takes the **first phase** of the program and extracts:
# - **Norms**: Constraints or rules the agent must follow. - **Norms**: Constraints or rules the agent must follow.
# - **Goals**: Objectives the agent must achieve. - **Goals**: Objectives the agent must achieve.
#
# These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will
# overwrite any existing norms/goals of the same name in the BDI agent. overwrite any existing norms/goals of the same name in the BDI agent.
#
# :param program: The program object received from the API. :param program: The program object received from the API.
# """ """
# first_phase = program.phases[0] asg = AgentSpeakGenerator()
# norms_belief = Belief(
# name="norms", asl_str = asg.generate(program)
# arguments=[norm.norm for norm in first_phase.norms],
# replace=True, file_name = "src/control_backend/agents/bdi/agentspeak.asl"
# )
# goals_belief = Belief( with open(file_name, "w") as f:
# name="goals", f.write(asl_str)
# arguments=[goal.description for goal in first_phase.goals],
# replace=True, msg = InternalMessage(
# ) sender=self.name,
# program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief]) to=settings.agent_settings.bdi_core_name,
# body=file_name,
# message = InternalMessage( thread="new_program",
# to=settings.agent_settings.bdi_core_name, )
# sender=self.name,
# body=program_beliefs.model_dump_json(), await self.send(msg)
# thread="beliefs",
# )
# await self.send(message)
# self.logger.debug("Sent new norms and goals to the BDI agent.")
async def _receive_programs(self): async def _receive_programs(self):
""" """
@@ -662,7 +72,7 @@ class BDIProgramManager(BaseAgent):
self.logger.exception("Received an invalid program.") self.logger.exception("Received an invalid program.")
continue continue
await self._send_to_bdi(program) await self._create_agentspeak_and_send_to_bdi(program)
async def setup(self): async def setup(self):
""" """
@@ -678,7 +88,3 @@ class BDIProgramManager(BaseAgent):
self.sub_socket.subscribe("program") self.sub_socket.subscribe("program")
self.add_behavior(self._receive_programs()) self.add_behavior(self._receive_programs())
if __name__ == "__main__":
do_things()

View File

@@ -101,7 +101,7 @@ class BDIBeliefCollectorAgent(BaseAgent):
:return: A Belief object if the input is valid or None. :return: A Belief object if the input is valid or None.
""" """
try: try:
return Belief(name=name, arguments=arguments) return Belief(name=name, arguments=arguments, replace=name == "user_said")
except ValidationError: except ValidationError:
return None return None
@@ -144,7 +144,7 @@ class BDIBeliefCollectorAgent(BaseAgent):
msg = InternalMessage( msg = InternalMessage(
to=settings.agent_settings.bdi_core_name, to=settings.agent_settings.bdi_core_name,
sender=self.name, sender=self.name,
body=BeliefMessage(create=beliefs).model_dump_json(), body=BeliefMessage(beliefs=beliefs).model_dump_json(),
thread="beliefs", thread="beliefs",
) )

View File

@@ -0,0 +1,5 @@
norms("").
+user_said(Message) : norms(Norms) <-
-user_said(Message);
.reply(Message, Norms).

View File

@@ -1,6 +0,0 @@
norms("").
goals("").
+user_said(Message) : norms(Norms) & goals(Goals) <-
-user_said(Message);
.reply(Message, Norms, Goals).

View File

@@ -1,23 +1,8 @@
import asyncio
import json import json
import httpx
from pydantic import ValidationError
from slugify import slugify
from control_backend.agents.base import BaseAgent from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import (
Belief,
ConditionalNorm,
InferredBelief,
Program,
SemanticBelief,
)
class TextBeliefExtractorAgent(BaseAgent): class TextBeliefExtractorAgent(BaseAgent):
@@ -27,110 +12,46 @@ class TextBeliefExtractorAgent(BaseAgent):
This agent is responsible for processing raw text (e.g., from speech transcription) and This agent is responsible for processing raw text (e.g., from speech transcription) and
extracting semantic beliefs from it. extracting semantic beliefs from it.
It uses the available beliefs received from the program manager to try to extract beliefs from a In the current demonstration version, it performs a simple wrapping of the user's input
user's message, sends and updated beliefs to the BDI core, and forms a ``user_said`` belief from into a ``user_said`` belief. In a full implementation, this agent would likely interact
the message itself. with an LLM or NLU engine to extract intent, entities, and other structured information.
""" """
def __init__(self, name: str):
super().__init__(name)
self.beliefs: dict[str, bool] = {}
self.available_beliefs: list[SemanticBelief] = []
self.conversation = ChatHistory(messages=[])
async def setup(self): async def setup(self):
""" """
Initialize the agent and its resources. Initialize the agent and its resources.
""" """
self.logger.info("Setting up %s.", self.name) self.logger.info("Settting up %s.", self.name)
# Setup LLM belief context if needed (currently demo is just passthrough)
self.beliefs = {"mood": ["X"], "car": ["Y"]}
async def handle_message(self, msg: InternalMessage): async def handle_message(self, msg: InternalMessage):
""" """
Handle incoming messages. Expect messages from the Transcriber agent, LLM agent, and the Handle incoming messages, primarily from the Transcription Agent.
Program manager agent.
:param msg: The received message. :param msg: The received message containing transcribed text.
""" """
sender = msg.sender sender = msg.sender
if sender == settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body)
else:
self.logger.info("Discarding message from %s", sender)
match sender: async def _process_transcription_demo(self, txt: str):
case settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="user", content=msg.body))
await self._infer_new_beliefs()
await self._user_said(msg.body)
case settings.agent_settings.llm_name:
self.logger.debug("Received text from LLM: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
case settings.agent_settings.bdi_program_manager_name:
self._handle_program_manager_message(msg)
case _:
self.logger.info("Discarding message from %s", sender)
return
def _apply_conversation_message(self, message: ChatMessage):
""" """
Save the chat message to our conversation history, taking into account the conversation Process the transcribed text and generate beliefs.
length limit.
:param message: The chat message to add to the conversation history. **Demo Implementation:**
Currently, this method takes the raw text ``txt`` and wraps it into a belief structure:
``user_said("txt")``.
This belief is then sent to the :class:`BDIBeliefCollectorAgent`.
:param txt: The raw transcribed text string.
""" """
length_limit = settings.behaviour_settings.conversation_history_length_limit # For demo, just wrapping user text as user_said belief
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:] belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
def _handle_program_manager_message(self, msg: InternalMessage):
"""
Handle a message from the program manager: extract available beliefs from it.
:param msg: The received message from the program manager.
"""
try:
program = Program.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received message from program manager but it is not a valid program."
)
return
self.logger.debug("Received a program from the program manager.")
self.available_beliefs = self._extract_basic_beliefs_from_program(program)
# TODO Copied from an incomplete version of the program manager. Use that one instead.
@staticmethod
def _extract_basic_beliefs_from_program(program: Program) -> list[SemanticBelief]:
beliefs = []
for phase in program.phases:
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
norm.condition
)
for trigger in phase.triggers:
beliefs += TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
trigger.condition
)
return beliefs
# TODO Copied from an incomplete version of the program manager. Use that one instead.
@staticmethod
def _extract_basic_beliefs_from_belief(belief: Belief) -> list[SemanticBelief]:
if isinstance(belief, InferredBelief):
return TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
belief.left
) + TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(belief.right)
return [belief]
async def _user_said(self, text: str):
"""
Create a belief for the user's full speech.
:param text: User's transcribed text.
"""
belief = {"beliefs": {"user_said": [text]}, "type": "belief_extraction_text"}
payload = json.dumps(belief) payload = json.dumps(belief)
belief_msg = InternalMessage( belief_msg = InternalMessage(
@@ -139,207 +60,6 @@ class TextBeliefExtractorAgent(BaseAgent):
body=payload, body=payload,
thread="beliefs", thread="beliefs",
) )
await self.send(belief_msg) await self.send(belief_msg)
self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"]))
async def _infer_new_beliefs(self):
"""
Process conversation history to extract beliefs, semantically. Any changed beliefs are sent
to the BDI core.
"""
# Return instantly if there are no beliefs to infer
if not self.available_beliefs:
return
candidate_beliefs = await self._infer_turn()
belief_changes = BeliefMessage()
for belief_key, belief_value in candidate_beliefs.items():
if belief_value is None:
continue
old_belief_value = self.beliefs.get(belief_key)
if belief_value == old_belief_value:
continue
self.beliefs[belief_key] = belief_value
belief = InternalBelief(name=belief_key, arguments=None)
if belief_value:
belief_changes.create.append(belief)
else:
belief_changes.delete.append(belief)
# Return if there were no changes in beliefs
if not belief_changes.has_values():
return
beliefs_message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=belief_changes.model_dump_json(),
thread="beliefs",
)
await self.send(beliefs_message)
@staticmethod
def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]:
k, m = divmod(len(items), n)
return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]
async def _infer_turn(self) -> dict:
"""
Process the stored conversation history to extract semantic beliefs. Returns a list of
beliefs that have been set to ``True``, ``False`` or ``None``.
:return: A dict mapping belief names to a value ``True``, ``False`` or ``None``.
"""
n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)))
all_beliefs = await asyncio.gather(
*[
self._infer_beliefs(self.conversation, beliefs)
for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel)
]
)
retval = {}
for beliefs in all_beliefs:
if beliefs is None:
continue
retval.update(beliefs)
return retval
@staticmethod
def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]:
# TODO: use real belief names
return belief.name or slugify(belief.description), {
"type": ["boolean", "null"],
"description": belief.description,
}
@staticmethod
def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict:
belief_schemas = [
TextBeliefExtractorAgent._create_belief_schema(belief) for belief in beliefs
]
return {
"type": "object",
"properties": dict(belief_schemas),
"required": [name for name, _ in belief_schemas],
}
@staticmethod
def _format_message(message: ChatMessage):
return f"{message.role.upper()}:\n{message.content}"
@staticmethod
def _format_conversation(conversation: ChatHistory):
return "\n\n".join(
[TextBeliefExtractorAgent._format_message(message) for message in conversation.messages]
)
@staticmethod
def _format_beliefs(beliefs: list[SemanticBelief]):
# TODO: use real belief names
return "\n".join(
[
f"- {belief.name or slugify(belief.description)}: {belief.description}"
for belief in beliefs
]
)
async def _infer_beliefs(
self,
conversation: ChatHistory,
beliefs: list[SemanticBelief],
) -> dict | None:
"""
Infer given beliefs based on the given conversation.
:param conversation: The conversation to infer beliefs from.
:param beliefs: The beliefs to infer.
:return: A dict containing belief names and a boolean whether they hold, or None if the
belief cannot be inferred based on the given conversation.
"""
example = {
"example_belief": True,
}
prompt = f"""{self._format_conversation(conversation)}
Given the above conversation, what beliefs can be inferred?
If there is no relevant information about a belief belief, give null.
In case messages conflict, prefer using the most recent messages for inference.
Choose from the following list of beliefs, formatted as (belief_name, description):
{self._format_beliefs(beliefs)}
Respond with a JSON similar to the following, but with the property names as given above:
{json.dumps(example, indent=2)}
"""
schema = self._create_beliefs_schema(beliefs)
return await self._retry_query_llm(prompt, schema)
async def _retry_query_llm(self, prompt: str, schema: dict, tries: int = 3) -> dict | None:
"""
Query the LLM with the given prompt and schema, return an instance of a dict conforming
to this schema. Try ``tries`` times, or return None.
:param prompt: Prompt to be queried.
:param schema: Schema to be queried.
:return: An instance of a dict conforming to this schema, or None if failed.
"""
try_count = 0
while try_count < tries:
try_count += 1
try:
return await self._query_llm(prompt, schema)
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
if try_count < tries:
continue
self.logger.exception(
"Failed to get LLM response after %d tries.",
try_count,
exc_info=e,
)
return None
@staticmethod
async def _query_llm(prompt: str, schema: dict) -> dict:
"""
Query an LLM with the given prompt and schema, return an instance of a dict conforming to
that schema.
:param prompt: The prompt to be queried.
:param schema: Schema to use during response.
:return: A dict conforming to this schema.
:raises httpx.HTTPStatusError: If the LLM server responded with an error.
:raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the
response was cut off early due to length limitations.
:raises KeyError: If the LLM server responded with no error, but the response was invalid.
"""
async with httpx.AsyncClient() as client:
response = await client.post(
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": [{"role": "user", "content": prompt}],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "Beliefs",
"strict": True,
"schema": schema,
},
},
"reasoning_effort": "low",
"temperature": settings.llm_settings.code_temperature,
"stream": False,
},
timeout=None,
)
response.raise_for_status()
response_json = response.json()
json_message = response_json["choices"][0]["message"]["content"]
return json.loads(json_message)

View File

@@ -3,14 +3,11 @@ import json
import zmq import zmq
import zmq.asyncio as azmq import zmq.asyncio as azmq
from pydantic import ValidationError
from zmq.asyncio import Context from zmq.asyncio import Context
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.ri_message import PauseCommand
from ..actuation.robot_speech_agent import RobotSpeechAgent from ..actuation.robot_speech_agent import RobotSpeechAgent
from ..perception import VADAgent from ..perception import VADAgent
@@ -185,7 +182,6 @@ class RICommunicationAgent(BaseAgent):
self._req_socket.bind(addr) self._req_socket.bind(addr)
case "actuation": case "actuation":
gesture_data = port_data.get("gestures", []) gesture_data = port_data.get("gestures", [])
single_gesture_data = port_data.get("single_gestures", [])
robot_speech_agent = RobotSpeechAgent( robot_speech_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name, settings.agent_settings.robot_speech_name,
address=addr, address=addr,
@@ -196,7 +192,6 @@ class RICommunicationAgent(BaseAgent):
address=addr, address=addr,
bind=bind, bind=bind,
gesture_data=gesture_data, gesture_data=gesture_data,
single_gesture_data=single_gesture_data,
) )
await robot_speech_agent.start() await robot_speech_agent.start()
await asyncio.sleep(0.1) # Small delay await asyncio.sleep(0.1) # Small delay
@@ -301,11 +296,3 @@ class RICommunicationAgent(BaseAgent):
self.logger.debug("Restarting communication negotiation.") self.logger.debug("Restarting communication negotiation.")
if await self._negotiate_connection(max_retries=1): if await self._negotiate_connection(max_retries=1):
self.connected = True self.connected = True
async def handle_message(self, msg : InternalMessage):
try:
pause_command = PauseCommand.model_validate_json(msg.body)
self._req_socket.send_json(pause_command.model_dump())
self.logger.debug(self._req_socket.recv_json())
except ValidationError:
self.logger.warning("Incorrect message format for PauseCommand.")

View File

@@ -64,12 +64,11 @@ class LLMAgent(BaseAgent):
:param message: The parsed prompt message containing text, norms, and goals. :param message: The parsed prompt message containing text, norms, and goals.
""" """
full_message = ""
async for chunk in self._query_llm(message.text, message.norms, message.goals): async for chunk in self._query_llm(message.text, message.norms, message.goals):
await self._send_reply(chunk) await self._send_reply(chunk)
full_message += chunk self.logger.debug(
self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.") "Finished processing BDI message. Response sent in chunks to BDI core."
await self._send_full_reply(full_message) )
async def _send_reply(self, msg: str): async def _send_reply(self, msg: str):
""" """
@@ -84,19 +83,6 @@ class LLMAgent(BaseAgent):
) )
await self.send(reply) await self.send(reply)
async def _send_full_reply(self, msg: str):
"""
Sends a response message (full) to agents that need it.
:param msg: The text content of the message.
"""
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=self.name,
body=msg,
)
await self.send(message)
async def _query_llm( async def _query_llm(
self, prompt: str, norms: list[str], goals: list[str] self, prompt: str, norms: list[str], goals: list[str]
) -> AsyncGenerator[str]: ) -> AsyncGenerator[str]:
@@ -186,7 +172,7 @@ class LLMAgent(BaseAgent):
json={ json={
"model": settings.llm_settings.local_llm_model, "model": settings.llm_settings.local_llm_model,
"messages": messages, "messages": messages,
"temperature": settings.llm_settings.chat_temperature, "temperature": 0.3,
"stream": True, "stream": True,
}, },
) as response: ) as response:

View File

@@ -1,68 +0,0 @@
import asyncio
import json
import zmq
from zmq.asyncio import Context
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
class TestPauseAgent(BaseAgent):
def __init__(self, name: str):
super().__init__(name)
async def setup(self):
context = Context.instance()
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
self.add_behavior(self._pause_command_loop())
self.logger.debug("TestPauseAgent setup complete.")
async def _pause_command_loop(self):
print("Starting Pause command test loop.")
while True:
pause_command = {
"endpoint": "pause",
"data": True,
}
message = InternalMessage(
to="ri_communication_agent",
sender=self.name,
body=json.dumps(pause_command),
)
await self.send(message)
# User interrupt message
data = {
"type": "pause",
"context": True,
}
await self.pub_socket.send_multipart([b"button_pressed", json.dumps(data).encode()])
self.logger.info("Pausing robot actions.")
await asyncio.sleep(15) # Simulate delay between messages
pause_command = {
"endpoint": "pause",
"data": False,
}
message = InternalMessage(
to="ri_communication_agent",
sender=self.name,
body=json.dumps(pause_command),
)
await self.send(message)
# User interrupt message
data = {
"type": "pause",
"context": False,
}
await self.pub_socket.send_multipart([b"button_pressed", json.dumps(data).encode()])
self.logger.info("Resuming robot actions.")
await asyncio.sleep(15) # Simulate delay between messages

View File

@@ -7,7 +7,6 @@ import zmq.asyncio as azmq
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
from .transcription_agent.transcription_agent import TranscriptionAgent from .transcription_agent.transcription_agent import TranscriptionAgent
@@ -87,12 +86,6 @@ class VADAgent(BaseAgent):
self.audio_buffer = np.array([], dtype=np.float32) self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = asyncio.Event() self._ready = asyncio.Event()
# Pause control
self._reset_needed = False
self._paused = asyncio.Event()
self._paused.set() # Not paused at start
self.model = None self.model = None
async def setup(self): async def setup(self):
@@ -220,16 +213,6 @@ class VADAgent(BaseAgent):
""" """
await self._ready.wait() await self._ready.wait()
while self._running: while self._running:
await self._paused.wait()
# After being unpaused, reset stream and buffers
if self._reset_needed:
self.logger.debug("Resuming: resetting stream and buffers.")
await self._reset_stream()
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._reset_needed = False
assert self.audio_in_poller is not None assert self.audio_in_poller is not None
data = await self.audio_in_poller.poll() data = await self.audio_in_poller.poll()
if data is None: if data is None:
@@ -271,27 +254,3 @@ class VADAgent(BaseAgent):
# At this point, we know that the speech has ended. # At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary # Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk self.audio_buffer = chunk
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages.
Expects messages to pause or resume the VAD processing from User Interrupt Agent.
:param msg: The received internal message.
"""
sender = msg.sender
if sender == settings.agent_settings.user_interrupt_name:
if msg.body == "PAUSE":
self.logger.info("Pausing VAD processing.")
self._paused.clear()
# If the robot needs to pick up speaking where it left off, do not set _reset_needed
self._reset_needed = True
elif msg.body == "RESUME":
self.logger.info("Resuming VAD processing.")
self._paused.set()
else:
self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}")
else:
self.logger.debug(f"Ignoring message from unknown sender: {sender}")

View File

@@ -1,189 +0,0 @@
import json
import zmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import (
GestureCommand,
PauseCommand,
RIEndpoint,
SpeechCommand,
)
class UserInterruptAgent(BaseAgent):
"""
User Interrupt Agent.
This agent receives button_pressed events from the external HTTP API
(via ZMQ) and uses the associated context to trigger one of the following actions:
- Send a prioritized message to the `RobotSpeechAgent`
- Send a prioritized gesture to the `RobotGestureAgent`
- Send a belief override to the `BDIProgramManager`in order to activate a
trigger/conditional norm or complete a goal.
Prioritized actions clear the current RI queue before inserting the new item,
ensuring they are executed immediately after Pepper's current action has been fulfilled.
:ivar sub_socket: The ZMQ SUB socket used to receive user intterupts.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
async def _receive_button_event(self):
"""
The behaviour of the UserInterruptAgent.
Continuous loop that receives button_pressed events from the button_pressed HTTP endpoint.
These events contain a type and a context.
These are the different types and contexts:
- type: "speech", context: string that the robot has to say.
- type: "gesture", context: single gesture name that the robot has to perform.
- type: "override", context: belief_id that overrides the goal/trigger/conditional norm.
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
try:
event_data = json.loads(body)
event_type = event_data.get("type") # e.g., "speech", "gesture"
event_context = event_data.get("context") # e.g., "Hello, I am Pepper!"
except json.JSONDecodeError:
self.logger.error("Received invalid JSON payload on topic %s", topic)
continue
if event_type == "speech":
await self._send_to_speech_agent(event_context)
self.logger.info(
"Forwarded button press (speech) with context '%s' to RobotSpeechAgent.",
event_context,
)
elif event_type == "gesture":
await self._send_to_gesture_agent(event_context)
self.logger.info(
"Forwarded button press (gesture) with context '%s' to RobotGestureAgent.",
event_context,
)
elif event_type == "override":
await self._send_to_program_manager(event_context)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
event_context,
)
elif event_type == "pause":
await self._send_pause_command(event_context)
if event_context:
self.logger.info("Sent pause command.")
else:
self.logger.info("Sent resume command.")
else:
self.logger.warning(
"Received button press with unknown type '%s' (context: '%s').",
event_type,
event_context,
)
async def _send_to_speech_agent(self, text_to_say: str):
"""
method to send prioritized speech command to RobotSpeechAgent.
:param text_to_say: The string that the robot has to say.
"""
cmd = SpeechCommand(data=text_to_say, is_priority=True)
out_msg = InternalMessage(
to=settings.agent_settings.robot_speech_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
async def _send_to_gesture_agent(self, single_gesture_name: str):
"""
method to send prioritized gesture command to RobotGestureAgent.
:param single_gesture_name: The gesture tag that the robot has to perform.
"""
# the endpoint is set to always be GESTURE_SINGLE for user interrupts
cmd = GestureCommand(
endpoint=RIEndpoint.GESTURE_SINGLE, data=single_gesture_name, is_priority=True
)
out_msg = InternalMessage(
to=settings.agent_settings.robot_gesture_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
async def _send_to_program_manager(self, belief_id: str):
"""
Send a button_override belief to the BDIProgramManager.
:param belief_id: The belief_id that overrides the goal/trigger/conditional norm.
this id can belong to a basic belief or an inferred belief.
See also: https://utrechtuniversity.youtrack.cloud/articles/N25B-A-27/UI-components
"""
data = {"belief": belief_id}
message = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
sender=self.name,
body=json.dumps(data),
thread="belief_override_id",
)
await self.send(message)
self.logger.info(
"Sent button_override belief with id '%s' to Program manager.",
belief_id,
)
async def _send_pause_command(self, pause : bool):
"""
Send a pause command to the Robot Interface via the RI Communication Agent.
Send a pause command to the other internal agents; for now just VAD agent.
"""
cmd = PauseCommand(data=pause)
message = InternalMessage(
to=settings.agent_settings.ri_communication_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(message)
if pause:
# Send pause to VAD agent
vad_message = InternalMessage(
to=settings.agent_settings.vad_name,
sender=self.name,
body="PAUSE",
)
await self.send(vad_message)
self.logger.info("Sent pause command to VAD Agent and RI Communication Agent.")
else:
# Send resume to VAD agent
vad_message = InternalMessage(
to=settings.agent_settings.vad_name,
sender=self.name,
body="RESUME",
)
await self.send(vad_message)
self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.")
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic.
Starts the background behavior to receive the user interrupts.
"""
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("button_pressed")
self.add_behavior(self._receive_button_event())

View File

@@ -1,31 +0,0 @@
import logging
from fastapi import APIRouter, Request
from control_backend.schemas.events import ButtonPressedEvent
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/button_pressed", status_code=202)
async def receive_button_event(event: ButtonPressedEvent, request: Request):
"""
Endpoint to handle external button press events.
Validates the event payload and publishes it to the internal 'button_pressed' topic.
Subscribers (in this case user_interrupt_agent) will pick this up to trigger
specific behaviors or state changes.
:param event: The parsed ButtonPressedEvent object.
:param request: The FastAPI request object.
"""
logger.debug("Received button event: %s | %s", event.type, event.context)
topic = b"button_pressed"
body = event.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Event received"}

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import button_pressed, logs, message, program, robot, sse from control_backend.api.v1.endpoints import logs, message, program, robot, sse
api_router = APIRouter() api_router = APIRouter()
@@ -13,5 +13,3 @@ api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Command
api_router.include_router(logs.router, tags=["Logs"]) api_router.include_router(logs.router, tags=["Logs"])
api_router.include_router(program.router, tags=["Program"]) api_router.include_router(program.router, tags=["Program"])
api_router.include_router(button_pressed.router, tags=["Button Pressed Events"])

View File

@@ -48,7 +48,6 @@ class AgentSettings(BaseModel):
ri_communication_name: str = "ri_communication_agent" ri_communication_name: str = "ri_communication_agent"
robot_speech_name: str = "robot_speech_agent" robot_speech_name: str = "robot_speech_agent"
robot_gesture_name: str = "robot_gesture_agent" robot_gesture_name: str = "robot_gesture_agent"
user_interrupt_name: str = "user_interrupt_agent"
class BehaviourSettings(BaseModel): class BehaviourSettings(BaseModel):
@@ -65,7 +64,6 @@ class BehaviourSettings(BaseModel):
:ivar transcription_words_per_minute: Estimated words per minute for transcription timing. :ivar transcription_words_per_minute: Estimated words per minute for transcription timing.
:ivar transcription_words_per_token: Estimated words per token for transcription timing. :ivar transcription_words_per_token: Estimated words per token for transcription timing.
:ivar transcription_token_buffer: Buffer for transcription tokens. :ivar transcription_token_buffer: Buffer for transcription tokens.
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
""" """
sleep_s: float = 1.0 sleep_s: float = 1.0
@@ -83,9 +81,6 @@ class BehaviourSettings(BaseModel):
transcription_words_per_token: float = 0.75 # (3 words = 4 tokens) transcription_words_per_token: float = 0.75 # (3 words = 4 tokens)
transcription_token_buffer: int = 10 transcription_token_buffer: int = 10
# Text belief extractor settings
conversation_history_length_limit: int = 10
class LLMSettings(BaseModel): class LLMSettings(BaseModel):
""" """
@@ -93,17 +88,10 @@ class LLMSettings(BaseModel):
:ivar local_llm_url: URL for the local LLM API. :ivar local_llm_url: URL for the local LLM API.
:ivar local_llm_model: Name of the local LLM model to use. :ivar local_llm_model: Name of the local LLM model to use.
:ivar chat_temperature: The temperature to use while generating chat responses.
:ivar code_temperature: The temperature to use while generating code-like responses like during
belief inference.
:ivar n_parallel: The number of parallel calls allowed to be made to the LLM.
""" """
local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "gpt-oss" local_llm_model: str = "gpt-oss"
chat_temperature: float = 1.0
code_temperature: float = 0.3
n_parallel: int = 4
class VADSettings(BaseModel): class VADSettings(BaseModel):

View File

@@ -40,10 +40,6 @@ from control_backend.agents.communication import RICommunicationAgent
from control_backend.agents.llm import LLMAgent from control_backend.agents.llm import LLMAgent
# Other backend imports # Other backend imports
from control_backend.agents.mock_agents.test_pause_ri import TestPauseAgent
# User Interrupt Agent
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.logging import setup_logging from control_backend.logging import setup_logging
@@ -121,7 +117,6 @@ async def lifespan(app: FastAPI):
BDICoreAgent, BDICoreAgent,
{ {
"name": settings.agent_settings.bdi_core_name, "name": settings.agent_settings.bdi_core_name,
"asl": "src/control_backend/agents/bdi/rules.asl",
}, },
), ),
"BeliefCollectorAgent": ( "BeliefCollectorAgent": (
@@ -142,18 +137,6 @@ async def lifespan(app: FastAPI):
"name": settings.agent_settings.bdi_program_manager_name, "name": settings.agent_settings.bdi_program_manager_name,
}, },
), ),
"TestPauseAgent": (
TestPauseAgent,
{
"name": "pause_test_agent",
},
),
"UserInterruptAgent": (
UserInterruptAgent,
{
"name": settings.agent_settings.user_interrupt_name,
},
),
} }
agents = [] agents = []

View File

@@ -6,27 +6,18 @@ class Belief(BaseModel):
Represents a single belief in the BDI system. Represents a single belief in the BDI system.
:ivar name: The functor or name of the belief (e.g., 'user_said'). :ivar name: The functor or name of the belief (e.g., 'user_said').
:ivar arguments: A list of string arguments for the belief, or None if the belief has no :ivar arguments: A list of string arguments for the belief.
arguments. :ivar replace: If True, existing beliefs with this name should be replaced by this one.
""" """
name: str name: str
arguments: list[str] | None arguments: list[str]
replace: bool = False
class BeliefMessage(BaseModel): class BeliefMessage(BaseModel):
""" """
A container for communicating beliefs between agents. A container for transporting a list of beliefs between agents.
:ivar create: Beliefs to create.
:ivar delete: Beliefs to delete.
:ivar replace: Beliefs to replace. Deletes all beliefs with the same name, replacing them with
one new belief.
""" """
create: list[Belief] = [] beliefs: list[Belief]
delete: list[Belief] = []
replace: list[Belief] = []
def has_values(self) -> bool:
return len(self.create) > 0 or len(self.delete) > 0 or len(self.replace) > 0

View File

@@ -1,10 +0,0 @@
from pydantic import BaseModel
class ChatMessage(BaseModel):
role: str
content: str
class ChatHistory(BaseModel):
messages: list[ChatMessage]

View File

@@ -1,6 +0,0 @@
from pydantic import BaseModel
class ButtonPressedEvent(BaseModel):
type: str
context: str

View File

@@ -14,7 +14,6 @@ class RIEndpoint(str, Enum):
GESTURE_TAG = "actuate/gesture/tag" GESTURE_TAG = "actuate/gesture/tag"
PING = "ping" PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports" NEGOTIATE_PORTS = "negotiate/ports"
PAUSE = "pause"
class RIMessage(BaseModel): class RIMessage(BaseModel):
@@ -39,7 +38,6 @@ class SpeechCommand(RIMessage):
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH) endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str data: str
is_priority: bool = False
class GestureCommand(RIMessage): class GestureCommand(RIMessage):
@@ -54,7 +52,6 @@ class GestureCommand(RIMessage):
RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG
] ]
data: str data: str
is_priority: bool = False
@model_validator(mode="after") @model_validator(mode="after")
def check_endpoint(self): def check_endpoint(self):
@@ -65,14 +62,3 @@ class GestureCommand(RIMessage):
if self.endpoint not in allowed: if self.endpoint not in allowed:
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG") raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
return self return self
class PauseCommand(RIMessage):
"""
A specific command to pause or unpause the robot's actions.
:ivar endpoint: Fixed to ``RIEndpoint.PAUSE``.
:ivar data: A boolean indicating whether to pause (True) or unpause (False).
"""
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.PAUSE)
data: bool

View File

@@ -64,7 +64,7 @@ async def test_handle_message_sends_command():
agent = mock_speech_agent() agent = mock_speech_agent()
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
payload = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False} payload = {"endpoint": "actuate/speech", "data": "hello"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg) await agent.handle_message(msg)
@@ -75,7 +75,7 @@ async def test_handle_message_sends_command():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_zmq_command_loop_valid_payload(zmq_context): async def test_zmq_command_loop_valid_payload(zmq_context):
"""UI command is read from SUB and published.""" """UI command is read from SUB and published."""
command = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False} command = {"endpoint": "actuate/speech", "data": "hello"}
fake_socket = AsyncMock() fake_socket = AsyncMock()
async def recv_once(): async def recv_once():

View File

@@ -51,7 +51,7 @@ async def test_handle_belief_collector_message(agent, mock_settings):
msg = InternalMessage( msg = InternalMessage(
to="bdi_agent", to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name, sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(create=beliefs).model_dump_json(), body=BeliefMessage(beliefs=beliefs).model_dump_json(),
thread="beliefs", thread="beliefs",
) )
@@ -64,26 +64,6 @@ async def test_handle_belief_collector_message(agent, mock_settings):
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),)) assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
@pytest.mark.asyncio
async def test_handle_delete_belief_message(agent, mock_settings):
"""Test that incoming beliefs to be deleted are removed from the BDI agent"""
beliefs = [Belief(name="user_said", arguments=["Hello"])]
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(delete=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
# Expect bdi_agent.call to be triggered to remove belief
args = agent.bdi_agent.call.call_args.args
assert args[0] == agentspeak.Trigger.removal
assert args[1] == agentspeak.GoalType.belief
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_incorrect_belief_collector_message(agent, mock_settings): async def test_incorrect_belief_collector_message(agent, mock_settings):
"""Test that incorrect message format triggers an exception.""" """Test that incorrect message format triggers an exception."""
@@ -148,8 +128,7 @@ def test_add_belief_sets_event(agent):
agent._wake_bdi_loop = MagicMock() agent._wake_bdi_loop = MagicMock()
belief = Belief(name="test_belief", arguments=["a", "b"]) belief = Belief(name="test_belief", arguments=["a", "b"])
belief_changes = BeliefMessage(replace=[belief]) agent._apply_beliefs([belief])
agent._apply_belief_changes(belief_changes)
assert agent.bdi_agent.call.called assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called() agent._wake_bdi_loop.set.assert_called()
@@ -158,7 +137,7 @@ def test_add_belief_sets_event(agent):
def test_apply_beliefs_empty_returns(agent): def test_apply_beliefs_empty_returns(agent):
"""Line: if not beliefs: return""" """Line: if not beliefs: return"""
agent._wake_bdi_loop = MagicMock() agent._wake_bdi_loop = MagicMock()
agent._apply_belief_changes(BeliefMessage()) agent._apply_beliefs([])
agent.bdi_agent.call.assert_not_called() agent.bdi_agent.call.assert_not_called()
agent._wake_bdi_loop.set.assert_not_called() agent._wake_bdi_loop.set.assert_not_called()
@@ -241,9 +220,8 @@ def test_replace_belief_calls_remove_all(agent):
agent._remove_all_with_name = MagicMock() agent._remove_all_with_name = MagicMock()
agent._wake_bdi_loop = MagicMock() agent._wake_bdi_loop = MagicMock()
belief = Belief(name="user_said", arguments=["Hello"]) belief = Belief(name="user_said", arguments=["Hello"], replace=True)
belief_changes = BeliefMessage(replace=[belief]) agent._apply_beliefs([belief])
agent._apply_belief_changes(belief_changes)
agent._remove_all_with_name.assert_called_with("user_said") agent._remove_all_with_name.assert_called_with("user_said")

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
import json
import sys import sys
import uuid
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
@@ -8,52 +8,38 @@ import pytest
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program from control_backend.schemas.program import Program
# Fix Windows Proactor loop for zmq # Fix Windows Proactor loop for zmq
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def make_valid_program_json(norm="N1", goal="G1") -> str: def make_valid_program_json(norm="N1", goal="G1"):
return Program( return json.dumps(
phases=[ {
Phase( "phases": [
id=uuid.uuid4(), {
name="Basic Phase", "id": "phase1",
norms=[ "label": "Phase 1",
BasicNorm( "triggers": [],
id=uuid.uuid4(), "norms": [{"id": "n1", "label": "Norm 1", "norm": norm}],
name=norm, "goals": [
norm=norm, {"id": "g1", "label": "Goal 1", "description": goal, "achieved": False}
), ],
], }
goals=[ ]
Goal( }
id=uuid.uuid4(), )
name=goal,
plan=Plan(
id=uuid.uuid4(),
name="Goal Plan",
steps=[],
),
can_fail=False,
),
],
triggers=[],
),
],
).model_dump_json()
@pytest.mark.skip(reason="Functionality being rebuilt.")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_to_bdi(): async def test_send_to_bdi():
manager = BDIProgramManager(name="program_manager_test") manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock() manager.send = AsyncMock()
program = Program.model_validate_json(make_valid_program_json()) program = Program.model_validate_json(make_valid_program_json())
await manager._send_to_bdi(program) await manager._create_agentspeak_and_send_to_bdi(program)
assert manager.send.await_count == 1 assert manager.send.await_count == 1
msg: InternalMessage = manager.send.await_args[0][0] msg: InternalMessage = manager.send.await_args[0][0]
@@ -76,7 +62,7 @@ async def test_receive_programs_valid_and_invalid():
manager = BDIProgramManager(name="program_manager_test") manager = BDIProgramManager(name="program_manager_test")
manager.sub_socket = sub manager.sub_socket = sub
manager._send_to_bdi = AsyncMock() manager._create_agentspeak_and_send_to_bdi = AsyncMock()
try: try:
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out # Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
@@ -85,7 +71,7 @@ async def test_receive_programs_valid_and_invalid():
pass pass
# Only valid Program should have triggered _send_to_bdi # Only valid Program should have triggered _send_to_bdi
assert manager._send_to_bdi.await_count == 1 assert manager._create_agentspeak_and_send_to_bdi.await_count == 1
forwarded: Program = manager._send_to_bdi.await_args[0][0] forwarded: Program = manager._create_agentspeak_and_send_to_bdi.await_args[0][0]
assert forwarded.phases[0].norms[0].name == "N1" assert forwarded.phases[0].norms[0].norm == "N1"
assert forwarded.phases[0].goals[0].name == "G1" assert forwarded.phases[0].goals[0].description == "G1"

View File

@@ -86,7 +86,7 @@ async def test_send_beliefs_to_bdi(agent):
sent: InternalMessage = agent.send.call_args.args[0] sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs" assert sent.thread == "beliefs"
assert json.loads(sent.body)["create"] == [belief.model_dump() for belief in beliefs] assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -1,346 +0,0 @@
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from control_backend.agents.bdi import TextBeliefExtractorAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.program import (
ConditionalNorm,
LLMAction,
Phase,
Plan,
Program,
SemanticBelief,
Trigger,
)
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
agent._query_llm = AsyncMock()
return agent
@pytest.fixture
def sample_program():
return Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[
ConditionalNorm(
name="Some norm",
id=uuid.uuid4(),
norm="Use nautical terms.",
critical=False,
condition=SemanticBelief(
name="is_pirate",
id=uuid.uuid4(),
description="The user is a pirate. Perhaps because they say "
"they are, or because they speak like a pirate "
'with terms like "arr".',
),
),
],
goals=[],
triggers=[
Trigger(
name="Some trigger",
id=uuid.uuid4(),
condition=SemanticBelief(
name="no_more_booze",
id=uuid.uuid4(),
description="There is no more alcohol.",
),
plan=Plan(
name="Some plan",
id=uuid.uuid4(),
steps=[
LLMAction(
name="Some action",
id=uuid.uuid4(),
goal="Suggest eating chocolate instead.",
),
],
),
),
],
),
],
)
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_user_said(agent, mock_settings):
transcription = "this is a test"
await agent._user_said(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]
@pytest.mark.asyncio
async def test_query_llm():
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "null",
}
}
]
}
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_async_client = MagicMock()
mock_async_client.__aenter__.return_value = mock_client
mock_async_client.__aexit__.return_value = None
with patch(
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
return_value=mock_async_client,
):
agent = TextBeliefExtractorAgent("text_belief_agent")
res = await agent._query_llm("hello world", {"type": "null"})
# Response content was set as "null", so should be deserialized as None
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success(agent):
agent._query_llm.return_value = None
res = await agent._retry_query_llm("hello world", {"type": "null"})
agent._query_llm.assert_called_once()
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success_after_failure(agent):
agent._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"})
assert agent._query_llm.call_count == 2
assert res == "real value"
@pytest.mark.asyncio
async def test_retry_query_llm_failures(agent):
agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"})
assert agent._query_llm.call_count == 3
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_fail_immediately(agent):
agent._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1)
assert agent._query_llm.call_count == 1
assert res is None
@pytest.mark.asyncio
async def test_extracting_beliefs_from_program(agent, sample_program):
assert len(agent.available_beliefs) == 0
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=sample_program.model_dump_json(),
),
)
assert len(agent.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_invalid_program(agent, sample_program):
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
assert len(agent.available_beliefs) == 2
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=json.dumps({"phases": "Invalid"}),
),
)
assert len(agent.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_robot_response(agent):
initial_length = len(agent.conversation.messages)
response = "Hi, I'm Pepper. What's your name?"
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.llm_name,
body=response,
),
)
assert len(agent.conversation.messages) == initial_length + 1
assert agent.conversation.messages[-1].role == "assistant"
assert agent.conversation.messages[-1].content == response
@pytest.mark.asyncio
async def test_simulated_real_turn_with_beliefs(agent, sample_program):
"""Test sending user message to extract beliefs from."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with the belief that there's no more booze
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
assert len(agent.conversation.messages) == 0
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="We're all out of schnaps.",
),
)
assert len(agent.conversation.messages) == 1
# There should be a belief set and sent to the BDI core, as well as the user_said belief
assert agent.send.call_count == 2
# First should be the beliefs message
message: InternalMessage = agent.send.call_args_list[0].args[0]
beliefs = BeliefMessage.model_validate_json(message.body)
assert len(beliefs.create) == 1
assert beliefs.create[0].name == "no_more_booze"
@pytest.mark.asyncio
async def test_simulated_real_turn_no_beliefs(agent, sample_program):
"""Test a user message to extract beliefs from, but no beliefs are formed."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with no new beliefs
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="Hello there!",
),
)
# Only the user_said belief should've been sent
agent.send.assert_called_once()
@pytest.mark.asyncio
async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
"""
Test a user message to extract beliefs from, but no new beliefs are formed because they already
existed.
"""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["is_pirate"] = True
# Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="Arr, nice to meet you, matey.",
),
)
# Only the user_said belief should've been sent, as no beliefs have changed
agent.send.assert_called_once()
@pytest.mark.asyncio
async def test_simulated_real_turn_remove_belief(agent, sample_program):
"""
Test a user message to extract beliefs from, but an existing belief is determined no longer to
hold.
"""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["no_more_booze"] = True
# Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="I found an untouched barrel of wine!",
),
)
# Both user_said and belief change should've been sent
assert agent.send.call_count == 2
# Agent's current beliefs should've changed
assert not agent.beliefs["no_more_booze"]
@pytest.mark.asyncio
async def test_llm_failure_handling(agent, sample_program):
"""
Check that the agent handles failures gracefully without crashing.
"""
agent._query_llm.side_effect = httpx.HTTPError("")
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
belief_changes = await agent._infer_turn()
assert len(belief_changes) == 0

View File

@@ -0,0 +1,65 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
TextBeliefExtractorAgent,
)
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
return agent
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_transcription_demo(agent, mock_settings):
transcription = "this is a test"
await agent._process_transcription_demo(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]
@pytest.mark.asyncio
async def test_setup_initializes_beliefs(agent):
"""Covers the setup method and ensures beliefs are initialized."""
await agent.setup()
assert agent.beliefs == {"mood": ["X"], "car": ["Y"]}

View File

@@ -67,7 +67,6 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
address="tcp://localhost:5556", address="tcp://localhost:5556",
bind=False, bind=False,
gesture_data=[], gesture_data=[],
single_gesture_data=[],
) )
agent.add_behavior.assert_called_once() agent.add_behavior.assert_called_once()

View File

@@ -66,7 +66,7 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
# "Hello world." constitutes one sentence/chunk based on punctuation split # "Hello world." constitutes one sentence/chunk based on punctuation split
# The agent should call send once with the full sentence # The agent should call send once with the full sentence
assert agent.send.called assert agent.send.called
args = agent.send.call_args_list[0][0][0] args = agent.send.call_args[0][0]
assert args.to == mock_settings.agent_settings.bdi_core_name assert args.to == mock_settings.agent_settings.bdi_core_name
assert "Hello world." in args.body assert "Hello world." in args.body
@@ -197,9 +197,6 @@ async def test_query_llm_yields_final_tail_chunk(mock_settings):
agent = LLMAgent("llm_agent") agent = LLMAgent("llm_agent")
agent.send = AsyncMock() agent.send = AsyncMock()
agent.logger = MagicMock()
agent.logger.llm = MagicMock()
# Patch _stream_query_llm to yield tokens that do NOT end with punctuation # Patch _stream_query_llm to yield tokens that do NOT end with punctuation
async def fake_stream(messages): async def fake_stream(messages):
yield "Hello" yield "Hello"

View File

@@ -1,146 +0,0 @@
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import RIEndpoint
@pytest.fixture
def agent():
agent = UserInterruptAgent(name="user_interrupt_agent")
agent.send = AsyncMock()
agent.logger = MagicMock()
agent.sub_socket = AsyncMock()
return agent
@pytest.mark.asyncio
async def test_send_to_speech_agent(agent):
"""Verify speech command format."""
await agent._send_to_speech_agent("Hello World")
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.robot_speech_name
body = json.loads(sent_msg.body)
assert body["data"] == "Hello World"
assert body["is_priority"] is True
@pytest.mark.asyncio
async def test_send_to_gesture_agent(agent):
"""Verify gesture command format."""
await agent._send_to_gesture_agent("wave_hand")
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.robot_gesture_name
body = json.loads(sent_msg.body)
assert body["data"] == "wave_hand"
assert body["is_priority"] is True
assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value
@pytest.mark.asyncio
async def test_send_to_program_manager(agent):
"""Verify belief update format."""
context_str = "2"
await agent._send_to_program_manager(context_str)
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.bdi_program_manager_name
assert sent_msg.thread == "belief_override_id"
body = json.loads(sent_msg.body)
assert body["belief"] == context_str
@pytest.mark.asyncio
async def test_receive_loop_routing_success(agent):
"""
Test that the loop correctly:
1. Receives 'button_pressed' topic from ZMQ
2. Parses the JSON payload to find 'type' and 'context'
3. Calls the correct handler method based on 'type'
"""
# Prepare JSON payloads as bytes
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
(b"button_pressed", payload_speech),
(b"button_pressed", payload_gesture),
(b"button_pressed", payload_override),
asyncio.CancelledError, # Stop the infinite loop
]
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_program_manager = AsyncMock()
try:
await agent._receive_button_event()
except asyncio.CancelledError:
pass
await asyncio.sleep(0)
# Speech
agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech")
# Gesture
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
# Override
agent._send_to_program_manager.assert_awaited_once_with("Hello Override")
assert agent._send_to_speech_agent.await_count == 1
assert agent._send_to_gesture_agent.await_count == 1
assert agent._send_to_program_manager.await_count == 1
@pytest.mark.asyncio
async def test_receive_loop_unknown_type(agent):
"""Test that unknown 'type' values in the JSON log a warning and do not crash."""
# Prepare a payload with an unknown type
payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
(b"button_pressed", payload_unknown),
asyncio.CancelledError,
]
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_belief_collector = AsyncMock()
try:
await agent._receive_button_event()
except asyncio.CancelledError:
pass
await asyncio.sleep(0)
# Ensure no handlers were called
agent._send_to_speech_agent.assert_not_called()
agent._send_to_gesture_agent.assert_not_called()
agent._send_to_belief_collector.assert_not_called()
agent.logger.warning.assert_called_with(
"Received button press with unknown type '%s' (context: '%s').",
"unknown_thing",
"some_data",
)

View File

@@ -1,5 +1,4 @@
import json import json
import uuid
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
@@ -7,7 +6,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import program from control_backend.api.v1.endpoints import program
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program from control_backend.schemas.program import Program
@pytest.fixture @pytest.fixture
@@ -26,37 +25,29 @@ def client(app):
def make_valid_program_dict(): def make_valid_program_dict():
"""Helper to create a valid Program JSON structure.""" """Helper to create a valid Program JSON structure."""
# Converting to JSON using Pydantic because it knows how to convert a UUID object return {
program_json_str = Program( "phases": [
phases=[ {
Phase( "id": "phase1",
id=uuid.uuid4(), "label": "basephase",
name="Basic Phase", "norms": [{"id": "n1", "label": "norm", "norm": "be nice"}],
norms=[ "goals": [
BasicNorm( {"id": "g1", "label": "goal", "description": "test goal", "achieved": False}
id=uuid.uuid4(),
name="Some norm",
norm="Do normal.",
),
], ],
goals=[ "triggers": [
Goal( {
id=uuid.uuid4(), "id": "t1",
name="Some goal", "label": "trigger",
plan=Plan( "type": "keywords",
id=uuid.uuid4(), "keywords": [
name="Goal Plan", {"id": "kw1", "keyword": "keyword1"},
steps=[], {"id": "kw2", "keyword": "keyword2"},
), ],
can_fail=False, },
),
], ],
triggers=[], }
), ]
], }
).model_dump_json()
# Converting back to a dict because that's what's expected
return json.loads(program_json_str)
def test_receive_program_success(client): def test_receive_program_success(client):
@@ -80,8 +71,7 @@ def test_receive_program_success(client):
sent_bytes = args[0][1] sent_bytes = args[0][1]
sent_obj = json.loads(sent_bytes.decode()) sent_obj = json.loads(sent_bytes.decode())
# Converting to JSON using Pydantic because it knows how to handle UUIDs expected_obj = Program.model_validate(program_dict).model_dump()
expected_obj = json.loads(Program.model_validate(program_dict).model_dump_json())
assert sent_obj == expected_obj assert sent_obj == expected_obj

View File

@@ -1,65 +1,49 @@
import uuid
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from control_backend.schemas.program import ( from control_backend.schemas.program import (
BasicNorm,
ConditionalNorm,
Goal, Goal,
InferredBelief, KeywordTrigger,
KeywordBelief, Norm,
LogicalOperator,
Phase, Phase,
Plan,
Program, Program,
SemanticBelief, TriggerKeyword,
Trigger,
) )
def base_norm() -> BasicNorm: def base_norm() -> Norm:
return BasicNorm( return Norm(
id=uuid.uuid4(), id="norm1",
name="testNormName", label="testNorm",
norm="testNormNorm", norm="testNormNorm",
critical=False,
) )
def base_goal() -> Goal: def base_goal() -> Goal:
return Goal( return Goal(
id=uuid.uuid4(), id="goal1",
name="testGoalName", label="testGoal",
plan=Plan( description="testGoalDescription",
id=uuid.uuid4(), achieved=False,
name="testGoalPlanName",
steps=[],
),
can_fail=False,
) )
def base_trigger() -> Trigger: def base_trigger() -> KeywordTrigger:
return Trigger( return KeywordTrigger(
id=uuid.uuid4(), id="trigger1",
name="testTriggerName", label="testTrigger",
condition=KeywordBelief( type="keywords",
id=uuid.uuid4(), keywords=[
name="testTriggerKeywordBeliefTriggerName", TriggerKeyword(id="keyword1", keyword="testKeyword1"),
keyword="Keyword", TriggerKeyword(id="keyword1", keyword="testKeyword2"),
), ],
plan=Plan(
id=uuid.uuid4(),
name="testTriggerPlanName",
steps=[],
),
) )
def base_phase() -> Phase: def base_phase() -> Phase:
return Phase( return Phase(
id=uuid.uuid4(), id="phase1",
label="basephase",
norms=[base_norm()], norms=[base_norm()],
goals=[base_goal()], goals=[base_goal()],
triggers=[base_trigger()], triggers=[base_trigger()],
@@ -74,7 +58,7 @@ def invalid_program() -> dict:
# wrong types inside phases list (not Phase objects) # wrong types inside phases list (not Phase objects)
return { return {
"phases": [ "phases": [
{"id": uuid.uuid4()}, # incomplete {"id": "phase1"}, # incomplete
{"not_a_phase": True}, {"not_a_phase": True},
] ]
} }
@@ -93,112 +77,11 @@ def test_valid_deepprogram():
# validate nested components directly # validate nested components directly
phase = validated.phases[0] phase = validated.phases[0]
assert isinstance(phase.goals[0], Goal) assert isinstance(phase.goals[0], Goal)
assert isinstance(phase.triggers[0], Trigger) assert isinstance(phase.triggers[0], KeywordTrigger)
assert isinstance(phase.norms[0], BasicNorm) assert isinstance(phase.norms[0], Norm)
def test_invalid_program(): def test_invalid_program():
bad = invalid_program() bad = invalid_program()
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
Program.model_validate(bad) Program.model_validate(bad)
def test_conditional_norm_parsing():
"""
Check that pydantic is able to preserve the type of the norm, that it doesn't lose its
"condition" field when serializing and deserializing.
"""
norm = ConditionalNorm(
name="testNormName",
id=uuid.uuid4(),
norm="testNormNorm",
critical=False,
condition=KeywordBelief(
name="testKeywordBelief",
id=uuid.uuid4(),
keyword="testKeywordBelief",
),
)
program = Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[norm],
goals=[],
triggers=[],
),
],
)
parsed_program = Program.model_validate_json(program.model_dump_json())
parsed_norm = parsed_program.phases[0].norms[0]
assert hasattr(parsed_norm, "condition")
assert isinstance(parsed_norm, ConditionalNorm)
def test_belief_type_parsing():
"""
Check that pydantic is able to discern between the different types of beliefs when serializing
and deserializing.
"""
keyword_belief = KeywordBelief(
name="testKeywordBelief",
id=uuid.uuid4(),
keyword="something",
)
semantic_belief = SemanticBelief(
name="testSemanticBelief",
id=uuid.uuid4(),
description="something",
)
inferred_belief = InferredBelief(
name="testInferredBelief",
id=uuid.uuid4(),
operator=LogicalOperator.OR,
left=keyword_belief,
right=semantic_belief,
)
program = Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[],
goals=[],
triggers=[
Trigger(
name="testTriggerKeywordTrigger",
id=uuid.uuid4(),
condition=keyword_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
Trigger(
name="testTriggerSemanticTrigger",
id=uuid.uuid4(),
condition=semantic_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
Trigger(
name="testTriggerInferredTrigger",
id=uuid.uuid4(),
condition=inferred_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
],
),
],
)
parsed_program = Program.model_validate_json(program.model_dump_json())
parsed_keyword_belief = parsed_program.phases[0].triggers[0].condition
assert isinstance(parsed_keyword_belief, KeywordBelief)
parsed_semantic_belief = parsed_program.phases[0].triggers[1].condition
assert isinstance(parsed_semantic_belief, SemanticBelief)
parsed_inferred_belief = parsed_program.phases[0].triggers[2].condition
assert isinstance(parsed_inferred_belief, InferredBelief)