Compare commits

...

21 Commits

Author SHA1 Message Date
Twirre Meulenbelt
4b71981a3e fix: some bugs and some tests
ref: N25B-429
2026-01-12 09:00:50 +01:00
866d7c4958 fix: end phase loop correctly notifies about user_said
ref: N25B-429
2026-01-08 15:13:12 +01:00
133019a928 feat: trigger name and trigger checks on belief update
ref: N25B-429
2026-01-08 14:04:44 +01:00
4d0ba69443 fix: don't re-add user_said upon phase transition
ref: N25B-429
2026-01-08 13:44:25 +01:00
625ef0c365 feat: phase transition waits for all goals
ref: N25B-429
2026-01-08 13:36:03 +01:00
b88758fa76 feat: phase transition independent of response
ref: N25B-429
2026-01-08 13:33:37 +01:00
Twirre Meulenbelt
45719c580b feat: prepend more silence before speech audio for better transcription beginnings
ref: N25B-429
2026-01-08 10:49:13 +01:00
5a61225c6f feat: reset extractor history
ref: N25B-429
2026-01-07 18:10:13 +01:00
a30cea5231 Merge branch 'feat/semantic-beliefs' into feat/extra-agentspeak-functionality 2026-01-07 17:51:30 +01:00
240624f887 Merge branch 'dev' into feat/extra-agentspeak-functionality
# Conflicts:
#	src/control_backend/agents/bdi/bdi_program_manager.py
#	src/control_backend/agents/llm/llm_agent.py
#	test/unit/agents/bdi/test_bdi_program_manager.py
2026-01-07 17:46:48 +01:00
8a77e8e1c7 feat: check goals only for this phase
Since conversation history still remains we can still check at a later point.

ref: N25B-429
2026-01-07 17:31:24 +01:00
3b4dccc760 Merge branch 'feat/semantic-beliefs' into feat/extra-agentspeak-functionality
# Conflicts:
#	src/control_backend/agents/bdi/bdi_program_manager.py
2026-01-07 17:20:52 +01:00
3d49e44cf7 fix: complete pipeline working
User interrupts still need to be tested.

ref: N25B-429
2026-01-07 17:13:58 +01:00
Björn Otgaar
612a96940d Merge branch 'feat/environment-variables' into 'dev'
Docs for environment variables, parameterize some constants

See merge request ics/sp/2025/n25b/pepperplus-cb!38
2026-01-06 09:02:49 +00:00
Pim Hutting
4c20656c75 Merge branch 'feat/program-reset-llm' into 'dev'
feat: made program reset LLM

See merge request ics/sp/2025/n25b/pepperplus-cb!39
2026-01-02 15:13:05 +00:00
Pim Hutting
6ca86e4b81 feat: made program reset LLM 2026-01-02 15:13:04 +00:00
Twirre Meulenbelt
7d798f2e77 Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:40:16 +01:00
Twirre Meulenbelt
5282c2471f Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:35:39 +01:00
Twirre Meulenbelt
0c682d6440 feat: introduce .env.example, docs
The example includes options that are expected to be changed. It also includes a reference to where in the docs you can find a full list of options.

ref: N25B-352
2025-12-11 13:35:19 +01:00
Twirre Meulenbelt
32d8f20dc9 feat: parameterize RI host
Was "localhost" in RI Communication Agent, now uses configurable setting. Secretly also removing "localhost" from VAD agent, as its socket should be something that's "inproc".

ref: N25B-352
2025-12-11 12:12:15 +01:00
Twirre Meulenbelt
9cc0e39955 fix: failures main tests since VAD agent initialization was changed
The test still expects the VAD agent to be started in main, rather than in the RI Communication Agent.

ref: N25B-356
2025-12-11 12:04:24 +01:00
21 changed files with 611 additions and 218 deletions

20
.env.example Normal file
View File

@@ -0,0 +1,20 @@
# Example .env file. To use, make a copy, call it ".env" (i.e. removing the ".example" suffix), then you edit values.
# The hostname of the Robot Interface. Change if the Control Backend and Robot Interface are running on different computers.
RI_HOST="localhost"
# URL for the local LLM API. Must be an API that implements the OpenAI Chat Completions API, but most do.
LLM_SETTINGS__LOCAL_LLM_URL="http://localhost:1234/v1/chat/completions"
# Name of the local LLM model to use.
LLM_SETTINGS__LOCAL_LLM_MODEL="gpt-oss"
# Number of non-speech chunks to wait before speech ended. A chunk is approximately 31 ms. Increasing this number allows longer pauses in speech, but also increases response time.
BEHAVIOUR_SETTINGS__VAD_NON_SPEECH_PATIENCE_CHUNKS=15
# Timeout in milliseconds for socket polling. Increase this number if network latency/jitter is high, often the case when using Wi-Fi. Perhaps 500 ms. A symptom of this issue is transcriptions getting cut off.
BEHAVIOUR_SETTINGS__SOCKET_POLLER_TIMEOUT_MS=100
# For an exhaustive list of options, see the control_backend.core.config module in the docs.

View File

@@ -27,6 +27,7 @@ This + part might differ based on what model you choose.
copy the model name in the module loaded and replace local_llm_modelL. In settings. copy the model name in the module loaded and replace local_llm_modelL. In settings.
## Running ## Running
To run the project (development server), execute the following command (while inside the root repository): To run the project (development server), execute the following command (while inside the root repository):
@@ -34,6 +35,14 @@ To run the project (development server), execute the following command (while in
uv run fastapi dev src/control_backend/main.py uv run fastapi dev src/control_backend/main.py
``` ```
### Environment Variables
You can use environment variables to change settings. Make a copy of the [`.env.example`](.env.example) file, name it `.env` and put it in the root directory. The file itself describes how to do the configuration.
For an exhaustive list of environment options, see the `control_backend.core.config` module in the docs.
## Testing ## Testing
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests: Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:

View File

@@ -33,7 +33,7 @@ class RobotGestureAgent(BaseAgent):
def __init__( def __init__(
self, self,
name: str, name: str,
address=settings.zmq_settings.ri_command_address, address: str,
bind=False, bind=False,
gesture_data=None, gesture_data=None,
single_gesture_data=None, single_gesture_data=None,

View File

@@ -145,7 +145,10 @@ class AgentSpeakGenerator:
type=TriggerType.ADDED_BELIEF, type=TriggerType.ADDED_BELIEF,
trigger_literal=AstLiteral("user_said", [AstVar("Message")]), trigger_literal=AstLiteral("user_said", [AstVar("Message")]),
context=[AstLiteral("phase", [AstString("end")])], context=[AstLiteral("phase", [AstString("end")])],
body=[AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply"))], body=[
AstStatement(StatementType.DO_ACTION, AstLiteral("notify_user_said")),
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply")),
],
) )
) )
@@ -157,7 +160,7 @@ class AgentSpeakGenerator:
previous_goal = None previous_goal = None
for goal in phase.goals: for goal in phase.goals:
self._process_goal(goal, phase, previous_goal) self._process_goal(goal, phase, previous_goal, main_goal=True)
previous_goal = goal previous_goal = goal
for trigger in phase.triggers: for trigger in phase.triggers:
@@ -171,25 +174,40 @@ class AgentSpeakGenerator:
self._astify(to_phase) if to_phase else AstLiteral("phase", [AstString("end")]) self._astify(to_phase) if to_phase else AstLiteral("phase", [AstString("end")])
) )
context = [from_phase_ast, ~AstLiteral("responded_this_turn")] context = [from_phase_ast]
if from_phase and from_phase.goals: if from_phase:
context.append(self._astify(from_phase.goals[-1], achieved=True)) for goal in from_phase.goals:
context.append(self._astify(goal, achieved=True))
body = [ body = [
AstStatement(StatementType.REMOVE_BELIEF, from_phase_ast), AstStatement(StatementType.REMOVE_BELIEF, from_phase_ast),
AstStatement(StatementType.ADD_BELIEF, to_phase_ast), AstStatement(StatementType.ADD_BELIEF, to_phase_ast),
] ]
if from_phase: # if from_phase:
body.extend( # body.extend(
# [
# AstStatement(
# StatementType.TEST_GOAL, AstLiteral("user_said", [AstVar("Message")])
# ),
# AstStatement(
# StatementType.REPLACE_BELIEF, AstLiteral("user_said", [AstVar("Message")])
# ),
# ]
# )
# Notify outside world about transition
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"notify_transition_phase",
[ [
AstStatement( AstString(str(from_phase.id)),
StatementType.TEST_GOAL, AstLiteral("user_said", [AstVar("Message")]) AstString(str(to_phase.id) if to_phase else "end"),
],
), ),
AstStatement( )
StatementType.REPLACE_BELIEF, AstLiteral("user_said", [AstVar("Message")])
),
]
) )
self._asp.plans.append( self._asp.plans.append(
@@ -213,6 +231,11 @@ class AgentSpeakGenerator:
def _add_default_loop(self, phase: Phase) -> None: def _add_default_loop(self, phase: Phase) -> None:
actions = [] actions = []
actions.append(
AstStatement(
StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")])
)
)
actions.append(AstStatement(StatementType.REMOVE_BELIEF, AstLiteral("responded_this_turn"))) actions.append(AstStatement(StatementType.REMOVE_BELIEF, AstLiteral("responded_this_turn")))
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("check_triggers"))) actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("check_triggers")))
@@ -236,6 +259,7 @@ class AgentSpeakGenerator:
phase: Phase, phase: Phase,
previous_goal: Goal | None = None, previous_goal: Goal | None = None,
continues_response: bool = False, continues_response: bool = False,
main_goal: bool = False,
) -> None: ) -> None:
context: list[AstExpression] = [self._astify(phase)] context: list[AstExpression] = [self._astify(phase)]
context.append(~self._astify(goal, achieved=True)) context.append(~self._astify(goal, achieved=True))
@@ -245,6 +269,13 @@ class AgentSpeakGenerator:
context.append(~AstLiteral("responded_this_turn")) context.append(~AstLiteral("responded_this_turn"))
body = [] body = []
if main_goal: # UI only needs to know about the main goals
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_goal_start", [AstString(self.slugify(goal))]),
)
)
subgoals = [] subgoals = []
for step in goal.plan.steps: for step in goal.plan.steps:
@@ -283,11 +314,23 @@ class AgentSpeakGenerator:
body = [] body = []
subgoals = [] subgoals = []
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_trigger_start", [AstString(self.slugify(trigger))]),
)
)
for step in trigger.plan.steps: for step in trigger.plan.steps:
body.append(self._step_to_statement(step)) body.append(self._step_to_statement(step))
if isinstance(step, Goal): if isinstance(step, Goal):
step.can_fail = False # triggers are continuous sequence step.can_fail = False # triggers are continuous sequence
subgoals.append(step) subgoals.append(step)
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_trigger_end", [AstString(self.slugify(trigger))]),
)
)
self._asp.plans.append( self._asp.plans.append(
AstPlan( AstPlan(
@@ -298,6 +341,9 @@ class AgentSpeakGenerator:
) )
) )
# Force trigger (from UI)
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(trigger), [], body))
for subgoal in subgoals: for subgoal in subgoals:
self._process_goal(subgoal, phase, continues_response=True) self._process_goal(subgoal, phase, continues_response=True)
@@ -332,13 +378,7 @@ class AgentSpeakGenerator:
@_astify.register @_astify.register
def _(self, sb: SemanticBelief) -> AstExpression: def _(self, sb: SemanticBelief) -> AstExpression:
return AstLiteral(self.get_semantic_belief_slug(sb)) return AstLiteral(self.slugify(sb))
@staticmethod
def get_semantic_belief_slug(sb: SemanticBelief) -> str:
# If you need a method like this for other types, make a public slugify singledispatch for
# all types.
return f"semantic_{AgentSpeakGenerator._slugify_str(sb.name)}"
@_astify.register @_astify.register
def _(self, ib: InferredBelief) -> AstExpression: def _(self, ib: InferredBelief) -> AstExpression:

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import copy import copy
import json
import time import time
from collections.abc import Iterable from collections.abc import Iterable
@@ -13,7 +14,7 @@ from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage from control_backend.schemas.llm_prompt_message import LLMPromptMessage
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, SpeechCommand
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
@@ -100,7 +101,6 @@ class BDICoreAgent(BaseAgent):
maybe_more_work = True maybe_more_work = True
while maybe_more_work: while maybe_more_work:
maybe_more_work = False maybe_more_work = False
self.logger.debug("Stepping BDI.")
if self.bdi_agent.step(): if self.bdi_agent.step():
maybe_more_work = True maybe_more_work = True
@@ -155,6 +155,17 @@ class BDICoreAgent(BaseAgent):
body=cmd.model_dump_json(), body=cmd.model_dump_json(),
) )
await self.send(out_msg) await self.send(out_msg)
case settings.agent_settings.user_interrupt_name:
content = msg.body
self.logger.debug("Received user interruption: %s", content)
match msg.thread:
case "force_phase_transition":
self._set_goal("transition_phase")
case "force_trigger":
self._force_trigger(msg.body)
case _:
self.logger.warning("Received unknow user interruption: %s", msg)
def _apply_belief_changes(self, belief_changes: BeliefMessage): def _apply_belief_changes(self, belief_changes: BeliefMessage):
""" """
@@ -201,6 +212,22 @@ class BDICoreAgent(BaseAgent):
agentspeak.runtime.Intention(), agentspeak.runtime.Intention(),
) )
# Check for transitions
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
agentspeak.Literal("transition_phase"),
agentspeak.runtime.Intention(),
)
# Check triggers
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
agentspeak.Literal("check_triggers"),
agentspeak.runtime.Intention(),
)
self._wake_bdi_loop.set() self._wake_bdi_loop.set()
self.logger.debug(f"Added belief {self.format_belief_string(name, args)}") self.logger.debug(f"Added belief {self.format_belief_string(name, args)}")
@@ -253,6 +280,37 @@ class BDICoreAgent(BaseAgent):
self.logger.debug(f"Removed {removed_count} beliefs.") self.logger.debug(f"Removed {removed_count} beliefs.")
def _set_goal(self, name: str, args: Iterable[str] | None = None):
args = args or []
if args:
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args)
else:
term = agentspeak.Literal(name)
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
term,
agentspeak.runtime.Intention(),
)
self._wake_bdi_loop.set()
self.logger.debug(f"Set goal !{self.format_belief_string(name, args)}.")
def _force_trigger(self, name: str):
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
agentspeak.Literal(name),
agentspeak.runtime.Intention(),
)
self.logger.info("Manually forced trigger %s.", name)
def _add_custom_actions(self) -> None: def _add_custom_actions(self) -> None:
""" """
Add any custom actions here. Inside `@self.actions.add()`, the first argument is Add any custom actions here. Inside `@self.actions.add()`, the first argument is
@@ -261,7 +319,7 @@ class BDICoreAgent(BaseAgent):
""" """
@self.actions.add(".reply", 2) @self.actions.add(".reply", 2)
def _reply(agent: "BDICoreAgent", term, intention): def _reply(agent, term, intention):
""" """
Let the LLM generate a response to a user's utterance with the current norms and goals. Let the LLM generate a response to a user's utterance with the current norms and goals.
""" """
@@ -294,7 +352,7 @@ class BDICoreAgent(BaseAgent):
yield yield
@self.actions.add(".say", 1) @self.actions.add(".say", 1)
def _say(agent: "BDICoreAgent", term, intention): def _say(agent, term, intention):
""" """
Make the robot say the given text instantly. Make the robot say the given text instantly.
""" """
@@ -308,12 +366,21 @@ class BDICoreAgent(BaseAgent):
sender=settings.agent_settings.bdi_core_name, sender=settings.agent_settings.bdi_core_name,
body=speech_command.model_dump_json(), body=speech_command.model_dump_json(),
) )
# TODO: add to conversation history
self.add_behavior(self.send(speech_message)) self.add_behavior(self.send(speech_message))
chat_history_message = InternalMessage(
to=settings.agent_settings.llm_name,
thread="assistant_message",
body=str(message_text),
)
self.add_behavior(self.send(chat_history_message))
yield yield
@self.actions.add(".gesture", 2) @self.actions.add(".gesture", 2)
def _gesture(agent: "BDICoreAgent", term, intention): def _gesture(agent, term, intention):
""" """
Make the robot perform the given gesture instantly. Make the robot perform the given gesture instantly.
""" """
@@ -326,13 +393,113 @@ class BDICoreAgent(BaseAgent):
gesture_name, gesture_name,
) )
# gesture = Gesture(type=gesture_type, name=gesture_name) if str(gesture_type) == "single":
# gesture_message = InternalMessage( endpoint = RIEndpoint.GESTURE_SINGLE
# to=settings.agent_settings.robot_gesture_name, elif str(gesture_type) == "tag":
# sender=settings.agent_settings.bdi_core_name, endpoint = RIEndpoint.GESTURE_TAG
# body=gesture.model_dump_json(), else:
# ) self.logger.warning("Gesture type %s could not be resolved.", gesture_type)
# asyncio.create_task(agent.send(gesture_message)) endpoint = RIEndpoint.GESTURE_SINGLE
gesture_command = GestureCommand(endpoint=endpoint, data=gesture_name)
gesture_message = InternalMessage(
to=settings.agent_settings.robot_gesture_name,
sender=settings.agent_settings.bdi_core_name,
body=gesture_command.model_dump_json(),
)
self.add_behavior(self.send(gesture_message))
yield
@self.actions.add(".notify_user_said", 1)
def _notify_user_said(agent, term, intention):
user_said = agentspeak.grounded(term.args[0], intention.scope)
msg = InternalMessage(
to=settings.agent_settings.llm_name, thread="user_message", body=str(user_said)
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_trigger_start", 1)
def _notify_trigger_start(agent, term, intention):
"""
Notify the UI about the trigger we just started doing.
"""
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Started trigger %s", trigger_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="trigger_start",
body=str(trigger_name),
)
# TODO: check with Pim
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_trigger_end", 1)
def _notify_trigger_end(agent, term, intention):
"""
Notify the UI about the trigger we just started doing.
"""
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Finished trigger %s", trigger_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="trigger_end",
body=str(trigger_name),
)
# TODO: check with Pim
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_goal_start", 1)
def _notify_goal_start(agent, term, intention):
"""
Notify the UI about the goal we just started chasing.
"""
goal_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Started chasing goal %s", goal_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="goal_start",
body=str(goal_name),
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_transition_phase", 2)
def _notify_transition_phase(agent, term, intention):
"""
Notify the BDI program manager about a phase transition.
"""
old = agentspeak.grounded(term.args[0], intention.scope)
new = agentspeak.grounded(term.args[1], intention.scope)
msg = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
thread="transition_phase",
body=json.dumps({"old": str(old), "new": str(new)}),
)
self.add_behavior(self.send(msg))
yield yield
async def _send_to_llm(self, text: str, norms: str, goals: str): async def _send_to_llm(self, text: str, norms: str, goals: str):
@@ -344,6 +511,7 @@ class BDICoreAgent(BaseAgent):
to=settings.agent_settings.llm_name, to=settings.agent_settings.llm_name,
sender=self.name, sender=self.name,
body=prompt.model_dump_json(), body=prompt.model_dump_json(),
thread="prompt_message",
) )
await self.send(msg) await self.send(msg)
self.logger.info("Message sent to LLM agent: %s", text) self.logger.info("Message sent to LLM agent: %s", text)

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import json
import zmq import zmq
from pydantic import ValidationError from pydantic import ValidationError
@@ -9,7 +10,14 @@ from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_list import BeliefList, GoalList from control_backend.schemas.belief_list import BeliefList, GoalList
from control_backend.schemas.internal_message import InternalMessage from control_backend.schemas.internal_message import InternalMessage
from control_backend.schemas.program import Belief, ConditionalNorm, Goal, InferredBelief, Program from control_backend.schemas.program import (
Belief,
ConditionalNorm,
Goal,
InferredBelief,
Phase,
Program,
)
class BDIProgramManager(BaseAgent): class BDIProgramManager(BaseAgent):
@@ -24,20 +32,20 @@ class BDIProgramManager(BaseAgent):
:ivar sub_socket: The ZMQ SUB socket used to receive program updates. :ivar sub_socket: The ZMQ SUB socket used to receive program updates.
""" """
_program: Program
_phase: Phase | None
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.sub_socket = None self.sub_socket = None
def _initialize_internal_state(self, program: Program):
self._program = program
self._phase = program.phases[0] # start in first phase
async def _create_agentspeak_and_send_to_bdi(self, program: Program): async def _create_agentspeak_and_send_to_bdi(self, program: Program):
""" """
Convert a received program into BDI beliefs and send them to the BDI Core Agent. Convert a received program into an AgentSpeak file and send it to the BDI Core Agent.
Currently, it takes the **first phase** of the program and extracts:
- **Norms**: Constraints or rules the agent must follow.
- **Goals**: Objectives the agent must achieve.
These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will
overwrite any existing norms/goals of the same name in the BDI agent.
:param program: The program object received from the API. :param program: The program object received from the API.
""" """
@@ -59,34 +67,61 @@ class BDIProgramManager(BaseAgent):
await self.send(msg) await self.send(msg)
@staticmethod async def handle_message(self, msg: InternalMessage):
def _extract_beliefs_from_program(program: Program) -> list[Belief]: match msg.thread:
case "transition_phase":
phases = json.loads(msg.body)
await self._transition_phase(phases["old"], phases["new"])
async def _transition_phase(self, old: str, new: str):
assert old == str(self._phase.id)
if new == "end":
self._phase = None
return
for phase in self._program.phases:
if str(phase.id) == new:
self._phase = phase
await self._send_beliefs_to_semantic_belief_extractor()
await self._send_goals_to_semantic_belief_extractor()
# Notify user interaction agent
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
thread="transition_phase",
body=str(self._phase.id),
)
self.add_behavior(self.send(msg))
def _extract_current_beliefs(self) -> list[Belief]:
beliefs: list[Belief] = [] beliefs: list[Belief] = []
def extract_beliefs_from_belief(belief: Belief) -> list[Belief]: for norm in self._phase.norms:
if isinstance(belief, InferredBelief):
return extract_beliefs_from_belief(belief.left) + extract_beliefs_from_belief(
belief.right
)
return [belief]
for phase in program.phases:
for norm in phase.norms:
if isinstance(norm, ConditionalNorm): if isinstance(norm, ConditionalNorm):
beliefs += extract_beliefs_from_belief(norm.condition) beliefs += self._extract_beliefs_from_belief(norm.condition)
for trigger in phase.triggers: for trigger in self._phase.triggers:
beliefs += extract_beliefs_from_belief(trigger.condition) beliefs += self._extract_beliefs_from_belief(trigger.condition)
return beliefs return beliefs
async def _send_beliefs_to_semantic_belief_extractor(self, program: Program): @staticmethod
def _extract_beliefs_from_belief(belief: Belief) -> list[Belief]:
if isinstance(belief, InferredBelief):
return BDIProgramManager._extract_beliefs_from_belief(
belief.left
) + BDIProgramManager._extract_beliefs_from_belief(belief.right)
return [belief]
async def _send_beliefs_to_semantic_belief_extractor(self):
""" """
Extract beliefs from the program and send them to the Semantic Belief Extractor Agent. Extract beliefs from the program and send them to the Semantic Belief Extractor Agent.
:param program: The program received from the API.
""" """
beliefs = BeliefList(beliefs=self._extract_beliefs_from_program(program)) beliefs = BeliefList(beliefs=self._extract_current_beliefs())
message = InternalMessage( message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name, to=settings.agent_settings.text_belief_extractor_name,
@@ -97,12 +132,10 @@ class BDIProgramManager(BaseAgent):
await self.send(message) await self.send(message)
@staticmethod def _extract_current_goals(self) -> list[Goal]:
def _extract_goals_from_program(program: Program) -> list[Goal]:
""" """
Extract all goals from the program, including subgoals. Extract all goals from the program, including subgoals.
:param program: The program received from the API.
:return: A list of Goal objects. :return: A list of Goal objects.
""" """
goals: list[Goal] = [] goals: list[Goal] = []
@@ -114,19 +147,16 @@ class BDIProgramManager(BaseAgent):
goals_.extend(extract_goals_from_goal(plan)) goals_.extend(extract_goals_from_goal(plan))
return goals_ return goals_
for phase in program.phases: for goal in self._phase.goals:
for goal in phase.goals:
goals.extend(extract_goals_from_goal(goal)) goals.extend(extract_goals_from_goal(goal))
return goals return goals
async def _send_goals_to_semantic_belief_extractor(self, program: Program): async def _send_goals_to_semantic_belief_extractor(self):
""" """
Extract goals from the program and send them to the Semantic Belief Extractor Agent. Extract goals for the current phase and send them to the Semantic Belief Extractor Agent.
:param program: The program received from the API.
""" """
goals = GoalList(goals=self._extract_goals_from_program(program)) goals = GoalList(goals=self._extract_current_goals())
message = InternalMessage( message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name, to=settings.agent_settings.text_belief_extractor_name,
@@ -137,12 +167,34 @@ class BDIProgramManager(BaseAgent):
await self.send(message) await self.send(message)
async def _send_clear_llm_history(self):
"""
Clear the LLM Agent's conversation history.
Sends an empty history to the LLM Agent to reset its state.
"""
message = InternalMessage(
to=settings.agent_settings.llm_name,
body="clear_history",
)
await self.send(message)
self.logger.debug("Sent message to LLM agent to clear history.")
extractor_msg = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
thread="conversation_history",
body="reset",
)
await self.send(extractor_msg)
self.logger.debug("Sent message to extractor agent to clear history.")
async def _receive_programs(self): async def _receive_programs(self):
""" """
Continuous loop that receives program updates from the HTTP endpoint. Continuous loop that receives program updates from the HTTP endpoint.
It listens to the ``program`` topic on the internal ZMQ SUB socket. It listens to the ``program`` topic on the internal ZMQ SUB socket.
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`. When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`.
""" """
while True: while True:
topic, body = await self.sub_socket.recv_multipart() topic, body = await self.sub_socket.recv_multipart()
@@ -150,13 +202,17 @@ class BDIProgramManager(BaseAgent):
try: try:
program = Program.model_validate_json(body) program = Program.model_validate_json(body)
except ValidationError: except ValidationError:
self.logger.exception("Received an invalid program.") self.logger.warning("Received an invalid program.")
continue continue
self._initialize_internal_state(program)
await self._send_clear_llm_history()
await asyncio.gather( await asyncio.gather(
self._create_agentspeak_and_send_to_bdi(program), self._create_agentspeak_and_send_to_bdi(program),
self._send_beliefs_to_semantic_belief_extractor(program), self._send_beliefs_to_semantic_belief_extractor(),
self._send_goals_to_semantic_belief_extractor(program), self._send_goals_to_semantic_belief_extractor(),
) )
async def setup(self): async def setup(self):

View File

@@ -1,5 +1,6 @@
norms(""). norms("").
+user_said(Message) : norms(Norms) <- +user_said(Message) : norms(Norms) <-
.notify_user_said(Message);
-user_said(Message); -user_said(Message);
.reply(Message, Norms). .reply(Message, Norms).

View File

@@ -90,7 +90,7 @@ class TextBeliefExtractorAgent(BaseAgent):
self.logger.debug("Received text from LLM: %s", msg.body) self.logger.debug("Received text from LLM: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body)) self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
case settings.agent_settings.bdi_program_manager_name: case settings.agent_settings.bdi_program_manager_name:
self._handle_program_manager_message(msg) await self._handle_program_manager_message(msg)
case _: case _:
self.logger.info("Discarding message from %s", sender) self.logger.info("Discarding message from %s", sender)
return return
@@ -105,7 +105,7 @@ class TextBeliefExtractorAgent(BaseAgent):
length_limit = settings.behaviour_settings.conversation_history_length_limit length_limit = settings.behaviour_settings.conversation_history_length_limit
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:] self.conversation.messages = (self.conversation.messages + [message])[-length_limit:]
def _handle_program_manager_message(self, msg: InternalMessage): async def _handle_program_manager_message(self, msg: InternalMessage):
""" """
Handle a message from the program manager: extract available beliefs and goals from it. Handle a message from the program manager: extract available beliefs and goals from it.
@@ -114,8 +114,10 @@ class TextBeliefExtractorAgent(BaseAgent):
match msg.thread: match msg.thread:
case "beliefs": case "beliefs":
self._handle_beliefs_message(msg) self._handle_beliefs_message(msg)
await self._infer_new_beliefs()
case "goals": case "goals":
self._handle_goals_message(msg) self._handle_goals_message(msg)
await self._infer_goal_completions()
case "conversation_history": case "conversation_history":
if msg.body == "reset": if msg.body == "reset":
self._reset() self._reset()
@@ -141,8 +143,9 @@ class TextBeliefExtractorAgent(BaseAgent):
available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)] available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
self.belief_inferrer.available_beliefs = available_beliefs self.belief_inferrer.available_beliefs = available_beliefs
self.logger.debug( self.logger.debug(
"Received %d semantic beliefs from the program manager.", "Received %d semantic beliefs from the program manager: %s",
len(available_beliefs), len(available_beliefs),
", ".join(b.name for b in available_beliefs),
) )
def _handle_goals_message(self, msg: InternalMessage): def _handle_goals_message(self, msg: InternalMessage):
@@ -158,8 +161,9 @@ class TextBeliefExtractorAgent(BaseAgent):
available_goals = [g for g in goals_list.goals if g.can_fail] available_goals = [g for g in goals_list.goals if g.can_fail]
self.goal_inferrer.goals = available_goals self.goal_inferrer.goals = available_goals
self.logger.debug( self.logger.debug(
"Received %d failable goals from the program manager.", "Received %d failable goals from the program manager: %s",
len(available_goals), len(available_goals),
", ".join(g.name for g in available_goals),
) )
async def _user_said(self, text: str): async def _user_said(self, text: str):
@@ -183,6 +187,7 @@ class TextBeliefExtractorAgent(BaseAgent):
new_beliefs = conversation_beliefs - self._current_beliefs new_beliefs = conversation_beliefs - self._current_beliefs
if not new_beliefs: if not new_beliefs:
self.logger.debug("No new beliefs detected.")
return return
self._current_beliefs |= new_beliefs self._current_beliefs |= new_beliefs
@@ -217,6 +222,7 @@ class TextBeliefExtractorAgent(BaseAgent):
self._current_goal_completions[goal] = achieved self._current_goal_completions[goal] = achieved
if not new_achieved and not new_not_achieved: if not new_achieved and not new_not_achieved:
self.logger.debug("No goal achievement changes detected.")
return return
belief_changes = BeliefMessage( belief_changes = BeliefMessage(

View File

@@ -38,7 +38,7 @@ class RICommunicationAgent(BaseAgent):
def __init__( def __init__(
self, self,
name: str, name: str,
address=settings.zmq_settings.ri_command_address, address=settings.zmq_settings.ri_communication_address,
bind=False, bind=False,
): ):
super().__init__(name) super().__init__(name)
@@ -168,7 +168,7 @@ class RICommunicationAgent(BaseAgent):
bind = port_data["bind"] bind = port_data["bind"]
if not bind: if not bind:
addr = f"tcp://localhost:{port}" addr = f"tcp://{settings.ri_host}:{port}"
else: else:
addr = f"tcp://*:{port}" addr = f"tcp://*:{port}"
@@ -248,6 +248,7 @@ class RICommunicationAgent(BaseAgent):
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2 self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
) )
if "endpoint" in message and message["endpoint"] != "ping":
self.logger.debug(f'Received message "{message}" from RI.') self.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" not in message: if "endpoint" not in message:
self.logger.warning("No received endpoint in message, expected ping endpoint.") self.logger.warning("No received endpoint in message, expected ping endpoint.")

View File

@@ -46,14 +46,23 @@ class LLMAgent(BaseAgent):
:param msg: The received internal message. :param msg: The received internal message.
""" """
if msg.sender == settings.agent_settings.bdi_core_name: if msg.sender == settings.agent_settings.bdi_core_name:
self.logger.debug("Processing message from BDI core.") match msg.thread:
case "prompt_message":
try: try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body) prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message) await self._process_bdi_message(prompt_message)
except ValidationError: except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.") self.logger.debug("Prompt message from BDI core is invalid.")
case "assistant_message":
self.history.append({"role": "assistant", "content": msg.body})
case "user_message":
self.history.append({"role": "user", "content": msg.body})
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
if msg.body == "clear_history":
self.logger.debug("Clearing conversation history.")
self.history.clear()
else: else:
self.logger.debug("Message ignored (not from BDI core.") self.logger.debug("Message ignored.")
async def _process_bdi_message(self, message: LLMPromptMessage): async def _process_bdi_message(self, message: LLMPromptMessage):
""" """
@@ -114,13 +123,6 @@ class LLMAgent(BaseAgent):
:param goals: Goals the LLM should achieve. :param goals: Goals the LLM should achieve.
:yield: Fragments of the LLM-generated content (e.g., sentences/phrases). :yield: Fragments of the LLM-generated content (e.g., sentences/phrases).
""" """
self.history.append(
{
"role": "user",
"content": prompt,
}
)
instructions = LLMInstructions(norms if norms else None, goals if goals else None) instructions = LLMInstructions(norms if norms else None, goals if goals else None)
messages = [ messages = [
{ {

View File

@@ -103,12 +103,11 @@ class VADAgent(BaseAgent):
self._connect_audio_in_socket() self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket() audio_out_address = self._connect_audio_out_socket()
if audio_out_port is None: if audio_out_address is None:
self.logger.error("Could not bind output socket, stopping.") self.logger.error("Could not bind output socket, stopping.")
await self.stop() await self.stop()
return return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Connect to internal communication socket # Connect to internal communication socket
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB) self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
@@ -161,13 +160,14 @@ class VADAgent(BaseAgent):
self.audio_in_socket.connect(self.audio_in_address) self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket) self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None: def _connect_audio_out_socket(self) -> str | None:
""" """
Returns the port bound, or None if binding failed. Returns the address that was bound to, or None if binding failed.
""" """
try: try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100) self.audio_out_socket.bind(settings.zmq_settings.vad_pub_address)
return settings.zmq_settings.vad_pub_address
except zmq.ZMQBindError: except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.") self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None self.audio_out_socket = None
@@ -229,10 +229,11 @@ class VADAgent(BaseAgent):
assert self.model is not None assert self.model is not None
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item() prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
begin_silence_length = settings.behaviour_settings.vad_begin_silence_chunks
prob_threshold = settings.behaviour_settings.vad_prob_threshold prob_threshold = settings.behaviour_settings.vad_prob_threshold
if prob > prob_threshold: if prob > prob_threshold:
if self.i_since_speech > non_speech_patience: if self.i_since_speech > non_speech_patience + begin_silence_length:
self.logger.debug("Speech started.") self.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk) self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0 self.i_since_speech = 0
@@ -246,11 +247,12 @@ class VADAgent(BaseAgent):
continue continue
# Speech probably ended. Make sure we have a usable amount of data. # Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk): if len(self.audio_buffer) > begin_silence_length * len(chunk):
self.logger.debug("Speech ended.") self.logger.debug("Speech ended.")
assert self.audio_out_socket is not None assert self.audio_out_socket is not None
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes()) await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended. # At this point, we know that there is no speech.
# Prepend the last chunk that had no speech, for a more fluent boundary # Prepend the last few chunks that had no speech, for a more fluent boundary.
self.audio_buffer = chunk self.audio_buffer = np.append(self.audio_buffer, chunk)
self.audio_buffer = self.audio_buffer[-begin_silence_length * len(chunk) :]

View File

@@ -131,6 +131,7 @@ class BaseAgent(ABC):
:param message: The message to send. :param message: The message to send.
""" """
target = AgentDirectory.get(message.to) target = AgentDirectory.get(message.to)
message.sender = self.name
if target: if target:
await target.inbox.put(message) await target.inbox.put(message)
self.logger.debug(f"Sent message {message.body} to {message.to} via regular inbox.") self.logger.debug(f"Sent message {message.body} to {message.to} via regular inbox.")

View File

@@ -1,3 +1,12 @@
"""
An exhaustive overview of configurable options. All of these can be set using environment variables
by nesting with double underscores (__). Start from the ``Settings`` class.
For example, ``settings.ri_host`` becomes ``RI_HOST``, and
``settings.zmq_settings.ri_communication_address`` becomes
``ZMQ_SETTINGS__RI_COMMUNICATION_ADDRESS``.
"""
from pydantic import BaseModel from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -8,16 +17,17 @@ class ZMQSettings(BaseModel):
:ivar internal_pub_address: Address for the internal PUB socket. :ivar internal_pub_address: Address for the internal PUB socket.
:ivar internal_sub_address: Address for the internal SUB socket. :ivar internal_sub_address: Address for the internal SUB socket.
:ivar ri_command_address: Address for sending commands to the Robot Interface. :ivar ri_communication_address: Address for the endpoint that the Robot Interface connects to.
:ivar ri_communication_address: Address for receiving communication from the Robot Interface. :ivar vad_pub_address: Address that the VAD agent binds to and publishes audio segments to.
:ivar vad_agent_address: Address for the Voice Activity Detection (VAD) agent.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
internal_pub_address: str = "tcp://localhost:5560" internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561" internal_sub_address: str = "tcp://localhost:5561"
ri_command_address: str = "tcp://localhost:0000"
ri_communication_address: str = "tcp://*:5555" ri_communication_address: str = "tcp://*:5555"
internal_gesture_rep_adress: str = "tcp://localhost:7788" internal_gesture_rep_adress: str = "tcp://localhost:7788"
vad_pub_address: str = "inproc://vad_stream"
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
@@ -36,6 +46,8 @@ class AgentSettings(BaseModel):
:ivar robot_speech_name: Name of the Robot Speech Agent. :ivar robot_speech_name: Name of the Robot Speech Agent.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
# agent names # agent names
bdi_core_name: str = "bdi_core_agent" bdi_core_name: str = "bdi_core_agent"
bdi_belief_collector_name: str = "belief_collector_agent" bdi_belief_collector_name: str = "belief_collector_agent"
@@ -61,6 +73,7 @@ class BehaviourSettings(BaseModel):
:ivar vad_prob_threshold: Probability threshold for Voice Activity Detection. :ivar vad_prob_threshold: Probability threshold for Voice Activity Detection.
:ivar vad_initial_since_speech: Initial value for 'since speech' counter in VAD. :ivar vad_initial_since_speech: Initial value for 'since speech' counter in VAD.
:ivar vad_non_speech_patience_chunks: Number of non-speech chunks to wait before speech ended. :ivar vad_non_speech_patience_chunks: Number of non-speech chunks to wait before speech ended.
:ivar vad_begin_silence_chunks: The number of chunks of silence to prepend to speech chunks.
:ivar transcription_max_concurrent_tasks: Maximum number of concurrent transcription tasks. :ivar transcription_max_concurrent_tasks: Maximum number of concurrent transcription tasks.
:ivar transcription_words_per_minute: Estimated words per minute for transcription timing. :ivar transcription_words_per_minute: Estimated words per minute for transcription timing.
:ivar transcription_words_per_token: Estimated words per token for transcription timing. :ivar transcription_words_per_token: Estimated words per token for transcription timing.
@@ -68,6 +81,8 @@ class BehaviourSettings(BaseModel):
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from. :ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
sleep_s: float = 1.0 sleep_s: float = 1.0
comm_setup_max_retries: int = 5 comm_setup_max_retries: int = 5
socket_poller_timeout_ms: int = 100 socket_poller_timeout_ms: int = 100
@@ -75,7 +90,8 @@ class BehaviourSettings(BaseModel):
# VAD settings # VAD settings
vad_prob_threshold: float = 0.5 vad_prob_threshold: float = 0.5
vad_initial_since_speech: int = 100 vad_initial_since_speech: int = 100
vad_non_speech_patience_chunks: int = 3 vad_non_speech_patience_chunks: int = 15
vad_begin_silence_chunks: int = 6
# transcription behaviour # transcription behaviour
transcription_max_concurrent_tasks: int = 3 transcription_max_concurrent_tasks: int = 3
@@ -99,6 +115,8 @@ class LLMSettings(BaseModel):
:ivar n_parallel: The number of parallel calls allowed to be made to the LLM. :ivar n_parallel: The number of parallel calls allowed to be made to the LLM.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "gpt-oss" local_llm_model: str = "gpt-oss"
chat_temperature: float = 1.0 chat_temperature: float = 1.0
@@ -115,6 +133,8 @@ class VADSettings(BaseModel):
:ivar sample_rate_hz: Sample rate in Hz for the VAD model. :ivar sample_rate_hz: Sample rate in Hz for the VAD model.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
repo_or_dir: str = "snakers4/silero-vad" repo_or_dir: str = "snakers4/silero-vad"
model_name: str = "silero_vad" model_name: str = "silero_vad"
sample_rate_hz: int = 16000 sample_rate_hz: int = 16000
@@ -128,6 +148,8 @@ class SpeechModelSettings(BaseModel):
:ivar openai_model_name: Model name for OpenAI-based speech recognition. :ivar openai_model_name: Model name for OpenAI-based speech recognition.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
# model identifiers for speech recognition # model identifiers for speech recognition
mlx_model_name: str = "mlx-community/whisper-small.en-mlx" mlx_model_name: str = "mlx-community/whisper-small.en-mlx"
openai_model_name: str = "small.en" openai_model_name: str = "small.en"
@@ -139,6 +161,7 @@ class Settings(BaseSettings):
:ivar app_title: Title of the application. :ivar app_title: Title of the application.
:ivar ui_url: URL of the frontend UI. :ivar ui_url: URL of the frontend UI.
:ivar ri_host: The hostname of the Robot Interface.
:ivar zmq_settings: ZMQ configuration. :ivar zmq_settings: ZMQ configuration.
:ivar agent_settings: Agent name configuration. :ivar agent_settings: Agent name configuration.
:ivar behaviour_settings: Behavior configuration. :ivar behaviour_settings: Behavior configuration.
@@ -151,6 +174,8 @@ class Settings(BaseSettings):
ui_url: str = "http://localhost:5173" ui_url: str = "http://localhost:5173"
ri_host: str = "localhost"
zmq_settings: ZMQSettings = ZMQSettings() zmq_settings: ZMQSettings = ZMQSettings()
agent_settings: AgentSettings = AgentSettings() agent_settings: AgentSettings = AgentSettings()

View File

@@ -12,6 +12,6 @@ class InternalMessage(BaseModel):
""" """
to: str to: str
sender: str sender: str | None = None
body: str body: str
thread: str | None = None thread: str | None = None

View File

@@ -180,7 +180,6 @@ class Trigger(ProgramElement):
:ivar plan: The plan to execute. :ivar plan: The plan to execute.
""" """
name: str = ""
condition: Belief condition: Belief
plan: Plan plan: Plan

View File

@@ -91,7 +91,7 @@ def test_out_socket_creation(zmq_context):
assert per_vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() zmq_context.return_value.socket.return_value.bind.assert_called_once_with("inproc://vad_stream")
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -73,7 +73,7 @@ async def test_setup_connect(zmq_context, mocker):
async def test_handle_message_sends_valid_gesture_command(): async def test_handle_message_sends_valid_gesture_command():
"""Internal message with valid gesture tag is forwarded to robot pub socket.""" """Internal message with valid gesture tag is forwarded to robot pub socket."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
payload = { payload = {
@@ -91,7 +91,7 @@ async def test_handle_message_sends_valid_gesture_command():
async def test_handle_message_sends_non_gesture_command(): async def test_handle_message_sends_non_gesture_command():
"""Internal message with non-gesture endpoint is not forwarded by this agent.""" """Internal message with non-gesture endpoint is not forwarded by this agent."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"} payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
@@ -107,7 +107,7 @@ async def test_handle_message_sends_non_gesture_command():
async def test_handle_message_rejects_invalid_gesture_tag(): async def test_handle_message_rejects_invalid_gesture_tag():
"""Internal message with invalid gesture tag is not forwarded.""" """Internal message with invalid gesture tag is not forwarded."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
# Use a tag that's not in gesture_data # Use a tag that's not in gesture_data
@@ -123,7 +123,7 @@ async def test_handle_message_rejects_invalid_gesture_tag():
async def test_handle_message_invalid_payload(): async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send.""" """Invalid payload is caught and does not send."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"})) msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
@@ -142,12 +142,12 @@ async def test_zmq_command_loop_valid_gesture_payload():
async def recv_once(): async def recv_once():
# stop after first iteration # stop after first iteration
agent._running = False agent._running = False
return (b"command", json.dumps(command).encode("utf-8")) return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -165,12 +165,12 @@ async def test_zmq_command_loop_valid_non_gesture_payload():
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"command", json.dumps(command).encode("utf-8")) return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -188,12 +188,12 @@ async def test_zmq_command_loop_invalid_gesture_tag():
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"command", json.dumps(command).encode("utf-8")) return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -210,12 +210,12 @@ async def test_zmq_command_loop_invalid_json():
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"command", b"{not_json}") return b"command", b"{not_json}"
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -232,12 +232,12 @@ async def test_zmq_command_loop_ignores_send_gestures_topic():
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"send_gestures", b"{}") return b"send_gestures", b"{}"
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -259,7 +259,9 @@ async def test_fetch_gestures_loop_without_amount():
fake_repsocket.recv = recv_once fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock() fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"]) agent = RobotGestureAgent(
"robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"], address=""
)
agent.repsocket = fake_repsocket agent.repsocket = fake_repsocket
agent._running = True agent._running = True
@@ -287,7 +289,9 @@ async def test_fetch_gestures_loop_with_amount():
fake_repsocket.recv = recv_once fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock() fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"]) agent = RobotGestureAgent(
"robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"], address=""
)
agent.repsocket = fake_repsocket agent.repsocket = fake_repsocket
agent._running = True agent._running = True
@@ -315,7 +319,7 @@ async def test_fetch_gestures_loop_with_integer_request():
fake_repsocket.recv = recv_once fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock() fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket agent.repsocket = fake_repsocket
agent._running = True agent._running = True
@@ -340,7 +344,7 @@ async def test_fetch_gestures_loop_with_invalid_json():
fake_repsocket.recv = recv_once fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock() fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket agent.repsocket = fake_repsocket
agent._running = True agent._running = True
@@ -365,7 +369,7 @@ async def test_fetch_gestures_loop_with_non_integer_json():
fake_repsocket.recv = recv_once fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock() fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket agent.repsocket = fake_repsocket
agent._running = True agent._running = True
@@ -381,7 +385,7 @@ async def test_fetch_gestures_loop_with_non_integer_json():
def test_gesture_data_attribute(): def test_gesture_data_attribute():
"""Test that gesture_data returns the expected list.""" """Test that gesture_data returns the expected list."""
gesture_data = ["hello", "yes", "no", "wave"] gesture_data = ["hello", "yes", "no", "wave"]
agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data) agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data, address="")
assert agent.gesture_data == gesture_data assert agent.gesture_data == gesture_data
assert isinstance(agent.gesture_data, list) assert isinstance(agent.gesture_data, list)
@@ -398,7 +402,7 @@ async def test_stop_closes_sockets():
pubsocket = MagicMock() pubsocket = MagicMock()
subsocket = MagicMock() subsocket = MagicMock()
repsocket = MagicMock() repsocket = MagicMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent("robot_gesture", address="")
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
agent.subsocket = subsocket agent.subsocket = subsocket
agent.repsocket = repsocket agent.repsocket = repsocket
@@ -415,7 +419,7 @@ async def test_stop_closes_sockets():
async def test_initialization_with_custom_gesture_data(): async def test_initialization_with_custom_gesture_data():
"""Agent can be initialized with custom gesture data.""" """Agent can be initialized with custom gesture data."""
custom_gestures = ["custom1", "custom2", "custom3"] custom_gestures = ["custom1", "custom2", "custom3"]
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures) agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures, address="")
assert agent.gesture_data == custom_gestures assert agent.gesture_data == custom_gestures
@@ -432,7 +436,7 @@ async def test_fetch_gestures_loop_handles_exception():
fake_repsocket.recv = recv_once fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock() fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket agent.repsocket = fake_repsocket
agent.logger = MagicMock() agent.logger = MagicMock()
agent._running = True agent._running = True

View File

@@ -80,6 +80,7 @@ async def test_receive_programs_valid_and_invalid():
manager._internal_pub_socket = AsyncMock() manager._internal_pub_socket = AsyncMock()
manager.sub_socket = sub manager.sub_socket = sub
manager._create_agentspeak_and_send_to_bdi = AsyncMock() manager._create_agentspeak_and_send_to_bdi = AsyncMock()
manager._send_clear_llm_history = AsyncMock()
try: try:
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out # Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
@@ -92,3 +93,24 @@ async def test_receive_programs_valid_and_invalid():
forwarded: Program = manager._create_agentspeak_and_send_to_bdi.await_args[0][0] forwarded: Program = manager._create_agentspeak_and_send_to_bdi.await_args[0][0]
assert forwarded.phases[0].norms[0].name == "N1" assert forwarded.phases[0].norms[0].name == "N1"
assert forwarded.phases[0].goals[0].name == "G1" assert forwarded.phases[0].goals[0].name == "G1"
# Verify history clear was triggered
assert manager._send_clear_llm_history.await_count == 1
@pytest.mark.asyncio
async def test_send_clear_llm_history(mock_settings):
# Ensure the mock returns a string for the agent name (just like in your LLM tests)
mock_settings.agent_settings.llm_agent_name = "llm_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
await manager._send_clear_llm_history()
assert manager.send.await_count == 2
msg: InternalMessage = manager.send.await_args_list[0][0][0]
# Verify the content and recipient
assert msg.body == "clear_history"
assert msg.to == "llm_agent"

View File

@@ -6,10 +6,13 @@ import httpx
import pytest import pytest
from control_backend.agents.bdi import TextBeliefExtractorAgent from control_backend.agents.bdi import TextBeliefExtractorAgent
from control_backend.agents.bdi.text_belief_extractor_agent import BeliefState
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_list import BeliefList from control_backend.schemas.belief_list import BeliefList
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import ( from control_backend.schemas.program import (
ConditionalNorm, ConditionalNorm,
KeywordBelief, KeywordBelief,
@@ -23,10 +26,20 @@ from control_backend.schemas.program import (
@pytest.fixture @pytest.fixture
def agent(): def llm():
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
llm._query_llm = AsyncMock()
return llm
@pytest.fixture
def agent(llm):
with patch(
"control_backend.agents.bdi.text_belief_extractor_agent.TextBeliefExtractorAgent.LLM",
return_value=llm,
):
agent = TextBeliefExtractorAgent("text_belief_agent") agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock() agent.send = AsyncMock()
agent._query_llm = AsyncMock()
return agent return agent
@@ -102,24 +115,12 @@ async def test_handle_message_from_transcriber(agent, mock_settings):
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it. agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name assert sent.to == mock_settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs" assert sent.thread == "beliefs"
parsed = json.loads(sent.body) parsed = BeliefMessage.model_validate_json(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"} replaced_last = parsed.replace.pop()
assert replaced_last.name == "user_said"
assert replaced_last.arguments == [transcription]
@pytest.mark.asyncio
async def test_process_user_said(agent, mock_settings):
transcription = "this is a test"
await agent._user_said(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -144,46 +145,46 @@ async def test_query_llm():
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient", "control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
return_value=mock_async_client, return_value=mock_async_client,
): ):
agent = TextBeliefExtractorAgent("text_belief_agent") llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
res = await agent._query_llm("hello world", {"type": "null"}) res = await llm._query_llm("hello world", {"type": "null"})
# Response content was set as "null", so should be deserialized as None # Response content was set as "null", so should be deserialized as None
assert res is None assert res is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_retry_query_llm_success(agent): async def test_retry_query_llm_success(llm):
agent._query_llm.return_value = None llm._query_llm.return_value = None
res = await agent._retry_query_llm("hello world", {"type": "null"}) res = await llm.query("hello world", {"type": "null"})
agent._query_llm.assert_called_once() llm._query_llm.assert_called_once()
assert res is None assert res is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_retry_query_llm_success_after_failure(agent): async def test_retry_query_llm_success_after_failure(llm):
agent._query_llm.side_effect = [KeyError(), "real value"] llm._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"}) res = await llm.query("hello world", {"type": "string"})
assert agent._query_llm.call_count == 2 assert llm._query_llm.call_count == 2
assert res == "real value" assert res == "real value"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_retry_query_llm_failures(agent): async def test_retry_query_llm_failures(llm):
agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"] llm._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"}) res = await llm.query("hello world", {"type": "string"})
assert agent._query_llm.call_count == 3 assert llm._query_llm.call_count == 3
assert res is None assert res is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_retry_query_llm_fail_immediately(agent): async def test_retry_query_llm_fail_immediately(llm):
agent._query_llm.side_effect = [KeyError(), "real value"] llm._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1) res = await llm.query("hello world", {"type": "string"}, tries=1)
assert agent._query_llm.call_count == 1 assert llm._query_llm.call_count == 1
assert res is None assert res is None
@@ -192,7 +193,7 @@ async def test_extracting_semantic_beliefs(agent):
""" """
The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly. The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly.
""" """
assert len(agent.available_beliefs) == 0 assert len(agent.belief_inferrer.available_beliefs) == 0
beliefs = BeliefList( beliefs = BeliefList(
beliefs=[ beliefs=[
KeywordBelief( KeywordBelief(
@@ -213,26 +214,28 @@ async def test_extracting_semantic_beliefs(agent):
to=settings.agent_settings.text_belief_extractor_name, to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name, sender=settings.agent_settings.bdi_program_manager_name,
body=beliefs.model_dump_json(), body=beliefs.model_dump_json(),
thread="beliefs",
), ),
) )
assert len(agent.available_beliefs) == 2 assert len(agent.belief_inferrer.available_beliefs) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_invalid_program(agent, sample_program): async def test_handle_invalid_beliefs(agent, sample_program):
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
assert len(agent.available_beliefs) == 2 assert len(agent.belief_inferrer.available_beliefs) == 2
await agent.handle_message( await agent.handle_message(
InternalMessage( InternalMessage(
to=settings.agent_settings.text_belief_extractor_name, to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name, sender=settings.agent_settings.bdi_program_manager_name,
body=json.dumps({"phases": "Invalid"}), body=json.dumps({"phases": "Invalid"}),
thread="beliefs",
), ),
) )
assert len(agent.available_beliefs) == 2 assert len(agent.belief_inferrer.available_beliefs) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -254,13 +257,13 @@ async def test_handle_robot_response(agent):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simulated_real_turn_with_beliefs(agent, sample_program): async def test_simulated_real_turn_with_beliefs(agent, llm, sample_program):
"""Test sending user message to extract beliefs from.""" """Test sending user message to extract beliefs from."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with the belief that there's no more booze # Send a user message with the belief that there's no more booze
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True} llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
assert len(agent.conversation.messages) == 0 assert len(agent.conversation.messages) == 0
await agent.handle_message( await agent.handle_message(
InternalMessage( InternalMessage(
@@ -275,20 +278,20 @@ async def test_simulated_real_turn_with_beliefs(agent, sample_program):
assert agent.send.call_count == 2 assert agent.send.call_count == 2
# First should be the beliefs message # First should be the beliefs message
message: InternalMessage = agent.send.call_args_list[0].args[0] message: InternalMessage = agent.send.call_args_list[1].args[0]
beliefs = BeliefMessage.model_validate_json(message.body) beliefs = BeliefMessage.model_validate_json(message.body)
assert len(beliefs.create) == 1 assert len(beliefs.create) == 1
assert beliefs.create[0].name == "no_more_booze" assert beliefs.create[0].name == "no_more_booze"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simulated_real_turn_no_beliefs(agent, sample_program): async def test_simulated_real_turn_no_beliefs(agent, llm, sample_program):
"""Test a user message to extract beliefs from, but no beliefs are formed.""" """Test a user message to extract beliefs from, but no beliefs are formed."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with no new beliefs # Send a user message with no new beliefs
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None} llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
await agent.handle_message( await agent.handle_message(
InternalMessage( InternalMessage(
to=settings.agent_settings.text_belief_extractor_name, to=settings.agent_settings.text_belief_extractor_name,
@@ -302,17 +305,17 @@ async def test_simulated_real_turn_no_beliefs(agent, sample_program):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simulated_real_turn_no_new_beliefs(agent, sample_program): async def test_simulated_real_turn_no_new_beliefs(agent, llm, sample_program):
""" """
Test a user message to extract beliefs from, but no new beliefs are formed because they already Test a user message to extract beliefs from, but no new beliefs are formed because they already
existed. existed.
""" """
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["is_pirate"] = True agent._current_beliefs = BeliefState(true={InternalBelief(name="is_pirate", arguments=None)})
# Send a user message with the belief the user is a pirate, still # Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None} llm._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
await agent.handle_message( await agent.handle_message(
InternalMessage( InternalMessage(
to=settings.agent_settings.text_belief_extractor_name, to=settings.agent_settings.text_belief_extractor_name,
@@ -326,17 +329,19 @@ async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simulated_real_turn_remove_belief(agent, sample_program): async def test_simulated_real_turn_remove_belief(agent, llm, sample_program):
""" """
Test a user message to extract beliefs from, but an existing belief is determined no longer to Test a user message to extract beliefs from, but an existing belief is determined no longer to
hold. hold.
""" """
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["no_more_booze"] = True agent._current_beliefs = BeliefState(
true={InternalBelief(name="no_more_booze", arguments=None)},
)
# Send a user message with the belief the user is a pirate, still # Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False} llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
await agent.handle_message( await agent.handle_message(
InternalMessage( InternalMessage(
to=settings.agent_settings.text_belief_extractor_name, to=settings.agent_settings.text_belief_extractor_name,
@@ -349,18 +354,23 @@ async def test_simulated_real_turn_remove_belief(agent, sample_program):
assert agent.send.call_count == 2 assert agent.send.call_count == 2
# Agent's current beliefs should've changed # Agent's current beliefs should've changed
assert not agent.beliefs["no_more_booze"] assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_llm_failure_handling(agent, sample_program): async def test_llm_failure_handling(agent, llm, sample_program):
""" """
Check that the agent handles failures gracefully without crashing. Check that the agent handles failures gracefully without crashing.
""" """
agent._query_llm.side_effect = httpx.HTTPError("") llm._query_llm.side_effect = httpx.HTTPError("")
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
belief_changes = await agent._infer_turn() belief_changes = await agent.belief_inferrer.infer_from_conversation(
ChatHistory(
messages=[ChatMessage(role="user", content="Good day!")],
),
)
assert len(belief_changes) == 0 assert len(belief_changes.true) == 0
assert len(belief_changes.false) == 0

View File

@@ -265,3 +265,23 @@ async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_set
# Only the valid 'data:' line should yield content # Only the valid 'data:' line should yield content
assert tokens == ["Hi"] assert tokens == ["Hi"]
@pytest.mark.asyncio
async def test_clear_history_command(mock_settings):
"""Test that the 'clear_history' message clears the agent's memory."""
# setup LLM to have some history
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
agent = LLMAgent("llm_agent")
agent.history = [
{"role": "user", "content": "Old conversation context"},
{"role": "assistant", "content": "Old response"},
]
assert len(agent.history) == 2
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_program_manager_name,
body="clear_history",
)
await agent.handle_message(msg)
assert len(agent.history) == 0

View File

@@ -7,6 +7,15 @@ import zmq
from control_backend.agents.perception.vad_agent import VADAgent from control_backend.agents.perception.vad_agent import VADAgent
# We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets
# aren't closed properly.
@pytest.fixture(autouse=True)
def mock_zmq():
with patch("zmq.asyncio.Context") as mock:
mock.instance.return_value = MagicMock()
yield mock
@pytest.fixture @pytest.fixture
def audio_out_socket(): def audio_out_socket():
return AsyncMock() return AsyncMock()
@@ -140,12 +149,10 @@ async def test_vad_model_load_failure_stops_agent(vad_agent):
# Patch stop to an AsyncMock so we can check it was awaited # Patch stop to an AsyncMock so we can check it was awaited
vad_agent.stop = AsyncMock() vad_agent.stop = AsyncMock()
result = await vad_agent.setup() await vad_agent.setup()
# Assert stop was called # Assert stop was called
vad_agent.stop.assert_awaited_once() vad_agent.stop.assert_awaited_once()
# Assert setup returned None
assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -155,7 +162,7 @@ async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog):
audio_out_socket is set to None, None is returned, and an error is logged. audio_out_socket is set to None, None is returned, and an error is logged.
""" """
mock_socket = MagicMock() mock_socket = MagicMock()
mock_socket.bind_to_random_port.side_effect = zmq.ZMQBindError() mock_socket.bind.side_effect = zmq.ZMQBindError()
with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx: with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx:
mock_ctx.return_value.socket.return_value = mock_socket mock_ctx.return_value.socket.return_value = mock_socket