Compare commits

...

59 Commits

Author SHA1 Message Date
8506c0d9ef chore: remove belief collector and small tweaks 2026-01-16 15:07:44 +01:00
b1c18abffd test: bunch of tests
Written with AI, still need to check them

ref: N25B-449
2026-01-16 13:11:41 +01:00
39e1bb1ead fix: sync issues
ref: N25B-447
2026-01-14 15:28:29 +01:00
8f6662e64a feat: phase transitions
ref: N25B-446
2026-01-14 13:22:51 +01:00
0794c549a8 chore: remove agentspeak file from tracking 2026-01-14 11:27:29 +01:00
ff24ab7a27 fix: default behavior and end phase
ref: N25B-448
2026-01-14 11:24:19 +01:00
43ac8ad69f chore: delete outdated files
ref: N25B-446
2026-01-14 10:58:41 +01:00
Twirre Meulenbelt
f7669c021b feat: support force completed goals in semantic belief agent
ref: N25B-427
2026-01-13 17:04:44 +01:00
Björn Otgaar
8f52f8bf0c Merge branch 'feat/monitoringpage-cb' of git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb into feat/monitoringpage-cb 2026-01-13 14:03:40 +01:00
Björn Otgaar
2a94a45b34 chore: adjust 'phase_id' to 'id' for correct payload 2026-01-13 14:03:37 +01:00
f87651f691 fix: achieved goal in bdi core
ref: N25B-400
2026-01-13 12:26:18 +01:00
Pim Hutting
65e0b2d250 feat: added correct message
ref: N25B-400
2026-01-13 12:05:38 +01:00
177e844349 feat: send achieved goal from interrupt->manager->semantic
ref: N25B-400
2026-01-13 11:46:17 +01:00
Pim Hutting
0df6040444 feat: added sending goal overwrites in Userinter.
ref: N25B-400
2026-01-13 11:26:03 +01:00
Twirre Meulenbelt
af81bd8620 Merge branch 'feat/multiple-receivers' into feat/monitoringpage-cb
# Conflicts:
#	src/control_backend/core/agent_system.py
#	src/control_backend/schemas/internal_message.py
2026-01-13 11:14:18 +01:00
Twirre Meulenbelt
70e05b6c92 test: sending to multiple agents, including remote
ref: N25B-441
2026-01-13 11:10:35 +01:00
c0b8fb8612 feat: able to send to multiple receivers
ref: N25B-441
2026-01-13 11:06:42 +01:00
Pim Hutting
d499111ea4 feat: added pause functionality
Storms code wasnt fully included in Bjorns branch

ref: N25B-400
2026-01-13 00:52:04 +01:00
Pim Hutting
72c2c57f26 chore: merged button functionality and fix bug
merged björns branch that has the following button functionality
-Pause/resume
-Next phase
-Restart phase
-reset experiment
fix bug where norms where not properly sent to the user interrupt agent

ref: N25B-400
2026-01-12 19:31:50 +01:00
Pim Hutting
4a014b577a Merge remote-tracking branch 'origin/feat/reset-skip-buttons' into feat/monitoringpage-cb 2026-01-12 19:19:31 +01:00
Pim Hutting
c45a258b22 fix: fixed a bug where norms where not updated
Now in UserInterruptAgent we store the norm.norm and not the slugified norm

ref: N25B-400
2026-01-12 19:07:05 +01:00
0f09276477 fix: send norms back to UI
ref: N25B-400
2026-01-12 17:02:39 +01:00
4e113c2d5c fix: default plan and norm force
ref: N25B-400
2026-01-12 16:20:24 +01:00
Pim Hutting
54c835cc0f feat: added force_norm handling in BDI core agent
ref: N25B-400
2026-01-12 15:37:04 +01:00
Pim Hutting
c4ccbcd354 Merge remote-tracking branch 'origin/feat/extra-agentspeak-functionality' into feat/monitoringpage-cb 2026-01-12 15:24:48 +01:00
Pim Hutting
d202abcd1b fix: phases update correctly
there was a bug where phases would not update without restarting cb

ref: N25B-400
2026-01-12 12:51:24 +01:00
Twirre Meulenbelt
4b71981a3e fix: some bugs and some tests
ref: N25B-429
2026-01-12 09:00:50 +01:00
866d7c4958 fix: end phase loop correctly notifies about user_said
ref: N25B-429
2026-01-08 15:13:12 +01:00
Pim Hutting
5e2126fc21 chore: code cleanup
ref: N25B-400
2026-01-08 15:05:43 +01:00
Pim Hutting
500bbc2d82 feat: added goal start sending functionality
ref: N25B-400
2026-01-08 14:52:55 +01:00
133019a928 feat: trigger name and trigger checks on belief update
ref: N25B-429
2026-01-08 14:04:44 +01:00
4d0ba69443 fix: don't re-add user_said upon phase transition
ref: N25B-429
2026-01-08 13:44:25 +01:00
625ef0c365 feat: phase transition waits for all goals
ref: N25B-429
2026-01-08 13:36:03 +01:00
b88758fa76 feat: phase transition independent of response
ref: N25B-429
2026-01-08 13:33:37 +01:00
Pim Hutting
3a8d1730a1 fix: made mapping for conditional norms only
ref: N25B-400
2026-01-08 12:29:16 +01:00
Pim Hutting
b27e5180c4 feat: small implementation change
ref: N25B-400
2026-01-08 11:25:53 +01:00
Pim Hutting
6b34f4b82c fix: small bugfix
ref: N25B-400
2026-01-08 10:59:24 +01:00
Twirre Meulenbelt
45719c580b feat: prepend more silence before speech audio for better transcription beginnings
ref: N25B-429
2026-01-08 10:49:13 +01:00
Pim Hutting
4bf2be6359 feat: added a functionality for monitoring page
ref: N25B-400
2026-01-08 09:56:10 +01:00
Pim Hutting
20e5e46639 Merge remote-tracking branch 'origin/feat/extra-agentspeak-functionality' into feat/monitoringpage-cb 2026-01-07 22:42:40 +01:00
Pim Hutting
365d449666 feat: commit before I can merge new changes
ref: N25B-400
2026-01-07 22:41:59 +01:00
5a61225c6f feat: reset extractor history
ref: N25B-429
2026-01-07 18:10:13 +01:00
a30cea5231 Merge branch 'feat/semantic-beliefs' into feat/extra-agentspeak-functionality 2026-01-07 17:51:30 +01:00
Twirre Meulenbelt
93d67ccb66 feat: add reset functionality to semantic belief extractor
ref: N25B-432
2026-01-07 17:50:47 +01:00
240624f887 Merge branch 'dev' into feat/extra-agentspeak-functionality
# Conflicts:
#	src/control_backend/agents/bdi/bdi_program_manager.py
#	src/control_backend/agents/llm/llm_agent.py
#	test/unit/agents/bdi/test_bdi_program_manager.py
2026-01-07 17:46:48 +01:00
Pim Hutting
be6bbbb849 feat: added endpoint userinterrupt to userinterrupt
ref: N25B-400
2026-01-07 17:42:54 +01:00
8a77e8e1c7 feat: check goals only for this phase
Since conversation history still remains we can still check at a later point.

ref: N25B-429
2026-01-07 17:31:24 +01:00
3b4dccc760 Merge branch 'feat/semantic-beliefs' into feat/extra-agentspeak-functionality
# Conflicts:
#	src/control_backend/agents/bdi/bdi_program_manager.py
2026-01-07 17:20:52 +01:00
3d49e44cf7 fix: complete pipeline working
User interrupts still need to be tested.

ref: N25B-429
2026-01-07 17:13:58 +01:00
Twirre Meulenbelt
aa5b386f65 feat: semantically determine goal completion
ref: N25B-432
2026-01-07 17:08:23 +01:00
Twirre Meulenbelt
3189b9fee3 fix: let belief extractor send user_said belief
ref: N25B-429
2026-01-07 15:19:23 +01:00
Björn Otgaar
612a96940d Merge branch 'feat/environment-variables' into 'dev'
Docs for environment variables, parameterize some constants

See merge request ics/sp/2025/n25b/pepperplus-cb!38
2026-01-06 09:02:49 +00:00
Pim Hutting
4c20656c75 Merge branch 'feat/program-reset-llm' into 'dev'
feat: made program reset LLM

See merge request ics/sp/2025/n25b/pepperplus-cb!39
2026-01-02 15:13:05 +00:00
Pim Hutting
6ca86e4b81 feat: made program reset LLM 2026-01-02 15:13:04 +00:00
Twirre Meulenbelt
7d798f2e77 Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:40:16 +01:00
Twirre Meulenbelt
5282c2471f Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:35:39 +01:00
Twirre Meulenbelt
0c682d6440 feat: introduce .env.example, docs
The example includes options that are expected to be changed. It also includes a reference to where in the docs you can find a full list of options.

ref: N25B-352
2025-12-11 13:35:19 +01:00
Twirre Meulenbelt
32d8f20dc9 feat: parameterize RI host
Was "localhost" in RI Communication Agent, now uses configurable setting. Secretly also removing "localhost" from VAD agent, as its socket should be something that's "inproc".

ref: N25B-352
2025-12-11 12:12:15 +01:00
Twirre Meulenbelt
9cc0e39955 fix: failures main tests since VAD agent initialization was changed
The test still expects the VAD agent to be started in main, rather than in the RI Communication Agent.

ref: N25B-356
2025-12-11 12:04:24 +01:00
50 changed files with 3368 additions and 1511 deletions

20
.env.example Normal file
View File

@@ -0,0 +1,20 @@
# Example .env file. To use, make a copy, call it ".env" (i.e. removing the ".example" suffix), then you edit values.
# The hostname of the Robot Interface. Change if the Control Backend and Robot Interface are running on different computers.
RI_HOST="localhost"
# URL for the local LLM API. Must be an API that implements the OpenAI Chat Completions API, but most do.
LLM_SETTINGS__LOCAL_LLM_URL="http://localhost:1234/v1/chat/completions"
# Name of the local LLM model to use.
LLM_SETTINGS__LOCAL_LLM_MODEL="gpt-oss"
# Number of non-speech chunks to wait before speech ended. A chunk is approximately 31 ms. Increasing this number allows longer pauses in speech, but also increases response time.
BEHAVIOUR_SETTINGS__VAD_NON_SPEECH_PATIENCE_CHUNKS=15
# Timeout in milliseconds for socket polling. Increase this number if network latency/jitter is high, often the case when using Wi-Fi. Perhaps 500 ms. A symptom of this issue is transcriptions getting cut off.
BEHAVIOUR_SETTINGS__SOCKET_POLLER_TIMEOUT_MS=100
# For an exhaustive list of options, see the control_backend.core.config module in the docs.

2
.gitignore vendored
View File

@@ -222,6 +222,8 @@ __marimo__/
docs/*
!docs/conf.py
# Generated files
agentspeak.asl

View File

@@ -27,6 +27,7 @@ This + part might differ based on what model you choose.
copy the model name in the module loaded and replace local_llm_modelL. In settings.
## Running
To run the project (development server), execute the following command (while inside the root repository):
@@ -34,6 +35,14 @@ To run the project (development server), execute the following command (while in
uv run fastapi dev src/control_backend/main.py
```
### Environment Variables
You can use environment variables to change settings. Make a copy of the [`.env.example`](.env.example) file, name it `.env` and put it in the root directory. The file itself describes how to do the configuration.
For an exhaustive list of environment options, see the `control_backend.core.config` module in the docs.
## Testing
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:

View File

@@ -33,7 +33,7 @@ class RobotGestureAgent(BaseAgent):
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
address: str,
bind=False,
gesture_data=None,
single_gesture_data=None,

View File

@@ -1,8 +1,5 @@
from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent
from .belief_collector_agent import (
BDIBeliefCollectorAgent as BDIBeliefCollectorAgent,
)
from .text_belief_extractor_agent import (
TextBeliefExtractorAgent as TextBeliefExtractorAgent,
)

View File

@@ -77,7 +77,7 @@ class AstTerm(AstExpression, ABC):
return AstBinaryOp(self, BinaryOperatorType.NOT_EQUALS, _coalesce_expr(other))
@dataclass
@dataclass(eq=False)
class AstAtom(AstTerm):
"""
Grounded expression in all lowercase.
@@ -89,7 +89,7 @@ class AstAtom(AstTerm):
return self.value.lower()
@dataclass
@dataclass(eq=False)
class AstVar(AstTerm):
"""
Ungrounded variable expression. First letter capitalized.
@@ -101,7 +101,7 @@ class AstVar(AstTerm):
return self.name.capitalize()
@dataclass
@dataclass(eq=False)
class AstNumber(AstTerm):
value: int | float
@@ -109,7 +109,7 @@ class AstNumber(AstTerm):
return str(self.value)
@dataclass
@dataclass(eq=False)
class AstString(AstTerm):
value: str
@@ -117,7 +117,7 @@ class AstString(AstTerm):
return f'"{self.value}"'
@dataclass
@dataclass(eq=False)
class AstLiteral(AstTerm):
functor: str
terms: list[AstTerm] = field(default_factory=list)

View File

@@ -3,9 +3,11 @@ from functools import singledispatchmethod
from slugify import slugify
from control_backend.agents.bdi.agentspeak_ast import (
AstAtom,
AstBinaryOp,
AstExpression,
AstLiteral,
AstNumber,
AstPlan,
AstProgram,
AstRule,
@@ -17,6 +19,7 @@ from control_backend.agents.bdi.agentspeak_ast import (
TriggerType,
)
from control_backend.schemas.program import (
BaseGoal,
BasicNorm,
ConditionalNorm,
GestureAction,
@@ -42,7 +45,13 @@ class AgentSpeakGenerator:
def generate(self, program: Program) -> str:
self._asp = AstProgram()
self._asp.rules.append(AstRule(self._astify(program.phases[0])))
if program.phases:
self._asp.rules.append(AstRule(self._astify(program.phases[0])))
else:
self._asp.rules.append(AstRule(AstLiteral("phase", [AstString("end")])))
self._asp.rules.append(AstRule(AstLiteral("!notify_cycle")))
self._add_keyword_inference()
self._add_default_plans()
@@ -70,6 +79,7 @@ class AgentSpeakGenerator:
self._add_reply_with_goal_plan()
self._add_say_plan()
self._add_reply_plan()
self._add_notify_cycle_plan()
def _add_reply_with_goal_plan(self):
self._asp.plans.append(
@@ -132,6 +142,29 @@ class AgentSpeakGenerator:
)
)
def _add_notify_cycle_plan(self):
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("notify_cycle"),
[],
[
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"findall",
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
),
),
AstStatement(
StatementType.DO_ACTION, AstLiteral("notify_norms", [AstVar("Norms")])
),
AstStatement(StatementType.DO_ACTION, AstLiteral("wait", [AstNumber(100)])),
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("notify_cycle")),
],
)
)
def _process_phases(self, phases: list[Phase]) -> None:
for curr_phase, next_phase in zip([None] + phases, phases + [None], strict=True):
if curr_phase:
@@ -145,7 +178,12 @@ class AgentSpeakGenerator:
type=TriggerType.ADDED_BELIEF,
trigger_literal=AstLiteral("user_said", [AstVar("Message")]),
context=[AstLiteral("phase", [AstString("end")])],
body=[AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply"))],
body=[
AstStatement(
StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")])
),
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply")),
],
)
)
@@ -157,7 +195,7 @@ class AgentSpeakGenerator:
previous_goal = None
for goal in phase.goals:
self._process_goal(goal, phase, previous_goal)
self._process_goal(goal, phase, previous_goal, main_goal=True)
previous_goal = goal
for trigger in phase.triggers:
@@ -171,29 +209,57 @@ class AgentSpeakGenerator:
self._astify(to_phase) if to_phase else AstLiteral("phase", [AstString("end")])
)
context = [from_phase_ast, ~AstLiteral("responded_this_turn")]
if from_phase and from_phase.goals:
context.append(self._astify(from_phase.goals[-1], achieved=True))
check_context = [from_phase_ast]
if from_phase:
for goal in from_phase.goals:
check_context.append(self._astify(goal, achieved=True))
force_context = [from_phase_ast]
body = [
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"notify_transition_phase",
[
AstString(str(from_phase.id)),
AstString(str(to_phase.id) if to_phase else "end"),
],
),
),
AstStatement(StatementType.REMOVE_BELIEF, from_phase_ast),
AstStatement(StatementType.ADD_BELIEF, to_phase_ast),
]
if from_phase:
body.extend(
[
AstStatement(
StatementType.TEST_GOAL, AstLiteral("user_said", [AstVar("Message")])
),
AstStatement(
StatementType.REPLACE_BELIEF, AstLiteral("user_said", [AstVar("Message")])
),
]
)
# if from_phase:
# body.extend(
# [
# AstStatement(
# StatementType.TEST_GOAL, AstLiteral("user_said", [AstVar("Message")])
# ),
# AstStatement(
# StatementType.REPLACE_BELIEF, AstLiteral("user_said", [AstVar("Message")])
# ),
# ]
# )
# Check
self._asp.plans.append(
AstPlan(TriggerType.ADDED_GOAL, AstLiteral("transition_phase"), context, body)
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("transition_phase"),
check_context,
[
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("force_transition_phase")),
],
)
)
# Force
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL, AstLiteral("force_transition_phase"), force_context, body
)
)
def _process_norm(self, norm: Norm, phase: Phase) -> None:
@@ -201,7 +267,11 @@ class AgentSpeakGenerator:
match norm:
case ConditionalNorm(condition=cond):
rule = AstRule(self._astify(norm), self._astify(phase) & self._astify(cond))
rule = AstRule(
self._astify(norm),
self._astify(phase) & self._astify(cond)
| AstAtom(f"force_{self.slugify(norm)}"),
)
case BasicNorm():
rule = AstRule(self._astify(norm), self._astify(phase))
@@ -213,6 +283,11 @@ class AgentSpeakGenerator:
def _add_default_loop(self, phase: Phase) -> None:
actions = []
actions.append(
AstStatement(
StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")])
)
)
actions.append(AstStatement(StatementType.REMOVE_BELIEF, AstLiteral("responded_this_turn")))
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("check_triggers")))
@@ -236,6 +311,7 @@ class AgentSpeakGenerator:
phase: Phase,
previous_goal: Goal | None = None,
continues_response: bool = False,
main_goal: bool = False,
) -> None:
context: list[AstExpression] = [self._astify(phase)]
context.append(~self._astify(goal, achieved=True))
@@ -245,6 +321,13 @@ class AgentSpeakGenerator:
context.append(~AstLiteral("responded_this_turn"))
body = []
if main_goal: # UI only needs to know about the main goals
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_goal_start", [AstString(self.slugify(goal))]),
)
)
subgoals = []
for step in goal.plan.steps:
@@ -283,12 +366,28 @@ class AgentSpeakGenerator:
body = []
subgoals = []
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_trigger_start", [AstString(self.slugify(trigger))]),
)
)
for step in trigger.plan.steps:
body.append(self._step_to_statement(step))
if isinstance(step, Goal):
step.can_fail = False # triggers are continuous sequence
subgoals.append(step)
# Arbitrary wait for UI to display nicely
body.append(AstStatement(StatementType.DO_ACTION, AstLiteral("wait", [AstNumber(2000)])))
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_trigger_end", [AstString(self.slugify(trigger))]),
)
)
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
@@ -298,6 +397,9 @@ class AgentSpeakGenerator:
)
)
# Force trigger (from UI)
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(trigger), [], body))
for subgoal in subgoals:
self._process_goal(subgoal, phase, continues_response=True)
@@ -332,13 +434,7 @@ class AgentSpeakGenerator:
@_astify.register
def _(self, sb: SemanticBelief) -> AstExpression:
return AstLiteral(self.get_semantic_belief_slug(sb))
@staticmethod
def get_semantic_belief_slug(sb: SemanticBelief) -> str:
# If you need a method like this for other types, make a public slugify singledispatch for
# all types.
return f"semantic_{AgentSpeakGenerator._slugify_str(sb.name)}"
return AstLiteral(self.slugify(sb))
@_astify.register
def _(self, ib: InferredBelief) -> AstExpression:
@@ -383,6 +479,11 @@ class AgentSpeakGenerator:
def slugify(element: ProgramElement) -> str:
raise NotImplementedError(f"Cannot convert element {element} to a slug.")
@slugify.register
@staticmethod
def _(n: Norm) -> str:
return f"norm_{AgentSpeakGenerator._slugify_str(n.norm)}"
@slugify.register
@staticmethod
def _(sb: SemanticBelief) -> str:
@@ -390,7 +491,7 @@ class AgentSpeakGenerator:
@slugify.register
@staticmethod
def _(g: Goal) -> str:
def _(g: BaseGoal) -> str:
return AgentSpeakGenerator._slugify_str(g.name)
@slugify.register

View File

@@ -1,203 +0,0 @@
import typing
from dataclasses import dataclass, field
# --- Types ---
@dataclass
class BeliefLiteral:
"""
Represents a literal or atom.
Example: phase(1), user_said("hello"), ~started
"""
functor: str
args: list[str] = field(default_factory=list)
negated: bool = False
def __str__(self):
# In ASL, 'not' is usually for closed-world assumption (prolog style),
# '~' is for explicit negation in beliefs.
# For simplicity in behavior trees, we often use 'not' for conditions.
prefix = "not " if self.negated else ""
if not self.args:
return f"{prefix}{self.functor}"
# Clean args to ensure strings are quoted if they look like strings,
# but usually the converter handles the quoting of string literals.
args_str = ", ".join(self.args)
return f"{prefix}{self.functor}({args_str})"
@dataclass
class GoalLiteral:
name: str
def __str__(self):
return f"!{self.name}"
@dataclass
class ActionLiteral:
"""
Represents a step in a plan body.
Example: .say("Hello") or !achieve_goal
"""
code: str
def __str__(self):
return self.code
@dataclass
class BinaryOp:
"""
Represents logical operations.
Example: (A & B) | C
"""
left: "Expression | str"
operator: typing.Literal["&", "|"]
right: "Expression | str"
def __str__(self):
l_str = str(self.left)
r_str = str(self.right)
if isinstance(self.left, BinaryOp):
l_str = f"({l_str})"
if isinstance(self.right, BinaryOp):
r_str = f"({r_str})"
return f"{l_str} {self.operator} {r_str}"
Literal = BeliefLiteral | GoalLiteral | ActionLiteral
Expression = Literal | BinaryOp | str
@dataclass
class Rule:
"""
Represents an inference rule.
Example: head :- body.
"""
head: Expression
body: Expression | None = None
def __str__(self):
if not self.body:
return f"{self.head}."
return f"{self.head} :- {self.body}."
@dataclass
class PersistentRule:
"""
Represents an inference rule, where the inferred belief is persistent when formed.
"""
head: Expression
body: Expression
def __str__(self):
if not self.body:
raise Exception("Rule without body should not be persistent.")
lines = []
if isinstance(self.body, BinaryOp):
lines.append(f"+{self.body.left}")
if self.body.operator == "&":
lines.append(f" : {self.body.right}")
lines.append(f" <- +{self.head}.")
if self.body.operator == "|":
lines.append(f"+{self.body.right}")
lines.append(f" <- +{self.head}.")
return "\n".join(lines)
@dataclass
class Plan:
"""
Represents a plan.
Syntax: +trigger : context <- body.
"""
trigger: BeliefLiteral | GoalLiteral
context: list[Expression] = field(default_factory=list)
body: list[ActionLiteral] = field(default_factory=list)
def __str__(self):
# Indentation settings
INDENT = " "
ARROW = "\n <- "
COLON = "\n : "
# Build Header
header = f"+{self.trigger}"
if self.context:
ctx_str = f" &\n{INDENT}".join(str(c) for c in self.context)
header += f"{COLON}{ctx_str}"
# Case 1: Empty body
if not self.body:
return f"{header}."
# Case 2: Short body (optional optimization, keeping it uniform usually better)
header += ARROW
lines = []
# We start the first action on the same line or next line.
# Let's put it on the next line for readability if there are multiple.
if len(self.body) == 1:
return f"{header}{self.body[0]}."
# First item
lines.append(f"{header}{self.body[0]};")
# Middle items
for item in self.body[1:-1]:
lines.append(f"{INDENT}{item};")
# Last item
lines.append(f"{INDENT}{self.body[-1]}.")
return "\n".join(lines)
@dataclass
class AgentSpeakFile:
"""
Root element representing the entire generated file.
"""
initial_beliefs: list[Rule] = field(default_factory=list)
inference_rules: list[Rule | PersistentRule] = field(default_factory=list)
plans: list[Plan] = field(default_factory=list)
def __str__(self):
sections = []
if self.initial_beliefs:
sections.append("// --- Initial Beliefs & Facts ---")
sections.extend(str(rule) for rule in self.initial_beliefs)
sections.append("")
if self.inference_rules:
sections.append("// --- Inference Rules ---")
sections.extend(str(rule) for rule in self.inference_rules if isinstance(rule, Rule))
sections.append("")
sections.extend(
str(rule) for rule in self.inference_rules if isinstance(rule, PersistentRule)
)
sections.append("")
if self.plans:
sections.append("// --- Plans ---")
# Separate plans by a newline for readability
sections.extend(str(plan) + "\n" for plan in self.plans)
return "\n".join(sections)

View File

@@ -1,425 +0,0 @@
import asyncio
import time
from functools import singledispatchmethod
from slugify import slugify
from control_backend.agents.bdi import BDICoreAgent
from control_backend.agents.bdi.asl_ast import (
ActionLiteral,
AgentSpeakFile,
BeliefLiteral,
BinaryOp,
Expression,
GoalLiteral,
PersistentRule,
Plan,
Rule,
)
from control_backend.agents.bdi.bdi_program_manager import test_program
from control_backend.schemas.program import (
BasicBelief,
Belief,
ConditionalNorm,
GestureAction,
Goal,
InferredBelief,
KeywordBelief,
LLMAction,
LogicalOperator,
Phase,
Program,
ProgramElement,
SemanticBelief,
SpeechAction,
)
async def do_things():
res = input("Wanna generate")
if res == "y":
program = AgentSpeakGenerator().generate(test_program)
filename = f"{int(time.time())}.asl"
with open(filename, "w") as f:
f.write(program)
else:
# filename = "0test.asl"
filename = "1766062491.asl"
bdi_agent = BDICoreAgent("BDICoreAgent", filename)
flag = asyncio.Event()
await bdi_agent.start()
await flag.wait()
def do_other_things():
print(AgentSpeakGenerator().generate(test_program))
class AgentSpeakGenerator:
"""
Converts a Pydantic Program behavior model into an AgentSpeak(L) AST,
then renders it to a string.
"""
def generate(self, program: Program) -> str:
asl = AgentSpeakFile()
self._generate_startup(program, asl)
for i, phase in enumerate(program.phases):
next_phase = program.phases[i + 1] if i < len(program.phases) - 1 else None
self._generate_phase_flow(phase, next_phase, asl)
self._generate_norms(phase, asl)
self._generate_goals(phase, asl)
self._generate_triggers(phase, asl)
self._generate_fallbacks(program, asl)
return str(asl)
# --- Section: Startup & Phase Management ---
def _generate_startup(self, program: Program, asl: AgentSpeakFile):
if not program.phases:
return
# Initial belief: phase(start).
asl.initial_beliefs.append(Rule(head=BeliefLiteral("phase", ['"start"'])))
# Startup plan: +started : phase(start) <- -phase(start); +phase(first_id).
asl.plans.append(
Plan(
trigger=BeliefLiteral("started"),
context=[BeliefLiteral("phase", ['"start"'])],
body=[
ActionLiteral('-phase("start")'),
ActionLiteral(f'+phase("{program.phases[0].id}")'),
],
)
)
# Initial plans:
asl.plans.append(
Plan(
trigger=GoalLiteral("generate_response_with_goal(Goal)"),
context=[BeliefLiteral("user_said", ["Message"])],
body=[
ActionLiteral("+responded_this_turn"),
ActionLiteral(".findall(Norm, norm(Norm), Norms)"),
ActionLiteral(".reply_with_goal(Message, Norms, Goal)"),
],
)
)
def _generate_phase_flow(self, phase: Phase, next_phase: Phase | None, asl: AgentSpeakFile):
"""Generates the main loop listener and the transition logic for this phase."""
# +user_said(Message) : phase(ID) <- !goal1; !goal2; !transition_phase.
goal_actions = [ActionLiteral("-responded_this_turn")]
goal_actions += [
ActionLiteral(f"!check_{self._slugify_str(keyword)}")
for keyword in self._get_keyword_conditionals(phase)
]
goal_actions += [ActionLiteral(f"!{self._slugify(g)}") for g in phase.goals]
goal_actions.append(ActionLiteral("!transition_phase"))
asl.plans.append(
Plan(
trigger=BeliefLiteral("user_said", ["Message"]),
context=[BeliefLiteral("phase", [f'"{phase.id}"'])],
body=goal_actions,
)
)
# +!transition_phase : phase(ID) <- -phase(ID); +(NEXT_ID).
next_id = str(next_phase.id) if next_phase else "end"
transition_context = [BeliefLiteral("phase", [f'"{phase.id}"'])]
if phase.goals:
transition_context.append(BeliefLiteral(f"achieved_{self._slugify(phase.goals[-1])}"))
asl.plans.append(
Plan(
trigger=GoalLiteral("transition_phase"),
context=transition_context,
body=[
ActionLiteral(f'-phase("{phase.id}")'),
ActionLiteral(f'+phase("{next_id}")'),
ActionLiteral("user_said(Anything)"),
ActionLiteral("-+user_said(Anything)"),
],
)
)
def _get_keyword_conditionals(self, phase: Phase) -> list[str]:
res = []
for belief in self._extract_basic_beliefs_from_phase(phase):
if isinstance(belief, KeywordBelief):
res.append(belief.keyword)
return res
# --- Section: Norms & Beliefs ---
def _generate_norms(self, phase: Phase, asl: AgentSpeakFile):
for norm in phase.norms:
norm_slug = f'"{norm.norm}"'
head = BeliefLiteral("norm", [norm_slug])
# Base context is the phase
phase_lit = BeliefLiteral("phase", [f'"{phase.id}"'])
if isinstance(norm, ConditionalNorm):
self._ensure_belief_inference(norm.condition, asl)
condition_expr = self._belief_to_expr(norm.condition)
body = BinaryOp(phase_lit, "&", condition_expr)
else:
body = phase_lit
asl.inference_rules.append(Rule(head=head, body=body))
def _ensure_belief_inference(self, belief: Belief, asl: AgentSpeakFile):
"""
Recursively adds rules to infer beliefs.
Checks strictly to avoid duplicates if necessary,
though ASL engines often handle redefinition or we can use a set to track processed IDs.
"""
if isinstance(belief, KeywordBelief):
pass
# # Rule: keyword_said("word") :- user_said(M) & .substring("word", M, P) & P >= 0.
# kwd_slug = f'"{belief.keyword}"'
# head = BeliefLiteral("keyword_said", [kwd_slug])
#
# # Avoid duplicates
# if any(str(r.head) == str(head) for r in asl.inference_rules):
# return
#
# body = BinaryOp(
# BeliefLiteral("user_said", ["Message"]),
# "&",
# BinaryOp(f".substring({kwd_slug}, Message, Pos)", "&", "Pos >= 0"),
# )
#
# asl.inference_rules.append(Rule(head=head, body=body))
elif isinstance(belief, InferredBelief):
self._ensure_belief_inference(belief.left, asl)
self._ensure_belief_inference(belief.right, asl)
slug = self._slugify(belief)
head = BeliefLiteral(slug)
if any(str(r.head) == str(head) for r in asl.inference_rules):
return
op_char = "&" if belief.operator == LogicalOperator.AND else "|"
body = BinaryOp(
self._belief_to_expr(belief.left), op_char, self._belief_to_expr(belief.right)
)
asl.inference_rules.append(PersistentRule(head=head, body=body))
def _belief_to_expr(self, belief: Belief) -> Expression:
if isinstance(belief, KeywordBelief):
return BeliefLiteral("keyword_said", [f'"{belief.keyword}"'])
else:
return BeliefLiteral(self._slugify(belief))
# --- Section: Goals ---
def _generate_goals(self, phase: Phase, asl: AgentSpeakFile):
previous_goal: Goal | None = None
for goal in phase.goals:
self._generate_goal_plan_recursive(goal, str(phase.id), previous_goal, asl)
previous_goal = goal
def _generate_goal_plan_recursive(
self,
goal: Goal,
phase_id: str,
previous_goal: Goal | None,
asl: AgentSpeakFile,
responded_needed: bool = True,
can_fail: bool = True,
):
goal_slug = self._slugify(goal)
# phase(ID) & not responded_this_turn & not achieved_goal
context = [
BeliefLiteral("phase", [f'"{phase_id}"']),
]
if responded_needed:
context.append(BeliefLiteral("responded_this_turn", negated=True))
if can_fail:
context.append(BeliefLiteral(f"achieved_{goal_slug}", negated=True))
if previous_goal:
prev_slug = self._slugify(previous_goal)
context.append(BeliefLiteral(f"achieved_{prev_slug}"))
body_actions = []
sub_goals_to_process = []
for step in goal.plan.steps:
if isinstance(step, Goal):
sub_slug = self._slugify(step)
body_actions.append(ActionLiteral(f"!{sub_slug}"))
sub_goals_to_process.append(step)
elif isinstance(step, SpeechAction):
body_actions.append(ActionLiteral(f'.say("{step.text}")'))
elif isinstance(step, GestureAction):
body_actions.append(ActionLiteral(f'.gesture("{step.gesture}")'))
elif isinstance(step, LLMAction):
body_actions.append(ActionLiteral(f'!generate_response_with_goal("{step.goal}")'))
# Mark achievement
if not goal.can_fail:
body_actions.append(ActionLiteral(f"+achieved_{goal_slug}"))
asl.plans.append(Plan(trigger=GoalLiteral(goal_slug), context=context, body=body_actions))
asl.plans.append(
Plan(trigger=GoalLiteral(goal_slug), context=[], body=[ActionLiteral("true")])
)
prev_sub = None
for sub_goal in sub_goals_to_process:
self._generate_goal_plan_recursive(sub_goal, phase_id, prev_sub, asl)
prev_sub = sub_goal
# --- Section: Triggers ---
def _generate_triggers(self, phase: Phase, asl: AgentSpeakFile):
for keyword in self._get_keyword_conditionals(phase):
asl.plans.append(
Plan(
trigger=GoalLiteral(f"check_{self._slugify_str(keyword)}"),
context=[
ActionLiteral(
f'user_said(Message) & .substring("{keyword}", Message, Pos) & Pos >= 0'
)
],
body=[
ActionLiteral(f'+keyword_said("{keyword}")'),
ActionLiteral(f'-keyword_said("{keyword}")'),
],
)
)
asl.plans.append(
Plan(
trigger=GoalLiteral(f"check_{self._slugify_str(keyword)}"),
body=[ActionLiteral("true")],
)
)
for trigger in phase.triggers:
self._ensure_belief_inference(trigger.condition, asl)
trigger_belief_slug = self._belief_to_expr(trigger.condition)
body_actions = []
sub_goals = []
for step in trigger.plan.steps:
if isinstance(step, Goal):
sub_slug = self._slugify(step)
body_actions.append(ActionLiteral(f"!{sub_slug}"))
sub_goals.append(step)
elif isinstance(step, SpeechAction):
body_actions.append(ActionLiteral(f'.say("{step.text}")'))
elif isinstance(step, GestureAction):
body_actions.append(
ActionLiteral(f'.gesture("{step.gesture.type}", "{step.gesture.name}")')
)
elif isinstance(step, LLMAction):
body_actions.append(
ActionLiteral(f'!generate_response_with_goal("{step.goal}")')
)
asl.plans.append(
Plan(
trigger=BeliefLiteral(trigger_belief_slug),
context=[BeliefLiteral("phase", [f'"{phase.id}"'])],
body=body_actions,
)
)
# Recurse for triggered goals
prev_sub = None
for sub_goal in sub_goals:
self._generate_goal_plan_recursive(
sub_goal, str(phase.id), prev_sub, asl, False, False
)
prev_sub = sub_goal
# --- Section: Fallbacks ---
def _generate_fallbacks(self, program: Program, asl: AgentSpeakFile):
asl.plans.append(
Plan(trigger=GoalLiteral("transition_phase"), context=[], body=[ActionLiteral("true")])
)
# --- Helpers ---
@singledispatchmethod
def _slugify(self, element: ProgramElement) -> str:
if element.name:
raise NotImplementedError("Cannot slugify this element.")
return self._slugify_str(element.name)
@_slugify.register
def _(self, goal: Goal) -> str:
if goal.name:
return self._slugify_str(goal.name)
return f"goal_{goal.id.hex}"
@_slugify.register
def _(self, kwb: KeywordBelief) -> str:
return f"keyword_said({kwb.keyword})"
@_slugify.register
def _(self, sb: SemanticBelief) -> str:
return self._slugify_str(sb.description)
@_slugify.register
def _(self, ib: InferredBelief) -> str:
return self._slugify_str(ib.name)
def _slugify_str(self, text: str) -> str:
return slugify(text, separator="_", stopwords=["a", "an", "the", "we", "you", "I"])
def _extract_basic_beliefs_from_program(self, program: Program) -> list[BasicBelief]:
beliefs = []
for phase in program.phases:
beliefs.extend(self._extract_basic_beliefs_from_phase(phase))
return beliefs
def _extract_basic_beliefs_from_phase(self, phase: Phase) -> list[BasicBelief]:
beliefs = []
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += self._extract_basic_beliefs_from_belief(norm.condition)
for trigger in phase.triggers:
beliefs += self._extract_basic_beliefs_from_belief(trigger.condition)
return beliefs
def _extract_basic_beliefs_from_belief(self, belief: Belief) -> list[BasicBelief]:
if isinstance(belief, InferredBelief):
return self._extract_basic_beliefs_from_belief(
belief.left
) + self._extract_basic_beliefs_from_belief(belief.right)
return [belief]
if __name__ == "__main__":
asyncio.run(do_things())
# do_other_things()y

View File

@@ -1,5 +1,6 @@
import asyncio
import copy
import json
import time
from collections.abc import Iterable
@@ -13,7 +14,7 @@ from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
from control_backend.schemas.ri_message import SpeechCommand
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, SpeechCommand
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
@@ -100,14 +101,12 @@ class BDICoreAgent(BaseAgent):
maybe_more_work = True
while maybe_more_work:
maybe_more_work = False
self.logger.debug("Stepping BDI.")
if self.bdi_agent.step():
maybe_more_work = True
if not maybe_more_work:
deadline = self.bdi_agent.shortest_deadline()
if deadline:
self.logger.debug("Sleeping until %s", deadline)
await asyncio.sleep(deadline - time.time())
maybe_more_work = True
else:
@@ -155,6 +154,20 @@ class BDICoreAgent(BaseAgent):
body=cmd.model_dump_json(),
)
await self.send(out_msg)
case settings.agent_settings.user_interrupt_name:
self.logger.debug("Received user interruption: %s", msg)
match msg.thread:
case "force_phase_transition":
self._set_goal("transition_phase")
case "force_trigger":
self._force_trigger(msg.body)
case "force_norm":
self._force_norm(msg.body)
case "force_next_phase":
self._force_next_phase()
case _:
self.logger.warning("Received unknown user interruption: %s", msg)
def _apply_belief_changes(self, belief_changes: BeliefMessage):
"""
@@ -201,16 +214,35 @@ class BDICoreAgent(BaseAgent):
agentspeak.runtime.Intention(),
)
# Check for transitions
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
agentspeak.Literal("transition_phase"),
agentspeak.runtime.Intention(),
)
# Check triggers
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
agentspeak.Literal("check_triggers"),
agentspeak.runtime.Intention(),
)
self._wake_bdi_loop.set()
self.logger.debug(f"Added belief {self.format_belief_string(name, args)}")
def _remove_belief(self, name: str, args: Iterable[str]):
def _remove_belief(self, name: str, args: Iterable[str] | None):
"""
Removes a specific belief (with arguments), if it exists.
"""
new_args = (agentspeak.Literal(arg) for arg in args)
term = agentspeak.Literal(name, new_args)
if args is None:
term = agentspeak.Literal(name)
else:
new_args = (agentspeak.Literal(arg) for arg in args)
term = agentspeak.Literal(name, new_args)
result = self.bdi_agent.call(
agentspeak.Trigger.removal,
@@ -250,6 +282,43 @@ class BDICoreAgent(BaseAgent):
self.logger.debug(f"Removed {removed_count} beliefs.")
def _set_goal(self, name: str, args: Iterable[str] | None = None):
args = args or []
if args:
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args)
else:
term = agentspeak.Literal(name)
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
term,
agentspeak.runtime.Intention(),
)
self._wake_bdi_loop.set()
self.logger.debug(f"Set goal !{self.format_belief_string(name, args)}.")
def _force_trigger(self, name: str):
self._set_goal(name)
self.logger.info("Manually forced trigger %s.", name)
# TODO: make this compatible for critical norms
def _force_norm(self, name: str):
self._add_belief(f"force_{name}")
self.logger.info("Manually forced norm %s.", name)
def _force_next_phase(self):
self._set_goal("force_transition_phase")
self.logger.info("Manually forced phase transition.")
def _add_custom_actions(self) -> None:
"""
Add any custom actions here. Inside `@self.actions.add()`, the first argument is
@@ -258,16 +327,13 @@ class BDICoreAgent(BaseAgent):
"""
@self.actions.add(".reply", 2)
def _reply(agent: "BDICoreAgent", term, intention):
def _reply(agent, term, intention):
"""
Let the LLM generate a response to a user's utterance with the current norms and goals.
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
self.logger.debug("Norms: %s", norms)
self.logger.debug("User text: %s", message_text)
self.add_behavior(self._send_to_llm(str(message_text), str(norms), ""))
yield
@@ -280,18 +346,24 @@ class BDICoreAgent(BaseAgent):
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
goal = agentspeak.grounded(term.args[2], intention.scope)
self.logger.debug(
'"reply_with_goal" action called with message=%s, norms=%s, goal=%s',
message_text,
norms,
goal,
)
self.add_behavior(self._send_to_llm(str(message_text), str(norms), str(goal)))
yield
@self.actions.add(".notify_norms", 1)
def _notify_norms(agent, term, intention):
norms = agentspeak.grounded(term.args[0], intention.scope)
norm_update_message = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
thread="active_norms_update",
body=str(norms),
)
self.add_behavior(self.send(norm_update_message, should_log=False))
yield
@self.actions.add(".say", 1)
def _say(agent: "BDICoreAgent", term, intention):
def _say(agent, term, intention):
"""
Make the robot say the given text instantly.
"""
@@ -305,12 +377,21 @@ class BDICoreAgent(BaseAgent):
sender=settings.agent_settings.bdi_core_name,
body=speech_command.model_dump_json(),
)
# TODO: add to conversation history
self.add_behavior(self.send(speech_message))
chat_history_message = InternalMessage(
to=settings.agent_settings.llm_name,
thread="assistant_message",
body=str(message_text),
)
self.add_behavior(self.send(chat_history_message))
yield
@self.actions.add(".gesture", 2)
def _gesture(agent: "BDICoreAgent", term, intention):
def _gesture(agent, term, intention):
"""
Make the robot perform the given gesture instantly.
"""
@@ -323,15 +404,118 @@ class BDICoreAgent(BaseAgent):
gesture_name,
)
# gesture = Gesture(type=gesture_type, name=gesture_name)
# gesture_message = InternalMessage(
# to=settings.agent_settings.robot_gesture_name,
# sender=settings.agent_settings.bdi_core_name,
# body=gesture.model_dump_json(),
# )
# asyncio.create_task(agent.send(gesture_message))
if str(gesture_type) == "single":
endpoint = RIEndpoint.GESTURE_SINGLE
elif str(gesture_type) == "tag":
endpoint = RIEndpoint.GESTURE_TAG
else:
self.logger.warning("Gesture type %s could not be resolved.", gesture_type)
endpoint = RIEndpoint.GESTURE_SINGLE
gesture_command = GestureCommand(endpoint=endpoint, data=gesture_name)
gesture_message = InternalMessage(
to=settings.agent_settings.robot_gesture_name,
sender=settings.agent_settings.bdi_core_name,
body=gesture_command.model_dump_json(),
)
self.add_behavior(self.send(gesture_message))
yield
@self.actions.add(".notify_user_said", 1)
def _notify_user_said(agent, term, intention):
user_said = agentspeak.grounded(term.args[0], intention.scope)
msg = InternalMessage(
to=settings.agent_settings.llm_name, thread="user_message", body=str(user_said)
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_trigger_start", 1)
def _notify_trigger_start(agent, term, intention):
"""
Notify the UI about the trigger we just started doing.
"""
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Started trigger %s", trigger_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="trigger_start",
body=str(trigger_name),
)
# TODO: check with Pim
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_trigger_end", 1)
def _notify_trigger_end(agent, term, intention):
"""
Notify the UI about the trigger we just started doing.
"""
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Finished trigger %s", trigger_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="trigger_end",
body=str(trigger_name),
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_goal_start", 1)
def _notify_goal_start(agent, term, intention):
"""
Notify the UI about the goal we just started chasing.
"""
goal_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Started chasing goal %s", goal_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="goal_start",
body=str(goal_name),
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_transition_phase", 2)
def _notify_transition_phase(agent, term, intention):
"""
Notify the BDI program manager about a phase transition.
"""
old = agentspeak.grounded(term.args[0], intention.scope)
new = agentspeak.grounded(term.args[1], intention.scope)
msg = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
thread="transition_phase",
body=json.dumps({"old": str(old), "new": str(new)}),
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_ui", 0)
def _notify_ui(agent, term, intention):
pass
async def _send_to_llm(self, text: str, norms: str, goals: str):
"""
Sends a text query to the LLM agent asynchronously.
@@ -341,13 +525,14 @@ class BDICoreAgent(BaseAgent):
to=settings.agent_settings.llm_name,
sender=self.name,
body=prompt.model_dump_json(),
thread="prompt_message",
)
await self.send(msg)
self.logger.info("Message sent to LLM agent: %s", text)
@staticmethod
def format_belief_string(name: str, args: Iterable[str] = []):
def format_belief_string(name: str, args: Iterable[str] | None = []):
"""
Given a belief's name and its args, return a string of the form "name(*args)"
"""
return f"{name}{'(' if args else ''}{','.join(args)}{')' if args else ''}"
return f"{name}{'(' if args else ''}{','.join(args or [])}{')' if args else ''}"

View File

@@ -1,4 +1,5 @@
import asyncio
import json
import zmq
from pydantic import ValidationError
@@ -7,9 +8,16 @@ from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.config import settings
from control_backend.schemas.belief_list import BeliefList
from control_backend.schemas.belief_list import BeliefList, GoalList
from control_backend.schemas.internal_message import InternalMessage
from control_backend.schemas.program import Belief, ConditionalNorm, InferredBelief, Program
from control_backend.schemas.program import (
Belief,
ConditionalNorm,
Goal,
InferredBelief,
Phase,
Program,
)
class BDIProgramManager(BaseAgent):
@@ -24,20 +32,30 @@ class BDIProgramManager(BaseAgent):
:ivar sub_socket: The ZMQ SUB socket used to receive program updates.
"""
_program: Program
_phase: Phase | None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
def _initialize_internal_state(self, program: Program):
self._program = program
self._phase = program.phases[0] # start in first phase
self._goal_mapping: dict[str, Goal] = {}
for phase in program.phases:
for goal in phase.goals:
self._populate_goal_mapping_with_goal(goal)
def _populate_goal_mapping_with_goal(self, goal: Goal):
self._goal_mapping[str(goal.id)] = goal
for step in goal.plan.steps:
if isinstance(step, Goal):
self._populate_goal_mapping_with_goal(step)
async def _create_agentspeak_and_send_to_bdi(self, program: Program):
"""
Convert a received program into BDI beliefs and send them to the BDI Core Agent.
Currently, it takes the **first phase** of the program and extracts:
- **Norms**: Constraints or rules the agent must follow.
- **Goals**: Objectives the agent must achieve.
These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will
overwrite any existing norms/goals of the same name in the BDI agent.
Convert a received program into an AgentSpeak file and send it to the BDI Core Agent.
:param program: The program object received from the API.
"""
@@ -59,17 +77,63 @@ class BDIProgramManager(BaseAgent):
await self.send(msg)
@staticmethod
def _extract_beliefs_from_program(program: Program) -> list[Belief]:
async def handle_message(self, msg: InternalMessage):
match msg.thread:
case "transition_phase":
phases = json.loads(msg.body)
await self._transition_phase(phases["old"], phases["new"])
case "achieve_goal":
goal_id = msg.body
await self._send_achieved_goal_to_semantic_belief_extractor(goal_id)
async def _transition_phase(self, old: str, new: str):
if old != str(self._phase.id):
self.logger.warning(
f"Phase transition desync detected! ASL requested move from '{old}', "
f"but Python is currently in '{self._phase.id}'. Request ignored."
)
return
if new == "end":
self._phase = None
# Notify user interaction agent
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
thread="transition_phase",
body="end",
)
self.logger.info("Transitioned to end phase, notifying UserInterruptAgent.")
self.add_behavior(self.send(msg))
return
for phase in self._program.phases:
if str(phase.id) == new:
self._phase = phase
await self._send_beliefs_to_semantic_belief_extractor()
await self._send_goals_to_semantic_belief_extractor()
# Notify user interaction agent
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
thread="transition_phase",
body=str(self._phase.id),
)
self.logger.info(f"Transitioned to phase {new}, notifying UserInterruptAgent.")
self.add_behavior(self.send(msg))
def _extract_current_beliefs(self) -> list[Belief]:
beliefs: list[Belief] = []
for phase in program.phases:
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += BDIProgramManager._extract_beliefs_from_belief(norm.condition)
for norm in self._phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += self._extract_beliefs_from_belief(norm.condition)
for trigger in phase.triggers:
beliefs += BDIProgramManager._extract_beliefs_from_belief(trigger.condition)
for trigger in self._phase.triggers:
beliefs += self._extract_beliefs_from_belief(trigger.condition)
return beliefs
@@ -81,13 +145,11 @@ class BDIProgramManager(BaseAgent):
) + BDIProgramManager._extract_beliefs_from_belief(belief.right)
return [belief]
async def _send_beliefs_to_semantic_belief_extractor(self, program: Program):
async def _send_beliefs_to_semantic_belief_extractor(self):
"""
Extract beliefs from the program and send them to the Semantic Belief Extractor Agent.
:param program: The program received from the API.
"""
beliefs = BeliefList(beliefs=self._extract_beliefs_from_program(program))
beliefs = BeliefList(beliefs=self._extract_current_beliefs())
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
@@ -98,12 +160,94 @@ class BDIProgramManager(BaseAgent):
await self.send(message)
@staticmethod
def _extract_goals_from_goal(goal: Goal) -> list[Goal]:
"""
Extract all goals from a given goal, that is: the goal itself and any subgoals.
:return: All goals within and including the given goal.
"""
goals: list[Goal] = [goal]
for plan in goal.plan:
if isinstance(plan, Goal):
goals.extend(BDIProgramManager._extract_goals_from_goal(plan))
return goals
def _extract_current_goals(self) -> list[Goal]:
"""
Extract all goals from the program, including subgoals.
:return: A list of Goal objects.
"""
goals: list[Goal] = []
for goal in self._phase.goals:
goals.extend(self._extract_goals_from_goal(goal))
return goals
async def _send_goals_to_semantic_belief_extractor(self):
"""
Extract goals for the current phase and send them to the Semantic Belief Extractor Agent.
"""
goals = GoalList(goals=self._extract_current_goals())
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=self.name,
body=goals.model_dump_json(),
thread="goals",
)
await self.send(message)
async def _send_achieved_goal_to_semantic_belief_extractor(self, achieved_goal_id: str):
"""
Inform the semantic belief extractor when a goal is marked achieved.
:param achieved_goal_id: The id of the achieved goal.
"""
goal = self._goal_mapping.get(achieved_goal_id)
if goal is None:
self.logger.debug(f"Goal with ID {achieved_goal_id} marked achieved but was not found.")
return
goals = self._extract_goals_from_goal(goal)
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
body=GoalList(goals=goals).model_dump_json(),
thread="achieved_goals",
)
await self.send(message)
async def _send_clear_llm_history(self):
"""
Clear the LLM Agent's conversation history.
Sends an empty history to the LLM Agent to reset its state.
"""
message = InternalMessage(
to=settings.agent_settings.llm_name,
body="clear_history",
)
await self.send(message)
self.logger.debug("Sent message to LLM agent to clear history.")
extractor_msg = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
thread="conversation_history",
body="reset",
)
await self.send(extractor_msg)
self.logger.debug("Sent message to extractor agent to clear history.")
async def _receive_programs(self):
"""
Continuous loop that receives program updates from the HTTP endpoint.
It listens to the ``program`` topic on the internal ZMQ SUB socket.
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`.
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
@@ -111,21 +255,43 @@ class BDIProgramManager(BaseAgent):
try:
program = Program.model_validate_json(body)
except ValidationError:
self.logger.exception("Received an invalid program.")
self.logger.warning("Received an invalid program.")
continue
self._initialize_internal_state(program)
await self._send_program_to_user_interrupt(program)
await self._send_clear_llm_history()
await asyncio.gather(
self._create_agentspeak_and_send_to_bdi(program),
self._send_beliefs_to_semantic_belief_extractor(program),
self._send_beliefs_to_semantic_belief_extractor(),
self._send_goals_to_semantic_belief_extractor(),
)
async def _send_program_to_user_interrupt(self, program: Program):
"""
Send the received program to the User Interrupt Agent.
:param program: The program object received from the API.
"""
msg = InternalMessage(
sender=self.name,
to=settings.agent_settings.user_interrupt_name,
body=program.model_dump_json(),
thread="new_program",
)
await self.send(msg)
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'program' topic.
Starts the background behavior to receive programs.
Starts the background behavior to receive programs. Initializes a default program.
"""
await self._create_agentspeak_and_send_to_bdi(Program(phases=[]))
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)

View File

@@ -1,152 +0,0 @@
import json
from pydantic import ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
class BDIBeliefCollectorAgent(BaseAgent):
"""
BDI Belief Collector Agent.
This agent acts as a central aggregator for beliefs derived from various sources (e.g., text,
emotion, vision). It receives raw extracted data from other agents,
normalizes them into valid :class:`Belief` objects, and forwards them as a unified packet to the
BDI Core Agent.
It serves as a funnel to ensure the BDI agent receives a consistent stream of beliefs.
"""
async def setup(self):
"""
Initialize the agent.
"""
self.logger.info("Setting up %s", self.name)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages from other extractor agents.
Routes the message to specific handlers based on the 'type' field in the JSON body.
Supported types:
- ``belief_extraction_text``: Handled by :meth:`_handle_belief_text`
- ``emotion_extraction_text``: Handled by :meth:`_handle_emo_text`
:param msg: The received internal message.
"""
sender_node = msg.sender
# Parse JSON payload
try:
payload = json.loads(msg.body)
except Exception as e:
self.logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node,
msg.body,
e,
)
return
msg_type = payload.get("type")
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text":
self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node)
await self._handle_belief_text(payload, sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text":
self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
await self._handle_emo_text(payload, sender_node)
else:
self.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type
)
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Process text-based belief extraction payloads.
Expected payload format::
{
"type": "belief_extraction_text",
"beliefs": {
"user_said": ["Can you help me?"],
"intention": ["ask_help"]
}
}
Validates and converts the dictionary items into :class:`Belief` objects.
:param payload: The dictionary payload containing belief data.
:param origin: The name of the sender agent.
"""
beliefs = payload.get("beliefs", {})
if not beliefs:
self.logger.debug("Received empty beliefs set.")
return
def try_create_belief(name, arguments) -> Belief | None:
"""
Create a belief object from name and arguments, or return None silently if the input is
not correct.
:param name: The name of the belief.
:param arguments: The arguments of the belief.
:return: A Belief object if the input is valid or None.
"""
try:
return Belief(name=name, arguments=arguments, replace=name == "user_said")
except ValidationError:
return None
beliefs = [
belief
for name, arguments in beliefs.items()
if (belief := try_create_belief(name, arguments)) is not None
]
self.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief in beliefs:
for argument in belief.arguments:
self.logger.debug(" - %s %s", belief.name, argument)
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""
Process emotion extraction payloads.
**TODO**: Implement this method once emotion recognition is integrated.
:param payload: The dictionary payload containing emotion data.
:param origin: The name of the sender agent.
"""
pass
async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None):
"""
Send a list of aggregated beliefs to the BDI Core Agent.
Wraps the beliefs in a :class:`BeliefMessage` and sends it via the 'beliefs' thread.
:param beliefs: The list of Belief objects to send.
:param origin: (Optional) The original source of the beliefs (unused currently).
"""
if not beliefs:
return
msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=BeliefMessage(create=beliefs).model_dump_json(),
thread="beliefs",
)
await self.send(msg)
self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))

View File

@@ -1,5 +1,34 @@
norms("").
phase("end").
keyword_said(Keyword) :- (user_said(Message) & .substring(Keyword, Message, Pos)) & (Pos >= 0).
+user_said(Message) : norms(Norms) <-
-user_said(Message);
.reply(Message, Norms).
+!reply_with_goal(Goal)
: user_said(Message)
<- +responded_this_turn;
.findall(Norm, norm(Norm), Norms);
.reply_with_goal(Message, Norms, Goal).
+!say(Text)
<- +responded_this_turn;
.say(Text).
+!reply
: user_said(Message)
<- +responded_this_turn;
.findall(Norm, norm(Norm), Norms);
.reply(Message, Norms).
+!notify_cycle
<- .notify_ui;
.wait(1).
+user_said(Message)
: phase("end")
<- .notify_user_said(Message);
!reply.
+!check_triggers
<- true.
+!transition_phase
<- true.

View File

@@ -2,17 +2,45 @@ import asyncio
import json
import httpx
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_list import BeliefList
from control_backend.schemas.belief_list import BeliefList, GoalList
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import SemanticBelief
from control_backend.schemas.program import BaseGoal, SemanticBelief
type JSONLike = None | bool | int | float | str | list["JSONLike"] | dict[str, "JSONLike"]
class BeliefState(BaseModel):
true: set[InternalBelief] = set()
false: set[InternalBelief] = set()
def difference(self, other: "BeliefState") -> "BeliefState":
return BeliefState(
true=self.true - other.true,
false=self.false - other.false,
)
def union(self, other: "BeliefState") -> "BeliefState":
return BeliefState(
true=self.true | other.true,
false=self.false | other.false,
)
def __sub__(self, other):
return self.difference(other)
def __or__(self, other):
return self.union(other)
def __bool__(self):
return bool(self.true) or bool(self.false)
class TextBeliefExtractorAgent(BaseAgent):
@@ -27,12 +55,15 @@ class TextBeliefExtractorAgent(BaseAgent):
the message itself.
"""
def __init__(self, name: str, temperature: float = settings.llm_settings.code_temperature):
def __init__(self, name: str):
super().__init__(name)
self.beliefs: dict[str, bool] = {}
self.available_beliefs: list[SemanticBelief] = []
self._llm = self.LLM(self, settings.llm_settings.n_parallel)
self.belief_inferrer = SemanticBeliefInferrer(self._llm)
self.goal_inferrer = GoalAchievementInferrer(self._llm)
self._current_beliefs = BeliefState()
self._current_goal_completions: dict[str, bool] = {}
self._force_completed_goals: set[BaseGoal] = set()
self.conversation = ChatHistory(messages=[])
self.temperature = temperature
async def setup(self):
"""
@@ -53,13 +84,14 @@ class TextBeliefExtractorAgent(BaseAgent):
case settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="user", content=msg.body))
await self._infer_new_beliefs()
await self._user_said(msg.body)
await self._infer_new_beliefs()
await self._infer_goal_completions()
case settings.agent_settings.llm_name:
self.logger.debug("Received text from LLM: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
case settings.agent_settings.bdi_program_manager_name:
self._handle_program_manager_message(msg)
await self._handle_program_manager_message(msg)
case _:
self.logger.info("Discarding message from %s", sender)
return
@@ -74,12 +106,35 @@ class TextBeliefExtractorAgent(BaseAgent):
length_limit = settings.behaviour_settings.conversation_history_length_limit
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:]
def _handle_program_manager_message(self, msg: InternalMessage):
async def _handle_program_manager_message(self, msg: InternalMessage):
"""
Handle a message from the program manager: extract available beliefs from it.
Handle a message from the program manager: extract available beliefs and goals from it.
:param msg: The received message from the program manager.
"""
match msg.thread:
case "beliefs":
self._handle_beliefs_message(msg)
await self._infer_new_beliefs()
case "goals":
self._handle_goals_message(msg)
await self._infer_goal_completions()
case "achieved_goals":
self._handle_goal_achieved_message(msg)
case "conversation_history":
if msg.body == "reset":
self._reset_phase()
case _:
self.logger.warning("Received unexpected message from %s", msg.sender)
def _reset_phase(self):
self.conversation = ChatHistory(messages=[])
self.belief_inferrer.available_beliefs.clear()
self._current_beliefs = BeliefState()
self.goal_inferrer.goals.clear()
self._current_goal_completions = {}
def _handle_beliefs_message(self, msg: InternalMessage):
try:
belief_list = BeliefList.model_validate_json(msg.body)
except ValidationError:
@@ -88,133 +143,262 @@ class TextBeliefExtractorAgent(BaseAgent):
)
return
self.available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
self.belief_inferrer.available_beliefs = available_beliefs
self.logger.debug(
"Received %d beliefs from the program manager.",
len(self.available_beliefs),
"Received %d semantic beliefs from the program manager: %s",
len(available_beliefs),
", ".join(b.name for b in available_beliefs),
)
def _handle_goals_message(self, msg: InternalMessage):
try:
goals_list = GoalList.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received message from program manager but it is not a valid list of goals."
)
return
# Use only goals that can fail, as the others are always assumed to be completed
available_goals = {g for g in goals_list.goals if g.can_fail}
available_goals -= self._force_completed_goals
self.goal_inferrer.goals = available_goals
self.logger.debug(
"Received %d failable goals from the program manager: %s",
len(available_goals),
", ".join(g.name for g in available_goals),
)
def _handle_goal_achieved_message(self, msg: InternalMessage):
# NOTE: When goals can be marked unachieved, remember to re-add them to the goal_inferrer
try:
goals_list = GoalList.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received goal achieved message from the program manager, "
"but it is not a valid list of goals."
)
return
for goal in goals_list.goals:
self._force_completed_goals.add(goal)
self._current_goal_completions[f"achieved_{AgentSpeakGenerator.slugify(goal)}"] = True
self.goal_inferrer.goals -= self._force_completed_goals
async def _user_said(self, text: str):
"""
Create a belief for the user's full speech.
:param text: User's transcribed text.
"""
belief = {"beliefs": {"user_said": [text]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = InternalMessage(
to=settings.agent_settings.bdi_belief_collector_name,
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=payload,
body=BeliefMessage(
replace=[InternalBelief(name="user_said", arguments=[text])],
).model_dump_json(),
thread="beliefs",
)
await self.send(belief_msg)
async def _infer_new_beliefs(self):
"""
Process conversation history to extract beliefs, semantically. Any changed beliefs are sent
to the BDI core.
"""
# Return instantly if there are no beliefs to infer
if not self.available_beliefs:
conversation_beliefs = await self.belief_inferrer.infer_from_conversation(self.conversation)
new_beliefs = conversation_beliefs - self._current_beliefs
if not new_beliefs:
self.logger.debug("No new beliefs detected.")
return
candidate_beliefs = await self._infer_turn()
belief_changes = BeliefMessage()
for belief_key, belief_value in candidate_beliefs.items():
if belief_value is None:
continue
old_belief_value = self.beliefs.get(belief_key)
if belief_value == old_belief_value:
continue
self._current_beliefs |= new_beliefs
self.beliefs[belief_key] = belief_value
belief_changes = BeliefMessage(
create=list(new_beliefs.true),
delete=list(new_beliefs.false),
)
belief = InternalBelief(name=belief_key, arguments=None)
if belief_value:
belief_changes.create.append(belief)
else:
belief_changes.delete.append(belief)
# Return if there were no changes in beliefs
if not belief_changes.has_values():
return
beliefs_message = InternalMessage(
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=belief_changes.model_dump_json(),
thread="beliefs",
)
await self.send(beliefs_message)
await self.send(message)
@staticmethod
def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]:
k, m = divmod(len(items), n)
return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]
async def _infer_goal_completions(self):
goal_completions = await self.goal_inferrer.infer_from_conversation(self.conversation)
async def _infer_turn(self) -> dict:
new_achieved = [
InternalBelief(name=goal, arguments=None)
for goal, achieved in goal_completions.items()
if achieved and self._current_goal_completions.get(goal) != achieved
]
new_not_achieved = [
InternalBelief(name=goal, arguments=None)
for goal, achieved in goal_completions.items()
if not achieved and self._current_goal_completions.get(goal) != achieved
]
for goal, achieved in goal_completions.items():
self._current_goal_completions[goal] = achieved
if not new_achieved and not new_not_achieved:
self.logger.debug("No goal achievement changes detected.")
return
belief_changes = BeliefMessage(
create=new_achieved,
delete=new_not_achieved,
)
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=belief_changes.model_dump_json(),
thread="beliefs",
)
await self.send(message)
class LLM:
"""
Process the stored conversation history to extract semantic beliefs. Returns a list of
beliefs that have been set to ``True``, ``False`` or ``None``.
:return: A dict mapping belief names to a value ``True``, ``False`` or ``None``.
Class that handles sending structured generation requests to an LLM.
"""
def __init__(self, agent: "TextBeliefExtractorAgent", n_parallel: int):
self._agent = agent
self._semaphore = asyncio.Semaphore(n_parallel)
async def query(self, prompt: str, schema: dict, tries: int = 3) -> JSONLike | None:
"""
Query the LLM with the given prompt and schema, return an instance of a dict conforming
to this schema. Try ``tries`` times, or return None.
:param prompt: Prompt to be queried.
:param schema: Schema to be queried.
:param tries: Number of times to try to query the LLM.
:return: An instance of a dict conforming to this schema, or None if failed.
"""
try_count = 0
while try_count < tries:
try_count += 1
try:
return await self._query_llm(prompt, schema)
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
if try_count < tries:
continue
self._agent.logger.exception(
"Failed to get LLM response after %d tries.",
try_count,
exc_info=e,
)
return None
async def _query_llm(self, prompt: str, schema: dict) -> JSONLike:
"""
Query an LLM with the given prompt and schema, return an instance of a dict conforming
to that schema.
:param prompt: The prompt to be queried.
:param schema: Schema to use during response.
:return: A dict conforming to this schema.
:raises httpx.HTTPStatusError: If the LLM server responded with an error.
:raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the
response was cut off early due to length limitations.
:raises KeyError: If the LLM server responded with no error, but the response was
invalid.
"""
async with self._semaphore:
async with httpx.AsyncClient() as client:
response = await client.post(
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": [{"role": "user", "content": prompt}],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "Beliefs",
"strict": True,
"schema": schema,
},
},
"reasoning_effort": "low",
"temperature": settings.llm_settings.code_temperature,
"stream": False,
},
timeout=30.0,
)
response.raise_for_status()
response_json = response.json()
json_message = response_json["choices"][0]["message"]["content"]
return json.loads(json_message)
class SemanticBeliefInferrer:
"""
Class that handles only prompting an LLM for semantic beliefs.
"""
def __init__(
self,
llm: "TextBeliefExtractorAgent.LLM",
available_beliefs: list[SemanticBelief] | None = None,
):
self._llm = llm
self.available_beliefs: list[SemanticBelief] = available_beliefs or []
async def infer_from_conversation(self, conversation: ChatHistory) -> BeliefState:
"""
Process conversation history to extract beliefs, semantically. The result is an object that
describes all beliefs that hold or don't hold based on the full conversation.
:param conversation: The conversation history to be processed.
:return: An object that describes beliefs.
"""
# Return instantly if there are no beliefs to infer
if not self.available_beliefs:
return BeliefState()
n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)))
all_beliefs = await asyncio.gather(
all_beliefs: list[dict[str, bool | None] | None] = await asyncio.gather(
*[
self._infer_beliefs(self.conversation, beliefs)
self._infer_beliefs(conversation, beliefs)
for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel)
]
)
retval = {}
retval = BeliefState()
for beliefs in all_beliefs:
if beliefs is None:
continue
retval.update(beliefs)
for belief_name, belief_holds in beliefs.items():
if belief_holds is None:
continue
belief = InternalBelief(name=belief_name, arguments=None)
if belief_holds:
retval.true.add(belief)
else:
retval.false.add(belief)
return retval
@staticmethod
def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]:
return AgentSpeakGenerator.slugify(belief), {
"type": ["boolean", "null"],
"description": belief.description,
}
def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]:
"""
Split a list into ``n`` chunks, making each chunk approximately ``len(items) / n`` long.
@staticmethod
def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict:
belief_schemas = [
TextBeliefExtractorAgent._create_belief_schema(belief) for belief in beliefs
]
return {
"type": "object",
"properties": dict(belief_schemas),
"required": [name for name, _ in belief_schemas],
}
@staticmethod
def _format_message(message: ChatMessage):
return f"{message.role.upper()}:\n{message.content}"
@staticmethod
def _format_conversation(conversation: ChatHistory):
return "\n\n".join(
[TextBeliefExtractorAgent._format_message(message) for message in conversation.messages]
)
@staticmethod
def _format_beliefs(beliefs: list[SemanticBelief]):
return "\n".join(
[f"- {AgentSpeakGenerator.slugify(belief)}: {belief.description}" for belief in beliefs]
)
:param items: The list of items to split.
:param n: The number of desired chunks.
:return: A list of chunks each approximately ``len(items) / n`` long.
"""
k, m = divmod(len(items), n)
return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]
async def _infer_beliefs(
self,
conversation: ChatHistory,
beliefs: list[SemanticBelief],
) -> dict | None:
) -> dict[str, bool | None] | None:
"""
Infer given beliefs based on the given conversation.
:param conversation: The conversation to infer beliefs from.
@@ -241,70 +425,79 @@ Respond with a JSON similar to the following, but with the property names as giv
schema = self._create_beliefs_schema(beliefs)
return await self._retry_query_llm(prompt, schema)
return await self._llm.query(prompt, schema)
async def _retry_query_llm(self, prompt: str, schema: dict, tries: int = 3) -> dict | None:
@staticmethod
def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]:
return AgentSpeakGenerator.slugify(belief), {
"type": ["boolean", "null"],
"description": belief.description,
}
@staticmethod
def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict:
belief_schemas = [
SemanticBeliefInferrer._create_belief_schema(belief) for belief in beliefs
]
return {
"type": "object",
"properties": dict(belief_schemas),
"required": [name for name, _ in belief_schemas],
}
@staticmethod
def _format_message(message: ChatMessage):
return f"{message.role.upper()}:\n{message.content}"
@staticmethod
def _format_conversation(conversation: ChatHistory):
return "\n\n".join(
[SemanticBeliefInferrer._format_message(message) for message in conversation.messages]
)
@staticmethod
def _format_beliefs(beliefs: list[SemanticBelief]):
return "\n".join(
[f"- {AgentSpeakGenerator.slugify(belief)}: {belief.description}" for belief in beliefs]
)
class GoalAchievementInferrer(SemanticBeliefInferrer):
def __init__(self, llm: TextBeliefExtractorAgent.LLM):
super().__init__(llm)
self.goals: set[BaseGoal] = set()
async def infer_from_conversation(self, conversation: ChatHistory) -> dict[str, bool]:
"""
Query the LLM with the given prompt and schema, return an instance of a dict conforming
to this schema. Try ``tries`` times, or return None.
Determine which goals have been achieved based on the given conversation.
:param prompt: Prompt to be queried.
:param schema: Schema to be queried.
:return: An instance of a dict conforming to this schema, or None if failed.
:param conversation: The conversation to infer goal completion from.
:return: A mapping of goals and a boolean whether they have been achieved.
"""
try_count = 0
while try_count < tries:
try_count += 1
if not self.goals:
return {}
try:
return await self._query_llm(prompt, schema)
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
if try_count < tries:
continue
self.logger.exception(
"Failed to get LLM response after %d tries.",
try_count,
exc_info=e,
)
goals_achieved = await asyncio.gather(
*[self._infer_goal(conversation, g) for g in self.goals]
)
return {
f"achieved_{AgentSpeakGenerator.slugify(goal)}": achieved
for goal, achieved in zip(self.goals, goals_achieved, strict=True)
}
return None
async def _infer_goal(self, conversation: ChatHistory, goal: BaseGoal) -> bool:
prompt = f"""{self._format_conversation(conversation)}
async def _query_llm(self, prompt: str, schema: dict) -> dict:
"""
Query an LLM with the given prompt and schema, return an instance of a dict conforming to
that schema.
Given the above conversation, what has the following goal been achieved?
:param prompt: The prompt to be queried.
:param schema: Schema to use during response.
:return: A dict conforming to this schema.
:raises httpx.HTTPStatusError: If the LLM server responded with an error.
:raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the
response was cut off early due to length limitations.
:raises KeyError: If the LLM server responded with no error, but the response was invalid.
"""
async with httpx.AsyncClient() as client:
response = await client.post(
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": [{"role": "user", "content": prompt}],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "Beliefs",
"strict": True,
"schema": schema,
},
},
"reasoning_effort": "low",
"temperature": self.temperature,
"stream": False,
},
timeout=None,
)
response.raise_for_status()
The name of the goal: {goal.name}
Description of the goal: {goal.description}
response_json = response.json()
json_message = response_json["choices"][0]["message"]["content"]
beliefs = json.loads(json_message)
return beliefs
Answer with literally only `true` or `false` (without backticks)."""
schema = {
"type": "boolean",
}
return await self._llm.query(prompt, schema)

View File

@@ -3,12 +3,14 @@ import json
import zmq
import zmq.asyncio as azmq
from pydantic import ValidationError
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
from control_backend.schemas.ri_message import PauseCommand
from ..actuation.robot_speech_agent import RobotSpeechAgent
from ..perception import VADAgent
@@ -39,7 +41,7 @@ class RICommunicationAgent(BaseAgent):
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
address=settings.zmq_settings.ri_communication_address,
bind=False,
):
super().__init__(name)
@@ -172,7 +174,7 @@ class RICommunicationAgent(BaseAgent):
bind = port_data["bind"]
if not bind:
addr = f"tcp://localhost:{port}"
addr = f"tcp://{settings.ri_host}:{port}"
else:
addr = f"tcp://*:{port}"
@@ -255,7 +257,8 @@ class RICommunicationAgent(BaseAgent):
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
)
self.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" in message and message["endpoint"] != "ping":
self.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" not in message:
self.logger.warning("No received endpoint in message, expected ping endpoint.")
continue
@@ -319,12 +322,9 @@ class RICommunicationAgent(BaseAgent):
self.connected = True
async def handle_message(self, msg: InternalMessage):
"""
Handle an incoming message.
Currently not implemented for this agent.
:param msg: The received message.
:raises NotImplementedError: Always, since this method is not implemented.
"""
self.logger.warning("custom warning for handle msg in ri coms %s", self.name)
try:
pause_command = PauseCommand.model_validate_json(msg.body)
await self._req_socket.send_json(pause_command.model_dump())
self.logger.debug(await self._req_socket.recv_json())
except ValidationError:
self.logger.warning("Incorrect message format for PauseCommand.")

View File

@@ -46,14 +46,23 @@ class LLMAgent(BaseAgent):
:param msg: The received internal message.
"""
if msg.sender == settings.agent_settings.bdi_core_name:
self.logger.debug("Processing message from BDI core.")
try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message)
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
match msg.thread:
case "prompt_message":
try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message)
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
case "assistant_message":
self.history.append({"role": "assistant", "content": msg.body})
case "user_message":
self.history.append({"role": "user", "content": msg.body})
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
if msg.body == "clear_history":
self.logger.debug("Clearing conversation history.")
self.history.clear()
else:
self.logger.debug("Message ignored (not from BDI core.")
self.logger.debug("Message ignored.")
async def _process_bdi_message(self, message: LLMPromptMessage):
"""
@@ -114,13 +123,6 @@ class LLMAgent(BaseAgent):
:param goals: Goals the LLM should achieve.
:yield: Fragments of the LLM-generated content (e.g., sentences/phrases).
"""
self.history.append(
{
"role": "user",
"content": prompt,
}
)
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
messages = [
{

View File

@@ -110,12 +110,11 @@ class VADAgent(BaseAgent):
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
audio_out_address = self._connect_audio_out_socket()
if audio_out_address is None:
self.logger.error("Could not bind output socket, stopping.")
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Connect to internal communication socket
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
@@ -168,13 +167,14 @@ class VADAgent(BaseAgent):
self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
def _connect_audio_out_socket(self) -> str | None:
"""
Returns the port bound, or None if binding failed.
Returns the address that was bound to, or None if binding failed.
"""
try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
self.audio_out_socket.bind(settings.zmq_settings.vad_pub_address)
return settings.zmq_settings.vad_pub_address
except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
@@ -246,10 +246,11 @@ class VADAgent(BaseAgent):
assert self.model is not None
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
begin_silence_length = settings.behaviour_settings.vad_begin_silence_chunks
prob_threshold = settings.behaviour_settings.vad_prob_threshold
if prob > prob_threshold:
if self.i_since_speech > non_speech_patience:
if self.i_since_speech > non_speech_patience + begin_silence_length:
self.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
@@ -263,7 +264,7 @@ class VADAgent(BaseAgent):
continue
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
if len(self.audio_buffer) > begin_silence_length * len(chunk):
self.logger.debug("Speech ended.")
assert self.audio_out_socket is not None
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())

View File

@@ -4,8 +4,11 @@ import zmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.program import ConditionalNorm, Program
from control_backend.schemas.ri_message import (
GestureCommand,
PauseCommand,
@@ -29,12 +32,39 @@ class UserInterruptAgent(BaseAgent):
Prioritized actions clear the current RI queue before inserting the new item,
ensuring they are executed immediately after Pepper's current action has been fulfilled.
:ivar sub_socket: The ZMQ SUB socket used to receive user intterupts.
:ivar sub_socket: The ZMQ SUB socket used to receive user interrupts.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
self.pub_socket = None
self._trigger_map = {}
self._trigger_reverse_map = {}
self._goal_map = {} # id -> sluggified goal
self._goal_reverse_map = {} # sluggified goal -> id
self._cond_norm_map = {} # id -> sluggified cond norm
self._cond_norm_reverse_map = {} # sluggified cond norm -> id
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic.
Starts the background behavior to receive the user interrupts.
"""
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("button_pressed")
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
self.add_behavior(self._receive_button_event())
async def _receive_button_event(self):
"""
@@ -46,6 +76,9 @@ class UserInterruptAgent(BaseAgent):
- type: "speech", context: string that the robot has to say.
- type: "gesture", context: single gesture name that the robot has to perform.
- type: "override", context: belief_id that overrides the goal/trigger/conditional norm.
- type: "pause", context: boolean indicating whether to pause
- type: "reset_phase", context: None, indicates to the BDI Core to
- type: "reset_experiment", context: None, indicates to the BDI Core to
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
@@ -58,6 +91,8 @@ class UserInterruptAgent(BaseAgent):
self.logger.error("Received invalid JSON payload on topic %s", topic)
continue
self.logger.debug("Received event type %s", event_type)
if event_type == "speech":
await self._send_to_speech_agent(event_context)
self.logger.info(
@@ -71,11 +106,36 @@ class UserInterruptAgent(BaseAgent):
event_context,
)
elif event_type == "override":
await self._send_to_program_manager(event_context)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
event_context,
)
ui_id = str(event_context)
if asl_trigger := self._trigger_map.get(ui_id):
await self._send_to_bdi("force_trigger", asl_trigger)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
event_context,
)
elif asl_cond_norm := self._cond_norm_map.get(ui_id):
await self._send_to_bdi("force_norm", asl_cond_norm)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
event_context,
)
elif asl_goal := self._goal_map.get(ui_id):
await self._send_to_bdi_belief(asl_goal)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
event_context,
)
goal_achieve_msg = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
thread="achieve_goal",
body=ui_id,
)
await self.send(goal_achieve_msg)
else:
self.logger.warning("Could not determine which element to override.")
elif event_type == "pause":
self.logger.debug(
"Received pause/resume button press with context '%s'.", event_context
@@ -88,7 +148,6 @@ class UserInterruptAgent(BaseAgent):
elif event_type in ["next_phase", "reset_phase", "reset_experiment"]:
await self._send_experiment_control_to_bdi_core(event_type)
else:
self.logger.warning(
"Received button press with unknown type '%s' (context: '%s').",
@@ -96,35 +155,121 @@ class UserInterruptAgent(BaseAgent):
event_context,
)
async def _send_experiment_control_to_bdi_core(self, type):
async def handle_message(self, msg: InternalMessage):
"""
method to send experiment control buttons to bdi core.
Handle commands received from other internal Python agents.
"""
match msg.thread:
case "new_program":
self._create_mapping(msg.body)
case "trigger_start":
# msg.body is the sluggified trigger
asl_slug = msg.body
ui_id = self._trigger_reverse_map.get(asl_slug)
:param type: the type of control button we should send to the bdi core.
"""
# Switch which thread we should send to bdi core
thread = ""
match type:
case "next_phase":
thread = "force_next_phase"
case "reset_phase":
thread = "reset_current_phase"
case "reset_experiment":
thread = "reset_experiment"
if ui_id:
payload = {"type": "trigger_update", "id": ui_id, "achieved": True}
await self._send_experiment_update(payload)
self.logger.info(f"UI Update: Trigger {asl_slug} started (ID: {ui_id})")
case "trigger_end":
asl_slug = msg.body
ui_id = self._trigger_reverse_map.get(asl_slug)
if ui_id:
payload = {"type": "trigger_update", "id": ui_id, "achieved": False}
await self._send_experiment_update(payload)
self.logger.info(f"UI Update: Trigger {asl_slug} ended (ID: {ui_id})")
case "transition_phase":
new_phase_id = msg.body
self.logger.info(f"Phase transition detected: {new_phase_id}")
payload = {"type": "phase_update", "id": new_phase_id}
await self._send_experiment_update(payload)
case "goal_start":
goal_name = msg.body
ui_id = self._goal_reverse_map.get(goal_name)
if ui_id:
payload = {"type": "goal_update", "id": ui_id, "active": True}
await self._send_experiment_update(payload)
self.logger.info(f"UI Update: Goal {goal_name} started (ID: {ui_id})")
case "active_norms_update":
norm_list = [s.strip("() '\",") for s in msg.body.split(",") if s.strip("() '\",")]
await self._broadcast_cond_norms(norm_list)
case _:
self.logger.warning(
"Received unknown experiment control type '%s' to send to BDI Core.",
type,
)
self.logger.debug(f"Received internal message on unhandled thread: {msg.thread}")
out_msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
thread=thread,
body="",
)
self.logger.debug("Sending experiment control '%s' to BDI Core.", thread)
await self.send(out_msg)
async def _broadcast_cond_norms(self, active_slugs: list[str]):
"""
Sends the current state of all conditional norms to the UI.
:param active_slugs: A list of slugs (strings) currently active in the BDI core.
"""
updates = []
for asl_slug, ui_id in self._cond_norm_reverse_map.items():
is_active = asl_slug in active_slugs
updates.append({"id": ui_id, "name": asl_slug, "active": is_active})
payload = {"type": "cond_norms_state_update", "norms": updates}
await self._send_experiment_update(payload, should_log=False)
# self.logger.debug(f"Broadcasted state for {len(updates)} conditional norms.")
def _create_mapping(self, program_json: str):
"""
Create mappings between UI IDs and ASL slugs for triggers, goals, and conditional norms
"""
try:
program = Program.model_validate_json(program_json)
self._trigger_map = {}
self._trigger_reverse_map = {}
self._goal_map = {}
self._cond_norm_map = {}
self._cond_norm_reverse_map = {}
for phase in program.phases:
for trigger in phase.triggers:
slug = AgentSpeakGenerator.slugify(trigger)
self._trigger_map[str(trigger.id)] = slug
self._trigger_reverse_map[slug] = str(trigger.id)
for goal in phase.goals:
self._goal_map[str(goal.id)] = AgentSpeakGenerator.slugify(goal)
self._goal_reverse_map[AgentSpeakGenerator.slugify(goal)] = str(goal.id)
for goal, id in self._goal_reverse_map.items():
self.logger.debug(f"Goal mapping: UI ID {goal} -> {id}")
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
asl_slug = AgentSpeakGenerator.slugify(norm)
norm_id = str(norm.id)
self._cond_norm_map[norm_id] = asl_slug
self._cond_norm_reverse_map[norm.norm] = norm_id
self.logger.debug("Added conditional norm %s", asl_slug)
self.logger.info(
f"Mapped {len(self._trigger_map)} triggers and {len(self._goal_map)} goals "
f"and {len(self._cond_norm_map)} conditional norms for UserInterruptAgent."
)
except Exception as e:
self.logger.error(f"Mapping failed: {e}")
async def _send_experiment_update(self, data, should_log: bool = True):
"""
Sends an update to the 'experiment' topic.
The SSE endpoint will pick this up and push it to the UI.
"""
if self.pub_socket:
topic = b"experiment"
body = json.dumps(data).encode("utf-8")
await self.pub_socket.send_multipart([topic, body])
if should_log:
self.logger.debug(f"Sent experiment update: {data}")
async def _send_to_speech_agent(self, text_to_say: str):
"""
@@ -157,26 +302,54 @@ class UserInterruptAgent(BaseAgent):
)
await self.send(out_msg)
async def _send_to_program_manager(self, belief_id: str):
"""
Send a button_override belief to the BDIProgramManager.
async def _send_to_bdi(self, thread: str, body: str):
"""Send slug of trigger to BDI"""
msg = InternalMessage(to=settings.agent_settings.bdi_core_name, thread=thread, body=body)
await self.send(msg)
self.logger.info(f"Directly forced {thread} in BDI: {body}")
:param belief_id: The belief_id that overrides the goal/trigger/conditional norm.
this id can belong to a basic belief or an inferred belief.
See also: https://utrechtuniversity.youtrack.cloud/articles/N25B-A-27/UI-components
async def _send_to_bdi_belief(self, asl_goal: str):
"""Send belief to BDI Core"""
belief_name = f"achieved_{asl_goal}"
belief = Belief(name=belief_name, arguments=None)
self.logger.debug(f"Sending belief to BDI Core: {belief_name}")
belief_message = BeliefMessage(create=[belief])
msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
thread="beliefs",
body=belief_message.model_dump_json(),
)
await self.send(msg)
async def _send_experiment_control_to_bdi_core(self, type):
"""
data = {"belief": belief_id}
message = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
method to send experiment control buttons to bdi core.
:param type: the type of control button we should send to the bdi core.
"""
# Switch which thread we should send to bdi core
thread = ""
match type:
case "next_phase":
thread = "force_next_phase"
case "reset_phase":
thread = "reset_current_phase"
case "reset_experiment":
thread = "reset_experiment"
case _:
self.logger.warning(
"Received unknown experiment control type '%s' to send to BDI Core.",
type,
)
out_msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=json.dumps(data),
thread="belief_override_id",
)
await self.send(message)
self.logger.info(
"Sent button_override belief with id '%s' to Program manager.",
belief_id,
thread=thread,
body="",
)
self.logger.debug("Sending experiment control '%s' to BDI Core.", thread)
await self.send(out_msg)
async def _send_pause_command(self, pause):
"""
@@ -209,18 +382,3 @@ class UserInterruptAgent(BaseAgent):
)
await self.send(vad_message)
self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.")
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic.
Starts the background behavior to receive the user interrupts.
"""
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("button_pressed")
self.add_behavior(self._receive_button_event())

View File

@@ -1,31 +0,0 @@
import logging
from fastapi import APIRouter, Request
from control_backend.schemas.events import ButtonPressedEvent
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/button_pressed", status_code=202)
async def receive_button_event(event: ButtonPressedEvent, request: Request):
"""
Endpoint to handle external button press events.
Validates the event payload and publishes it to the internal 'button_pressed' topic.
Subscribers (in this case user_interrupt_agent) will pick this up to trigger
specific behaviors or state changes.
:param event: The parsed ButtonPressedEvent object.
:param request: The FastAPI request object.
"""
logger.debug("Received button event: %s | %s", event.type, event.context)
topic = b"button_pressed"
body = event.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Event received"}

View File

@@ -137,7 +137,6 @@ async def ping_stream(request: Request):
logger.info("Client disconnected from SSE")
break
logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}")
connectedJson = json.dumps(connected)
yield (f"data: {connectedJson}\n\n")

View File

@@ -0,0 +1,67 @@
import asyncio
import logging
import zmq
import zmq.asyncio
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.schemas.events import ButtonPressedEvent
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/button_pressed", status_code=202)
async def receive_button_event(event: ButtonPressedEvent, request: Request):
"""
Endpoint to handle external button press events.
Validates the event payload and publishes it to the internal 'button_pressed' topic.
Subscribers (in this case user_interrupt_agent) will pick this up to trigger
specific behaviors or state changes.
:param event: The parsed ButtonPressedEvent object.
:param request: The FastAPI request object.
"""
logger.debug("Received button event: %s | %s", event.type, event.context)
topic = b"button_pressed"
body = event.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Event received"}
@router.get("/experiment_stream")
async def experiment_stream(request: Request):
# Use the asyncio-compatible context
context = Context.instance()
socket = context.socket(zmq.SUB)
# Connect and subscribe
socket.connect(settings.zmq_settings.internal_sub_address)
socket.subscribe(b"experiment")
async def gen():
try:
while True:
# Check if client closed the tab
if await request.is_disconnected():
logger.info("Client disconnected from experiment stream.")
break
try:
parts = await asyncio.wait_for(socket.recv_multipart(), timeout=1.0)
_, message = parts
yield f"data: {message.decode().strip()}\n\n"
except TimeoutError:
continue
finally:
socket.close()
return StreamingResponse(gen(), media_type="text/event-stream")

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import button_pressed, logs, message, program, robot, sse
from control_backend.api.v1.endpoints import logs, message, program, robot, sse, user_interact
api_router = APIRouter()
@@ -14,4 +14,4 @@ api_router.include_router(logs.router, tags=["Logs"])
api_router.include_router(program.router, tags=["Program"])
api_router.include_router(button_pressed.router, tags=["Button Pressed Events"])
api_router.include_router(user_interact.router, tags=["Button Pressed Events"])

View File

@@ -60,6 +60,9 @@ class BaseAgent(ABC):
self._tasks: set[asyncio.Task] = set()
self._running = False
self._internal_pub_socket: None | azmq.Socket = None
self._internal_sub_socket: None | azmq.Socket = None
# Register immediately
AgentDirectory.register(name, self)
@@ -117,7 +120,7 @@ class BaseAgent(ABC):
task.cancel()
self.logger.info(f"Agent {self.name} stopped")
async def send(self, message: InternalMessage):
async def send(self, message: InternalMessage, should_log: bool = True):
"""
Send a message to another agent.
@@ -130,16 +133,26 @@ class BaseAgent(ABC):
:param message: The message to send.
"""
target = AgentDirectory.get(message.to)
if target:
await target.inbox.put(message)
self.logger.debug(f"Sent message {message.body} to {message.to} via regular inbox.")
else:
# Apparently target agent is on a different process, send via ZMQ
topic = f"internal/{message.to}".encode()
body = message.model_dump_json().encode()
await self._internal_pub_socket.send_multipart([topic, body])
self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.")
message.sender = self.name
to = message.to
receivers = [to] if isinstance(to, str) else to
for receiver in receivers:
target = AgentDirectory.get(receiver)
if target:
await target.inbox.put(message)
if should_log:
self.logger.debug(
f"Sent message {message.body} to {message.to} via regular inbox."
)
else:
# Apparently target agent is on a different process, send via ZMQ
topic = f"internal/{receiver}".encode()
body = message.model_dump_json().encode()
await self._internal_pub_socket.send_multipart([topic, body])
if should_log:
self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.")
async def _process_inbox(self):
"""
@@ -149,7 +162,6 @@ class BaseAgent(ABC):
"""
while self._running:
msg = await self.inbox.get()
self.logger.debug(f"Received message from {msg.sender}.")
await self.handle_message(msg)
async def _receive_internal_zmq_loop(self):
@@ -192,7 +204,16 @@ class BaseAgent(ABC):
:param coro: The coroutine to execute as a task.
"""
task = asyncio.create_task(coro)
async def try_coro(coro_: Coroutine):
try:
await coro_
except asyncio.CancelledError:
self.logger.debug("A behavior was canceled successfully: %s", coro_)
except Exception:
self.logger.warning("An exception occurred in a behavior.", exc_info=True)
task = asyncio.create_task(try_coro(coro))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
return task

View File

@@ -1,3 +1,12 @@
"""
An exhaustive overview of configurable options. All of these can be set using environment variables
by nesting with double underscores (__). Start from the ``Settings`` class.
For example, ``settings.ri_host`` becomes ``RI_HOST``, and
``settings.zmq_settings.ri_communication_address`` becomes
``ZMQ_SETTINGS__RI_COMMUNICATION_ADDRESS``.
"""
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -8,16 +17,17 @@ class ZMQSettings(BaseModel):
:ivar internal_pub_address: Address for the internal PUB socket.
:ivar internal_sub_address: Address for the internal SUB socket.
:ivar ri_command_address: Address for sending commands to the Robot Interface.
:ivar ri_communication_address: Address for receiving communication from the Robot Interface.
:ivar vad_agent_address: Address for the Voice Activity Detection (VAD) agent.
:ivar ri_communication_address: Address for the endpoint that the Robot Interface connects to.
:ivar vad_pub_address: Address that the VAD agent binds to and publishes audio segments to.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561"
ri_command_address: str = "tcp://localhost:0000"
ri_communication_address: str = "tcp://*:5555"
internal_gesture_rep_adress: str = "tcp://localhost:7788"
vad_pub_address: str = "inproc://vad_stream"
class AgentSettings(BaseModel):
@@ -25,7 +35,6 @@ class AgentSettings(BaseModel):
Names of the various agents in the system. These names are used for routing messages.
:ivar bdi_core_name: Name of the BDI Core Agent.
:ivar bdi_belief_collector_name: Name of the Belief Collector Agent.
:ivar bdi_program_manager_name: Name of the BDI Program Manager Agent.
:ivar text_belief_extractor_name: Name of the Text Belief Extractor Agent.
:ivar vad_name: Name of the Voice Activity Detection (VAD) Agent.
@@ -36,9 +45,10 @@ class AgentSettings(BaseModel):
:ivar robot_speech_name: Name of the Robot Speech Agent.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
# agent names
bdi_core_name: str = "bdi_core_agent"
bdi_belief_collector_name: str = "belief_collector_agent"
bdi_program_manager_name: str = "bdi_program_manager_agent"
text_belief_extractor_name: str = "text_belief_extractor_agent"
vad_name: str = "vad_agent"
@@ -61,6 +71,7 @@ class BehaviourSettings(BaseModel):
:ivar vad_prob_threshold: Probability threshold for Voice Activity Detection.
:ivar vad_initial_since_speech: Initial value for 'since speech' counter in VAD.
:ivar vad_non_speech_patience_chunks: Number of non-speech chunks to wait before speech ended.
:ivar vad_begin_silence_chunks: The number of chunks of silence to prepend to speech chunks.
:ivar transcription_max_concurrent_tasks: Maximum number of concurrent transcription tasks.
:ivar transcription_words_per_minute: Estimated words per minute for transcription timing.
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
@@ -68,6 +79,8 @@ class BehaviourSettings(BaseModel):
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
sleep_s: float = 1.0
comm_setup_max_retries: int = 5
socket_poller_timeout_ms: int = 100
@@ -75,7 +88,8 @@ class BehaviourSettings(BaseModel):
# VAD settings
vad_prob_threshold: float = 0.5
vad_initial_since_speech: int = 100
vad_non_speech_patience_chunks: int = 3
vad_non_speech_patience_chunks: int = 15
vad_begin_silence_chunks: int = 6
# transcription behaviour
transcription_max_concurrent_tasks: int = 3
@@ -99,6 +113,8 @@ class LLMSettings(BaseModel):
:ivar n_parallel: The number of parallel calls allowed to be made to the LLM.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "gpt-oss"
chat_temperature: float = 1.0
@@ -115,6 +131,8 @@ class VADSettings(BaseModel):
:ivar sample_rate_hz: Sample rate in Hz for the VAD model.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
repo_or_dir: str = "snakers4/silero-vad"
model_name: str = "silero_vad"
sample_rate_hz: int = 16000
@@ -128,6 +146,8 @@ class SpeechModelSettings(BaseModel):
:ivar openai_model_name: Model name for OpenAI-based speech recognition.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
# model identifiers for speech recognition
mlx_model_name: str = "mlx-community/whisper-small.en-mlx"
openai_model_name: str = "small.en"
@@ -139,6 +159,7 @@ class Settings(BaseSettings):
:ivar app_title: Title of the application.
:ivar ui_url: URL of the frontend UI.
:ivar ri_host: The hostname of the Robot Interface.
:ivar zmq_settings: ZMQ configuration.
:ivar agent_settings: Agent name configuration.
:ivar behaviour_settings: Behavior configuration.
@@ -151,6 +172,8 @@ class Settings(BaseSettings):
ui_url: str = "http://localhost:5173"
ri_host: str = "localhost"
zmq_settings: ZMQSettings = ZMQSettings()
agent_settings: AgentSettings = AgentSettings()

View File

@@ -26,7 +26,6 @@ from zmq.asyncio import Context
# BDI agents
from control_backend.agents.bdi import (
BDIBeliefCollectorAgent,
BDICoreAgent,
TextBeliefExtractorAgent,
)
@@ -122,12 +121,6 @@ async def lifespan(app: FastAPI):
"name": settings.agent_settings.bdi_core_name,
},
),
"BeliefCollectorAgent": (
BDIBeliefCollectorAgent,
{
"name": settings.agent_settings.bdi_belief_collector_name,
},
),
"TextBeliefExtractorAgent": (
TextBeliefExtractorAgent,
{
@@ -172,6 +165,8 @@ async def lifespan(app: FastAPI):
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
# Additional shutdown logic goes here
for agent in agents:
await agent.stop()
logger.info("Application shutdown complete.")

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel
from control_backend.schemas.program import BaseGoal
from control_backend.schemas.program import Belief as ProgramBelief
@@ -12,3 +13,7 @@ class BeliefList(BaseModel):
"""
beliefs: list[ProgramBelief]
class GoalList(BaseModel):
goals: list[BaseGoal]

View File

@@ -11,7 +11,10 @@ class Belief(BaseModel):
"""
name: str
arguments: list[str] | None
arguments: list[str] | None = None
# To make it hashable
model_config = {"frozen": True}
class BeliefMessage(BaseModel):

View File

@@ -1,3 +1,5 @@
from collections.abc import Iterable
from pydantic import BaseModel
@@ -11,7 +13,7 @@ class InternalMessage(BaseModel):
:ivar thread: An optional thread identifier/topic to categorize the message (e.g., 'beliefs').
"""
to: str
sender: str
to: str | Iterable[str]
sender: str | None = None
body: str
thread: str | None = None

View File

@@ -15,6 +15,9 @@ class ProgramElement(BaseModel):
name: str
id: UUID4
# To make program elements hashable
model_config = {"frozen": True}
class LogicalOperator(Enum):
AND = "AND"
@@ -105,23 +108,33 @@ class Plan(ProgramElement):
steps: list[PlanElement]
class Goal(ProgramElement):
class BaseGoal(ProgramElement):
"""
Represents an objective to be achieved. To reach the goal, we should execute
the corresponding plan. If we can fail to achieve a goal after executing the plan,
for example when the achieving of the goal is dependent on the user's reply, this means
that the achieved status will be set from somewhere else in the program.
Represents an objective to be achieved. This base version does not include a plan to achieve
this goal, and is used in semantic belief extraction.
:ivar description: A description of the goal, used to determine if it has been achieved.
:ivar plan: The plan to execute.
:ivar can_fail: Whether we can fail to achieve the goal after executing the plan.
"""
description: str
plan: Plan
description: str = ""
can_fail: bool = True
class Goal(BaseGoal):
"""
Represents an objective to be achieved. To reach the goal, we should execute the corresponding
plan. It inherits from the BaseGoal a variable `can_fail`, which if true will cause the
completion to be determined based on the conversation.
Instances of this goal are not hashable because a plan is not hashable.
:ivar plan: The plan to execute.
"""
plan: Plan
type Action = SpeechAction | GestureAction | LLMAction
@@ -180,7 +193,6 @@ class Trigger(ProgramElement):
:ivar plan: The plan to execute.
"""
name: str = ""
condition: Belief
plan: Plan

View File

@@ -91,7 +91,7 @@ def test_out_socket_creation(zmq_context):
assert per_vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("inproc://vad_stream")
@pytest.mark.asyncio

View File

@@ -28,7 +28,11 @@ async def test_setup_bind(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -55,7 +59,11 @@ async def test_setup_connect(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -73,7 +81,7 @@ async def test_setup_connect(zmq_context, mocker):
async def test_handle_message_sends_valid_gesture_command():
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket
payload = {
@@ -91,7 +99,7 @@ async def test_handle_message_sends_valid_gesture_command():
async def test_handle_message_sends_non_gesture_command():
"""Internal message with non-gesture endpoint is not forwarded by this agent."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
@@ -107,7 +115,7 @@ async def test_handle_message_sends_non_gesture_command():
async def test_handle_message_rejects_invalid_gesture_tag():
"""Internal message with invalid gesture tag is not forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket
# Use a tag that's not in gesture_data
@@ -119,11 +127,70 @@ async def test_handle_message_rejects_invalid_gesture_tag():
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_message_sends_valid_single_gesture_command():
"""Internal message with valid single gesture is forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
agent.pubsocket = pubsocket
payload = {
"endpoint": RIEndpoint.GESTURE_SINGLE,
"data": "wave",
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_handle_message_rejects_invalid_single_gesture():
"""Internal message with invalid single gesture is not forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
agent.pubsocket = pubsocket
payload = {
"endpoint": RIEndpoint.GESTURE_SINGLE,
"data": "dance",
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_single_gesture_payload():
"""UI command with valid single gesture is read from SUB and published."""
command = {"endpoint": RIEndpoint.GESTURE_SINGLE, "data": "wave"}
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
@@ -142,12 +209,12 @@ async def test_zmq_command_loop_valid_gesture_payload():
async def recv_once():
# stop after first iteration
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -165,12 +232,12 @@ async def test_zmq_command_loop_valid_non_gesture_payload():
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -188,12 +255,12 @@ async def test_zmq_command_loop_invalid_gesture_tag():
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -210,12 +277,12 @@ async def test_zmq_command_loop_invalid_json():
async def recv_once():
agent._running = False
return (b"command", b"{not_json}")
return b"command", b"{not_json}"
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -232,12 +299,12 @@ async def test_zmq_command_loop_ignores_send_gestures_topic():
async def recv_once():
agent._running = False
return (b"send_gestures", b"{}")
return b"send_gestures", b"{}"
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -259,7 +326,9 @@ async def test_fetch_gestures_loop_without_amount():
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"])
agent = RobotGestureAgent(
"robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"], address=""
)
agent.repsocket = fake_repsocket
agent._running = True
@@ -287,7 +356,9 @@ async def test_fetch_gestures_loop_with_amount():
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"])
agent = RobotGestureAgent(
"robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"], address=""
)
agent.repsocket = fake_repsocket
agent._running = True
@@ -315,7 +386,7 @@ async def test_fetch_gestures_loop_with_integer_request():
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket
agent._running = True
@@ -340,7 +411,7 @@ async def test_fetch_gestures_loop_with_invalid_json():
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket
agent._running = True
@@ -365,7 +436,7 @@ async def test_fetch_gestures_loop_with_non_integer_json():
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket
agent._running = True
@@ -381,7 +452,7 @@ async def test_fetch_gestures_loop_with_non_integer_json():
def test_gesture_data_attribute():
"""Test that gesture_data returns the expected list."""
gesture_data = ["hello", "yes", "no", "wave"]
agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data)
agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data, address="")
assert agent.gesture_data == gesture_data
assert isinstance(agent.gesture_data, list)
@@ -398,7 +469,7 @@ async def test_stop_closes_sockets():
pubsocket = MagicMock()
subsocket = MagicMock()
repsocket = MagicMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", address="")
agent.pubsocket = pubsocket
agent.subsocket = subsocket
agent.repsocket = repsocket
@@ -407,15 +478,14 @@ async def test_stop_closes_sockets():
pubsocket.close.assert_called_once()
subsocket.close.assert_called_once()
# Note: repsocket is not closed in stop() method, but you might want to add it
# repsocket.close.assert_called_once()
repsocket.close.assert_called_once()
@pytest.mark.asyncio
async def test_initialization_with_custom_gesture_data():
"""Agent can be initialized with custom gesture data."""
custom_gestures = ["custom1", "custom2", "custom3"]
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures)
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures, address="")
assert agent.gesture_data == custom_gestures
@@ -432,7 +502,7 @@ async def test_fetch_gestures_loop_handles_exception():
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket
agent.logger = MagicMock()
agent._running = True

View File

@@ -30,7 +30,11 @@ async def test_setup_bind(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -48,7 +52,11 @@ async def test_setup_connect(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()

View File

@@ -0,0 +1,186 @@
import pytest
from control_backend.agents.bdi.agentspeak_ast import (
AstAtom,
AstBinaryOp,
AstLiteral,
AstLogicalExpression,
AstNumber,
AstPlan,
AstProgram,
AstRule,
AstStatement,
AstString,
AstVar,
BinaryOperatorType,
StatementType,
TriggerType,
_coalesce_expr,
)
def test_ast_atom():
atom = AstAtom("test")
assert str(atom) == "test"
assert atom._to_agentspeak() == "test"
def test_ast_var():
var = AstVar("Variable")
assert str(var) == "Variable"
assert var._to_agentspeak() == "Variable"
def test_ast_number():
num = AstNumber(42)
assert str(num) == "42"
num_float = AstNumber(3.14)
assert str(num_float) == "3.14"
def test_ast_string():
s = AstString("hello")
assert str(s) == '"hello"'
def test_ast_literal():
lit = AstLiteral("functor", [AstAtom("atom"), AstNumber(1)])
assert str(lit) == "functor(atom, 1)"
lit_empty = AstLiteral("functor")
assert str(lit_empty) == "functor"
def test_ast_binary_op():
left = AstNumber(1)
right = AstNumber(2)
op = AstBinaryOp(left, BinaryOperatorType.GREATER_THAN, right)
assert str(op) == "1 > 2"
# Test logical wrapper
assert isinstance(op.left, AstLogicalExpression)
assert isinstance(op.right, AstLogicalExpression)
def test_ast_binary_op_parens():
# 1 > 2
inner = AstBinaryOp(AstNumber(1), BinaryOperatorType.GREATER_THAN, AstNumber(2))
# (1 > 2) & 3
outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstNumber(3))
assert str(outer) == "(1 > 2) & 3"
# 3 & (1 > 2)
outer_right = AstBinaryOp(AstNumber(3), BinaryOperatorType.AND, inner)
assert str(outer_right) == "3 & (1 > 2)"
def test_ast_binary_op_parens_negated():
inner = AstLogicalExpression(AstAtom("foo"), negated=True)
outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstAtom("bar"))
# The current implementation checks `if self.left.negated: l_str = f"({l_str})"`
# str(inner) is "not foo"
# so we expect "(not foo) & bar"
assert str(outer) == "(not foo) & bar"
outer_right = AstBinaryOp(AstAtom("bar"), BinaryOperatorType.AND, inner)
assert str(outer_right) == "bar & (not foo)"
def test_ast_logical_expression_negation():
expr = AstLogicalExpression(AstAtom("true"), negated=True)
assert str(expr) == "not true"
expr_neg_neg = ~expr
assert str(expr_neg_neg) == "true"
assert not expr_neg_neg.negated
# Invert a non-logical expression (wraps it)
term = AstAtom("true")
inverted = ~term
assert isinstance(inverted, AstLogicalExpression)
assert inverted.negated
assert str(inverted) == "not true"
def test_ast_logical_expression_no_negation():
# _as_logical on already logical expression
expr = AstLogicalExpression(AstAtom("x"))
# Doing binary op will call _as_logical
op = AstBinaryOp(expr, BinaryOperatorType.AND, AstAtom("y"))
assert isinstance(op.left, AstLogicalExpression)
assert op.left is expr # Should reuse instance
def test_ast_operators():
t1 = AstAtom("a")
t2 = AstAtom("b")
assert str(t1 & t2) == "a & b"
assert str(t1 | t2) == "a | b"
assert str(t1 >= t2) == "a >= b"
assert str(t1 > t2) == "a > b"
assert str(t1 <= t2) == "a <= b"
assert str(t1 < t2) == "a < b"
assert str(t1 == t2) == "a == b"
assert str(t1 != t2) == r"a \== b"
def test_coalesce_expr():
t = AstAtom("a")
assert str(t & "b") == 'a & "b"'
assert str(t & 1) == "a & 1"
assert str(t & 1.5) == "a & 1.5"
with pytest.raises(TypeError):
_coalesce_expr(None)
def test_ast_statement():
stmt = AstStatement(StatementType.DO_ACTION, AstLiteral("action"))
assert str(stmt) == ".action"
def test_ast_rule():
# Rule with condition
rule = AstRule(AstLiteral("head"), AstLiteral("body"))
assert str(rule) == "head :- body."
# Rule without condition
rule_simple = AstRule(AstLiteral("fact"))
assert str(rule_simple) == "fact."
def test_ast_plan():
plan = AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("goal"),
[AstLiteral("context")],
[AstStatement(StatementType.DO_ACTION, AstLiteral("action"))],
)
output = str(plan)
# verify parts exist
assert "+!goal" in output
assert ": context" in output
assert "<- .action." in output
def test_ast_plan_no_context():
plan = AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("goal"),
[],
[AstStatement(StatementType.DO_ACTION, AstLiteral("action"))],
)
output = str(plan)
assert "+!goal" in output
assert ": " not in output
assert "<- .action." in output
def test_ast_program():
prog = AstProgram(
rules=[AstRule(AstLiteral("fact"))],
plans=[AstPlan(TriggerType.ADDED_BELIEF, AstLiteral("b"), [], [])],
)
output = str(prog)
assert "fact." in output
assert "+b" in output

View File

@@ -0,0 +1,187 @@
import uuid
import pytest
from control_backend.agents.bdi.agentspeak_ast import AstProgram
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.schemas.program import (
BasicNorm,
ConditionalNorm,
Gesture,
GestureAction,
Goal,
InferredBelief,
KeywordBelief,
LLMAction,
LogicalOperator,
Phase,
Plan,
Program,
SemanticBelief,
SpeechAction,
Trigger,
)
@pytest.fixture
def generator():
return AgentSpeakGenerator()
def test_generate_empty_program(generator):
prog = Program(phases=[])
code = generator.generate(prog)
assert 'phase("end").' in code
assert "!notify_cycle" in code
def test_generate_basic_norm(generator):
norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="be nice")
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
assert f'norm("be nice") :- phase("{phase.id}").' in code
def test_generate_critical_norm(generator):
norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="safety", critical=True)
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
assert f'critical_norm("safety") :- phase("{phase.id}").' in code
def test_generate_conditional_norm(generator):
cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="please")
norm = ConditionalNorm(id=uuid.uuid4(), name="n1", norm="help", condition=cond)
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
assert 'norm("help")' in code
assert 'keyword_said("please")' in code
assert f"force_norm_{generator._slugify_str(norm.norm)}" in code
def test_generate_goal_and_plan(generator):
action = SpeechAction(id=uuid.uuid4(), name="s1", text="hello")
plan = Plan(id=uuid.uuid4(), name="p1", steps=[action])
# IMPORTANT: can_fail must be False for +achieved_ belief to be added
goal = Goal(id=uuid.uuid4(), name="g1", description="desc", plan=plan, can_fail=False)
phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
# Check trigger for goal
goal_slug = generator._slugify_str(goal.name)
assert f"+!{goal_slug}" in code
assert f'phase("{phase.id}")' in code
assert '!say("hello")' in code
# Check success belief addition
assert f"+achieved_{goal_slug}" in code
def test_generate_subgoal(generator):
subplan = Plan(id=uuid.uuid4(), name="p2", steps=[])
subgoal = Goal(id=uuid.uuid4(), name="sub1", description="sub", plan=subplan)
plan = Plan(id=uuid.uuid4(), name="p1", steps=[subgoal])
goal = Goal(id=uuid.uuid4(), name="g1", description="main", plan=plan)
phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
subgoal_slug = generator._slugify_str(subgoal.name)
# Main goal calls subgoal
assert f"!{subgoal_slug}" in code
# Subgoal plan exists
assert f"+!{subgoal_slug}" in code
def test_generate_trigger(generator):
cond = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc")
plan = Plan(id=uuid.uuid4(), name="p1", steps=[])
trigger = Trigger(id=uuid.uuid4(), name="t1", condition=cond, plan=plan)
phase = Phase(id=uuid.uuid4(), norms=[], goals=[], triggers=[trigger])
prog = Program(phases=[phase])
code = generator.generate(prog)
# Trigger logic is added to check_triggers
assert f"{generator.slugify(cond)}" in code
assert f'notify_trigger_start("{generator.slugify(trigger)}")' in code
assert f'notify_trigger_end("{generator.slugify(trigger)}")' in code
def test_phase_transition(generator):
phase1 = Phase(id=uuid.uuid4(), name="p1", norms=[], goals=[], triggers=[])
phase2 = Phase(id=uuid.uuid4(), name="p2", norms=[], goals=[], triggers=[])
prog = Program(phases=[phase1, phase2])
code = generator.generate(prog)
assert "transition_phase" in code
assert f'phase("{phase1.id}")' in code
assert f'phase("{phase2.id}")' in code
assert "force_transition_phase" in code
def test_astify_gesture(generator):
gesture = Gesture(type="single", name="wave")
action = GestureAction(id=uuid.uuid4(), name="g1", gesture=gesture)
ast = generator._astify(action)
assert str(ast) == 'gesture("single", "wave")'
def test_astify_llm_action(generator):
action = LLMAction(id=uuid.uuid4(), name="l1", goal="be funny")
ast = generator._astify(action)
assert str(ast) == 'reply_with_goal("be funny")'
def test_astify_inferred_belief_and(generator):
left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a")
right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b")
inf = InferredBelief(
id=uuid.uuid4(), name="i1", operator=LogicalOperator.AND, left=left, right=right
)
ast = generator._astify(inf)
assert 'keyword_said("a") & keyword_said("b")' == str(ast)
def test_astify_inferred_belief_or(generator):
left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a")
right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b")
inf = InferredBelief(
id=uuid.uuid4(), name="i1", operator=LogicalOperator.OR, left=left, right=right
)
ast = generator._astify(inf)
assert 'keyword_said("a") | keyword_said("b")' == str(ast)
def test_astify_semantic_belief(generator):
sb = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc")
ast = generator._astify(sb)
assert str(ast) == f"semantic_{generator._slugify_str(sb.name)}"
def test_slugify_not_implemented(generator):
with pytest.raises(NotImplementedError):
generator.slugify("not a program element")
def test_astify_not_implemented(generator):
with pytest.raises(NotImplementedError):
generator._astify("not a program element")
def test_process_phase_transition_from_none(generator):
# Initialize AstProgram manually as we are bypassing generate()
generator._asp = AstProgram()
# Should safely return doing nothing
generator._add_phase_transition(None, None)
assert len(generator._asp.plans) == 0

View File

@@ -45,23 +45,34 @@ async def test_setup_no_asl(mock_agentspeak_env, agent):
@pytest.mark.asyncio
async def test_handle_belief_collector_message(agent, mock_settings):
async def test_handle_belief_message(agent, mock_settings):
"""Test that incoming beliefs are added to the BDI agent"""
beliefs = [Belief(name="user_said", arguments=["Hello"])]
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
sender=mock_settings.agent_settings.text_belief_extractor_name,
body=BeliefMessage(create=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
# Expect bdi_agent.call to be triggered to add belief
args = agent.bdi_agent.call.call_args.args
assert args[0] == agentspeak.Trigger.addition
assert args[1] == agentspeak.GoalType.belief
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
# Check for the specific call we expect among all calls
# bdi_agent.call is called multiple times (for transition_phase, check_triggers)
# We want to confirm the belief addition call exists
found_call = False
for call in agent.bdi_agent.call.call_args_list:
args = call.args
if (
args[0] == agentspeak.Trigger.addition
and args[1] == agentspeak.GoalType.belief
and args[2].functor == "user_said"
and args[2].args[0].functor == "Hello"
):
found_call = True
break
assert found_call, "Expected belief addition call not found in bdi_agent.call history"
@pytest.mark.asyncio
@@ -71,25 +82,33 @@ async def test_handle_delete_belief_message(agent, mock_settings):
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
sender=mock_settings.agent_settings.text_belief_extractor_name,
body=BeliefMessage(delete=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
# Expect bdi_agent.call to be triggered to remove belief
args = agent.bdi_agent.call.call_args.args
assert args[0] == agentspeak.Trigger.removal
assert args[1] == agentspeak.GoalType.belief
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
found_call = False
for call in agent.bdi_agent.call.call_args_list:
args = call.args
if (
args[0] == agentspeak.Trigger.removal
and args[1] == agentspeak.GoalType.belief
and args[2].functor == "user_said"
and args[2].args[0].functor == "Hello"
):
found_call = True
break
assert found_call
@pytest.mark.asyncio
async def test_incorrect_belief_collector_message(agent, mock_settings):
async def test_incorrect_belief_message(agent, mock_settings):
"""Test that incorrect message format triggers an exception."""
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
sender=mock_settings.agent_settings.text_belief_extractor_name,
body=json.dumps({"bad_format": "bad_format"}),
thread="beliefs",
)
@@ -171,7 +190,11 @@ def test_remove_belief_success_wakes_loop(agent):
agent._remove_belief("remove_me", ["x"])
assert agent.bdi_agent.call.called
trigger, goaltype, literal, *_ = agent.bdi_agent.call.call_args.args
call_args = agent.bdi_agent.call.call_args.args
trigger = call_args[0]
goaltype = call_args[1]
literal = call_args[2]
assert trigger == agentspeak.Trigger.removal
assert goaltype == agentspeak.GoalType.belief
@@ -288,3 +311,216 @@ async def test_deadline_sleep_branch(agent):
duration = time.time() - start_time
assert duration >= 0.004 # loop slept until deadline
@pytest.mark.asyncio
async def test_handle_new_program(agent):
agent._load_asl = AsyncMock()
agent.add_behavior = MagicMock()
# Mock existing loop task so it can be cancelled
mock_task = MagicMock()
mock_task.cancel = MagicMock()
agent._bdi_loop_task = mock_task
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
msg = InternalMessage(to="bdi_agent", thread="new_program", body="path/to/asl.asl")
await agent.handle_message(msg)
mock_task.cancel.assert_called_once()
agent._load_asl.assert_awaited_once_with("path/to/asl.asl")
agent.add_behavior.assert_called()
@pytest.mark.asyncio
async def test_handle_user_interrupts(agent, mock_settings):
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
# force_phase_transition
agent._set_goal = MagicMock()
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.user_interrupt_name,
thread="force_phase_transition",
body="",
)
await agent.handle_message(msg)
agent._set_goal.assert_called_with("transition_phase")
# force_trigger
agent._force_trigger = MagicMock()
msg.thread = "force_trigger"
msg.body = "trigger_x"
await agent.handle_message(msg)
agent._force_trigger.assert_called_with("trigger_x")
# force_norm
agent._force_norm = MagicMock()
msg.thread = "force_norm"
msg.body = "norm_y"
await agent.handle_message(msg)
agent._force_norm.assert_called_with("norm_y")
# force_next_phase
agent._force_next_phase = MagicMock()
msg.thread = "force_next_phase"
msg.body = ""
await agent.handle_message(msg)
agent._force_next_phase.assert_called_once()
# unknown interrupt
agent.logger = MagicMock()
msg.thread = "unknown_thing"
await agent.handle_message(msg)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_custom_action_reply_with_goal(agent):
agent._send_to_llm = MagicMock(side_effect=agent.send)
agent._add_custom_actions()
action_fn = agent.actions.actions[(".reply_with_goal", 3)]
mock_term = MagicMock(args=["msg", "norms", "goal"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
agent._send_to_llm.assert_called_with("msg", "norms", "goal")
@pytest.mark.asyncio
async def test_custom_action_notify_norms(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".notify_norms", 1)]
mock_term = MagicMock(args=["norms_list"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
agent.send.assert_called()
msg = agent.send.call_args[0][0]
assert msg.thread == "active_norms_update"
assert msg.body == "norms_list"
@pytest.mark.asyncio
async def test_custom_action_say(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".say", 1)]
mock_term = MagicMock(args=["hello"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
assert agent.send.call_count == 2
msgs = [c[0][0] for c in agent.send.call_args_list]
assert any(m.to == settings.agent_settings.robot_speech_name for m in msgs)
assert any(
m.to == settings.agent_settings.llm_name and m.thread == "assistant_message" for m in msgs
)
@pytest.mark.asyncio
async def test_custom_action_gesture(agent):
agent._add_custom_actions()
# Test single
action_fn = agent.actions.actions[(".gesture", 2)]
mock_term = MagicMock(args=["single", "wave"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
msg = agent.send.call_args[0][0]
assert "actuate/gesture/single" in msg.body
# Test tag
mock_term.args = ["tag", "happy"]
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
msg = agent.send.call_args[0][0]
assert "actuate/gesture/tag" in msg.body
@pytest.mark.asyncio
async def test_custom_action_notify_user_said(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".notify_user_said", 1)]
mock_term = MagicMock(args=["hello"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
msg = agent.send.call_args[0][0]
assert msg.to == settings.agent_settings.llm_name
assert msg.thread == "user_message"
@pytest.mark.asyncio
async def test_custom_action_notify_trigger_start_end(agent):
agent._add_custom_actions()
# Start
action_fn = agent.actions.actions[(".notify_trigger_start", 1)]
gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock())
next(gen)
assert agent.send.call_args[0][0].thread == "trigger_start"
# End
action_fn = agent.actions.actions[(".notify_trigger_end", 1)]
gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock())
next(gen)
assert agent.send.call_args[0][0].thread == "trigger_end"
@pytest.mark.asyncio
async def test_custom_action_notify_goal_start(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".notify_goal_start", 1)]
gen = action_fn(agent, MagicMock(args=["g1"]), MagicMock())
next(gen)
assert agent.send.call_args[0][0].thread == "goal_start"
@pytest.mark.asyncio
async def test_custom_action_notify_transition_phase(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".notify_transition_phase", 2)]
gen = action_fn(agent, MagicMock(args=["old", "new"]), MagicMock())
next(gen)
msg = agent.send.call_args[0][0]
assert msg.thread == "transition_phase"
assert "old" in msg.body and "new" in msg.body
def test_remove_belief_no_args(agent):
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = True
agent._remove_belief("fact", None)
assert agent.bdi_agent.call.called
def test_set_goal_with_args(agent):
agent._wake_bdi_loop = MagicMock()
agent._set_goal("goal", ["arg1", "arg2"])
assert agent.bdi_agent.call.called
def test_format_belief_string():
assert BDICoreAgent.format_belief_string("b") == "b"
assert BDICoreAgent.format_belief_string("b", ["a1", "a2"]) == "b(a1,a2)"
def test_force_norm(agent):
agent._add_belief = MagicMock()
agent._force_norm("be_polite")
agent._add_belief.assert_called_with("force_be_polite")
def test_force_trigger(agent):
agent._set_goal = MagicMock()
agent._force_trigger("trig")
agent._set_goal.assert_called_with("trig")
def test_force_next_phase(agent):
agent._set_goal = MagicMock()
agent._force_next_phase()
agent._set_goal.assert_called_with("force_transition_phase")

View File

@@ -1,13 +1,13 @@
import asyncio
import json
import sys
import uuid
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import pytest
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program
# Fix Windows Proactor loop for zmq
@@ -48,24 +48,26 @@ def make_valid_program_json(norm="N1", goal="G1") -> str:
).model_dump_json()
@pytest.mark.skip(reason="Functionality being rebuilt.")
@pytest.mark.asyncio
async def test_send_to_bdi():
async def test_create_agentspeak_and_send_to_bdi(mock_settings):
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
program = Program.model_validate_json(make_valid_program_json())
await manager._create_agentspeak_and_send_to_bdi(program)
with patch("builtins.open", mock_open()) as mock_file:
await manager._create_agentspeak_and_send_to_bdi(program)
# Check file writing
mock_file.assert_called_with("src/control_backend/agents/bdi/agentspeak.asl", "w")
handle = mock_file()
handle.write.assert_called()
assert manager.send.await_count == 1
msg: InternalMessage = manager.send.await_args[0][0]
assert msg.thread == "beliefs"
beliefs = BeliefMessage.model_validate_json(msg.body)
names = {b.name: b.arguments for b in beliefs.beliefs}
assert "norms" in names and names["norms"] == ["N1"]
assert "goals" in names and names["goals"] == ["G1"]
assert msg.thread == "new_program"
assert msg.to == mock_settings.agent_settings.bdi_core_name
assert msg.body == "src/control_backend/agents/bdi/agentspeak.asl"
@pytest.mark.asyncio
@@ -80,6 +82,10 @@ async def test_receive_programs_valid_and_invalid():
manager._internal_pub_socket = AsyncMock()
manager.sub_socket = sub
manager._create_agentspeak_and_send_to_bdi = AsyncMock()
manager._send_clear_llm_history = AsyncMock()
manager._send_program_to_user_interrupt = AsyncMock()
manager._send_beliefs_to_semantic_belief_extractor = AsyncMock()
manager._send_goals_to_semantic_belief_extractor = AsyncMock()
try:
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
@@ -92,3 +98,200 @@ async def test_receive_programs_valid_and_invalid():
forwarded: Program = manager._create_agentspeak_and_send_to_bdi.await_args[0][0]
assert forwarded.phases[0].norms[0].name == "N1"
assert forwarded.phases[0].goals[0].name == "G1"
# Verify history clear was triggered exactly once (for the valid program)
# The invalid program loop `continue`s before calling _send_clear_llm_history
assert manager._send_clear_llm_history.await_count == 1
@pytest.mark.asyncio
async def test_send_clear_llm_history(mock_settings):
# Ensure the mock returns a string for the agent name (just like in your LLM tests)
mock_settings.agent_settings.llm_agent_name = "llm_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
await manager._send_clear_llm_history()
assert manager.send.await_count == 2
msg: InternalMessage = manager.send.await_args_list[0][0][0]
# Verify the content and recipient
assert msg.body == "clear_history"
@pytest.mark.asyncio
async def test_handle_message_transition_phase(mock_settings):
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
# Setup state
prog = Program.model_validate_json(make_valid_program_json(norm="N1", goal="G1"))
manager._initialize_internal_state(prog)
# Test valid transition (to same phase for simplicity, or we need 2 phases)
# Let's create a program with 2 phases
phase2_id = uuid.uuid4()
phase2 = Phase(id=phase2_id, name="Phase 2", norms=[], goals=[], triggers=[])
prog.phases.append(phase2)
manager._initialize_internal_state(prog)
current_phase_id = str(prog.phases[0].id)
next_phase_id = str(phase2_id)
payload = json.dumps({"old": current_phase_id, "new": next_phase_id})
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
await manager.handle_message(msg)
assert str(manager._phase.id) == next_phase_id
# Allow background tasks to run (add_behavior)
await asyncio.sleep(0)
# Check notifications sent
# 1. beliefs to extractor
# 2. goals to extractor
# 3. notification to user interrupt
assert manager.send.await_count >= 3
# Verify user interrupt notification
calls = manager.send.await_args_list
ui_msgs = [
c[0][0] for c in calls if c[0][0].to == mock_settings.agent_settings.user_interrupt_name
]
assert len(ui_msgs) > 0
assert ui_msgs[-1].body == next_phase_id
@pytest.mark.asyncio
async def test_handle_message_transition_phase_desync():
manager = BDIProgramManager(name="program_manager_test")
manager.logger = MagicMock()
prog = Program.model_validate_json(make_valid_program_json())
manager._initialize_internal_state(prog)
current_phase_id = str(prog.phases[0].id)
# Request transition from WRONG old phase
payload = json.dumps({"old": "wrong_id", "new": "some_new_id"})
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
await manager.handle_message(msg)
# Should warn and do nothing
manager.logger.warning.assert_called_once()
assert "Phase transition desync detected" in manager.logger.warning.call_args[0][0]
assert str(manager._phase.id) == current_phase_id
@pytest.mark.asyncio
async def test_handle_message_transition_phase_end(mock_settings):
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
prog = Program.model_validate_json(make_valid_program_json())
manager._initialize_internal_state(prog)
current_phase_id = str(prog.phases[0].id)
payload = json.dumps({"old": current_phase_id, "new": "end"})
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
await manager.handle_message(msg)
assert manager._phase is None
# Allow background tasks to run (add_behavior)
await asyncio.sleep(0)
# Verify notification to user interrupt
assert manager.send.await_count == 1
msg_sent = manager.send.await_args[0][0]
assert msg_sent.to == mock_settings.agent_settings.user_interrupt_name
assert msg_sent.body == "end"
@pytest.mark.asyncio
async def test_handle_message_achieve_goal(mock_settings):
mock_settings.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
prog = Program.model_validate_json(make_valid_program_json(goal="TargetGoal"))
manager._initialize_internal_state(prog)
goal_id = str(prog.phases[0].goals[0].id)
msg = InternalMessage(to="me", sender="ui", body=goal_id, thread="achieve_goal")
await manager.handle_message(msg)
# Should send achieved goals to text extractor
assert manager.send.await_count == 1
msg_sent = manager.send.await_args[0][0]
assert msg_sent.to == mock_settings.agent_settings.text_belief_extractor_name
assert msg_sent.thread == "achieved_goals"
# Verify body
from control_backend.schemas.belief_list import GoalList
gl = GoalList.model_validate_json(msg_sent.body)
assert len(gl.goals) == 1
assert gl.goals[0].name == "TargetGoal"
@pytest.mark.asyncio
async def test_handle_message_achieve_goal_not_found():
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
manager.logger = MagicMock()
prog = Program.model_validate_json(make_valid_program_json())
manager._initialize_internal_state(prog)
msg = InternalMessage(to="me", sender="ui", body="non_existent_id", thread="achieve_goal")
await manager.handle_message(msg)
manager.send.assert_not_called()
manager.logger.debug.assert_called()
@pytest.mark.asyncio
async def test_setup(mock_settings):
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
def close_coro(coro):
coro.close()
return MagicMock()
manager.add_behavior = MagicMock(side_effect=close_coro)
mock_context = MagicMock()
mock_sub = MagicMock()
mock_context.socket.return_value = mock_sub
with patch(
"control_backend.agents.bdi.bdi_program_manager.Context.instance", return_value=mock_context
):
# We also need to mock file writing in _create_agentspeak_and_send_to_bdi
with patch("builtins.open", new_callable=MagicMock):
await manager.setup()
# Check logic
# 1. Sends default empty program to BDI
assert manager.send.await_count == 1
assert manager.send.await_args[0][0].to == mock_settings.agent_settings.bdi_core_name
# 2. Connects SUB socket
mock_sub.connect.assert_called_with(mock_settings.zmq_settings.internal_sub_address)
mock_sub.subscribe.assert_called_with("program")
# 3. Adds behavior
manager.add_behavior.assert_called()

View File

@@ -1,135 +0,0 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
BDIBeliefCollectorAgent,
)
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief
@pytest.fixture
def agent():
agent = BDIBeliefCollectorAgent("belief_collector_agent")
return agent
def make_msg(body: dict, sender: str = "sender"):
return InternalMessage(to="collector", sender=sender, body=json.dumps(body))
@pytest.mark.asyncio
async def test_handle_message_routes_belief_text(agent, mocker):
"""
Test that when a message is received, _handle_belief_text is called with that message.
"""
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}
spy = mocker.patch.object(agent, "_handle_belief_text", new_callable=AsyncMock)
await agent.handle_message(make_msg(payload))
spy.assert_awaited_once_with(payload, "sender")
@pytest.mark.asyncio
async def test_handle_message_routes_emotion(agent, mocker):
payload = {"type": "emotion_extraction_text"}
spy = mocker.patch.object(agent, "_handle_emo_text", new_callable=AsyncMock)
await agent.handle_message(make_msg(payload))
spy.assert_awaited_once_with(payload, "sender")
@pytest.mark.asyncio
async def test_handle_message_bad_json(agent, mocker):
agent._handle_belief_text = AsyncMock()
bad_msg = InternalMessage(to="collector", sender="sender", body="not json")
await agent.handle_message(bad_msg)
agent._handle_belief_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
expected = [Belief(name="user_said", arguments=["hello"])]
await agent._handle_belief_text(payload, "origin")
spy.assert_awaited_once_with(expected, origin="origin")
@pytest.mark.asyncio
async def test_handle_belief_text_no_send_when_empty(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
await agent._handle_belief_text(payload, "origin")
spy.assert_not_awaited()
@pytest.mark.asyncio
async def test_send_beliefs_to_bdi(agent):
agent.send = AsyncMock()
beliefs = [Belief(name="user_said", arguments=["hello", "world"])]
await agent._send_beliefs_to_bdi(beliefs, origin="origin")
agent.send.assert_awaited_once()
sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
assert json.loads(sent.body)["create"] == [belief.model_dump() for belief in beliefs]
@pytest.mark.asyncio
async def test_setup_executes(agent):
"""Covers setup and asserts the agent has a name."""
await agent.setup()
assert agent.name == "belief_collector_agent" # simple property assertion
@pytest.mark.asyncio
async def test_handle_message_unrecognized_type_executes(agent):
"""Covers the else branch for unrecognized message type."""
payload = {"type": "unknown_type"}
msg = make_msg(payload, sender="tester")
# Wrap send to ensure nothing is sent
agent.send = AsyncMock()
await agent.handle_message(msg)
# Assert no messages were sent
agent.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_emo_text_executes(agent):
"""Covers the _handle_emo_text method."""
# The method does nothing, but we can assert it returns None
result = await agent._handle_emo_text({}, "origin")
assert result is None
@pytest.mark.asyncio
async def test_send_beliefs_to_bdi_empty_executes(agent):
"""Covers early return when beliefs are empty."""
agent.send = AsyncMock()
await agent._send_beliefs_to_bdi({})
# Assert that nothing was sent
agent.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_belief_text_invalid_returns_none(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "invalid-argument"}}
result = await agent._handle_belief_text(payload, "origin")
# The method itself returns None
assert result is None

View File

@@ -6,11 +6,15 @@ import httpx
import pytest
from control_backend.agents.bdi import TextBeliefExtractorAgent
from control_backend.agents.bdi.text_belief_extractor_agent import BeliefState
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_list import BeliefList
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import (
BaseGoal, # Changed from Goal
ConditionalNorm,
KeywordBelief,
LLMAction,
@@ -23,11 +27,22 @@ from control_backend.schemas.program import (
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
agent._query_llm = AsyncMock()
return agent
def llm():
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
# We must ensure _query_llm returns a dictionary so iterating it doesn't fail
llm._query_llm = AsyncMock(return_value={})
return llm
@pytest.fixture
def agent(llm):
with patch(
"control_backend.agents.bdi.text_belief_extractor_agent.TextBeliefExtractorAgent.LLM",
return_value=llm,
):
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
return agent
@pytest.fixture
@@ -102,24 +117,12 @@ async def test_handle_message_from_transcriber(agent, mock_settings):
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.to == mock_settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_user_said(agent, mock_settings):
transcription = "this is a test"
await agent._user_said(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]
parsed = BeliefMessage.model_validate_json(sent.body)
replaced_last = parsed.replace.pop()
assert replaced_last.name == "user_said"
assert replaced_last.arguments == [transcription]
@pytest.mark.asyncio
@@ -144,46 +147,46 @@ async def test_query_llm():
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
return_value=mock_async_client,
):
agent = TextBeliefExtractorAgent("text_belief_agent")
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
res = await agent._query_llm("hello world", {"type": "null"})
res = await llm._query_llm("hello world", {"type": "null"})
# Response content was set as "null", so should be deserialized as None
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success(agent):
agent._query_llm.return_value = None
res = await agent._retry_query_llm("hello world", {"type": "null"})
async def test_retry_query_llm_success(llm):
llm._query_llm.return_value = None
res = await llm.query("hello world", {"type": "null"})
agent._query_llm.assert_called_once()
llm._query_llm.assert_called_once()
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success_after_failure(agent):
agent._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"})
async def test_retry_query_llm_success_after_failure(llm):
llm._query_llm.side_effect = [KeyError(), "real value"]
res = await llm.query("hello world", {"type": "string"})
assert agent._query_llm.call_count == 2
assert llm._query_llm.call_count == 2
assert res == "real value"
@pytest.mark.asyncio
async def test_retry_query_llm_failures(agent):
agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"})
async def test_retry_query_llm_failures(llm):
llm._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
res = await llm.query("hello world", {"type": "string"})
assert agent._query_llm.call_count == 3
assert llm._query_llm.call_count == 3
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_fail_immediately(agent):
agent._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1)
async def test_retry_query_llm_fail_immediately(llm):
llm._query_llm.side_effect = [KeyError(), "real value"]
res = await llm.query("hello world", {"type": "string"}, tries=1)
assert agent._query_llm.call_count == 1
assert llm._query_llm.call_count == 1
assert res is None
@@ -192,7 +195,7 @@ async def test_extracting_semantic_beliefs(agent):
"""
The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly.
"""
assert len(agent.available_beliefs) == 0
assert len(agent.belief_inferrer.available_beliefs) == 0
beliefs = BeliefList(
beliefs=[
KeywordBelief(
@@ -213,26 +216,28 @@ async def test_extracting_semantic_beliefs(agent):
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=beliefs.model_dump_json(),
thread="beliefs",
),
)
assert len(agent.available_beliefs) == 2
assert len(agent.belief_inferrer.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_invalid_program(agent, sample_program):
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
assert len(agent.available_beliefs) == 2
async def test_handle_invalid_beliefs(agent, sample_program):
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
assert len(agent.belief_inferrer.available_beliefs) == 2
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=json.dumps({"phases": "Invalid"}),
thread="beliefs",
),
)
assert len(agent.available_beliefs) == 2
assert len(agent.belief_inferrer.available_beliefs) == 2
@pytest.mark.asyncio
@@ -254,13 +259,13 @@ async def test_handle_robot_response(agent):
@pytest.mark.asyncio
async def test_simulated_real_turn_with_beliefs(agent, sample_program):
async def test_simulated_real_turn_with_beliefs(agent, llm, sample_program):
"""Test sending user message to extract beliefs from."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with the belief that there's no more booze
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
assert len(agent.conversation.messages) == 0
await agent.handle_message(
InternalMessage(
@@ -275,20 +280,20 @@ async def test_simulated_real_turn_with_beliefs(agent, sample_program):
assert agent.send.call_count == 2
# First should be the beliefs message
message: InternalMessage = agent.send.call_args_list[0].args[0]
message: InternalMessage = agent.send.call_args_list[1].args[0]
beliefs = BeliefMessage.model_validate_json(message.body)
assert len(beliefs.create) == 1
assert beliefs.create[0].name == "no_more_booze"
@pytest.mark.asyncio
async def test_simulated_real_turn_no_beliefs(agent, sample_program):
async def test_simulated_real_turn_no_beliefs(agent, llm, sample_program):
"""Test a user message to extract beliefs from, but no beliefs are formed."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with no new beliefs
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
@@ -302,17 +307,17 @@ async def test_simulated_real_turn_no_beliefs(agent, sample_program):
@pytest.mark.asyncio
async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
async def test_simulated_real_turn_no_new_beliefs(agent, llm, sample_program):
"""
Test a user message to extract beliefs from, but no new beliefs are formed because they already
existed.
"""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["is_pirate"] = True
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent._current_beliefs = BeliefState(true={InternalBelief(name="is_pirate", arguments=None)})
# Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
llm._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
@@ -326,17 +331,19 @@ async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
@pytest.mark.asyncio
async def test_simulated_real_turn_remove_belief(agent, sample_program):
async def test_simulated_real_turn_remove_belief(agent, llm, sample_program):
"""
Test a user message to extract beliefs from, but an existing belief is determined no longer to
hold.
"""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["no_more_booze"] = True
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent._current_beliefs = BeliefState(
true={InternalBelief(name="no_more_booze", arguments=None)},
)
# Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
@@ -349,18 +356,199 @@ async def test_simulated_real_turn_remove_belief(agent, sample_program):
assert agent.send.call_count == 2
# Agent's current beliefs should've changed
assert not agent.beliefs["no_more_booze"]
assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false)
@pytest.mark.asyncio
async def test_llm_failure_handling(agent, sample_program):
async def test_infer_goal_completions_sends_beliefs(agent, llm):
"""Test that inferred goal completions are sent to the BDI core."""
goal = BaseGoal(
id=uuid.uuid4(), name="Say Hello", description="The user said hello", can_fail=True
)
agent.goal_inferrer.goals = {goal}
# Mock goal inference: goal is achieved
llm.query = AsyncMock(return_value=True)
await agent._infer_goal_completions()
# Should send belief change to BDI core
agent.send.assert_awaited_once()
sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
parsed = BeliefMessage.model_validate_json(sent.body)
assert len(parsed.create) == 1
assert parsed.create[0].name == "achieved_say_hello"
@pytest.mark.asyncio
async def test_llm_failure_handling(agent, llm, sample_program):
"""
Check that the agent handles failures gracefully without crashing.
"""
agent._query_llm.side_effect = httpx.HTTPError("")
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
llm._query_llm.side_effect = httpx.HTTPError("")
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
belief_changes = await agent._infer_turn()
belief_changes = await agent.belief_inferrer.infer_from_conversation(
ChatHistory(
messages=[ChatMessage(role="user", content="Good day!")],
),
)
assert len(belief_changes) == 0
assert len(belief_changes.true) == 0
assert len(belief_changes.false) == 0
def test_belief_state_bool():
# Empty
bs = BeliefState()
assert not bs
# True set
bs_true = BeliefState(true={InternalBelief(name="a", arguments=None)})
assert bs_true
# False set
bs_false = BeliefState(false={InternalBelief(name="a", arguments=None)})
assert bs_false
@pytest.mark.asyncio
async def test_handle_beliefs_message_validation_error(agent, mock_settings):
# Invalid JSON
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
msg = InternalMessage(
to="me",
sender=mock_settings.agent_settings.bdi_program_manager_name,
thread="beliefs",
body="invalid json",
)
# Should log warning and return
agent.logger = MagicMock()
await agent.handle_message(msg)
agent.logger.warning.assert_called()
# Invalid Model
msg.body = json.dumps({"beliefs": [{"invalid": "obj"}]})
await agent.handle_message(msg)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_handle_goals_message_validation_error(agent, mock_settings):
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
msg = InternalMessage(
to="me",
sender=mock_settings.agent_settings.bdi_program_manager_name,
thread="goals",
body="invalid json",
)
agent.logger = MagicMock()
await agent.handle_message(msg)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_handle_goal_achieved_message_validation_error(agent, mock_settings):
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
msg = InternalMessage(
to="me",
sender=mock_settings.agent_settings.bdi_program_manager_name,
thread="achieved_goals",
body="invalid json",
)
agent.logger = MagicMock()
await agent.handle_message(msg)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_goal_inferrer_infer_from_conversation(agent, llm):
# Setup goals
# Use BaseGoal object as typically received by the extractor
g1 = BaseGoal(id=uuid.uuid4(), name="g1", description="desc", can_fail=True)
# Use real GoalAchievementInferrer
from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer
inferrer = GoalAchievementInferrer(llm)
inferrer.goals = {g1}
# Mock LLM response
llm._query_llm.return_value = True
completions = await inferrer.infer_from_conversation(ChatHistory(messages=[]))
assert completions
# slugify uses slugify library, hard to predict exact string without it,
# but we can check values
assert list(completions.values())[0] is True
def test_apply_conversation_message_limit(agent):
with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s:
mock_s.behaviour_settings.conversation_history_length_limit = 2
agent.conversation.messages = []
agent._apply_conversation_message(ChatMessage(role="user", content="1"))
agent._apply_conversation_message(ChatMessage(role="assistant", content="2"))
agent._apply_conversation_message(ChatMessage(role="user", content="3"))
assert len(agent.conversation.messages) == 2
assert agent.conversation.messages[0].content == "2"
assert agent.conversation.messages[1].content == "3"
@pytest.mark.asyncio
async def test_handle_program_manager_reset(agent):
with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s:
mock_s.agent_settings.bdi_program_manager_name = "pm"
agent.conversation.messages = [ChatMessage(role="user", content="hi")]
agent.belief_inferrer.available_beliefs = [
SemanticBelief(id=uuid.uuid4(), name="b", description="d")
]
msg = InternalMessage(to="me", sender="pm", thread="conversation_history", body="reset")
await agent.handle_message(msg)
assert len(agent.conversation.messages) == 0
assert len(agent.belief_inferrer.available_beliefs) == 0
def test_split_into_chunks():
from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer
items = [1, 2, 3, 4, 5]
chunks = SemanticBeliefInferrer._split_into_chunks(items, 2)
assert len(chunks) == 2
assert len(chunks[0]) + len(chunks[1]) == 5
@pytest.mark.asyncio
async def test_infer_beliefs_call(agent, llm):
from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer
inferrer = SemanticBeliefInferrer(llm)
sb = SemanticBelief(id=uuid.uuid4(), name="is_happy", description="User is happy")
llm.query = AsyncMock(return_value={"is_happy": True})
res = await inferrer._infer_beliefs(ChatHistory(messages=[]), [sb])
assert res == {"is_happy": True}
llm.query.assert_called_once()
@pytest.mark.asyncio
async def test_infer_goal_call(agent, llm):
from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer
inferrer = GoalAchievementInferrer(llm)
goal = BaseGoal(id=uuid.uuid4(), name="g1", description="d")
llm.query = AsyncMock(return_value=True)
res = await inferrer._infer_goal(ChatHistory(messages=[]), goal)
assert res is True
llm.query.assert_called_once()

View File

@@ -4,6 +4,8 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.ri_message import PauseCommand, RIEndpoint
def speech_agent_path():
@@ -53,7 +55,11 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
MockGesture.return_value.start = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -83,7 +89,11 @@ async def test_setup_binds_when_requested(zmq_context):
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
@@ -151,6 +161,7 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context):
@pytest.mark.asyncio
async def test_handle_disconnection_publishes_and_reconnects():
pub_socket = AsyncMock()
pub_socket.close = MagicMock()
agent = RICommunicationAgent("ri_comm")
agent.pub_socket = pub_socket
agent.connected = True
@@ -233,6 +244,25 @@ async def test_handle_negotiation_response_unhandled_id():
)
@pytest.mark.asyncio
async def test_handle_negotiation_response_audio(zmq_context):
agent = RICommunicationAgent("ri_comm")
with patch(
"control_backend.agents.communication.ri_communication_agent.VADAgent", autospec=True
) as MockVAD:
MockVAD.return_value.start = AsyncMock()
await agent._handle_negotiation_response(
{"data": [{"id": "audio", "port": 7000, "bind": False}]}
)
MockVAD.assert_called_once_with(
audio_in_address="tcp://localhost:7000", audio_in_bind=False
)
MockVAD.return_value.start.assert_awaited_once()
@pytest.mark.asyncio
async def test_stop_closes_sockets():
req = MagicMock()
@@ -323,6 +353,7 @@ async def test_listen_loop_generic_exception():
@pytest.mark.asyncio
async def test_handle_disconnection_timeout(monkeypatch):
pub = AsyncMock()
pub.close = MagicMock()
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
agent = RICommunicationAgent("ri_comm")
@@ -365,3 +396,38 @@ async def test_negotiate_req_socket_none_causes_retry(zmq_context):
result = await agent._negotiate_connection(max_retries=1)
assert result is False
@pytest.mark.asyncio
async def test_handle_message_pause_command(zmq_context):
"""Test handle_message with a valid PauseCommand."""
agent = RICommunicationAgent("ri_comm")
agent._req_socket = AsyncMock()
agent.logger = MagicMock()
agent._req_socket.recv_json.return_value = {"status": "ok"}
pause_cmd = PauseCommand(data=True)
msg = InternalMessage(to="ri_comm", sender="user_int", body=pause_cmd.model_dump_json())
await agent.handle_message(msg)
agent._req_socket.send_json.assert_awaited_once()
args = agent._req_socket.send_json.await_args[0][0]
assert args["endpoint"] == RIEndpoint.PAUSE.value
assert args["data"] is True
@pytest.mark.asyncio
async def test_handle_message_invalid_pause_command(zmq_context):
"""Test handle_message with invalid JSON."""
agent = RICommunicationAgent("ri_comm")
agent._req_socket = AsyncMock()
agent.logger = MagicMock()
msg = InternalMessage(to="ri_comm", sender="user_int", body="invalid json")
await agent.handle_message(msg)
agent.logger.warning.assert_called_with("Incorrect message format for PauseCommand.")
agent._req_socket.send_json.assert_not_called()

View File

@@ -58,17 +58,20 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
thread="prompt_message", # REQUIRED: thread must match handle_message logic
)
await agent.handle_message(msg)
# Verification
# "Hello world." constitutes one sentence/chunk based on punctuation split
# The agent should call send once with the full sentence
# The agent should call send once with the full sentence, PLUS once more for full reply
assert agent.send.called
args = agent.send.call_args_list[0][0][0]
assert args.to == mock_settings.agent_settings.bdi_core_name
assert "Hello world." in args.body
# Check args. We expect at least one call sending "Hello world."
calls = agent.send.call_args_list
bodies = [c[0][0].body for c in calls]
assert any("Hello world." in b for b in bodies)
@pytest.mark.asyncio
@@ -80,18 +83,23 @@ async def test_llm_processing_errors(mock_httpx_client, mock_settings):
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
thread="prompt_message",
)
# HTTP Error
# HTTP Error: stream method RAISES exception immediately
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
await agent.handle_message(msg)
assert "LLM service unavailable." in agent.send.call_args[0][0].body
# Check that error message was sent
assert agent.send.called
assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body
# General Exception
agent.send.reset_mock()
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
await agent.handle_message(msg)
assert "Error processing the request." in agent.send.call_args[0][0].body
assert "Error processing the request." in agent.send.call_args_list[0][0][0].body
@pytest.mark.asyncio
@@ -113,16 +121,19 @@ async def test_llm_json_error(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Ensure logger is mocked
agent.logger = MagicMock()
with patch.object(agent.logger, "error") as log:
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
await agent.handle_message(msg)
log.assert_called() # Should log JSONDecodeError
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
thread="prompt_message",
)
await agent.handle_message(msg)
agent.logger.error.assert_called() # Should log JSONDecodeError
def test_llm_instructions():
@@ -157,6 +168,7 @@ async def test_handle_message_validation_error_branch_no_send(mock_httpx_client,
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=invalid_json,
thread="prompt_message",
)
await agent.handle_message(msg)
@@ -265,3 +277,48 @@ async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_set
# Only the valid 'data:' line should yield content
assert tokens == ["Hi"]
@pytest.mark.asyncio
async def test_clear_history_command(mock_settings):
"""Test that the 'clear_history' message clears the agent's memory."""
# setup LLM to have some history
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
agent = LLMAgent("llm_agent")
agent.history = [
{"role": "user", "content": "Old conversation context"},
{"role": "assistant", "content": "Old response"},
]
assert len(agent.history) == 2
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_program_manager_name,
body="clear_history",
)
await agent.handle_message(msg)
assert len(agent.history) == 0
@pytest.mark.asyncio
async def test_handle_assistant_and_user_messages(mock_settings):
agent = LLMAgent("llm_agent")
# Assistant message
msg_ast = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
thread="assistant_message",
body="I said this",
)
await agent.handle_message(msg_ast)
assert agent.history[-1] == {"role": "assistant", "content": "I said this"}
# User message
msg_usr = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
thread="user_message",
body="User said this",
)
await agent.handle_message(msg_usr)
assert agent.history[-1] == {"role": "user", "content": "User said this"}

View File

@@ -55,4 +55,6 @@ def test_get_decode_options():
assert isinstance(options["sample_len"], int)
# When disabled, it should not limit output length based on input size
assert "sample_rate" not in options
recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=False)
options = recognizer._get_decode_options(audio)
assert "sample_len" not in options

View File

@@ -36,7 +36,12 @@ async def test_transcription_agent_flow(mock_zmq_context):
agent.send = AsyncMock()
agent._running = True
agent.add_behavior = AsyncMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -143,7 +148,12 @@ async def test_transcription_loop_continues_after_error(mock_zmq_context):
agent = TranscriptionAgent("tcp://in")
agent._running = True # ← REQUIRED to enter the loop
agent.send = AsyncMock() # should never be called
agent.add_behavior = AsyncMock() # match other tests
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro) # match other tests
await agent.setup()
@@ -180,7 +190,12 @@ async def test_transcription_continue_branch_when_empty(mock_zmq_context):
# Make loop runnable
agent._running = True
agent.send = AsyncMock()
agent.add_behavior = AsyncMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()

View File

@@ -0,0 +1,152 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
@pytest.fixture(autouse=True)
def mock_zmq():
with patch("zmq.asyncio.Context") as mock:
mock.instance.return_value = MagicMock()
yield mock
@pytest.fixture
def agent():
return VADAgent("tcp://localhost:5555", False)
@pytest.mark.asyncio
async def test_handle_message_pause(agent):
agent._paused = MagicMock()
# It starts set (not paused)
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="PAUSE")
# We need to mock settings to match sender name
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
await agent.handle_message(msg)
agent._paused.clear.assert_called_once()
assert agent._reset_needed is True
@pytest.mark.asyncio
async def test_handle_message_resume(agent):
agent._paused = MagicMock()
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="RESUME")
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
await agent.handle_message(msg)
agent._paused.set.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_unknown_command(agent):
agent._paused = MagicMock()
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="UNKNOWN")
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
agent.logger = MagicMock()
await agent.handle_message(msg)
agent._paused.clear.assert_not_called()
agent._paused.set.assert_not_called()
@pytest.mark.asyncio
async def test_handle_message_unknown_sender(agent):
agent._paused = MagicMock()
msg = InternalMessage(to="vad", sender="other_agent", body="PAUSE")
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
await agent.handle_message(msg)
agent._paused.clear.assert_not_called()
@pytest.mark.asyncio
async def test_status_loop_waits_for_running(agent):
agent._running = True
agent.program_sub_socket = AsyncMock()
agent.program_sub_socket.close = MagicMock()
agent._reset_stream = AsyncMock()
# Sequence of messages:
# 1. Wrong topic
# 2. Right topic, wrong status (STARTING)
# 3. Right topic, RUNNING -> Should break loop
agent.program_sub_socket.recv_multipart.side_effect = [
(b"wrong_topic", b"whatever"),
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
]
await agent._status_loop()
assert agent._reset_stream.await_count == 1
agent.program_sub_socket.close.assert_called_once()
@pytest.mark.asyncio
async def test_setup_success(agent, mock_zmq):
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
mock_context = mock_zmq.instance.return_value
mock_sub = MagicMock()
mock_pub = MagicMock()
# We expect multiple socket calls:
# 1. audio_in (SUB)
# 2. audio_out (PUB)
# 3. program_sub (SUB)
mock_context.socket.side_effect = [mock_sub, mock_pub, mock_sub]
with patch("control_backend.agents.perception.vad_agent.torch.hub.load") as mock_load:
mock_load.return_value = (MagicMock(), None)
with patch("control_backend.agents.perception.vad_agent.TranscriptionAgent") as MockTrans:
mock_trans_instance = MockTrans.return_value
mock_trans_instance.start = AsyncMock()
await agent.setup()
mock_trans_instance.start.assert_awaited_once()
assert agent.add_behavior.call_count == 2 # streaming_loop + status_loop
assert agent.audio_in_socket is not None
assert agent.audio_out_socket is not None
assert agent.program_sub_socket is not None
@pytest.mark.asyncio
async def test_reset_stream(agent):
mock_poller = MagicMock()
agent.audio_in_poller = mock_poller
# poll(1) returns not None twice, then None
mock_poller.poll = AsyncMock(side_effect=[b"data", b"data", None])
agent._ready = MagicMock()
await agent._reset_stream()
assert mock_poller.poll.await_count == 3
agent._ready.set.assert_called_once()

View File

@@ -5,6 +5,16 @@ import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.core.config import settings
# We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets
# aren't closed properly.
@pytest.fixture(autouse=True)
def mock_zmq():
with patch("zmq.asyncio.Context") as mock:
mock.instance.return_value = MagicMock()
yield mock
@pytest.fixture
@@ -126,6 +136,54 @@ async def test_no_data(audio_out_socket, vad_agent):
assert len(vad_agent.audio_buffer) == 0
@pytest.mark.asyncio
async def test_streaming_loop_reset_needed(audio_out_socket, vad_agent):
"""Test that _reset_needed branch works as expected."""
vad_agent._reset_needed = True
vad_agent._ready.set()
vad_agent._paused.set()
vad_agent._running = True
vad_agent.audio_buffer = np.array([1.0], dtype=np.float32)
vad_agent.i_since_speech = 0
# Mock _reset_stream to stop the loop by setting _running=False
async def mock_reset():
vad_agent._running = False
vad_agent._reset_stream = mock_reset
# Needs a poller to avoid AssertionError
vad_agent.audio_in_poller = AsyncMock()
vad_agent.audio_in_poller.poll.return_value = None
await vad_agent._streaming_loop()
assert vad_agent._reset_needed is False
assert len(vad_agent.audio_buffer) == 0
assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech
@pytest.mark.asyncio
async def test_streaming_loop_no_data_clears_buffer(audio_out_socket, vad_agent):
"""Test that if poll returns None, buffer is cleared if not empty."""
vad_agent.audio_buffer = np.array([1.0], dtype=np.float32)
vad_agent._ready.set()
vad_agent._paused.set()
vad_agent._running = True
class MockPoller:
async def poll(self, timeout_ms=None):
vad_agent._running = False # stop after one poll
return None
vad_agent.audio_in_poller = MockPoller()
await vad_agent._streaming_loop()
assert len(vad_agent.audio_buffer) == 0
assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech
@pytest.mark.asyncio
async def test_vad_model_load_failure_stops_agent(vad_agent):
"""
@@ -140,12 +198,10 @@ async def test_vad_model_load_failure_stops_agent(vad_agent):
# Patch stop to an AsyncMock so we can check it was awaited
vad_agent.stop = AsyncMock()
result = await vad_agent.setup()
await vad_agent.setup()
# Assert stop was called
vad_agent.stop.assert_awaited_once()
# Assert setup returned None
assert result is None
@pytest.mark.asyncio
@@ -155,7 +211,7 @@ async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog):
audio_out_socket is set to None, None is returned, and an error is logged.
"""
mock_socket = MagicMock()
mock_socket.bind_to_random_port.side_effect = zmq.ZMQBindError()
mock_socket.bind.side_effect = zmq.ZMQBindError()
with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx:
mock_ctx.return_value.socket.return_value = mock_socket

View File

@@ -0,0 +1,24 @@
import logging
from control_backend.agents.base import BaseAgent
class MyAgent(BaseAgent):
async def setup(self):
pass
async def handle_message(self, msg):
pass
def test_base_agent_logger_init():
# When defining a subclass, __init_subclass__ runs
# The BaseAgent in agents/base.py sets the logger
assert hasattr(MyAgent, "logger")
assert isinstance(MyAgent.logger, logging.Logger)
# The logger name depends on the package.
# Since this test file is running as a module, __package__ might be None or the test package.
# In 'src/control_backend/agents/base.py', it uses __package__ of base.py which is
# 'control_backend.agents'.
# So logger name should be control_backend.agents.MyAgent
assert MyAgent.logger.name == "control_backend.agents.MyAgent"

View File

@@ -7,6 +7,15 @@ import pytest
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.program import (
ConditionalNorm,
Goal,
KeywordBelief,
Phase,
Plan,
Program,
Trigger,
)
from control_backend.schemas.ri_message import RIEndpoint
@@ -16,6 +25,7 @@ def agent():
agent.send = AsyncMock()
agent.logger = MagicMock()
agent.sub_socket = AsyncMock()
agent.pub_socket = AsyncMock()
return agent
@@ -49,21 +59,18 @@ async def test_send_to_gesture_agent(agent):
@pytest.mark.asyncio
async def test_send_to_program_manager(agent):
async def test_send_to_bdi_belief(agent):
"""Verify belief update format."""
context_str = "2"
context_str = "some_goal"
await agent._send_to_program_manager(context_str)
await agent._send_to_bdi_belief(context_str)
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert agent.send.await_count == 1
sent_msg = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.bdi_program_manager_name
assert sent_msg.thread == "belief_override_id"
body = json.loads(sent_msg.body)
assert body["belief"] == context_str
assert sent_msg.to == settings.agent_settings.bdi_core_name
assert sent_msg.thread == "beliefs"
assert "achieved_some_goal" in sent_msg.body
@pytest.mark.asyncio
@@ -77,6 +84,10 @@ async def test_receive_loop_routing_success(agent):
# Prepare JSON payloads as bytes
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
# override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal).
# To test routing, we need to populate the maps
agent._goal_map["Hello Override"] = "some_goal_slug"
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
@@ -88,7 +99,7 @@ async def test_receive_loop_routing_success(agent):
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_program_manager = AsyncMock()
agent._send_to_bdi_belief = AsyncMock()
try:
await agent._receive_button_event()
@@ -103,12 +114,12 @@ async def test_receive_loop_routing_success(agent):
# Gesture
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
# Override
agent._send_to_program_manager.assert_awaited_once_with("Hello Override")
# Override (since we mapped it to a goal)
agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug")
assert agent._send_to_speech_agent.await_count == 1
assert agent._send_to_gesture_agent.await_count == 1
assert agent._send_to_program_manager.await_count == 1
assert agent._send_to_bdi_belief.await_count == 1
@pytest.mark.asyncio
@@ -125,7 +136,6 @@ async def test_receive_loop_unknown_type(agent):
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_belief_collector = AsyncMock()
try:
await agent._receive_button_event()
@@ -137,10 +147,165 @@ async def test_receive_loop_unknown_type(agent):
# Ensure no handlers were called
agent._send_to_speech_agent.assert_not_called()
agent._send_to_gesture_agent.assert_not_called()
agent._send_to_belief_collector.assert_not_called()
agent.logger.warning.assert_called_with(
"Received button press with unknown type '%s' (context: '%s').",
"unknown_thing",
"some_data",
)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_create_mapping(agent):
# Create a program with a trigger, goal, and conditional norm
import uuid
trigger_id = uuid.uuid4()
goal_id = uuid.uuid4()
norm_id = uuid.uuid4()
cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key")
plan = Plan(id=uuid.uuid4(), name="p1", steps=[])
trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan)
goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan)
cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond)
phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger])
prog = Program(phases=[phase])
# Call create_mapping via handle_message
msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json())
await agent.handle_message(msg)
# Check maps
assert str(trigger_id) in agent._trigger_map
assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger"
assert str(goal_id) in agent._goal_map
assert agent._goal_map[str(goal_id)] == "my_goal"
assert str(norm_id) in agent._cond_norm_map
assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite"
@pytest.mark.asyncio
async def test_create_mapping_invalid_json(agent):
# Pass invalid json to handle_message thread "new_program"
msg = InternalMessage(to="me", thread="new_program", body="invalid json")
await agent.handle_message(msg)
# Should log error and maps should remain empty or cleared
agent.logger.error.assert_called()
@pytest.mark.asyncio
async def test_handle_message_trigger_start(agent):
# Setup reverse map manually
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
args = agent.pub_socket.send_multipart.call_args[0][0]
assert args[0] == b"experiment"
payload = json.loads(args[1])
assert payload["type"] == "trigger_update"
assert payload["id"] == "ui_id_123"
assert payload["achieved"] is True
@pytest.mark.asyncio
async def test_handle_message_trigger_end(agent):
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
assert payload["type"] == "trigger_update"
assert payload["achieved"] is False
@pytest.mark.asyncio
async def test_handle_message_transition_phase(agent):
msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
assert payload["type"] == "phase_update"
assert payload["id"] == "phase_id_123"
@pytest.mark.asyncio
async def test_handle_message_goal_start(agent):
agent._goal_reverse_map["goal_slug"] = "goal_id_123"
msg = InternalMessage(to="me", thread="goal_start", body="goal_slug")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
assert payload["type"] == "goal_update"
assert payload["id"] == "goal_id_123"
assert payload["active"] is True
@pytest.mark.asyncio
async def test_handle_message_active_norms_update(agent):
agent._cond_norm_reverse_map["norm_active"] = "id_1"
agent._cond_norm_reverse_map["norm_inactive"] = "id_2"
# Body is like: "('norm_active', 'other')"
# The split logic handles quotes etc.
msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
assert payload["type"] == "cond_norms_state_update"
norms = {n["id"]: n["active"] for n in payload["norms"]}
assert norms["id_1"] is True
assert norms["id_2"] is False
@pytest.mark.asyncio
async def test_send_experiment_control(agent):
# Test next_phase
await agent._send_experiment_control_to_bdi_core("next_phase")
agent.send.assert_awaited()
msg = agent.send.call_args[0][0]
assert msg.thread == "force_next_phase"
# Test reset_phase
await agent._send_experiment_control_to_bdi_core("reset_phase")
msg = agent.send.call_args[0][0]
assert msg.thread == "reset_current_phase"
# Test reset_experiment
await agent._send_experiment_control_to_bdi_core("reset_experiment")
msg = agent.send.call_args[0][0]
assert msg.thread == "reset_experiment"
@pytest.mark.asyncio
async def test_send_pause_command(agent):
await agent._send_pause_command("true")
# Sends to RI and VAD
assert agent.send.await_count == 2
msgs = [call.args[0] for call in agent.send.call_args_list]
ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name)
assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint
assert json.loads(ri_msg.body)["data"] is True
vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name)
assert vad_msg.body == "PAUSE"
agent.send.reset_mock()
await agent._send_pause_command("false")
assert agent.send.await_count == 2
vad_msg = next(
m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name
).args[0]
assert vad_msg.body == "RESUME"

View File

@@ -0,0 +1,96 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import user_interact
@pytest.fixture
def app():
app = FastAPI()
app.include_router(user_interact.router)
return app
@pytest.fixture
def client(app):
return TestClient(app)
@pytest.mark.asyncio
async def test_receive_button_event(client):
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
payload = {"type": "speech", "context": "hello"}
response = client.post("/button_pressed", json=payload)
assert response.status_code == 202
assert response.json() == {"status": "Event received"}
mock_pub_socket.send_multipart.assert_awaited_once()
args = mock_pub_socket.send_multipart.call_args[0][0]
assert args[0] == b"button_pressed"
assert "speech" in args[1].decode()
@pytest.mark.asyncio
async def test_receive_button_event_invalid_payload(client):
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Missing context
payload = {"type": "speech"}
response = client.post("/button_pressed", json=payload)
assert response.status_code == 422
mock_pub_socket.send_multipart.assert_not_called()
@pytest.mark.asyncio
async def test_experiment_stream_direct_call():
"""
Directly calling the endpoint function to test the streaming logic
without dealing with TestClient streaming limitations.
"""
mock_socket = AsyncMock()
# 1. recv data
# 2. recv timeout
# 3. disconnect (request.is_disconnected returns True)
mock_socket.recv_multipart.side_effect = [
(b"topic", b"message1"),
TimeoutError(),
(b"topic", b"message2"), # Should not be reached if disconnect checks work
]
mock_socket.close = MagicMock()
mock_socket.connect = MagicMock()
mock_socket.subscribe = MagicMock()
mock_context = MagicMock()
mock_context.socket.return_value = mock_socket
with patch(
"control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context
):
mock_request = AsyncMock()
# is_disconnected sequence:
# 1. False (before first recv) -> reads message1
# 2. False (before second recv) -> triggers TimeoutError, continues
# 3. True (before third recv) -> break loop
mock_request.is_disconnected.side_effect = [False, False, True]
response = await user_interact.experiment_stream(mock_request)
lines = []
# Consume the generator
async for line in response.body_iterator:
lines.append(line)
assert "data: message1\n\n" in lines
assert len(lines) == 1
mock_socket.connect.assert_called()
mock_socket.subscribe.assert_called_with(b"experiment")
mock_socket.close.assert_called()

View File

@@ -25,7 +25,6 @@ def mock_settings():
mock.zmq_settings.internal_sub_address = "tcp://localhost:5561"
mock.zmq_settings.ri_command_address = "tcp://localhost:0000"
mock.agent_settings.bdi_core_name = "bdi_core_agent"
mock.agent_settings.bdi_belief_collector_name = "belief_collector_agent"
mock.agent_settings.llm_name = "llm_agent"
mock.agent_settings.robot_speech_name = "robot_speech_agent"
mock.agent_settings.transcription_name = "transcription_agent"

View File

@@ -99,12 +99,75 @@ async def test_send_to_local_agent(monkeypatch):
# Patch inbox.put
target.inbox.put = AsyncMock()
message = InternalMessage(to="receiver", sender="sender", body="hello")
message = InternalMessage(to=target.name, sender=sender.name, body="hello")
await sender.send(message)
target.inbox.put.assert_awaited_once_with(message)
sender.logger.debug.assert_called_once()
@pytest.mark.asyncio
async def test_send_to_zmq_agent(monkeypatch):
sender = DummyAgent("sender")
target = "remote_receiver"
# Fake logger
sender.logger = MagicMock()
# Fake zmq
sender._internal_pub_socket = AsyncMock()
message = InternalMessage(to=target, sender=sender.name, body="hello")
await sender.send(message)
zmq_calls = sender._internal_pub_socket.send_multipart.call_args[0][0]
assert zmq_calls[0] == f"internal/{target}".encode()
@pytest.mark.asyncio
async def test_send_to_multiple_local_agents(monkeypatch):
sender = DummyAgent("sender")
target1 = DummyAgent("receiver1")
target2 = DummyAgent("receiver2")
# Fake logger
sender.logger = MagicMock()
# Patch inbox.put
target1.inbox.put = AsyncMock()
target2.inbox.put = AsyncMock()
message = InternalMessage(to=[target1.name, target2.name], sender=sender.name, body="hello")
await sender.send(message)
target1.inbox.put.assert_awaited_once_with(message)
target2.inbox.put.assert_awaited_once_with(message)
@pytest.mark.asyncio
async def test_send_to_multiple_agents(monkeypatch):
sender = DummyAgent("sender")
target1 = DummyAgent("receiver1")
target2 = "remote_receiver"
# Fake logger
sender.logger = MagicMock()
# Fake zmq
sender._internal_pub_socket = AsyncMock()
# Patch inbox.put
target1.inbox.put = AsyncMock()
message = InternalMessage(to=[target1.name, target2], sender=sender.name, body="hello")
await sender.send(message)
target1.inbox.put.assert_awaited_once_with(message)
zmq_calls = sender._internal_pub_socket.send_multipart.call_args[0][0]
assert zmq_calls[0] == f"internal/{target2}".encode()
@pytest.mark.asyncio

View File

@@ -0,0 +1,40 @@
from unittest.mock import MagicMock, patch
import zmq
from control_backend.main import setup_sockets
def test_setup_sockets_proxy():
mock_context = MagicMock()
mock_pub = MagicMock()
mock_sub = MagicMock()
mock_context.socket.side_effect = [mock_pub, mock_sub]
with patch("zmq.asyncio.Context.instance", return_value=mock_context):
with patch("zmq.proxy") as mock_proxy:
setup_sockets()
mock_pub.bind.assert_called()
mock_sub.bind.assert_called()
mock_proxy.assert_called_with(mock_sub, mock_pub)
# Check cleanup
mock_pub.close.assert_called()
mock_sub.close.assert_called()
def test_setup_sockets_proxy_error():
mock_context = MagicMock()
mock_pub = MagicMock()
mock_sub = MagicMock()
mock_context.socket.side_effect = [mock_pub, mock_sub]
with patch("zmq.asyncio.Context.instance", return_value=mock_context):
with patch("zmq.proxy", side_effect=zmq.ZMQError):
with patch("control_backend.main.logger") as mock_logger:
setup_sockets()
mock_logger.warning.assert_called()
mock_pub.close.assert_called()
mock_sub.close.assert_called()