Compare commits

..

114 Commits

Author SHA1 Message Date
8506c0d9ef chore: remove belief collector and small tweaks 2026-01-16 15:07:44 +01:00
b1c18abffd test: bunch of tests
Written with AI, still need to check them

ref: N25B-449
2026-01-16 13:11:41 +01:00
39e1bb1ead fix: sync issues
ref: N25B-447
2026-01-14 15:28:29 +01:00
8f6662e64a feat: phase transitions
ref: N25B-446
2026-01-14 13:22:51 +01:00
0794c549a8 chore: remove agentspeak file from tracking 2026-01-14 11:27:29 +01:00
ff24ab7a27 fix: default behavior and end phase
ref: N25B-448
2026-01-14 11:24:19 +01:00
43ac8ad69f chore: delete outdated files
ref: N25B-446
2026-01-14 10:58:41 +01:00
Twirre Meulenbelt
f7669c021b feat: support force completed goals in semantic belief agent
ref: N25B-427
2026-01-13 17:04:44 +01:00
Björn Otgaar
8f52f8bf0c Merge branch 'feat/monitoringpage-cb' of git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb into feat/monitoringpage-cb 2026-01-13 14:03:40 +01:00
Björn Otgaar
2a94a45b34 chore: adjust 'phase_id' to 'id' for correct payload 2026-01-13 14:03:37 +01:00
f87651f691 fix: achieved goal in bdi core
ref: N25B-400
2026-01-13 12:26:18 +01:00
Pim Hutting
65e0b2d250 feat: added correct message
ref: N25B-400
2026-01-13 12:05:38 +01:00
177e844349 feat: send achieved goal from interrupt->manager->semantic
ref: N25B-400
2026-01-13 11:46:17 +01:00
Pim Hutting
0df6040444 feat: added sending goal overwrites in Userinter.
ref: N25B-400
2026-01-13 11:26:03 +01:00
Twirre Meulenbelt
af81bd8620 Merge branch 'feat/multiple-receivers' into feat/monitoringpage-cb
# Conflicts:
#	src/control_backend/core/agent_system.py
#	src/control_backend/schemas/internal_message.py
2026-01-13 11:14:18 +01:00
Twirre Meulenbelt
70e05b6c92 test: sending to multiple agents, including remote
ref: N25B-441
2026-01-13 11:10:35 +01:00
c0b8fb8612 feat: able to send to multiple receivers
ref: N25B-441
2026-01-13 11:06:42 +01:00
Pim Hutting
d499111ea4 feat: added pause functionality
Storms code wasnt fully included in Bjorns branch

ref: N25B-400
2026-01-13 00:52:04 +01:00
Pim Hutting
72c2c57f26 chore: merged button functionality and fix bug
merged björns branch that has the following button functionality
-Pause/resume
-Next phase
-Restart phase
-reset experiment
fix bug where norms where not properly sent to the user interrupt agent

ref: N25B-400
2026-01-12 19:31:50 +01:00
Pim Hutting
4a014b577a Merge remote-tracking branch 'origin/feat/reset-skip-buttons' into feat/monitoringpage-cb 2026-01-12 19:19:31 +01:00
Pim Hutting
c45a258b22 fix: fixed a bug where norms where not updated
Now in UserInterruptAgent we store the norm.norm and not the slugified norm

ref: N25B-400
2026-01-12 19:07:05 +01:00
0f09276477 fix: send norms back to UI
ref: N25B-400
2026-01-12 17:02:39 +01:00
4e113c2d5c fix: default plan and norm force
ref: N25B-400
2026-01-12 16:20:24 +01:00
Pim Hutting
54c835cc0f feat: added force_norm handling in BDI core agent
ref: N25B-400
2026-01-12 15:37:04 +01:00
Pim Hutting
c4ccbcd354 Merge remote-tracking branch 'origin/feat/extra-agentspeak-functionality' into feat/monitoringpage-cb 2026-01-12 15:24:48 +01:00
Pim Hutting
d202abcd1b fix: phases update correctly
there was a bug where phases would not update without restarting cb

ref: N25B-400
2026-01-12 12:51:24 +01:00
Twirre Meulenbelt
4b71981a3e fix: some bugs and some tests
ref: N25B-429
2026-01-12 09:00:50 +01:00
Björn Otgaar
c91b999104 chore: fix bugs and make sure connected robots work 2026-01-08 15:31:44 +01:00
866d7c4958 fix: end phase loop correctly notifies about user_said
ref: N25B-429
2026-01-08 15:13:12 +01:00
Pim Hutting
5e2126fc21 chore: code cleanup
ref: N25B-400
2026-01-08 15:05:43 +01:00
Pim Hutting
500bbc2d82 feat: added goal start sending functionality
ref: N25B-400
2026-01-08 14:52:55 +01:00
133019a928 feat: trigger name and trigger checks on belief update
ref: N25B-429
2026-01-08 14:04:44 +01:00
4d0ba69443 fix: don't re-add user_said upon phase transition
ref: N25B-429
2026-01-08 13:44:25 +01:00
625ef0c365 feat: phase transition waits for all goals
ref: N25B-429
2026-01-08 13:36:03 +01:00
b88758fa76 feat: phase transition independent of response
ref: N25B-429
2026-01-08 13:33:37 +01:00
Björn Otgaar
1360567820 chore: indenting 2026-01-08 13:01:38 +01:00
Björn Otgaar
cc0d5af28c chore: fixing bugs 2026-01-08 12:56:22 +01:00
Pim Hutting
3a8d1730a1 fix: made mapping for conditional norms only
ref: N25B-400
2026-01-08 12:29:16 +01:00
Pim Hutting
b27e5180c4 feat: small implementation change
ref: N25B-400
2026-01-08 11:25:53 +01:00
Pim Hutting
6b34f4b82c fix: small bugfix
ref: N25B-400
2026-01-08 10:59:24 +01:00
Twirre Meulenbelt
45719c580b feat: prepend more silence before speech audio for better transcription beginnings
ref: N25B-429
2026-01-08 10:49:13 +01:00
Pim Hutting
4bf2be6359 feat: added a functionality for monitoring page
ref: N25B-400
2026-01-08 09:56:10 +01:00
Pim Hutting
20e5e46639 Merge remote-tracking branch 'origin/feat/extra-agentspeak-functionality' into feat/monitoringpage-cb 2026-01-07 22:42:40 +01:00
Pim Hutting
365d449666 feat: commit before I can merge new changes
ref: N25B-400
2026-01-07 22:41:59 +01:00
Björn Otgaar
be88323cf7 chore: add one endpoint fo avoid errors 2026-01-07 18:24:35 +01:00
5a61225c6f feat: reset extractor history
ref: N25B-429
2026-01-07 18:10:13 +01:00
a30cea5231 Merge branch 'feat/semantic-beliefs' into feat/extra-agentspeak-functionality 2026-01-07 17:51:30 +01:00
Twirre Meulenbelt
93d67ccb66 feat: add reset functionality to semantic belief extractor
ref: N25B-432
2026-01-07 17:50:47 +01:00
240624f887 Merge branch 'dev' into feat/extra-agentspeak-functionality
# Conflicts:
#	src/control_backend/agents/bdi/bdi_program_manager.py
#	src/control_backend/agents/llm/llm_agent.py
#	test/unit/agents/bdi/test_bdi_program_manager.py
2026-01-07 17:46:48 +01:00
Pim Hutting
be6bbbb849 feat: added endpoint userinterrupt to userinterrupt
ref: N25B-400
2026-01-07 17:42:54 +01:00
8a77e8e1c7 feat: check goals only for this phase
Since conversation history still remains we can still check at a later point.

ref: N25B-429
2026-01-07 17:31:24 +01:00
3b4dccc760 Merge branch 'feat/semantic-beliefs' into feat/extra-agentspeak-functionality
# Conflicts:
#	src/control_backend/agents/bdi/bdi_program_manager.py
2026-01-07 17:20:52 +01:00
3d49e44cf7 fix: complete pipeline working
User interrupts still need to be tested.

ref: N25B-429
2026-01-07 17:13:58 +01:00
Twirre Meulenbelt
aa5b386f65 feat: semantically determine goal completion
ref: N25B-432
2026-01-07 17:08:23 +01:00
Storm
76dfcb23ef feat: added pause functionality
ref: N25B-350
2026-01-07 16:03:49 +01:00
Twirre Meulenbelt
3189b9fee3 fix: let belief extractor send user_said belief
ref: N25B-429
2026-01-07 15:19:23 +01:00
Björn Otgaar
34afca6652 chore: automatically send the experiment controls to the bdi core in the user interupt agent. 2026-01-07 15:07:33 +01:00
Björn Otgaar
324a63e5cc chore: add styles to user_interrupt_agent 2026-01-07 14:45:42 +01:00
07d70cb781 fix: single dispatch order
ref: N25B-429
2026-01-07 13:02:23 +01:00
af832980c8 feat: general slugify method
ref: N25B-429
2026-01-07 12:24:46 +01:00
Twirre Meulenbelt
cabe35cdbd feat: integrate AgentSpeak with semantic belief extraction
ref: N25B-429
2026-01-07 11:44:48 +01:00
Twirre Meulenbelt
de8e829d3e Merge remote-tracking branch 'origin/feat/agentspeak-generation' into feat/semantic-beliefs
# Conflicts:
#	test/unit/agents/bdi/test_bdi_program_manager.py
2026-01-06 15:30:59 +01:00
Twirre Meulenbelt
3406e9ac2f feat: make the pipeline work with Program and AgentSpeak
ref: N25B-429
2026-01-06 15:26:44 +01:00
a357b6990b feat: send program to bdi core
ref: N25B-376
2026-01-06 12:11:37 +01:00
Björn Otgaar
612a96940d Merge branch 'feat/environment-variables' into 'dev'
Docs for environment variables, parameterize some constants

See merge request ics/sp/2025/n25b/pepperplus-cb!38
2026-01-06 09:02:49 +00:00
Pim Hutting
4c20656c75 Merge branch 'feat/program-reset-llm' into 'dev'
feat: made program reset LLM

See merge request ics/sp/2025/n25b/pepperplus-cb!39
2026-01-02 15:13:05 +00:00
Pim Hutting
6ca86e4b81 feat: made program reset LLM 2026-01-02 15:13:04 +00:00
9eea4ee345 feat: new ASL generation
ref: N25B-376
2026-01-02 12:08:20 +01:00
Twirre Meulenbelt
42ee5c76d8 test: create tests for belief extractor agent
Includes changes in schemas. Change type of `norms` in `Program` imperceptibly, big changes in schema of `BeliefMessage` to support deleting beliefs.

ref: N25B-380
2025-12-29 17:12:02 +01:00
Twirre Meulenbelt
7d798f2e77 Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:40:16 +01:00
Twirre Meulenbelt
5282c2471f Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:35:39 +01:00
Twirre Meulenbelt
57b1276cb5 test: make tests work again after changing Program schema
ref: N25B-380
2025-12-29 12:31:51 +01:00
Twirre Meulenbelt
7e0dc9ce1c Merge remote-tracking branch 'origin/feat/agentspeak-generation' into feat/semantic-beliefs
# Conflicts:
#	src/control_backend/schemas/program.py
2025-12-23 17:36:39 +01:00
3253760ef1 feat: new AST representation
File names will be changed eventually.

ref: N25B-376
2025-12-23 17:30:35 +01:00
Twirre Meulenbelt
71cefdfef3 fix: add types to all config properties
ref: N25B-380
2025-12-23 17:14:49 +01:00
Twirre Meulenbelt
33501093a1 feat: extract semantic beliefs from conversation
ref: N25B-380
2025-12-23 17:09:58 +01:00
Luijkx,S.O.H. (Storm)
adbb7ffd5c Merge branch 'feat/user-interrupt-agent' into 'dev'
create UserInterruptAgent with connection to UI

See merge request ics/sp/2025/n25b/pepperplus-cb!40
2025-12-22 13:56:03 +00:00
Pim Hutting
0501a9fba3 create UserInterruptAgent with connection to UI 2025-12-22 13:56:02 +00:00
756e1f0dc5 feat: persistent rules and stuff
So ugly

ref: N25B-376
2025-12-18 14:33:42 +01:00
Twirre Meulenbelt
f91cec6708 fix: things in AgentSpeak, add custom actions
ref: N25B-376
2025-12-18 11:50:16 +01:00
28262eb27e fix: default case for plans
ref: N25B-376
2025-12-17 16:20:37 +01:00
1d36d2e089 feat: (hopefully) better intermediate representation
ref: N25B-376
2025-12-17 15:33:27 +01:00
742e36b94f chore: non-optional uuid id
ref: N25B-376
2025-12-17 14:30:14 +01:00
Twirre Meulenbelt
57fe3ae3f6 Merge remote-tracking branch 'origin/dev' into feat/agentspeak-generation 2025-12-17 13:20:14 +01:00
e704ec5ed4 feat: basic flow and phase transitions
ref: N25B-376
2025-12-16 17:00:32 +01:00
Twirre Meulenbelt
27f04f0958 style: use yield instead of returning arrays
ref: N25B-376
2025-12-16 16:11:01 +01:00
Twirre Meulenbelt
8cc177041a feat: add a second phase in test_program
ref: N25B-376
2025-12-16 15:12:22 +01:00
4a432a603f fix: separate trigger plan generation
ref: N25B-376
2025-12-16 14:12:04 +01:00
JobvAlewijk
3e7f2ef574 Merge branch 'feat/quiet-llm' into 'dev'
feat: implemented extra log level for LLM token stream

See merge request ics/sp/2025/n25b/pepperplus-cb!37
2025-12-16 11:26:37 +00:00
Luijkx,S.O.H. (Storm)
78abad55d3 feat: implemented extra log level for LLM token stream 2025-12-16 11:26:35 +00:00
bab4800698 feat: add trigger generation
ref: N25B-376
2025-12-16 12:10:52 +01:00
Twirre
4ab6b2a0e6 Merge branch 'feat/cb2ri-gestures' into 'dev'
Gestures in the CB.

See merge request ics/sp/2025/n25b/pepperplus-cb!36
2025-12-16 09:24:25 +00:00
Twirre Meulenbelt
db5504db20 chore: remove redundant check 2025-12-16 10:22:11 +01:00
d043c54336 refactor: program restructure
Also includes some AgentSpeak generation.

ref: N25B-376
2025-12-16 10:21:50 +01:00
Björn Otgaar
f15a518984 fix: tests
ref: N25B-334
2025-12-15 11:52:01 +01:00
Björn Otgaar
71d86f5fb0 Merge branch 'feat/cb2ri-gestures' of git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb into feat/cb2ri-gestures 2025-12-15 11:36:12 +01:00
Björn Otgaar
daf31ac6a6 fix: change the address to the config, update some logic, seperate the api endpoint, renaming things. yes, the tests don't work right now- this shouldn't be merged yet.
ref: N25B-334
2025-12-15 11:35:56 +01:00
Björn Otgaar
b2d014753d Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Pim Hutting <p.r.p.hutting@students.uu.nl>
2025-12-11 15:08:15 +00:00
Twirre Meulenbelt
0c682d6440 feat: introduce .env.example, docs
The example includes options that are expected to be changed. It also includes a reference to where in the docs you can find a full list of options.

ref: N25B-352
2025-12-11 13:35:19 +01:00
Björn Otgaar
2e472ea292 chore: remove wrong test paths 2025-12-11 12:48:18 +01:00
Björn Otgaar
1c9b722ba3 Merge branch 'dev' into feat/cb2ri-gestures 2025-12-11 12:46:32 +01:00
Twirre Meulenbelt
32d8f20dc9 feat: parameterize RI host
Was "localhost" in RI Communication Agent, now uses configurable setting. Secretly also removing "localhost" from VAD agent, as its socket should be something that's "inproc".

ref: N25B-352
2025-12-11 12:12:15 +01:00
Twirre Meulenbelt
9cc0e39955 fix: failures main tests since VAD agent initialization was changed
The test still expects the VAD agent to be started in main, rather than in the RI Communication Agent.

ref: N25B-356
2025-12-11 12:04:24 +01:00
Björn Otgaar
2366255b92 Merge branch 'fix/correct-vad-starting' into 'dev'
Move VAD agent creation to RI communication agent

See merge request ics/sp/2025/n25b/pepperplus-cb!34
2025-12-09 14:57:09 +00:00
Björn Otgaar
7f34fede81 fix: fix the tests
ref: N25B-334
2025-12-09 15:37:00 +01:00
Luijkx,S.O.H. (Storm)
a9255cb6e7 Merge branch 'test/coverage-max-all' into 'dev'
test: increased cb test coverage

See merge request ics/sp/2025/n25b/pepperplus-cb!32
2025-12-09 13:14:03 +00:00
JobvAlewijk
7f7c658901 test: increased cb test coverage 2025-12-09 13:14:02 +00:00
Björn Otgaar
3d62e7fc0c Merge branch 'feat/cb2ri-gestures' of git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb into feat/cb2ri-gestures 2025-12-09 14:13:01 +01:00
Björn Otgaar
6034263259 fix: correct the gestures bugs, change gestures socket to request/reply
ref: N25B-334
2025-12-09 14:08:59 +01:00
JobvAlewijk
63897f5969 chore: double tag 2025-12-09 12:33:43 +01:00
JobvAlewijk
a3cf389c05 Merge branch 'dev' of https://git.science.uu.nl/ics/sp/2025/n25b/pepperplus-cb into fix/correct-vad-starting 2025-12-07 23:09:31 +01:00
JobvAlewijk
de2e56ffce Merge branch 'fix/fix-socket-typing' into 'dev'
chore: fix socket typing in robot speech agent

See merge request ics/sp/2025/n25b/pepperplus-cb!33
2025-12-03 14:26:46 +00:00
Twirre Meulenbelt
21e9d05d6e fix: move VAD agent creation to RI communication agent
Previously, it was started in main, but it should use values negotiated by the RI communication agent.

ref: N25B-356
2025-12-03 15:07:29 +01:00
Björn Otgaar
bacc63aa31 chore: fix socket typing in robot speech agent 2025-12-02 14:22:39 +01:00
70 changed files with 6757 additions and 1076 deletions

20
.env.example Normal file
View File

@@ -0,0 +1,20 @@
# Example .env file. To use, make a copy, call it ".env" (i.e. removing the ".example" suffix), then you edit values.
# The hostname of the Robot Interface. Change if the Control Backend and Robot Interface are running on different computers.
RI_HOST="localhost"
# URL for the local LLM API. Must be an API that implements the OpenAI Chat Completions API, but most do.
LLM_SETTINGS__LOCAL_LLM_URL="http://localhost:1234/v1/chat/completions"
# Name of the local LLM model to use.
LLM_SETTINGS__LOCAL_LLM_MODEL="gpt-oss"
# Number of non-speech chunks to wait before speech ended. A chunk is approximately 31 ms. Increasing this number allows longer pauses in speech, but also increases response time.
BEHAVIOUR_SETTINGS__VAD_NON_SPEECH_PATIENCE_CHUNKS=15
# Timeout in milliseconds for socket polling. Increase this number if network latency/jitter is high, often the case when using Wi-Fi. Perhaps 500 ms. A symptom of this issue is transcriptions getting cut off.
BEHAVIOUR_SETTINGS__SOCKET_POLLER_TIMEOUT_MS=100
# For an exhaustive list of options, see the control_backend.core.config module in the docs.

2
.gitignore vendored
View File

@@ -222,6 +222,8 @@ __marimo__/
docs/*
!docs/conf.py
# Generated files
agentspeak.asl

View File

@@ -0,0 +1,9 @@
%{first_multiline_commit_description}
To verify:
- [ ] Style checks pass
- [ ] Pipeline (tests) pass
- [ ] Documentation is up to date
- [ ] Tests are up to date (new code is covered)
- [ ] ...

View File

@@ -3,6 +3,7 @@ version: 1
custom_levels:
OBSERVATION: 25
ACTION: 26
LLM: 9
formatters:
# Console output
@@ -26,7 +27,7 @@ handlers:
stream: ext://sys.stdout
ui:
class: zmq.log.handlers.PUBHandler
level: DEBUG
level: LLM
formatter: json_experiment
# Level of external libraries
@@ -36,5 +37,5 @@ root:
loggers:
control_backend:
level: DEBUG
level: LLM
handlers: [ui]

View File

@@ -27,6 +27,7 @@ This + part might differ based on what model you choose.
copy the model name in the module loaded and replace local_llm_modelL. In settings.
## Running
To run the project (development server), execute the following command (while inside the root repository):
@@ -34,6 +35,14 @@ To run the project (development server), execute the following command (while in
uv run fastapi dev src/control_backend/main.py
```
### Environment Variables
You can use environment variables to change settings. Make a copy of the [`.env.example`](.env.example) file, name it `.env` and put it in the root directory. The file itself describes how to do the configuration.
For an exhaustive list of environment options, see the `control_backend.core.config` module in the docs.
## Testing
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:

View File

@@ -15,6 +15,7 @@ dependencies = [
"pydantic>=2.12.0",
"pydantic-settings>=2.11.0",
"python-json-logger>=4.0.0",
"python-slugify>=8.0.4",
"pyyaml>=6.0.3",
"pyzmq>=27.1.0",
"silero-vad>=6.0.0",

View File

@@ -12,7 +12,7 @@ from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
class RobotGestureAgent(BaseAgent):
"""
This agent acts as a bridge between the control backend and the Robot Interface (RI).
It receives speech commands from other agents or from the UI,
It receives gesture commands from other agents or from the UI,
and forwards them to the robot via a ZMQ PUB socket.
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
@@ -23,22 +23,23 @@ class RobotGestureAgent(BaseAgent):
"""
subsocket: azmq.Socket
repsocket: azmq.Socket
pubsocket: azmq.Socket
address = ""
bind = False
gesture_data = []
single_gesture_data = []
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
address: str,
bind=False,
gesture_data=None,
single_gesture_data=None,
):
if gesture_data is None:
self.gesture_data = []
else:
self.gesture_data = gesture_data
self.gesture_data = gesture_data or []
self.single_gesture_data = single_gesture_data or []
super().__init__(name)
self.address = address
self.bind = bind
@@ -56,9 +57,8 @@ class RobotGestureAgent(BaseAgent):
context = azmq.Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind: # TODO: Should this ever be the case?
if self.bind:
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
@@ -69,6 +69,10 @@ class RobotGestureAgent(BaseAgent):
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures")
# REP socket for replying to gesture requests
self.repsocket = context.socket(zmq.REP)
self.repsocket.bind(settings.zmq_settings.internal_gesture_rep_adress)
self.add_behavior(self._zmq_command_loop())
self.add_behavior(self._fetch_gestures_loop())
@@ -79,6 +83,8 @@ class RobotGestureAgent(BaseAgent):
self.subsocket.close()
if self.pubsocket:
self.pubsocket.close()
if self.repsocket:
self.repsocket.close()
await super().stop()
async def handle_message(self, msg: InternalMessage):
@@ -92,13 +98,19 @@ class RobotGestureAgent(BaseAgent):
try:
gesture_command = GestureCommand.model_validate_json(msg.body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.availableTags():
if gesture_command.data not in self.gesture_data:
self.logger.warning(
"Received gesture tag '%s' which is not in available tags. Early returning",
gesture_command.data,
)
return
elif gesture_command.endpoint == RIEndpoint.GESTURE_SINGLE:
if gesture_command.data not in self.single_gesture_data:
self.logger.warning(
"Received gesture '%s' which is not in available gestures. Early returning",
gesture_command.data,
)
return
await self.pubsocket.send_json(gesture_command.model_dump())
except Exception:
self.logger.exception("Error processing internal message.")
@@ -120,7 +132,7 @@ class RobotGestureAgent(BaseAgent):
body = json.loads(body)
gesture_command = GestureCommand.model_validate(body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.availableTags():
if gesture_command.data not in self.gesture_data:
self.logger.warning(
"Received gesture tag '%s' which is not in available tags.\
Early returning",
@@ -139,157 +151,23 @@ class RobotGestureAgent(BaseAgent):
"""
while self._running:
try:
topic, body = await self.subsocket.recv_multipart()
# Don't process commands here
if topic != b"send_gestures":
continue
# Get a request
body = await self.repsocket.recv()
# Figure out amount, if specified
try:
body = json.loads(body)
except json.JSONDecodeError:
body = None
# We could have the body be the nummer of gestures you want to fetch or something.
amount = None
if isinstance(body, int):
amount = body
tags = self.availableTags()[:amount] if amount else self.availableTags()
# Fetch tags from gesture data and respond
tags = self.gesture_data[:amount] if amount else self.gesture_data
response = json.dumps({"tags": tags}).encode()
await self.pubsocket.send_multipart(
[
b"get_gestures",
response,
]
)
await self.repsocket.send(response)
except Exception:
self.logger.exception("Error fetching gesture tags.")
def availableTags(self):
"""
Returns the available gesture tags.
:return: List of available gesture tags.
"""
return [
"above",
"affirmative",
"afford",
"agitated",
"all",
"allright",
"alright",
"any",
"assuage",
"assuage",
"attemper",
"back",
"bashful",
"beg",
"beseech",
"blank",
"body language",
"bored",
"bow",
"but",
"call",
"calm",
"choose",
"choice",
"cloud",
"cogitate",
"cool",
"crazy",
"disappointed",
"down",
"earth",
"empty",
"embarrassed",
"enthusiastic",
"entire",
"estimate",
"except",
"exalted",
"excited",
"explain",
"far",
"field",
"floor",
"forlorn",
"friendly",
"front",
"frustrated",
"gentle",
"gift",
"give",
"ground",
"happy",
"hello",
"her",
"here",
"hey",
"hi",
"him",
"hopeless",
"hysterical",
"I",
"implore",
"indicate",
"joyful",
"me",
"meditate",
"modest",
"negative",
"nervous",
"no",
"not know",
"nothing",
"offer",
"ok",
"once upon a time",
"oppose",
"or",
"pacify",
"pick",
"placate",
"please",
"present",
"proffer",
"quiet",
"reason",
"refute",
"reject",
"rousing",
"sad",
"select",
"shamefaced",
"show",
"show sky",
"sky",
"soothe",
"sun",
"supplicate",
"tablet",
"tall",
"them",
"there",
"think",
"timid",
"top",
"unless",
"up",
"upstairs",
"void",
"warm",
"winner",
"yeah",
"yes",
"yoo-hoo",
"you",
"your",
"zero",
"zestful",
]

View File

@@ -29,7 +29,7 @@ class RobotSpeechAgent(BaseAgent):
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
address: str,
bind=False,
):
super().__init__(name)

View File

@@ -1,8 +1,5 @@
from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent
from .belief_collector_agent import (
BDIBeliefCollectorAgent as BDIBeliefCollectorAgent,
)
from .text_belief_extractor_agent import (
TextBeliefExtractorAgent as TextBeliefExtractorAgent,
)

View File

@@ -0,0 +1,273 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import StrEnum
class AstNode(ABC):
"""
Abstract base class for all elements of an AgentSpeak program.
"""
@abstractmethod
def _to_agentspeak(self) -> str:
"""
Generates the AgentSpeak code string.
"""
pass
def __str__(self) -> str:
return self._to_agentspeak()
class AstExpression(AstNode, ABC):
"""
Intermediate class for anything that can be used in a logical expression.
"""
def __and__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.AND, _coalesce_expr(other))
def __or__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.OR, _coalesce_expr(other))
def __invert__(self) -> AstLogicalExpression:
if isinstance(self, AstLogicalExpression):
self.negated = not self.negated
return self
return AstLogicalExpression(self, negated=True)
type ExprCoalescible = AstExpression | str | int | float
def _coalesce_expr(value: ExprCoalescible) -> AstExpression:
if isinstance(value, AstExpression):
return value
if isinstance(value, str):
return AstString(value)
if isinstance(value, (int, float)):
return AstNumber(value)
raise TypeError(f"Cannot coalesce type {type(value)} into an AstTerm.")
@dataclass
class AstTerm(AstExpression, ABC):
"""
Base class for terms appearing inside literals.
"""
def __ge__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.GREATER_EQUALS, _coalesce_expr(other))
def __gt__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.GREATER_THAN, _coalesce_expr(other))
def __le__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.LESS_EQUALS, _coalesce_expr(other))
def __lt__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.LESS_THAN, _coalesce_expr(other))
def __eq__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.EQUALS, _coalesce_expr(other))
def __ne__(self, other: ExprCoalescible) -> AstBinaryOp:
return AstBinaryOp(self, BinaryOperatorType.NOT_EQUALS, _coalesce_expr(other))
@dataclass(eq=False)
class AstAtom(AstTerm):
"""
Grounded expression in all lowercase.
"""
value: str
def _to_agentspeak(self) -> str:
return self.value.lower()
@dataclass(eq=False)
class AstVar(AstTerm):
"""
Ungrounded variable expression. First letter capitalized.
"""
name: str
def _to_agentspeak(self) -> str:
return self.name.capitalize()
@dataclass(eq=False)
class AstNumber(AstTerm):
value: int | float
def _to_agentspeak(self) -> str:
return str(self.value)
@dataclass(eq=False)
class AstString(AstTerm):
value: str
def _to_agentspeak(self) -> str:
return f'"{self.value}"'
@dataclass(eq=False)
class AstLiteral(AstTerm):
functor: str
terms: list[AstTerm] = field(default_factory=list)
def _to_agentspeak(self) -> str:
if not self.terms:
return self.functor
args = ", ".join(map(str, self.terms))
return f"{self.functor}({args})"
class BinaryOperatorType(StrEnum):
AND = "&"
OR = "|"
GREATER_THAN = ">"
LESS_THAN = "<"
EQUALS = "=="
NOT_EQUALS = "\\=="
GREATER_EQUALS = ">="
LESS_EQUALS = "<="
@dataclass
class AstBinaryOp(AstExpression):
left: AstExpression
operator: BinaryOperatorType
right: AstExpression
def __post_init__(self):
self.left = _as_logical(self.left)
self.right = _as_logical(self.right)
def _to_agentspeak(self) -> str:
l_str = str(self.left)
r_str = str(self.right)
assert isinstance(self.left, AstLogicalExpression)
assert isinstance(self.right, AstLogicalExpression)
if isinstance(self.left.expression, AstBinaryOp) or self.left.negated:
l_str = f"({l_str})"
if isinstance(self.right.expression, AstBinaryOp) or self.right.negated:
r_str = f"({r_str})"
return f"{l_str} {self.operator.value} {r_str}"
@dataclass
class AstLogicalExpression(AstExpression):
expression: AstExpression
negated: bool = False
def _to_agentspeak(self) -> str:
expr_str = str(self.expression)
if isinstance(self.expression, AstBinaryOp) and self.negated:
expr_str = f"({expr_str})"
return f"{'not ' if self.negated else ''}{expr_str}"
def _as_logical(expr: AstExpression) -> AstLogicalExpression:
if isinstance(expr, AstLogicalExpression):
return expr
return AstLogicalExpression(expr)
class StatementType(StrEnum):
EMPTY = ""
DO_ACTION = "."
ACHIEVE_GOAL = "!"
TEST_GOAL = "?"
ADD_BELIEF = "+"
REMOVE_BELIEF = "-"
REPLACE_BELIEF = "-+"
@dataclass
class AstStatement(AstNode):
"""
A statement that can appear inside a plan.
"""
type: StatementType
expression: AstExpression
def _to_agentspeak(self) -> str:
return f"{self.type.value}{self.expression}"
@dataclass
class AstRule(AstNode):
result: AstExpression
condition: AstExpression | None = None
def __post_init__(self):
if self.condition is not None:
self.condition = _as_logical(self.condition)
def _to_agentspeak(self) -> str:
if not self.condition:
return f"{self.result}."
return f"{self.result} :- {self.condition}."
class TriggerType(StrEnum):
ADDED_BELIEF = "+"
# REMOVED_BELIEF = "-" # TODO
# MODIFIED_BELIEF = "^" # TODO
ADDED_GOAL = "+!"
# REMOVED_GOAL = "-!" # TODO
@dataclass
class AstPlan(AstNode):
type: TriggerType
trigger_literal: AstExpression
context: list[AstExpression]
body: list[AstStatement]
def _to_agentspeak(self) -> str:
assert isinstance(self.trigger_literal, AstLiteral)
indent = " " * 6
colon = " : "
arrow = " <- "
lines = []
lines.append(f"{self.type.value}{self.trigger_literal}")
if self.context:
lines.append(colon + f" &\n{indent}".join(str(c) for c in self.context))
if self.body:
lines.append(arrow + f";\n{indent}".join(str(s) for s in self.body) + ".")
lines.append("")
return "\n".join(lines)
@dataclass
class AstProgram(AstNode):
rules: list[AstRule] = field(default_factory=list)
plans: list[AstPlan] = field(default_factory=list)
def _to_agentspeak(self) -> str:
lines = []
lines.extend(map(str, self.rules))
lines.extend(["", ""])
lines.extend(map(str, self.plans))
return "\n".join(lines)

View File

@@ -0,0 +1,504 @@
from functools import singledispatchmethod
from slugify import slugify
from control_backend.agents.bdi.agentspeak_ast import (
AstAtom,
AstBinaryOp,
AstExpression,
AstLiteral,
AstNumber,
AstPlan,
AstProgram,
AstRule,
AstStatement,
AstString,
AstVar,
BinaryOperatorType,
StatementType,
TriggerType,
)
from control_backend.schemas.program import (
BaseGoal,
BasicNorm,
ConditionalNorm,
GestureAction,
Goal,
InferredBelief,
KeywordBelief,
LLMAction,
LogicalOperator,
Norm,
Phase,
PlanElement,
Program,
ProgramElement,
SemanticBelief,
SpeechAction,
Trigger,
)
class AgentSpeakGenerator:
_asp: AstProgram
def generate(self, program: Program) -> str:
self._asp = AstProgram()
if program.phases:
self._asp.rules.append(AstRule(self._astify(program.phases[0])))
else:
self._asp.rules.append(AstRule(AstLiteral("phase", [AstString("end")])))
self._asp.rules.append(AstRule(AstLiteral("!notify_cycle")))
self._add_keyword_inference()
self._add_default_plans()
self._process_phases(program.phases)
self._add_fallbacks()
return str(self._asp)
def _add_keyword_inference(self) -> None:
keyword = AstVar("Keyword")
message = AstVar("Message")
position = AstVar("Pos")
self._asp.rules.append(
AstRule(
AstLiteral("keyword_said", [keyword]),
AstLiteral("user_said", [message])
& AstLiteral(".substring", [keyword, message, position])
& (position >= 0),
)
)
def _add_default_plans(self):
self._add_reply_with_goal_plan()
self._add_say_plan()
self._add_reply_plan()
self._add_notify_cycle_plan()
def _add_reply_with_goal_plan(self):
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("reply_with_goal", [AstVar("Goal")]),
[AstLiteral("user_said", [AstVar("Message")])],
[
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"findall",
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
),
),
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"reply_with_goal", [AstVar("Message"), AstVar("Norms"), AstVar("Goal")]
),
),
],
)
)
def _add_say_plan(self):
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("say", [AstVar("Text")]),
[],
[
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
AstStatement(StatementType.DO_ACTION, AstLiteral("say", [AstVar("Text")])),
],
)
)
def _add_reply_plan(self):
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("reply"),
[AstLiteral("user_said", [AstVar("Message")])],
[
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"findall",
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
),
),
AstStatement(
StatementType.DO_ACTION,
AstLiteral("reply", [AstVar("Message"), AstVar("Norms")]),
),
],
)
)
def _add_notify_cycle_plan(self):
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("notify_cycle"),
[],
[
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"findall",
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
),
),
AstStatement(
StatementType.DO_ACTION, AstLiteral("notify_norms", [AstVar("Norms")])
),
AstStatement(StatementType.DO_ACTION, AstLiteral("wait", [AstNumber(100)])),
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("notify_cycle")),
],
)
)
def _process_phases(self, phases: list[Phase]) -> None:
for curr_phase, next_phase in zip([None] + phases, phases + [None], strict=True):
if curr_phase:
self._process_phase(curr_phase)
self._add_phase_transition(curr_phase, next_phase)
# End phase behavior
# When deleting this, the entire `reply` plan and action can be deleted
self._asp.plans.append(
AstPlan(
type=TriggerType.ADDED_BELIEF,
trigger_literal=AstLiteral("user_said", [AstVar("Message")]),
context=[AstLiteral("phase", [AstString("end")])],
body=[
AstStatement(
StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")])
),
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply")),
],
)
)
def _process_phase(self, phase: Phase) -> None:
for norm in phase.norms:
self._process_norm(norm, phase)
self._add_default_loop(phase)
previous_goal = None
for goal in phase.goals:
self._process_goal(goal, phase, previous_goal, main_goal=True)
previous_goal = goal
for trigger in phase.triggers:
self._process_trigger(trigger, phase)
def _add_phase_transition(self, from_phase: Phase | None, to_phase: Phase | None) -> None:
if from_phase is None:
return
from_phase_ast = self._astify(from_phase)
to_phase_ast = (
self._astify(to_phase) if to_phase else AstLiteral("phase", [AstString("end")])
)
check_context = [from_phase_ast]
if from_phase:
for goal in from_phase.goals:
check_context.append(self._astify(goal, achieved=True))
force_context = [from_phase_ast]
body = [
AstStatement(
StatementType.DO_ACTION,
AstLiteral(
"notify_transition_phase",
[
AstString(str(from_phase.id)),
AstString(str(to_phase.id) if to_phase else "end"),
],
),
),
AstStatement(StatementType.REMOVE_BELIEF, from_phase_ast),
AstStatement(StatementType.ADD_BELIEF, to_phase_ast),
]
# if from_phase:
# body.extend(
# [
# AstStatement(
# StatementType.TEST_GOAL, AstLiteral("user_said", [AstVar("Message")])
# ),
# AstStatement(
# StatementType.REPLACE_BELIEF, AstLiteral("user_said", [AstVar("Message")])
# ),
# ]
# )
# Check
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("transition_phase"),
check_context,
[
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("force_transition_phase")),
],
)
)
# Force
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL, AstLiteral("force_transition_phase"), force_context, body
)
)
def _process_norm(self, norm: Norm, phase: Phase) -> None:
rule: AstRule | None = None
match norm:
case ConditionalNorm(condition=cond):
rule = AstRule(
self._astify(norm),
self._astify(phase) & self._astify(cond)
| AstAtom(f"force_{self.slugify(norm)}"),
)
case BasicNorm():
rule = AstRule(self._astify(norm), self._astify(phase))
if not rule:
return
self._asp.rules.append(rule)
def _add_default_loop(self, phase: Phase) -> None:
actions = []
actions.append(
AstStatement(
StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")])
)
)
actions.append(AstStatement(StatementType.REMOVE_BELIEF, AstLiteral("responded_this_turn")))
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("check_triggers")))
for goal in phase.goals:
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, self._astify(goal)))
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("transition_phase")))
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_BELIEF,
AstLiteral("user_said", [AstVar("Message")]),
[self._astify(phase)],
actions,
)
)
def _process_goal(
self,
goal: Goal,
phase: Phase,
previous_goal: Goal | None = None,
continues_response: bool = False,
main_goal: bool = False,
) -> None:
context: list[AstExpression] = [self._astify(phase)]
context.append(~self._astify(goal, achieved=True))
if previous_goal and previous_goal.can_fail:
context.append(self._astify(previous_goal, achieved=True))
if not continues_response:
context.append(~AstLiteral("responded_this_turn"))
body = []
if main_goal: # UI only needs to know about the main goals
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_goal_start", [AstString(self.slugify(goal))]),
)
)
subgoals = []
for step in goal.plan.steps:
body.append(self._step_to_statement(step))
if isinstance(step, Goal):
subgoals.append(step)
if not goal.can_fail and not continues_response:
body.append(AstStatement(StatementType.ADD_BELIEF, self._astify(goal, achieved=True)))
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(goal), context, body))
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
self._astify(goal),
context=[],
body=[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
)
)
prev_goal = None
for subgoal in subgoals:
self._process_goal(subgoal, phase, prev_goal)
prev_goal = subgoal
def _step_to_statement(self, step: PlanElement) -> AstStatement:
match step:
case Goal() | SpeechAction() | LLMAction() as a:
return AstStatement(StatementType.ACHIEVE_GOAL, self._astify(a))
case GestureAction() as a:
return AstStatement(StatementType.DO_ACTION, self._astify(a))
# TODO: separate handling of keyword and others
def _process_trigger(self, trigger: Trigger, phase: Phase) -> None:
body = []
subgoals = []
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_trigger_start", [AstString(self.slugify(trigger))]),
)
)
for step in trigger.plan.steps:
body.append(self._step_to_statement(step))
if isinstance(step, Goal):
step.can_fail = False # triggers are continuous sequence
subgoals.append(step)
# Arbitrary wait for UI to display nicely
body.append(AstStatement(StatementType.DO_ACTION, AstLiteral("wait", [AstNumber(2000)])))
body.append(
AstStatement(
StatementType.DO_ACTION,
AstLiteral("notify_trigger_end", [AstString(self.slugify(trigger))]),
)
)
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("check_triggers"),
[self._astify(phase), self._astify(trigger.condition)],
body,
)
)
# Force trigger (from UI)
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(trigger), [], body))
for subgoal in subgoals:
self._process_goal(subgoal, phase, continues_response=True)
def _add_fallbacks(self):
# Trigger fallback
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("check_triggers"),
[],
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
)
)
# Phase transition fallback
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("transition_phase"),
[],
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
)
)
@singledispatchmethod
def _astify(self, element: ProgramElement) -> AstExpression:
raise NotImplementedError(f"Cannot convert element {element} to an AgentSpeak expression.")
@_astify.register
def _(self, kwb: KeywordBelief) -> AstExpression:
return AstLiteral("keyword_said", [AstString(kwb.keyword)])
@_astify.register
def _(self, sb: SemanticBelief) -> AstExpression:
return AstLiteral(self.slugify(sb))
@_astify.register
def _(self, ib: InferredBelief) -> AstExpression:
return AstBinaryOp(
self._astify(ib.left),
BinaryOperatorType.AND if ib.operator == LogicalOperator.AND else BinaryOperatorType.OR,
self._astify(ib.right),
)
@_astify.register
def _(self, norm: Norm) -> AstExpression:
functor = "critical_norm" if norm.critical else "norm"
return AstLiteral(functor, [AstString(norm.norm)])
@_astify.register
def _(self, phase: Phase) -> AstExpression:
return AstLiteral("phase", [AstString(str(phase.id))])
@_astify.register
def _(self, goal: Goal, achieved: bool = False) -> AstExpression:
return AstLiteral(f"{'achieved_' if achieved else ''}{self._slugify_str(goal.name)}")
@_astify.register
def _(self, trigger: Trigger) -> AstExpression:
return AstLiteral(self.slugify(trigger))
@_astify.register
def _(self, sa: SpeechAction) -> AstExpression:
return AstLiteral("say", [AstString(sa.text)])
@_astify.register
def _(self, ga: GestureAction) -> AstExpression:
gesture = ga.gesture
return AstLiteral("gesture", [AstString(gesture.type), AstString(gesture.name)])
@_astify.register
def _(self, la: LLMAction) -> AstExpression:
return AstLiteral("reply_with_goal", [AstString(la.goal)])
@singledispatchmethod
@staticmethod
def slugify(element: ProgramElement) -> str:
raise NotImplementedError(f"Cannot convert element {element} to a slug.")
@slugify.register
@staticmethod
def _(n: Norm) -> str:
return f"norm_{AgentSpeakGenerator._slugify_str(n.norm)}"
@slugify.register
@staticmethod
def _(sb: SemanticBelief) -> str:
return f"semantic_{AgentSpeakGenerator._slugify_str(sb.name)}"
@slugify.register
@staticmethod
def _(g: BaseGoal) -> str:
return AgentSpeakGenerator._slugify_str(g.name)
@slugify.register
@staticmethod
def _(t: Trigger):
return f"trigger_{AgentSpeakGenerator._slugify_str(t.name)}"
@staticmethod
def _slugify_str(text: str) -> str:
return slugify(text, separator="_", stopwords=["a", "an", "the", "we", "you", "I"])

View File

@@ -1,5 +1,6 @@
import asyncio
import copy
import json
import time
from collections.abc import Iterable
@@ -11,9 +12,9 @@ from pydantic import ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
from control_backend.schemas.ri_message import SpeechCommand
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, SpeechCommand
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
@@ -42,13 +43,13 @@ class BDICoreAgent(BaseAgent):
bdi_agent: agentspeak.runtime.Agent
def __init__(self, name: str, asl: str):
def __init__(self, name: str):
super().__init__(name)
self.asl_file = asl
self.env = agentspeak.runtime.Environment()
# Deep copy because we don't actually want to modify the standard actions globally
self.actions = copy.deepcopy(agentspeak.stdlib.actions)
self._wake_bdi_loop = asyncio.Event()
self._bdi_loop_task = None
async def setup(self) -> None:
"""
@@ -65,19 +66,22 @@ class BDICoreAgent(BaseAgent):
await self._load_asl()
# Start the BDI cycle loop
self.add_behavior(self._bdi_loop())
self._bdi_loop_task = self.add_behavior(self._bdi_loop())
self._wake_bdi_loop.set()
self.logger.debug("Setup complete.")
async def _load_asl(self):
async def _load_asl(self, file_name: str | None = None) -> None:
"""
Load and parse the AgentSpeak source file.
"""
file_name = file_name or "src/control_backend/agents/bdi/default_behavior.asl"
try:
with open(self.asl_file) as source:
with open(file_name) as source:
self.bdi_agent = self.env.build_agent(source, self.actions)
self.logger.info(f"Loaded new ASL from {file_name}.")
except FileNotFoundError:
self.logger.warning(f"Could not find the specified ASL file at {self.asl_file}.")
self.logger.warning(f"Could not find the specified ASL file at {file_name}.")
self.bdi_agent = agentspeak.runtime.Agent(self.env, self.name)
async def _bdi_loop(self):
@@ -97,14 +101,12 @@ class BDICoreAgent(BaseAgent):
maybe_more_work = True
while maybe_more_work:
maybe_more_work = False
self.logger.debug("Stepping BDI.")
if self.bdi_agent.step():
maybe_more_work = True
if not maybe_more_work:
deadline = self.bdi_agent.shortest_deadline()
if deadline:
self.logger.debug("Sleeping until %s", deadline)
await asyncio.sleep(deadline - time.time())
maybe_more_work = True
else:
@@ -116,6 +118,7 @@ class BDICoreAgent(BaseAgent):
Handle incoming messages.
- **Beliefs**: Updates the internal belief base.
- **Program**: Updates the internal agentspeak file to match the current program.
- **LLM Responses**: Forwards the generated text to the Robot Speech Agent (actuation).
:param msg: The received internal message.
@@ -124,12 +127,19 @@ class BDICoreAgent(BaseAgent):
if msg.thread == "beliefs":
try:
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
self._apply_beliefs(beliefs)
belief_changes = BeliefMessage.model_validate_json(msg.body)
self._apply_belief_changes(belief_changes)
except ValidationError:
self.logger.exception("Error processing belief.")
return
# New agentspeak file
if msg.thread == "new_program":
if self._bdi_loop_task:
self._bdi_loop_task.cancel()
await self._load_asl(msg.body)
self.add_behavior(self._bdi_loop())
# The message was not a belief, handle special cases based on sender
match msg.sender:
case settings.agent_settings.llm_name:
@@ -144,23 +154,44 @@ class BDICoreAgent(BaseAgent):
body=cmd.model_dump_json(),
)
await self.send(out_msg)
case settings.agent_settings.user_interrupt_name:
self.logger.debug("Received user interruption: %s", msg)
def _apply_beliefs(self, beliefs: list[Belief]):
match msg.thread:
case "force_phase_transition":
self._set_goal("transition_phase")
case "force_trigger":
self._force_trigger(msg.body)
case "force_norm":
self._force_norm(msg.body)
case "force_next_phase":
self._force_next_phase()
case _:
self.logger.warning("Received unknown user interruption: %s", msg)
def _apply_belief_changes(self, belief_changes: BeliefMessage):
"""
Update the belief base with a list of new beliefs.
If ``replace=True`` is set on a belief, it removes all existing beliefs with that name
before adding the new one.
For beliefs in ``belief_changes.replace``, it removes all existing beliefs with that name
before adding one new one.
:param belief_changes: The changes in beliefs to apply.
"""
if not beliefs:
if not belief_changes.create and not belief_changes.replace and not belief_changes.delete:
return
for belief in beliefs:
if belief.replace:
self._remove_all_with_name(belief.name)
for belief in belief_changes.create:
self._add_belief(belief.name, belief.arguments)
def _add_belief(self, name: str, args: Iterable[str] = []):
for belief in belief_changes.replace:
self._remove_all_with_name(belief.name)
self._add_belief(belief.name, belief.arguments)
for belief in belief_changes.delete:
self._remove_belief(belief.name, belief.arguments)
def _add_belief(self, name: str, args: list[str] = None):
"""
Add a single belief to the BDI agent.
@@ -168,9 +199,13 @@ class BDICoreAgent(BaseAgent):
:param args: Arguments for the belief.
"""
# new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args)
args = args or []
if args:
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args)
else:
term = agentspeak.Literal(name)
self.bdi_agent.call(
agentspeak.Trigger.addition,
@@ -179,16 +214,35 @@ class BDICoreAgent(BaseAgent):
agentspeak.runtime.Intention(),
)
# Check for transitions
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
agentspeak.Literal("transition_phase"),
agentspeak.runtime.Intention(),
)
# Check triggers
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
agentspeak.Literal("check_triggers"),
agentspeak.runtime.Intention(),
)
self._wake_bdi_loop.set()
self.logger.debug(f"Added belief {self.format_belief_string(name, args)}")
def _remove_belief(self, name: str, args: Iterable[str]):
def _remove_belief(self, name: str, args: Iterable[str] | None):
"""
Removes a specific belief (with arguments), if it exists.
"""
new_args = (agentspeak.Literal(arg) for arg in args)
term = agentspeak.Literal(name, new_args)
if args is None:
term = agentspeak.Literal(name)
else:
new_args = (agentspeak.Literal(arg) for arg in args)
term = agentspeak.Literal(name, new_args)
result = self.bdi_agent.call(
agentspeak.Trigger.removal,
@@ -228,6 +282,43 @@ class BDICoreAgent(BaseAgent):
self.logger.debug(f"Removed {removed_count} beliefs.")
def _set_goal(self, name: str, args: Iterable[str] | None = None):
args = args or []
if args:
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args)
else:
term = agentspeak.Literal(name)
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.achievement,
term,
agentspeak.runtime.Intention(),
)
self._wake_bdi_loop.set()
self.logger.debug(f"Set goal !{self.format_belief_string(name, args)}.")
def _force_trigger(self, name: str):
self._set_goal(name)
self.logger.info("Manually forced trigger %s.", name)
# TODO: make this compatible for critical norms
def _force_norm(self, name: str):
self._add_belief(f"force_{name}")
self.logger.info("Manually forced norm %s.", name)
def _force_next_phase(self):
self._set_goal("force_transition_phase")
self.logger.info("Manually forced phase transition.")
def _add_custom_actions(self) -> None:
"""
Add any custom actions here. Inside `@self.actions.add()`, the first argument is
@@ -235,43 +326,213 @@ class BDICoreAgent(BaseAgent):
the function expects (which will be located in `term.args`).
"""
@self.actions.add(".reply", 3)
def _reply(agent: "BDICoreAgent", term, intention):
@self.actions.add(".reply", 2)
def _reply(agent, term, intention):
"""
Sends text to the LLM (AgentSpeak action).
Example: .reply("Hello LLM!", "Some norm", "Some goal")
Let the LLM generate a response to a user's utterance with the current norms and goals.
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
goals = agentspeak.grounded(term.args[2], intention.scope)
self.logger.debug("Norms: %s", norms)
self.logger.debug("Goals: %s", goals)
self.logger.debug("User text: %s", message_text)
asyncio.create_task(self._send_to_llm(str(message_text), str(norms), str(goals)))
self.add_behavior(self._send_to_llm(str(message_text), str(norms), ""))
yield
async def _send_to_llm(self, text: str, norms: str = None, goals: str = None):
@self.actions.add(".reply_with_goal", 3)
def _reply_with_goal(agent: "BDICoreAgent", term, intention):
"""
Let the LLM generate a response to a user's utterance with the current norms and a
specific goal.
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
goal = agentspeak.grounded(term.args[2], intention.scope)
self.add_behavior(self._send_to_llm(str(message_text), str(norms), str(goal)))
yield
@self.actions.add(".notify_norms", 1)
def _notify_norms(agent, term, intention):
norms = agentspeak.grounded(term.args[0], intention.scope)
norm_update_message = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
thread="active_norms_update",
body=str(norms),
)
self.add_behavior(self.send(norm_update_message, should_log=False))
yield
@self.actions.add(".say", 1)
def _say(agent, term, intention):
"""
Make the robot say the given text instantly.
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug('"say" action called with text=%s', message_text)
speech_command = SpeechCommand(data=message_text)
speech_message = InternalMessage(
to=settings.agent_settings.robot_speech_name,
sender=settings.agent_settings.bdi_core_name,
body=speech_command.model_dump_json(),
)
self.add_behavior(self.send(speech_message))
chat_history_message = InternalMessage(
to=settings.agent_settings.llm_name,
thread="assistant_message",
body=str(message_text),
)
self.add_behavior(self.send(chat_history_message))
yield
@self.actions.add(".gesture", 2)
def _gesture(agent, term, intention):
"""
Make the robot perform the given gesture instantly.
"""
gesture_type = agentspeak.grounded(term.args[0], intention.scope)
gesture_name = agentspeak.grounded(term.args[1], intention.scope)
self.logger.debug(
'"gesture" action called with type=%s, name=%s',
gesture_type,
gesture_name,
)
if str(gesture_type) == "single":
endpoint = RIEndpoint.GESTURE_SINGLE
elif str(gesture_type) == "tag":
endpoint = RIEndpoint.GESTURE_TAG
else:
self.logger.warning("Gesture type %s could not be resolved.", gesture_type)
endpoint = RIEndpoint.GESTURE_SINGLE
gesture_command = GestureCommand(endpoint=endpoint, data=gesture_name)
gesture_message = InternalMessage(
to=settings.agent_settings.robot_gesture_name,
sender=settings.agent_settings.bdi_core_name,
body=gesture_command.model_dump_json(),
)
self.add_behavior(self.send(gesture_message))
yield
@self.actions.add(".notify_user_said", 1)
def _notify_user_said(agent, term, intention):
user_said = agentspeak.grounded(term.args[0], intention.scope)
msg = InternalMessage(
to=settings.agent_settings.llm_name, thread="user_message", body=str(user_said)
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_trigger_start", 1)
def _notify_trigger_start(agent, term, intention):
"""
Notify the UI about the trigger we just started doing.
"""
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Started trigger %s", trigger_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="trigger_start",
body=str(trigger_name),
)
# TODO: check with Pim
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_trigger_end", 1)
def _notify_trigger_end(agent, term, intention):
"""
Notify the UI about the trigger we just started doing.
"""
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Finished trigger %s", trigger_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="trigger_end",
body=str(trigger_name),
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_goal_start", 1)
def _notify_goal_start(agent, term, intention):
"""
Notify the UI about the goal we just started chasing.
"""
goal_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Started chasing goal %s", goal_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
sender=self.name,
thread="goal_start",
body=str(goal_name),
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_transition_phase", 2)
def _notify_transition_phase(agent, term, intention):
"""
Notify the BDI program manager about a phase transition.
"""
old = agentspeak.grounded(term.args[0], intention.scope)
new = agentspeak.grounded(term.args[1], intention.scope)
msg = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
thread="transition_phase",
body=json.dumps({"old": str(old), "new": str(new)}),
)
self.add_behavior(self.send(msg))
yield
@self.actions.add(".notify_ui", 0)
def _notify_ui(agent, term, intention):
pass
async def _send_to_llm(self, text: str, norms: str, goals: str):
"""
Sends a text query to the LLM agent asynchronously.
"""
prompt = LLMPromptMessage(
text=text,
norms=norms.split("\n") if norms else [],
goals=goals.split("\n") if norms else [],
)
prompt = LLMPromptMessage(text=text, norms=norms.split("\n"), goals=goals.split("\n"))
msg = InternalMessage(
to=settings.agent_settings.llm_name,
sender=self.name,
body=prompt.model_dump_json(),
thread="prompt_message",
)
await self.send(msg)
self.logger.info("Message sent to LLM agent: %s", text)
@staticmethod
def format_belief_string(name: str, args: Iterable[str] = []):
def format_belief_string(name: str, args: Iterable[str] | None = []):
"""
Given a belief's name and its args, return a string of the form "name(*args)"
"""
return f"{name}{'(' if args else ''}{','.join(args)}{')' if args else ''}"
return f"{name}{'(' if args else ''}{','.join(args or [])}{')' if args else ''}"

View File

@@ -1,12 +1,23 @@
import asyncio
import json
import zmq
from pydantic import ValidationError
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.program import Program
from control_backend.schemas.belief_list import BeliefList, GoalList
from control_backend.schemas.internal_message import InternalMessage
from control_backend.schemas.program import (
Belief,
ConditionalNorm,
Goal,
InferredBelief,
Phase,
Program,
)
class BDIProgramManager(BaseAgent):
@@ -21,44 +32,214 @@ class BDIProgramManager(BaseAgent):
:ivar sub_socket: The ZMQ SUB socket used to receive program updates.
"""
_program: Program
_phase: Phase | None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
async def _send_to_bdi(self, program: Program):
def _initialize_internal_state(self, program: Program):
self._program = program
self._phase = program.phases[0] # start in first phase
self._goal_mapping: dict[str, Goal] = {}
for phase in program.phases:
for goal in phase.goals:
self._populate_goal_mapping_with_goal(goal)
def _populate_goal_mapping_with_goal(self, goal: Goal):
self._goal_mapping[str(goal.id)] = goal
for step in goal.plan.steps:
if isinstance(step, Goal):
self._populate_goal_mapping_with_goal(step)
async def _create_agentspeak_and_send_to_bdi(self, program: Program):
"""
Convert a received program into BDI beliefs and send them to the BDI Core Agent.
Currently, it takes the **first phase** of the program and extracts:
- **Norms**: Constraints or rules the agent must follow.
- **Goals**: Objectives the agent must achieve.
These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will
overwrite any existing norms/goals of the same name in the BDI agent.
Convert a received program into an AgentSpeak file and send it to the BDI Core Agent.
:param program: The program object received from the API.
"""
first_phase = program.phases[0]
norms_belief = Belief(
name="norms",
arguments=[norm.norm for norm in first_phase.norms],
replace=True,
asg = AgentSpeakGenerator()
asl_str = asg.generate(program)
file_name = "src/control_backend/agents/bdi/agentspeak.asl"
with open(file_name, "w") as f:
f.write(asl_str)
msg = InternalMessage(
sender=self.name,
to=settings.agent_settings.bdi_core_name,
body=file_name,
thread="new_program",
)
goals_belief = Belief(
name="goals",
arguments=[goal.description for goal in first_phase.goals],
replace=True,
await self.send(msg)
async def handle_message(self, msg: InternalMessage):
match msg.thread:
case "transition_phase":
phases = json.loads(msg.body)
await self._transition_phase(phases["old"], phases["new"])
case "achieve_goal":
goal_id = msg.body
await self._send_achieved_goal_to_semantic_belief_extractor(goal_id)
async def _transition_phase(self, old: str, new: str):
if old != str(self._phase.id):
self.logger.warning(
f"Phase transition desync detected! ASL requested move from '{old}', "
f"but Python is currently in '{self._phase.id}'. Request ignored."
)
return
if new == "end":
self._phase = None
# Notify user interaction agent
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
thread="transition_phase",
body="end",
)
self.logger.info("Transitioned to end phase, notifying UserInterruptAgent.")
self.add_behavior(self.send(msg))
return
for phase in self._program.phases:
if str(phase.id) == new:
self._phase = phase
await self._send_beliefs_to_semantic_belief_extractor()
await self._send_goals_to_semantic_belief_extractor()
# Notify user interaction agent
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
thread="transition_phase",
body=str(self._phase.id),
)
program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief])
self.logger.info(f"Transitioned to phase {new}, notifying UserInterruptAgent.")
self.add_behavior(self.send(msg))
def _extract_current_beliefs(self) -> list[Belief]:
beliefs: list[Belief] = []
for norm in self._phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += self._extract_beliefs_from_belief(norm.condition)
for trigger in self._phase.triggers:
beliefs += self._extract_beliefs_from_belief(trigger.condition)
return beliefs
@staticmethod
def _extract_beliefs_from_belief(belief: Belief) -> list[Belief]:
if isinstance(belief, InferredBelief):
return BDIProgramManager._extract_beliefs_from_belief(
belief.left
) + BDIProgramManager._extract_beliefs_from_belief(belief.right)
return [belief]
async def _send_beliefs_to_semantic_belief_extractor(self):
"""
Extract beliefs from the program and send them to the Semantic Belief Extractor Agent.
"""
beliefs = BeliefList(beliefs=self._extract_current_beliefs())
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
to=settings.agent_settings.text_belief_extractor_name,
sender=self.name,
body=program_beliefs.model_dump_json(),
body=beliefs.model_dump_json(),
thread="beliefs",
)
await self.send(message)
self.logger.debug("Sent new norms and goals to the BDI agent.")
@staticmethod
def _extract_goals_from_goal(goal: Goal) -> list[Goal]:
"""
Extract all goals from a given goal, that is: the goal itself and any subgoals.
:return: All goals within and including the given goal.
"""
goals: list[Goal] = [goal]
for plan in goal.plan:
if isinstance(plan, Goal):
goals.extend(BDIProgramManager._extract_goals_from_goal(plan))
return goals
def _extract_current_goals(self) -> list[Goal]:
"""
Extract all goals from the program, including subgoals.
:return: A list of Goal objects.
"""
goals: list[Goal] = []
for goal in self._phase.goals:
goals.extend(self._extract_goals_from_goal(goal))
return goals
async def _send_goals_to_semantic_belief_extractor(self):
"""
Extract goals for the current phase and send them to the Semantic Belief Extractor Agent.
"""
goals = GoalList(goals=self._extract_current_goals())
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=self.name,
body=goals.model_dump_json(),
thread="goals",
)
await self.send(message)
async def _send_achieved_goal_to_semantic_belief_extractor(self, achieved_goal_id: str):
"""
Inform the semantic belief extractor when a goal is marked achieved.
:param achieved_goal_id: The id of the achieved goal.
"""
goal = self._goal_mapping.get(achieved_goal_id)
if goal is None:
self.logger.debug(f"Goal with ID {achieved_goal_id} marked achieved but was not found.")
return
goals = self._extract_goals_from_goal(goal)
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
body=GoalList(goals=goals).model_dump_json(),
thread="achieved_goals",
)
await self.send(message)
async def _send_clear_llm_history(self):
"""
Clear the LLM Agent's conversation history.
Sends an empty history to the LLM Agent to reset its state.
"""
message = InternalMessage(
to=settings.agent_settings.llm_name,
body="clear_history",
)
await self.send(message)
self.logger.debug("Sent message to LLM agent to clear history.")
extractor_msg = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
thread="conversation_history",
body="reset",
)
await self.send(extractor_msg)
self.logger.debug("Sent message to extractor agent to clear history.")
async def _receive_programs(self):
"""
@@ -66,6 +247,7 @@ class BDIProgramManager(BaseAgent):
It listens to the ``program`` topic on the internal ZMQ SUB socket.
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`.
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
@@ -73,18 +255,43 @@ class BDIProgramManager(BaseAgent):
try:
program = Program.model_validate_json(body)
except ValidationError:
self.logger.exception("Received an invalid program.")
self.logger.warning("Received an invalid program.")
continue
await self._send_to_bdi(program)
self._initialize_internal_state(program)
await self._send_program_to_user_interrupt(program)
await self._send_clear_llm_history()
await asyncio.gather(
self._create_agentspeak_and_send_to_bdi(program),
self._send_beliefs_to_semantic_belief_extractor(),
self._send_goals_to_semantic_belief_extractor(),
)
async def _send_program_to_user_interrupt(self, program: Program):
"""
Send the received program to the User Interrupt Agent.
:param program: The program object received from the API.
"""
msg = InternalMessage(
sender=self.name,
to=settings.agent_settings.user_interrupt_name,
body=program.model_dump_json(),
thread="new_program",
)
await self.send(msg)
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'program' topic.
Starts the background behavior to receive programs.
Starts the background behavior to receive programs. Initializes a default program.
"""
await self._create_agentspeak_and_send_to_bdi(Program(phases=[]))
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)

View File

@@ -1,152 +0,0 @@
import json
from pydantic import ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
class BDIBeliefCollectorAgent(BaseAgent):
"""
BDI Belief Collector Agent.
This agent acts as a central aggregator for beliefs derived from various sources (e.g., text,
emotion, vision). It receives raw extracted data from other agents,
normalizes them into valid :class:`Belief` objects, and forwards them as a unified packet to the
BDI Core Agent.
It serves as a funnel to ensure the BDI agent receives a consistent stream of beliefs.
"""
async def setup(self):
"""
Initialize the agent.
"""
self.logger.info("Setting up %s", self.name)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages from other extractor agents.
Routes the message to specific handlers based on the 'type' field in the JSON body.
Supported types:
- ``belief_extraction_text``: Handled by :meth:`_handle_belief_text`
- ``emotion_extraction_text``: Handled by :meth:`_handle_emo_text`
:param msg: The received internal message.
"""
sender_node = msg.sender
# Parse JSON payload
try:
payload = json.loads(msg.body)
except Exception as e:
self.logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node,
msg.body,
e,
)
return
msg_type = payload.get("type")
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text":
self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node)
await self._handle_belief_text(payload, sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text":
self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
await self._handle_emo_text(payload, sender_node)
else:
self.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type
)
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Process text-based belief extraction payloads.
Expected payload format::
{
"type": "belief_extraction_text",
"beliefs": {
"user_said": ["Can you help me?"],
"intention": ["ask_help"]
}
}
Validates and converts the dictionary items into :class:`Belief` objects.
:param payload: The dictionary payload containing belief data.
:param origin: The name of the sender agent.
"""
beliefs = payload.get("beliefs", {})
if not beliefs:
self.logger.debug("Received empty beliefs set.")
return
def try_create_belief(name, arguments) -> Belief | None:
"""
Create a belief object from name and arguments, or return None silently if the input is
not correct.
:param name: The name of the belief.
:param arguments: The arguments of the belief.
:return: A Belief object if the input is valid or None.
"""
try:
return Belief(name=name, arguments=arguments)
except ValidationError:
return None
beliefs = [
belief
for name, arguments in beliefs.items()
if (belief := try_create_belief(name, arguments)) is not None
]
self.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief in beliefs:
for argument in belief.arguments:
self.logger.debug(" - %s %s", belief.name, argument)
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""
Process emotion extraction payloads.
**TODO**: Implement this method once emotion recognition is integrated.
:param payload: The dictionary payload containing emotion data.
:param origin: The name of the sender agent.
"""
pass
async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None):
"""
Send a list of aggregated beliefs to the BDI Core Agent.
Wraps the beliefs in a :class:`BeliefMessage` and sends it via the 'beliefs' thread.
:param beliefs: The list of Belief objects to send.
:param origin: (Optional) The original source of the beliefs (unused currently).
"""
if not beliefs:
return
msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
thread="beliefs",
)
await self.send(msg)
self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))

View File

@@ -0,0 +1,34 @@
phase("end").
keyword_said(Keyword) :- (user_said(Message) & .substring(Keyword, Message, Pos)) & (Pos >= 0).
+!reply_with_goal(Goal)
: user_said(Message)
<- +responded_this_turn;
.findall(Norm, norm(Norm), Norms);
.reply_with_goal(Message, Norms, Goal).
+!say(Text)
<- +responded_this_turn;
.say(Text).
+!reply
: user_said(Message)
<- +responded_this_turn;
.findall(Norm, norm(Norm), Norms);
.reply(Message, Norms).
+!notify_cycle
<- .notify_ui;
.wait(1).
+user_said(Message)
: phase("end")
<- .notify_user_said(Message);
!reply.
+!check_triggers
<- true.
+!transition_phase
<- true.

View File

@@ -1,6 +0,0 @@
norms("").
goals("").
+user_said(Message) : norms(Norms) & goals(Goals) <-
-user_said(Message);
.reply(Message, Norms, Goals).

View File

@@ -1,8 +1,46 @@
import asyncio
import json
import httpx
from pydantic import BaseModel, ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_list import BeliefList, GoalList
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import BaseGoal, SemanticBelief
type JSONLike = None | bool | int | float | str | list["JSONLike"] | dict[str, "JSONLike"]
class BeliefState(BaseModel):
true: set[InternalBelief] = set()
false: set[InternalBelief] = set()
def difference(self, other: "BeliefState") -> "BeliefState":
return BeliefState(
true=self.true - other.true,
false=self.false - other.false,
)
def union(self, other: "BeliefState") -> "BeliefState":
return BeliefState(
true=self.true | other.true,
false=self.false | other.false,
)
def __sub__(self, other):
return self.difference(other)
def __or__(self, other):
return self.union(other)
def __bool__(self):
return bool(self.true) or bool(self.false)
class TextBeliefExtractorAgent(BaseAgent):
@@ -12,54 +50,454 @@ class TextBeliefExtractorAgent(BaseAgent):
This agent is responsible for processing raw text (e.g., from speech transcription) and
extracting semantic beliefs from it.
In the current demonstration version, it performs a simple wrapping of the user's input
into a ``user_said`` belief. In a full implementation, this agent would likely interact
with an LLM or NLU engine to extract intent, entities, and other structured information.
It uses the available beliefs received from the program manager to try to extract beliefs from a
user's message, sends and updated beliefs to the BDI core, and forms a ``user_said`` belief from
the message itself.
"""
def __init__(self, name: str):
super().__init__(name)
self._llm = self.LLM(self, settings.llm_settings.n_parallel)
self.belief_inferrer = SemanticBeliefInferrer(self._llm)
self.goal_inferrer = GoalAchievementInferrer(self._llm)
self._current_beliefs = BeliefState()
self._current_goal_completions: dict[str, bool] = {}
self._force_completed_goals: set[BaseGoal] = set()
self.conversation = ChatHistory(messages=[])
async def setup(self):
"""
Initialize the agent and its resources.
"""
self.logger.info("Settting up %s.", self.name)
# Setup LLM belief context if needed (currently demo is just passthrough)
self.beliefs = {"mood": ["X"], "car": ["Y"]}
self.logger.info("Setting up %s.", self.name)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages, primarily from the Transcription Agent.
Handle incoming messages. Expect messages from the Transcriber agent, LLM agent, and the
Program manager agent.
:param msg: The received message containing transcribed text.
:param msg: The received message.
"""
sender = msg.sender
if sender == settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body)
else:
self.logger.info("Discarding message from %s", sender)
async def _process_transcription_demo(self, txt: str):
match sender:
case settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="user", content=msg.body))
await self._user_said(msg.body)
await self._infer_new_beliefs()
await self._infer_goal_completions()
case settings.agent_settings.llm_name:
self.logger.debug("Received text from LLM: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
case settings.agent_settings.bdi_program_manager_name:
await self._handle_program_manager_message(msg)
case _:
self.logger.info("Discarding message from %s", sender)
return
def _apply_conversation_message(self, message: ChatMessage):
"""
Process the transcribed text and generate beliefs.
Save the chat message to our conversation history, taking into account the conversation
length limit.
**Demo Implementation:**
Currently, this method takes the raw text ``txt`` and wraps it into a belief structure:
``user_said("txt")``.
This belief is then sent to the :class:`BDIBeliefCollectorAgent`.
:param txt: The raw transcribed text string.
:param message: The chat message to add to the conversation history.
"""
# For demo, just wrapping user text as user_said belief
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
length_limit = settings.behaviour_settings.conversation_history_length_limit
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:]
belief_msg = InternalMessage(
to=settings.agent_settings.bdi_belief_collector_name,
sender=self.name,
body=payload,
thread="beliefs",
async def _handle_program_manager_message(self, msg: InternalMessage):
"""
Handle a message from the program manager: extract available beliefs and goals from it.
:param msg: The received message from the program manager.
"""
match msg.thread:
case "beliefs":
self._handle_beliefs_message(msg)
await self._infer_new_beliefs()
case "goals":
self._handle_goals_message(msg)
await self._infer_goal_completions()
case "achieved_goals":
self._handle_goal_achieved_message(msg)
case "conversation_history":
if msg.body == "reset":
self._reset_phase()
case _:
self.logger.warning("Received unexpected message from %s", msg.sender)
def _reset_phase(self):
self.conversation = ChatHistory(messages=[])
self.belief_inferrer.available_beliefs.clear()
self._current_beliefs = BeliefState()
self.goal_inferrer.goals.clear()
self._current_goal_completions = {}
def _handle_beliefs_message(self, msg: InternalMessage):
try:
belief_list = BeliefList.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received message from program manager but it is not a valid list of beliefs."
)
return
available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
self.belief_inferrer.available_beliefs = available_beliefs
self.logger.debug(
"Received %d semantic beliefs from the program manager: %s",
len(available_beliefs),
", ".join(b.name for b in available_beliefs),
)
def _handle_goals_message(self, msg: InternalMessage):
try:
goals_list = GoalList.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received message from program manager but it is not a valid list of goals."
)
return
# Use only goals that can fail, as the others are always assumed to be completed
available_goals = {g for g in goals_list.goals if g.can_fail}
available_goals -= self._force_completed_goals
self.goal_inferrer.goals = available_goals
self.logger.debug(
"Received %d failable goals from the program manager: %s",
len(available_goals),
", ".join(g.name for g in available_goals),
)
def _handle_goal_achieved_message(self, msg: InternalMessage):
# NOTE: When goals can be marked unachieved, remember to re-add them to the goal_inferrer
try:
goals_list = GoalList.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received goal achieved message from the program manager, "
"but it is not a valid list of goals."
)
return
for goal in goals_list.goals:
self._force_completed_goals.add(goal)
self._current_goal_completions[f"achieved_{AgentSpeakGenerator.slugify(goal)}"] = True
self.goal_inferrer.goals -= self._force_completed_goals
async def _user_said(self, text: str):
"""
Create a belief for the user's full speech.
:param text: User's transcribed text.
"""
belief_msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=BeliefMessage(
replace=[InternalBelief(name="user_said", arguments=[text])],
).model_dump_json(),
thread="beliefs",
)
await self.send(belief_msg)
self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"]))
async def _infer_new_beliefs(self):
conversation_beliefs = await self.belief_inferrer.infer_from_conversation(self.conversation)
new_beliefs = conversation_beliefs - self._current_beliefs
if not new_beliefs:
self.logger.debug("No new beliefs detected.")
return
self._current_beliefs |= new_beliefs
belief_changes = BeliefMessage(
create=list(new_beliefs.true),
delete=list(new_beliefs.false),
)
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=belief_changes.model_dump_json(),
thread="beliefs",
)
await self.send(message)
async def _infer_goal_completions(self):
goal_completions = await self.goal_inferrer.infer_from_conversation(self.conversation)
new_achieved = [
InternalBelief(name=goal, arguments=None)
for goal, achieved in goal_completions.items()
if achieved and self._current_goal_completions.get(goal) != achieved
]
new_not_achieved = [
InternalBelief(name=goal, arguments=None)
for goal, achieved in goal_completions.items()
if not achieved and self._current_goal_completions.get(goal) != achieved
]
for goal, achieved in goal_completions.items():
self._current_goal_completions[goal] = achieved
if not new_achieved and not new_not_achieved:
self.logger.debug("No goal achievement changes detected.")
return
belief_changes = BeliefMessage(
create=new_achieved,
delete=new_not_achieved,
)
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=belief_changes.model_dump_json(),
thread="beliefs",
)
await self.send(message)
class LLM:
"""
Class that handles sending structured generation requests to an LLM.
"""
def __init__(self, agent: "TextBeliefExtractorAgent", n_parallel: int):
self._agent = agent
self._semaphore = asyncio.Semaphore(n_parallel)
async def query(self, prompt: str, schema: dict, tries: int = 3) -> JSONLike | None:
"""
Query the LLM with the given prompt and schema, return an instance of a dict conforming
to this schema. Try ``tries`` times, or return None.
:param prompt: Prompt to be queried.
:param schema: Schema to be queried.
:param tries: Number of times to try to query the LLM.
:return: An instance of a dict conforming to this schema, or None if failed.
"""
try_count = 0
while try_count < tries:
try_count += 1
try:
return await self._query_llm(prompt, schema)
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
if try_count < tries:
continue
self._agent.logger.exception(
"Failed to get LLM response after %d tries.",
try_count,
exc_info=e,
)
return None
async def _query_llm(self, prompt: str, schema: dict) -> JSONLike:
"""
Query an LLM with the given prompt and schema, return an instance of a dict conforming
to that schema.
:param prompt: The prompt to be queried.
:param schema: Schema to use during response.
:return: A dict conforming to this schema.
:raises httpx.HTTPStatusError: If the LLM server responded with an error.
:raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the
response was cut off early due to length limitations.
:raises KeyError: If the LLM server responded with no error, but the response was
invalid.
"""
async with self._semaphore:
async with httpx.AsyncClient() as client:
response = await client.post(
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": [{"role": "user", "content": prompt}],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "Beliefs",
"strict": True,
"schema": schema,
},
},
"reasoning_effort": "low",
"temperature": settings.llm_settings.code_temperature,
"stream": False,
},
timeout=30.0,
)
response.raise_for_status()
response_json = response.json()
json_message = response_json["choices"][0]["message"]["content"]
return json.loads(json_message)
class SemanticBeliefInferrer:
"""
Class that handles only prompting an LLM for semantic beliefs.
"""
def __init__(
self,
llm: "TextBeliefExtractorAgent.LLM",
available_beliefs: list[SemanticBelief] | None = None,
):
self._llm = llm
self.available_beliefs: list[SemanticBelief] = available_beliefs or []
async def infer_from_conversation(self, conversation: ChatHistory) -> BeliefState:
"""
Process conversation history to extract beliefs, semantically. The result is an object that
describes all beliefs that hold or don't hold based on the full conversation.
:param conversation: The conversation history to be processed.
:return: An object that describes beliefs.
"""
# Return instantly if there are no beliefs to infer
if not self.available_beliefs:
return BeliefState()
n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)))
all_beliefs: list[dict[str, bool | None] | None] = await asyncio.gather(
*[
self._infer_beliefs(conversation, beliefs)
for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel)
]
)
retval = BeliefState()
for beliefs in all_beliefs:
if beliefs is None:
continue
for belief_name, belief_holds in beliefs.items():
if belief_holds is None:
continue
belief = InternalBelief(name=belief_name, arguments=None)
if belief_holds:
retval.true.add(belief)
else:
retval.false.add(belief)
return retval
@staticmethod
def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]:
"""
Split a list into ``n`` chunks, making each chunk approximately ``len(items) / n`` long.
:param items: The list of items to split.
:param n: The number of desired chunks.
:return: A list of chunks each approximately ``len(items) / n`` long.
"""
k, m = divmod(len(items), n)
return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]
async def _infer_beliefs(
self,
conversation: ChatHistory,
beliefs: list[SemanticBelief],
) -> dict[str, bool | None] | None:
"""
Infer given beliefs based on the given conversation.
:param conversation: The conversation to infer beliefs from.
:param beliefs: The beliefs to infer.
:return: A dict containing belief names and a boolean whether they hold, or None if the
belief cannot be inferred based on the given conversation.
"""
example = {
"example_belief": True,
}
prompt = f"""{self._format_conversation(conversation)}
Given the above conversation, what beliefs can be inferred?
If there is no relevant information about a belief belief, give null.
In case messages conflict, prefer using the most recent messages for inference.
Choose from the following list of beliefs, formatted as `- <belief_name>: <description>`:
{self._format_beliefs(beliefs)}
Respond with a JSON similar to the following, but with the property names as given above:
{json.dumps(example, indent=2)}
"""
schema = self._create_beliefs_schema(beliefs)
return await self._llm.query(prompt, schema)
@staticmethod
def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]:
return AgentSpeakGenerator.slugify(belief), {
"type": ["boolean", "null"],
"description": belief.description,
}
@staticmethod
def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict:
belief_schemas = [
SemanticBeliefInferrer._create_belief_schema(belief) for belief in beliefs
]
return {
"type": "object",
"properties": dict(belief_schemas),
"required": [name for name, _ in belief_schemas],
}
@staticmethod
def _format_message(message: ChatMessage):
return f"{message.role.upper()}:\n{message.content}"
@staticmethod
def _format_conversation(conversation: ChatHistory):
return "\n\n".join(
[SemanticBeliefInferrer._format_message(message) for message in conversation.messages]
)
@staticmethod
def _format_beliefs(beliefs: list[SemanticBelief]):
return "\n".join(
[f"- {AgentSpeakGenerator.slugify(belief)}: {belief.description}" for belief in beliefs]
)
class GoalAchievementInferrer(SemanticBeliefInferrer):
def __init__(self, llm: TextBeliefExtractorAgent.LLM):
super().__init__(llm)
self.goals: set[BaseGoal] = set()
async def infer_from_conversation(self, conversation: ChatHistory) -> dict[str, bool]:
"""
Determine which goals have been achieved based on the given conversation.
:param conversation: The conversation to infer goal completion from.
:return: A mapping of goals and a boolean whether they have been achieved.
"""
if not self.goals:
return {}
goals_achieved = await asyncio.gather(
*[self._infer_goal(conversation, g) for g in self.goals]
)
return {
f"achieved_{AgentSpeakGenerator.slugify(goal)}": achieved
for goal, achieved in zip(self.goals, goals_achieved, strict=True)
}
async def _infer_goal(self, conversation: ChatHistory, goal: BaseGoal) -> bool:
prompt = f"""{self._format_conversation(conversation)}
Given the above conversation, what has the following goal been achieved?
The name of the goal: {goal.name}
Description of the goal: {goal.description}
Answer with literally only `true` or `false` (without backticks)."""
schema = {
"type": "boolean",
}
return await self._llm.query(prompt, schema)

View File

@@ -3,13 +3,17 @@ import json
import zmq
import zmq.asyncio as azmq
from pydantic import ValidationError
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
from control_backend.schemas.ri_message import PauseCommand
from ..actuation.robot_speech_agent import RobotSpeechAgent
from ..perception import VADAgent
class RICommunicationAgent(BaseAgent):
@@ -37,7 +41,7 @@ class RICommunicationAgent(BaseAgent):
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
address=settings.zmq_settings.ri_communication_address,
bind=False,
):
super().__init__(name)
@@ -46,6 +50,8 @@ class RICommunicationAgent(BaseAgent):
self._req_socket: azmq.Socket | None = None
self.pub_socket: azmq.Socket | None = None
self.connected = False
self.gesture_agent: RobotGestureAgent | None = None
self.speech_agent: RobotSpeechAgent | None = None
async def setup(self):
"""
@@ -139,6 +145,7 @@ class RICommunicationAgent(BaseAgent):
# At this point, we have a valid response
try:
self.logger.debug("Negotiation successful. Handling rn")
await self._handle_negotiation_response(received_message)
# Let UI know that we're connected
topic = b"ping"
@@ -167,7 +174,7 @@ class RICommunicationAgent(BaseAgent):
bind = port_data["bind"]
if not bind:
addr = f"tcp://localhost:{port}"
addr = f"tcp://{settings.ri_host}:{port}"
else:
addr = f"tcp://*:{port}"
@@ -181,20 +188,27 @@ class RICommunicationAgent(BaseAgent):
self._req_socket.bind(addr)
case "actuation":
gesture_data = port_data.get("gestures", [])
single_gesture_data = port_data.get("single_gestures", [])
robot_speech_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name,
address=addr,
bind=bind,
)
self.speech_agent = robot_speech_agent
robot_gesture_agent = RobotGestureAgent(
settings.agent_settings.robot_gesture_name,
address=addr,
bind=bind,
gesture_data=gesture_data,
single_gesture_data=single_gesture_data,
)
self.gesture_agent = robot_gesture_agent
await robot_speech_agent.start()
await asyncio.sleep(0.1) # Small delay
await robot_gesture_agent.start()
case "audio":
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
await vad_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)
@@ -219,6 +233,7 @@ class RICommunicationAgent(BaseAgent):
while self._running:
if not self.connected:
await asyncio.sleep(settings.behaviour_settings.sleep_s)
self.logger.debug("Not connected, skipping ping loop iteration.")
continue
# We need to listen and send pings.
@@ -242,7 +257,8 @@ class RICommunicationAgent(BaseAgent):
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
)
self.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" in message and message["endpoint"] != "ping":
self.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" not in message:
self.logger.warning("No received endpoint in message, expected ping endpoint.")
continue
@@ -282,13 +298,33 @@ class RICommunicationAgent(BaseAgent):
# Tell UI we're disconnected.
topic = b"ping"
data = json.dumps(False).encode()
self.logger.debug("1")
if self.pub_socket:
try:
self.logger.debug("2")
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
except TimeoutError:
self.logger.debug("3")
self.logger.warning("Connection ping for router timed out.")
# Try to reboot/renegotiate
if self.gesture_agent is not None:
await self.gesture_agent.stop()
if self.speech_agent is not None:
await self.speech_agent.stop()
if self.pub_socket is not None:
self.pub_socket.close()
self.logger.debug("Restarting communication negotiation.")
if await self._negotiate_connection(max_retries=1):
if await self._negotiate_connection(max_retries=2):
self.connected = True
async def handle_message(self, msg: InternalMessage):
try:
pause_command = PauseCommand.model_validate_json(msg.body)
await self._req_socket.send_json(pause_command.model_dump())
self.logger.debug(await self._req_socket.recv_json())
except ValidationError:
self.logger.warning("Incorrect message format for PauseCommand.")

View File

@@ -46,14 +46,23 @@ class LLMAgent(BaseAgent):
:param msg: The received internal message.
"""
if msg.sender == settings.agent_settings.bdi_core_name:
self.logger.debug("Processing message from BDI core.")
try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message)
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
match msg.thread:
case "prompt_message":
try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message)
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
case "assistant_message":
self.history.append({"role": "assistant", "content": msg.body})
case "user_message":
self.history.append({"role": "user", "content": msg.body})
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
if msg.body == "clear_history":
self.logger.debug("Clearing conversation history.")
self.history.clear()
else:
self.logger.debug("Message ignored (not from BDI core.")
self.logger.debug("Message ignored.")
async def _process_bdi_message(self, message: LLMPromptMessage):
"""
@@ -64,11 +73,12 @@ class LLMAgent(BaseAgent):
:param message: The parsed prompt message containing text, norms, and goals.
"""
full_message = ""
async for chunk in self._query_llm(message.text, message.norms, message.goals):
await self._send_reply(chunk)
self.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI core."
)
full_message += chunk
self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.")
await self._send_full_reply(full_message)
async def _send_reply(self, msg: str):
"""
@@ -83,6 +93,19 @@ class LLMAgent(BaseAgent):
)
await self.send(reply)
async def _send_full_reply(self, msg: str):
"""
Sends a response message (full) to agents that need it.
:param msg: The text content of the message.
"""
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=self.name,
body=msg,
)
await self.send(message)
async def _query_llm(
self, prompt: str, norms: list[str], goals: list[str]
) -> AsyncGenerator[str]:
@@ -100,13 +123,6 @@ class LLMAgent(BaseAgent):
:param goals: Goals the LLM should achieve.
:yield: Fragments of the LLM-generated content (e.g., sentences/phrases).
"""
self.history.append(
{
"role": "user",
"content": prompt,
}
)
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
messages = [
{
@@ -125,7 +141,7 @@ class LLMAgent(BaseAgent):
full_message += token
current_chunk += token
self.logger.info(
self.logger.llm(
"Received token: %s",
full_message,
extra={"reference": message_id}, # Used in the UI to update old logs
@@ -172,7 +188,7 @@ class LLMAgent(BaseAgent):
json={
"model": settings.llm_settings.local_llm_model,
"messages": messages,
"temperature": 0.3,
"temperature": settings.llm_settings.chat_temperature,
"stream": True,
},
) as response:

View File

@@ -7,7 +7,9 @@ import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
from .transcription_agent.transcription_agent import TranscriptionAgent
@@ -61,6 +63,7 @@ class VADAgent(BaseAgent):
:ivar audio_in_address: Address of the input audio stream.
:ivar audio_in_bind: Whether to bind or connect to the input address.
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
@@ -79,9 +82,17 @@ class VADAgent(BaseAgent):
self.audio_out_socket: azmq.Socket | None = None
self.audio_in_poller: SocketPoller | None = None
self.program_sub_socket: azmq.Socket | None = None
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = asyncio.Event()
# Pause control
self._reset_needed = False
self._paused = asyncio.Event()
self._paused.set() # Not paused at start
self.model = None
async def setup(self):
@@ -90,20 +101,25 @@ class VADAgent(BaseAgent):
1. Connects audio input socket.
2. Binds audio output socket (random port).
3. Loads VAD model from Torch Hub.
4. Starts the streaming loop.
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
3. Connects to program communication socket.
4. Loads VAD model from Torch Hub.
5. Starts the streaming loop.
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
"""
self.logger.info("Setting up %s", self.name)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
audio_out_address = self._connect_audio_out_socket()
if audio_out_address is None:
self.logger.error("Could not bind output socket, stopping.")
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Connect to internal communication socket
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.program_sub_socket.subscribe(PROGRAM_STATUS)
# Initialize VAD model
try:
@@ -117,10 +133,8 @@ class VADAgent(BaseAgent):
await self.stop()
return
# Warmup/reset
await self.reset_stream()
self.add_behavior(self._streaming_loop())
self.add_behavior(self._status_loop())
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
@@ -153,19 +167,20 @@ class VADAgent(BaseAgent):
self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
def _connect_audio_out_socket(self) -> str | None:
"""
Returns the port bound, or None if binding failed.
Returns the address that was bound to, or None if binding failed.
"""
try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
self.audio_out_socket.bind(settings.zmq_settings.vad_pub_address)
return settings.zmq_settings.vad_pub_address
except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def reset_stream(self):
async def _reset_stream(self):
"""
Clears the ZeroMQ queue and sets ready state.
"""
@@ -176,6 +191,23 @@ class VADAgent(BaseAgent):
self.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready.set()
async def _status_loop(self):
"""Loop for checking program status. Only start listening if program is RUNNING."""
while self._running:
topic, body = await self.program_sub_socket.recv_multipart()
if topic != PROGRAM_STATUS:
continue
if body != ProgramStatus.RUNNING.value:
continue
# Program is now running, we can start our stream
await self._reset_stream()
# We don't care about further status updates
self.program_sub_socket.close()
break
async def _streaming_loop(self):
"""
Main loop for processing audio stream.
@@ -188,6 +220,16 @@ class VADAgent(BaseAgent):
"""
await self._ready.wait()
while self._running:
await self._paused.wait()
# After being unpaused, reset stream and buffers
if self._reset_needed:
self.logger.debug("Resuming: resetting stream and buffers.")
await self._reset_stream()
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._reset_needed = False
assert self.audio_in_poller is not None
data = await self.audio_in_poller.poll()
if data is None:
@@ -204,10 +246,11 @@ class VADAgent(BaseAgent):
assert self.model is not None
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
begin_silence_length = settings.behaviour_settings.vad_begin_silence_chunks
prob_threshold = settings.behaviour_settings.vad_prob_threshold
if prob > prob_threshold:
if self.i_since_speech > non_speech_patience:
if self.i_since_speech > non_speech_patience + begin_silence_length:
self.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
@@ -221,7 +264,7 @@ class VADAgent(BaseAgent):
continue
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
if len(self.audio_buffer) > begin_silence_length * len(chunk):
self.logger.debug("Speech ended.")
assert self.audio_out_socket is not None
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
@@ -229,3 +272,27 @@ class VADAgent(BaseAgent):
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages.
Expects messages to pause or resume the VAD processing from User Interrupt Agent.
:param msg: The received internal message.
"""
sender = msg.sender
if sender == settings.agent_settings.user_interrupt_name:
if msg.body == "PAUSE":
self.logger.info("Pausing VAD processing.")
self._paused.clear()
# If the robot needs to pick up speaking where it left off, do not set _reset_needed
self._reset_needed = True
elif msg.body == "RESUME":
self.logger.info("Resuming VAD processing.")
self._paused.set()
else:
self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}")
else:
self.logger.debug(f"Ignoring message from unknown sender: {sender}")

View File

@@ -0,0 +1,384 @@
import json
import zmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.program import ConditionalNorm, Program
from control_backend.schemas.ri_message import (
GestureCommand,
PauseCommand,
RIEndpoint,
SpeechCommand,
)
class UserInterruptAgent(BaseAgent):
"""
User Interrupt Agent.
This agent receives button_pressed events from the external HTTP API
(via ZMQ) and uses the associated context to trigger one of the following actions:
- Send a prioritized message to the `RobotSpeechAgent`
- Send a prioritized gesture to the `RobotGestureAgent`
- Send a belief override to the `BDIProgramManager`in order to activate a
trigger/conditional norm or complete a goal.
Prioritized actions clear the current RI queue before inserting the new item,
ensuring they are executed immediately after Pepper's current action has been fulfilled.
:ivar sub_socket: The ZMQ SUB socket used to receive user interrupts.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
self.pub_socket = None
self._trigger_map = {}
self._trigger_reverse_map = {}
self._goal_map = {} # id -> sluggified goal
self._goal_reverse_map = {} # sluggified goal -> id
self._cond_norm_map = {} # id -> sluggified cond norm
self._cond_norm_reverse_map = {} # sluggified cond norm -> id
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic.
Starts the background behavior to receive the user interrupts.
"""
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("button_pressed")
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
self.add_behavior(self._receive_button_event())
async def _receive_button_event(self):
"""
The behaviour of the UserInterruptAgent.
Continuous loop that receives button_pressed events from the button_pressed HTTP endpoint.
These events contain a type and a context.
These are the different types and contexts:
- type: "speech", context: string that the robot has to say.
- type: "gesture", context: single gesture name that the robot has to perform.
- type: "override", context: belief_id that overrides the goal/trigger/conditional norm.
- type: "pause", context: boolean indicating whether to pause
- type: "reset_phase", context: None, indicates to the BDI Core to
- type: "reset_experiment", context: None, indicates to the BDI Core to
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
try:
event_data = json.loads(body)
event_type = event_data.get("type") # e.g., "speech", "gesture"
event_context = event_data.get("context") # e.g., "Hello, I am Pepper!"
except json.JSONDecodeError:
self.logger.error("Received invalid JSON payload on topic %s", topic)
continue
self.logger.debug("Received event type %s", event_type)
if event_type == "speech":
await self._send_to_speech_agent(event_context)
self.logger.info(
"Forwarded button press (speech) with context '%s' to RobotSpeechAgent.",
event_context,
)
elif event_type == "gesture":
await self._send_to_gesture_agent(event_context)
self.logger.info(
"Forwarded button press (gesture) with context '%s' to RobotGestureAgent.",
event_context,
)
elif event_type == "override":
ui_id = str(event_context)
if asl_trigger := self._trigger_map.get(ui_id):
await self._send_to_bdi("force_trigger", asl_trigger)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
event_context,
)
elif asl_cond_norm := self._cond_norm_map.get(ui_id):
await self._send_to_bdi("force_norm", asl_cond_norm)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
event_context,
)
elif asl_goal := self._goal_map.get(ui_id):
await self._send_to_bdi_belief(asl_goal)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
event_context,
)
goal_achieve_msg = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
thread="achieve_goal",
body=ui_id,
)
await self.send(goal_achieve_msg)
else:
self.logger.warning("Could not determine which element to override.")
elif event_type == "pause":
self.logger.debug(
"Received pause/resume button press with context '%s'.", event_context
)
await self._send_pause_command(event_context)
if event_context:
self.logger.info("Sent pause command.")
else:
self.logger.info("Sent resume command.")
elif event_type in ["next_phase", "reset_phase", "reset_experiment"]:
await self._send_experiment_control_to_bdi_core(event_type)
else:
self.logger.warning(
"Received button press with unknown type '%s' (context: '%s').",
event_type,
event_context,
)
async def handle_message(self, msg: InternalMessage):
"""
Handle commands received from other internal Python agents.
"""
match msg.thread:
case "new_program":
self._create_mapping(msg.body)
case "trigger_start":
# msg.body is the sluggified trigger
asl_slug = msg.body
ui_id = self._trigger_reverse_map.get(asl_slug)
if ui_id:
payload = {"type": "trigger_update", "id": ui_id, "achieved": True}
await self._send_experiment_update(payload)
self.logger.info(f"UI Update: Trigger {asl_slug} started (ID: {ui_id})")
case "trigger_end":
asl_slug = msg.body
ui_id = self._trigger_reverse_map.get(asl_slug)
if ui_id:
payload = {"type": "trigger_update", "id": ui_id, "achieved": False}
await self._send_experiment_update(payload)
self.logger.info(f"UI Update: Trigger {asl_slug} ended (ID: {ui_id})")
case "transition_phase":
new_phase_id = msg.body
self.logger.info(f"Phase transition detected: {new_phase_id}")
payload = {"type": "phase_update", "id": new_phase_id}
await self._send_experiment_update(payload)
case "goal_start":
goal_name = msg.body
ui_id = self._goal_reverse_map.get(goal_name)
if ui_id:
payload = {"type": "goal_update", "id": ui_id, "active": True}
await self._send_experiment_update(payload)
self.logger.info(f"UI Update: Goal {goal_name} started (ID: {ui_id})")
case "active_norms_update":
norm_list = [s.strip("() '\",") for s in msg.body.split(",") if s.strip("() '\",")]
await self._broadcast_cond_norms(norm_list)
case _:
self.logger.debug(f"Received internal message on unhandled thread: {msg.thread}")
async def _broadcast_cond_norms(self, active_slugs: list[str]):
"""
Sends the current state of all conditional norms to the UI.
:param active_slugs: A list of slugs (strings) currently active in the BDI core.
"""
updates = []
for asl_slug, ui_id in self._cond_norm_reverse_map.items():
is_active = asl_slug in active_slugs
updates.append({"id": ui_id, "name": asl_slug, "active": is_active})
payload = {"type": "cond_norms_state_update", "norms": updates}
await self._send_experiment_update(payload, should_log=False)
# self.logger.debug(f"Broadcasted state for {len(updates)} conditional norms.")
def _create_mapping(self, program_json: str):
"""
Create mappings between UI IDs and ASL slugs for triggers, goals, and conditional norms
"""
try:
program = Program.model_validate_json(program_json)
self._trigger_map = {}
self._trigger_reverse_map = {}
self._goal_map = {}
self._cond_norm_map = {}
self._cond_norm_reverse_map = {}
for phase in program.phases:
for trigger in phase.triggers:
slug = AgentSpeakGenerator.slugify(trigger)
self._trigger_map[str(trigger.id)] = slug
self._trigger_reverse_map[slug] = str(trigger.id)
for goal in phase.goals:
self._goal_map[str(goal.id)] = AgentSpeakGenerator.slugify(goal)
self._goal_reverse_map[AgentSpeakGenerator.slugify(goal)] = str(goal.id)
for goal, id in self._goal_reverse_map.items():
self.logger.debug(f"Goal mapping: UI ID {goal} -> {id}")
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
asl_slug = AgentSpeakGenerator.slugify(norm)
norm_id = str(norm.id)
self._cond_norm_map[norm_id] = asl_slug
self._cond_norm_reverse_map[norm.norm] = norm_id
self.logger.debug("Added conditional norm %s", asl_slug)
self.logger.info(
f"Mapped {len(self._trigger_map)} triggers and {len(self._goal_map)} goals "
f"and {len(self._cond_norm_map)} conditional norms for UserInterruptAgent."
)
except Exception as e:
self.logger.error(f"Mapping failed: {e}")
async def _send_experiment_update(self, data, should_log: bool = True):
"""
Sends an update to the 'experiment' topic.
The SSE endpoint will pick this up and push it to the UI.
"""
if self.pub_socket:
topic = b"experiment"
body = json.dumps(data).encode("utf-8")
await self.pub_socket.send_multipart([topic, body])
if should_log:
self.logger.debug(f"Sent experiment update: {data}")
async def _send_to_speech_agent(self, text_to_say: str):
"""
method to send prioritized speech command to RobotSpeechAgent.
:param text_to_say: The string that the robot has to say.
"""
cmd = SpeechCommand(data=text_to_say, is_priority=True)
out_msg = InternalMessage(
to=settings.agent_settings.robot_speech_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
async def _send_to_gesture_agent(self, single_gesture_name: str):
"""
method to send prioritized gesture command to RobotGestureAgent.
:param single_gesture_name: The gesture tag that the robot has to perform.
"""
# the endpoint is set to always be GESTURE_SINGLE for user interrupts
cmd = GestureCommand(
endpoint=RIEndpoint.GESTURE_SINGLE, data=single_gesture_name, is_priority=True
)
out_msg = InternalMessage(
to=settings.agent_settings.robot_gesture_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
async def _send_to_bdi(self, thread: str, body: str):
"""Send slug of trigger to BDI"""
msg = InternalMessage(to=settings.agent_settings.bdi_core_name, thread=thread, body=body)
await self.send(msg)
self.logger.info(f"Directly forced {thread} in BDI: {body}")
async def _send_to_bdi_belief(self, asl_goal: str):
"""Send belief to BDI Core"""
belief_name = f"achieved_{asl_goal}"
belief = Belief(name=belief_name, arguments=None)
self.logger.debug(f"Sending belief to BDI Core: {belief_name}")
belief_message = BeliefMessage(create=[belief])
msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
thread="beliefs",
body=belief_message.model_dump_json(),
)
await self.send(msg)
async def _send_experiment_control_to_bdi_core(self, type):
"""
method to send experiment control buttons to bdi core.
:param type: the type of control button we should send to the bdi core.
"""
# Switch which thread we should send to bdi core
thread = ""
match type:
case "next_phase":
thread = "force_next_phase"
case "reset_phase":
thread = "reset_current_phase"
case "reset_experiment":
thread = "reset_experiment"
case _:
self.logger.warning(
"Received unknown experiment control type '%s' to send to BDI Core.",
type,
)
out_msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
thread=thread,
body="",
)
self.logger.debug("Sending experiment control '%s' to BDI Core.", thread)
await self.send(out_msg)
async def _send_pause_command(self, pause):
"""
Send a pause command to the Robot Interface via the RI Communication Agent.
Send a pause command to the other internal agents; for now just VAD agent.
"""
cmd = PauseCommand(data=pause)
message = InternalMessage(
to=settings.agent_settings.ri_communication_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(message)
if pause == "true":
# Send pause to VAD agent
vad_message = InternalMessage(
to=settings.agent_settings.vad_name,
sender=self.name,
body="PAUSE",
)
await self.send(vad_message)
self.logger.info("Sent pause command to VAD Agent and RI Communication Agent.")
else:
# Send resume to VAD agent
vad_message = InternalMessage(
to=settings.agent_settings.vad_name,
sender=self.name,
body="RESUME",
)
await self.send(vad_message)
self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.")

View File

@@ -3,9 +3,8 @@ import json
import logging
import zmq.asyncio
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from pydantic import ValidationError
from zmq.asyncio import Context, Socket
from control_backend.core.config import settings
@@ -16,38 +15,44 @@ logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
@router.post("/command/speech", status_code=202)
async def receive_command_speech(command: SpeechCommand, request: Request):
"""
Send a direct speech command to the robot.
Publishes the command to the internal 'command' topic. The
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`
or
will forward this to the robot.
:param command: The speech command payload.
:param request: The FastAPI request object.
"""
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Speech command received"}
@router.post("/command/gesture", status_code=202)
async def receive_command_gesture(command: GestureCommand, request: Request):
"""
Send a direct gesture command to the robot.
Publishes the command to the internal 'command' topic. The
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotGestureAgent`
will forward this to the robot.
:param command: The speech command payload.
:param request: The FastAPI request object.
"""
# Validate and retrieve data.
validated = None
valid_commands = (GestureCommand, SpeechCommand)
for command_model in valid_commands:
try:
validated = command_model.model_validate(command)
except ValidationError:
continue
if validated is None:
raise HTTPException(status_code=422, detail="Payload is not valid for command models")
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, validated.model_dump_json().encode()])
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}
return {"status": "Gesture command received"}
@router.get("/ping_check")
@@ -58,31 +63,27 @@ async def ping(request: Request):
pass
@router.get("/get_available_gesture_tags")
async def get_available_gesture_tags(request: Request):
@router.get("/commands/gesture/tags")
async def get_available_gesture_tags(request: Request, count=0):
"""
Endpoint to retrieve the available gesture tags for the robot.
:param request: The FastAPI request object.
:return: A list of available gesture tags.
"""
sub_socket = Context.instance().socket(zmq.SUB)
sub_socket.connect(settings.zmq_settings.internal_sub_address)
sub_socket.setsockopt(zmq.SUBSCRIBE, b"get_gestures")
req_socket = Context.instance().socket(zmq.REQ)
req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress)
pub_socket: Socket = request.app.state.endpoints_pub_socket
topic = b"send_gestures"
# TODO: Implement a way to get a certain ammount from the UI, rather than everything.
amount = None
# Check to see if we've got any count given in the query parameter
amount = count or None
timeout = 5 # seconds
await pub_socket.send_multipart([topic, amount.to_bytes(4, "big") if amount else b""])
await req_socket.send(f"{amount}".encode() if amount else b"None")
try:
_, body = await asyncio.wait_for(sub_socket.recv_multipart(), timeout=timeout)
body = await asyncio.wait_for(req_socket.recv(), timeout=timeout)
except TimeoutError:
body = b"tags: []"
logger.debug("got timeout error fetching gestures")
body = '{"tags": []}'
logger.debug("Got timeout error fetching gestures.")
# Handle empty response and JSON decode errors
available_tags = []
@@ -136,7 +137,6 @@ async def ping_stream(request: Request):
logger.info("Client disconnected from SSE")
break
logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}")
connectedJson = json.dumps(connected)
yield (f"data: {connectedJson}\n\n")

View File

@@ -0,0 +1,67 @@
import asyncio
import logging
import zmq
import zmq.asyncio
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.schemas.events import ButtonPressedEvent
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/button_pressed", status_code=202)
async def receive_button_event(event: ButtonPressedEvent, request: Request):
"""
Endpoint to handle external button press events.
Validates the event payload and publishes it to the internal 'button_pressed' topic.
Subscribers (in this case user_interrupt_agent) will pick this up to trigger
specific behaviors or state changes.
:param event: The parsed ButtonPressedEvent object.
:param request: The FastAPI request object.
"""
logger.debug("Received button event: %s | %s", event.type, event.context)
topic = b"button_pressed"
body = event.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Event received"}
@router.get("/experiment_stream")
async def experiment_stream(request: Request):
# Use the asyncio-compatible context
context = Context.instance()
socket = context.socket(zmq.SUB)
# Connect and subscribe
socket.connect(settings.zmq_settings.internal_sub_address)
socket.subscribe(b"experiment")
async def gen():
try:
while True:
# Check if client closed the tab
if await request.is_disconnected():
logger.info("Client disconnected from experiment stream.")
break
try:
parts = await asyncio.wait_for(socket.recv_multipart(), timeout=1.0)
_, message = parts
yield f"data: {message.decode().strip()}\n\n"
except TimeoutError:
continue
finally:
socket.close()
return StreamingResponse(gen(), media_type="text/event-stream")

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import logs, message, program, robot, sse
from control_backend.api.v1.endpoints import logs, message, program, robot, sse, user_interact
api_router = APIRouter()
@@ -13,3 +13,5 @@ api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Command
api_router.include_router(logs.router, tags=["Logs"])
api_router.include_router(program.router, tags=["Program"])
api_router.include_router(user_interact.router, tags=["Button Pressed Events"])

View File

@@ -60,6 +60,9 @@ class BaseAgent(ABC):
self._tasks: set[asyncio.Task] = set()
self._running = False
self._internal_pub_socket: None | azmq.Socket = None
self._internal_sub_socket: None | azmq.Socket = None
# Register immediately
AgentDirectory.register(name, self)
@@ -117,7 +120,7 @@ class BaseAgent(ABC):
task.cancel()
self.logger.info(f"Agent {self.name} stopped")
async def send(self, message: InternalMessage):
async def send(self, message: InternalMessage, should_log: bool = True):
"""
Send a message to another agent.
@@ -130,16 +133,26 @@ class BaseAgent(ABC):
:param message: The message to send.
"""
target = AgentDirectory.get(message.to)
if target:
await target.inbox.put(message)
self.logger.debug(f"Sent message {message.body} to {message.to} via regular inbox.")
else:
# Apparently target agent is on a different process, send via ZMQ
topic = f"internal/{message.to}".encode()
body = message.model_dump_json().encode()
await self._internal_pub_socket.send_multipart([topic, body])
self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.")
message.sender = self.name
to = message.to
receivers = [to] if isinstance(to, str) else to
for receiver in receivers:
target = AgentDirectory.get(receiver)
if target:
await target.inbox.put(message)
if should_log:
self.logger.debug(
f"Sent message {message.body} to {message.to} via regular inbox."
)
else:
# Apparently target agent is on a different process, send via ZMQ
topic = f"internal/{receiver}".encode()
body = message.model_dump_json().encode()
await self._internal_pub_socket.send_multipart([topic, body])
if should_log:
self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.")
async def _process_inbox(self):
"""
@@ -149,7 +162,6 @@ class BaseAgent(ABC):
"""
while self._running:
msg = await self.inbox.get()
self.logger.debug(f"Received message from {msg.sender}.")
await self.handle_message(msg)
async def _receive_internal_zmq_loop(self):
@@ -192,7 +204,16 @@ class BaseAgent(ABC):
:param coro: The coroutine to execute as a task.
"""
task = asyncio.create_task(coro)
async def try_coro(coro_: Coroutine):
try:
await coro_
except asyncio.CancelledError:
self.logger.debug("A behavior was canceled successfully: %s", coro_)
except Exception:
self.logger.warning("An exception occurred in a behavior.", exc_info=True)
task = asyncio.create_task(try_coro(coro))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
return task

View File

@@ -1,3 +1,12 @@
"""
An exhaustive overview of configurable options. All of these can be set using environment variables
by nesting with double underscores (__). Start from the ``Settings`` class.
For example, ``settings.ri_host`` becomes ``RI_HOST``, and
``settings.zmq_settings.ri_communication_address`` becomes
``ZMQ_SETTINGS__RI_COMMUNICATION_ADDRESS``.
"""
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -8,16 +17,17 @@ class ZMQSettings(BaseModel):
:ivar internal_pub_address: Address for the internal PUB socket.
:ivar internal_sub_address: Address for the internal SUB socket.
:ivar ri_command_address: Address for sending commands to the Robot Interface.
:ivar ri_communication_address: Address for receiving communication from the Robot Interface.
:ivar vad_agent_address: Address for the Voice Activity Detection (VAD) agent.
:ivar ri_communication_address: Address for the endpoint that the Robot Interface connects to.
:ivar vad_pub_address: Address that the VAD agent binds to and publishes audio segments to.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561"
ri_command_address: str = "tcp://localhost:0000"
ri_communication_address: str = "tcp://*:5555"
vad_agent_address: str = "tcp://localhost:5558"
internal_gesture_rep_adress: str = "tcp://localhost:7788"
vad_pub_address: str = "inproc://vad_stream"
class AgentSettings(BaseModel):
@@ -25,7 +35,6 @@ class AgentSettings(BaseModel):
Names of the various agents in the system. These names are used for routing messages.
:ivar bdi_core_name: Name of the BDI Core Agent.
:ivar bdi_belief_collector_name: Name of the Belief Collector Agent.
:ivar bdi_program_manager_name: Name of the BDI Program Manager Agent.
:ivar text_belief_extractor_name: Name of the Text Belief Extractor Agent.
:ivar vad_name: Name of the Voice Activity Detection (VAD) Agent.
@@ -36,9 +45,10 @@ class AgentSettings(BaseModel):
:ivar robot_speech_name: Name of the Robot Speech Agent.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
# agent names
bdi_core_name: str = "bdi_core_agent"
bdi_belief_collector_name: str = "belief_collector_agent"
bdi_program_manager_name: str = "bdi_program_manager_agent"
text_belief_extractor_name: str = "text_belief_extractor_agent"
vad_name: str = "vad_agent"
@@ -48,6 +58,7 @@ class AgentSettings(BaseModel):
ri_communication_name: str = "ri_communication_agent"
robot_speech_name: str = "robot_speech_agent"
robot_gesture_name: str = "robot_gesture_agent"
user_interrupt_name: str = "user_interrupt_agent"
class BehaviourSettings(BaseModel):
@@ -60,12 +71,16 @@ class BehaviourSettings(BaseModel):
:ivar vad_prob_threshold: Probability threshold for Voice Activity Detection.
:ivar vad_initial_since_speech: Initial value for 'since speech' counter in VAD.
:ivar vad_non_speech_patience_chunks: Number of non-speech chunks to wait before speech ended.
:ivar vad_begin_silence_chunks: The number of chunks of silence to prepend to speech chunks.
:ivar transcription_max_concurrent_tasks: Maximum number of concurrent transcription tasks.
:ivar transcription_words_per_minute: Estimated words per minute for transcription timing.
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
:ivar transcription_token_buffer: Buffer for transcription tokens.
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
sleep_s: float = 1.0
comm_setup_max_retries: int = 5
socket_poller_timeout_ms: int = 100
@@ -73,7 +88,8 @@ class BehaviourSettings(BaseModel):
# VAD settings
vad_prob_threshold: float = 0.5
vad_initial_since_speech: int = 100
vad_non_speech_patience_chunks: int = 3
vad_non_speech_patience_chunks: int = 15
vad_begin_silence_chunks: int = 6
# transcription behaviour
transcription_max_concurrent_tasks: int = 3
@@ -81,6 +97,9 @@ class BehaviourSettings(BaseModel):
transcription_words_per_token: float = 0.75 # (3 words = 4 tokens)
transcription_token_buffer: int = 10
# Text belief extractor settings
conversation_history_length_limit: int = 10
class LLMSettings(BaseModel):
"""
@@ -88,10 +107,19 @@ class LLMSettings(BaseModel):
:ivar local_llm_url: URL for the local LLM API.
:ivar local_llm_model: Name of the local LLM model to use.
:ivar chat_temperature: The temperature to use while generating chat responses.
:ivar code_temperature: The temperature to use while generating code-like responses like during
belief inference.
:ivar n_parallel: The number of parallel calls allowed to be made to the LLM.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "gpt-oss"
chat_temperature: float = 1.0
code_temperature: float = 0.3
n_parallel: int = 4
class VADSettings(BaseModel):
@@ -103,6 +131,8 @@ class VADSettings(BaseModel):
:ivar sample_rate_hz: Sample rate in Hz for the VAD model.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
repo_or_dir: str = "snakers4/silero-vad"
model_name: str = "silero_vad"
sample_rate_hz: int = 16000
@@ -116,6 +146,8 @@ class SpeechModelSettings(BaseModel):
:ivar openai_model_name: Model name for OpenAI-based speech recognition.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
# model identifiers for speech recognition
mlx_model_name: str = "mlx-community/whisper-small.en-mlx"
openai_model_name: str = "small.en"
@@ -127,6 +159,7 @@ class Settings(BaseSettings):
:ivar app_title: Title of the application.
:ivar ui_url: URL of the frontend UI.
:ivar ri_host: The hostname of the Robot Interface.
:ivar zmq_settings: ZMQ configuration.
:ivar agent_settings: Agent name configuration.
:ivar behaviour_settings: Behavior configuration.
@@ -139,6 +172,8 @@ class Settings(BaseSettings):
ui_url: str = "http://localhost:5173"
ri_host: str = "localhost"
zmq_settings: ZMQSettings = ZMQSettings()
agent_settings: AgentSettings = AgentSettings()

View File

@@ -4,6 +4,7 @@ import os
import yaml
import zmq
from zmq.log.handlers import PUBHandler
from control_backend.core.config import settings
@@ -51,15 +52,27 @@ def setup_logging(path: str = ".logging_config.yaml") -> None:
logging.warning(f"Could not load logging configuration: {e}")
config = {}
if "custom_levels" in config:
for level_name, level_num in config["custom_levels"].items():
add_logging_level(level_name, level_num)
custom_levels = config.get("custom_levels", {}) or {}
for level_name, level_num in custom_levels.items():
add_logging_level(level_name, level_num)
if config.get("handlers") is not None and config.get("handlers").get("ui"):
pub_socket = zmq.Context.instance().socket(zmq.PUB)
pub_socket.connect(settings.zmq_settings.internal_pub_address)
config["handlers"]["ui"]["interface_or_socket"] = pub_socket
logging.config.dictConfig(config)
# Patch ZMQ PUBHandler to know about custom levels
if custom_levels:
for logger_name in ("control_backend",):
logger = logging.getLogger(logger_name)
for handler in logger.handlers:
if isinstance(handler, PUBHandler):
# Use the INFO formatter as the default template
default_fmt = handler.formatters[logging.INFO]
for level_num in custom_levels.values():
handler.setFormatter(default_fmt, level=level_num)
else:
logging.warning("Logging config file not found. Using default logging configuration.")

View File

@@ -26,7 +26,6 @@ from zmq.asyncio import Context
# BDI agents
from control_backend.agents.bdi import (
BDIBeliefCollectorAgent,
BDICoreAgent,
TextBeliefExtractorAgent,
)
@@ -39,13 +38,14 @@ from control_backend.agents.communication import RICommunicationAgent
# LLM Agents
from control_backend.agents.llm import LLMAgent
# Perceive agents
from control_backend.agents.perception import VADAgent
# User Interrupt Agent
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
# Other backend imports
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.logging import setup_logging
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
logger = logging.getLogger(__name__)
@@ -95,6 +95,8 @@ async def lifespan(app: FastAPI):
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
app.state.endpoints_pub_socket = endpoints_pub_socket
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STARTING.value])
# --- Initialize Agents ---
logger.info("Initializing and starting agents.")
@@ -117,13 +119,6 @@ async def lifespan(app: FastAPI):
BDICoreAgent,
{
"name": settings.agent_settings.bdi_core_name,
"asl": "src/control_backend/agents/bdi/rules.asl",
},
),
"BeliefCollectorAgent": (
BDIBeliefCollectorAgent,
{
"name": settings.agent_settings.bdi_belief_collector_name,
},
),
"TextBeliefExtractorAgent": (
@@ -132,46 +127,46 @@ async def lifespan(app: FastAPI):
"name": settings.agent_settings.text_belief_extractor_name,
},
),
"VADAgent": (
VADAgent,
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
),
"ProgramManagerAgent": (
BDIProgramManager,
{
"name": settings.agent_settings.bdi_program_manager_name,
},
),
"UserInterruptAgent": (
UserInterruptAgent,
{
"name": settings.agent_settings.user_interrupt_name,
},
),
}
agents = []
vad_agent = None
for name, (agent_class, kwargs) in agents_to_start.items():
try:
logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**kwargs)
await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance
agents.append(agent_instance)
logger.info("Agent '%s' started successfully.", name)
except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
raise
assert vad_agent is not None
await vad_agent.reset_stream()
logger.info("Application startup complete.")
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.RUNNING.value])
yield
# --- APPLICATION SHUTDOWN ---
logger.info("%s is shutting down.", app.title)
# Potential shutdown logic goes here
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
# Additional shutdown logic goes here
for agent in agents:
await agent.stop()
logger.info("Application shutdown complete.")

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from control_backend.schemas.program import BaseGoal
from control_backend.schemas.program import Belief as ProgramBelief
class BeliefList(BaseModel):
"""
Represents a list of beliefs, separated from a program. Useful in agents which need to
communicate beliefs.
:ivar: beliefs: The list of beliefs.
"""
beliefs: list[ProgramBelief]
class GoalList(BaseModel):
goals: list[BaseGoal]

View File

@@ -6,18 +6,30 @@ class Belief(BaseModel):
Represents a single belief in the BDI system.
:ivar name: The functor or name of the belief (e.g., 'user_said').
:ivar arguments: A list of string arguments for the belief.
:ivar replace: If True, existing beliefs with this name should be replaced by this one.
:ivar arguments: A list of string arguments for the belief, or None if the belief has no
arguments.
"""
name: str
arguments: list[str]
replace: bool = False
arguments: list[str] | None = None
# To make it hashable
model_config = {"frozen": True}
class BeliefMessage(BaseModel):
"""
A container for transporting a list of beliefs between agents.
A container for communicating beliefs between agents.
:ivar create: Beliefs to create.
:ivar delete: Beliefs to delete.
:ivar replace: Beliefs to replace. Deletes all beliefs with the same name, replacing them with
one new belief.
"""
beliefs: list[Belief]
create: list[Belief] = []
delete: list[Belief] = []
replace: list[Belief] = []
def has_values(self) -> bool:
return len(self.create) > 0 or len(self.delete) > 0 or len(self.replace) > 0

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class ChatMessage(BaseModel):
role: str
content: str
class ChatHistory(BaseModel):
messages: list[ChatMessage]

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ButtonPressedEvent(BaseModel):
type: str
context: str

View File

@@ -1,3 +1,5 @@
from collections.abc import Iterable
from pydantic import BaseModel
@@ -11,7 +13,7 @@ class InternalMessage(BaseModel):
:ivar thread: An optional thread identifier/topic to categorize the message (e.g., 'beliefs').
"""
to: str
sender: str
to: str | Iterable[str]
sender: str | None = None
body: str
thread: str | None = None

View File

@@ -1,64 +1,215 @@
from pydantic import BaseModel
from enum import Enum
from typing import Literal
from pydantic import UUID4, BaseModel
class Norm(BaseModel):
class ProgramElement(BaseModel):
"""
Represents a behavioral norm.
Represents a basic element of our behavior program.
:ivar name: The researcher-assigned name of the element.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar norm: The actual norm text describing the behavior.
"""
id: str
label: str
norm: str
name: str
id: UUID4
# To make program elements hashable
model_config = {"frozen": True}
class Goal(BaseModel):
class LogicalOperator(Enum):
AND = "AND"
OR = "OR"
type Belief = KeywordBelief | SemanticBelief | InferredBelief
type BasicBelief = KeywordBelief | SemanticBelief
class KeywordBelief(ProgramElement):
"""
Represents an objective to be achieved.
Represents a belief that is set when the user spoken text contains a certain keyword.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar description: Detailed description of the goal.
:ivar achieved: Status flag indicating if the goal has been met.
:ivar keyword: The keyword on which this belief gets set.
"""
id: str
label: str
description: str
achieved: bool
class TriggerKeyword(BaseModel):
id: str
name: str = ""
keyword: str
class KeywordTrigger(BaseModel):
id: str
label: str
type: str
keywords: list[TriggerKeyword]
class SemanticBelief(ProgramElement):
"""
Represents a belief that is set by semantic LLM validation.
:ivar description: Description of how to form the belief, used by the LLM.
"""
description: str
class Phase(BaseModel):
class InferredBelief(ProgramElement):
"""
Represents a belief that gets formed by combining two beliefs with a logical AND or OR.
These beliefs can also be :class:`InferredBelief`, leading to arbitrarily deep nesting.
:ivar operator: The logical operator to apply.
:ivar left: The left part of the logical expression.
:ivar right: The right part of the logical expression.
"""
name: str = ""
operator: LogicalOperator
left: Belief
right: Belief
class Norm(ProgramElement):
name: str = ""
norm: str
critical: bool = False
class BasicNorm(Norm):
"""
Represents a behavioral norm.
:ivar norm: The actual norm text describing the behavior.
:ivar critical: When true, this norm should absolutely not be violated (checked separately).
"""
pass
class ConditionalNorm(Norm):
"""
Represents a norm that is only active when a condition is met (i.e., a certain belief holds).
:ivar condition: When to activate this norm.
"""
condition: Belief
type PlanElement = Goal | Action
class Plan(ProgramElement):
"""
Represents a list of steps to execute. Each of these steps can be a goal (with its own plan)
or a simple action.
:ivar steps: The actions or subgoals to execute, in order.
"""
name: str = ""
steps: list[PlanElement]
class BaseGoal(ProgramElement):
"""
Represents an objective to be achieved. This base version does not include a plan to achieve
this goal, and is used in semantic belief extraction.
:ivar description: A description of the goal, used to determine if it has been achieved.
:ivar can_fail: Whether we can fail to achieve the goal after executing the plan.
"""
description: str = ""
can_fail: bool = True
class Goal(BaseGoal):
"""
Represents an objective to be achieved. To reach the goal, we should execute the corresponding
plan. It inherits from the BaseGoal a variable `can_fail`, which if true will cause the
completion to be determined based on the conversation.
Instances of this goal are not hashable because a plan is not hashable.
:ivar plan: The plan to execute.
"""
plan: Plan
type Action = SpeechAction | GestureAction | LLMAction
class SpeechAction(ProgramElement):
"""
Represents the action of the robot speaking a literal text.
:ivar text: The text to speak.
"""
name: str = ""
text: str
class Gesture(BaseModel):
"""
Represents a gesture to be performed. Can be either a single gesture,
or a random gesture from a category (tag).
:ivar type: The type of the gesture, "tag" or "single".
:ivar name: The name of the single gesture or tag.
"""
type: Literal["tag", "single"]
name: str
class GestureAction(ProgramElement):
"""
Represents the action of the robot performing a gesture.
:ivar gesture: The gesture to perform.
"""
name: str = ""
gesture: Gesture
class LLMAction(ProgramElement):
"""
Represents the action of letting an LLM generate a reply based on its chat history
and an additional goal added in the prompt.
:ivar goal: The extra (temporary) goal to add to the LLM.
"""
name: str = ""
goal: str
class Trigger(ProgramElement):
"""
Represents a belief-based trigger. When a belief is set, the corresponding plan is executed.
:ivar condition: When to activate the trigger.
:ivar plan: The plan to execute.
"""
condition: Belief
plan: Plan
class Phase(ProgramElement):
"""
A distinct phase within a program, containing norms, goals, and triggers.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar norms: List of norms active in this phase.
:ivar goals: List of goals to pursue in this phase.
:ivar triggers: List of triggers that define transitions out of this phase.
"""
id: str
label: str
norms: list[Norm]
name: str = ""
norms: list[BasicNorm | ConditionalNorm]
goals: list[Goal]
triggers: list[KeywordTrigger]
triggers: list[Trigger]
class Program(BaseModel):

View File

@@ -0,0 +1,16 @@
from enum import Enum
PROGRAM_STATUS = b"internal/program_status"
"""A topic key for the program status."""
class ProgramStatus(Enum):
"""
Used in internal communication, to tell agents what the status of the program is.
For example, the VAD agent only starts listening when the program is RUNNING.
"""
STARTING = b"starting"
RUNNING = b"running"
STOPPING = b"stopping"

View File

@@ -14,6 +14,7 @@ class RIEndpoint(str, Enum):
GESTURE_TAG = "actuate/gesture/tag"
PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports"
PAUSE = ""
class RIMessage(BaseModel):
@@ -38,6 +39,7 @@ class SpeechCommand(RIMessage):
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str
is_priority: bool = False
class GestureCommand(RIMessage):
@@ -52,6 +54,7 @@ class GestureCommand(RIMessage):
RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG
]
data: str
is_priority: bool = False
@model_validator(mode="after")
def check_endpoint(self):
@@ -62,3 +65,15 @@ class GestureCommand(RIMessage):
if self.endpoint not in allowed:
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
return self
class PauseCommand(RIMessage):
"""
A specific command to pause or unpause the robot's actions.
:ivar endpoint: Fixed to ``RIEndpoint.PAUSE``.
:ivar data: A boolean indicating whether to pause (True) or unpause (False).
"""
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.PAUSE)
data: bool

View File

@@ -5,6 +5,7 @@ import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
@pytest.fixture
@@ -43,14 +44,12 @@ async def test_normal_setup(per_transcription_agent):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
per_vad_agent.reset_stream = AsyncMock()
await per_vad_agent.setup()
per_transcription_agent.assert_called_once()
per_transcription_agent.return_value.start.assert_called_once()
per_vad_agent._streaming_loop.assert_called_once()
per_vad_agent.reset_stream.assert_called_once()
assert per_vad_agent.audio_in_socket is not None
assert per_vad_agent.audio_out_socket is not None
@@ -92,7 +91,7 @@ def test_out_socket_creation(zmq_context):
assert per_vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("inproc://vad_stream")
@pytest.mark.asyncio
@@ -103,7 +102,7 @@ async def test_out_socket_creation_failure(zmq_context):
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.stop = AsyncMock()
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
@@ -124,7 +123,7 @@ async def test_stop(zmq_context, per_transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
@@ -142,3 +141,66 @@ async def test_stop(zmq_context, per_transcription_agent):
assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert per_vad_agent.audio_in_socket is None
assert per_vad_agent.audio_out_socket is None
@pytest.mark.asyncio
async def test_application_startup_complete(zmq_context):
"""Check that it resets the stream when the program finishes startup."""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
]
await vad_agent._status_loop()
vad_agent._reset_stream.assert_called_once()
vad_agent.program_sub_socket.close.assert_called_once()
@pytest.mark.asyncio
async def test_application_other_status(zmq_context):
"""
Check that it does nothing when the internal communication message is a status update, but not
running.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
]
try:
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()
@pytest.mark.asyncio
async def test_application_message_other(zmq_context):
"""
Check that it does nothing when there's an internal communication message that is not a status
update.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
try:
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()

View File

@@ -28,20 +28,26 @@ async def test_setup_bind(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
# Check PUB socket binding
fake_socket.bind.assert_any_call("tcp://localhost:5556")
# Check REP socket binding
fake_socket.bind.assert_called()
# Check SUB socket connection and subscriptions
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures")
# Check behavior was added
agent.add_behavior.assert_called() # Twice, even.
# Check behavior was added (twice: once for command loop, once for fetch gestures loop)
assert agent.add_behavior.call_count == 2
@pytest.mark.asyncio
@@ -53,28 +59,34 @@ async def test_setup_connect(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
# Check PUB socket connection (not binding)
fake_socket.connect.assert_any_call("tcp://localhost:5556")
fake_socket.connect.assert_any_call("tcp://internal:1234")
# Check REP socket binding (always binds)
fake_socket.bind.assert_called()
# Check behavior was added
agent.add_behavior.assert_called() # Twice, actually.
# Check behavior was added (twice)
assert agent.add_behavior.call_count == 2
@pytest.mark.asyncio
async def test_handle_message_sends_valid_gesture_command():
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket
payload = {
"endpoint": RIEndpoint.GESTURE_TAG,
"data": "hello", # "hello" is in availableTags
"data": "hello", # "hello" is in gesture_data
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
@@ -85,9 +97,9 @@ async def test_handle_message_sends_valid_gesture_command():
@pytest.mark.asyncio
async def test_handle_message_sends_non_gesture_command():
"""Internal message with non-gesture endpoint is not handled by this agent."""
"""Internal message with non-gesture endpoint is not forwarded by this agent."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
@@ -95,6 +107,7 @@ async def test_handle_message_sends_non_gesture_command():
await agent.handle_message(msg)
# Non-gesture endpoints should not be forwarded by this agent
pubsocket.send_json.assert_not_awaited()
@@ -102,10 +115,10 @@ async def test_handle_message_sends_non_gesture_command():
async def test_handle_message_rejects_invalid_gesture_tag():
"""Internal message with invalid gesture tag is not forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket
# Use a tag that's not in availableTags
# Use a tag that's not in gesture_data
payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
@@ -114,11 +127,70 @@ async def test_handle_message_rejects_invalid_gesture_tag():
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_message_sends_valid_single_gesture_command():
"""Internal message with valid single gesture is forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
agent.pubsocket = pubsocket
payload = {
"endpoint": RIEndpoint.GESTURE_SINGLE,
"data": "wave",
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_handle_message_rejects_invalid_single_gesture():
"""Internal message with invalid single gesture is not forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
agent.pubsocket = pubsocket
payload = {
"endpoint": RIEndpoint.GESTURE_SINGLE,
"data": "dance",
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_single_gesture_payload():
"""UI command with valid single gesture is read from SUB and published."""
command = {"endpoint": RIEndpoint.GESTURE_SINGLE, "data": "wave"}
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
@@ -137,12 +209,12 @@ async def test_zmq_command_loop_valid_gesture_payload():
async def recv_once():
# stop after first iteration
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -154,18 +226,18 @@ async def test_zmq_command_loop_valid_gesture_payload():
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_non_gesture_payload():
"""UI command with non-gesture endpoint is not handled by this agent."""
"""UI command with non-gesture endpoint is not forwarded by this agent."""
command = {"endpoint": "some_other_endpoint", "data": "anything"}
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -183,12 +255,12 @@ async def test_zmq_command_loop_invalid_gesture_tag():
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
return b"command", json.dumps(command).encode("utf-8")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -205,12 +277,12 @@ async def test_zmq_command_loop_invalid_json():
async def recv_once():
agent._running = False
return (b"command", b"{not_json}")
return b"command", b"{not_json}"
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -227,12 +299,12 @@ async def test_zmq_command_loop_ignores_send_gestures_topic():
async def recv_once():
agent._running = False
return (b"send_gestures", b"{}")
return b"send_gestures", b"{}"
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -245,148 +317,198 @@ async def test_zmq_command_loop_ignores_send_gestures_topic():
@pytest.mark.asyncio
async def test_fetch_gestures_loop_without_amount():
"""Fetch gestures request without amount returns all tags."""
fake_socket = AsyncMock()
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
return (b"send_gestures", b"{}")
return b"{}" # Empty JSON request
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent = RobotGestureAgent(
"robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"], address=""
)
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_awaited_once()
fake_repsocket.send.assert_awaited_once()
# Check the response contains all tags
args, kwargs = fake_socket.send_multipart.call_args
assert args[0][0] == b"get_gestures"
response = json.loads(args[0][1])
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert "tags" in response
assert len(response["tags"]) > 0
# Check it includes some expected tags
assert "hello" in response["tags"]
assert "yes" in response["tags"]
assert response["tags"] == ["hello", "yes", "no", "wave", "point"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_amount():
"""Fetch gestures request with amount returns limited tags."""
fake_socket = AsyncMock()
amount = 5
fake_repsocket = AsyncMock()
amount = 3
async def recv_once():
agent._running = False
return (b"send_gestures", json.dumps(amount).encode())
return json.dumps(amount).encode()
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent = RobotGestureAgent(
"robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"], address=""
)
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_awaited_once()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_socket.send_multipart.call_args
assert args[0][0] == b"get_gestures"
response = json.loads(args[0][1])
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert "tags" in response
assert len(response["tags"]) == amount
assert response["tags"] == ["hello", "yes", "no"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_ignores_command_topic():
"""Command topic is ignored in fetch gestures loop."""
fake_socket = AsyncMock()
async def test_fetch_gestures_loop_with_integer_request():
"""Fetch gestures request with integer amount."""
fake_repsocket = AsyncMock()
amount = 2
async def recv_once():
agent._running = False
return (b"command", b"{}")
return json.dumps(amount).encode()
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_not_awaited()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_invalid_request():
"""Invalid request body is handled gracefully."""
fake_socket = AsyncMock()
async def test_fetch_gestures_loop_with_invalid_json():
"""Invalid JSON request returns all tags."""
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
# Send a non-integer, non-JSON body
return (b"send_gestures", b"not_json")
return b"not_json"
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
# Should still send a response (all tags)
fake_socket.send_multipart.assert_awaited_once()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes", "no"]
def test_available_tags():
"""Test that availableTags returns the expected list."""
agent = RobotGestureAgent("robot_gesture")
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_non_integer_json():
"""Non-integer JSON request returns all tags."""
fake_repsocket = AsyncMock()
tags = agent.availableTags()
async def recv_once():
agent._running = False
return json.dumps({"not": "an_integer"}).encode()
assert isinstance(tags, list)
assert len(tags) > 0
# Check some expected tags are present
assert "hello" in tags
assert "yes" in tags
assert "no" in tags
# Check a non-existent tag is not present
assert "invalid_tag_not_in_list" not in tags
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes", "no"]
def test_gesture_data_attribute():
"""Test that gesture_data returns the expected list."""
gesture_data = ["hello", "yes", "no", "wave"]
agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data, address="")
assert agent.gesture_data == gesture_data
assert isinstance(agent.gesture_data, list)
assert len(agent.gesture_data) == 4
assert "hello" in agent.gesture_data
assert "yes" in agent.gesture_data
assert "no" in agent.gesture_data
assert "invalid_tag_not_in_list" not in agent.gesture_data
@pytest.mark.asyncio
async def test_stop_closes_sockets():
"""Stop method closes both sockets."""
"""Stop method closes all sockets."""
pubsocket = MagicMock()
subsocket = MagicMock()
agent = RobotGestureAgent("robot_gesture")
repsocket = MagicMock()
agent = RobotGestureAgent("robot_gesture", address="")
agent.pubsocket = pubsocket
agent.subsocket = subsocket
agent.repsocket = repsocket
await agent.stop()
pubsocket.close.assert_called_once()
subsocket.close.assert_called_once()
repsocket.close.assert_called_once()
@pytest.mark.asyncio
async def test_initialization_with_custom_gesture_data():
"""Agent can be initialized with custom gesture data."""
custom_gestures = ["custom1", "custom2", "custom3"]
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures)
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures, address="")
# Note: The current implementation doesn't use the gesture_data parameter
# in availableTags(). This test documents that behavior.
# If you update the agent to use gesture_data, update this test accordingly.
assert agent.gesture_data == custom_gestures
@pytest.mark.asyncio
async def test_fetch_gestures_loop_handles_exception():
"""Exception in fetch gestures loop is caught and logged."""
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
raise Exception("Test exception")
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
agent.repsocket = fake_repsocket
agent.logger = MagicMock()
agent._running = True
# Should not raise exception
await agent._fetch_gestures_loop()
# Exception should be logged
agent.logger.exception.assert_called_once()

View File

@@ -8,6 +8,11 @@ from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
from control_backend.core.agent_system import InternalMessage
def mock_speech_agent():
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
return agent
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
@@ -25,7 +30,11 @@ async def test_setup_bind(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -43,7 +52,11 @@ async def test_setup_connect(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -56,10 +69,10 @@ async def test_setup_connect(zmq_context, mocker):
async def test_handle_message_sends_command():
"""Internal message is forwarded to robot pub socket as JSON."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
payload = {"endpoint": "actuate/speech", "data": "hello"}
payload = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
@@ -70,7 +83,7 @@ async def test_handle_message_sends_command():
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_payload(zmq_context):
"""UI command is read from SUB and published."""
command = {"endpoint": "actuate/speech", "data": "hello"}
command = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False}
fake_socket = AsyncMock()
async def recv_once():
@@ -80,7 +93,7 @@ async def test_zmq_command_loop_valid_payload(zmq_context):
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -101,7 +114,7 @@ async def test_zmq_command_loop_invalid_json():
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -115,7 +128,7 @@ async def test_zmq_command_loop_invalid_json():
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
@@ -129,7 +142,7 @@ async def test_handle_message_invalid_payload():
async def test_stop_closes_sockets():
pubsocket = MagicMock()
subsocket = MagicMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
agent.subsocket = subsocket

View File

@@ -0,0 +1,186 @@
import pytest
from control_backend.agents.bdi.agentspeak_ast import (
AstAtom,
AstBinaryOp,
AstLiteral,
AstLogicalExpression,
AstNumber,
AstPlan,
AstProgram,
AstRule,
AstStatement,
AstString,
AstVar,
BinaryOperatorType,
StatementType,
TriggerType,
_coalesce_expr,
)
def test_ast_atom():
atom = AstAtom("test")
assert str(atom) == "test"
assert atom._to_agentspeak() == "test"
def test_ast_var():
var = AstVar("Variable")
assert str(var) == "Variable"
assert var._to_agentspeak() == "Variable"
def test_ast_number():
num = AstNumber(42)
assert str(num) == "42"
num_float = AstNumber(3.14)
assert str(num_float) == "3.14"
def test_ast_string():
s = AstString("hello")
assert str(s) == '"hello"'
def test_ast_literal():
lit = AstLiteral("functor", [AstAtom("atom"), AstNumber(1)])
assert str(lit) == "functor(atom, 1)"
lit_empty = AstLiteral("functor")
assert str(lit_empty) == "functor"
def test_ast_binary_op():
left = AstNumber(1)
right = AstNumber(2)
op = AstBinaryOp(left, BinaryOperatorType.GREATER_THAN, right)
assert str(op) == "1 > 2"
# Test logical wrapper
assert isinstance(op.left, AstLogicalExpression)
assert isinstance(op.right, AstLogicalExpression)
def test_ast_binary_op_parens():
# 1 > 2
inner = AstBinaryOp(AstNumber(1), BinaryOperatorType.GREATER_THAN, AstNumber(2))
# (1 > 2) & 3
outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstNumber(3))
assert str(outer) == "(1 > 2) & 3"
# 3 & (1 > 2)
outer_right = AstBinaryOp(AstNumber(3), BinaryOperatorType.AND, inner)
assert str(outer_right) == "3 & (1 > 2)"
def test_ast_binary_op_parens_negated():
inner = AstLogicalExpression(AstAtom("foo"), negated=True)
outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstAtom("bar"))
# The current implementation checks `if self.left.negated: l_str = f"({l_str})"`
# str(inner) is "not foo"
# so we expect "(not foo) & bar"
assert str(outer) == "(not foo) & bar"
outer_right = AstBinaryOp(AstAtom("bar"), BinaryOperatorType.AND, inner)
assert str(outer_right) == "bar & (not foo)"
def test_ast_logical_expression_negation():
expr = AstLogicalExpression(AstAtom("true"), negated=True)
assert str(expr) == "not true"
expr_neg_neg = ~expr
assert str(expr_neg_neg) == "true"
assert not expr_neg_neg.negated
# Invert a non-logical expression (wraps it)
term = AstAtom("true")
inverted = ~term
assert isinstance(inverted, AstLogicalExpression)
assert inverted.negated
assert str(inverted) == "not true"
def test_ast_logical_expression_no_negation():
# _as_logical on already logical expression
expr = AstLogicalExpression(AstAtom("x"))
# Doing binary op will call _as_logical
op = AstBinaryOp(expr, BinaryOperatorType.AND, AstAtom("y"))
assert isinstance(op.left, AstLogicalExpression)
assert op.left is expr # Should reuse instance
def test_ast_operators():
t1 = AstAtom("a")
t2 = AstAtom("b")
assert str(t1 & t2) == "a & b"
assert str(t1 | t2) == "a | b"
assert str(t1 >= t2) == "a >= b"
assert str(t1 > t2) == "a > b"
assert str(t1 <= t2) == "a <= b"
assert str(t1 < t2) == "a < b"
assert str(t1 == t2) == "a == b"
assert str(t1 != t2) == r"a \== b"
def test_coalesce_expr():
t = AstAtom("a")
assert str(t & "b") == 'a & "b"'
assert str(t & 1) == "a & 1"
assert str(t & 1.5) == "a & 1.5"
with pytest.raises(TypeError):
_coalesce_expr(None)
def test_ast_statement():
stmt = AstStatement(StatementType.DO_ACTION, AstLiteral("action"))
assert str(stmt) == ".action"
def test_ast_rule():
# Rule with condition
rule = AstRule(AstLiteral("head"), AstLiteral("body"))
assert str(rule) == "head :- body."
# Rule without condition
rule_simple = AstRule(AstLiteral("fact"))
assert str(rule_simple) == "fact."
def test_ast_plan():
plan = AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("goal"),
[AstLiteral("context")],
[AstStatement(StatementType.DO_ACTION, AstLiteral("action"))],
)
output = str(plan)
# verify parts exist
assert "+!goal" in output
assert ": context" in output
assert "<- .action." in output
def test_ast_plan_no_context():
plan = AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("goal"),
[],
[AstStatement(StatementType.DO_ACTION, AstLiteral("action"))],
)
output = str(plan)
assert "+!goal" in output
assert ": " not in output
assert "<- .action." in output
def test_ast_program():
prog = AstProgram(
rules=[AstRule(AstLiteral("fact"))],
plans=[AstPlan(TriggerType.ADDED_BELIEF, AstLiteral("b"), [], [])],
)
output = str(prog)
assert "fact." in output
assert "+b" in output

View File

@@ -0,0 +1,187 @@
import uuid
import pytest
from control_backend.agents.bdi.agentspeak_ast import AstProgram
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.schemas.program import (
BasicNorm,
ConditionalNorm,
Gesture,
GestureAction,
Goal,
InferredBelief,
KeywordBelief,
LLMAction,
LogicalOperator,
Phase,
Plan,
Program,
SemanticBelief,
SpeechAction,
Trigger,
)
@pytest.fixture
def generator():
return AgentSpeakGenerator()
def test_generate_empty_program(generator):
prog = Program(phases=[])
code = generator.generate(prog)
assert 'phase("end").' in code
assert "!notify_cycle" in code
def test_generate_basic_norm(generator):
norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="be nice")
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
assert f'norm("be nice") :- phase("{phase.id}").' in code
def test_generate_critical_norm(generator):
norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="safety", critical=True)
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
assert f'critical_norm("safety") :- phase("{phase.id}").' in code
def test_generate_conditional_norm(generator):
cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="please")
norm = ConditionalNorm(id=uuid.uuid4(), name="n1", norm="help", condition=cond)
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
assert 'norm("help")' in code
assert 'keyword_said("please")' in code
assert f"force_norm_{generator._slugify_str(norm.norm)}" in code
def test_generate_goal_and_plan(generator):
action = SpeechAction(id=uuid.uuid4(), name="s1", text="hello")
plan = Plan(id=uuid.uuid4(), name="p1", steps=[action])
# IMPORTANT: can_fail must be False for +achieved_ belief to be added
goal = Goal(id=uuid.uuid4(), name="g1", description="desc", plan=plan, can_fail=False)
phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
# Check trigger for goal
goal_slug = generator._slugify_str(goal.name)
assert f"+!{goal_slug}" in code
assert f'phase("{phase.id}")' in code
assert '!say("hello")' in code
# Check success belief addition
assert f"+achieved_{goal_slug}" in code
def test_generate_subgoal(generator):
subplan = Plan(id=uuid.uuid4(), name="p2", steps=[])
subgoal = Goal(id=uuid.uuid4(), name="sub1", description="sub", plan=subplan)
plan = Plan(id=uuid.uuid4(), name="p1", steps=[subgoal])
goal = Goal(id=uuid.uuid4(), name="g1", description="main", plan=plan)
phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[])
prog = Program(phases=[phase])
code = generator.generate(prog)
subgoal_slug = generator._slugify_str(subgoal.name)
# Main goal calls subgoal
assert f"!{subgoal_slug}" in code
# Subgoal plan exists
assert f"+!{subgoal_slug}" in code
def test_generate_trigger(generator):
cond = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc")
plan = Plan(id=uuid.uuid4(), name="p1", steps=[])
trigger = Trigger(id=uuid.uuid4(), name="t1", condition=cond, plan=plan)
phase = Phase(id=uuid.uuid4(), norms=[], goals=[], triggers=[trigger])
prog = Program(phases=[phase])
code = generator.generate(prog)
# Trigger logic is added to check_triggers
assert f"{generator.slugify(cond)}" in code
assert f'notify_trigger_start("{generator.slugify(trigger)}")' in code
assert f'notify_trigger_end("{generator.slugify(trigger)}")' in code
def test_phase_transition(generator):
phase1 = Phase(id=uuid.uuid4(), name="p1", norms=[], goals=[], triggers=[])
phase2 = Phase(id=uuid.uuid4(), name="p2", norms=[], goals=[], triggers=[])
prog = Program(phases=[phase1, phase2])
code = generator.generate(prog)
assert "transition_phase" in code
assert f'phase("{phase1.id}")' in code
assert f'phase("{phase2.id}")' in code
assert "force_transition_phase" in code
def test_astify_gesture(generator):
gesture = Gesture(type="single", name="wave")
action = GestureAction(id=uuid.uuid4(), name="g1", gesture=gesture)
ast = generator._astify(action)
assert str(ast) == 'gesture("single", "wave")'
def test_astify_llm_action(generator):
action = LLMAction(id=uuid.uuid4(), name="l1", goal="be funny")
ast = generator._astify(action)
assert str(ast) == 'reply_with_goal("be funny")'
def test_astify_inferred_belief_and(generator):
left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a")
right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b")
inf = InferredBelief(
id=uuid.uuid4(), name="i1", operator=LogicalOperator.AND, left=left, right=right
)
ast = generator._astify(inf)
assert 'keyword_said("a") & keyword_said("b")' == str(ast)
def test_astify_inferred_belief_or(generator):
left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a")
right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b")
inf = InferredBelief(
id=uuid.uuid4(), name="i1", operator=LogicalOperator.OR, left=left, right=right
)
ast = generator._astify(inf)
assert 'keyword_said("a") | keyword_said("b")' == str(ast)
def test_astify_semantic_belief(generator):
sb = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc")
ast = generator._astify(sb)
assert str(ast) == f"semantic_{generator._slugify_str(sb.name)}"
def test_slugify_not_implemented(generator):
with pytest.raises(NotImplementedError):
generator.slugify("not a program element")
def test_astify_not_implemented(generator):
with pytest.raises(NotImplementedError):
generator._astify("not a program element")
def test_process_phase_transition_from_none(generator):
# Initialize AstProgram manually as we are bypassing generate()
generator._asp = AstProgram()
# Should safely return doing nothing
generator._add_phase_transition(None, None)
assert len(generator._asp.plans) == 0

View File

@@ -1,4 +1,6 @@
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import agentspeak
@@ -18,7 +20,7 @@ def mock_agentspeak_env():
@pytest.fixture
def agent():
agent = BDICoreAgent("bdi_agent", "dummy.asl")
agent = BDICoreAgent("bdi_agent")
agent.send = AsyncMock()
agent.bdi_agent = MagicMock()
return agent
@@ -43,31 +45,70 @@ async def test_setup_no_asl(mock_agentspeak_env, agent):
@pytest.mark.asyncio
async def test_handle_belief_collector_message(agent, mock_settings):
async def test_handle_belief_message(agent, mock_settings):
"""Test that incoming beliefs are added to the BDI agent"""
beliefs = [Belief(name="user_said", arguments=["Hello"])]
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
sender=mock_settings.agent_settings.text_belief_extractor_name,
body=BeliefMessage(create=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
# Expect bdi_agent.call to be triggered to add belief
args = agent.bdi_agent.call.call_args.args
assert args[0] == agentspeak.Trigger.addition
assert args[1] == agentspeak.GoalType.belief
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
# Check for the specific call we expect among all calls
# bdi_agent.call is called multiple times (for transition_phase, check_triggers)
# We want to confirm the belief addition call exists
found_call = False
for call in agent.bdi_agent.call.call_args_list:
args = call.args
if (
args[0] == agentspeak.Trigger.addition
and args[1] == agentspeak.GoalType.belief
and args[2].functor == "user_said"
and args[2].args[0].functor == "Hello"
):
found_call = True
break
assert found_call, "Expected belief addition call not found in bdi_agent.call history"
@pytest.mark.asyncio
async def test_incorrect_belief_collector_message(agent, mock_settings):
async def test_handle_delete_belief_message(agent, mock_settings):
"""Test that incoming beliefs to be deleted are removed from the BDI agent"""
beliefs = [Belief(name="user_said", arguments=["Hello"])]
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.text_belief_extractor_name,
body=BeliefMessage(delete=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
found_call = False
for call in agent.bdi_agent.call.call_args_list:
args = call.args
if (
args[0] == agentspeak.Trigger.removal
and args[1] == agentspeak.GoalType.belief
and args[2].functor == "user_said"
and args[2].args[0].functor == "Hello"
):
found_call = True
break
assert found_call
@pytest.mark.asyncio
async def test_incorrect_belief_message(agent, mock_settings):
"""Test that incorrect message format triggers an exception."""
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
sender=mock_settings.agent_settings.text_belief_extractor_name,
body=json.dumps({"bad_format": "bad_format"}),
thread="beliefs",
)
@@ -77,11 +118,6 @@ async def test_incorrect_belief_collector_message(agent, mock_settings):
agent.bdi_agent.call.assert_not_called() # did not set belief
@pytest.mark.asyncio
async def test():
pass
@pytest.mark.asyncio
async def test_handle_llm_response(agent):
"""Test that LLM responses are forwarded to the Robot Speech Agent"""
@@ -116,11 +152,375 @@ async def test_custom_actions(agent):
# Invoke action
mock_term = MagicMock()
mock_term.args = ["Hello", "Norm", "Goal"]
mock_term.args = ["Hello", "Norm"]
mock_intention = MagicMock()
# Run generator
gen = action_fn(agent, mock_term, mock_intention)
next(gen) # Execute
agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal")
agent._send_to_llm.assert_called_with("Hello", "Norm", "")
def test_add_belief_sets_event(agent):
"""Test that a belief triggers wake event and call()"""
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="test_belief", arguments=["a", "b"])
belief_changes = BeliefMessage(replace=[belief])
agent._apply_belief_changes(belief_changes)
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
def test_apply_beliefs_empty_returns(agent):
"""Line: if not beliefs: return"""
agent._wake_bdi_loop = MagicMock()
agent._apply_belief_changes(BeliefMessage())
agent.bdi_agent.call.assert_not_called()
agent._wake_bdi_loop.set.assert_not_called()
def test_remove_belief_success_wakes_loop(agent):
"""Line: if result: wake set"""
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = True
agent._remove_belief("remove_me", ["x"])
assert agent.bdi_agent.call.called
call_args = agent.bdi_agent.call.call_args.args
trigger = call_args[0]
goaltype = call_args[1]
literal = call_args[2]
assert trigger == agentspeak.Trigger.removal
assert goaltype == agentspeak.GoalType.belief
assert literal.functor == "remove_me"
assert literal.args[0].functor == "x"
agent._wake_bdi_loop.set.assert_called()
def test_remove_belief_failure_does_not_wake(agent):
"""Line: else result is False"""
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = False
agent._remove_belief("not_there", ["y"])
assert agent.bdi_agent.call.called # removal was attempted
agent._wake_bdi_loop.set.assert_not_called()
def test_remove_all_with_name_wakes_loop(agent):
"""Cover _remove_all_with_name() removed counter + wake"""
agent._wake_bdi_loop = MagicMock()
fake_literal = agentspeak.Literal("delete_me", (agentspeak.Literal("arg1"),))
fake_key = ("delete_me", 1)
agent.bdi_agent.beliefs = {fake_key: {fake_literal}}
agent._remove_all_with_name("delete_me")
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
@pytest.mark.asyncio
async def test_bdi_step_true_branch_hits_line_67(agent):
"""Force step() to return True once so line 67 is actually executed"""
# counter that isn't tied to MagicMock.call_count ordering
counter = {"i": 0}
def fake_step():
counter["i"] += 1
return counter["i"] == 1 # True only first time
# Important: wrap fake_step into another mock so `.called` still exists
agent.bdi_agent.step = MagicMock(side_effect=fake_step)
agent.bdi_agent.shortest_deadline = MagicMock(return_value=None)
agent._running = True
agent._wake_bdi_loop = asyncio.Event()
agent._wake_bdi_loop.set()
task = asyncio.create_task(agent._bdi_loop())
await asyncio.sleep(0.01)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert agent.bdi_agent.step.called
assert counter["i"] >= 1 # proves True branch ran
def test_replace_belief_calls_remove_all(agent):
"""Cover: if belief.replace: self._remove_all_with_name()"""
agent._remove_all_with_name = MagicMock()
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="user_said", arguments=["Hello"])
belief_changes = BeliefMessage(replace=[belief])
agent._apply_belief_changes(belief_changes)
agent._remove_all_with_name.assert_called_with("user_said")
@pytest.mark.asyncio
async def test_send_to_llm_creates_prompt_and_sends(agent):
"""Cover entire _send_to_llm() including message send and logger.info"""
agent.bdi_agent = MagicMock() # ensure mocked BDI does not interfere
agent._wake_bdi_loop = MagicMock()
await agent._send_to_llm("hello world", "n1\nn2", "g1")
# send() was called
assert agent.send.called
sent_msg: InternalMessage = agent.send.call_args.args[0]
# Message routing values correct
assert sent_msg.to == settings.agent_settings.llm_name
assert "hello world" in sent_msg.body
# JSON contains split norms/goals
body = json.loads(sent_msg.body)
assert body["norms"] == ["n1", "n2"]
assert body["goals"] == ["g1"]
@pytest.mark.asyncio
async def test_deadline_sleep_branch(agent):
"""Specifically assert the if deadline: sleep → maybe_more_work=True branch"""
future_deadline = time.time() + 0.005
agent.bdi_agent.step.return_value = False
agent.bdi_agent.shortest_deadline.return_value = future_deadline
start_time = time.time()
agent._running = True
agent._wake_bdi_loop = asyncio.Event()
agent._wake_bdi_loop.set()
task = asyncio.create_task(agent._bdi_loop())
await asyncio.sleep(0.01)
task.cancel()
duration = time.time() - start_time
assert duration >= 0.004 # loop slept until deadline
@pytest.mark.asyncio
async def test_handle_new_program(agent):
agent._load_asl = AsyncMock()
agent.add_behavior = MagicMock()
# Mock existing loop task so it can be cancelled
mock_task = MagicMock()
mock_task.cancel = MagicMock()
agent._bdi_loop_task = mock_task
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
msg = InternalMessage(to="bdi_agent", thread="new_program", body="path/to/asl.asl")
await agent.handle_message(msg)
mock_task.cancel.assert_called_once()
agent._load_asl.assert_awaited_once_with("path/to/asl.asl")
agent.add_behavior.assert_called()
@pytest.mark.asyncio
async def test_handle_user_interrupts(agent, mock_settings):
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
# force_phase_transition
agent._set_goal = MagicMock()
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.user_interrupt_name,
thread="force_phase_transition",
body="",
)
await agent.handle_message(msg)
agent._set_goal.assert_called_with("transition_phase")
# force_trigger
agent._force_trigger = MagicMock()
msg.thread = "force_trigger"
msg.body = "trigger_x"
await agent.handle_message(msg)
agent._force_trigger.assert_called_with("trigger_x")
# force_norm
agent._force_norm = MagicMock()
msg.thread = "force_norm"
msg.body = "norm_y"
await agent.handle_message(msg)
agent._force_norm.assert_called_with("norm_y")
# force_next_phase
agent._force_next_phase = MagicMock()
msg.thread = "force_next_phase"
msg.body = ""
await agent.handle_message(msg)
agent._force_next_phase.assert_called_once()
# unknown interrupt
agent.logger = MagicMock()
msg.thread = "unknown_thing"
await agent.handle_message(msg)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_custom_action_reply_with_goal(agent):
agent._send_to_llm = MagicMock(side_effect=agent.send)
agent._add_custom_actions()
action_fn = agent.actions.actions[(".reply_with_goal", 3)]
mock_term = MagicMock(args=["msg", "norms", "goal"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
agent._send_to_llm.assert_called_with("msg", "norms", "goal")
@pytest.mark.asyncio
async def test_custom_action_notify_norms(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".notify_norms", 1)]
mock_term = MagicMock(args=["norms_list"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
agent.send.assert_called()
msg = agent.send.call_args[0][0]
assert msg.thread == "active_norms_update"
assert msg.body == "norms_list"
@pytest.mark.asyncio
async def test_custom_action_say(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".say", 1)]
mock_term = MagicMock(args=["hello"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
assert agent.send.call_count == 2
msgs = [c[0][0] for c in agent.send.call_args_list]
assert any(m.to == settings.agent_settings.robot_speech_name for m in msgs)
assert any(
m.to == settings.agent_settings.llm_name and m.thread == "assistant_message" for m in msgs
)
@pytest.mark.asyncio
async def test_custom_action_gesture(agent):
agent._add_custom_actions()
# Test single
action_fn = agent.actions.actions[(".gesture", 2)]
mock_term = MagicMock(args=["single", "wave"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
msg = agent.send.call_args[0][0]
assert "actuate/gesture/single" in msg.body
# Test tag
mock_term.args = ["tag", "happy"]
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
msg = agent.send.call_args[0][0]
assert "actuate/gesture/tag" in msg.body
@pytest.mark.asyncio
async def test_custom_action_notify_user_said(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".notify_user_said", 1)]
mock_term = MagicMock(args=["hello"])
gen = action_fn(agent, mock_term, MagicMock())
next(gen)
msg = agent.send.call_args[0][0]
assert msg.to == settings.agent_settings.llm_name
assert msg.thread == "user_message"
@pytest.mark.asyncio
async def test_custom_action_notify_trigger_start_end(agent):
agent._add_custom_actions()
# Start
action_fn = agent.actions.actions[(".notify_trigger_start", 1)]
gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock())
next(gen)
assert agent.send.call_args[0][0].thread == "trigger_start"
# End
action_fn = agent.actions.actions[(".notify_trigger_end", 1)]
gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock())
next(gen)
assert agent.send.call_args[0][0].thread == "trigger_end"
@pytest.mark.asyncio
async def test_custom_action_notify_goal_start(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".notify_goal_start", 1)]
gen = action_fn(agent, MagicMock(args=["g1"]), MagicMock())
next(gen)
assert agent.send.call_args[0][0].thread == "goal_start"
@pytest.mark.asyncio
async def test_custom_action_notify_transition_phase(agent):
agent._add_custom_actions()
action_fn = agent.actions.actions[(".notify_transition_phase", 2)]
gen = action_fn(agent, MagicMock(args=["old", "new"]), MagicMock())
next(gen)
msg = agent.send.call_args[0][0]
assert msg.thread == "transition_phase"
assert "old" in msg.body and "new" in msg.body
def test_remove_belief_no_args(agent):
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = True
agent._remove_belief("fact", None)
assert agent.bdi_agent.call.called
def test_set_goal_with_args(agent):
agent._wake_bdi_loop = MagicMock()
agent._set_goal("goal", ["arg1", "arg2"])
assert agent.bdi_agent.call.called
def test_format_belief_string():
assert BDICoreAgent.format_belief_string("b") == "b"
assert BDICoreAgent.format_belief_string("b", ["a1", "a2"]) == "b(a1,a2)"
def test_force_norm(agent):
agent._add_belief = MagicMock()
agent._force_norm("be_polite")
agent._add_belief.assert_called_with("force_be_polite")
def test_force_trigger(agent):
agent._set_goal = MagicMock()
agent._force_trigger("trig")
agent._set_goal.assert_called_with("trig")
def test_force_next_phase(agent):
agent._set_goal = MagicMock()
agent._force_next_phase()
agent._set_goal.assert_called_with("force_transition_phase")

View File

@@ -0,0 +1,297 @@
import asyncio
import json
import sys
import uuid
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import pytest
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program
# Fix Windows Proactor loop for zmq
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def make_valid_program_json(norm="N1", goal="G1") -> str:
return Program(
phases=[
Phase(
id=uuid.uuid4(),
name="Basic Phase",
norms=[
BasicNorm(
id=uuid.uuid4(),
name=norm,
norm=norm,
),
],
goals=[
Goal(
id=uuid.uuid4(),
name=goal,
description="This description can be used to determine whether the goal "
"has been achieved.",
plan=Plan(
id=uuid.uuid4(),
name="Goal Plan",
steps=[],
),
can_fail=False,
),
],
triggers=[],
),
],
).model_dump_json()
@pytest.mark.asyncio
async def test_create_agentspeak_and_send_to_bdi(mock_settings):
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
program = Program.model_validate_json(make_valid_program_json())
with patch("builtins.open", mock_open()) as mock_file:
await manager._create_agentspeak_and_send_to_bdi(program)
# Check file writing
mock_file.assert_called_with("src/control_backend/agents/bdi/agentspeak.asl", "w")
handle = mock_file()
handle.write.assert_called()
assert manager.send.await_count == 1
msg: InternalMessage = manager.send.await_args[0][0]
assert msg.thread == "new_program"
assert msg.to == mock_settings.agent_settings.bdi_core_name
assert msg.body == "src/control_backend/agents/bdi/agentspeak.asl"
@pytest.mark.asyncio
async def test_receive_programs_valid_and_invalid():
sub = AsyncMock()
sub.recv_multipart.side_effect = [
(b"program", b"{bad json"),
(b"program", make_valid_program_json().encode()),
]
manager = BDIProgramManager(name="program_manager_test")
manager._internal_pub_socket = AsyncMock()
manager.sub_socket = sub
manager._create_agentspeak_and_send_to_bdi = AsyncMock()
manager._send_clear_llm_history = AsyncMock()
manager._send_program_to_user_interrupt = AsyncMock()
manager._send_beliefs_to_semantic_belief_extractor = AsyncMock()
manager._send_goals_to_semantic_belief_extractor = AsyncMock()
try:
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
await manager._receive_programs()
except StopAsyncIteration:
pass
# Only valid Program should have triggered _send_to_bdi
assert manager._create_agentspeak_and_send_to_bdi.await_count == 1
forwarded: Program = manager._create_agentspeak_and_send_to_bdi.await_args[0][0]
assert forwarded.phases[0].norms[0].name == "N1"
assert forwarded.phases[0].goals[0].name == "G1"
# Verify history clear was triggered exactly once (for the valid program)
# The invalid program loop `continue`s before calling _send_clear_llm_history
assert manager._send_clear_llm_history.await_count == 1
@pytest.mark.asyncio
async def test_send_clear_llm_history(mock_settings):
# Ensure the mock returns a string for the agent name (just like in your LLM tests)
mock_settings.agent_settings.llm_agent_name = "llm_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
await manager._send_clear_llm_history()
assert manager.send.await_count == 2
msg: InternalMessage = manager.send.await_args_list[0][0][0]
# Verify the content and recipient
assert msg.body == "clear_history"
@pytest.mark.asyncio
async def test_handle_message_transition_phase(mock_settings):
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
# Setup state
prog = Program.model_validate_json(make_valid_program_json(norm="N1", goal="G1"))
manager._initialize_internal_state(prog)
# Test valid transition (to same phase for simplicity, or we need 2 phases)
# Let's create a program with 2 phases
phase2_id = uuid.uuid4()
phase2 = Phase(id=phase2_id, name="Phase 2", norms=[], goals=[], triggers=[])
prog.phases.append(phase2)
manager._initialize_internal_state(prog)
current_phase_id = str(prog.phases[0].id)
next_phase_id = str(phase2_id)
payload = json.dumps({"old": current_phase_id, "new": next_phase_id})
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
await manager.handle_message(msg)
assert str(manager._phase.id) == next_phase_id
# Allow background tasks to run (add_behavior)
await asyncio.sleep(0)
# Check notifications sent
# 1. beliefs to extractor
# 2. goals to extractor
# 3. notification to user interrupt
assert manager.send.await_count >= 3
# Verify user interrupt notification
calls = manager.send.await_args_list
ui_msgs = [
c[0][0] for c in calls if c[0][0].to == mock_settings.agent_settings.user_interrupt_name
]
assert len(ui_msgs) > 0
assert ui_msgs[-1].body == next_phase_id
@pytest.mark.asyncio
async def test_handle_message_transition_phase_desync():
manager = BDIProgramManager(name="program_manager_test")
manager.logger = MagicMock()
prog = Program.model_validate_json(make_valid_program_json())
manager._initialize_internal_state(prog)
current_phase_id = str(prog.phases[0].id)
# Request transition from WRONG old phase
payload = json.dumps({"old": "wrong_id", "new": "some_new_id"})
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
await manager.handle_message(msg)
# Should warn and do nothing
manager.logger.warning.assert_called_once()
assert "Phase transition desync detected" in manager.logger.warning.call_args[0][0]
assert str(manager._phase.id) == current_phase_id
@pytest.mark.asyncio
async def test_handle_message_transition_phase_end(mock_settings):
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
prog = Program.model_validate_json(make_valid_program_json())
manager._initialize_internal_state(prog)
current_phase_id = str(prog.phases[0].id)
payload = json.dumps({"old": current_phase_id, "new": "end"})
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
await manager.handle_message(msg)
assert manager._phase is None
# Allow background tasks to run (add_behavior)
await asyncio.sleep(0)
# Verify notification to user interrupt
assert manager.send.await_count == 1
msg_sent = manager.send.await_args[0][0]
assert msg_sent.to == mock_settings.agent_settings.user_interrupt_name
assert msg_sent.body == "end"
@pytest.mark.asyncio
async def test_handle_message_achieve_goal(mock_settings):
mock_settings.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
prog = Program.model_validate_json(make_valid_program_json(goal="TargetGoal"))
manager._initialize_internal_state(prog)
goal_id = str(prog.phases[0].goals[0].id)
msg = InternalMessage(to="me", sender="ui", body=goal_id, thread="achieve_goal")
await manager.handle_message(msg)
# Should send achieved goals to text extractor
assert manager.send.await_count == 1
msg_sent = manager.send.await_args[0][0]
assert msg_sent.to == mock_settings.agent_settings.text_belief_extractor_name
assert msg_sent.thread == "achieved_goals"
# Verify body
from control_backend.schemas.belief_list import GoalList
gl = GoalList.model_validate_json(msg_sent.body)
assert len(gl.goals) == 1
assert gl.goals[0].name == "TargetGoal"
@pytest.mark.asyncio
async def test_handle_message_achieve_goal_not_found():
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
manager.logger = MagicMock()
prog = Program.model_validate_json(make_valid_program_json())
manager._initialize_internal_state(prog)
msg = InternalMessage(to="me", sender="ui", body="non_existent_id", thread="achieve_goal")
await manager.handle_message(msg)
manager.send.assert_not_called()
manager.logger.debug.assert_called()
@pytest.mark.asyncio
async def test_setup(mock_settings):
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
def close_coro(coro):
coro.close()
return MagicMock()
manager.add_behavior = MagicMock(side_effect=close_coro)
mock_context = MagicMock()
mock_sub = MagicMock()
mock_context.socket.return_value = mock_sub
with patch(
"control_backend.agents.bdi.bdi_program_manager.Context.instance", return_value=mock_context
):
# We also need to mock file writing in _create_agentspeak_and_send_to_bdi
with patch("builtins.open", new_callable=MagicMock):
await manager.setup()
# Check logic
# 1. Sends default empty program to BDI
assert manager.send.await_count == 1
assert manager.send.await_args[0][0].to == mock_settings.agent_settings.bdi_core_name
# 2. Connects SUB socket
mock_sub.connect.assert_called_with(mock_settings.zmq_settings.internal_sub_address)
mock_sub.subscribe.assert_called_with("program")
# 3. Adds behavior
manager.add_behavior.assert_called()

View File

@@ -1,89 +0,0 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
BDIBeliefCollectorAgent,
)
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief
@pytest.fixture
def agent():
agent = BDIBeliefCollectorAgent("belief_collector_agent")
return agent
def make_msg(body: dict, sender: str = "sender"):
return InternalMessage(to="collector", sender=sender, body=json.dumps(body))
@pytest.mark.asyncio
async def test_handle_message_routes_belief_text(agent, mocker):
"""
Test that when a message is received, _handle_belief_text is called with that message.
"""
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}
spy = mocker.patch.object(agent, "_handle_belief_text", new_callable=AsyncMock)
await agent.handle_message(make_msg(payload))
spy.assert_awaited_once_with(payload, "sender")
@pytest.mark.asyncio
async def test_handle_message_routes_emotion(agent, mocker):
payload = {"type": "emotion_extraction_text"}
spy = mocker.patch.object(agent, "_handle_emo_text", new_callable=AsyncMock)
await agent.handle_message(make_msg(payload))
spy.assert_awaited_once_with(payload, "sender")
@pytest.mark.asyncio
async def test_handle_message_bad_json(agent, mocker):
agent._handle_belief_text = AsyncMock()
bad_msg = InternalMessage(to="collector", sender="sender", body="not json")
await agent.handle_message(bad_msg)
agent._handle_belief_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
expected = [Belief(name="user_said", arguments=["hello"])]
await agent._handle_belief_text(payload, "origin")
spy.assert_awaited_once_with(expected, origin="origin")
@pytest.mark.asyncio
async def test_handle_belief_text_no_send_when_empty(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
await agent._handle_belief_text(payload, "origin")
spy.assert_not_awaited()
@pytest.mark.asyncio
async def test_send_beliefs_to_bdi(agent):
agent.send = AsyncMock()
beliefs = [Belief(name="user_said", arguments=["hello", "world"])]
await agent._send_beliefs_to_bdi(beliefs, origin="origin")
agent.send.assert_awaited_once()
sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]

View File

@@ -0,0 +1,554 @@
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from control_backend.agents.bdi import TextBeliefExtractorAgent
from control_backend.agents.bdi.text_belief_extractor_agent import BeliefState
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_list import BeliefList
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import (
BaseGoal, # Changed from Goal
ConditionalNorm,
KeywordBelief,
LLMAction,
Phase,
Plan,
Program,
SemanticBelief,
Trigger,
)
@pytest.fixture
def llm():
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
# We must ensure _query_llm returns a dictionary so iterating it doesn't fail
llm._query_llm = AsyncMock(return_value={})
return llm
@pytest.fixture
def agent(llm):
with patch(
"control_backend.agents.bdi.text_belief_extractor_agent.TextBeliefExtractorAgent.LLM",
return_value=llm,
):
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
return agent
@pytest.fixture
def sample_program():
return Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[
ConditionalNorm(
name="Some norm",
id=uuid.uuid4(),
norm="Use nautical terms.",
critical=False,
condition=SemanticBelief(
name="is_pirate",
id=uuid.uuid4(),
description="The user is a pirate. Perhaps because they say "
"they are, or because they speak like a pirate "
'with terms like "arr".',
),
),
],
goals=[],
triggers=[
Trigger(
name="Some trigger",
id=uuid.uuid4(),
condition=SemanticBelief(
name="no_more_booze",
id=uuid.uuid4(),
description="There is no more alcohol.",
),
plan=Plan(
name="Some plan",
id=uuid.uuid4(),
steps=[
LLMAction(
name="Some action",
id=uuid.uuid4(),
goal="Suggest eating chocolate instead.",
),
],
),
),
],
),
],
)
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
parsed = BeliefMessage.model_validate_json(sent.body)
replaced_last = parsed.replace.pop()
assert replaced_last.name == "user_said"
assert replaced_last.arguments == [transcription]
@pytest.mark.asyncio
async def test_query_llm():
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "null",
}
}
]
}
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_async_client = MagicMock()
mock_async_client.__aenter__.return_value = mock_client
mock_async_client.__aexit__.return_value = None
with patch(
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
return_value=mock_async_client,
):
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
res = await llm._query_llm("hello world", {"type": "null"})
# Response content was set as "null", so should be deserialized as None
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success(llm):
llm._query_llm.return_value = None
res = await llm.query("hello world", {"type": "null"})
llm._query_llm.assert_called_once()
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success_after_failure(llm):
llm._query_llm.side_effect = [KeyError(), "real value"]
res = await llm.query("hello world", {"type": "string"})
assert llm._query_llm.call_count == 2
assert res == "real value"
@pytest.mark.asyncio
async def test_retry_query_llm_failures(llm):
llm._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
res = await llm.query("hello world", {"type": "string"})
assert llm._query_llm.call_count == 3
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_fail_immediately(llm):
llm._query_llm.side_effect = [KeyError(), "real value"]
res = await llm.query("hello world", {"type": "string"}, tries=1)
assert llm._query_llm.call_count == 1
assert res is None
@pytest.mark.asyncio
async def test_extracting_semantic_beliefs(agent):
"""
The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly.
"""
assert len(agent.belief_inferrer.available_beliefs) == 0
beliefs = BeliefList(
beliefs=[
KeywordBelief(
id=uuid.uuid4(),
name="keyword_hello",
keyword="hello",
),
SemanticBelief(
id=uuid.uuid4(), name="semantic_hello_1", description="Some semantic belief 1"
),
SemanticBelief(
id=uuid.uuid4(), name="semantic_hello_2", description="Some semantic belief 2"
),
]
)
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=beliefs.model_dump_json(),
thread="beliefs",
),
)
assert len(agent.belief_inferrer.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_invalid_beliefs(agent, sample_program):
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
assert len(agent.belief_inferrer.available_beliefs) == 2
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=json.dumps({"phases": "Invalid"}),
thread="beliefs",
),
)
assert len(agent.belief_inferrer.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_robot_response(agent):
initial_length = len(agent.conversation.messages)
response = "Hi, I'm Pepper. What's your name?"
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.llm_name,
body=response,
),
)
assert len(agent.conversation.messages) == initial_length + 1
assert agent.conversation.messages[-1].role == "assistant"
assert agent.conversation.messages[-1].content == response
@pytest.mark.asyncio
async def test_simulated_real_turn_with_beliefs(agent, llm, sample_program):
"""Test sending user message to extract beliefs from."""
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with the belief that there's no more booze
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
assert len(agent.conversation.messages) == 0
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="We're all out of schnaps.",
),
)
assert len(agent.conversation.messages) == 1
# There should be a belief set and sent to the BDI core, as well as the user_said belief
assert agent.send.call_count == 2
# First should be the beliefs message
message: InternalMessage = agent.send.call_args_list[1].args[0]
beliefs = BeliefMessage.model_validate_json(message.body)
assert len(beliefs.create) == 1
assert beliefs.create[0].name == "no_more_booze"
@pytest.mark.asyncio
async def test_simulated_real_turn_no_beliefs(agent, llm, sample_program):
"""Test a user message to extract beliefs from, but no beliefs are formed."""
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with no new beliefs
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="Hello there!",
),
)
# Only the user_said belief should've been sent
agent.send.assert_called_once()
@pytest.mark.asyncio
async def test_simulated_real_turn_no_new_beliefs(agent, llm, sample_program):
"""
Test a user message to extract beliefs from, but no new beliefs are formed because they already
existed.
"""
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent._current_beliefs = BeliefState(true={InternalBelief(name="is_pirate", arguments=None)})
# Send a user message with the belief the user is a pirate, still
llm._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="Arr, nice to meet you, matey.",
),
)
# Only the user_said belief should've been sent, as no beliefs have changed
agent.send.assert_called_once()
@pytest.mark.asyncio
async def test_simulated_real_turn_remove_belief(agent, llm, sample_program):
"""
Test a user message to extract beliefs from, but an existing belief is determined no longer to
hold.
"""
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent._current_beliefs = BeliefState(
true={InternalBelief(name="no_more_booze", arguments=None)},
)
# Send a user message with the belief the user is a pirate, still
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="I found an untouched barrel of wine!",
),
)
# Both user_said and belief change should've been sent
assert agent.send.call_count == 2
# Agent's current beliefs should've changed
assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false)
@pytest.mark.asyncio
async def test_infer_goal_completions_sends_beliefs(agent, llm):
"""Test that inferred goal completions are sent to the BDI core."""
goal = BaseGoal(
id=uuid.uuid4(), name="Say Hello", description="The user said hello", can_fail=True
)
agent.goal_inferrer.goals = {goal}
# Mock goal inference: goal is achieved
llm.query = AsyncMock(return_value=True)
await agent._infer_goal_completions()
# Should send belief change to BDI core
agent.send.assert_awaited_once()
sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
parsed = BeliefMessage.model_validate_json(sent.body)
assert len(parsed.create) == 1
assert parsed.create[0].name == "achieved_say_hello"
@pytest.mark.asyncio
async def test_llm_failure_handling(agent, llm, sample_program):
"""
Check that the agent handles failures gracefully without crashing.
"""
llm._query_llm.side_effect = httpx.HTTPError("")
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
belief_changes = await agent.belief_inferrer.infer_from_conversation(
ChatHistory(
messages=[ChatMessage(role="user", content="Good day!")],
),
)
assert len(belief_changes.true) == 0
assert len(belief_changes.false) == 0
def test_belief_state_bool():
# Empty
bs = BeliefState()
assert not bs
# True set
bs_true = BeliefState(true={InternalBelief(name="a", arguments=None)})
assert bs_true
# False set
bs_false = BeliefState(false={InternalBelief(name="a", arguments=None)})
assert bs_false
@pytest.mark.asyncio
async def test_handle_beliefs_message_validation_error(agent, mock_settings):
# Invalid JSON
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
msg = InternalMessage(
to="me",
sender=mock_settings.agent_settings.bdi_program_manager_name,
thread="beliefs",
body="invalid json",
)
# Should log warning and return
agent.logger = MagicMock()
await agent.handle_message(msg)
agent.logger.warning.assert_called()
# Invalid Model
msg.body = json.dumps({"beliefs": [{"invalid": "obj"}]})
await agent.handle_message(msg)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_handle_goals_message_validation_error(agent, mock_settings):
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
msg = InternalMessage(
to="me",
sender=mock_settings.agent_settings.bdi_program_manager_name,
thread="goals",
body="invalid json",
)
agent.logger = MagicMock()
await agent.handle_message(msg)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_handle_goal_achieved_message_validation_error(agent, mock_settings):
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
msg = InternalMessage(
to="me",
sender=mock_settings.agent_settings.bdi_program_manager_name,
thread="achieved_goals",
body="invalid json",
)
agent.logger = MagicMock()
await agent.handle_message(msg)
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_goal_inferrer_infer_from_conversation(agent, llm):
# Setup goals
# Use BaseGoal object as typically received by the extractor
g1 = BaseGoal(id=uuid.uuid4(), name="g1", description="desc", can_fail=True)
# Use real GoalAchievementInferrer
from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer
inferrer = GoalAchievementInferrer(llm)
inferrer.goals = {g1}
# Mock LLM response
llm._query_llm.return_value = True
completions = await inferrer.infer_from_conversation(ChatHistory(messages=[]))
assert completions
# slugify uses slugify library, hard to predict exact string without it,
# but we can check values
assert list(completions.values())[0] is True
def test_apply_conversation_message_limit(agent):
with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s:
mock_s.behaviour_settings.conversation_history_length_limit = 2
agent.conversation.messages = []
agent._apply_conversation_message(ChatMessage(role="user", content="1"))
agent._apply_conversation_message(ChatMessage(role="assistant", content="2"))
agent._apply_conversation_message(ChatMessage(role="user", content="3"))
assert len(agent.conversation.messages) == 2
assert agent.conversation.messages[0].content == "2"
assert agent.conversation.messages[1].content == "3"
@pytest.mark.asyncio
async def test_handle_program_manager_reset(agent):
with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s:
mock_s.agent_settings.bdi_program_manager_name = "pm"
agent.conversation.messages = [ChatMessage(role="user", content="hi")]
agent.belief_inferrer.available_beliefs = [
SemanticBelief(id=uuid.uuid4(), name="b", description="d")
]
msg = InternalMessage(to="me", sender="pm", thread="conversation_history", body="reset")
await agent.handle_message(msg)
assert len(agent.conversation.messages) == 0
assert len(agent.belief_inferrer.available_beliefs) == 0
def test_split_into_chunks():
from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer
items = [1, 2, 3, 4, 5]
chunks = SemanticBeliefInferrer._split_into_chunks(items, 2)
assert len(chunks) == 2
assert len(chunks[0]) + len(chunks[1]) == 5
@pytest.mark.asyncio
async def test_infer_beliefs_call(agent, llm):
from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer
inferrer = SemanticBeliefInferrer(llm)
sb = SemanticBelief(id=uuid.uuid4(), name="is_happy", description="User is happy")
llm.query = AsyncMock(return_value={"is_happy": True})
res = await inferrer._infer_beliefs(ChatHistory(messages=[]), [sb])
assert res == {"is_happy": True}
llm.query.assert_called_once()
@pytest.mark.asyncio
async def test_infer_goal_call(agent, llm):
from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer
inferrer = GoalAchievementInferrer(llm)
goal = BaseGoal(id=uuid.uuid4(), name="g1", description="d")
llm.query = AsyncMock(return_value=True)
res = await inferrer._infer_goal(ChatHistory(messages=[]), goal)
assert res is True
llm.query.assert_called_once()

View File

@@ -1,58 +0,0 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
TextBeliefExtractorAgent,
)
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
return agent
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_transcription_demo(agent, mock_settings):
transcription = "this is a test"
await agent._process_transcription_demo(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]

View File

@@ -4,6 +4,8 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.ri_message import PauseCommand, RIEndpoint
def speech_agent_path():
@@ -53,7 +55,11 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
MockGesture.return_value.start = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -67,6 +73,7 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
address="tcp://localhost:5556",
bind=False,
gesture_data=[],
single_gesture_data=[],
)
agent.add_behavior.assert_called_once()
@@ -82,7 +89,11 @@ async def test_setup_binds_when_requested(zmq_context):
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
agent.add_behavior = MagicMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
@@ -150,6 +161,7 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context):
@pytest.mark.asyncio
async def test_handle_disconnection_publishes_and_reconnects():
pub_socket = AsyncMock()
pub_socket.close = MagicMock()
agent = RICommunicationAgent("ri_comm")
agent.pub_socket = pub_socket
agent.connected = True
@@ -232,6 +244,25 @@ async def test_handle_negotiation_response_unhandled_id():
)
@pytest.mark.asyncio
async def test_handle_negotiation_response_audio(zmq_context):
agent = RICommunicationAgent("ri_comm")
with patch(
"control_backend.agents.communication.ri_communication_agent.VADAgent", autospec=True
) as MockVAD:
MockVAD.return_value.start = AsyncMock()
await agent._handle_negotiation_response(
{"data": [{"id": "audio", "port": 7000, "bind": False}]}
)
MockVAD.assert_called_once_with(
audio_in_address="tcp://localhost:7000", audio_in_bind=False
)
MockVAD.return_value.start.assert_awaited_once()
@pytest.mark.asyncio
async def test_stop_closes_sockets():
req = MagicMock()
@@ -322,6 +353,7 @@ async def test_listen_loop_generic_exception():
@pytest.mark.asyncio
async def test_handle_disconnection_timeout(monkeypatch):
pub = AsyncMock()
pub.close = MagicMock()
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
agent = RICommunicationAgent("ri_comm")
@@ -354,3 +386,48 @@ async def test_listen_loop_ping_sends_internal(zmq_context):
await agent._listen_loop()
pub_socket.send_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_negotiate_req_socket_none_causes_retry(zmq_context):
agent = RICommunicationAgent("ri_comm")
agent._req_socket = None
result = await agent._negotiate_connection(max_retries=1)
assert result is False
@pytest.mark.asyncio
async def test_handle_message_pause_command(zmq_context):
"""Test handle_message with a valid PauseCommand."""
agent = RICommunicationAgent("ri_comm")
agent._req_socket = AsyncMock()
agent.logger = MagicMock()
agent._req_socket.recv_json.return_value = {"status": "ok"}
pause_cmd = PauseCommand(data=True)
msg = InternalMessage(to="ri_comm", sender="user_int", body=pause_cmd.model_dump_json())
await agent.handle_message(msg)
agent._req_socket.send_json.assert_awaited_once()
args = agent._req_socket.send_json.await_args[0][0]
assert args["endpoint"] == RIEndpoint.PAUSE.value
assert args["data"] is True
@pytest.mark.asyncio
async def test_handle_message_invalid_pause_command(zmq_context):
"""Test handle_message with invalid JSON."""
agent = RICommunicationAgent("ri_comm")
agent._req_socket = AsyncMock()
agent.logger = MagicMock()
msg = InternalMessage(to="ri_comm", sender="user_int", body="invalid json")
await agent.handle_message(msg)
agent.logger.warning.assert_called_with("Incorrect message format for PauseCommand.")
agent._req_socket.send_json.assert_not_called()

View File

@@ -49,23 +49,29 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock() # Mock the send method to verify replies
mock_logger = MagicMock()
agent.logger = mock_logger
# Simulate receiving a message from BDI
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
thread="prompt_message", # REQUIRED: thread must match handle_message logic
)
await agent.handle_message(msg)
# Verification
# "Hello world." constitutes one sentence/chunk based on punctuation split
# The agent should call send once with the full sentence
# The agent should call send once with the full sentence, PLUS once more for full reply
assert agent.send.called
args = agent.send.call_args[0][0]
assert args.to == mock_settings.agent_settings.bdi_core_name
assert "Hello world." in args.body
# Check args. We expect at least one call sending "Hello world."
calls = agent.send.call_args_list
bodies = [c[0][0].body for c in calls]
assert any("Hello world." in b for b in bodies)
@pytest.mark.asyncio
@@ -77,18 +83,23 @@ async def test_llm_processing_errors(mock_httpx_client, mock_settings):
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
thread="prompt_message",
)
# HTTP Error
# HTTP Error: stream method RAISES exception immediately
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
await agent.handle_message(msg)
assert "LLM service unavailable." in agent.send.call_args[0][0].body
# Check that error message was sent
assert agent.send.called
assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body
# General Exception
agent.send.reset_mock()
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
await agent.handle_message(msg)
assert "Error processing the request." in agent.send.call_args[0][0].body
assert "Error processing the request." in agent.send.call_args_list[0][0][0].body
@pytest.mark.asyncio
@@ -110,16 +121,19 @@ async def test_llm_json_error(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Ensure logger is mocked
agent.logger = MagicMock()
with patch.object(agent.logger, "error") as log:
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
await agent.handle_message(msg)
log.assert_called() # Should log JSONDecodeError
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
thread="prompt_message",
)
await agent.handle_message(msg)
agent.logger.error.assert_called() # Should log JSONDecodeError
def test_llm_instructions():
@@ -134,3 +148,177 @@ def test_llm_instructions():
text_def = instr_def.build_developer_instruction()
assert "Norms to follow" in text_def
assert "Goals to reach" in text_def
@pytest.mark.asyncio
async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, mock_settings):
"""
Covers the ValidationError branch:
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
Assert: no message is sent.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Invalid JSON that triggers ValidationError in LLMPromptMessage
invalid_json = '{"text": "Hi", "wrong_field": 123}' # field not in schema
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=invalid_json,
thread="prompt_message",
)
await agent.handle_message(msg)
# Should not send any reply
agent.send.assert_not_called()
@pytest.mark.asyncio
async def test_handle_message_ignored_sender_branch_no_send(mock_httpx_client, mock_settings):
"""
Covers the else branch for messages not from BDI core:
else:
self.logger.debug("Message ignored (not from BDI core.")
Assert: no message is sent.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
msg = InternalMessage(
to="llm_agent",
sender="some_other_agent", # Not BDI core
body='{"text": "Hi"}',
)
await agent.handle_message(msg)
# Should not send any reply
agent.send.assert_not_called()
@pytest.mark.asyncio
async def test_query_llm_yields_final_tail_chunk(mock_settings):
"""
Covers the branch: if current_chunk: yield current_chunk
Ensure that the last partial chunk is emitted.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
agent.logger = MagicMock()
agent.logger.llm = MagicMock()
# Patch _stream_query_llm to yield tokens that do NOT end with punctuation
async def fake_stream(messages):
yield "Hello"
yield " world" # No punctuation to trigger the normal chunking
agent._stream_query_llm = fake_stream
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
# Collect chunks yielded
chunks = []
async for chunk in agent._query_llm(prompt.text, prompt.norms, prompt.goals):
chunks.append(chunk)
# The final chunk should be yielded
assert chunks[-1] == "Hello world"
assert any("Hello" in c for c in chunks)
@pytest.mark.asyncio
async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_settings):
"""
Covers: if not line or not line.startswith("data: "): continue
Feed lines that are empty or do not start with 'data:' and check they are skipped.
"""
# Mock response
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
lines = [
"", # empty line
"not data", # invalid prefix
'data: {"choices": [{"delta": {"content": "Hi"}}]}',
"data: [DONE]",
]
async def aiter_lines_gen():
for line in lines:
yield line
mock_response.aiter_lines.side_effect = aiter_lines_gen
# Proper async context manager for stream
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
# Make stream return the async context manager
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Patch settings for local LLM URL
with patch("control_backend.agents.llm.llm_agent.settings") as mock_sett:
mock_sett.llm_settings.local_llm_url = "http://localhost"
mock_sett.llm_settings.local_llm_model = "test-model"
# Collect tokens
tokens = []
async for token in agent._stream_query_llm([]):
tokens.append(token)
# Only the valid 'data:' line should yield content
assert tokens == ["Hi"]
@pytest.mark.asyncio
async def test_clear_history_command(mock_settings):
"""Test that the 'clear_history' message clears the agent's memory."""
# setup LLM to have some history
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
agent = LLMAgent("llm_agent")
agent.history = [
{"role": "user", "content": "Old conversation context"},
{"role": "assistant", "content": "Old response"},
]
assert len(agent.history) == 2
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_program_manager_name,
body="clear_history",
)
await agent.handle_message(msg)
assert len(agent.history) == 0
@pytest.mark.asyncio
async def test_handle_assistant_and_user_messages(mock_settings):
agent = LLMAgent("llm_agent")
# Assistant message
msg_ast = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
thread="assistant_message",
body="I said this",
)
await agent.handle_message(msg_ast)
assert agent.history[-1] == {"role": "assistant", "content": "I said this"}
# User message
msg_usr = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
thread="user_message",
body="User said this",
)
await agent.handle_message(msg_usr)
assert agent.history[-1] == {"role": "user", "content": "User said this"}

View File

@@ -55,4 +55,6 @@ def test_get_decode_options():
assert isinstance(options["sample_len"], int)
# When disabled, it should not limit output length based on input size
assert "sample_rate" not in options
recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=False)
options = recognizer._get_decode_options(audio)
assert "sample_len" not in options

View File

@@ -36,7 +36,12 @@ async def test_transcription_agent_flow(mock_zmq_context):
agent.send = AsyncMock()
agent._running = True
agent.add_behavior = AsyncMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
@@ -120,3 +125,93 @@ def test_mlx_recognizer():
mlx_mock.transcribe.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"
@pytest.mark.asyncio
async def test_transcription_loop_continues_after_error(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [
fake_audio, # first iteration → recognizer fails
asyncio.CancelledError(), # second iteration → stop loop
]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.side_effect = RuntimeError("fail")
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent._running = True # ← REQUIRED to enter the loop
agent.send = AsyncMock() # should never be called
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro) # match other tests
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# recognizer failed, so we should never send anything
agent.send.assert_not_called()
# recv must have been called twice (audio then CancelledError)
assert mock_sub.recv.call_count == 2
@pytest.mark.asyncio
async def test_transcription_continue_branch_when_empty(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# First recv → audio chunk
# Second recv → Cancel loop → stop iteration
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = "" # <— triggers the continue branch
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
# Make loop runnable
agent._running = True
agent.send = AsyncMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
# Execute loop manually
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# → Because of "continue", NO sending should occur
agent.send.assert_not_called()
# → Continue was hit, so we must have read exactly 2 times:
# - first audio
# - second CancelledError
assert mock_sub.recv.call_count == 2
# → recognizer was called once (first iteration)
assert mock_recognizer.recognize_speech.call_count == 1

View File

@@ -0,0 +1,152 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
@pytest.fixture(autouse=True)
def mock_zmq():
with patch("zmq.asyncio.Context") as mock:
mock.instance.return_value = MagicMock()
yield mock
@pytest.fixture
def agent():
return VADAgent("tcp://localhost:5555", False)
@pytest.mark.asyncio
async def test_handle_message_pause(agent):
agent._paused = MagicMock()
# It starts set (not paused)
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="PAUSE")
# We need to mock settings to match sender name
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
await agent.handle_message(msg)
agent._paused.clear.assert_called_once()
assert agent._reset_needed is True
@pytest.mark.asyncio
async def test_handle_message_resume(agent):
agent._paused = MagicMock()
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="RESUME")
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
await agent.handle_message(msg)
agent._paused.set.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_unknown_command(agent):
agent._paused = MagicMock()
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="UNKNOWN")
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
agent.logger = MagicMock()
await agent.handle_message(msg)
agent._paused.clear.assert_not_called()
agent._paused.set.assert_not_called()
@pytest.mark.asyncio
async def test_handle_message_unknown_sender(agent):
agent._paused = MagicMock()
msg = InternalMessage(to="vad", sender="other_agent", body="PAUSE")
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
await agent.handle_message(msg)
agent._paused.clear.assert_not_called()
@pytest.mark.asyncio
async def test_status_loop_waits_for_running(agent):
agent._running = True
agent.program_sub_socket = AsyncMock()
agent.program_sub_socket.close = MagicMock()
agent._reset_stream = AsyncMock()
# Sequence of messages:
# 1. Wrong topic
# 2. Right topic, wrong status (STARTING)
# 3. Right topic, RUNNING -> Should break loop
agent.program_sub_socket.recv_multipart.side_effect = [
(b"wrong_topic", b"whatever"),
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
]
await agent._status_loop()
assert agent._reset_stream.await_count == 1
agent.program_sub_socket.close.assert_called_once()
@pytest.mark.asyncio
async def test_setup_success(agent, mock_zmq):
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
mock_context = mock_zmq.instance.return_value
mock_sub = MagicMock()
mock_pub = MagicMock()
# We expect multiple socket calls:
# 1. audio_in (SUB)
# 2. audio_out (PUB)
# 3. program_sub (SUB)
mock_context.socket.side_effect = [mock_sub, mock_pub, mock_sub]
with patch("control_backend.agents.perception.vad_agent.torch.hub.load") as mock_load:
mock_load.return_value = (MagicMock(), None)
with patch("control_backend.agents.perception.vad_agent.TranscriptionAgent") as MockTrans:
mock_trans_instance = MockTrans.return_value
mock_trans_instance.start = AsyncMock()
await agent.setup()
mock_trans_instance.start.assert_awaited_once()
assert agent.add_behavior.call_count == 2 # streaming_loop + status_loop
assert agent.audio_in_socket is not None
assert agent.audio_out_socket is not None
assert agent.program_sub_socket is not None
@pytest.mark.asyncio
async def test_reset_stream(agent):
mock_poller = MagicMock()
agent.audio_in_poller = mock_poller
# poll(1) returns not None twice, then None
mock_poller.poll = AsyncMock(side_effect=[b"data", b"data", None])
agent._ready = MagicMock()
await agent._reset_stream()
assert mock_poller.poll.await_count == 3
agent._ready.set.assert_called_once()

View File

@@ -1,9 +1,20 @@
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.core.config import settings
# We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets
# aren't closed properly.
@pytest.fixture(autouse=True)
def mock_zmq():
with patch("zmq.asyncio.Context") as mock:
mock.instance.return_value = MagicMock()
yield mock
@pytest.fixture
@@ -123,3 +134,90 @@ async def test_no_data(audio_out_socket, vad_agent):
audio_out_socket.send.assert_not_called()
assert len(vad_agent.audio_buffer) == 0
@pytest.mark.asyncio
async def test_streaming_loop_reset_needed(audio_out_socket, vad_agent):
"""Test that _reset_needed branch works as expected."""
vad_agent._reset_needed = True
vad_agent._ready.set()
vad_agent._paused.set()
vad_agent._running = True
vad_agent.audio_buffer = np.array([1.0], dtype=np.float32)
vad_agent.i_since_speech = 0
# Mock _reset_stream to stop the loop by setting _running=False
async def mock_reset():
vad_agent._running = False
vad_agent._reset_stream = mock_reset
# Needs a poller to avoid AssertionError
vad_agent.audio_in_poller = AsyncMock()
vad_agent.audio_in_poller.poll.return_value = None
await vad_agent._streaming_loop()
assert vad_agent._reset_needed is False
assert len(vad_agent.audio_buffer) == 0
assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech
@pytest.mark.asyncio
async def test_streaming_loop_no_data_clears_buffer(audio_out_socket, vad_agent):
"""Test that if poll returns None, buffer is cleared if not empty."""
vad_agent.audio_buffer = np.array([1.0], dtype=np.float32)
vad_agent._ready.set()
vad_agent._paused.set()
vad_agent._running = True
class MockPoller:
async def poll(self, timeout_ms=None):
vad_agent._running = False # stop after one poll
return None
vad_agent.audio_in_poller = MockPoller()
await vad_agent._streaming_loop()
assert len(vad_agent.audio_buffer) == 0
assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech
@pytest.mark.asyncio
async def test_vad_model_load_failure_stops_agent(vad_agent):
"""
Test that if loading the VAD model raises an Exception, it is caught,
the agent logs an exception, stops itself, and setup returns.
"""
# Patch torch.hub.load to raise an exception
with patch(
"control_backend.agents.perception.vad_agent.torch.hub.load",
side_effect=Exception("model fail"),
):
# Patch stop to an AsyncMock so we can check it was awaited
vad_agent.stop = AsyncMock()
await vad_agent.setup()
# Assert stop was called
vad_agent.stop.assert_awaited_once()
@pytest.mark.asyncio
async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog):
"""
Test that if binding the output socket raises ZMQBindError,
audio_out_socket is set to None, None is returned, and an error is logged.
"""
mock_socket = MagicMock()
mock_socket.bind.side_effect = zmq.ZMQBindError()
with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx:
mock_ctx.return_value.socket.return_value = mock_socket
with caplog.at_level("ERROR"):
port = vad_agent._connect_audio_out_socket()
assert port is None
assert vad_agent.audio_out_socket is None
assert caplog.text is not None

View File

@@ -0,0 +1,24 @@
import logging
from control_backend.agents.base import BaseAgent
class MyAgent(BaseAgent):
async def setup(self):
pass
async def handle_message(self, msg):
pass
def test_base_agent_logger_init():
# When defining a subclass, __init_subclass__ runs
# The BaseAgent in agents/base.py sets the logger
assert hasattr(MyAgent, "logger")
assert isinstance(MyAgent.logger, logging.Logger)
# The logger name depends on the package.
# Since this test file is running as a module, __package__ might be None or the test package.
# In 'src/control_backend/agents/base.py', it uses __package__ of base.py which is
# 'control_backend.agents'.
# So logger name should be control_backend.agents.MyAgent
assert MyAgent.logger.name == "control_backend.agents.MyAgent"

View File

@@ -0,0 +1,311 @@
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.program import (
ConditionalNorm,
Goal,
KeywordBelief,
Phase,
Plan,
Program,
Trigger,
)
from control_backend.schemas.ri_message import RIEndpoint
@pytest.fixture
def agent():
agent = UserInterruptAgent(name="user_interrupt_agent")
agent.send = AsyncMock()
agent.logger = MagicMock()
agent.sub_socket = AsyncMock()
agent.pub_socket = AsyncMock()
return agent
@pytest.mark.asyncio
async def test_send_to_speech_agent(agent):
"""Verify speech command format."""
await agent._send_to_speech_agent("Hello World")
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.robot_speech_name
body = json.loads(sent_msg.body)
assert body["data"] == "Hello World"
assert body["is_priority"] is True
@pytest.mark.asyncio
async def test_send_to_gesture_agent(agent):
"""Verify gesture command format."""
await agent._send_to_gesture_agent("wave_hand")
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.robot_gesture_name
body = json.loads(sent_msg.body)
assert body["data"] == "wave_hand"
assert body["is_priority"] is True
assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value
@pytest.mark.asyncio
async def test_send_to_bdi_belief(agent):
"""Verify belief update format."""
context_str = "some_goal"
await agent._send_to_bdi_belief(context_str)
assert agent.send.await_count == 1
sent_msg = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.bdi_core_name
assert sent_msg.thread == "beliefs"
assert "achieved_some_goal" in sent_msg.body
@pytest.mark.asyncio
async def test_receive_loop_routing_success(agent):
"""
Test that the loop correctly:
1. Receives 'button_pressed' topic from ZMQ
2. Parses the JSON payload to find 'type' and 'context'
3. Calls the correct handler method based on 'type'
"""
# Prepare JSON payloads as bytes
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
# override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal).
# To test routing, we need to populate the maps
agent._goal_map["Hello Override"] = "some_goal_slug"
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
(b"button_pressed", payload_speech),
(b"button_pressed", payload_gesture),
(b"button_pressed", payload_override),
asyncio.CancelledError, # Stop the infinite loop
]
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_bdi_belief = AsyncMock()
try:
await agent._receive_button_event()
except asyncio.CancelledError:
pass
await asyncio.sleep(0)
# Speech
agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech")
# Gesture
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
# Override (since we mapped it to a goal)
agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug")
assert agent._send_to_speech_agent.await_count == 1
assert agent._send_to_gesture_agent.await_count == 1
assert agent._send_to_bdi_belief.await_count == 1
@pytest.mark.asyncio
async def test_receive_loop_unknown_type(agent):
"""Test that unknown 'type' values in the JSON log a warning and do not crash."""
# Prepare a payload with an unknown type
payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
(b"button_pressed", payload_unknown),
asyncio.CancelledError,
]
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
try:
await agent._receive_button_event()
except asyncio.CancelledError:
pass
await asyncio.sleep(0)
# Ensure no handlers were called
agent._send_to_speech_agent.assert_not_called()
agent._send_to_gesture_agent.assert_not_called()
agent.logger.warning.assert_called()
@pytest.mark.asyncio
async def test_create_mapping(agent):
# Create a program with a trigger, goal, and conditional norm
import uuid
trigger_id = uuid.uuid4()
goal_id = uuid.uuid4()
norm_id = uuid.uuid4()
cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key")
plan = Plan(id=uuid.uuid4(), name="p1", steps=[])
trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan)
goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan)
cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond)
phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger])
prog = Program(phases=[phase])
# Call create_mapping via handle_message
msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json())
await agent.handle_message(msg)
# Check maps
assert str(trigger_id) in agent._trigger_map
assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger"
assert str(goal_id) in agent._goal_map
assert agent._goal_map[str(goal_id)] == "my_goal"
assert str(norm_id) in agent._cond_norm_map
assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite"
@pytest.mark.asyncio
async def test_create_mapping_invalid_json(agent):
# Pass invalid json to handle_message thread "new_program"
msg = InternalMessage(to="me", thread="new_program", body="invalid json")
await agent.handle_message(msg)
# Should log error and maps should remain empty or cleared
agent.logger.error.assert_called()
@pytest.mark.asyncio
async def test_handle_message_trigger_start(agent):
# Setup reverse map manually
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
args = agent.pub_socket.send_multipart.call_args[0][0]
assert args[0] == b"experiment"
payload = json.loads(args[1])
assert payload["type"] == "trigger_update"
assert payload["id"] == "ui_id_123"
assert payload["achieved"] is True
@pytest.mark.asyncio
async def test_handle_message_trigger_end(agent):
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
assert payload["type"] == "trigger_update"
assert payload["achieved"] is False
@pytest.mark.asyncio
async def test_handle_message_transition_phase(agent):
msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
assert payload["type"] == "phase_update"
assert payload["id"] == "phase_id_123"
@pytest.mark.asyncio
async def test_handle_message_goal_start(agent):
agent._goal_reverse_map["goal_slug"] = "goal_id_123"
msg = InternalMessage(to="me", thread="goal_start", body="goal_slug")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
assert payload["type"] == "goal_update"
assert payload["id"] == "goal_id_123"
assert payload["active"] is True
@pytest.mark.asyncio
async def test_handle_message_active_norms_update(agent):
agent._cond_norm_reverse_map["norm_active"] = "id_1"
agent._cond_norm_reverse_map["norm_inactive"] = "id_2"
# Body is like: "('norm_active', 'other')"
# The split logic handles quotes etc.
msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'")
await agent.handle_message(msg)
agent.pub_socket.send_multipart.assert_awaited_once()
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
assert payload["type"] == "cond_norms_state_update"
norms = {n["id"]: n["active"] for n in payload["norms"]}
assert norms["id_1"] is True
assert norms["id_2"] is False
@pytest.mark.asyncio
async def test_send_experiment_control(agent):
# Test next_phase
await agent._send_experiment_control_to_bdi_core("next_phase")
agent.send.assert_awaited()
msg = agent.send.call_args[0][0]
assert msg.thread == "force_next_phase"
# Test reset_phase
await agent._send_experiment_control_to_bdi_core("reset_phase")
msg = agent.send.call_args[0][0]
assert msg.thread == "reset_current_phase"
# Test reset_experiment
await agent._send_experiment_control_to_bdi_core("reset_experiment")
msg = agent.send.call_args[0][0]
assert msg.thread == "reset_experiment"
@pytest.mark.asyncio
async def test_send_pause_command(agent):
await agent._send_pause_command("true")
# Sends to RI and VAD
assert agent.send.await_count == 2
msgs = [call.args[0] for call in agent.send.call_args_list]
ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name)
assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint
assert json.loads(ri_msg.body)["data"] is True
vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name)
assert vad_msg.body == "PAUSE"
agent.send.reset_mock()
await agent._send_pause_command("false")
assert agent.send.await_count == 2
vad_msg = next(
m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name
).args[0]
assert vad_msg.body == "RESUME"

View File

@@ -0,0 +1,63 @@
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.responses import StreamingResponse
from control_backend.api.v1.endpoints import logs
@pytest.fixture
def client():
"""TestClient with logs router included."""
app = FastAPI()
app.include_router(logs.router)
return TestClient(app)
@pytest.mark.asyncio
async def test_log_stream_endpoint_lines(client):
"""Call /logs/stream with a mocked ZMQ socket to cover all lines."""
# Dummy socket to mock ZMQ behavior
class DummySocket:
def __init__(self):
self.subscribed = []
self.connected = False
self.recv_count = 0
def subscribe(self, topic):
self.subscribed.append(topic)
def connect(self, addr):
self.connected = True
async def recv_multipart(self):
# Return one message, then stop generator
if self.recv_count == 0:
self.recv_count += 1
return (b"INFO", b"test message")
else:
raise StopAsyncIteration
dummy_socket = DummySocket()
# Patch Context.instance().socket() to return dummy socket
with patch("control_backend.api.v1.endpoints.logs.Context.instance") as mock_context:
mock_context.return_value.socket.return_value = dummy_socket
# Call the endpoint directly
response = await logs.log_stream()
assert isinstance(response, StreamingResponse)
# Fetch one chunk from the generator
gen = response.body_iterator
chunk = await gen.__anext__()
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
assert "data:" in chunk
# Optional: assert subscribe/connect were called
assert dummy_socket.subscribed # at least some log levels subscribed
assert dummy_socket.connected # connect was called

View File

@@ -0,0 +1,45 @@
import json
import pytest
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import message
@pytest.fixture
def client():
"""FastAPI TestClient for the message router."""
from fastapi import FastAPI
app = FastAPI()
app.include_router(message.router)
return TestClient(app)
def test_receive_message_post(client, monkeypatch):
"""Test POST /message endpoint sends message to pub socket."""
# Dummy pub socket to capture sent messages
class DummyPubSocket:
def __init__(self):
self.sent = []
async def send_multipart(self, msg):
self.sent.append(msg)
dummy_socket = DummyPubSocket()
# Patch app.state.endpoints_pub_socket
client.app.state.endpoints_pub_socket = dummy_socket
data = {"message": "Hello world"}
response = client.post("/message", json=data)
assert response.status_code == 202
assert response.json() == {"status": "Message received"}
# Ensure the message was sent via pub_socket
assert len(dummy_socket.sent) == 1
topic, body = dummy_socket.sent[0]
parsed = json.loads(body.decode("utf-8"))
assert parsed["message"] == "Hello world"

View File

@@ -1,4 +1,5 @@
import json
import uuid
from unittest.mock import AsyncMock
import pytest
@@ -6,7 +7,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import program
from control_backend.schemas.program import Program
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program
@pytest.fixture
@@ -25,29 +26,39 @@ def client(app):
def make_valid_program_dict():
"""Helper to create a valid Program JSON structure."""
return {
"phases": [
{
"id": "phase1",
"label": "basephase",
"norms": [{"id": "n1", "label": "norm", "norm": "be nice"}],
"goals": [
{"id": "g1", "label": "goal", "description": "test goal", "achieved": False}
# Converting to JSON using Pydantic because it knows how to convert a UUID object
program_json_str = Program(
phases=[
Phase(
id=uuid.uuid4(),
name="Basic Phase",
norms=[
BasicNorm(
id=uuid.uuid4(),
name="Some norm",
norm="Do normal.",
),
],
"triggers": [
{
"id": "t1",
"label": "trigger",
"type": "keywords",
"keywords": [
{"id": "kw1", "keyword": "keyword1"},
{"id": "kw2", "keyword": "keyword2"},
],
},
goals=[
Goal(
id=uuid.uuid4(),
name="Some goal",
description="This description can be used to determine whether the goal "
"has been achieved.",
plan=Plan(
id=uuid.uuid4(),
name="Goal Plan",
steps=[],
),
can_fail=False,
),
],
}
]
}
triggers=[],
),
],
).model_dump_json()
# Converting back to a dict because that's what's expected
return json.loads(program_json_str)
def test_receive_program_success(client):
@@ -71,7 +82,8 @@ def test_receive_program_success(client):
sent_bytes = args[0][1]
sent_obj = json.loads(sent_bytes.decode())
expected_obj = Program.model_validate(program_dict).model_dump()
# Converting to JSON using Pydantic because it knows how to handle UUIDs
expected_obj = json.loads(Program.model_validate(program_dict).model_dump_json())
assert sent_obj == expected_obj

View File

@@ -1,3 +1,4 @@
# tests/test_robot_endpoints.py
import json
from unittest.mock import AsyncMock, MagicMock, patch
@@ -29,7 +30,7 @@ def client(app):
@pytest.fixture
def mock_zmq_context():
"""Mock the ZMQ context."""
"""Mock the ZMQ context used by the endpoint module."""
with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context:
context_instance = MagicMock()
mock_context.return_value = context_instance
@@ -38,13 +39,13 @@ def mock_zmq_context():
@pytest.fixture
def mock_sockets(mock_zmq_context):
"""Mock ZMQ sockets."""
"""Optional helper if you want both a sub and req/push socket available."""
mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_pub_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_zmq_context.socket.return_value = mock_sub_socket
return {"sub": mock_sub_socket, "pub": mock_pub_socket}
return {"sub": mock_sub_socket, "req": mock_req_socket}
def test_receive_speech_command_success(client):
@@ -61,11 +62,11 @@ def test_receive_speech_command_success(client):
speech_command = SpeechCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
response = client.post("/command/speech", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
assert response.json() == {"status": "Speech command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
@@ -75,9 +76,8 @@ def test_receive_speech_command_success(client):
def test_receive_gesture_command_success(client):
"""
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
Test for successful reception of a command that is a gesture command.
Ensures the status code is 202 and the response body is correct.
"""
# Arrange
mock_pub_socket = AsyncMock()
@@ -87,11 +87,11 @@ def test_receive_gesture_command_success(client):
gesture_command = GestureCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
response = client.post("/command/gesture", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
assert response.json() == {"status": "Gesture command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
@@ -99,13 +99,23 @@ def test_receive_gesture_command_success(client):
)
def test_receive_command_invalid_payload(client):
def test_receive_speech_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
response = client.post("/command/speech", json=bad_payload)
assert response.status_code == 422 # validation error
def test_receive_gesture_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command/gesture", json=bad_payload)
assert response.status_code == 422 # validation error
@@ -116,7 +126,9 @@ def test_ping_check_returns_none(client):
assert response.json() is None
# TODO: Convert these mock sockets to the fixture.
# ----------------------------
# ping_stream tests (unchanged behavior)
# ----------------------------
@pytest.mark.asyncio
async def test_ping_stream_yields_ping_event(monkeypatch):
"""Test that ping_stream yields a proper SSE message when a ping is received."""
@@ -129,6 +141,11 @@ async def test_ping_stream_yields_ping_event(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
# patch settings address used by ping_stream
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
@@ -142,7 +159,7 @@ async def test_ping_stream_yields_ping_event(monkeypatch):
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@@ -159,6 +176,10 @@ async def test_ping_stream_handles_timeout(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(return_value=True)
@@ -168,7 +189,7 @@ async def test_ping_stream_handles_timeout(monkeypatch):
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@@ -187,6 +208,10 @@ async def test_ping_stream_yields_json_values(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
@@ -199,183 +224,135 @@ async def test_ping_stream_yields_json_values(monkeypatch):
assert "connected" in event_text
assert "true" in event_text
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
# New tests for get_available_gesture_tags endpoint
# ----------------------------
# Updated get_available_gesture_tags tests (REQ socket on tcp://localhost:7788)
# ----------------------------
@pytest.mark.asyncio
async def test_get_available_gesture_tags_success(client, monkeypatch):
"""
Test successful retrieval of available gesture tags.
Test successful retrieval of available gesture tags using a REQ socket.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with gesture tags
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": ["wave", "nod", "point", "dance"]}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger to avoid actual logging
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Replace logger methods to avoid noisy logs in tests
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]}
# Verify ZeroMQ interactions
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_sub_socket.recv_multipart.assert_awaited_once()
# Verify ZeroMQ REQ interactions
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
mock_req_socket.recv.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_available_gesture_tags_with_amount(client, monkeypatch):
"""
Test retrieval of gesture tags with a specific amount parameter.
This tests the TODO in the endpoint about getting a certain amount from the UI.
The endpoint currently ignores the 'amount' TODO, so behavior is the same as 'success'.
This test asserts that the endpoint still sends b"None" and returns the tags.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with gesture tags
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": ["wave", "nod"]}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Act - Note: The endpoint currently doesn't support query parameters for amount,
# but we're testing what happens if the UI sends an amount (the TODO in the code)
# For now, we test the current behavior
response = client.get("/get_available_gesture_tags")
# Act
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod"]}
# The endpoint currently doesn't use the amount parameter, so it should send empty bytes
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
@pytest.mark.asyncio
async def test_get_available_gesture_tags_timeout(client, monkeypatch):
"""
Test timeout scenario when fetching gesture tags.
Test timeout scenario when fetching gesture tags. Endpoint should handle TimeoutError
and return an empty list while logging the timeout.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a timeout
mock_sub_socket.recv_multipart = AsyncMock(side_effect=TimeoutError)
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
mock_req_socket.recv = AsyncMock(side_effect=TimeoutError)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger to verify debug message is logged
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Patch logger.debug so we can assert it was called with the expected message
mock_debug = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_debug)
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
# On timeout, body becomes b"" and json.loads(b"") raises JSONDecodeError
# But looking at the endpoint code, it will try to parse empty bytes which will fail
# Let's check what actually happens
assert response.json() == {"available_gesture_tags": []}
# Verify the timeout was logged
mock_logger.assert_called_once_with("got timeout error fetching gestures")
# Verify the timeout was logged using the exact string from the endpoint code
mock_debug.assert_called_once_with("Got timeout error fetching gestures.")
# Verify ZeroMQ interactions
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_sub_socket.recv_multipart.assert_awaited_once()
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
mock_req_socket.recv.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_available_gesture_tags_empty_response(client, monkeypatch):
"""
Test scenario when response contains no tags.
Test scenario when response contains an empty 'tags' list.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with empty tags
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": []}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
@@ -388,65 +365,51 @@ async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch):
Test scenario when response JSON doesn't contain 'tags' key.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response without 'tags' key
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"some_other_key": "value"}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
# .get("tags", []) should return empty list if 'tags' key is missing
assert response.json() == {"available_gesture_tags": []}
@pytest.mark.asyncio
async def test_get_available_gesture_tags_invalid_json(client, monkeypatch):
"""
Test scenario when response contains invalid JSON.
Test scenario when response contains invalid JSON. Endpoint should log the error
and return an empty list.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with invalid JSON
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"get_gestures", b"invalid json"])
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
mock_req_socket.recv = AsyncMock(return_value=b"invalid json")
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_error = MagicMock()
monkeypatch.setattr(robot.logger, "error", mock_error)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
# Act
response = client.get("/get_available_gesture_tags")
response = client.get("/commands/gesture/tags")
# Assert - invalid JSON should raise an exception
# Assert - invalid JSON should lead to empty list and error log invocation
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []}
assert mock_error.call_count == 1

View File

@@ -0,0 +1,16 @@
from fastapi.routing import APIRoute
from control_backend.api.v1.router import api_router # <--- corrected import
def test_router_includes_expected_paths():
"""Ensure api_router includes main router prefixes."""
routes = [r for r in api_router.routes if isinstance(r, APIRoute)]
paths = [r.path for r in routes]
# Ensure at least one route under each prefix exists
assert any(p.startswith("/robot") for p in paths)
assert any(p.startswith("/message") for p in paths)
assert any(p.startswith("/sse") for p in paths)
assert any(p.startswith("/logs") for p in paths)
assert any(p.startswith("/program") for p in paths)

View File

@@ -0,0 +1,24 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import sse
@pytest.fixture
def app():
app = FastAPI()
app.include_router(sse.router)
return app
@pytest.fixture
def client(app):
return TestClient(app)
def test_sse_route_exists(client):
"""Minimal smoke test to ensure /sse route exists and responds."""
response = client.get("/sse")
# Since implementation is not done, we only assert it doesn't crash
assert response.status_code == 200

View File

@@ -0,0 +1,96 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import user_interact
@pytest.fixture
def app():
app = FastAPI()
app.include_router(user_interact.router)
return app
@pytest.fixture
def client(app):
return TestClient(app)
@pytest.mark.asyncio
async def test_receive_button_event(client):
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
payload = {"type": "speech", "context": "hello"}
response = client.post("/button_pressed", json=payload)
assert response.status_code == 202
assert response.json() == {"status": "Event received"}
mock_pub_socket.send_multipart.assert_awaited_once()
args = mock_pub_socket.send_multipart.call_args[0][0]
assert args[0] == b"button_pressed"
assert "speech" in args[1].decode()
@pytest.mark.asyncio
async def test_receive_button_event_invalid_payload(client):
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Missing context
payload = {"type": "speech"}
response = client.post("/button_pressed", json=payload)
assert response.status_code == 422
mock_pub_socket.send_multipart.assert_not_called()
@pytest.mark.asyncio
async def test_experiment_stream_direct_call():
"""
Directly calling the endpoint function to test the streaming logic
without dealing with TestClient streaming limitations.
"""
mock_socket = AsyncMock()
# 1. recv data
# 2. recv timeout
# 3. disconnect (request.is_disconnected returns True)
mock_socket.recv_multipart.side_effect = [
(b"topic", b"message1"),
TimeoutError(),
(b"topic", b"message2"), # Should not be reached if disconnect checks work
]
mock_socket.close = MagicMock()
mock_socket.connect = MagicMock()
mock_socket.subscribe = MagicMock()
mock_context = MagicMock()
mock_context.socket.return_value = mock_socket
with patch(
"control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context
):
mock_request = AsyncMock()
# is_disconnected sequence:
# 1. False (before first recv) -> reads message1
# 2. False (before second recv) -> triggers TimeoutError, continues
# 3. True (before third recv) -> break loop
mock_request.is_disconnected.side_effect = [False, False, True]
response = await user_interact.experiment_stream(mock_request)
lines = []
# Consume the generator
async for line in response.body_iterator:
lines.append(line)
assert "data: message1\n\n" in lines
assert len(lines) == 1
mock_socket.connect.assert_called()
mock_socket.subscribe.assert_called_with(b"experiment")
mock_socket.close.assert_called()

View File

@@ -25,7 +25,6 @@ def mock_settings():
mock.zmq_settings.internal_sub_address = "tcp://localhost:5561"
mock.zmq_settings.ri_command_address = "tcp://localhost:0000"
mock.agent_settings.bdi_core_name = "bdi_core_agent"
mock.agent_settings.bdi_belief_collector_name = "belief_collector_agent"
mock.agent_settings.llm_name = "llm_agent"
mock.agent_settings.robot_speech_name = "robot_speech_agent"
mock.agent_settings.transcription_name = "transcription_agent"

View File

@@ -2,7 +2,7 @@
import asyncio
import logging
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -70,3 +70,205 @@ async def test_get_agent():
agent = ConcreteTestAgent("registrant")
assert AgentDirectory.get("registrant") == agent
assert AgentDirectory.get("non_existent") is None
class DummyAgent(BaseAgent):
async def setup(self):
pass # we will test this separately
async def handle_message(self, msg: InternalMessage):
self.last_handled = msg
@pytest.mark.asyncio
async def test_base_agent_setup_is_noop():
agent = DummyAgent("dummy")
# Should simply return without error
assert await agent.setup() is None
@pytest.mark.asyncio
async def test_send_to_local_agent(monkeypatch):
sender = DummyAgent("sender")
target = DummyAgent("receiver")
# Fake logger
sender.logger = MagicMock()
# Patch inbox.put
target.inbox.put = AsyncMock()
message = InternalMessage(to=target.name, sender=sender.name, body="hello")
await sender.send(message)
target.inbox.put.assert_awaited_once_with(message)
@pytest.mark.asyncio
async def test_send_to_zmq_agent(monkeypatch):
sender = DummyAgent("sender")
target = "remote_receiver"
# Fake logger
sender.logger = MagicMock()
# Fake zmq
sender._internal_pub_socket = AsyncMock()
message = InternalMessage(to=target, sender=sender.name, body="hello")
await sender.send(message)
zmq_calls = sender._internal_pub_socket.send_multipart.call_args[0][0]
assert zmq_calls[0] == f"internal/{target}".encode()
@pytest.mark.asyncio
async def test_send_to_multiple_local_agents(monkeypatch):
sender = DummyAgent("sender")
target1 = DummyAgent("receiver1")
target2 = DummyAgent("receiver2")
# Fake logger
sender.logger = MagicMock()
# Patch inbox.put
target1.inbox.put = AsyncMock()
target2.inbox.put = AsyncMock()
message = InternalMessage(to=[target1.name, target2.name], sender=sender.name, body="hello")
await sender.send(message)
target1.inbox.put.assert_awaited_once_with(message)
target2.inbox.put.assert_awaited_once_with(message)
@pytest.mark.asyncio
async def test_send_to_multiple_agents(monkeypatch):
sender = DummyAgent("sender")
target1 = DummyAgent("receiver1")
target2 = "remote_receiver"
# Fake logger
sender.logger = MagicMock()
# Fake zmq
sender._internal_pub_socket = AsyncMock()
# Patch inbox.put
target1.inbox.put = AsyncMock()
message = InternalMessage(to=[target1.name, target2], sender=sender.name, body="hello")
await sender.send(message)
target1.inbox.put.assert_awaited_once_with(message)
zmq_calls = sender._internal_pub_socket.send_multipart.call_args[0][0]
assert zmq_calls[0] == f"internal/{target2}".encode()
@pytest.mark.asyncio
async def test_process_inbox_calls_handle_message(monkeypatch):
agent = DummyAgent("dummy")
agent.logger = MagicMock()
# Make agent running so loop triggers
agent._running = True
# Prepare inbox to give one message then stop
msg = InternalMessage(to="dummy", sender="x", body="test")
async def get_once():
agent._running = False # stop after first iteration
return msg
agent.inbox.get = AsyncMock(side_effect=get_once)
agent.handle_message = AsyncMock()
await agent._process_inbox()
agent.handle_message.assert_awaited_once_with(msg)
@pytest.mark.asyncio
async def test_receive_internal_zmq_loop_success(monkeypatch):
agent = DummyAgent("dummy")
agent.logger = MagicMock()
agent._running = True
mock_socket = MagicMock()
mock_socket.recv_multipart = AsyncMock(
side_effect=[
(
b"topic",
InternalMessage(to="dummy", sender="x", body="hi").model_dump_json().encode(),
),
asyncio.CancelledError(), # stop loop
]
)
agent._internal_sub_socket = mock_socket
agent.inbox.put = AsyncMock()
await agent._receive_internal_zmq_loop()
agent.inbox.put.assert_awaited() # message forwarded
@pytest.mark.asyncio
async def test_receive_internal_zmq_loop_exception_logs_error():
agent = DummyAgent("dummy")
agent.logger = MagicMock()
agent._running = True
mock_socket = MagicMock()
mock_socket.recv_multipart = AsyncMock(
side_effect=[Exception("boom"), asyncio.CancelledError()]
)
agent._internal_sub_socket = mock_socket
agent.inbox.put = AsyncMock()
await agent._receive_internal_zmq_loop()
agent.logger.exception.assert_called_once()
assert "Could not process ZMQ message." in agent.logger.exception.call_args[0][0]
@pytest.mark.asyncio
async def test_base_agent_handle_message_not_implemented():
class RawAgent(BaseAgent):
async def setup(self):
pass
agent = RawAgent("raw")
msg = InternalMessage(to="raw", sender="x", body="hi")
with pytest.raises(NotImplementedError):
await BaseAgent.handle_message(agent, msg)
@pytest.mark.asyncio
async def test_base_agent_setup_abstract_method_body_executes():
"""
Covers the 'pass' inside BaseAgent.setup().
Since BaseAgent is abstract, we do NOT instantiate it.
We call the coroutine function directly on BaseAgent and pass a dummy self.
"""
class Dummy:
"""Minimal stub to act as 'self'."""
pass
stub = Dummy()
# Call BaseAgent.setup() as an unbound coroutine, passing stub as 'self'
result = await BaseAgent.setup(stub)
# The method contains only 'pass', so it returns None
assert result is None

View File

@@ -86,3 +86,34 @@ def test_setup_logging_zmq_handler(mock_zmq_context):
args = mock_dict_config.call_args[0][0]
assert "interface_or_socket" in args["handlers"]["ui"]
def test_add_logging_level_method_name_exists_in_logging():
# method_name explicitly set to an existing logging method → triggers first hasattr branch
with pytest.raises(AttributeError) as exc:
add_logging_level("NEWDUPLEVEL", 37, method_name="info")
assert "info already defined in logging module" in str(exc.value)
def test_add_logging_level_method_name_exists_in_logger_class():
# 'makeRecord' exists on Logger class but not on the logging module
with pytest.raises(AttributeError) as exc:
add_logging_level("ANOTHERLEVEL", 38, method_name="makeRecord")
assert "makeRecord already defined in logger class" in str(exc.value)
def test_add_logging_level_log_to_root_path_executes_without_error():
# Verify log_to_root is installed and callable — without asserting logging output
level_name = "ROOTTEST"
level_num = 36
add_logging_level(level_name, level_num)
# Simply call the injected root logger method
# The line is executed even if we don't validate output
root_logging_method = getattr(logging, level_name.lower(), None)
assert callable(root_logging_method)
# Execute the method to hit log_to_root in coverage.
# No need to verify log output.
root_logging_method("some message")

View File

@@ -0,0 +1,12 @@
from control_backend.schemas.message import Message
def base_message() -> Message:
return Message(message="Example")
def test_valid_message():
mess = base_message()
validated = Message.model_validate(mess)
assert isinstance(validated, Message)
assert validated.message == "Example"

View File

@@ -1,49 +1,66 @@
import uuid
import pytest
from pydantic import ValidationError
from control_backend.schemas.program import (
BasicNorm,
ConditionalNorm,
Goal,
KeywordTrigger,
Norm,
InferredBelief,
KeywordBelief,
LogicalOperator,
Phase,
Plan,
Program,
TriggerKeyword,
SemanticBelief,
Trigger,
)
def base_norm() -> Norm:
return Norm(
id="norm1",
label="testNorm",
def base_norm() -> BasicNorm:
return BasicNorm(
id=uuid.uuid4(),
name="testNormName",
norm="testNormNorm",
critical=False,
)
def base_goal() -> Goal:
return Goal(
id="goal1",
label="testGoal",
description="testGoalDescription",
achieved=False,
id=uuid.uuid4(),
name="testGoalName",
description="This description can be used to determine whether the goal has been achieved.",
plan=Plan(
id=uuid.uuid4(),
name="testGoalPlanName",
steps=[],
),
can_fail=False,
)
def base_trigger() -> KeywordTrigger:
return KeywordTrigger(
id="trigger1",
label="testTrigger",
type="keywords",
keywords=[
TriggerKeyword(id="keyword1", keyword="testKeyword1"),
TriggerKeyword(id="keyword1", keyword="testKeyword2"),
],
def base_trigger() -> Trigger:
return Trigger(
id=uuid.uuid4(),
name="testTriggerName",
condition=KeywordBelief(
id=uuid.uuid4(),
name="testTriggerKeywordBeliefTriggerName",
keyword="Keyword",
),
plan=Plan(
id=uuid.uuid4(),
name="testTriggerPlanName",
steps=[],
),
)
def base_phase() -> Phase:
return Phase(
id="phase1",
label="basephase",
id=uuid.uuid4(),
norms=[base_norm()],
goals=[base_goal()],
triggers=[base_trigger()],
@@ -58,7 +75,7 @@ def invalid_program() -> dict:
# wrong types inside phases list (not Phase objects)
return {
"phases": [
{"id": "phase1"}, # incomplete
{"id": uuid.uuid4()}, # incomplete
{"not_a_phase": True},
]
}
@@ -77,11 +94,112 @@ def test_valid_deepprogram():
# validate nested components directly
phase = validated.phases[0]
assert isinstance(phase.goals[0], Goal)
assert isinstance(phase.triggers[0], KeywordTrigger)
assert isinstance(phase.norms[0], Norm)
assert isinstance(phase.triggers[0], Trigger)
assert isinstance(phase.norms[0], BasicNorm)
def test_invalid_program():
bad = invalid_program()
with pytest.raises(ValidationError):
Program.model_validate(bad)
def test_conditional_norm_parsing():
"""
Check that pydantic is able to preserve the type of the norm, that it doesn't lose its
"condition" field when serializing and deserializing.
"""
norm = ConditionalNorm(
name="testNormName",
id=uuid.uuid4(),
norm="testNormNorm",
critical=False,
condition=KeywordBelief(
name="testKeywordBelief",
id=uuid.uuid4(),
keyword="testKeywordBelief",
),
)
program = Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[norm],
goals=[],
triggers=[],
),
],
)
parsed_program = Program.model_validate_json(program.model_dump_json())
parsed_norm = parsed_program.phases[0].norms[0]
assert hasattr(parsed_norm, "condition")
assert isinstance(parsed_norm, ConditionalNorm)
def test_belief_type_parsing():
"""
Check that pydantic is able to discern between the different types of beliefs when serializing
and deserializing.
"""
keyword_belief = KeywordBelief(
name="testKeywordBelief",
id=uuid.uuid4(),
keyword="something",
)
semantic_belief = SemanticBelief(
name="testSemanticBelief",
id=uuid.uuid4(),
description="something",
)
inferred_belief = InferredBelief(
name="testInferredBelief",
id=uuid.uuid4(),
operator=LogicalOperator.OR,
left=keyword_belief,
right=semantic_belief,
)
program = Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[],
goals=[],
triggers=[
Trigger(
name="testTriggerKeywordTrigger",
id=uuid.uuid4(),
condition=keyword_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
Trigger(
name="testTriggerSemanticTrigger",
id=uuid.uuid4(),
condition=semantic_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
Trigger(
name="testTriggerInferredTrigger",
id=uuid.uuid4(),
condition=inferred_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
],
),
],
)
parsed_program = Program.model_validate_json(program.model_dump_json())
parsed_keyword_belief = parsed_program.phases[0].triggers[0].condition
assert isinstance(parsed_keyword_belief, KeywordBelief)
parsed_semantic_belief = parsed_program.phases[0].triggers[1].condition
assert isinstance(parsed_semantic_belief, SemanticBelief)
parsed_inferred_belief = parsed_program.phases[0].triggers[2].condition
assert isinstance(parsed_inferred_belief, InferredBelief)

73
test/unit/test_main.py Normal file
View File

@@ -0,0 +1,73 @@
import asyncio
import sys
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from control_backend.api.v1.router import api_router
from control_backend.main import app, lifespan
# Fix event loop on Windows
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@pytest.fixture
def client():
# Patch setup_logging so it does nothing
with patch("control_backend.main.setup_logging"):
with TestClient(app) as c:
yield c
def test_root_fast():
# Patch heavy startup code so it doesnt slow down
with patch("control_backend.main.setup_logging"), patch("control_backend.main.lifespan"):
client = TestClient(app)
resp = client.get("/")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
def test_cors_middleware_added():
"""Test that CORSMiddleware is correctly added to the app."""
from starlette.middleware.cors import CORSMiddleware
middleware_classes = [m.cls for m in app.user_middleware]
assert CORSMiddleware in middleware_classes
def test_api_router_included():
"""Test that the API router is included in the FastAPI app."""
route_paths = [r.path for r in app.routes]
for route in api_router.routes:
assert route.path in route_paths
@pytest.mark.asyncio
async def test_lifespan_agent_start_exception():
"""
Trigger an exception during agent startup to cover the error logging branch.
Ensures exceptions are logged properly and re-raised.
"""
with (
patch(
"control_backend.main.RICommunicationAgent.start", new_callable=AsyncMock
) as ri_start,
patch("control_backend.main.setup_logging"),
patch("control_backend.main.threading.Thread"),
):
# Force RICommunicationAgent.start to raise an exception
ri_start.side_effect = Exception("Test exception")
with patch("control_backend.main.logger") as mock_logger:
with pytest.raises(Exception, match="Test exception"):
async with lifespan(app):
pass
# Verify the error was logged correctly
assert mock_logger.error.called
args, _ = mock_logger.error.call_args
assert isinstance(args[2], Exception)

View File

@@ -0,0 +1,40 @@
from unittest.mock import MagicMock, patch
import zmq
from control_backend.main import setup_sockets
def test_setup_sockets_proxy():
mock_context = MagicMock()
mock_pub = MagicMock()
mock_sub = MagicMock()
mock_context.socket.side_effect = [mock_pub, mock_sub]
with patch("zmq.asyncio.Context.instance", return_value=mock_context):
with patch("zmq.proxy") as mock_proxy:
setup_sockets()
mock_pub.bind.assert_called()
mock_sub.bind.assert_called()
mock_proxy.assert_called_with(mock_sub, mock_pub)
# Check cleanup
mock_pub.close.assert_called()
mock_sub.close.assert_called()
def test_setup_sockets_proxy_error():
mock_context = MagicMock()
mock_pub = MagicMock()
mock_sub = MagicMock()
mock_context.socket.side_effect = [mock_pub, mock_sub]
with patch("zmq.asyncio.Context.instance", return_value=mock_context):
with patch("zmq.proxy", side_effect=zmq.ZMQError):
with patch("control_backend.main.logger") as mock_logger:
setup_sockets()
mock_logger.warning.assert_called()
mock_pub.close.assert_called()
mock_sub.close.assert_called()

23
uv.lock generated
View File

@@ -997,6 +997,7 @@ dependencies = [
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "python-json-logger" },
{ name = "python-slugify" },
{ name = "pyyaml" },
{ name = "pyzmq" },
{ name = "silero-vad" },
@@ -1046,6 +1047,7 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.12.0" },
{ name = "pydantic-settings", specifier = ">=2.11.0" },
{ name = "python-json-logger", specifier = ">=4.0.0" },
{ name = "python-slugify", specifier = ">=8.0.4" },
{ name = "pyyaml", specifier = ">=6.0.3" },
{ name = "pyzmq", specifier = ">=27.1.0" },
{ name = "silero-vad", specifier = ">=6.0.0" },
@@ -1341,6 +1343,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" },
]
[[package]]
name = "python-slugify"
version = "8.0.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "text-unidecode" },
]
sdist = { url = "https://files.pythonhosted.org/packages/87/c7/5e1547c44e31da50a460df93af11a535ace568ef89d7a811069ead340c4a/python-slugify-8.0.4.tar.gz", hash = "sha256:59202371d1d05b54a9e7720c5e038f928f45daaffe41dd10822f3907b937c856", size = 10921, upload-time = "2024-02-08T18:32:45.488Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/62/02da182e544a51a5c3ccf4b03ab79df279f9c60c5e82d5e8bec7ca26ac11/python_slugify-8.0.4-py2.py3-none-any.whl", hash = "sha256:276540b79961052b66b7d116620b36518847f52d5fd9e3a70164fc8c50faa6b8", size = 10051, upload-time = "2024-02-08T18:32:43.911Z" },
]
[[package]]
name = "pyyaml"
version = "6.0.3"
@@ -1864,6 +1878,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" },
]
[[package]]
name = "text-unidecode"
version = "1.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ab/e2/e9a00f0ccb71718418230718b3d900e71a5d16e701a3dae079a21e9cd8f8/text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93", size = 76885, upload-time = "2019-08-30T21:36:45.405Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a6/a5/c0b6468d3824fe3fde30dbb5e1f687b291608f9473681bbf7dabbf5a87d7/text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8", size = 78154, upload-time = "2019-08-30T21:37:03.543Z" },
]
[[package]]
name = "tiktoken"
version = "0.12.0"