Compare commits

..

55 Commits

Author SHA1 Message Date
Storm
cdb7fac53f Merge branch 'dev' into feat/pause-functionality 2026-01-07 15:50:45 +01:00
Storm
d1ad2c1549 feat: implement pausing functionality in CB
ref: N25B-350
2026-01-06 18:08:43 +01:00
Björn Otgaar
612a96940d Merge branch 'feat/environment-variables' into 'dev'
Docs for environment variables, parameterize some constants

See merge request ics/sp/2025/n25b/pepperplus-cb!38
2026-01-06 09:02:49 +00:00
Pim Hutting
4c20656c75 Merge branch 'feat/program-reset-llm' into 'dev'
feat: made program reset LLM

See merge request ics/sp/2025/n25b/pepperplus-cb!39
2026-01-02 15:13:05 +00:00
Pim Hutting
6ca86e4b81 feat: made program reset LLM 2026-01-02 15:13:04 +00:00
Storm
867837dcc4 feat: implemented pause functionality in VAD agent
Functionality is implemented by pausing the _streaming_loop function.

ref: N25B-350
2025-12-30 15:58:18 +02:00
Storm
9adeb1efff Merge branch 'feat/semantic-beliefs' into feat/pause-functionality 2025-12-30 15:52:12 +02:00
Twirre Meulenbelt
42ee5c76d8 test: create tests for belief extractor agent
Includes changes in schemas. Change type of `norms` in `Program` imperceptibly, big changes in schema of `BeliefMessage` to support deleting beliefs.

ref: N25B-380
2025-12-29 17:12:02 +01:00
Twirre Meulenbelt
7d798f2e77 Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:40:16 +01:00
Twirre Meulenbelt
5282c2471f Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:35:39 +01:00
Twirre Meulenbelt
57b1276cb5 test: make tests work again after changing Program schema
ref: N25B-380
2025-12-29 12:31:51 +01:00
Storm
200bd27d9b Merge branch 'dev' into feat/pause-functionality 2025-12-29 12:45:14 +02:00
Twirre Meulenbelt
7e0dc9ce1c Merge remote-tracking branch 'origin/feat/agentspeak-generation' into feat/semantic-beliefs
# Conflicts:
#	src/control_backend/schemas/program.py
2025-12-23 17:36:39 +01:00
3253760ef1 feat: new AST representation
File names will be changed eventually.

ref: N25B-376
2025-12-23 17:30:35 +01:00
Twirre Meulenbelt
71cefdfef3 fix: add types to all config properties
ref: N25B-380
2025-12-23 17:14:49 +01:00
Twirre Meulenbelt
33501093a1 feat: extract semantic beliefs from conversation
ref: N25B-380
2025-12-23 17:09:58 +01:00
Luijkx,S.O.H. (Storm)
adbb7ffd5c Merge branch 'feat/user-interrupt-agent' into 'dev'
create UserInterruptAgent with connection to UI

See merge request ics/sp/2025/n25b/pepperplus-cb!40
2025-12-22 13:56:03 +00:00
Pim Hutting
0501a9fba3 create UserInterruptAgent with connection to UI 2025-12-22 13:56:02 +00:00
Storm
539e814c5a feat: functionality implemented for RI pausing functionality
Currently, no CB pausing functionality has been implemented yet. This commit only includes necessary changes to use RI pausing.

ref: N25B-350
2025-12-22 14:02:18 +01:00
756e1f0dc5 feat: persistent rules and stuff
So ugly

ref: N25B-376
2025-12-18 14:33:42 +01:00
Twirre Meulenbelt
f91cec6708 fix: things in AgentSpeak, add custom actions
ref: N25B-376
2025-12-18 11:50:16 +01:00
28262eb27e fix: default case for plans
ref: N25B-376
2025-12-17 16:20:37 +01:00
1d36d2e089 feat: (hopefully) better intermediate representation
ref: N25B-376
2025-12-17 15:33:27 +01:00
742e36b94f chore: non-optional uuid id
ref: N25B-376
2025-12-17 14:30:14 +01:00
Twirre Meulenbelt
57fe3ae3f6 Merge remote-tracking branch 'origin/dev' into feat/agentspeak-generation 2025-12-17 13:20:14 +01:00
e704ec5ed4 feat: basic flow and phase transitions
ref: N25B-376
2025-12-16 17:00:32 +01:00
Twirre Meulenbelt
27f04f0958 style: use yield instead of returning arrays
ref: N25B-376
2025-12-16 16:11:01 +01:00
Twirre Meulenbelt
8cc177041a feat: add a second phase in test_program
ref: N25B-376
2025-12-16 15:12:22 +01:00
4a432a603f fix: separate trigger plan generation
ref: N25B-376
2025-12-16 14:12:04 +01:00
JobvAlewijk
3e7f2ef574 Merge branch 'feat/quiet-llm' into 'dev'
feat: implemented extra log level for LLM token stream

See merge request ics/sp/2025/n25b/pepperplus-cb!37
2025-12-16 11:26:37 +00:00
Luijkx,S.O.H. (Storm)
78abad55d3 feat: implemented extra log level for LLM token stream 2025-12-16 11:26:35 +00:00
bab4800698 feat: add trigger generation
ref: N25B-376
2025-12-16 12:10:52 +01:00
Twirre
4ab6b2a0e6 Merge branch 'feat/cb2ri-gestures' into 'dev'
Gestures in the CB.

See merge request ics/sp/2025/n25b/pepperplus-cb!36
2025-12-16 09:24:25 +00:00
Twirre Meulenbelt
db5504db20 chore: remove redundant check 2025-12-16 10:22:11 +01:00
d043c54336 refactor: program restructure
Also includes some AgentSpeak generation.

ref: N25B-376
2025-12-16 10:21:50 +01:00
Björn Otgaar
f15a518984 fix: tests
ref: N25B-334
2025-12-15 11:52:01 +01:00
Björn Otgaar
71d86f5fb0 Merge branch 'feat/cb2ri-gestures' of git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb into feat/cb2ri-gestures 2025-12-15 11:36:12 +01:00
Björn Otgaar
daf31ac6a6 fix: change the address to the config, update some logic, seperate the api endpoint, renaming things. yes, the tests don't work right now- this shouldn't be merged yet.
ref: N25B-334
2025-12-15 11:35:56 +01:00
Björn Otgaar
b2d014753d Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Pim Hutting <p.r.p.hutting@students.uu.nl>
2025-12-11 15:08:15 +00:00
Twirre Meulenbelt
0c682d6440 feat: introduce .env.example, docs
The example includes options that are expected to be changed. It also includes a reference to where in the docs you can find a full list of options.

ref: N25B-352
2025-12-11 13:35:19 +01:00
Björn Otgaar
2e472ea292 chore: remove wrong test paths 2025-12-11 12:48:18 +01:00
Björn Otgaar
1c9b722ba3 Merge branch 'dev' into feat/cb2ri-gestures 2025-12-11 12:46:32 +01:00
Twirre Meulenbelt
32d8f20dc9 feat: parameterize RI host
Was "localhost" in RI Communication Agent, now uses configurable setting. Secretly also removing "localhost" from VAD agent, as its socket should be something that's "inproc".

ref: N25B-352
2025-12-11 12:12:15 +01:00
Twirre Meulenbelt
9cc0e39955 fix: failures main tests since VAD agent initialization was changed
The test still expects the VAD agent to be started in main, rather than in the RI Communication Agent.

ref: N25B-356
2025-12-11 12:04:24 +01:00
Björn Otgaar
2366255b92 Merge branch 'fix/correct-vad-starting' into 'dev'
Move VAD agent creation to RI communication agent

See merge request ics/sp/2025/n25b/pepperplus-cb!34
2025-12-09 14:57:09 +00:00
Björn Otgaar
7f34fede81 fix: fix the tests
ref: N25B-334
2025-12-09 15:37:00 +01:00
Luijkx,S.O.H. (Storm)
a9255cb6e7 Merge branch 'test/coverage-max-all' into 'dev'
test: increased cb test coverage

See merge request ics/sp/2025/n25b/pepperplus-cb!32
2025-12-09 13:14:03 +00:00
JobvAlewijk
7f7c658901 test: increased cb test coverage 2025-12-09 13:14:02 +00:00
Björn Otgaar
3d62e7fc0c Merge branch 'feat/cb2ri-gestures' of git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb into feat/cb2ri-gestures 2025-12-09 14:13:01 +01:00
Björn Otgaar
6034263259 fix: correct the gestures bugs, change gestures socket to request/reply
ref: N25B-334
2025-12-09 14:08:59 +01:00
JobvAlewijk
63897f5969 chore: double tag 2025-12-09 12:33:43 +01:00
JobvAlewijk
a3cf389c05 Merge branch 'dev' of https://git.science.uu.nl/ics/sp/2025/n25b/pepperplus-cb into fix/correct-vad-starting 2025-12-07 23:09:31 +01:00
JobvAlewijk
de2e56ffce Merge branch 'fix/fix-socket-typing' into 'dev'
chore: fix socket typing in robot speech agent

See merge request ics/sp/2025/n25b/pepperplus-cb!33
2025-12-03 14:26:46 +00:00
Twirre Meulenbelt
21e9d05d6e fix: move VAD agent creation to RI communication agent
Previously, it was started in main, but it should use values negotiated by the RI communication agent.

ref: N25B-356
2025-12-03 15:07:29 +01:00
Björn Otgaar
bacc63aa31 chore: fix socket typing in robot speech agent 2025-12-02 14:22:39 +01:00
57 changed files with 4680 additions and 709 deletions

View File

@@ -0,0 +1,9 @@
%{first_multiline_commit_description}
To verify:
- [ ] Style checks pass
- [ ] Pipeline (tests) pass
- [ ] Documentation is up to date
- [ ] Tests are up to date (new code is covered)
- [ ] ...

View File

@@ -3,6 +3,7 @@ version: 1
custom_levels:
OBSERVATION: 25
ACTION: 26
LLM: 9
formatters:
# Console output
@@ -26,7 +27,7 @@ handlers:
stream: ext://sys.stdout
ui:
class: zmq.log.handlers.PUBHandler
level: DEBUG
level: LLM
formatter: json_experiment
# Level of external libraries
@@ -36,5 +37,5 @@ root:
loggers:
control_backend:
level: DEBUG
level: LLM
handlers: [ui]

View File

@@ -15,6 +15,7 @@ dependencies = [
"pydantic>=2.12.0",
"pydantic-settings>=2.11.0",
"python-json-logger>=4.0.0",
"python-slugify>=8.0.4",
"pyyaml>=6.0.3",
"pyzmq>=27.1.0",
"silero-vad>=6.0.0",

View File

@@ -12,7 +12,7 @@ from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
class RobotGestureAgent(BaseAgent):
"""
This agent acts as a bridge between the control backend and the Robot Interface (RI).
It receives speech commands from other agents or from the UI,
It receives gesture commands from other agents or from the UI,
and forwards them to the robot via a ZMQ PUB socket.
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
@@ -23,10 +23,12 @@ class RobotGestureAgent(BaseAgent):
"""
subsocket: azmq.Socket
repsocket: azmq.Socket
pubsocket: azmq.Socket
address = ""
bind = False
gesture_data = []
single_gesture_data = []
def __init__(
self,
@@ -34,11 +36,10 @@ class RobotGestureAgent(BaseAgent):
address=settings.zmq_settings.ri_command_address,
bind=False,
gesture_data=None,
single_gesture_data=None,
):
if gesture_data is None:
self.gesture_data = []
else:
self.gesture_data = gesture_data
self.gesture_data = gesture_data or []
self.single_gesture_data = single_gesture_data or []
super().__init__(name)
self.address = address
self.bind = bind
@@ -56,9 +57,8 @@ class RobotGestureAgent(BaseAgent):
context = azmq.Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind: # TODO: Should this ever be the case?
if self.bind:
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
@@ -69,6 +69,10 @@ class RobotGestureAgent(BaseAgent):
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures")
# REP socket for replying to gesture requests
self.repsocket = context.socket(zmq.REP)
self.repsocket.bind(settings.zmq_settings.internal_gesture_rep_adress)
self.add_behavior(self._zmq_command_loop())
self.add_behavior(self._fetch_gestures_loop())
@@ -92,13 +96,19 @@ class RobotGestureAgent(BaseAgent):
try:
gesture_command = GestureCommand.model_validate_json(msg.body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.availableTags():
if gesture_command.data not in self.gesture_data:
self.logger.warning(
"Received gesture tag '%s' which is not in available tags. Early returning",
gesture_command.data,
)
return
elif gesture_command.endpoint == RIEndpoint.GESTURE_SINGLE:
if gesture_command.data not in self.single_gesture_data:
self.logger.warning(
"Received gesture '%s' which is not in available gestures. Early returning",
gesture_command.data,
)
return
await self.pubsocket.send_json(gesture_command.model_dump())
except Exception:
self.logger.exception("Error processing internal message.")
@@ -120,7 +130,7 @@ class RobotGestureAgent(BaseAgent):
body = json.loads(body)
gesture_command = GestureCommand.model_validate(body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.availableTags():
if gesture_command.data not in self.gesture_data:
self.logger.warning(
"Received gesture tag '%s' which is not in available tags.\
Early returning",
@@ -139,157 +149,23 @@ class RobotGestureAgent(BaseAgent):
"""
while self._running:
try:
topic, body = await self.subsocket.recv_multipart()
# Don't process commands here
if topic != b"send_gestures":
continue
# Get a request
body = await self.repsocket.recv()
# Figure out amount, if specified
try:
body = json.loads(body)
except json.JSONDecodeError:
body = None
# We could have the body be the nummer of gestures you want to fetch or something.
amount = None
if isinstance(body, int):
amount = body
tags = self.availableTags()[:amount] if amount else self.availableTags()
# Fetch tags from gesture data and respond
tags = self.gesture_data[:amount] if amount else self.gesture_data
response = json.dumps({"tags": tags}).encode()
await self.pubsocket.send_multipart(
[
b"get_gestures",
response,
]
)
await self.repsocket.send(response)
except Exception:
self.logger.exception("Error fetching gesture tags.")
def availableTags(self):
"""
Returns the available gesture tags.
:return: List of available gesture tags.
"""
return [
"above",
"affirmative",
"afford",
"agitated",
"all",
"allright",
"alright",
"any",
"assuage",
"assuage",
"attemper",
"back",
"bashful",
"beg",
"beseech",
"blank",
"body language",
"bored",
"bow",
"but",
"call",
"calm",
"choose",
"choice",
"cloud",
"cogitate",
"cool",
"crazy",
"disappointed",
"down",
"earth",
"empty",
"embarrassed",
"enthusiastic",
"entire",
"estimate",
"except",
"exalted",
"excited",
"explain",
"far",
"field",
"floor",
"forlorn",
"friendly",
"front",
"frustrated",
"gentle",
"gift",
"give",
"ground",
"happy",
"hello",
"her",
"here",
"hey",
"hi",
"him",
"hopeless",
"hysterical",
"I",
"implore",
"indicate",
"joyful",
"me",
"meditate",
"modest",
"negative",
"nervous",
"no",
"not know",
"nothing",
"offer",
"ok",
"once upon a time",
"oppose",
"or",
"pacify",
"pick",
"placate",
"please",
"present",
"proffer",
"quiet",
"reason",
"refute",
"reject",
"rousing",
"sad",
"select",
"shamefaced",
"show",
"show sky",
"sky",
"soothe",
"sun",
"supplicate",
"tablet",
"tall",
"them",
"there",
"think",
"timid",
"top",
"unless",
"up",
"upstairs",
"void",
"warm",
"winner",
"yeah",
"yes",
"yoo-hoo",
"you",
"your",
"zero",
"zestful",
]

View File

@@ -29,7 +29,7 @@ class RobotSpeechAgent(BaseAgent):
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
address: str,
bind=False,
):
super().__init__(name)

View File

@@ -0,0 +1,203 @@
import typing
from dataclasses import dataclass, field
# --- Types ---
@dataclass
class BeliefLiteral:
"""
Represents a literal or atom.
Example: phase(1), user_said("hello"), ~started
"""
functor: str
args: list[str] = field(default_factory=list)
negated: bool = False
def __str__(self):
# In ASL, 'not' is usually for closed-world assumption (prolog style),
# '~' is for explicit negation in beliefs.
# For simplicity in behavior trees, we often use 'not' for conditions.
prefix = "not " if self.negated else ""
if not self.args:
return f"{prefix}{self.functor}"
# Clean args to ensure strings are quoted if they look like strings,
# but usually the converter handles the quoting of string literals.
args_str = ", ".join(self.args)
return f"{prefix}{self.functor}({args_str})"
@dataclass
class GoalLiteral:
name: str
def __str__(self):
return f"!{self.name}"
@dataclass
class ActionLiteral:
"""
Represents a step in a plan body.
Example: .say("Hello") or !achieve_goal
"""
code: str
def __str__(self):
return self.code
@dataclass
class BinaryOp:
"""
Represents logical operations.
Example: (A & B) | C
"""
left: "Expression | str"
operator: typing.Literal["&", "|"]
right: "Expression | str"
def __str__(self):
l_str = str(self.left)
r_str = str(self.right)
if isinstance(self.left, BinaryOp):
l_str = f"({l_str})"
if isinstance(self.right, BinaryOp):
r_str = f"({r_str})"
return f"{l_str} {self.operator} {r_str}"
Literal = BeliefLiteral | GoalLiteral | ActionLiteral
Expression = Literal | BinaryOp | str
@dataclass
class Rule:
"""
Represents an inference rule.
Example: head :- body.
"""
head: Expression
body: Expression | None = None
def __str__(self):
if not self.body:
return f"{self.head}."
return f"{self.head} :- {self.body}."
@dataclass
class PersistentRule:
"""
Represents an inference rule, where the inferred belief is persistent when formed.
"""
head: Expression
body: Expression
def __str__(self):
if not self.body:
raise Exception("Rule without body should not be persistent.")
lines = []
if isinstance(self.body, BinaryOp):
lines.append(f"+{self.body.left}")
if self.body.operator == "&":
lines.append(f" : {self.body.right}")
lines.append(f" <- +{self.head}.")
if self.body.operator == "|":
lines.append(f"+{self.body.right}")
lines.append(f" <- +{self.head}.")
return "\n".join(lines)
@dataclass
class Plan:
"""
Represents a plan.
Syntax: +trigger : context <- body.
"""
trigger: BeliefLiteral | GoalLiteral
context: list[Expression] = field(default_factory=list)
body: list[ActionLiteral] = field(default_factory=list)
def __str__(self):
# Indentation settings
INDENT = " "
ARROW = "\n <- "
COLON = "\n : "
# Build Header
header = f"+{self.trigger}"
if self.context:
ctx_str = f" &\n{INDENT}".join(str(c) for c in self.context)
header += f"{COLON}{ctx_str}"
# Case 1: Empty body
if not self.body:
return f"{header}."
# Case 2: Short body (optional optimization, keeping it uniform usually better)
header += ARROW
lines = []
# We start the first action on the same line or next line.
# Let's put it on the next line for readability if there are multiple.
if len(self.body) == 1:
return f"{header}{self.body[0]}."
# First item
lines.append(f"{header}{self.body[0]};")
# Middle items
for item in self.body[1:-1]:
lines.append(f"{INDENT}{item};")
# Last item
lines.append(f"{INDENT}{self.body[-1]}.")
return "\n".join(lines)
@dataclass
class AgentSpeakFile:
"""
Root element representing the entire generated file.
"""
initial_beliefs: list[Rule] = field(default_factory=list)
inference_rules: list[Rule | PersistentRule] = field(default_factory=list)
plans: list[Plan] = field(default_factory=list)
def __str__(self):
sections = []
if self.initial_beliefs:
sections.append("// --- Initial Beliefs & Facts ---")
sections.extend(str(rule) for rule in self.initial_beliefs)
sections.append("")
if self.inference_rules:
sections.append("// --- Inference Rules ---")
sections.extend(str(rule) for rule in self.inference_rules if isinstance(rule, Rule))
sections.append("")
sections.extend(
str(rule) for rule in self.inference_rules if isinstance(rule, PersistentRule)
)
sections.append("")
if self.plans:
sections.append("// --- Plans ---")
# Separate plans by a newline for readability
sections.extend(str(plan) + "\n" for plan in self.plans)
return "\n".join(sections)

View File

@@ -0,0 +1,425 @@
import asyncio
import time
from functools import singledispatchmethod
from slugify import slugify
from control_backend.agents.bdi import BDICoreAgent
from control_backend.agents.bdi.asl_ast import (
ActionLiteral,
AgentSpeakFile,
BeliefLiteral,
BinaryOp,
Expression,
GoalLiteral,
PersistentRule,
Plan,
Rule,
)
from control_backend.agents.bdi.bdi_program_manager import test_program
from control_backend.schemas.program import (
BasicBelief,
Belief,
ConditionalNorm,
GestureAction,
Goal,
InferredBelief,
KeywordBelief,
LLMAction,
LogicalOperator,
Phase,
Program,
ProgramElement,
SemanticBelief,
SpeechAction,
)
async def do_things():
res = input("Wanna generate")
if res == "y":
program = AgentSpeakGenerator().generate(test_program)
filename = f"{int(time.time())}.asl"
with open(filename, "w") as f:
f.write(program)
else:
# filename = "0test.asl"
filename = "1766062491.asl"
bdi_agent = BDICoreAgent("BDICoreAgent", filename)
flag = asyncio.Event()
await bdi_agent.start()
await flag.wait()
def do_other_things():
print(AgentSpeakGenerator().generate(test_program))
class AgentSpeakGenerator:
"""
Converts a Pydantic Program behavior model into an AgentSpeak(L) AST,
then renders it to a string.
"""
def generate(self, program: Program) -> str:
asl = AgentSpeakFile()
self._generate_startup(program, asl)
for i, phase in enumerate(program.phases):
next_phase = program.phases[i + 1] if i < len(program.phases) - 1 else None
self._generate_phase_flow(phase, next_phase, asl)
self._generate_norms(phase, asl)
self._generate_goals(phase, asl)
self._generate_triggers(phase, asl)
self._generate_fallbacks(program, asl)
return str(asl)
# --- Section: Startup & Phase Management ---
def _generate_startup(self, program: Program, asl: AgentSpeakFile):
if not program.phases:
return
# Initial belief: phase(start).
asl.initial_beliefs.append(Rule(head=BeliefLiteral("phase", ['"start"'])))
# Startup plan: +started : phase(start) <- -phase(start); +phase(first_id).
asl.plans.append(
Plan(
trigger=BeliefLiteral("started"),
context=[BeliefLiteral("phase", ['"start"'])],
body=[
ActionLiteral('-phase("start")'),
ActionLiteral(f'+phase("{program.phases[0].id}")'),
],
)
)
# Initial plans:
asl.plans.append(
Plan(
trigger=GoalLiteral("generate_response_with_goal(Goal)"),
context=[BeliefLiteral("user_said", ["Message"])],
body=[
ActionLiteral("+responded_this_turn"),
ActionLiteral(".findall(Norm, norm(Norm), Norms)"),
ActionLiteral(".reply_with_goal(Message, Norms, Goal)"),
],
)
)
def _generate_phase_flow(self, phase: Phase, next_phase: Phase | None, asl: AgentSpeakFile):
"""Generates the main loop listener and the transition logic for this phase."""
# +user_said(Message) : phase(ID) <- !goal1; !goal2; !transition_phase.
goal_actions = [ActionLiteral("-responded_this_turn")]
goal_actions += [
ActionLiteral(f"!check_{self._slugify_str(keyword)}")
for keyword in self._get_keyword_conditionals(phase)
]
goal_actions += [ActionLiteral(f"!{self._slugify(g)}") for g in phase.goals]
goal_actions.append(ActionLiteral("!transition_phase"))
asl.plans.append(
Plan(
trigger=BeliefLiteral("user_said", ["Message"]),
context=[BeliefLiteral("phase", [f'"{phase.id}"'])],
body=goal_actions,
)
)
# +!transition_phase : phase(ID) <- -phase(ID); +(NEXT_ID).
next_id = str(next_phase.id) if next_phase else "end"
transition_context = [BeliefLiteral("phase", [f'"{phase.id}"'])]
if phase.goals:
transition_context.append(BeliefLiteral(f"achieved_{self._slugify(phase.goals[-1])}"))
asl.plans.append(
Plan(
trigger=GoalLiteral("transition_phase"),
context=transition_context,
body=[
ActionLiteral(f'-phase("{phase.id}")'),
ActionLiteral(f'+phase("{next_id}")'),
ActionLiteral("user_said(Anything)"),
ActionLiteral("-+user_said(Anything)"),
],
)
)
def _get_keyword_conditionals(self, phase: Phase) -> list[str]:
res = []
for belief in self._extract_basic_beliefs_from_phase(phase):
if isinstance(belief, KeywordBelief):
res.append(belief.keyword)
return res
# --- Section: Norms & Beliefs ---
def _generate_norms(self, phase: Phase, asl: AgentSpeakFile):
for norm in phase.norms:
norm_slug = f'"{norm.norm}"'
head = BeliefLiteral("norm", [norm_slug])
# Base context is the phase
phase_lit = BeliefLiteral("phase", [f'"{phase.id}"'])
if isinstance(norm, ConditionalNorm):
self._ensure_belief_inference(norm.condition, asl)
condition_expr = self._belief_to_expr(norm.condition)
body = BinaryOp(phase_lit, "&", condition_expr)
else:
body = phase_lit
asl.inference_rules.append(Rule(head=head, body=body))
def _ensure_belief_inference(self, belief: Belief, asl: AgentSpeakFile):
"""
Recursively adds rules to infer beliefs.
Checks strictly to avoid duplicates if necessary,
though ASL engines often handle redefinition or we can use a set to track processed IDs.
"""
if isinstance(belief, KeywordBelief):
pass
# # Rule: keyword_said("word") :- user_said(M) & .substring("word", M, P) & P >= 0.
# kwd_slug = f'"{belief.keyword}"'
# head = BeliefLiteral("keyword_said", [kwd_slug])
#
# # Avoid duplicates
# if any(str(r.head) == str(head) for r in asl.inference_rules):
# return
#
# body = BinaryOp(
# BeliefLiteral("user_said", ["Message"]),
# "&",
# BinaryOp(f".substring({kwd_slug}, Message, Pos)", "&", "Pos >= 0"),
# )
#
# asl.inference_rules.append(Rule(head=head, body=body))
elif isinstance(belief, InferredBelief):
self._ensure_belief_inference(belief.left, asl)
self._ensure_belief_inference(belief.right, asl)
slug = self._slugify(belief)
head = BeliefLiteral(slug)
if any(str(r.head) == str(head) for r in asl.inference_rules):
return
op_char = "&" if belief.operator == LogicalOperator.AND else "|"
body = BinaryOp(
self._belief_to_expr(belief.left), op_char, self._belief_to_expr(belief.right)
)
asl.inference_rules.append(PersistentRule(head=head, body=body))
def _belief_to_expr(self, belief: Belief) -> Expression:
if isinstance(belief, KeywordBelief):
return BeliefLiteral("keyword_said", [f'"{belief.keyword}"'])
else:
return BeliefLiteral(self._slugify(belief))
# --- Section: Goals ---
def _generate_goals(self, phase: Phase, asl: AgentSpeakFile):
previous_goal: Goal | None = None
for goal in phase.goals:
self._generate_goal_plan_recursive(goal, str(phase.id), previous_goal, asl)
previous_goal = goal
def _generate_goal_plan_recursive(
self,
goal: Goal,
phase_id: str,
previous_goal: Goal | None,
asl: AgentSpeakFile,
responded_needed: bool = True,
can_fail: bool = True,
):
goal_slug = self._slugify(goal)
# phase(ID) & not responded_this_turn & not achieved_goal
context = [
BeliefLiteral("phase", [f'"{phase_id}"']),
]
if responded_needed:
context.append(BeliefLiteral("responded_this_turn", negated=True))
if can_fail:
context.append(BeliefLiteral(f"achieved_{goal_slug}", negated=True))
if previous_goal:
prev_slug = self._slugify(previous_goal)
context.append(BeliefLiteral(f"achieved_{prev_slug}"))
body_actions = []
sub_goals_to_process = []
for step in goal.plan.steps:
if isinstance(step, Goal):
sub_slug = self._slugify(step)
body_actions.append(ActionLiteral(f"!{sub_slug}"))
sub_goals_to_process.append(step)
elif isinstance(step, SpeechAction):
body_actions.append(ActionLiteral(f'.say("{step.text}")'))
elif isinstance(step, GestureAction):
body_actions.append(ActionLiteral(f'.gesture("{step.gesture}")'))
elif isinstance(step, LLMAction):
body_actions.append(ActionLiteral(f'!generate_response_with_goal("{step.goal}")'))
# Mark achievement
if not goal.can_fail:
body_actions.append(ActionLiteral(f"+achieved_{goal_slug}"))
asl.plans.append(Plan(trigger=GoalLiteral(goal_slug), context=context, body=body_actions))
asl.plans.append(
Plan(trigger=GoalLiteral(goal_slug), context=[], body=[ActionLiteral("true")])
)
prev_sub = None
for sub_goal in sub_goals_to_process:
self._generate_goal_plan_recursive(sub_goal, phase_id, prev_sub, asl)
prev_sub = sub_goal
# --- Section: Triggers ---
def _generate_triggers(self, phase: Phase, asl: AgentSpeakFile):
for keyword in self._get_keyword_conditionals(phase):
asl.plans.append(
Plan(
trigger=GoalLiteral(f"check_{self._slugify_str(keyword)}"),
context=[
ActionLiteral(
f'user_said(Message) & .substring("{keyword}", Message, Pos) & Pos >= 0'
)
],
body=[
ActionLiteral(f'+keyword_said("{keyword}")'),
ActionLiteral(f'-keyword_said("{keyword}")'),
],
)
)
asl.plans.append(
Plan(
trigger=GoalLiteral(f"check_{self._slugify_str(keyword)}"),
body=[ActionLiteral("true")],
)
)
for trigger in phase.triggers:
self._ensure_belief_inference(trigger.condition, asl)
trigger_belief_slug = self._belief_to_expr(trigger.condition)
body_actions = []
sub_goals = []
for step in trigger.plan.steps:
if isinstance(step, Goal):
sub_slug = self._slugify(step)
body_actions.append(ActionLiteral(f"!{sub_slug}"))
sub_goals.append(step)
elif isinstance(step, SpeechAction):
body_actions.append(ActionLiteral(f'.say("{step.text}")'))
elif isinstance(step, GestureAction):
body_actions.append(
ActionLiteral(f'.gesture("{step.gesture.type}", "{step.gesture.name}")')
)
elif isinstance(step, LLMAction):
body_actions.append(
ActionLiteral(f'!generate_response_with_goal("{step.goal}")')
)
asl.plans.append(
Plan(
trigger=BeliefLiteral(trigger_belief_slug),
context=[BeliefLiteral("phase", [f'"{phase.id}"'])],
body=body_actions,
)
)
# Recurse for triggered goals
prev_sub = None
for sub_goal in sub_goals:
self._generate_goal_plan_recursive(
sub_goal, str(phase.id), prev_sub, asl, False, False
)
prev_sub = sub_goal
# --- Section: Fallbacks ---
def _generate_fallbacks(self, program: Program, asl: AgentSpeakFile):
asl.plans.append(
Plan(trigger=GoalLiteral("transition_phase"), context=[], body=[ActionLiteral("true")])
)
# --- Helpers ---
@singledispatchmethod
def _slugify(self, element: ProgramElement) -> str:
if element.name:
raise NotImplementedError("Cannot slugify this element.")
return self._slugify_str(element.name)
@_slugify.register
def _(self, goal: Goal) -> str:
if goal.name:
return self._slugify_str(goal.name)
return f"goal_{goal.id.hex}"
@_slugify.register
def _(self, kwb: KeywordBelief) -> str:
return f"keyword_said({kwb.keyword})"
@_slugify.register
def _(self, sb: SemanticBelief) -> str:
return self._slugify_str(sb.description)
@_slugify.register
def _(self, ib: InferredBelief) -> str:
return self._slugify_str(ib.name)
def _slugify_str(self, text: str) -> str:
return slugify(text, separator="_", stopwords=["a", "an", "the", "we", "you", "I"])
def _extract_basic_beliefs_from_program(self, program: Program) -> list[BasicBelief]:
beliefs = []
for phase in program.phases:
beliefs.extend(self._extract_basic_beliefs_from_phase(phase))
return beliefs
def _extract_basic_beliefs_from_phase(self, phase: Phase) -> list[BasicBelief]:
beliefs = []
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += self._extract_basic_beliefs_from_belief(norm.condition)
for trigger in phase.triggers:
beliefs += self._extract_basic_beliefs_from_belief(trigger.condition)
return beliefs
def _extract_basic_beliefs_from_belief(self, belief: Belief) -> list[BasicBelief]:
if isinstance(belief, InferredBelief):
return self._extract_basic_beliefs_from_belief(
belief.left
) + self._extract_basic_beliefs_from_belief(belief.right)
return [belief]
if __name__ == "__main__":
asyncio.run(do_things())
# do_other_things()y

View File

@@ -0,0 +1,272 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import StrEnum
class AstNode(ABC):
"""
Abstract base class for all elements of an AgentSpeak program.
"""
@abstractmethod
def _to_agentspeak(self) -> str:
"""
Generates the AgentSpeak code string.
"""
pass
def __str__(self) -> str:
return self._to_agentspeak()
class AstExpression(AstNode, ABC):
"""
Intermediate class for anything that can be used in a logical expression.
"""
def __and__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.AND, _coalesce_expr(other))
def __or__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.OR, _coalesce_expr(other))
def __invert__(self) -> AstLogicalExpression:
if isinstance(self, AstLogicalExpression):
self.negated = not self.negated
return self
return AstLogicalExpression(self, negated=True)
type ExprCoalescible = AstExpression | str | int | float
def _coalesce_expr(value: ExprCoalescible) -> AstExpression:
if isinstance(value, AstExpression):
return value
if isinstance(value, str):
return AstString(value)
if isinstance(value, (int, float)):
return AstNumber(value)
raise TypeError(f"Cannot coalesce type {type(value)} into an AstTerm.")
@dataclass
class AstTerm(AstExpression, ABC):
"""
Base class for terms appearing inside literals.
"""
def __ge__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.GREATER_EQUALS, _coalesce_expr(other))
def __gt__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.GREATER_THAN, _coalesce_expr(other))
def __le__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.LESS_EQUALS, _coalesce_expr(other))
def __lt__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.LESS_THAN, _coalesce_expr(other))
def __eq__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.EQUALS, _coalesce_expr(other))
def __ne__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.NOT_EQUALS, _coalesce_expr(other))
@dataclass
class AstAtom(AstTerm):
"""
Grounded expression in all lowercase.
"""
value: str
def _to_agentspeak(self) -> str:
return self.value.lower()
@dataclass
class AstVar(AstTerm):
"""
Ungrounded variable expression. First letter capitalized.
"""
name: str
def _to_agentspeak(self) -> str:
return self.name.capitalize()
@dataclass
class AstNumber(AstTerm):
value: int | float
def _to_agentspeak(self) -> str:
return str(self.value)
@dataclass
class AstString(AstTerm):
value: str
def _to_agentspeak(self) -> str:
return f'"{self.value}"'
@dataclass
class AstLiteral(AstTerm):
functor: str
terms: list[AstTerm] = field(default_factory=list)
def _to_agentspeak(self) -> str:
if not self.terms:
return self.functor
args = ", ".join(map(str, self.terms))
return f"{self.functor}({args})"
class BinaryOperatorType(StrEnum):
AND = "&"
OR = "|"
GREATER_THAN = ">"
LESS_THAN = "<"
EQUALS = "=="
NOT_EQUALS = "\\=="
GREATER_EQUALS = ">="
LESS_EQUALS = "<="
@dataclass
class AstBinaryOp(AstExpression):
left: AstExpression
operator: BinaryOperatorType
right: AstExpression
def __post_init__(self):
self.left = _as_logical(self.left)
self.right = _as_logical(self.right)
def _to_agentspeak(self) -> str:
l_str = str(self.left)
r_str = str(self.right)
assert isinstance(self.left, AstLogicalExpression)
assert isinstance(self.right, AstLogicalExpression)
if isinstance(self.left.expression, AstBinaryOp) or self.left.negated:
l_str = f"({l_str})"
if isinstance(self.right.expression, AstBinaryOp) or self.right.negated:
r_str = f"({r_str})"
return f"{l_str} {self.operator.value} {r_str}"
@dataclass
class AstLogicalExpression(AstExpression):
expression: AstExpression
negated: bool = False
def _to_agentspeak(self) -> str:
expr_str = str(self.expression)
if isinstance(self.expression, AstBinaryOp) and self.negated:
expr_str = f"({expr_str})"
return f"{'not ' if self.negated else ''}{expr_str}"
def _as_logical(expr: AstExpression) -> AstLogicalExpression:
if isinstance(expr, AstLogicalExpression):
return expr
return AstLogicalExpression(expr)
class StatementType(StrEnum):
EMPTY = ""
DO_ACTION = "."
ACHIEVE_GOAL = "!"
# TEST_GOAL = "?" # TODO
ADD_BELIEF = "+"
REMOVE_BELIEF = "-"
@dataclass
class AstStatement(AstNode):
"""
A statement that can appear inside a plan.
"""
type: StatementType
expression: AstExpression
def _to_agentspeak(self) -> str:
return f"{self.type.value}{self.expression}"
@dataclass
class AstRule(AstNode):
result: AstExpression
condition: AstExpression | None = None
def __post_init__(self):
if self.condition is not None:
self.condition = _as_logical(self.condition)
def _to_agentspeak(self) -> str:
if not self.condition:
return f"{self.result}."
return f"{self.result} :- {self.condition}."
class TriggerType(StrEnum):
ADDED_BELIEF = "+"
# REMOVED_BELIEF = "-" # TODO
# MODIFIED_BELIEF = "^" # TODO
ADDED_GOAL = "+!"
# REMOVED_GOAL = "-!" # TODO
@dataclass
class AstPlan(AstNode):
type: TriggerType
trigger_literal: AstExpression
context: list[AstExpression]
body: list[AstStatement]
def _to_agentspeak(self) -> str:
assert isinstance(self.trigger_literal, AstLiteral)
indent = " " * 6
colon = " : "
arrow = " <- "
lines = []
lines.append(f"{self.type.value}{self.trigger_literal}")
if self.context:
lines.append(colon + f" &\n{indent}".join(str(c) for c in self.context))
if self.body:
lines.append(arrow + f";\n{indent}".join(str(s) for s in self.body) + ".")
lines.append("")
return "\n".join(lines)
@dataclass
class AstProgram(AstNode):
rules: list[AstRule] = field(default_factory=list)
plans: list[AstPlan] = field(default_factory=list)
def _to_agentspeak(self) -> str:
lines = []
lines.extend(map(str, self.rules))
lines.extend(["", ""])
lines.extend(map(str, self.plans))
return "\n".join(lines)

View File

@@ -11,7 +11,7 @@ from pydantic import ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
from control_backend.schemas.ri_message import SpeechCommand
@@ -124,8 +124,8 @@ class BDICoreAgent(BaseAgent):
if msg.thread == "beliefs":
try:
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
self._apply_beliefs(beliefs)
belief_changes = BeliefMessage.model_validate_json(msg.body)
self._apply_belief_changes(belief_changes)
except ValidationError:
self.logger.exception("Error processing belief.")
return
@@ -145,22 +145,29 @@ class BDICoreAgent(BaseAgent):
)
await self.send(out_msg)
def _apply_beliefs(self, beliefs: list[Belief]):
def _apply_belief_changes(self, belief_changes: BeliefMessage):
"""
Update the belief base with a list of new beliefs.
If ``replace=True`` is set on a belief, it removes all existing beliefs with that name
before adding the new one.
For beliefs in ``belief_changes.replace``, it removes all existing beliefs with that name
before adding one new one.
:param belief_changes: The changes in beliefs to apply.
"""
if not beliefs:
if not belief_changes.create and not belief_changes.replace and not belief_changes.delete:
return
for belief in beliefs:
if belief.replace:
self._remove_all_with_name(belief.name)
for belief in belief_changes.create:
self._add_belief(belief.name, belief.arguments)
def _add_belief(self, name: str, args: Iterable[str] = []):
for belief in belief_changes.replace:
self._remove_all_with_name(belief.name)
self._add_belief(belief.name, belief.arguments)
for belief in belief_changes.delete:
self._remove_belief(belief.name, belief.arguments)
def _add_belief(self, name: str, args: list[str] = None):
"""
Add a single belief to the BDI agent.
@@ -168,9 +175,13 @@ class BDICoreAgent(BaseAgent):
:param args: Arguments for the belief.
"""
# new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args)
args = args or []
if args:
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args)
else:
term = agentspeak.Literal(name)
self.bdi_agent.call(
agentspeak.Trigger.addition,
@@ -238,8 +249,7 @@ class BDICoreAgent(BaseAgent):
@self.actions.add(".reply", 3)
def _reply(agent: "BDICoreAgent", term, intention):
"""
Sends text to the LLM (AgentSpeak action).
Example: .reply("Hello LLM!", "Some norm", "Some goal")
Let the LLM generate a response to a user's utterance with the current norms and goals.
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
@@ -252,15 +262,71 @@ class BDICoreAgent(BaseAgent):
asyncio.create_task(self._send_to_llm(str(message_text), str(norms), str(goals)))
yield
async def _send_to_llm(self, text: str, norms: str = None, goals: str = None):
@self.actions.add(".reply_with_goal", 3)
def _reply_with_goal(agent: "BDICoreAgent", term, intention):
"""
Let the LLM generate a response to a user's utterance with the current norms and a
specific goal.
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
goal = agentspeak.grounded(term.args[2], intention.scope)
self.logger.debug(
'"reply_with_goal" action called with message=%s, norms=%s, goal=%s',
message_text,
norms,
goal,
)
# asyncio.create_task(self._send_to_llm(str(message_text), norms, str(goal)))
yield
@self.actions.add(".say", 1)
def _say(agent: "BDICoreAgent", term, intention):
"""
Make the robot say the given text instantly.
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug('"say" action called with text=%s', message_text)
# speech_command = SpeechCommand(data=message_text)
# speech_message = InternalMessage(
# to=settings.agent_settings.robot_speech_name,
# sender=settings.agent_settings.bdi_core_name,
# body=speech_command.model_dump_json(),
# )
# asyncio.create_task(agent.send(speech_message))
yield
@self.actions.add(".gesture", 2)
def _gesture(agent: "BDICoreAgent", term, intention):
"""
Make the robot perform the given gesture instantly.
"""
gesture_type = agentspeak.grounded(term.args[0], intention.scope)
gesture_name = agentspeak.grounded(term.args[1], intention.scope)
self.logger.debug(
'"gesture" action called with type=%s, name=%s',
gesture_type,
gesture_name,
)
# gesture = Gesture(type=gesture_type, name=gesture_name)
# gesture_message = InternalMessage(
# to=settings.agent_settings.robot_gesture_name,
# sender=settings.agent_settings.bdi_core_name,
# body=gesture.model_dump_json(),
# )
# asyncio.create_task(agent.send(gesture_message))
yield
async def _send_to_llm(self, text: str, norms: str, goals: str):
"""
Sends a text query to the LLM agent asynchronously.
"""
prompt = LLMPromptMessage(
text=text,
norms=norms.split("\n") if norms else [],
goals=goals.split("\n") if norms else [],
)
prompt = LLMPromptMessage(text=text, norms=norms.split("\n"), goals=goals.split("\n"))
msg = InternalMessage(
to=settings.agent_settings.llm_name,
sender=self.name,

View File

@@ -1,12 +1,598 @@
import uuid
from collections.abc import Iterable
import zmq
from pydantic import ValidationError
from slugify import slugify
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.program import Program
from control_backend.schemas.program import (
Action,
BasicBelief,
BasicNorm,
Belief,
ConditionalNorm,
GestureAction,
Goal,
InferredBelief,
KeywordBelief,
LLMAction,
LogicalOperator,
Phase,
Plan,
Program,
ProgramElement,
SemanticBelief,
SpeechAction,
Trigger,
)
test_program = Program(
phases=[
Phase(
norms=[
BasicNorm(norm="Talk like a pirate", id=uuid.uuid4()),
ConditionalNorm(
condition=InferredBelief(
left=KeywordBelief(keyword="Arr", id=uuid.uuid4()),
right=SemanticBelief(
description="testing", name="semantic belief", id=uuid.uuid4()
),
operator=LogicalOperator.OR,
name="Talking to a pirate",
id=uuid.uuid4(),
),
norm="Use nautical terms",
id=uuid.uuid4(),
),
ConditionalNorm(
condition=SemanticBelief(
description="We are talking to a child",
name="talking to child",
id=uuid.uuid4(),
),
norm="Do not use cuss words",
id=uuid.uuid4(),
),
],
triggers=[
Trigger(
condition=InferredBelief(
left=KeywordBelief(keyword="key", id=uuid.uuid4()),
right=InferredBelief(
left=KeywordBelief(keyword="key2", id=uuid.uuid4()),
right=SemanticBelief(
description="Decode this", name="semantic belief 2", id=uuid.uuid4()
),
operator=LogicalOperator.OR,
name="test trigger inferred inner",
id=uuid.uuid4(),
),
operator=LogicalOperator.OR,
name="test trigger inferred outer",
id=uuid.uuid4(),
),
plan=Plan(
steps=[
SpeechAction(text="Testing trigger", id=uuid.uuid4()),
Goal(
name="Testing trigger",
plan=Plan(
steps=[LLMAction(goal="Do something", id=uuid.uuid4())],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
)
],
goals=[
Goal(
name="Determine user age",
plan=Plan(
steps=[LLMAction(goal="Determine the age of the user.", id=uuid.uuid4())],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
Goal(
name="Find the user's name",
plan=Plan(
steps=[
Goal(
name="Greet the user",
plan=Plan(
steps=[LLMAction(goal="Greet the user.", id=uuid.uuid4())],
id=uuid.uuid4(),
),
can_fail=False,
id=uuid.uuid4(),
),
Goal(
name="Ask for name",
plan=Plan(
steps=[
LLMAction(goal="Obtain the user's name.", id=uuid.uuid4())
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
Goal(
name="Tell a joke",
plan=Plan(
steps=[LLMAction(goal="Tell a joke.", id=uuid.uuid4())], id=uuid.uuid4()
),
id=uuid.uuid4(),
),
],
id=uuid.uuid4(),
),
Phase(
id=uuid.uuid4(),
norms=[
BasicNorm(norm="Use very gentle speech.", id=uuid.uuid4()),
ConditionalNorm(
condition=SemanticBelief(
description="We are talking to a child",
name="talking to child",
id=uuid.uuid4(),
),
norm="Do not use cuss words",
id=uuid.uuid4(),
),
],
triggers=[
Trigger(
condition=InferredBelief(
left=KeywordBelief(keyword="help", id=uuid.uuid4()),
right=SemanticBelief(
description="User is stuck", name="stuck", id=uuid.uuid4()
),
operator=LogicalOperator.OR,
name="help_or_stuck",
id=uuid.uuid4(),
),
plan=Plan(
steps=[
Goal(
name="Unblock user",
plan=Plan(
steps=[
LLMAction(
goal="Provide a step-by-step path to "
"resolve the user's issue.",
id=uuid.uuid4(),
)
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
goals=[
Goal(
name="Clarify intent",
plan=Plan(
steps=[
LLMAction(
goal="Ask 1-2 targeted questions to clarify the "
"user's intent, then proceed.",
id=uuid.uuid4(),
)
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
Goal(
name="Provide solution",
plan=Plan(
steps=[
LLMAction(
goal="Deliver a solution to complete the user's goal.",
id=uuid.uuid4(),
)
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
Goal(
name="Summarize next steps",
plan=Plan(
steps=[
LLMAction(
goal="Summarize what the user should do next.", id=uuid.uuid4()
)
],
id=uuid.uuid4(),
),
id=uuid.uuid4(),
),
],
),
]
)
def do_things():
print(AgentSpeakGenerator().generate(test_program))
class AgentSpeakGenerator:
"""
Converts Pydantic representation of behavior programs into AgentSpeak(L) code string.
"""
arrow_prefix = f"{' ' * 2}<-{' ' * 2}"
colon_prefix = f"{' ' * 2}:{' ' * 3}"
indent_prefix = " " * 6
def generate(self, program: Program) -> str:
lines = []
lines.append("")
lines += self._generate_initial_beliefs(program)
lines += self._generate_basic_flow(program)
lines += self._generate_phase_transitions(program)
lines += self._generate_norms(program)
lines += self._generate_belief_inference(program)
lines += self._generate_goals(program)
lines += self._generate_triggers(program)
return "\n".join(lines)
def _generate_initial_beliefs(self, program: Program) -> Iterable[str]:
yield "// --- Initial beliefs and agent startup ---"
yield "phase(start)."
yield ""
yield "+started"
yield f"{self.colon_prefix}phase(start)"
yield f"{self.arrow_prefix}phase({program.phases[0].id if program.phases else 'end'})."
yield from ["", ""]
def _generate_basic_flow(self, program: Program) -> Iterable[str]:
yield "// --- Basic flow ---"
for phase in program.phases:
yield from self._generate_basic_flow_per_phase(phase)
yield from ["", ""]
def _generate_basic_flow_per_phase(self, phase: Phase) -> Iterable[str]:
yield "+user_said(Message)"
yield f"{self.colon_prefix}phase({phase.id})"
goals = phase.goals
if goals:
yield f"{self.arrow_prefix}{self._slugify(goals[0], include_prefix=True)}"
for goal in goals[1:]:
yield f"{self.indent_prefix}{self._slugify(goal, include_prefix=True)}"
yield f"{self.indent_prefix if goals else self.arrow_prefix}!transition_phase."
def _generate_phase_transitions(self, program: Program) -> Iterable[str]:
yield "// --- Phase transitions ---"
if len(program.phases) == 0:
yield from ["", ""]
return
# TODO: remove outdated things
for i in range(-1, len(program.phases)):
predecessor = program.phases[i] if i >= 0 else None
successor = program.phases[i + 1] if i < len(program.phases) - 1 else None
yield from self._generate_phase_transition(predecessor, successor)
yield from self._generate_phase_transition(None, None) # to avoid failing plan
yield from ["", ""]
def _generate_phase_transition(
self, phase: Phase | None = None, next_phase: Phase | None = None
) -> Iterable[str]:
yield "+!transition_phase"
if phase is None and next_phase is None: # base case true to avoid failing plan
yield f"{self.arrow_prefix}true."
return
yield f"{self.colon_prefix}phase({phase.id if phase else 'start'})"
yield f"{self.arrow_prefix}-+phase({next_phase.id if next_phase else 'end'})."
def _generate_norms(self, program: Program) -> Iterable[str]:
yield "// --- Norms ---"
for phase in program.phases:
for norm in phase.norms:
if type(norm) is BasicNorm:
yield f"{self._slugify(norm)} :- phase({phase.id})."
if type(norm) is ConditionalNorm:
yield (
f"{self._slugify(norm)} :- phase({phase.id}) & "
f"{self._slugify(norm.condition)}."
)
yield from ["", ""]
def _generate_belief_inference(self, program: Program) -> Iterable[str]:
yield "// --- Belief inference rules ---"
for phase in program.phases:
for norm in phase.norms:
if not isinstance(norm, ConditionalNorm):
continue
yield from self._belief_inference_recursive(norm.condition)
for trigger in phase.triggers:
yield from self._belief_inference_recursive(trigger.condition)
yield from ["", ""]
def _belief_inference_recursive(self, belief: Belief) -> Iterable[str]:
if type(belief) is KeywordBelief:
yield (
f"{self._slugify(belief)} :- user_said(Message) & "
f'.substring(Message, "{belief.keyword}", Pos) & Pos >= 0.'
)
if type(belief) is InferredBelief:
yield (
f"{self._slugify(belief)} :- {self._slugify(belief.left)} "
f"{'&' if belief.operator == LogicalOperator.AND else '|'} "
f"{self._slugify(belief.right)}."
)
yield from self._belief_inference_recursive(belief.left)
yield from self._belief_inference_recursive(belief.right)
def _generate_goals(self, program: Program) -> Iterable[str]:
yield "// --- Goals ---"
for phase in program.phases:
previous_goal: Goal | None = None
for goal in phase.goals:
yield from self._generate_goal_plan_recursive(goal, phase, previous_goal)
previous_goal = goal
yield from ["", ""]
def _generate_goal_plan_recursive(
self, goal: Goal, phase: Phase, previous_goal: Goal | None = None
) -> Iterable[str]:
yield f"+{self._slugify(goal, include_prefix=True)}"
# Context
yield f"{self.colon_prefix}phase({phase.id}) &"
yield f"{self.indent_prefix}not responded_this_turn &"
yield f"{self.indent_prefix}not achieved_{self._slugify(goal)} &"
if previous_goal:
yield f"{self.indent_prefix}achieved_{self._slugify(previous_goal)}"
else:
yield f"{self.indent_prefix}true"
extra_goals_to_generate = []
steps = goal.plan.steps
if len(steps) == 0:
yield f"{self.arrow_prefix}true."
return
first_step = steps[0]
yield (
f"{self.arrow_prefix}{self._slugify(first_step, include_prefix=True)}"
f"{'.' if len(steps) == 1 and goal.can_fail else ';'}"
)
if isinstance(first_step, Goal):
extra_goals_to_generate.append(first_step)
for step in steps[1:-1]:
yield f"{self.indent_prefix}{self._slugify(step, include_prefix=True)};"
if isinstance(step, Goal):
extra_goals_to_generate.append(step)
if len(steps) > 1:
last_step = steps[-1]
yield (
f"{self.indent_prefix}{self._slugify(last_step, include_prefix=True)}"
f"{'.' if goal.can_fail else ';'}"
)
if isinstance(last_step, Goal):
extra_goals_to_generate.append(last_step)
if not goal.can_fail:
yield f"{self.indent_prefix}+achieved_{self._slugify(goal)}."
yield f"+{self._slugify(goal, include_prefix=True)}"
yield f"{self.arrow_prefix}true."
yield ""
extra_previous_goal: Goal | None = None
for extra_goal in extra_goals_to_generate:
yield from self._generate_goal_plan_recursive(extra_goal, phase, extra_previous_goal)
extra_previous_goal = extra_goal
def _generate_triggers(self, program: Program) -> Iterable[str]:
yield "// --- Triggers ---"
for phase in program.phases:
for trigger in phase.triggers:
yield from self._generate_trigger_plan(trigger, phase)
yield from ["", ""]
def _generate_trigger_plan(self, trigger: Trigger, phase: Phase) -> Iterable[str]:
belief_name = self._slugify(trigger.condition)
yield f"+{belief_name}"
yield f"{self.colon_prefix}phase({phase.id})"
extra_goals_to_generate = []
steps = trigger.plan.steps
if len(steps) == 0:
yield f"{self.arrow_prefix}true."
return
first_step = steps[0]
yield (
f"{self.arrow_prefix}{self._slugify(first_step, include_prefix=True)}"
f"{'.' if len(steps) == 1 else ';'}"
)
if isinstance(first_step, Goal):
extra_goals_to_generate.append(first_step)
for step in steps[1:-1]:
yield f"{self.indent_prefix}{self._slugify(step, include_prefix=True)};"
if isinstance(step, Goal):
extra_goals_to_generate.append(step)
if len(steps) > 1:
last_step = steps[-1]
yield f"{self.indent_prefix}{self._slugify(last_step, include_prefix=True)}."
if isinstance(last_step, Goal):
extra_goals_to_generate.append(last_step)
yield ""
extra_previous_goal: Goal | None = None
for extra_goal in extra_goals_to_generate:
yield from self._generate_trigger_plan_recursive(extra_goal, phase, extra_previous_goal)
extra_previous_goal = extra_goal
def _generate_trigger_plan_recursive(
self, goal: Goal, phase: Phase, previous_goal: Goal | None = None
) -> Iterable[str]:
yield f"+{self._slugify(goal, include_prefix=True)}"
extra_goals_to_generate = []
steps = goal.plan.steps
if len(steps) == 0:
yield f"{self.arrow_prefix}true."
return
first_step = steps[0]
yield (
f"{self.arrow_prefix}{self._slugify(first_step, include_prefix=True)}"
f"{'.' if len(steps) == 1 and goal.can_fail else ';'}"
)
if isinstance(first_step, Goal):
extra_goals_to_generate.append(first_step)
for step in steps[1:-1]:
yield f"{self.indent_prefix}{self._slugify(step, include_prefix=True)};"
if isinstance(step, Goal):
extra_goals_to_generate.append(step)
if len(steps) > 1:
last_step = steps[-1]
yield (
f"{self.indent_prefix}{self._slugify(last_step, include_prefix=True)}"
f"{'.' if goal.can_fail else ';'}"
)
if isinstance(last_step, Goal):
extra_goals_to_generate.append(last_step)
if not goal.can_fail:
yield f"{self.indent_prefix}+achieved_{self._slugify(goal)}."
yield f"+{self._slugify(goal, include_prefix=True)}"
yield f"{self.arrow_prefix}true."
yield ""
extra_previous_goal: Goal | None = None
for extra_goal in extra_goals_to_generate:
yield from self._generate_goal_plan_recursive(extra_goal, phase, extra_previous_goal)
extra_previous_goal = extra_goal
def _slugify(self, element: ProgramElement, include_prefix: bool = False) -> str:
def base_slugify_call(text: str):
return slugify(text, separator="_", stopwords=["a", "the"])
if type(element) is KeywordBelief:
return f'keyword_said("{element.keyword}")'
if type(element) is SemanticBelief:
name = element.name
return f"semantic_{base_slugify_call(name if name else element.description)}"
if isinstance(element, BasicNorm):
return f'norm("{element.norm}")'
if isinstance(element, Goal):
return f"{'!' if include_prefix else ''}{base_slugify_call(element.name)}"
if isinstance(element, SpeechAction):
return f'.say("{element.text}")'
if isinstance(element, GestureAction):
return f'.gesture("{element.gesture}")'
if isinstance(element, LLMAction):
return f'!generate_response_with_goal("{element.goal}")'
if isinstance(element, Action.__value__):
raise NotImplementedError(
"Have not implemented an ASL string representation for this action."
)
if element.name == "":
raise ValueError("Name must be initialized for this type of ProgramElement.")
return base_slugify_call(element.name)
def _extract_basic_beliefs_from_program(self, program: Program) -> list[BasicBelief]:
beliefs = []
for phase in program.phases:
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += self._extract_basic_beliefs_from_belief(norm.condition)
for trigger in phase.triggers:
beliefs += self._extract_basic_beliefs_from_belief(trigger.condition)
return beliefs
def _extract_basic_beliefs_from_belief(self, belief: Belief) -> list[BasicBelief]:
if isinstance(belief, InferredBelief):
return self._extract_basic_beliefs_from_belief(
belief.left
) + self._extract_basic_beliefs_from_belief(belief.right)
return [belief]
class BDIProgramManager(BaseAgent):
@@ -25,40 +611,40 @@ class BDIProgramManager(BaseAgent):
super().__init__(**kwargs)
self.sub_socket = None
async def _send_to_bdi(self, program: Program):
"""
Convert a received program into BDI beliefs and send them to the BDI Core Agent.
Currently, it takes the **first phase** of the program and extracts:
- **Norms**: Constraints or rules the agent must follow.
- **Goals**: Objectives the agent must achieve.
These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will
overwrite any existing norms/goals of the same name in the BDI agent.
:param program: The program object received from the API.
"""
first_phase = program.phases[0]
norms_belief = Belief(
name="norms",
arguments=[norm.norm for norm in first_phase.norms],
replace=True,
)
goals_belief = Belief(
name="goals",
arguments=[goal.description for goal in first_phase.goals],
replace=True,
)
program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief])
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=program_beliefs.model_dump_json(),
thread="beliefs",
)
await self.send(message)
self.logger.debug("Sent new norms and goals to the BDI agent.")
# async def _send_to_bdi(self, program: Program):
# """
# Convert a received program into BDI beliefs and send them to the BDI Core Agent.
#
# Currently, it takes the **first phase** of the program and extracts:
# - **Norms**: Constraints or rules the agent must follow.
# - **Goals**: Objectives the agent must achieve.
#
# These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will
# overwrite any existing norms/goals of the same name in the BDI agent.
#
# :param program: The program object received from the API.
# """
# first_phase = program.phases[0]
# norms_belief = Belief(
# name="norms",
# arguments=[norm.norm for norm in first_phase.norms],
# replace=True,
# )
# goals_belief = Belief(
# name="goals",
# arguments=[goal.description for goal in first_phase.goals],
# replace=True,
# )
# program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief])
#
# message = InternalMessage(
# to=settings.agent_settings.bdi_core_name,
# sender=self.name,
# body=program_beliefs.model_dump_json(),
# thread="beliefs",
# )
# await self.send(message)
# self.logger.debug("Sent new norms and goals to the BDI agent.")
async def _receive_programs(self):
"""
@@ -92,3 +678,7 @@ class BDIProgramManager(BaseAgent):
self.sub_socket.subscribe("program")
self.add_behavior(self._receive_programs())
if __name__ == "__main__":
do_things()

View File

@@ -144,7 +144,7 @@ class BDIBeliefCollectorAgent(BaseAgent):
msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
body=BeliefMessage(create=beliefs).model_dump_json(),
thread="beliefs",
)

View File

View File

View File

@@ -1,8 +1,23 @@
import asyncio
import json
import httpx
from pydantic import ValidationError
from slugify import slugify
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import (
Belief,
ConditionalNorm,
InferredBelief,
Program,
SemanticBelief,
)
class TextBeliefExtractorAgent(BaseAgent):
@@ -12,46 +27,110 @@ class TextBeliefExtractorAgent(BaseAgent):
This agent is responsible for processing raw text (e.g., from speech transcription) and
extracting semantic beliefs from it.
In the current demonstration version, it performs a simple wrapping of the user's input
into a ``user_said`` belief. In a full implementation, this agent would likely interact
with an LLM or NLU engine to extract intent, entities, and other structured information.
It uses the available beliefs received from the program manager to try to extract beliefs from a
user's message, sends and updated beliefs to the BDI core, and forms a ``user_said`` belief from
the message itself.
"""
def __init__(self, name: str):
super().__init__(name)
self.beliefs: dict[str, bool] = {}
self.available_beliefs: list[SemanticBelief] = []
self.conversation = ChatHistory(messages=[])
async def setup(self):
"""
Initialize the agent and its resources.
"""
self.logger.info("Settting up %s.", self.name)
# Setup LLM belief context if needed (currently demo is just passthrough)
self.beliefs = {"mood": ["X"], "car": ["Y"]}
self.logger.info("Setting up %s.", self.name)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages, primarily from the Transcription Agent.
Handle incoming messages. Expect messages from the Transcriber agent, LLM agent, and the
Program manager agent.
:param msg: The received message containing transcribed text.
:param msg: The received message.
"""
sender = msg.sender
if sender == settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body)
else:
self.logger.info("Discarding message from %s", sender)
async def _process_transcription_demo(self, txt: str):
match sender:
case settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="user", content=msg.body))
await self._infer_new_beliefs()
await self._user_said(msg.body)
case settings.agent_settings.llm_name:
self.logger.debug("Received text from LLM: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
case settings.agent_settings.bdi_program_manager_name:
self._handle_program_manager_message(msg)
case _:
self.logger.info("Discarding message from %s", sender)
return
def _apply_conversation_message(self, message: ChatMessage):
"""
Process the transcribed text and generate beliefs.
Save the chat message to our conversation history, taking into account the conversation
length limit.
**Demo Implementation:**
Currently, this method takes the raw text ``txt`` and wraps it into a belief structure:
``user_said("txt")``.
This belief is then sent to the :class:`BDIBeliefCollectorAgent`.
:param txt: The raw transcribed text string.
:param message: The chat message to add to the conversation history.
"""
# For demo, just wrapping user text as user_said belief
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
length_limit = settings.behaviour_settings.conversation_history_length_limit
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:]
def _handle_program_manager_message(self, msg: InternalMessage):
"""
Handle a message from the program manager: extract available beliefs from it.
:param msg: The received message from the program manager.
"""
try:
program = Program.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received message from program manager but it is not a valid program."
)
return
self.logger.debug("Received a program from the program manager.")
self.available_beliefs = self._extract_basic_beliefs_from_program(program)
# TODO Copied from an incomplete version of the program manager. Use that one instead.
@staticmethod
def _extract_basic_beliefs_from_program(program: Program) -> list[SemanticBelief]:
beliefs = []
for phase in program.phases:
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
norm.condition
)
for trigger in phase.triggers:
beliefs += TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
trigger.condition
)
return beliefs
# TODO Copied from an incomplete version of the program manager. Use that one instead.
@staticmethod
def _extract_basic_beliefs_from_belief(belief: Belief) -> list[SemanticBelief]:
if isinstance(belief, InferredBelief):
return TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
belief.left
) + TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(belief.right)
return [belief]
async def _user_said(self, text: str):
"""
Create a belief for the user's full speech.
:param text: User's transcribed text.
"""
belief = {"beliefs": {"user_said": [text]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = InternalMessage(
@@ -60,6 +139,207 @@ class TextBeliefExtractorAgent(BaseAgent):
body=payload,
thread="beliefs",
)
await self.send(belief_msg)
self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"]))
async def _infer_new_beliefs(self):
"""
Process conversation history to extract beliefs, semantically. Any changed beliefs are sent
to the BDI core.
"""
# Return instantly if there are no beliefs to infer
if not self.available_beliefs:
return
candidate_beliefs = await self._infer_turn()
belief_changes = BeliefMessage()
for belief_key, belief_value in candidate_beliefs.items():
if belief_value is None:
continue
old_belief_value = self.beliefs.get(belief_key)
if belief_value == old_belief_value:
continue
self.beliefs[belief_key] = belief_value
belief = InternalBelief(name=belief_key, arguments=None)
if belief_value:
belief_changes.create.append(belief)
else:
belief_changes.delete.append(belief)
# Return if there were no changes in beliefs
if not belief_changes.has_values():
return
beliefs_message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=belief_changes.model_dump_json(),
thread="beliefs",
)
await self.send(beliefs_message)
@staticmethod
def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]:
k, m = divmod(len(items), n)
return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]
async def _infer_turn(self) -> dict:
"""
Process the stored conversation history to extract semantic beliefs. Returns a list of
beliefs that have been set to ``True``, ``False`` or ``None``.
:return: A dict mapping belief names to a value ``True``, ``False`` or ``None``.
"""
n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)))
all_beliefs = await asyncio.gather(
*[
self._infer_beliefs(self.conversation, beliefs)
for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel)
]
)
retval = {}
for beliefs in all_beliefs:
if beliefs is None:
continue
retval.update(beliefs)
return retval
@staticmethod
def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]:
# TODO: use real belief names
return belief.name or slugify(belief.description), {
"type": ["boolean", "null"],
"description": belief.description,
}
@staticmethod
def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict:
belief_schemas = [
TextBeliefExtractorAgent._create_belief_schema(belief) for belief in beliefs
]
return {
"type": "object",
"properties": dict(belief_schemas),
"required": [name for name, _ in belief_schemas],
}
@staticmethod
def _format_message(message: ChatMessage):
return f"{message.role.upper()}:\n{message.content}"
@staticmethod
def _format_conversation(conversation: ChatHistory):
return "\n\n".join(
[TextBeliefExtractorAgent._format_message(message) for message in conversation.messages]
)
@staticmethod
def _format_beliefs(beliefs: list[SemanticBelief]):
# TODO: use real belief names
return "\n".join(
[
f"- {belief.name or slugify(belief.description)}: {belief.description}"
for belief in beliefs
]
)
async def _infer_beliefs(
self,
conversation: ChatHistory,
beliefs: list[SemanticBelief],
) -> dict | None:
"""
Infer given beliefs based on the given conversation.
:param conversation: The conversation to infer beliefs from.
:param beliefs: The beliefs to infer.
:return: A dict containing belief names and a boolean whether they hold, or None if the
belief cannot be inferred based on the given conversation.
"""
example = {
"example_belief": True,
}
prompt = f"""{self._format_conversation(conversation)}
Given the above conversation, what beliefs can be inferred?
If there is no relevant information about a belief belief, give null.
In case messages conflict, prefer using the most recent messages for inference.
Choose from the following list of beliefs, formatted as (belief_name, description):
{self._format_beliefs(beliefs)}
Respond with a JSON similar to the following, but with the property names as given above:
{json.dumps(example, indent=2)}
"""
schema = self._create_beliefs_schema(beliefs)
return await self._retry_query_llm(prompt, schema)
async def _retry_query_llm(self, prompt: str, schema: dict, tries: int = 3) -> dict | None:
"""
Query the LLM with the given prompt and schema, return an instance of a dict conforming
to this schema. Try ``tries`` times, or return None.
:param prompt: Prompt to be queried.
:param schema: Schema to be queried.
:return: An instance of a dict conforming to this schema, or None if failed.
"""
try_count = 0
while try_count < tries:
try_count += 1
try:
return await self._query_llm(prompt, schema)
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
if try_count < tries:
continue
self.logger.exception(
"Failed to get LLM response after %d tries.",
try_count,
exc_info=e,
)
return None
@staticmethod
async def _query_llm(prompt: str, schema: dict) -> dict:
"""
Query an LLM with the given prompt and schema, return an instance of a dict conforming to
that schema.
:param prompt: The prompt to be queried.
:param schema: Schema to use during response.
:return: A dict conforming to this schema.
:raises httpx.HTTPStatusError: If the LLM server responded with an error.
:raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the
response was cut off early due to length limitations.
:raises KeyError: If the LLM server responded with no error, but the response was invalid.
"""
async with httpx.AsyncClient() as client:
response = await client.post(
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": [{"role": "user", "content": prompt}],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "Beliefs",
"strict": True,
"schema": schema,
},
},
"reasoning_effort": "low",
"temperature": settings.llm_settings.code_temperature,
"stream": False,
},
timeout=None,
)
response.raise_for_status()
response_json = response.json()
json_message = response_json["choices"][0]["message"]["content"]
return json.loads(json_message)

View File

@@ -3,13 +3,17 @@ import json
import zmq
import zmq.asyncio as azmq
from pydantic import ValidationError
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import PauseCommand
from ..actuation.robot_speech_agent import RobotSpeechAgent
from ..perception import VADAgent
class RICommunicationAgent(BaseAgent):
@@ -181,6 +185,7 @@ class RICommunicationAgent(BaseAgent):
self._req_socket.bind(addr)
case "actuation":
gesture_data = port_data.get("gestures", [])
single_gesture_data = port_data.get("single_gestures", [])
robot_speech_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name,
address=addr,
@@ -191,10 +196,14 @@ class RICommunicationAgent(BaseAgent):
address=addr,
bind=bind,
gesture_data=gesture_data,
single_gesture_data=single_gesture_data,
)
await robot_speech_agent.start()
await asyncio.sleep(0.1) # Small delay
await robot_gesture_agent.start()
case "audio":
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
await vad_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)
@@ -292,3 +301,11 @@ class RICommunicationAgent(BaseAgent):
self.logger.debug("Restarting communication negotiation.")
if await self._negotiate_connection(max_retries=1):
self.connected = True
async def handle_message(self, msg : InternalMessage):
try:
pause_command = PauseCommand.model_validate_json(msg.body)
self._req_socket.send_json(pause_command.model_dump())
self.logger.debug(self._req_socket.recv_json())
except ValidationError:
self.logger.warning("Incorrect message format for PauseCommand.")

View File

@@ -64,11 +64,12 @@ class LLMAgent(BaseAgent):
:param message: The parsed prompt message containing text, norms, and goals.
"""
full_message = ""
async for chunk in self._query_llm(message.text, message.norms, message.goals):
await self._send_reply(chunk)
self.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI core."
)
full_message += chunk
self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.")
await self._send_full_reply(full_message)
async def _send_reply(self, msg: str):
"""
@@ -83,6 +84,19 @@ class LLMAgent(BaseAgent):
)
await self.send(reply)
async def _send_full_reply(self, msg: str):
"""
Sends a response message (full) to agents that need it.
:param msg: The text content of the message.
"""
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=self.name,
body=msg,
)
await self.send(message)
async def _query_llm(
self, prompt: str, norms: list[str], goals: list[str]
) -> AsyncGenerator[str]:
@@ -125,7 +139,7 @@ class LLMAgent(BaseAgent):
full_message += token
current_chunk += token
self.logger.info(
self.logger.llm(
"Received token: %s",
full_message,
extra={"reference": message_id}, # Used in the UI to update old logs
@@ -172,7 +186,7 @@ class LLMAgent(BaseAgent):
json={
"model": settings.llm_settings.local_llm_model,
"messages": messages,
"temperature": 0.3,
"temperature": settings.llm_settings.chat_temperature,
"stream": True,
},
) as response:

View File

@@ -0,0 +1,68 @@
import asyncio
import json
import zmq
from zmq.asyncio import Context
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
class TestPauseAgent(BaseAgent):
def __init__(self, name: str):
super().__init__(name)
async def setup(self):
context = Context.instance()
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
self.add_behavior(self._pause_command_loop())
self.logger.debug("TestPauseAgent setup complete.")
async def _pause_command_loop(self):
print("Starting Pause command test loop.")
while True:
pause_command = {
"endpoint": "pause",
"data": True,
}
message = InternalMessage(
to="ri_communication_agent",
sender=self.name,
body=json.dumps(pause_command),
)
await self.send(message)
# User interrupt message
data = {
"type": "pause",
"context": True,
}
await self.pub_socket.send_multipart([b"button_pressed", json.dumps(data).encode()])
self.logger.info("Pausing robot actions.")
await asyncio.sleep(15) # Simulate delay between messages
pause_command = {
"endpoint": "pause",
"data": False,
}
message = InternalMessage(
to="ri_communication_agent",
sender=self.name,
body=json.dumps(pause_command),
)
await self.send(message)
# User interrupt message
data = {
"type": "pause",
"context": False,
}
await self.pub_socket.send_multipart([b"button_pressed", json.dumps(data).encode()])
self.logger.info("Resuming robot actions.")
await asyncio.sleep(15) # Simulate delay between messages

View File

@@ -7,7 +7,9 @@ import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
from .transcription_agent.transcription_agent import TranscriptionAgent
@@ -61,6 +63,7 @@ class VADAgent(BaseAgent):
:ivar audio_in_address: Address of the input audio stream.
:ivar audio_in_bind: Whether to bind or connect to the input address.
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
@@ -79,9 +82,17 @@ class VADAgent(BaseAgent):
self.audio_out_socket: azmq.Socket | None = None
self.audio_in_poller: SocketPoller | None = None
self.program_sub_socket: azmq.Socket | None = None
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = asyncio.Event()
# Pause control
self._reset_needed = False
self._paused = asyncio.Event()
self._paused.set() # Not paused at start
self.model = None
async def setup(self):
@@ -90,9 +101,10 @@ class VADAgent(BaseAgent):
1. Connects audio input socket.
2. Binds audio output socket (random port).
3. Loads VAD model from Torch Hub.
4. Starts the streaming loop.
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
3. Connects to program communication socket.
4. Loads VAD model from Torch Hub.
5. Starts the streaming loop.
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
"""
self.logger.info("Setting up %s", self.name)
@@ -105,6 +117,11 @@ class VADAgent(BaseAgent):
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Connect to internal communication socket
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.program_sub_socket.subscribe(PROGRAM_STATUS)
# Initialize VAD model
try:
self.model, _ = torch.hub.load(
@@ -117,10 +134,8 @@ class VADAgent(BaseAgent):
await self.stop()
return
# Warmup/reset
await self.reset_stream()
self.add_behavior(self._streaming_loop())
self.add_behavior(self._status_loop())
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
@@ -165,7 +180,7 @@ class VADAgent(BaseAgent):
self.audio_out_socket = None
return None
async def reset_stream(self):
async def _reset_stream(self):
"""
Clears the ZeroMQ queue and sets ready state.
"""
@@ -176,6 +191,23 @@ class VADAgent(BaseAgent):
self.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready.set()
async def _status_loop(self):
"""Loop for checking program status. Only start listening if program is RUNNING."""
while self._running:
topic, body = await self.program_sub_socket.recv_multipart()
if topic != PROGRAM_STATUS:
continue
if body != ProgramStatus.RUNNING.value:
continue
# Program is now running, we can start our stream
await self._reset_stream()
# We don't care about further status updates
self.program_sub_socket.close()
break
async def _streaming_loop(self):
"""
Main loop for processing audio stream.
@@ -188,6 +220,16 @@ class VADAgent(BaseAgent):
"""
await self._ready.wait()
while self._running:
await self._paused.wait()
# After being unpaused, reset stream and buffers
if self._reset_needed:
self.logger.debug("Resuming: resetting stream and buffers.")
await self._reset_stream()
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._reset_needed = False
assert self.audio_in_poller is not None
data = await self.audio_in_poller.poll()
if data is None:
@@ -229,3 +271,27 @@ class VADAgent(BaseAgent):
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages.
Expects messages to pause or resume the VAD processing from User Interrupt Agent.
:param msg: The received internal message.
"""
sender = msg.sender
if sender == settings.agent_settings.user_interrupt_name:
if msg.body == "PAUSE":
self.logger.info("Pausing VAD processing.")
self._paused.clear()
# If the robot needs to pick up speaking where it left off, do not set _reset_needed
self._reset_needed = True
elif msg.body == "RESUME":
self.logger.info("Resuming VAD processing.")
self._paused.set()
else:
self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}")
else:
self.logger.debug(f"Ignoring message from unknown sender: {sender}")

View File

@@ -0,0 +1,189 @@
import json
import zmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import (
GestureCommand,
PauseCommand,
RIEndpoint,
SpeechCommand,
)
class UserInterruptAgent(BaseAgent):
"""
User Interrupt Agent.
This agent receives button_pressed events from the external HTTP API
(via ZMQ) and uses the associated context to trigger one of the following actions:
- Send a prioritized message to the `RobotSpeechAgent`
- Send a prioritized gesture to the `RobotGestureAgent`
- Send a belief override to the `BDIProgramManager`in order to activate a
trigger/conditional norm or complete a goal.
Prioritized actions clear the current RI queue before inserting the new item,
ensuring they are executed immediately after Pepper's current action has been fulfilled.
:ivar sub_socket: The ZMQ SUB socket used to receive user intterupts.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
async def _receive_button_event(self):
"""
The behaviour of the UserInterruptAgent.
Continuous loop that receives button_pressed events from the button_pressed HTTP endpoint.
These events contain a type and a context.
These are the different types and contexts:
- type: "speech", context: string that the robot has to say.
- type: "gesture", context: single gesture name that the robot has to perform.
- type: "override", context: belief_id that overrides the goal/trigger/conditional norm.
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
try:
event_data = json.loads(body)
event_type = event_data.get("type") # e.g., "speech", "gesture"
event_context = event_data.get("context") # e.g., "Hello, I am Pepper!"
except json.JSONDecodeError:
self.logger.error("Received invalid JSON payload on topic %s", topic)
continue
if event_type == "speech":
await self._send_to_speech_agent(event_context)
self.logger.info(
"Forwarded button press (speech) with context '%s' to RobotSpeechAgent.",
event_context,
)
elif event_type == "gesture":
await self._send_to_gesture_agent(event_context)
self.logger.info(
"Forwarded button press (gesture) with context '%s' to RobotGestureAgent.",
event_context,
)
elif event_type == "override":
await self._send_to_program_manager(event_context)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
event_context,
)
elif event_type == "pause":
await self._send_pause_command(event_context)
if event_context:
self.logger.info("Sent pause command.")
else:
self.logger.info("Sent resume command.")
else:
self.logger.warning(
"Received button press with unknown type '%s' (context: '%s').",
event_type,
event_context,
)
async def _send_to_speech_agent(self, text_to_say: str):
"""
method to send prioritized speech command to RobotSpeechAgent.
:param text_to_say: The string that the robot has to say.
"""
cmd = SpeechCommand(data=text_to_say, is_priority=True)
out_msg = InternalMessage(
to=settings.agent_settings.robot_speech_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
async def _send_to_gesture_agent(self, single_gesture_name: str):
"""
method to send prioritized gesture command to RobotGestureAgent.
:param single_gesture_name: The gesture tag that the robot has to perform.
"""
# the endpoint is set to always be GESTURE_SINGLE for user interrupts
cmd = GestureCommand(
endpoint=RIEndpoint.GESTURE_SINGLE, data=single_gesture_name, is_priority=True
)
out_msg = InternalMessage(
to=settings.agent_settings.robot_gesture_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
async def _send_to_program_manager(self, belief_id: str):
"""
Send a button_override belief to the BDIProgramManager.
:param belief_id: The belief_id that overrides the goal/trigger/conditional norm.
this id can belong to a basic belief or an inferred belief.
See also: https://utrechtuniversity.youtrack.cloud/articles/N25B-A-27/UI-components
"""
data = {"belief": belief_id}
message = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
sender=self.name,
body=json.dumps(data),
thread="belief_override_id",
)
await self.send(message)
self.logger.info(
"Sent button_override belief with id '%s' to Program manager.",
belief_id,
)
async def _send_pause_command(self, pause : bool):
"""
Send a pause command to the Robot Interface via the RI Communication Agent.
Send a pause command to the other internal agents; for now just VAD agent.
"""
cmd = PauseCommand(data=pause)
message = InternalMessage(
to=settings.agent_settings.ri_communication_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(message)
if pause:
# Send pause to VAD agent
vad_message = InternalMessage(
to=settings.agent_settings.vad_name,
sender=self.name,
body="PAUSE",
)
await self.send(vad_message)
self.logger.info("Sent pause command to VAD Agent and RI Communication Agent.")
else:
# Send resume to VAD agent
vad_message = InternalMessage(
to=settings.agent_settings.vad_name,
sender=self.name,
body="RESUME",
)
await self.send(vad_message)
self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.")
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic.
Starts the background behavior to receive the user interrupts.
"""
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("button_pressed")
self.add_behavior(self._receive_button_event())

View File

@@ -0,0 +1,31 @@
import logging
from fastapi import APIRouter, Request
from control_backend.schemas.events import ButtonPressedEvent
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/button_pressed", status_code=202)
async def receive_button_event(event: ButtonPressedEvent, request: Request):
"""
Endpoint to handle external button press events.
Validates the event payload and publishes it to the internal 'button_pressed' topic.
Subscribers (in this case user_interrupt_agent) will pick this up to trigger
specific behaviors or state changes.
:param event: The parsed ButtonPressedEvent object.
:param request: The FastAPI request object.
"""
logger.debug("Received button event: %s | %s", event.type, event.context)
topic = b"button_pressed"
body = event.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Event received"}

View File

@@ -3,9 +3,8 @@ import json
import logging
import zmq.asyncio
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from pydantic import ValidationError
from zmq.asyncio import Context, Socket
from control_backend.core.config import settings
@@ -16,38 +15,44 @@ logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
@router.post("/command/speech", status_code=202)
async def receive_command_speech(command: SpeechCommand, request: Request):
"""
Send a direct speech command to the robot.
Publishes the command to the internal 'command' topic. The
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`
or
will forward this to the robot.
:param command: The speech command payload.
:param request: The FastAPI request object.
"""
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Speech command received"}
@router.post("/command/gesture", status_code=202)
async def receive_command_gesture(command: GestureCommand, request: Request):
"""
Send a direct gesture command to the robot.
Publishes the command to the internal 'command' topic. The
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotGestureAgent`
will forward this to the robot.
:param command: The speech command payload.
:param request: The FastAPI request object.
"""
# Validate and retrieve data.
validated = None
valid_commands = (GestureCommand, SpeechCommand)
for command_model in valid_commands:
try:
validated = command_model.model_validate(command)
except ValidationError:
continue
if validated is None:
raise HTTPException(status_code=422, detail="Payload is not valid for command models")
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, validated.model_dump_json().encode()])
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}
return {"status": "Gesture command received"}
@router.get("/ping_check")
@@ -58,31 +63,27 @@ async def ping(request: Request):
pass
@router.get("/get_available_gesture_tags")
async def get_available_gesture_tags(request: Request):
@router.get("/commands/gesture/tags")
async def get_available_gesture_tags(request: Request, count=0):
"""
Endpoint to retrieve the available gesture tags for the robot.
:param request: The FastAPI request object.
:return: A list of available gesture tags.
"""
sub_socket = Context.instance().socket(zmq.SUB)
sub_socket.connect(settings.zmq_settings.internal_sub_address)
sub_socket.setsockopt(zmq.SUBSCRIBE, b"get_gestures")
req_socket = Context.instance().socket(zmq.REQ)
req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress)
pub_socket: Socket = request.app.state.endpoints_pub_socket
topic = b"send_gestures"
# TODO: Implement a way to get a certain ammount from the UI, rather than everything.
amount = None
# Check to see if we've got any count given in the query parameter
amount = count or None
timeout = 5 # seconds
await pub_socket.send_multipart([topic, amount.to_bytes(4, "big") if amount else b""])
await req_socket.send(f"{amount}".encode() if amount else b"None")
try:
_, body = await asyncio.wait_for(sub_socket.recv_multipart(), timeout=timeout)
body = await asyncio.wait_for(req_socket.recv(), timeout=timeout)
except TimeoutError:
body = b"tags: []"
logger.debug("got timeout error fetching gestures")
body = '{"tags": []}'
logger.debug("Got timeout error fetching gestures.")
# Handle empty response and JSON decode errors
available_tags = []

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import logs, message, program, robot, sse
from control_backend.api.v1.endpoints import button_pressed, logs, message, program, robot, sse
api_router = APIRouter()
@@ -13,3 +13,5 @@ api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Command
api_router.include_router(logs.router, tags=["Logs"])
api_router.include_router(program.router, tags=["Program"])
api_router.include_router(button_pressed.router, tags=["Button Pressed Events"])

View File

@@ -17,7 +17,7 @@ class ZMQSettings(BaseModel):
internal_sub_address: str = "tcp://localhost:5561"
ri_command_address: str = "tcp://localhost:0000"
ri_communication_address: str = "tcp://*:5555"
vad_agent_address: str = "tcp://localhost:5558"
internal_gesture_rep_adress: str = "tcp://localhost:7788"
class AgentSettings(BaseModel):
@@ -48,6 +48,7 @@ class AgentSettings(BaseModel):
ri_communication_name: str = "ri_communication_agent"
robot_speech_name: str = "robot_speech_agent"
robot_gesture_name: str = "robot_gesture_agent"
user_interrupt_name: str = "user_interrupt_agent"
class BehaviourSettings(BaseModel):
@@ -64,6 +65,7 @@ class BehaviourSettings(BaseModel):
:ivar transcription_words_per_minute: Estimated words per minute for transcription timing.
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
:ivar transcription_token_buffer: Buffer for transcription tokens.
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
"""
sleep_s: float = 1.0
@@ -81,6 +83,9 @@ class BehaviourSettings(BaseModel):
transcription_words_per_token: float = 0.75 # (3 words = 4 tokens)
transcription_token_buffer: int = 10
# Text belief extractor settings
conversation_history_length_limit: int = 10
class LLMSettings(BaseModel):
"""
@@ -88,10 +93,17 @@ class LLMSettings(BaseModel):
:ivar local_llm_url: URL for the local LLM API.
:ivar local_llm_model: Name of the local LLM model to use.
:ivar chat_temperature: The temperature to use while generating chat responses.
:ivar code_temperature: The temperature to use while generating code-like responses like during
belief inference.
:ivar n_parallel: The number of parallel calls allowed to be made to the LLM.
"""
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "gpt-oss"
chat_temperature: float = 1.0
code_temperature: float = 0.3
n_parallel: int = 4
class VADSettings(BaseModel):

View File

@@ -4,6 +4,7 @@ import os
import yaml
import zmq
from zmq.log.handlers import PUBHandler
from control_backend.core.config import settings
@@ -51,15 +52,27 @@ def setup_logging(path: str = ".logging_config.yaml") -> None:
logging.warning(f"Could not load logging configuration: {e}")
config = {}
if "custom_levels" in config:
for level_name, level_num in config["custom_levels"].items():
add_logging_level(level_name, level_num)
custom_levels = config.get("custom_levels", {}) or {}
for level_name, level_num in custom_levels.items():
add_logging_level(level_name, level_num)
if config.get("handlers") is not None and config.get("handlers").get("ui"):
pub_socket = zmq.Context.instance().socket(zmq.PUB)
pub_socket.connect(settings.zmq_settings.internal_pub_address)
config["handlers"]["ui"]["interface_or_socket"] = pub_socket
logging.config.dictConfig(config)
# Patch ZMQ PUBHandler to know about custom levels
if custom_levels:
for logger_name in ("control_backend",):
logger = logging.getLogger(logger_name)
for handler in logger.handlers:
if isinstance(handler, PUBHandler):
# Use the INFO formatter as the default template
default_fmt = handler.formatters[logging.INFO]
for level_num in custom_levels.values():
handler.setFormatter(default_fmt, level=level_num)
else:
logging.warning("Logging config file not found. Using default logging configuration.")

View File

@@ -39,13 +39,15 @@ from control_backend.agents.communication import RICommunicationAgent
# LLM Agents
from control_backend.agents.llm import LLMAgent
# Perceive agents
from control_backend.agents.perception import VADAgent
# Other backend imports
from control_backend.agents.mock_agents.test_pause_ri import TestPauseAgent
# User Interrupt Agent
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.logging import setup_logging
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
logger = logging.getLogger(__name__)
@@ -95,6 +97,8 @@ async def lifespan(app: FastAPI):
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
app.state.endpoints_pub_socket = endpoints_pub_socket
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STARTING.value])
# --- Initialize Agents ---
logger.info("Initializing and starting agents.")
@@ -132,46 +136,50 @@ async def lifespan(app: FastAPI):
"name": settings.agent_settings.text_belief_extractor_name,
},
),
"VADAgent": (
VADAgent,
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
),
"ProgramManagerAgent": (
BDIProgramManager,
{
"name": settings.agent_settings.bdi_program_manager_name,
},
),
"TestPauseAgent": (
TestPauseAgent,
{
"name": "pause_test_agent",
},
),
"UserInterruptAgent": (
UserInterruptAgent,
{
"name": settings.agent_settings.user_interrupt_name,
},
),
}
agents = []
vad_agent = None
for name, (agent_class, kwargs) in agents_to_start.items():
try:
logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**kwargs)
await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance
agents.append(agent_instance)
logger.info("Agent '%s' started successfully.", name)
except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
raise
assert vad_agent is not None
await vad_agent.reset_stream()
logger.info("Application startup complete.")
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.RUNNING.value])
yield
# --- APPLICATION SHUTDOWN ---
logger.info("%s is shutting down.", app.title)
# Potential shutdown logic goes here
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
# Additional shutdown logic goes here
logger.info("Application shutdown complete.")

View File

@@ -6,18 +6,27 @@ class Belief(BaseModel):
Represents a single belief in the BDI system.
:ivar name: The functor or name of the belief (e.g., 'user_said').
:ivar arguments: A list of string arguments for the belief.
:ivar replace: If True, existing beliefs with this name should be replaced by this one.
:ivar arguments: A list of string arguments for the belief, or None if the belief has no
arguments.
"""
name: str
arguments: list[str]
replace: bool = False
arguments: list[str] | None
class BeliefMessage(BaseModel):
"""
A container for transporting a list of beliefs between agents.
A container for communicating beliefs between agents.
:ivar create: Beliefs to create.
:ivar delete: Beliefs to delete.
:ivar replace: Beliefs to replace. Deletes all beliefs with the same name, replacing them with
one new belief.
"""
beliefs: list[Belief]
create: list[Belief] = []
delete: list[Belief] = []
replace: list[Belief] = []
def has_values(self) -> bool:
return len(self.create) > 0 or len(self.delete) > 0 or len(self.replace) > 0

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class ChatMessage(BaseModel):
role: str
content: str
class ChatHistory(BaseModel):
messages: list[ChatMessage]

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ButtonPressedEvent(BaseModel):
type: str
context: str

View File

@@ -1,64 +1,202 @@
from pydantic import BaseModel
from enum import Enum
from typing import Literal
from pydantic import UUID4, BaseModel
class Norm(BaseModel):
class ProgramElement(BaseModel):
"""
Represents a behavioral norm.
Represents a basic element of our behavior program.
:ivar name: The researcher-assigned name of the element.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar norm: The actual norm text describing the behavior.
"""
id: str
label: str
norm: str
name: str
id: UUID4
class Goal(BaseModel):
class LogicalOperator(Enum):
AND = "AND"
OR = "OR"
type Belief = KeywordBelief | SemanticBelief | InferredBelief
type BasicBelief = KeywordBelief | SemanticBelief
class KeywordBelief(ProgramElement):
"""
Represents an objective to be achieved.
Represents a belief that is set when the user spoken text contains a certain keyword.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar description: Detailed description of the goal.
:ivar achieved: Status flag indicating if the goal has been met.
:ivar keyword: The keyword on which this belief gets set.
"""
id: str
label: str
description: str
achieved: bool
class TriggerKeyword(BaseModel):
id: str
name: str = ""
keyword: str
class KeywordTrigger(BaseModel):
id: str
label: str
type: str
keywords: list[TriggerKeyword]
class SemanticBelief(ProgramElement):
"""
Represents a belief that is set by semantic LLM validation.
:ivar description: Description of how to form the belief, used by the LLM.
"""
name: str = ""
description: str
class Phase(BaseModel):
class InferredBelief(ProgramElement):
"""
Represents a belief that gets formed by combining two beliefs with a logical AND or OR.
These beliefs can also be :class:`InferredBelief`, leading to arbitrarily deep nesting.
:ivar operator: The logical operator to apply.
:ivar left: The left part of the logical expression.
:ivar right: The right part of the logical expression.
"""
name: str = ""
operator: LogicalOperator
left: Belief
right: Belief
class Norm(ProgramElement):
name: str = ""
norm: str
critical: bool = False
class BasicNorm(Norm):
"""
Represents a behavioral norm.
:ivar norm: The actual norm text describing the behavior.
:ivar critical: When true, this norm should absolutely not be violated (checked separately).
"""
pass
class ConditionalNorm(Norm):
"""
Represents a norm that is only active when a condition is met (i.e., a certain belief holds).
:ivar condition: When to activate this norm.
"""
condition: Belief
type PlanElement = Goal | Action
class Plan(ProgramElement):
"""
Represents a list of steps to execute. Each of these steps can be a goal (with its own plan)
or a simple action.
:ivar steps: The actions or subgoals to execute, in order.
"""
name: str = ""
steps: list[PlanElement]
class Goal(ProgramElement):
"""
Represents an objective to be achieved. To reach the goal, we should execute
the corresponding plan. If we can fail to achieve a goal after executing the plan,
for example when the achieving of the goal is dependent on the user's reply, this means
that the achieved status will be set from somewhere else in the program.
:ivar plan: The plan to execute.
:ivar can_fail: Whether we can fail to achieve the goal after executing the plan.
"""
plan: Plan
can_fail: bool = True
type Action = SpeechAction | GestureAction | LLMAction
class SpeechAction(ProgramElement):
"""
Represents the action of the robot speaking a literal text.
:ivar text: The text to speak.
"""
name: str = ""
text: str
class Gesture(BaseModel):
"""
Represents a gesture to be performed. Can be either a single gesture,
or a random gesture from a category (tag).
:ivar type: The type of the gesture, "tag" or "single".
:ivar name: The name of the single gesture or tag.
"""
type: Literal["tag", "single"]
name: str
class GestureAction(ProgramElement):
"""
Represents the action of the robot performing a gesture.
:ivar gesture: The gesture to perform.
"""
name: str = ""
gesture: Gesture
class LLMAction(ProgramElement):
"""
Represents the action of letting an LLM generate a reply based on its chat history
and an additional goal added in the prompt.
:ivar goal: The extra (temporary) goal to add to the LLM.
"""
name: str = ""
goal: str
class Trigger(ProgramElement):
"""
Represents a belief-based trigger. When a belief is set, the corresponding plan is executed.
:ivar condition: When to activate the trigger.
:ivar plan: The plan to execute.
"""
name: str = ""
condition: Belief
plan: Plan
class Phase(ProgramElement):
"""
A distinct phase within a program, containing norms, goals, and triggers.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar norms: List of norms active in this phase.
:ivar goals: List of goals to pursue in this phase.
:ivar triggers: List of triggers that define transitions out of this phase.
"""
id: str
label: str
norms: list[Norm]
name: str = ""
norms: list[BasicNorm | ConditionalNorm]
goals: list[Goal]
triggers: list[KeywordTrigger]
triggers: list[Trigger]
class Program(BaseModel):

View File

@@ -0,0 +1,16 @@
from enum import Enum
PROGRAM_STATUS = b"internal/program_status"
"""A topic key for the program status."""
class ProgramStatus(Enum):
"""
Used in internal communication, to tell agents what the status of the program is.
For example, the VAD agent only starts listening when the program is RUNNING.
"""
STARTING = b"starting"
RUNNING = b"running"
STOPPING = b"stopping"

View File

@@ -14,6 +14,7 @@ class RIEndpoint(str, Enum):
GESTURE_TAG = "actuate/gesture/tag"
PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports"
PAUSE = "pause"
class RIMessage(BaseModel):
@@ -38,6 +39,7 @@ class SpeechCommand(RIMessage):
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str
is_priority: bool = False
class GestureCommand(RIMessage):
@@ -52,6 +54,7 @@ class GestureCommand(RIMessage):
RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG
]
data: str
is_priority: bool = False
@model_validator(mode="after")
def check_endpoint(self):
@@ -62,3 +65,14 @@ class GestureCommand(RIMessage):
if self.endpoint not in allowed:
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
return self
class PauseCommand(RIMessage):
"""
A specific command to pause or unpause the robot's actions.
:ivar endpoint: Fixed to ``RIEndpoint.PAUSE``.
:ivar data: A boolean indicating whether to pause (True) or unpause (False).
"""
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.PAUSE)
data: bool

View File

@@ -5,6 +5,7 @@ import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
@pytest.fixture
@@ -43,14 +44,12 @@ async def test_normal_setup(per_transcription_agent):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
per_vad_agent.reset_stream = AsyncMock()
await per_vad_agent.setup()
per_transcription_agent.assert_called_once()
per_transcription_agent.return_value.start.assert_called_once()
per_vad_agent._streaming_loop.assert_called_once()
per_vad_agent.reset_stream.assert_called_once()
assert per_vad_agent.audio_in_socket is not None
assert per_vad_agent.audio_out_socket is not None
@@ -103,7 +102,7 @@ async def test_out_socket_creation_failure(zmq_context):
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.stop = AsyncMock()
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
@@ -124,7 +123,7 @@ async def test_stop(zmq_context, per_transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
@@ -142,3 +141,66 @@ async def test_stop(zmq_context, per_transcription_agent):
assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert per_vad_agent.audio_in_socket is None
assert per_vad_agent.audio_out_socket is None
@pytest.mark.asyncio
async def test_application_startup_complete(zmq_context):
"""Check that it resets the stream when the program finishes startup."""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
]
await vad_agent._status_loop()
vad_agent._reset_stream.assert_called_once()
vad_agent.program_sub_socket.close.assert_called_once()
@pytest.mark.asyncio
async def test_application_other_status(zmq_context):
"""
Check that it does nothing when the internal communication message is a status update, but not
running.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
]
try:
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()
@pytest.mark.asyncio
async def test_application_message_other(zmq_context):
"""
Check that it does nothing when there's an internal communication message that is not a status
update.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
try:
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()

View File

@@ -34,14 +34,16 @@ async def test_setup_bind(zmq_context, mocker):
# Check PUB socket binding
fake_socket.bind.assert_any_call("tcp://localhost:5556")
# Check REP socket binding
fake_socket.bind.assert_called()
# Check SUB socket connection and subscriptions
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures")
# Check behavior was added
agent.add_behavior.assert_called() # Twice, even.
# Check behavior was added (twice: once for command loop, once for fetch gestures loop)
assert agent.add_behavior.call_count == 2
@pytest.mark.asyncio
@@ -60,21 +62,23 @@ async def test_setup_connect(zmq_context, mocker):
# Check PUB socket connection (not binding)
fake_socket.connect.assert_any_call("tcp://localhost:5556")
fake_socket.connect.assert_any_call("tcp://internal:1234")
# Check REP socket binding (always binds)
fake_socket.bind.assert_called()
# Check behavior was added
agent.add_behavior.assert_called() # Twice, actually.
# Check behavior was added (twice)
assert agent.add_behavior.call_count == 2
@pytest.mark.asyncio
async def test_handle_message_sends_valid_gesture_command():
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.pubsocket = pubsocket
payload = {
"endpoint": RIEndpoint.GESTURE_TAG,
"data": "hello", # "hello" is in availableTags
"data": "hello", # "hello" is in gesture_data
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
@@ -85,9 +89,9 @@ async def test_handle_message_sends_valid_gesture_command():
@pytest.mark.asyncio
async def test_handle_message_sends_non_gesture_command():
"""Internal message with non-gesture endpoint is not handled by this agent."""
"""Internal message with non-gesture endpoint is not forwarded by this agent."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.pubsocket = pubsocket
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
@@ -95,6 +99,7 @@ async def test_handle_message_sends_non_gesture_command():
await agent.handle_message(msg)
# Non-gesture endpoints should not be forwarded by this agent
pubsocket.send_json.assert_not_awaited()
@@ -102,10 +107,10 @@ async def test_handle_message_sends_non_gesture_command():
async def test_handle_message_rejects_invalid_gesture_tag():
"""Internal message with invalid gesture tag is not forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.pubsocket = pubsocket
# Use a tag that's not in availableTags
# Use a tag that's not in gesture_data
payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
@@ -118,7 +123,7 @@ async def test_handle_message_rejects_invalid_gesture_tag():
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
@@ -142,7 +147,7 @@ async def test_zmq_command_loop_valid_gesture_payload():
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -154,7 +159,7 @@ async def test_zmq_command_loop_valid_gesture_payload():
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_non_gesture_payload():
"""UI command with non-gesture endpoint is not handled by this agent."""
"""UI command with non-gesture endpoint is not forwarded by this agent."""
command = {"endpoint": "some_other_endpoint", "data": "anything"}
fake_socket = AsyncMock()
@@ -165,7 +170,7 @@ async def test_zmq_command_loop_valid_non_gesture_payload():
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -188,7 +193,7 @@ async def test_zmq_command_loop_invalid_gesture_tag():
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -210,7 +215,7 @@ async def test_zmq_command_loop_invalid_json():
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -232,7 +237,7 @@ async def test_zmq_command_loop_ignores_send_gestures_topic():
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -245,139 +250,165 @@ async def test_zmq_command_loop_ignores_send_gestures_topic():
@pytest.mark.asyncio
async def test_fetch_gestures_loop_without_amount():
"""Fetch gestures request without amount returns all tags."""
fake_socket = AsyncMock()
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
return (b"send_gestures", b"{}")
return b"{}" # Empty JSON request
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_awaited_once()
fake_repsocket.send.assert_awaited_once()
# Check the response contains all tags
args, kwargs = fake_socket.send_multipart.call_args
assert args[0][0] == b"get_gestures"
response = json.loads(args[0][1])
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert "tags" in response
assert len(response["tags"]) > 0
# Check it includes some expected tags
assert "hello" in response["tags"]
assert "yes" in response["tags"]
assert response["tags"] == ["hello", "yes", "no", "wave", "point"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_amount():
"""Fetch gestures request with amount returns limited tags."""
fake_socket = AsyncMock()
amount = 5
fake_repsocket = AsyncMock()
amount = 3
async def recv_once():
agent._running = False
return (b"send_gestures", json.dumps(amount).encode())
return json.dumps(amount).encode()
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_awaited_once()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_socket.send_multipart.call_args
assert args[0][0] == b"get_gestures"
response = json.loads(args[0][1])
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert "tags" in response
assert len(response["tags"]) == amount
assert response["tags"] == ["hello", "yes", "no"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_ignores_command_topic():
"""Command topic is ignored in fetch gestures loop."""
fake_socket = AsyncMock()
async def test_fetch_gestures_loop_with_integer_request():
"""Fetch gestures request with integer amount."""
fake_repsocket = AsyncMock()
amount = 2
async def recv_once():
agent._running = False
return (b"command", b"{}")
return json.dumps(amount).encode()
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_not_awaited()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_invalid_request():
"""Invalid request body is handled gracefully."""
fake_socket = AsyncMock()
async def test_fetch_gestures_loop_with_invalid_json():
"""Invalid JSON request returns all tags."""
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
# Send a non-integer, non-JSON body
return (b"send_gestures", b"not_json")
return b"not_json"
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
# Should still send a response (all tags)
fake_socket.send_multipart.assert_awaited_once()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes", "no"]
def test_available_tags():
"""Test that availableTags returns the expected list."""
agent = RobotGestureAgent("robot_gesture")
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_non_integer_json():
"""Non-integer JSON request returns all tags."""
fake_repsocket = AsyncMock()
tags = agent.availableTags()
async def recv_once():
agent._running = False
return json.dumps({"not": "an_integer"}).encode()
assert isinstance(tags, list)
assert len(tags) > 0
# Check some expected tags are present
assert "hello" in tags
assert "yes" in tags
assert "no" in tags
# Check a non-existent tag is not present
assert "invalid_tag_not_in_list" not in tags
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes", "no"]
def test_gesture_data_attribute():
"""Test that gesture_data returns the expected list."""
gesture_data = ["hello", "yes", "no", "wave"]
agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data)
assert agent.gesture_data == gesture_data
assert isinstance(agent.gesture_data, list)
assert len(agent.gesture_data) == 4
assert "hello" in agent.gesture_data
assert "yes" in agent.gesture_data
assert "no" in agent.gesture_data
assert "invalid_tag_not_in_list" not in agent.gesture_data
@pytest.mark.asyncio
async def test_stop_closes_sockets():
"""Stop method closes both sockets."""
"""Stop method closes all sockets."""
pubsocket = MagicMock()
subsocket = MagicMock()
repsocket = MagicMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
agent.subsocket = subsocket
agent.repsocket = repsocket
await agent.stop()
pubsocket.close.assert_called_once()
subsocket.close.assert_called_once()
# Note: repsocket is not closed in stop() method, but you might want to add it
# repsocket.close.assert_called_once()
@pytest.mark.asyncio
@@ -386,7 +417,28 @@ async def test_initialization_with_custom_gesture_data():
custom_gestures = ["custom1", "custom2", "custom3"]
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures)
# Note: The current implementation doesn't use the gesture_data parameter
# in availableTags(). This test documents that behavior.
# If you update the agent to use gesture_data, update this test accordingly.
assert agent.gesture_data == custom_gestures
@pytest.mark.asyncio
async def test_fetch_gestures_loop_handles_exception():
"""Exception in fetch gestures loop is caught and logged."""
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
raise Exception("Test exception")
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.repsocket = fake_repsocket
agent.logger = MagicMock()
agent._running = True
# Should not raise exception
await agent._fetch_gestures_loop()
# Exception should be logged
agent.logger.exception.assert_called_once()

View File

@@ -8,6 +8,11 @@ from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
from control_backend.core.agent_system import InternalMessage
def mock_speech_agent():
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
return agent
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
@@ -56,10 +61,10 @@ async def test_setup_connect(zmq_context, mocker):
async def test_handle_message_sends_command():
"""Internal message is forwarded to robot pub socket as JSON."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
payload = {"endpoint": "actuate/speech", "data": "hello"}
payload = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
@@ -70,7 +75,7 @@ async def test_handle_message_sends_command():
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_payload(zmq_context):
"""UI command is read from SUB and published."""
command = {"endpoint": "actuate/speech", "data": "hello"}
command = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False}
fake_socket = AsyncMock()
async def recv_once():
@@ -80,7 +85,7 @@ async def test_zmq_command_loop_valid_payload(zmq_context):
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -101,7 +106,7 @@ async def test_zmq_command_loop_invalid_json():
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -115,7 +120,7 @@ async def test_zmq_command_loop_invalid_json():
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
@@ -129,7 +134,7 @@ async def test_handle_message_invalid_payload():
async def test_stop_closes_sockets():
pubsocket = MagicMock()
subsocket = MagicMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
agent.subsocket = subsocket

View File

@@ -1,4 +1,6 @@
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import agentspeak
@@ -49,7 +51,7 @@ async def test_handle_belief_collector_message(agent, mock_settings):
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
body=BeliefMessage(create=beliefs).model_dump_json(),
thread="beliefs",
)
@@ -62,6 +64,26 @@ async def test_handle_belief_collector_message(agent, mock_settings):
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
@pytest.mark.asyncio
async def test_handle_delete_belief_message(agent, mock_settings):
"""Test that incoming beliefs to be deleted are removed from the BDI agent"""
beliefs = [Belief(name="user_said", arguments=["Hello"])]
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(delete=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
# Expect bdi_agent.call to be triggered to remove belief
args = agent.bdi_agent.call.call_args.args
assert args[0] == agentspeak.Trigger.removal
assert args[1] == agentspeak.GoalType.belief
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
@pytest.mark.asyncio
async def test_incorrect_belief_collector_message(agent, mock_settings):
"""Test that incorrect message format triggers an exception."""
@@ -77,11 +99,6 @@ async def test_incorrect_belief_collector_message(agent, mock_settings):
agent.bdi_agent.call.assert_not_called() # did not set belief
@pytest.mark.asyncio
async def test():
pass
@pytest.mark.asyncio
async def test_handle_llm_response(agent):
"""Test that LLM responses are forwarded to the Robot Speech Agent"""
@@ -124,3 +141,150 @@ async def test_custom_actions(agent):
next(gen) # Execute
agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal")
def test_add_belief_sets_event(agent):
"""Test that a belief triggers wake event and call()"""
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="test_belief", arguments=["a", "b"])
belief_changes = BeliefMessage(replace=[belief])
agent._apply_belief_changes(belief_changes)
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
def test_apply_beliefs_empty_returns(agent):
"""Line: if not beliefs: return"""
agent._wake_bdi_loop = MagicMock()
agent._apply_belief_changes(BeliefMessage())
agent.bdi_agent.call.assert_not_called()
agent._wake_bdi_loop.set.assert_not_called()
def test_remove_belief_success_wakes_loop(agent):
"""Line: if result: wake set"""
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = True
agent._remove_belief("remove_me", ["x"])
assert agent.bdi_agent.call.called
trigger, goaltype, literal, *_ = agent.bdi_agent.call.call_args.args
assert trigger == agentspeak.Trigger.removal
assert goaltype == agentspeak.GoalType.belief
assert literal.functor == "remove_me"
assert literal.args[0].functor == "x"
agent._wake_bdi_loop.set.assert_called()
def test_remove_belief_failure_does_not_wake(agent):
"""Line: else result is False"""
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = False
agent._remove_belief("not_there", ["y"])
assert agent.bdi_agent.call.called # removal was attempted
agent._wake_bdi_loop.set.assert_not_called()
def test_remove_all_with_name_wakes_loop(agent):
"""Cover _remove_all_with_name() removed counter + wake"""
agent._wake_bdi_loop = MagicMock()
fake_literal = agentspeak.Literal("delete_me", (agentspeak.Literal("arg1"),))
fake_key = ("delete_me", 1)
agent.bdi_agent.beliefs = {fake_key: {fake_literal}}
agent._remove_all_with_name("delete_me")
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
@pytest.mark.asyncio
async def test_bdi_step_true_branch_hits_line_67(agent):
"""Force step() to return True once so line 67 is actually executed"""
# counter that isn't tied to MagicMock.call_count ordering
counter = {"i": 0}
def fake_step():
counter["i"] += 1
return counter["i"] == 1 # True only first time
# Important: wrap fake_step into another mock so `.called` still exists
agent.bdi_agent.step = MagicMock(side_effect=fake_step)
agent.bdi_agent.shortest_deadline = MagicMock(return_value=None)
agent._running = True
agent._wake_bdi_loop = asyncio.Event()
agent._wake_bdi_loop.set()
task = asyncio.create_task(agent._bdi_loop())
await asyncio.sleep(0.01)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert agent.bdi_agent.step.called
assert counter["i"] >= 1 # proves True branch ran
def test_replace_belief_calls_remove_all(agent):
"""Cover: if belief.replace: self._remove_all_with_name()"""
agent._remove_all_with_name = MagicMock()
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="user_said", arguments=["Hello"])
belief_changes = BeliefMessage(replace=[belief])
agent._apply_belief_changes(belief_changes)
agent._remove_all_with_name.assert_called_with("user_said")
@pytest.mark.asyncio
async def test_send_to_llm_creates_prompt_and_sends(agent):
"""Cover entire _send_to_llm() including message send and logger.info"""
agent.bdi_agent = MagicMock() # ensure mocked BDI does not interfere
agent._wake_bdi_loop = MagicMock()
await agent._send_to_llm("hello world", "n1\nn2", "g1")
# send() was called
assert agent.send.called
sent_msg: InternalMessage = agent.send.call_args.args[0]
# Message routing values correct
assert sent_msg.to == settings.agent_settings.llm_name
assert "hello world" in sent_msg.body
# JSON contains split norms/goals
body = json.loads(sent_msg.body)
assert body["norms"] == ["n1", "n2"]
assert body["goals"] == ["g1"]
@pytest.mark.asyncio
async def test_deadline_sleep_branch(agent):
"""Specifically assert the if deadline: sleep → maybe_more_work=True branch"""
future_deadline = time.time() + 0.005
agent.bdi_agent.step.return_value = False
agent.bdi_agent.shortest_deadline.return_value = future_deadline
start_time = time.time()
agent._running = True
agent._wake_bdi_loop = asyncio.Event()
agent._wake_bdi_loop.set()
task = asyncio.create_task(agent._bdi_loop())
await asyncio.sleep(0.01)
task.cancel()
duration = time.time() - start_time
assert duration >= 0.004 # loop slept until deadline

View File

@@ -0,0 +1,91 @@
import asyncio
import sys
import uuid
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program
# Fix Windows Proactor loop for zmq
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def make_valid_program_json(norm="N1", goal="G1") -> str:
return Program(
phases=[
Phase(
id=uuid.uuid4(),
name="Basic Phase",
norms=[
BasicNorm(
id=uuid.uuid4(),
name=norm,
norm=norm,
),
],
goals=[
Goal(
id=uuid.uuid4(),
name=goal,
plan=Plan(
id=uuid.uuid4(),
name="Goal Plan",
steps=[],
),
can_fail=False,
),
],
triggers=[],
),
],
).model_dump_json()
@pytest.mark.skip(reason="Functionality being rebuilt.")
@pytest.mark.asyncio
async def test_send_to_bdi():
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
program = Program.model_validate_json(make_valid_program_json())
await manager._send_to_bdi(program)
assert manager.send.await_count == 1
msg: InternalMessage = manager.send.await_args[0][0]
assert msg.thread == "beliefs"
beliefs = BeliefMessage.model_validate_json(msg.body)
names = {b.name: b.arguments for b in beliefs.beliefs}
assert "norms" in names and names["norms"] == ["N1"]
assert "goals" in names and names["goals"] == ["G1"]
@pytest.mark.asyncio
async def test_receive_programs_valid_and_invalid():
sub = AsyncMock()
sub.recv_multipart.side_effect = [
(b"program", b"{bad json"),
(b"program", make_valid_program_json().encode()),
]
manager = BDIProgramManager(name="program_manager_test")
manager.sub_socket = sub
manager._send_to_bdi = AsyncMock()
try:
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
await manager._receive_programs()
except StopAsyncIteration:
pass
# Only valid Program should have triggered _send_to_bdi
assert manager._send_to_bdi.await_count == 1
forwarded: Program = manager._send_to_bdi.await_args[0][0]
assert forwarded.phases[0].norms[0].name == "N1"
assert forwarded.phases[0].goals[0].name == "G1"

View File

@@ -86,4 +86,50 @@ async def test_send_beliefs_to_bdi(agent):
sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]
assert json.loads(sent.body)["create"] == [belief.model_dump() for belief in beliefs]
@pytest.mark.asyncio
async def test_setup_executes(agent):
"""Covers setup and asserts the agent has a name."""
await agent.setup()
assert agent.name == "belief_collector_agent" # simple property assertion
@pytest.mark.asyncio
async def test_handle_message_unrecognized_type_executes(agent):
"""Covers the else branch for unrecognized message type."""
payload = {"type": "unknown_type"}
msg = make_msg(payload, sender="tester")
# Wrap send to ensure nothing is sent
agent.send = AsyncMock()
await agent.handle_message(msg)
# Assert no messages were sent
agent.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_emo_text_executes(agent):
"""Covers the _handle_emo_text method."""
# The method does nothing, but we can assert it returns None
result = await agent._handle_emo_text({}, "origin")
assert result is None
@pytest.mark.asyncio
async def test_send_beliefs_to_bdi_empty_executes(agent):
"""Covers early return when beliefs are empty."""
agent.send = AsyncMock()
await agent._send_beliefs_to_bdi({})
# Assert that nothing was sent
agent.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_belief_text_invalid_returns_none(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "invalid-argument"}}
result = await agent._handle_belief_text(payload, "origin")
# The method itself returns None
assert result is None

View File

@@ -0,0 +1,346 @@
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from control_backend.agents.bdi import TextBeliefExtractorAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.program import (
ConditionalNorm,
LLMAction,
Phase,
Plan,
Program,
SemanticBelief,
Trigger,
)
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
agent._query_llm = AsyncMock()
return agent
@pytest.fixture
def sample_program():
return Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[
ConditionalNorm(
name="Some norm",
id=uuid.uuid4(),
norm="Use nautical terms.",
critical=False,
condition=SemanticBelief(
name="is_pirate",
id=uuid.uuid4(),
description="The user is a pirate. Perhaps because they say "
"they are, or because they speak like a pirate "
'with terms like "arr".',
),
),
],
goals=[],
triggers=[
Trigger(
name="Some trigger",
id=uuid.uuid4(),
condition=SemanticBelief(
name="no_more_booze",
id=uuid.uuid4(),
description="There is no more alcohol.",
),
plan=Plan(
name="Some plan",
id=uuid.uuid4(),
steps=[
LLMAction(
name="Some action",
id=uuid.uuid4(),
goal="Suggest eating chocolate instead.",
),
],
),
),
],
),
],
)
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_user_said(agent, mock_settings):
transcription = "this is a test"
await agent._user_said(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]
@pytest.mark.asyncio
async def test_query_llm():
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "null",
}
}
]
}
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_async_client = MagicMock()
mock_async_client.__aenter__.return_value = mock_client
mock_async_client.__aexit__.return_value = None
with patch(
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
return_value=mock_async_client,
):
agent = TextBeliefExtractorAgent("text_belief_agent")
res = await agent._query_llm("hello world", {"type": "null"})
# Response content was set as "null", so should be deserialized as None
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success(agent):
agent._query_llm.return_value = None
res = await agent._retry_query_llm("hello world", {"type": "null"})
agent._query_llm.assert_called_once()
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success_after_failure(agent):
agent._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"})
assert agent._query_llm.call_count == 2
assert res == "real value"
@pytest.mark.asyncio
async def test_retry_query_llm_failures(agent):
agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"})
assert agent._query_llm.call_count == 3
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_fail_immediately(agent):
agent._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1)
assert agent._query_llm.call_count == 1
assert res is None
@pytest.mark.asyncio
async def test_extracting_beliefs_from_program(agent, sample_program):
assert len(agent.available_beliefs) == 0
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=sample_program.model_dump_json(),
),
)
assert len(agent.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_invalid_program(agent, sample_program):
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
assert len(agent.available_beliefs) == 2
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=json.dumps({"phases": "Invalid"}),
),
)
assert len(agent.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_robot_response(agent):
initial_length = len(agent.conversation.messages)
response = "Hi, I'm Pepper. What's your name?"
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.llm_name,
body=response,
),
)
assert len(agent.conversation.messages) == initial_length + 1
assert agent.conversation.messages[-1].role == "assistant"
assert agent.conversation.messages[-1].content == response
@pytest.mark.asyncio
async def test_simulated_real_turn_with_beliefs(agent, sample_program):
"""Test sending user message to extract beliefs from."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with the belief that there's no more booze
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
assert len(agent.conversation.messages) == 0
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="We're all out of schnaps.",
),
)
assert len(agent.conversation.messages) == 1
# There should be a belief set and sent to the BDI core, as well as the user_said belief
assert agent.send.call_count == 2
# First should be the beliefs message
message: InternalMessage = agent.send.call_args_list[0].args[0]
beliefs = BeliefMessage.model_validate_json(message.body)
assert len(beliefs.create) == 1
assert beliefs.create[0].name == "no_more_booze"
@pytest.mark.asyncio
async def test_simulated_real_turn_no_beliefs(agent, sample_program):
"""Test a user message to extract beliefs from, but no beliefs are formed."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with no new beliefs
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="Hello there!",
),
)
# Only the user_said belief should've been sent
agent.send.assert_called_once()
@pytest.mark.asyncio
async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
"""
Test a user message to extract beliefs from, but no new beliefs are formed because they already
existed.
"""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["is_pirate"] = True
# Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="Arr, nice to meet you, matey.",
),
)
# Only the user_said belief should've been sent, as no beliefs have changed
agent.send.assert_called_once()
@pytest.mark.asyncio
async def test_simulated_real_turn_remove_belief(agent, sample_program):
"""
Test a user message to extract beliefs from, but an existing belief is determined no longer to
hold.
"""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["no_more_booze"] = True
# Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="I found an untouched barrel of wine!",
),
)
# Both user_said and belief change should've been sent
assert agent.send.call_count == 2
# Agent's current beliefs should've changed
assert not agent.beliefs["no_more_booze"]
@pytest.mark.asyncio
async def test_llm_failure_handling(agent, sample_program):
"""
Check that the agent handles failures gracefully without crashing.
"""
agent._query_llm.side_effect = httpx.HTTPError("")
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
belief_changes = await agent._infer_turn()
assert len(belief_changes) == 0

View File

@@ -1,58 +0,0 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
TextBeliefExtractorAgent,
)
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
return agent
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_transcription_demo(agent, mock_settings):
transcription = "this is a test"
await agent._process_transcription_demo(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]

View File

@@ -67,6 +67,7 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
address="tcp://localhost:5556",
bind=False,
gesture_data=[],
single_gesture_data=[],
)
agent.add_behavior.assert_called_once()
@@ -354,3 +355,13 @@ async def test_listen_loop_ping_sends_internal(zmq_context):
await agent._listen_loop()
pub_socket.send_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_negotiate_req_socket_none_causes_retry(zmq_context):
agent = RICommunicationAgent("ri_comm")
agent._req_socket = None
result = await agent._negotiate_connection(max_retries=1)
assert result is False

View File

@@ -49,6 +49,9 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock() # Mock the send method to verify replies
mock_logger = MagicMock()
agent.logger = mock_logger
# Simulate receiving a message from BDI
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
@@ -63,7 +66,7 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
# "Hello world." constitutes one sentence/chunk based on punctuation split
# The agent should call send once with the full sentence
assert agent.send.called
args = agent.send.call_args[0][0]
args = agent.send.call_args_list[0][0][0]
assert args.to == mock_settings.agent_settings.bdi_core_name
assert "Hello world." in args.body
@@ -134,3 +137,131 @@ def test_llm_instructions():
text_def = instr_def.build_developer_instruction()
assert "Norms to follow" in text_def
assert "Goals to reach" in text_def
@pytest.mark.asyncio
async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, mock_settings):
"""
Covers the ValidationError branch:
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
Assert: no message is sent.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Invalid JSON that triggers ValidationError in LLMPromptMessage
invalid_json = '{"text": "Hi", "wrong_field": 123}' # field not in schema
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=invalid_json,
)
await agent.handle_message(msg)
# Should not send any reply
agent.send.assert_not_called()
@pytest.mark.asyncio
async def test_handle_message_ignored_sender_branch_no_send(mock_httpx_client, mock_settings):
"""
Covers the else branch for messages not from BDI core:
else:
self.logger.debug("Message ignored (not from BDI core.")
Assert: no message is sent.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
msg = InternalMessage(
to="llm_agent",
sender="some_other_agent", # Not BDI core
body='{"text": "Hi"}',
)
await agent.handle_message(msg)
# Should not send any reply
agent.send.assert_not_called()
@pytest.mark.asyncio
async def test_query_llm_yields_final_tail_chunk(mock_settings):
"""
Covers the branch: if current_chunk: yield current_chunk
Ensure that the last partial chunk is emitted.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
agent.logger = MagicMock()
agent.logger.llm = MagicMock()
# Patch _stream_query_llm to yield tokens that do NOT end with punctuation
async def fake_stream(messages):
yield "Hello"
yield " world" # No punctuation to trigger the normal chunking
agent._stream_query_llm = fake_stream
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
# Collect chunks yielded
chunks = []
async for chunk in agent._query_llm(prompt.text, prompt.norms, prompt.goals):
chunks.append(chunk)
# The final chunk should be yielded
assert chunks[-1] == "Hello world"
assert any("Hello" in c for c in chunks)
@pytest.mark.asyncio
async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_settings):
"""
Covers: if not line or not line.startswith("data: "): continue
Feed lines that are empty or do not start with 'data:' and check they are skipped.
"""
# Mock response
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
lines = [
"", # empty line
"not data", # invalid prefix
'data: {"choices": [{"delta": {"content": "Hi"}}]}',
"data: [DONE]",
]
async def aiter_lines_gen():
for line in lines:
yield line
mock_response.aiter_lines.side_effect = aiter_lines_gen
# Proper async context manager for stream
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
# Make stream return the async context manager
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Patch settings for local LLM URL
with patch("control_backend.agents.llm.llm_agent.settings") as mock_sett:
mock_sett.llm_settings.local_llm_url = "http://localhost"
mock_sett.llm_settings.local_llm_model = "test-model"
# Collect tokens
tokens = []
async for token in agent._stream_query_llm([]):
tokens.append(token)
# Only the valid 'data:' line should yield content
assert tokens == ["Hi"]

View File

@@ -120,3 +120,83 @@ def test_mlx_recognizer():
mlx_mock.transcribe.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"
@pytest.mark.asyncio
async def test_transcription_loop_continues_after_error(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [
fake_audio, # first iteration → recognizer fails
asyncio.CancelledError(), # second iteration → stop loop
]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.side_effect = RuntimeError("fail")
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent._running = True # ← REQUIRED to enter the loop
agent.send = AsyncMock() # should never be called
agent.add_behavior = AsyncMock() # match other tests
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# recognizer failed, so we should never send anything
agent.send.assert_not_called()
# recv must have been called twice (audio then CancelledError)
assert mock_sub.recv.call_count == 2
@pytest.mark.asyncio
async def test_transcription_continue_branch_when_empty(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# First recv → audio chunk
# Second recv → Cancel loop → stop iteration
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = "" # <— triggers the continue branch
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
# Make loop runnable
agent._running = True
agent.send = AsyncMock()
agent.add_behavior = AsyncMock()
await agent.setup()
# Execute loop manually
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# → Because of "continue", NO sending should occur
agent.send.assert_not_called()
# → Continue was hit, so we must have read exactly 2 times:
# - first audio
# - second CancelledError
assert mock_sub.recv.call_count == 2
# → recognizer was called once (first iteration)
assert mock_recognizer.recognize_speech.call_count == 1

View File

@@ -1,7 +1,8 @@
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
@@ -123,3 +124,44 @@ async def test_no_data(audio_out_socket, vad_agent):
audio_out_socket.send.assert_not_called()
assert len(vad_agent.audio_buffer) == 0
@pytest.mark.asyncio
async def test_vad_model_load_failure_stops_agent(vad_agent):
"""
Test that if loading the VAD model raises an Exception, it is caught,
the agent logs an exception, stops itself, and setup returns.
"""
# Patch torch.hub.load to raise an exception
with patch(
"control_backend.agents.perception.vad_agent.torch.hub.load",
side_effect=Exception("model fail"),
):
# Patch stop to an AsyncMock so we can check it was awaited
vad_agent.stop = AsyncMock()
result = await vad_agent.setup()
# Assert stop was called
vad_agent.stop.assert_awaited_once()
# Assert setup returned None
assert result is None
@pytest.mark.asyncio
async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog):
"""
Test that if binding the output socket raises ZMQBindError,
audio_out_socket is set to None, None is returned, and an error is logged.
"""
mock_socket = MagicMock()
mock_socket.bind_to_random_port.side_effect = zmq.ZMQBindError()
with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx:
mock_ctx.return_value.socket.return_value = mock_socket
with caplog.at_level("ERROR"):
port = vad_agent._connect_audio_out_socket()
assert port is None
assert vad_agent.audio_out_socket is None
assert caplog.text is not None

View File

@@ -0,0 +1,146 @@
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import RIEndpoint
@pytest.fixture
def agent():
agent = UserInterruptAgent(name="user_interrupt_agent")
agent.send = AsyncMock()
agent.logger = MagicMock()
agent.sub_socket = AsyncMock()
return agent
@pytest.mark.asyncio
async def test_send_to_speech_agent(agent):
"""Verify speech command format."""
await agent._send_to_speech_agent("Hello World")
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.robot_speech_name
body = json.loads(sent_msg.body)
assert body["data"] == "Hello World"
assert body["is_priority"] is True
@pytest.mark.asyncio
async def test_send_to_gesture_agent(agent):
"""Verify gesture command format."""
await agent._send_to_gesture_agent("wave_hand")
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.robot_gesture_name
body = json.loads(sent_msg.body)
assert body["data"] == "wave_hand"
assert body["is_priority"] is True
assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value
@pytest.mark.asyncio
async def test_send_to_program_manager(agent):
"""Verify belief update format."""
context_str = "2"
await agent._send_to_program_manager(context_str)
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.bdi_program_manager_name
assert sent_msg.thread == "belief_override_id"
body = json.loads(sent_msg.body)
assert body["belief"] == context_str
@pytest.mark.asyncio
async def test_receive_loop_routing_success(agent):
"""
Test that the loop correctly:
1. Receives 'button_pressed' topic from ZMQ
2. Parses the JSON payload to find 'type' and 'context'
3. Calls the correct handler method based on 'type'
"""
# Prepare JSON payloads as bytes
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
(b"button_pressed", payload_speech),
(b"button_pressed", payload_gesture),
(b"button_pressed", payload_override),
asyncio.CancelledError, # Stop the infinite loop
]
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_program_manager = AsyncMock()
try:
await agent._receive_button_event()
except asyncio.CancelledError:
pass
await asyncio.sleep(0)
# Speech
agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech")
# Gesture
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
# Override
agent._send_to_program_manager.assert_awaited_once_with("Hello Override")
assert agent._send_to_speech_agent.await_count == 1
assert agent._send_to_gesture_agent.await_count == 1
assert agent._send_to_program_manager.await_count == 1
@pytest.mark.asyncio
async def test_receive_loop_unknown_type(agent):
"""Test that unknown 'type' values in the JSON log a warning and do not crash."""
# Prepare a payload with an unknown type
payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
(b"button_pressed", payload_unknown),
asyncio.CancelledError,
]
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_belief_collector = AsyncMock()
try:
await agent._receive_button_event()
except asyncio.CancelledError:
pass
await asyncio.sleep(0)
# Ensure no handlers were called
agent._send_to_speech_agent.assert_not_called()
agent._send_to_gesture_agent.assert_not_called()
agent._send_to_belief_collector.assert_not_called()
agent.logger.warning.assert_called_with(
"Received button press with unknown type '%s' (context: '%s').",
"unknown_thing",
"some_data",
)

View File

@@ -0,0 +1,63 @@
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.responses import StreamingResponse
from control_backend.api.v1.endpoints import logs
@pytest.fixture
def client():
"""TestClient with logs router included."""
app = FastAPI()
app.include_router(logs.router)
return TestClient(app)
@pytest.mark.asyncio
async def test_log_stream_endpoint_lines(client):
"""Call /logs/stream with a mocked ZMQ socket to cover all lines."""
# Dummy socket to mock ZMQ behavior
class DummySocket:
def __init__(self):
self.subscribed = []
self.connected = False
self.recv_count = 0
def subscribe(self, topic):
self.subscribed.append(topic)
def connect(self, addr):
self.connected = True
async def recv_multipart(self):
# Return one message, then stop generator
if self.recv_count == 0:
self.recv_count += 1
return (b"INFO", b"test message")
else:
raise StopAsyncIteration
dummy_socket = DummySocket()
# Patch Context.instance().socket() to return dummy socket
with patch("control_backend.api.v1.endpoints.logs.Context.instance") as mock_context:
mock_context.return_value.socket.return_value = dummy_socket
# Call the endpoint directly
response = await logs.log_stream()
assert isinstance(response, StreamingResponse)
# Fetch one chunk from the generator
gen = response.body_iterator
chunk = await gen.__anext__()
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
assert "data:" in chunk
# Optional: assert subscribe/connect were called
assert dummy_socket.subscribed # at least some log levels subscribed
assert dummy_socket.connected # connect was called

View File

@@ -0,0 +1,45 @@
import json
import pytest
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import message
@pytest.fixture
def client():
"""FastAPI TestClient for the message router."""
from fastapi import FastAPI
app = FastAPI()
app.include_router(message.router)
return TestClient(app)
def test_receive_message_post(client, monkeypatch):
"""Test POST /message endpoint sends message to pub socket."""
# Dummy pub socket to capture sent messages
class DummyPubSocket:
def __init__(self):
self.sent = []
async def send_multipart(self, msg):
self.sent.append(msg)
dummy_socket = DummyPubSocket()
# Patch app.state.endpoints_pub_socket
client.app.state.endpoints_pub_socket = dummy_socket
data = {"message": "Hello world"}
response = client.post("/message", json=data)
assert response.status_code == 202
assert response.json() == {"status": "Message received"}
# Ensure the message was sent via pub_socket
assert len(dummy_socket.sent) == 1
topic, body = dummy_socket.sent[0]
parsed = json.loads(body.decode("utf-8"))
assert parsed["message"] == "Hello world"

View File

@@ -1,4 +1,5 @@
import json
import uuid
from unittest.mock import AsyncMock
import pytest
@@ -6,7 +7,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import program
from control_backend.schemas.program import Program
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program
@pytest.fixture
@@ -25,29 +26,37 @@ def client(app):
def make_valid_program_dict():
"""Helper to create a valid Program JSON structure."""
return {
"phases": [
{
"id": "phase1",
"label": "basephase",
"norms": [{"id": "n1", "label": "norm", "norm": "be nice"}],
"goals": [
{"id": "g1", "label": "goal", "description": "test goal", "achieved": False}
# Converting to JSON using Pydantic because it knows how to convert a UUID object
program_json_str = Program(
phases=[
Phase(
id=uuid.uuid4(),
name="Basic Phase",
norms=[
BasicNorm(
id=uuid.uuid4(),
name="Some norm",
norm="Do normal.",
),
],
"triggers": [
{
"id": "t1",
"label": "trigger",
"type": "keywords",
"keywords": [
{"id": "kw1", "keyword": "keyword1"},
{"id": "kw2", "keyword": "keyword2"},
],
},
goals=[
Goal(
id=uuid.uuid4(),
name="Some goal",
plan=Plan(
id=uuid.uuid4(),
name="Goal Plan",
steps=[],
),
can_fail=False,
),
],
}
]
}
triggers=[],
),
],
).model_dump_json()
# Converting back to a dict because that's what's expected
return json.loads(program_json_str)
def test_receive_program_success(client):
@@ -71,7 +80,8 @@ def test_receive_program_success(client):
sent_bytes = args[0][1]
sent_obj = json.loads(sent_bytes.decode())
expected_obj = Program.model_validate(program_dict).model_dump()
# Converting to JSON using Pydantic because it knows how to handle UUIDs
expected_obj = json.loads(Program.model_validate(program_dict).model_dump_json())
assert sent_obj == expected_obj

View File

@@ -1,3 +1,4 @@
# tests/test_robot_endpoints.py
import json
from unittest.mock import AsyncMock, MagicMock, patch
@@ -29,7 +30,7 @@ def client(app):
@pytest.fixture
def mock_zmq_context():
"""Mock the ZMQ context."""
"""Mock the ZMQ context used by the endpoint module."""
with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context:
context_instance = MagicMock()
mock_context.return_value = context_instance
@@ -38,13 +39,13 @@ def mock_zmq_context():
@pytest.fixture
def mock_sockets(mock_zmq_context):
"""Mock ZMQ sockets."""
"""Optional helper if you want both a sub and req/push socket available."""
mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_pub_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_zmq_context.socket.return_value = mock_sub_socket
return {"sub": mock_sub_socket, "pub": mock_pub_socket}
return {"sub": mock_sub_socket, "req": mock_req_socket}
def test_receive_speech_command_success(client):
@@ -61,11 +62,11 @@ def test_receive_speech_command_success(client):
speech_command = SpeechCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
response = client.post("/command/speech", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
assert response.json() == {"status": "Speech command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
@@ -75,9 +76,8 @@ def test_receive_speech_command_success(client):
def test_receive_gesture_command_success(client):
"""
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
Test for successful reception of a command that is a gesture command.
Ensures the status code is 202 and the response body is correct.
"""
# Arrange
mock_pub_socket = AsyncMock()
@@ -87,11 +87,11 @@ def test_receive_gesture_command_success(client):
gesture_command = GestureCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
response = client.post("/command/gesture", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
assert response.json() == {"status": "Gesture command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
@@ -99,13 +99,23 @@ def test_receive_gesture_command_success(client):
)
def test_receive_command_invalid_payload(client):
def test_receive_speech_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
response = client.post("/command/speech", json=bad_payload)
assert response.status_code == 422 # validation error
def test_receive_gesture_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command/gesture", json=bad_payload)
assert response.status_code == 422 # validation error
@@ -116,7 +126,9 @@ def test_ping_check_returns_none(client):
assert response.json() is None
# TODO: Convert these mock sockets to the fixture.
# ----------------------------
# ping_stream tests (unchanged behavior)
# ----------------------------
@pytest.mark.asyncio
async def test_ping_stream_yields_ping_event(monkeypatch):
"""Test that ping_stream yields a proper SSE message when a ping is received."""
@@ -129,6 +141,11 @@ async def test_ping_stream_yields_ping_event(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
# patch settings address used by ping_stream
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
@@ -142,7 +159,7 @@ async def test_ping_stream_yields_ping_event(monkeypatch):
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@@ -159,6 +176,10 @@ async def test_ping_stream_handles_timeout(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(return_value=True)
@@ -168,7 +189,7 @@ async def test_ping_stream_handles_timeout(monkeypatch):
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@@ -187,6 +208,10 @@ async def test_ping_stream_yields_json_values(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
@@ -199,183 +224,135 @@ async def test_ping_stream_yields_json_values(monkeypatch):
assert "connected" in event_text
assert "true" in event_text
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
# New tests for get_available_gesture_tags endpoint
# ----------------------------
# Updated get_available_gesture_tags tests (REQ socket on tcp://localhost:7788)
# ----------------------------
@pytest.mark.asyncio
async def test_get_available_gesture_tags_success(client, monkeypatch):
"""
Test successful retrieval of available gesture tags.
Test successful retrieval of available gesture tags using a REQ socket.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with gesture tags
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": ["wave", "nod", "point", "dance"]}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger to avoid actual logging
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Replace logger methods to avoid noisy logs in tests
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]}
# Verify ZeroMQ interactions
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_sub_socket.recv_multipart.assert_awaited_once()
# Verify ZeroMQ REQ interactions
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
mock_req_socket.recv.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_available_gesture_tags_with_amount(client, monkeypatch):
"""
Test retrieval of gesture tags with a specific amount parameter.
This tests the TODO in the endpoint about getting a certain amount from the UI.
The endpoint currently ignores the 'amount' TODO, so behavior is the same as 'success'.
This test asserts that the endpoint still sends b"None" and returns the tags.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with gesture tags
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": ["wave", "nod"]}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Act - Note: The endpoint currently doesn't support query parameters for amount,
# but we're testing what happens if the UI sends an amount (the TODO in the code)
# For now, we test the current behavior
response = client.get("/get_available_gesture_tags")
# Act
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod"]}
# The endpoint currently doesn't use the amount parameter, so it should send empty bytes
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
@pytest.mark.asyncio
async def test_get_available_gesture_tags_timeout(client, monkeypatch):
"""
Test timeout scenario when fetching gesture tags.
Test timeout scenario when fetching gesture tags. Endpoint should handle TimeoutError
and return an empty list while logging the timeout.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a timeout
mock_sub_socket.recv_multipart = AsyncMock(side_effect=TimeoutError)
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
mock_req_socket.recv = AsyncMock(side_effect=TimeoutError)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger to verify debug message is logged
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Patch logger.debug so we can assert it was called with the expected message
mock_debug = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_debug)
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
# On timeout, body becomes b"" and json.loads(b"") raises JSONDecodeError
# But looking at the endpoint code, it will try to parse empty bytes which will fail
# Let's check what actually happens
assert response.json() == {"available_gesture_tags": []}
# Verify the timeout was logged
mock_logger.assert_called_once_with("got timeout error fetching gestures")
# Verify the timeout was logged using the exact string from the endpoint code
mock_debug.assert_called_once_with("Got timeout error fetching gestures.")
# Verify ZeroMQ interactions
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_sub_socket.recv_multipart.assert_awaited_once()
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
mock_req_socket.recv.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_available_gesture_tags_empty_response(client, monkeypatch):
"""
Test scenario when response contains no tags.
Test scenario when response contains an empty 'tags' list.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with empty tags
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": []}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
@@ -388,65 +365,51 @@ async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch):
Test scenario when response JSON doesn't contain 'tags' key.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response without 'tags' key
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"some_other_key": "value"}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
# .get("tags", []) should return empty list if 'tags' key is missing
assert response.json() == {"available_gesture_tags": []}
@pytest.mark.asyncio
async def test_get_available_gesture_tags_invalid_json(client, monkeypatch):
"""
Test scenario when response contains invalid JSON.
Test scenario when response contains invalid JSON. Endpoint should log the error
and return an empty list.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with invalid JSON
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"get_gestures", b"invalid json"])
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
mock_req_socket.recv = AsyncMock(return_value=b"invalid json")
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_error = MagicMock()
monkeypatch.setattr(robot.logger, "error", mock_error)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert - invalid JSON should raise an exception
# Assert - invalid JSON should lead to empty list and error log invocation
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []}
assert mock_error.call_count == 1

View File

@@ -0,0 +1,16 @@
from fastapi.routing import APIRoute
from control_backend.api.v1.router import api_router # <--- corrected import
def test_router_includes_expected_paths():
"""Ensure api_router includes main router prefixes."""
routes = [r for r in api_router.routes if isinstance(r, APIRoute)]
paths = [r.path for r in routes]
# Ensure at least one route under each prefix exists
assert any(p.startswith("/robot") for p in paths)
assert any(p.startswith("/message") for p in paths)
assert any(p.startswith("/sse") for p in paths)
assert any(p.startswith("/logs") for p in paths)
assert any(p.startswith("/program") for p in paths)

View File

@@ -0,0 +1,24 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import sse
@pytest.fixture
def app():
app = FastAPI()
app.include_router(sse.router)
return app
@pytest.fixture
def client(app):
return TestClient(app)
def test_sse_route_exists(client):
"""Minimal smoke test to ensure /sse route exists and responds."""
response = client.get("/sse")
# Since implementation is not done, we only assert it doesn't crash
assert response.status_code == 200

View File

@@ -2,7 +2,7 @@
import asyncio
import logging
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -70,3 +70,142 @@ async def test_get_agent():
agent = ConcreteTestAgent("registrant")
assert AgentDirectory.get("registrant") == agent
assert AgentDirectory.get("non_existent") is None
class DummyAgent(BaseAgent):
async def setup(self):
pass # we will test this separately
async def handle_message(self, msg: InternalMessage):
self.last_handled = msg
@pytest.mark.asyncio
async def test_base_agent_setup_is_noop():
agent = DummyAgent("dummy")
# Should simply return without error
assert await agent.setup() is None
@pytest.mark.asyncio
async def test_send_to_local_agent(monkeypatch):
sender = DummyAgent("sender")
target = DummyAgent("receiver")
# Fake logger
sender.logger = MagicMock()
# Patch inbox.put
target.inbox.put = AsyncMock()
message = InternalMessage(to="receiver", sender="sender", body="hello")
await sender.send(message)
target.inbox.put.assert_awaited_once_with(message)
sender.logger.debug.assert_called_once()
@pytest.mark.asyncio
async def test_process_inbox_calls_handle_message(monkeypatch):
agent = DummyAgent("dummy")
agent.logger = MagicMock()
# Make agent running so loop triggers
agent._running = True
# Prepare inbox to give one message then stop
msg = InternalMessage(to="dummy", sender="x", body="test")
async def get_once():
agent._running = False # stop after first iteration
return msg
agent.inbox.get = AsyncMock(side_effect=get_once)
agent.handle_message = AsyncMock()
await agent._process_inbox()
agent.handle_message.assert_awaited_once_with(msg)
@pytest.mark.asyncio
async def test_receive_internal_zmq_loop_success(monkeypatch):
agent = DummyAgent("dummy")
agent.logger = MagicMock()
agent._running = True
mock_socket = MagicMock()
mock_socket.recv_multipart = AsyncMock(
side_effect=[
(
b"topic",
InternalMessage(to="dummy", sender="x", body="hi").model_dump_json().encode(),
),
asyncio.CancelledError(), # stop loop
]
)
agent._internal_sub_socket = mock_socket
agent.inbox.put = AsyncMock()
await agent._receive_internal_zmq_loop()
agent.inbox.put.assert_awaited() # message forwarded
@pytest.mark.asyncio
async def test_receive_internal_zmq_loop_exception_logs_error():
agent = DummyAgent("dummy")
agent.logger = MagicMock()
agent._running = True
mock_socket = MagicMock()
mock_socket.recv_multipart = AsyncMock(
side_effect=[Exception("boom"), asyncio.CancelledError()]
)
agent._internal_sub_socket = mock_socket
agent.inbox.put = AsyncMock()
await agent._receive_internal_zmq_loop()
agent.logger.exception.assert_called_once()
assert "Could not process ZMQ message." in agent.logger.exception.call_args[0][0]
@pytest.mark.asyncio
async def test_base_agent_handle_message_not_implemented():
class RawAgent(BaseAgent):
async def setup(self):
pass
agent = RawAgent("raw")
msg = InternalMessage(to="raw", sender="x", body="hi")
with pytest.raises(NotImplementedError):
await BaseAgent.handle_message(agent, msg)
@pytest.mark.asyncio
async def test_base_agent_setup_abstract_method_body_executes():
"""
Covers the 'pass' inside BaseAgent.setup().
Since BaseAgent is abstract, we do NOT instantiate it.
We call the coroutine function directly on BaseAgent and pass a dummy self.
"""
class Dummy:
"""Minimal stub to act as 'self'."""
pass
stub = Dummy()
# Call BaseAgent.setup() as an unbound coroutine, passing stub as 'self'
result = await BaseAgent.setup(stub)
# The method contains only 'pass', so it returns None
assert result is None

View File

@@ -86,3 +86,34 @@ def test_setup_logging_zmq_handler(mock_zmq_context):
args = mock_dict_config.call_args[0][0]
assert "interface_or_socket" in args["handlers"]["ui"]
def test_add_logging_level_method_name_exists_in_logging():
# method_name explicitly set to an existing logging method → triggers first hasattr branch
with pytest.raises(AttributeError) as exc:
add_logging_level("NEWDUPLEVEL", 37, method_name="info")
assert "info already defined in logging module" in str(exc.value)
def test_add_logging_level_method_name_exists_in_logger_class():
# 'makeRecord' exists on Logger class but not on the logging module
with pytest.raises(AttributeError) as exc:
add_logging_level("ANOTHERLEVEL", 38, method_name="makeRecord")
assert "makeRecord already defined in logger class" in str(exc.value)
def test_add_logging_level_log_to_root_path_executes_without_error():
# Verify log_to_root is installed and callable — without asserting logging output
level_name = "ROOTTEST"
level_num = 36
add_logging_level(level_name, level_num)
# Simply call the injected root logger method
# The line is executed even if we don't validate output
root_logging_method = getattr(logging, level_name.lower(), None)
assert callable(root_logging_method)
# Execute the method to hit log_to_root in coverage.
# No need to verify log output.
root_logging_method("some message")

View File

@@ -0,0 +1,12 @@
from control_backend.schemas.message import Message
def base_message() -> Message:
return Message(message="Example")
def test_valid_message():
mess = base_message()
validated = Message.model_validate(mess)
assert isinstance(validated, Message)
assert validated.message == "Example"

View File

@@ -1,49 +1,65 @@
import uuid
import pytest
from pydantic import ValidationError
from control_backend.schemas.program import (
BasicNorm,
ConditionalNorm,
Goal,
KeywordTrigger,
Norm,
InferredBelief,
KeywordBelief,
LogicalOperator,
Phase,
Plan,
Program,
TriggerKeyword,
SemanticBelief,
Trigger,
)
def base_norm() -> Norm:
return Norm(
id="norm1",
label="testNorm",
def base_norm() -> BasicNorm:
return BasicNorm(
id=uuid.uuid4(),
name="testNormName",
norm="testNormNorm",
critical=False,
)
def base_goal() -> Goal:
return Goal(
id="goal1",
label="testGoal",
description="testGoalDescription",
achieved=False,
id=uuid.uuid4(),
name="testGoalName",
plan=Plan(
id=uuid.uuid4(),
name="testGoalPlanName",
steps=[],
),
can_fail=False,
)
def base_trigger() -> KeywordTrigger:
return KeywordTrigger(
id="trigger1",
label="testTrigger",
type="keywords",
keywords=[
TriggerKeyword(id="keyword1", keyword="testKeyword1"),
TriggerKeyword(id="keyword1", keyword="testKeyword2"),
],
def base_trigger() -> Trigger:
return Trigger(
id=uuid.uuid4(),
name="testTriggerName",
condition=KeywordBelief(
id=uuid.uuid4(),
name="testTriggerKeywordBeliefTriggerName",
keyword="Keyword",
),
plan=Plan(
id=uuid.uuid4(),
name="testTriggerPlanName",
steps=[],
),
)
def base_phase() -> Phase:
return Phase(
id="phase1",
label="basephase",
id=uuid.uuid4(),
norms=[base_norm()],
goals=[base_goal()],
triggers=[base_trigger()],
@@ -58,7 +74,7 @@ def invalid_program() -> dict:
# wrong types inside phases list (not Phase objects)
return {
"phases": [
{"id": "phase1"}, # incomplete
{"id": uuid.uuid4()}, # incomplete
{"not_a_phase": True},
]
}
@@ -77,11 +93,112 @@ def test_valid_deepprogram():
# validate nested components directly
phase = validated.phases[0]
assert isinstance(phase.goals[0], Goal)
assert isinstance(phase.triggers[0], KeywordTrigger)
assert isinstance(phase.norms[0], Norm)
assert isinstance(phase.triggers[0], Trigger)
assert isinstance(phase.norms[0], BasicNorm)
def test_invalid_program():
bad = invalid_program()
with pytest.raises(ValidationError):
Program.model_validate(bad)
def test_conditional_norm_parsing():
"""
Check that pydantic is able to preserve the type of the norm, that it doesn't lose its
"condition" field when serializing and deserializing.
"""
norm = ConditionalNorm(
name="testNormName",
id=uuid.uuid4(),
norm="testNormNorm",
critical=False,
condition=KeywordBelief(
name="testKeywordBelief",
id=uuid.uuid4(),
keyword="testKeywordBelief",
),
)
program = Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[norm],
goals=[],
triggers=[],
),
],
)
parsed_program = Program.model_validate_json(program.model_dump_json())
parsed_norm = parsed_program.phases[0].norms[0]
assert hasattr(parsed_norm, "condition")
assert isinstance(parsed_norm, ConditionalNorm)
def test_belief_type_parsing():
"""
Check that pydantic is able to discern between the different types of beliefs when serializing
and deserializing.
"""
keyword_belief = KeywordBelief(
name="testKeywordBelief",
id=uuid.uuid4(),
keyword="something",
)
semantic_belief = SemanticBelief(
name="testSemanticBelief",
id=uuid.uuid4(),
description="something",
)
inferred_belief = InferredBelief(
name="testInferredBelief",
id=uuid.uuid4(),
operator=LogicalOperator.OR,
left=keyword_belief,
right=semantic_belief,
)
program = Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[],
goals=[],
triggers=[
Trigger(
name="testTriggerKeywordTrigger",
id=uuid.uuid4(),
condition=keyword_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
Trigger(
name="testTriggerSemanticTrigger",
id=uuid.uuid4(),
condition=semantic_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
Trigger(
name="testTriggerInferredTrigger",
id=uuid.uuid4(),
condition=inferred_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
],
),
],
)
parsed_program = Program.model_validate_json(program.model_dump_json())
parsed_keyword_belief = parsed_program.phases[0].triggers[0].condition
assert isinstance(parsed_keyword_belief, KeywordBelief)
parsed_semantic_belief = parsed_program.phases[0].triggers[1].condition
assert isinstance(parsed_semantic_belief, SemanticBelief)
parsed_inferred_belief = parsed_program.phases[0].triggers[2].condition
assert isinstance(parsed_inferred_belief, InferredBelief)

73
test/unit/test_main.py Normal file
View File

@@ -0,0 +1,73 @@
import asyncio
import sys
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from control_backend.api.v1.router import api_router
from control_backend.main import app, lifespan
# Fix event loop on Windows
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@pytest.fixture
def client():
# Patch setup_logging so it does nothing
with patch("control_backend.main.setup_logging"):
with TestClient(app) as c:
yield c
def test_root_fast():
# Patch heavy startup code so it doesnt slow down
with patch("control_backend.main.setup_logging"), patch("control_backend.main.lifespan"):
client = TestClient(app)
resp = client.get("/")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
def test_cors_middleware_added():
"""Test that CORSMiddleware is correctly added to the app."""
from starlette.middleware.cors import CORSMiddleware
middleware_classes = [m.cls for m in app.user_middleware]
assert CORSMiddleware in middleware_classes
def test_api_router_included():
"""Test that the API router is included in the FastAPI app."""
route_paths = [r.path for r in app.routes]
for route in api_router.routes:
assert route.path in route_paths
@pytest.mark.asyncio
async def test_lifespan_agent_start_exception():
"""
Trigger an exception during agent startup to cover the error logging branch.
Ensures exceptions are logged properly and re-raised.
"""
with (
patch(
"control_backend.main.RICommunicationAgent.start", new_callable=AsyncMock
) as ri_start,
patch("control_backend.main.setup_logging"),
patch("control_backend.main.threading.Thread"),
):
# Force RICommunicationAgent.start to raise an exception
ri_start.side_effect = Exception("Test exception")
with patch("control_backend.main.logger") as mock_logger:
with pytest.raises(Exception, match="Test exception"):
async with lifespan(app):
pass
# Verify the error was logged correctly
assert mock_logger.error.called
args, _ = mock_logger.error.call_args
assert isinstance(args[2], Exception)

23
uv.lock generated
View File

@@ -997,6 +997,7 @@ dependencies = [
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "python-json-logger" },
{ name = "python-slugify" },
{ name = "pyyaml" },
{ name = "pyzmq" },
{ name = "silero-vad" },
@@ -1046,6 +1047,7 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.12.0" },
{ name = "pydantic-settings", specifier = ">=2.11.0" },
{ name = "python-json-logger", specifier = ">=4.0.0" },
{ name = "python-slugify", specifier = ">=8.0.4" },
{ name = "pyyaml", specifier = ">=6.0.3" },
{ name = "pyzmq", specifier = ">=27.1.0" },
{ name = "silero-vad", specifier = ">=6.0.0" },
@@ -1341,6 +1343,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" },
]
[[package]]
name = "python-slugify"
version = "8.0.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "text-unidecode" },
]
sdist = { url = "https://files.pythonhosted.org/packages/87/c7/5e1547c44e31da50a460df93af11a535ace568ef89d7a811069ead340c4a/python-slugify-8.0.4.tar.gz", hash = "sha256:59202371d1d05b54a9e7720c5e038f928f45daaffe41dd10822f3907b937c856", size = 10921, upload-time = "2024-02-08T18:32:45.488Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/62/02da182e544a51a5c3ccf4b03ab79df279f9c60c5e82d5e8bec7ca26ac11/python_slugify-8.0.4-py2.py3-none-any.whl", hash = "sha256:276540b79961052b66b7d116620b36518847f52d5fd9e3a70164fc8c50faa6b8", size = 10051, upload-time = "2024-02-08T18:32:43.911Z" },
]
[[package]]
name = "pyyaml"
version = "6.0.3"
@@ -1864,6 +1878,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" },
]
[[package]]
name = "text-unidecode"
version = "1.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ab/e2/e9a00f0ccb71718418230718b3d900e71a5d16e701a3dae079a21e9cd8f8/text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93", size = 76885, upload-time = "2019-08-30T21:36:45.405Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a6/a5/c0b6468d3824fe3fde30dbb5e1f687b291608f9473681bbf7dabbf5a87d7/text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8", size = 78154, upload-time = "2019-08-30T21:37:03.543Z" },
]
[[package]]
name = "tiktoken"
version = "0.12.0"