Compare commits

...

37 Commits

Author SHA1 Message Date
JobvAlewijk
128e9b3c00 chore: fixed sending stuff to ui 2026-01-06 14:52:53 +01:00
JobvAlewijk
a96e332d63 chore: cleanup 2026-01-06 14:12:11 +01:00
JobvAlewijk
955b3109bc Merge branch 'dev' of https://git.science.uu.nl/ics/sp/2025/n25b/pepperplus-cb into feat/10-basic-gestures 2026-01-06 14:08:55 +01:00
Björn Otgaar
612a96940d Merge branch 'feat/environment-variables' into 'dev'
Docs for environment variables, parameterize some constants

See merge request ics/sp/2025/n25b/pepperplus-cb!38
2026-01-06 09:02:49 +00:00
Pim Hutting
4c20656c75 Merge branch 'feat/program-reset-llm' into 'dev'
feat: made program reset LLM

See merge request ics/sp/2025/n25b/pepperplus-cb!39
2026-01-02 15:13:05 +00:00
Pim Hutting
6ca86e4b81 feat: made program reset LLM 2026-01-02 15:13:04 +00:00
JobvAlewijk
aa6a90f4e1 Merge branch 'dev' of https://git.science.uu.nl/ics/sp/2025/n25b/pepperplus-cb into feat/10-basic-gestures 2025-12-29 19:43:03 +01:00
JobvAlewijk
3571bd614f feat: single gestures are forwarded properly to ui
ref: N25B-399
2025-12-29 19:23:10 +01:00
JobvAlewijk
8cfd59c14b feat: added way to communicate 10 basic gestures
ref: N25B-399
2025-12-29 16:00:25 +01:00
Twirre Meulenbelt
7d798f2e77 Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:40:16 +01:00
Twirre Meulenbelt
5282c2471f Merge remote-tracking branch 'origin/dev' into feat/environment-variables
# Conflicts:
#	src/control_backend/core/config.py
#	test/unit/agents/actuation/test_robot_speech_agent.py
2025-12-29 12:35:39 +01:00
Luijkx,S.O.H. (Storm)
adbb7ffd5c Merge branch 'feat/user-interrupt-agent' into 'dev'
create UserInterruptAgent with connection to UI

See merge request ics/sp/2025/n25b/pepperplus-cb!40
2025-12-22 13:56:03 +00:00
Pim Hutting
0501a9fba3 create UserInterruptAgent with connection to UI 2025-12-22 13:56:02 +00:00
JobvAlewijk
3e7f2ef574 Merge branch 'feat/quiet-llm' into 'dev'
feat: implemented extra log level for LLM token stream

See merge request ics/sp/2025/n25b/pepperplus-cb!37
2025-12-16 11:26:37 +00:00
Luijkx,S.O.H. (Storm)
78abad55d3 feat: implemented extra log level for LLM token stream 2025-12-16 11:26:35 +00:00
Twirre
4ab6b2a0e6 Merge branch 'feat/cb2ri-gestures' into 'dev'
Gestures in the CB.

See merge request ics/sp/2025/n25b/pepperplus-cb!36
2025-12-16 09:24:25 +00:00
Twirre Meulenbelt
db5504db20 chore: remove redundant check 2025-12-16 10:22:11 +01:00
Björn Otgaar
f15a518984 fix: tests
ref: N25B-334
2025-12-15 11:52:01 +01:00
Björn Otgaar
71d86f5fb0 Merge branch 'feat/cb2ri-gestures' of git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb into feat/cb2ri-gestures 2025-12-15 11:36:12 +01:00
Björn Otgaar
daf31ac6a6 fix: change the address to the config, update some logic, seperate the api endpoint, renaming things. yes, the tests don't work right now- this shouldn't be merged yet.
ref: N25B-334
2025-12-15 11:35:56 +01:00
Björn Otgaar
b2d014753d Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Pim Hutting <p.r.p.hutting@students.uu.nl>
2025-12-11 15:08:15 +00:00
Twirre Meulenbelt
0c682d6440 feat: introduce .env.example, docs
The example includes options that are expected to be changed. It also includes a reference to where in the docs you can find a full list of options.

ref: N25B-352
2025-12-11 13:35:19 +01:00
Björn Otgaar
2e472ea292 chore: remove wrong test paths 2025-12-11 12:48:18 +01:00
Björn Otgaar
1c9b722ba3 Merge branch 'dev' into feat/cb2ri-gestures 2025-12-11 12:46:32 +01:00
Twirre Meulenbelt
32d8f20dc9 feat: parameterize RI host
Was "localhost" in RI Communication Agent, now uses configurable setting. Secretly also removing "localhost" from VAD agent, as its socket should be something that's "inproc".

ref: N25B-352
2025-12-11 12:12:15 +01:00
Twirre Meulenbelt
9cc0e39955 fix: failures main tests since VAD agent initialization was changed
The test still expects the VAD agent to be started in main, rather than in the RI Communication Agent.

ref: N25B-356
2025-12-11 12:04:24 +01:00
Björn Otgaar
2366255b92 Merge branch 'fix/correct-vad-starting' into 'dev'
Move VAD agent creation to RI communication agent

See merge request ics/sp/2025/n25b/pepperplus-cb!34
2025-12-09 14:57:09 +00:00
Björn Otgaar
7f34fede81 fix: fix the tests
ref: N25B-334
2025-12-09 15:37:00 +01:00
Luijkx,S.O.H. (Storm)
a9255cb6e7 Merge branch 'test/coverage-max-all' into 'dev'
test: increased cb test coverage

See merge request ics/sp/2025/n25b/pepperplus-cb!32
2025-12-09 13:14:03 +00:00
JobvAlewijk
7f7c658901 test: increased cb test coverage 2025-12-09 13:14:02 +00:00
Björn Otgaar
3d62e7fc0c Merge branch 'feat/cb2ri-gestures' of git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb into feat/cb2ri-gestures 2025-12-09 14:13:01 +01:00
Björn Otgaar
6034263259 fix: correct the gestures bugs, change gestures socket to request/reply
ref: N25B-334
2025-12-09 14:08:59 +01:00
JobvAlewijk
63897f5969 chore: double tag 2025-12-09 12:33:43 +01:00
JobvAlewijk
a3cf389c05 Merge branch 'dev' of https://git.science.uu.nl/ics/sp/2025/n25b/pepperplus-cb into fix/correct-vad-starting 2025-12-07 23:09:31 +01:00
JobvAlewijk
de2e56ffce Merge branch 'fix/fix-socket-typing' into 'dev'
chore: fix socket typing in robot speech agent

See merge request ics/sp/2025/n25b/pepperplus-cb!33
2025-12-03 14:26:46 +00:00
Twirre Meulenbelt
21e9d05d6e fix: move VAD agent creation to RI communication agent
Previously, it was started in main, but it should use values negotiated by the RI communication agent.

ref: N25B-356
2025-12-03 15:07:29 +01:00
Björn Otgaar
bacc63aa31 chore: fix socket typing in robot speech agent 2025-12-02 14:22:39 +01:00
42 changed files with 1980 additions and 686 deletions

20
.env.example Normal file
View File

@@ -0,0 +1,20 @@
# Example .env file. To use, make a copy, call it ".env" (i.e. removing the ".example" suffix), then you edit values.
# The hostname of the Robot Interface. Change if the Control Backend and Robot Interface are running on different computers.
RI_HOST="localhost"
# URL for the local LLM API. Must be an API that implements the OpenAI Chat Completions API, but most do.
LLM_SETTINGS__LOCAL_LLM_URL="http://localhost:1234/v1/chat/completions"
# Name of the local LLM model to use.
LLM_SETTINGS__LOCAL_LLM_MODEL="gpt-oss"
# Number of non-speech chunks to wait before speech ended. A chunk is approximately 31 ms. Increasing this number allows longer pauses in speech, but also increases response time.
BEHAVIOUR_SETTINGS__VAD_NON_SPEECH_PATIENCE_CHUNKS=3
# Timeout in milliseconds for socket polling. Increase this number if network latency/jitter is high, often the case when using Wi-Fi. Perhaps 500 ms. A symptom of this issue is transcriptions getting cut off.
BEHAVIOUR_SETTINGS__SOCKET_POLLER_TIMEOUT_MS=100
# For an exhaustive list of options, see the control_backend.core.config module in the docs.

View File

@@ -0,0 +1,9 @@
%{first_multiline_commit_description}
To verify:
- [ ] Style checks pass
- [ ] Pipeline (tests) pass
- [ ] Documentation is up to date
- [ ] Tests are up to date (new code is covered)
- [ ] ...

View File

@@ -3,6 +3,7 @@ version: 1
custom_levels: custom_levels:
OBSERVATION: 25 OBSERVATION: 25
ACTION: 26 ACTION: 26
LLM: 9
formatters: formatters:
# Console output # Console output
@@ -26,7 +27,7 @@ handlers:
stream: ext://sys.stdout stream: ext://sys.stdout
ui: ui:
class: zmq.log.handlers.PUBHandler class: zmq.log.handlers.PUBHandler
level: DEBUG level: LLM
formatter: json_experiment formatter: json_experiment
# Level of external libraries # Level of external libraries
@@ -36,5 +37,5 @@ root:
loggers: loggers:
control_backend: control_backend:
level: DEBUG level: LLM
handlers: [ui] handlers: [ui]

View File

@@ -27,6 +27,7 @@ This + part might differ based on what model you choose.
copy the model name in the module loaded and replace local_llm_modelL. In settings. copy the model name in the module loaded and replace local_llm_modelL. In settings.
## Running ## Running
To run the project (development server), execute the following command (while inside the root repository): To run the project (development server), execute the following command (while inside the root repository):
@@ -34,6 +35,14 @@ To run the project (development server), execute the following command (while in
uv run fastapi dev src/control_backend/main.py uv run fastapi dev src/control_backend/main.py
``` ```
### Environment Variables
You can use environment variables to change settings. Make a copy of the [`.env.example`](.env.example) file, name it `.env` and put it in the root directory. The file itself describes how to do the configuration.
For an exhaustive list of environment options, see the `control_backend.core.config` module in the docs.
## Testing ## Testing
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests: Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:

View File

@@ -12,33 +12,39 @@ from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
class RobotGestureAgent(BaseAgent): class RobotGestureAgent(BaseAgent):
""" """
This agent acts as a bridge between the control backend and the Robot Interface (RI). This agent acts as a bridge between the control backend and the Robot Interface (RI).
It receives speech commands from other agents or from the UI, It receives gesture commands from other agents or from the UI,
and forwards them to the robot via a ZMQ PUB socket. and forwards them to the robot via a ZMQ PUB socket.
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI). :ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
:ivar pubsocket: ZMQ PUB socket for sending commands to the Robot Interface. :ivar pubsocket: ZMQ PUB socket for sending commands to the Robot Interface.
:ivar address: Address to bind/connect the PUB socket. :ivar address: Address to bind/connect the PUB socket.
:ivar bind: Whether to bind or connect the PUB socket. :ivar bind: Whether to bind or connect the PUB socket.
:ivar gesture_data: A list of strings for available gestures :ivar gesture_tags: A list of strings for available gesture tags
:ivar gesture_basic: A list of strings for 10 basisc gestures
:ivar gesture_single: A list of strings for all available gestures
""" """
subsocket: azmq.Socket subsocket: azmq.Socket
repsocket: azmq.Socket
pubsocket: azmq.Socket pubsocket: azmq.Socket
address = "" address = ""
bind = False bind = False
gesture_data = [] gesture_tags = []
gesture_basic = []
gesture_single = []
def __init__( def __init__(
self, self,
name: str, name: str,
address=settings.zmq_settings.ri_command_address, address: str,
bind=False, bind=False,
gesture_data=None, gesture_tags=None,
gesture_basic=None,
gesture_single=None,
): ):
if gesture_data is None: self.gesture_tags = gesture_tags or []
self.gesture_data = [] self.gesture_basic = gesture_basic or []
else: self.gesture_single = gesture_single or []
self.gesture_data = gesture_data
super().__init__(name) super().__init__(name)
self.address = address self.address = address
self.bind = bind self.bind = bind
@@ -56,9 +62,8 @@ class RobotGestureAgent(BaseAgent):
context = azmq.Context.instance() context = azmq.Context.instance()
# To the robot # To the robot
self.pubsocket = context.socket(zmq.PUB) self.pubsocket = context.socket(zmq.PUB)
if self.bind: # TODO: Should this ever be the case? if self.bind:
self.pubsocket.bind(self.address) self.pubsocket.bind(self.address)
else: else:
self.pubsocket.connect(self.address) self.pubsocket.connect(self.address)
@@ -69,6 +74,10 @@ class RobotGestureAgent(BaseAgent):
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures") self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures")
# REP socket for replying to gesture requests
self.repsocket = context.socket(zmq.REP)
self.repsocket.bind(settings.zmq_settings.internal_gesture_rep_adress)
self.add_behavior(self._zmq_command_loop()) self.add_behavior(self._zmq_command_loop())
self.add_behavior(self._fetch_gestures_loop()) self.add_behavior(self._fetch_gestures_loop())
@@ -92,13 +101,19 @@ class RobotGestureAgent(BaseAgent):
try: try:
gesture_command = GestureCommand.model_validate_json(msg.body) gesture_command = GestureCommand.model_validate_json(msg.body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG: if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.availableTags(): if gesture_command.data not in self.gesture_tags:
self.logger.warning( self.logger.warning(
"Received gesture tag '%s' which is not in available tags. Early returning", "Received gesture tag '%s' which is not in available tags. Early returning",
gesture_command.data, gesture_command.data,
) )
return return
elif gesture_command.endpoint == RIEndpoint.GESTURE_SINGLE:
if gesture_command.data not in self.gesture_single:
self.logger.warning(
"Received gesture '%s' which is not in available gestures. Early returning",
gesture_command.data,
)
return
await self.pubsocket.send_json(gesture_command.model_dump()) await self.pubsocket.send_json(gesture_command.model_dump())
except Exception: except Exception:
self.logger.exception("Error processing internal message.") self.logger.exception("Error processing internal message.")
@@ -120,7 +135,7 @@ class RobotGestureAgent(BaseAgent):
body = json.loads(body) body = json.loads(body)
gesture_command = GestureCommand.model_validate(body) gesture_command = GestureCommand.model_validate(body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG: if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.availableTags(): if gesture_command.data not in (self.gesture_tags or self.gesture_single):
self.logger.warning( self.logger.warning(
"Received gesture tag '%s' which is not in available tags.\ "Received gesture tag '%s' which is not in available tags.\
Early returning", Early returning",
@@ -133,163 +148,39 @@ class RobotGestureAgent(BaseAgent):
async def _fetch_gestures_loop(self): async def _fetch_gestures_loop(self):
""" """
Loop to handle fetching gestures received via ZMQ (e.g., from the UI). REP socket handler for gesture queries.
Supports:
Listens on the 'send_gestures' topic, and returns a list on the get_gestures topic. - tags
- basic_gestures
- single_gestures
""" """
while self._running: while self._running:
try: try:
topic, body = await self.subsocket.recv_multipart() req = await self.repsocket.recv_json()
# Don't process commands here req_type = req.get("type")
if topic != b"send_gestures": amount = req.get("count")
if req_type == "tags":
data = self.gesture_tags
key = "tags"
elif req_type == "basic":
data = self.gesture_basic
key = "basic_gestures"
elif req_type == "single":
data = self.gesture_single
key = "single_gestures"
else:
await self.repsocket.send_json({})
continue continue
try: if amount:
body = json.loads(body) data = data[:amount]
except json.JSONDecodeError:
body = None
# We could have the body be the nummer of gestures you want to fetch or something. await self.repsocket.send_json({key: data})
amount = None
if isinstance(body, int):
amount = body
tags = self.availableTags()[:amount] if amount else self.availableTags()
response = json.dumps({"tags": tags}).encode()
await self.pubsocket.send_multipart(
[
b"get_gestures",
response,
]
)
except Exception: except Exception:
self.logger.exception("Error fetching gesture tags.") self.logger.exception("Error fetching gestures.")
def availableTags(self):
"""
Returns the available gesture tags.
:return: List of available gesture tags.
"""
return [
"above",
"affirmative",
"afford",
"agitated",
"all",
"allright",
"alright",
"any",
"assuage",
"assuage",
"attemper",
"back",
"bashful",
"beg",
"beseech",
"blank",
"body language",
"bored",
"bow",
"but",
"call",
"calm",
"choose",
"choice",
"cloud",
"cogitate",
"cool",
"crazy",
"disappointed",
"down",
"earth",
"empty",
"embarrassed",
"enthusiastic",
"entire",
"estimate",
"except",
"exalted",
"excited",
"explain",
"far",
"field",
"floor",
"forlorn",
"friendly",
"front",
"frustrated",
"gentle",
"gift",
"give",
"ground",
"happy",
"hello",
"her",
"here",
"hey",
"hi",
"him",
"hopeless",
"hysterical",
"I",
"implore",
"indicate",
"joyful",
"me",
"meditate",
"modest",
"negative",
"nervous",
"no",
"not know",
"nothing",
"offer",
"ok",
"once upon a time",
"oppose",
"or",
"pacify",
"pick",
"placate",
"please",
"present",
"proffer",
"quiet",
"reason",
"refute",
"reject",
"rousing",
"sad",
"select",
"shamefaced",
"show",
"show sky",
"sky",
"soothe",
"sun",
"supplicate",
"tablet",
"tall",
"them",
"there",
"think",
"timid",
"top",
"unless",
"up",
"upstairs",
"void",
"warm",
"winner",
"yeah",
"yes",
"yoo-hoo",
"you",
"your",
"zero",
"zestful",
]

View File

@@ -29,7 +29,7 @@ class RobotSpeechAgent(BaseAgent):
def __init__( def __init__(
self, self,
name: str, name: str,
address=settings.zmq_settings.ri_command_address, address: str,
bind=False, bind=False,
): ):
super().__init__(name) super().__init__(name)

View File

@@ -60,24 +60,41 @@ class BDIProgramManager(BaseAgent):
await self.send(message) await self.send(message)
self.logger.debug("Sent new norms and goals to the BDI agent.") self.logger.debug("Sent new norms and goals to the BDI agent.")
async def _send_clear_llm_history(self):
"""
Clear the LLM Agent's conversation history.
Sends an empty history to the LLM Agent to reset its state.
"""
message = InternalMessage(
to=settings.agent_settings.llm_name,
sender=self.name,
body="clear_history",
threads="clear history message",
)
await self.send(message)
self.logger.debug("Sent message to LLM agent to clear history.")
async def _receive_programs(self): async def _receive_programs(self):
""" """
Continuous loop that receives program updates from the HTTP endpoint. Continuous loop that receives program updates from the HTTP endpoint.
It listens to the ``program`` topic on the internal ZMQ SUB socket. It listens to the ``program`` topic on the internal ZMQ SUB socket.
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`. When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`.
""" """
while True: while True:
topic, body = await self.sub_socket.recv_multipart() topic, body = await self.sub_socket.recv_multipart()
try: try:
program = Program.model_validate_json(body) program = Program.model_validate_json(body)
await self._send_to_bdi(program)
await self._send_clear_llm_history()
except ValidationError: except ValidationError:
self.logger.exception("Received an invalid program.") self.logger.exception("Received an invalid program.")
continue continue
await self._send_to_bdi(program)
async def setup(self): async def setup(self):
""" """
Initialize the agent. Initialize the agent.

View File

@@ -10,6 +10,7 @@ from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAge
from control_backend.core.config import settings from control_backend.core.config import settings
from ..actuation.robot_speech_agent import RobotSpeechAgent from ..actuation.robot_speech_agent import RobotSpeechAgent
from ..perception import VADAgent
class RICommunicationAgent(BaseAgent): class RICommunicationAgent(BaseAgent):
@@ -37,7 +38,7 @@ class RICommunicationAgent(BaseAgent):
def __init__( def __init__(
self, self,
name: str, name: str,
address=settings.zmq_settings.ri_command_address, address=settings.zmq_settings.ri_communication_address,
bind=False, bind=False,
): ):
super().__init__(name) super().__init__(name)
@@ -167,7 +168,7 @@ class RICommunicationAgent(BaseAgent):
bind = port_data["bind"] bind = port_data["bind"]
if not bind: if not bind:
addr = f"tcp://localhost:{port}" addr = f"tcp://{settings.ri_host}:{port}"
else: else:
addr = f"tcp://*:{port}" addr = f"tcp://*:{port}"
@@ -180,7 +181,9 @@ class RICommunicationAgent(BaseAgent):
else: else:
self._req_socket.bind(addr) self._req_socket.bind(addr)
case "actuation": case "actuation":
gesture_data = port_data.get("gestures", []) gesture_tags = port_data.get("gestures", [])
gesture_single = port_data.get("single_gestures", [])
gesture_basic = port_data.get("basic_gestures", [])
robot_speech_agent = RobotSpeechAgent( robot_speech_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name, settings.agent_settings.robot_speech_name,
address=addr, address=addr,
@@ -190,11 +193,16 @@ class RICommunicationAgent(BaseAgent):
settings.agent_settings.robot_gesture_name, settings.agent_settings.robot_gesture_name,
address=addr, address=addr,
bind=bind, bind=bind,
gesture_data=gesture_data, gesture_tags=gesture_tags,
gesture_basic=gesture_basic,
gesture_single=gesture_single,
) )
await robot_speech_agent.start() await robot_speech_agent.start()
await asyncio.sleep(0.1) # Small delay await asyncio.sleep(0.1) # Small delay
await robot_gesture_agent.start() await robot_gesture_agent.start()
case "audio":
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
await vad_agent.start()
case _: case _:
self.logger.warning("Unhandled negotiation id: %s", id) self.logger.warning("Unhandled negotiation id: %s", id)

View File

@@ -52,6 +52,10 @@ class LLMAgent(BaseAgent):
await self._process_bdi_message(prompt_message) await self._process_bdi_message(prompt_message)
except ValidationError: except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.") self.logger.debug("Prompt message from BDI core is invalid.")
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
if msg.body == "clear_history":
self.logger.debug("Clearing conversation history.")
self.history.clear()
else: else:
self.logger.debug("Message ignored (not from BDI core.") self.logger.debug("Message ignored (not from BDI core.")
@@ -125,7 +129,7 @@ class LLMAgent(BaseAgent):
full_message += token full_message += token
current_chunk += token current_chunk += token
self.logger.info( self.logger.llm(
"Received token: %s", "Received token: %s",
full_message, full_message,
extra={"reference": message_id}, # Used in the UI to update old logs extra={"reference": message_id}, # Used in the UI to update old logs

View File

@@ -8,6 +8,7 @@ import zmq.asyncio as azmq
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
from .transcription_agent.transcription_agent import TranscriptionAgent from .transcription_agent.transcription_agent import TranscriptionAgent
@@ -61,6 +62,7 @@ class VADAgent(BaseAgent):
:ivar audio_in_address: Address of the input audio stream. :ivar audio_in_address: Address of the input audio stream.
:ivar audio_in_bind: Whether to bind or connect to the input address. :ivar audio_in_bind: Whether to bind or connect to the input address.
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments. :ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
""" """
def __init__(self, audio_in_address: str, audio_in_bind: bool): def __init__(self, audio_in_address: str, audio_in_bind: bool):
@@ -79,6 +81,8 @@ class VADAgent(BaseAgent):
self.audio_out_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None
self.audio_in_poller: SocketPoller | None = None self.audio_in_poller: SocketPoller | None = None
self.program_sub_socket: azmq.Socket | None = None
self.audio_buffer = np.array([], dtype=np.float32) self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = asyncio.Event() self._ready = asyncio.Event()
@@ -90,20 +94,25 @@ class VADAgent(BaseAgent):
1. Connects audio input socket. 1. Connects audio input socket.
2. Binds audio output socket (random port). 2. Binds audio output socket (random port).
3. Loads VAD model from Torch Hub. 3. Connects to program communication socket.
4. Starts the streaming loop. 4. Loads VAD model from Torch Hub.
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address. 5. Starts the streaming loop.
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
""" """
self.logger.info("Setting up %s", self.name) self.logger.info("Setting up %s", self.name)
self._connect_audio_in_socket() self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket() audio_out_address = self._connect_audio_out_socket()
if audio_out_port is None: if audio_out_address is None:
self.logger.error("Could not bind output socket, stopping.") self.logger.error("Could not bind output socket, stopping.")
await self.stop() await self.stop()
return return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Connect to internal communication socket
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.program_sub_socket.subscribe(PROGRAM_STATUS)
# Initialize VAD model # Initialize VAD model
try: try:
@@ -117,10 +126,8 @@ class VADAgent(BaseAgent):
await self.stop() await self.stop()
return return
# Warmup/reset
await self.reset_stream()
self.add_behavior(self._streaming_loop()) self.add_behavior(self._streaming_loop())
self.add_behavior(self._status_loop())
# Start agents dependent on the output audio fragments here # Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address) transcriber = TranscriptionAgent(audio_out_address)
@@ -153,19 +160,20 @@ class VADAgent(BaseAgent):
self.audio_in_socket.connect(self.audio_in_address) self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket) self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None: def _connect_audio_out_socket(self) -> str | None:
""" """
Returns the port bound, or None if binding failed. Returns the address that was bound to, or None if binding failed.
""" """
try: try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100) self.audio_out_socket.bind(settings.zmq_settings.vad_pub_address)
return settings.zmq_settings.vad_pub_address
except zmq.ZMQBindError: except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.") self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None self.audio_out_socket = None
return None return None
async def reset_stream(self): async def _reset_stream(self):
""" """
Clears the ZeroMQ queue and sets ready state. Clears the ZeroMQ queue and sets ready state.
""" """
@@ -176,6 +184,23 @@ class VADAgent(BaseAgent):
self.logger.info(f"Discarded {discarded} audio packets before starting.") self.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready.set() self._ready.set()
async def _status_loop(self):
"""Loop for checking program status. Only start listening if program is RUNNING."""
while self._running:
topic, body = await self.program_sub_socket.recv_multipart()
if topic != PROGRAM_STATUS:
continue
if body != ProgramStatus.RUNNING.value:
continue
# Program is now running, we can start our stream
await self._reset_stream()
# We don't care about further status updates
self.program_sub_socket.close()
break
async def _streaming_loop(self): async def _streaming_loop(self):
""" """
Main loop for processing audio stream. Main loop for processing audio stream.

View File

@@ -0,0 +1,146 @@
import json
import zmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, SpeechCommand
class UserInterruptAgent(BaseAgent):
"""
User Interrupt Agent.
This agent receives button_pressed events from the external HTTP API
(via ZMQ) and uses the associated context to trigger one of the following actions:
- Send a prioritized message to the `RobotSpeechAgent`
- Send a prioritized gesture to the `RobotGestureAgent`
- Send a belief override to the `BDIProgramManager`in order to activate a
trigger/conditional norm or complete a goal.
Prioritized actions clear the current RI queue before inserting the new item,
ensuring they are executed immediately after Pepper's current action has been fulfilled.
:ivar sub_socket: The ZMQ SUB socket used to receive user intterupts.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
async def _receive_button_event(self):
"""
The behaviour of the UserInterruptAgent.
Continuous loop that receives button_pressed events from the button_pressed HTTP endpoint.
These events contain a type and a context.
These are the different types and contexts:
- type: "speech", context: string that the robot has to say.
- type: "gesture", context: single gesture name that the robot has to perform.
- type: "override", context: belief_id that overrides the goal/trigger/conditional norm.
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
try:
event_data = json.loads(body)
event_type = event_data.get("type") # e.g., "speech", "gesture"
event_context = event_data.get("context") # e.g., "Hello, I am Pepper!"
except json.JSONDecodeError:
self.logger.error("Received invalid JSON payload on topic %s", topic)
continue
if event_type == "speech":
await self._send_to_speech_agent(event_context)
self.logger.info(
"Forwarded button press (speech) with context '%s' to RobotSpeechAgent.",
event_context,
)
elif event_type == "gesture":
await self._send_to_gesture_agent(event_context)
self.logger.info(
"Forwarded button press (gesture) with context '%s' to RobotGestureAgent.",
event_context,
)
elif event_type == "override":
await self._send_to_program_manager(event_context)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
event_context,
)
else:
self.logger.warning(
"Received button press with unknown type '%s' (context: '%s').",
event_type,
event_context,
)
async def _send_to_speech_agent(self, text_to_say: str):
"""
method to send prioritized speech command to RobotSpeechAgent.
:param text_to_say: The string that the robot has to say.
"""
cmd = SpeechCommand(data=text_to_say, is_priority=True)
out_msg = InternalMessage(
to=settings.agent_settings.robot_speech_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
async def _send_to_gesture_agent(self, single_gesture_name: str):
"""
method to send prioritized gesture command to RobotGestureAgent.
:param single_gesture_name: The gesture tag that the robot has to perform.
"""
# the endpoint is set to always be GESTURE_SINGLE for user interrupts
cmd = GestureCommand(
endpoint=RIEndpoint.GESTURE_SINGLE, data=single_gesture_name, is_priority=True
)
out_msg = InternalMessage(
to=settings.agent_settings.robot_gesture_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
async def _send_to_program_manager(self, belief_id: str):
"""
Send a button_override belief to the BDIProgramManager.
:param belief_id: The belief_id that overrides the goal/trigger/conditional norm.
this id can belong to a basic belief or an inferred belief.
See also: https://utrechtuniversity.youtrack.cloud/articles/N25B-A-27/UI-components
"""
data = {"belief": belief_id}
message = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
sender=self.name,
body=json.dumps(data),
thread="belief_override_id",
)
await self.send(message)
self.logger.info(
"Sent button_override belief with id '%s' to Program manager.",
belief_id,
)
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic.
Starts the background behavior to receive the user interrupts.
"""
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("button_pressed")
self.add_behavior(self._receive_button_event())

View File

@@ -0,0 +1,31 @@
import logging
from fastapi import APIRouter, Request
from control_backend.schemas.events import ButtonPressedEvent
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/button_pressed", status_code=202)
async def receive_button_event(event: ButtonPressedEvent, request: Request):
"""
Endpoint to handle external button press events.
Validates the event payload and publishes it to the internal 'button_pressed' topic.
Subscribers (in this case user_interrupt_agent) will pick this up to trigger
specific behaviors or state changes.
:param event: The parsed ButtonPressedEvent object.
:param request: The FastAPI request object.
"""
logger.debug("Received button event: %s | %s", event.type, event.context)
topic = b"button_pressed"
body = event.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Event received"}

View File

@@ -3,9 +3,8 @@ import json
import logging import logging
import zmq.asyncio import zmq.asyncio
from fastapi import APIRouter, HTTPException, Request from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import ValidationError
from zmq.asyncio import Context, Socket from zmq.asyncio import Context, Socket
from control_backend.core.config import settings from control_backend.core.config import settings
@@ -16,38 +15,44 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@router.post("/command", status_code=202) @router.post("/command/speech", status_code=202)
async def receive_command(command: SpeechCommand, request: Request): async def receive_command_speech(command: SpeechCommand, request: Request):
""" """
Send a direct speech command to the robot. Send a direct speech command to the robot.
Publishes the command to the internal 'command' topic. The Publishes the command to the internal 'command' topic. The
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent` :class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`
or will forward this to the robot.
:param command: The speech command payload.
:param request: The FastAPI request object.
"""
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Speech command received"}
@router.post("/command/gesture", status_code=202)
async def receive_command_gesture(command: GestureCommand, request: Request):
"""
Send a direct gesture command to the robot.
Publishes the command to the internal 'command' topic. The
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotGestureAgent` :class:`~control_backend.agents.actuation.robot_speech_agent.RobotGestureAgent`
will forward this to the robot. will forward this to the robot.
:param command: The speech command payload. :param command: The speech command payload.
:param request: The FastAPI request object. :param request: The FastAPI request object.
""" """
# Validate and retrieve data.
validated = None
valid_commands = (GestureCommand, SpeechCommand)
for command_model in valid_commands:
try:
validated = command_model.model_validate(command)
except ValidationError:
continue
if validated is None:
raise HTTPException(status_code=422, detail="Payload is not valid for command models")
topic = b"command" topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, validated.model_dump_json().encode()]) await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"} return {"status": "Gesture command received"}
@router.get("/ping_check") @router.get("/ping_check")
@@ -58,31 +63,28 @@ async def ping(request: Request):
pass pass
@router.get("/get_available_gesture_tags") @router.get("/commands/gesture/tags")
async def get_available_gesture_tags(request: Request): async def get_available_gesture_tags(request: Request, count=0):
""" """
Endpoint to retrieve the available gesture tags for the robot. Endpoint to retrieve the available gesture tags for the robot.
:param request: The FastAPI request object. :param request: The FastAPI request object.
:return: A list of available gesture tags. :return: A list of available gesture tags.
""" """
sub_socket = Context.instance().socket(zmq.SUB) req_socket = Context.instance().socket(zmq.REQ)
sub_socket.connect(settings.zmq_settings.internal_sub_address) req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress)
sub_socket.setsockopt(zmq.SUBSCRIBE, b"get_gestures")
pub_socket: Socket = request.app.state.endpoints_pub_socket # Check to see if we've got any count given in the query parameter
topic = b"send_gestures" amount = count or None
# TODO: Implement a way to get a certain ammount from the UI, rather than everything.
amount = None
timeout = 5 # seconds timeout = 5 # seconds
await pub_socket.send_multipart([topic, amount.to_bytes(4, "big") if amount else b""]) await req_socket.send_json({"type": "tags", "count": amount})
try: try:
_, body = await asyncio.wait_for(sub_socket.recv_multipart(), timeout=timeout) body = await asyncio.wait_for(req_socket.recv(), timeout=timeout)
except TimeoutError: except TimeoutError:
body = b"tags: []" body = '{"tags": []}'
logger.debug("got timeout error fetching gestures") logger.debug("Got timeout error fetching gestures.")
# Handle empty response and JSON decode errors # Handle empty response and JSON decode errors
available_tags = [] available_tags = []
@@ -93,8 +95,75 @@ async def get_available_gesture_tags(request: Request):
logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}") logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}")
# Return empty list on JSON error # Return empty list on JSON error
available_tags = [] available_tags = []
return {"available_gestures": available_tags}
return {"available_gesture_tags": available_tags}
@router.get("/commands/gesture/single")
async def get_available_gestures(request: Request, count=0):
"""
Endpoint to retrieve the available gestures for the robot.
:param request: The FastAPI request object.
:return: A list of available gestures.
"""
req_socket = Context.instance().socket(zmq.REQ)
req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress)
# Check to see if we've got any count given in the query parameter
amount = count or None
timeout = 5 # seconds
await req_socket.send_json({"type": "single", "count": amount})
try:
body = await asyncio.wait_for(req_socket.recv(), timeout=timeout)
except TimeoutError:
body = '{"tags": []}'
logger.debug("Got timeout error fetching gestures.")
# Handle empty response and JSON decode errors
available_tags = []
if body:
try:
available_tags = json.loads(body).get("single_gestures", [])
except json.JSONDecodeError as e:
logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}")
# Return empty list on JSON error
available_tags = []
return {"available_gestures": available_tags}
@router.get("/commands/gesture/basic")
async def get_available_basic_gestures(request: Request, count=0):
"""
Endpoint to retrieve the available gesture tags for the robot.
:param request: The FastAPI request object.
:return: A list of 10 available gestures.
"""
req_socket = Context.instance().socket(zmq.REQ)
req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress)
# Check to see if we've got any count given in the query parameter
amount = count or None
timeout = 5 # seconds
await req_socket.send_json({"type": "basic", "count": amount})
try:
body = await asyncio.wait_for(req_socket.recv(), timeout=timeout)
except TimeoutError:
body = '{"tags": []}'
logger.debug("Got timeout error fetching gestures.")
# Handle empty response and JSON decode errors
available_tags = []
if body:
try:
available_tags = json.loads(body).get("basic_gestures", [])
except json.JSONDecodeError as e:
logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}")
# Return empty list on JSON error
available_tags = []
return {"available_gestures": available_tags}
@router.get("/ping_stream") @router.get("/ping_stream")

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import logs, message, program, robot, sse from control_backend.api.v1.endpoints import button_pressed, logs, message, program, robot, sse
api_router = APIRouter() api_router = APIRouter()
@@ -13,3 +13,5 @@ api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Command
api_router.include_router(logs.router, tags=["Logs"]) api_router.include_router(logs.router, tags=["Logs"])
api_router.include_router(program.router, tags=["Program"]) api_router.include_router(program.router, tags=["Program"])
api_router.include_router(button_pressed.router, tags=["Button Pressed Events"])

View File

@@ -1,3 +1,12 @@
"""
An exhaustive overview of configurable options. All of these can be set using environment variables
by nesting with double underscores (__). Start from the ``Settings`` class.
For example, ``settings.ri_host`` becomes ``RI_HOST``, and
``settings.zmq_settings.ri_communication_address`` becomes
``ZMQ_SETTINGS__RI_COMMUNICATION_ADDRESS``.
"""
from pydantic import BaseModel from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -8,16 +17,17 @@ class ZMQSettings(BaseModel):
:ivar internal_pub_address: Address for the internal PUB socket. :ivar internal_pub_address: Address for the internal PUB socket.
:ivar internal_sub_address: Address for the internal SUB socket. :ivar internal_sub_address: Address for the internal SUB socket.
:ivar ri_command_address: Address for sending commands to the Robot Interface. :ivar ri_communication_address: Address for the endpoint that the Robot Interface connects to.
:ivar ri_communication_address: Address for receiving communication from the Robot Interface. :ivar vad_pub_address: Address that the VAD agent binds to and publishes audio segments to.
:ivar vad_agent_address: Address for the Voice Activity Detection (VAD) agent.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
internal_pub_address: str = "tcp://localhost:5560" internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561" internal_sub_address: str = "tcp://localhost:5561"
ri_command_address: str = "tcp://localhost:0000"
ri_communication_address: str = "tcp://*:5555" ri_communication_address: str = "tcp://*:5555"
vad_agent_address: str = "tcp://localhost:5558" internal_gesture_rep_adress: str = "tcp://localhost:7788"
vad_pub_address: str = "inproc://vad_stream"
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
@@ -36,6 +46,8 @@ class AgentSettings(BaseModel):
:ivar robot_speech_name: Name of the Robot Speech Agent. :ivar robot_speech_name: Name of the Robot Speech Agent.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
# agent names # agent names
bdi_core_name: str = "bdi_core_agent" bdi_core_name: str = "bdi_core_agent"
bdi_belief_collector_name: str = "belief_collector_agent" bdi_belief_collector_name: str = "belief_collector_agent"
@@ -48,6 +60,7 @@ class AgentSettings(BaseModel):
ri_communication_name: str = "ri_communication_agent" ri_communication_name: str = "ri_communication_agent"
robot_speech_name: str = "robot_speech_agent" robot_speech_name: str = "robot_speech_agent"
robot_gesture_name: str = "robot_gesture_agent" robot_gesture_name: str = "robot_gesture_agent"
user_interrupt_name: str = "user_interrupt_agent"
class BehaviourSettings(BaseModel): class BehaviourSettings(BaseModel):
@@ -66,6 +79,8 @@ class BehaviourSettings(BaseModel):
:ivar transcription_token_buffer: Buffer for transcription tokens. :ivar transcription_token_buffer: Buffer for transcription tokens.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
sleep_s: float = 1.0 sleep_s: float = 1.0
comm_setup_max_retries: int = 5 comm_setup_max_retries: int = 5
socket_poller_timeout_ms: int = 100 socket_poller_timeout_ms: int = 100
@@ -90,6 +105,8 @@ class LLMSettings(BaseModel):
:ivar local_llm_model: Name of the local LLM model to use. :ivar local_llm_model: Name of the local LLM model to use.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "gpt-oss" local_llm_model: str = "gpt-oss"
@@ -103,6 +120,8 @@ class VADSettings(BaseModel):
:ivar sample_rate_hz: Sample rate in Hz for the VAD model. :ivar sample_rate_hz: Sample rate in Hz for the VAD model.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
repo_or_dir: str = "snakers4/silero-vad" repo_or_dir: str = "snakers4/silero-vad"
model_name: str = "silero_vad" model_name: str = "silero_vad"
sample_rate_hz: int = 16000 sample_rate_hz: int = 16000
@@ -116,6 +135,8 @@ class SpeechModelSettings(BaseModel):
:ivar openai_model_name: Model name for OpenAI-based speech recognition. :ivar openai_model_name: Model name for OpenAI-based speech recognition.
""" """
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
# model identifiers for speech recognition # model identifiers for speech recognition
mlx_model_name: str = "mlx-community/whisper-small.en-mlx" mlx_model_name: str = "mlx-community/whisper-small.en-mlx"
openai_model_name: str = "small.en" openai_model_name: str = "small.en"
@@ -127,6 +148,7 @@ class Settings(BaseSettings):
:ivar app_title: Title of the application. :ivar app_title: Title of the application.
:ivar ui_url: URL of the frontend UI. :ivar ui_url: URL of the frontend UI.
:ivar ri_host: The hostname of the Robot Interface.
:ivar zmq_settings: ZMQ configuration. :ivar zmq_settings: ZMQ configuration.
:ivar agent_settings: Agent name configuration. :ivar agent_settings: Agent name configuration.
:ivar behaviour_settings: Behavior configuration. :ivar behaviour_settings: Behavior configuration.
@@ -139,6 +161,8 @@ class Settings(BaseSettings):
ui_url: str = "http://localhost:5173" ui_url: str = "http://localhost:5173"
ri_host: str = "localhost"
zmq_settings: ZMQSettings = ZMQSettings() zmq_settings: ZMQSettings = ZMQSettings()
agent_settings: AgentSettings = AgentSettings() agent_settings: AgentSettings = AgentSettings()

View File

@@ -4,6 +4,7 @@ import os
import yaml import yaml
import zmq import zmq
from zmq.log.handlers import PUBHandler
from control_backend.core.config import settings from control_backend.core.config import settings
@@ -51,15 +52,27 @@ def setup_logging(path: str = ".logging_config.yaml") -> None:
logging.warning(f"Could not load logging configuration: {e}") logging.warning(f"Could not load logging configuration: {e}")
config = {} config = {}
if "custom_levels" in config: custom_levels = config.get("custom_levels", {}) or {}
for level_name, level_num in config["custom_levels"].items(): for level_name, level_num in custom_levels.items():
add_logging_level(level_name, level_num) add_logging_level(level_name, level_num)
if config.get("handlers") is not None and config.get("handlers").get("ui"): if config.get("handlers") is not None and config.get("handlers").get("ui"):
pub_socket = zmq.Context.instance().socket(zmq.PUB) pub_socket = zmq.Context.instance().socket(zmq.PUB)
pub_socket.connect(settings.zmq_settings.internal_pub_address) pub_socket.connect(settings.zmq_settings.internal_pub_address)
config["handlers"]["ui"]["interface_or_socket"] = pub_socket config["handlers"]["ui"]["interface_or_socket"] = pub_socket
logging.config.dictConfig(config) logging.config.dictConfig(config)
# Patch ZMQ PUBHandler to know about custom levels
if custom_levels:
for logger_name in ("control_backend",):
logger = logging.getLogger(logger_name)
for handler in logger.handlers:
if isinstance(handler, PUBHandler):
# Use the INFO formatter as the default template
default_fmt = handler.formatters[logging.INFO]
for level_num in custom_levels.values():
handler.setFormatter(default_fmt, level=level_num)
else: else:
logging.warning("Logging config file not found. Using default logging configuration.") logging.warning("Logging config file not found. Using default logging configuration.")

View File

@@ -39,13 +39,14 @@ from control_backend.agents.communication import RICommunicationAgent
# LLM Agents # LLM Agents
from control_backend.agents.llm import LLMAgent from control_backend.agents.llm import LLMAgent
# Perceive agents # User Interrupt Agent
from control_backend.agents.perception import VADAgent from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
# Other backend imports # Other backend imports
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.logging import setup_logging from control_backend.logging import setup_logging
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -95,6 +96,8 @@ async def lifespan(app: FastAPI):
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address) endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
app.state.endpoints_pub_socket = endpoints_pub_socket app.state.endpoints_pub_socket = endpoints_pub_socket
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STARTING.value])
# --- Initialize Agents --- # --- Initialize Agents ---
logger.info("Initializing and starting agents.") logger.info("Initializing and starting agents.")
@@ -132,46 +135,44 @@ async def lifespan(app: FastAPI):
"name": settings.agent_settings.text_belief_extractor_name, "name": settings.agent_settings.text_belief_extractor_name,
}, },
), ),
"VADAgent": (
VADAgent,
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
),
"ProgramManagerAgent": ( "ProgramManagerAgent": (
BDIProgramManager, BDIProgramManager,
{ {
"name": settings.agent_settings.bdi_program_manager_name, "name": settings.agent_settings.bdi_program_manager_name,
}, },
), ),
"UserInterruptAgent": (
UserInterruptAgent,
{
"name": settings.agent_settings.user_interrupt_name,
},
),
} }
agents = [] agents = []
vad_agent = None
for name, (agent_class, kwargs) in agents_to_start.items(): for name, (agent_class, kwargs) in agents_to_start.items():
try: try:
logger.debug("Starting agent: %s", name) logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**kwargs) agent_instance = agent_class(**kwargs)
await agent_instance.start() await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance
agents.append(agent_instance) agents.append(agent_instance)
logger.info("Agent '%s' started successfully.", name) logger.info("Agent '%s' started successfully.", name)
except Exception as e: except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True) logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
raise raise
assert vad_agent is not None
await vad_agent.reset_stream()
logger.info("Application startup complete.") logger.info("Application startup complete.")
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.RUNNING.value])
yield yield
# --- APPLICATION SHUTDOWN --- # --- APPLICATION SHUTDOWN ---
logger.info("%s is shutting down.", app.title) logger.info("%s is shutting down.", app.title)
# Potential shutdown logic goes here await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
# Additional shutdown logic goes here
logger.info("Application shutdown complete.") logger.info("Application shutdown complete.")

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ButtonPressedEvent(BaseModel):
type: str
context: str

View File

@@ -0,0 +1,16 @@
from enum import Enum
PROGRAM_STATUS = b"internal/program_status"
"""A topic key for the program status."""
class ProgramStatus(Enum):
"""
Used in internal communication, to tell agents what the status of the program is.
For example, the VAD agent only starts listening when the program is RUNNING.
"""
STARTING = b"starting"
RUNNING = b"running"
STOPPING = b"stopping"

View File

@@ -12,6 +12,7 @@ class RIEndpoint(str, Enum):
SPEECH = "actuate/speech" SPEECH = "actuate/speech"
GESTURE_SINGLE = "actuate/gesture/single" GESTURE_SINGLE = "actuate/gesture/single"
GESTURE_TAG = "actuate/gesture/tag" GESTURE_TAG = "actuate/gesture/tag"
GESTURE_BASIC = "actuate/gesture/single"
PING = "ping" PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports" NEGOTIATE_PORTS = "negotiate/ports"
@@ -38,6 +39,7 @@ class SpeechCommand(RIMessage):
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH) endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str data: str
is_priority: bool = False
class GestureCommand(RIMessage): class GestureCommand(RIMessage):
@@ -52,13 +54,11 @@ class GestureCommand(RIMessage):
RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG
] ]
data: str data: str
is_priority: bool = False
@model_validator(mode="after") @model_validator(mode="after")
def check_endpoint(self): def check_endpoint(self):
allowed = { allowed = {RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG}
RIEndpoint.GESTURE_SINGLE,
RIEndpoint.GESTURE_TAG,
}
if self.endpoint not in allowed: if self.endpoint not in allowed:
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG") raise ValueError("endpoint must be GESTURE_SINGLE, GESTURE_TAG or GESTURE_BASIC")
return self return self

View File

@@ -5,6 +5,7 @@ import pytest
import zmq import zmq
from control_backend.agents.perception.vad_agent import VADAgent from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
@pytest.fixture @pytest.fixture
@@ -43,14 +44,12 @@ async def test_normal_setup(per_transcription_agent):
coro.close() coro.close()
per_vad_agent.add_behavior = swallow_background_task per_vad_agent.add_behavior = swallow_background_task
per_vad_agent.reset_stream = AsyncMock()
await per_vad_agent.setup() await per_vad_agent.setup()
per_transcription_agent.assert_called_once() per_transcription_agent.assert_called_once()
per_transcription_agent.return_value.start.assert_called_once() per_transcription_agent.return_value.start.assert_called_once()
per_vad_agent._streaming_loop.assert_called_once() per_vad_agent._streaming_loop.assert_called_once()
per_vad_agent.reset_stream.assert_called_once()
assert per_vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
assert per_vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
@@ -92,7 +91,7 @@ def test_out_socket_creation(zmq_context):
assert per_vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() zmq_context.return_value.socket.return_value.bind.assert_called_once_with("inproc://vad_stream")
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -103,7 +102,7 @@ async def test_out_socket_creation_failure(zmq_context):
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.stop = AsyncMock() per_vad_agent.stop = AsyncMock()
per_vad_agent.reset_stream = AsyncMock() per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock() per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None) per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
@@ -124,7 +123,7 @@ async def test_stop(zmq_context, per_transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.reset_stream = AsyncMock() per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock() per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro): async def swallow_background_task(coro):
@@ -142,3 +141,66 @@ async def test_stop(zmq_context, per_transcription_agent):
assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert per_vad_agent.audio_in_socket is None assert per_vad_agent.audio_in_socket is None
assert per_vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None
@pytest.mark.asyncio
async def test_application_startup_complete(zmq_context):
"""Check that it resets the stream when the program finishes startup."""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
]
await vad_agent._status_loop()
vad_agent._reset_stream.assert_called_once()
vad_agent.program_sub_socket.close.assert_called_once()
@pytest.mark.asyncio
async def test_application_other_status(zmq_context):
"""
Check that it does nothing when the internal communication message is a status update, but not
running.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
]
try:
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()
@pytest.mark.asyncio
async def test_application_message_other(zmq_context):
"""
Check that it does nothing when there's an internal communication message that is not a status
update.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
try:
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()

View File

@@ -11,7 +11,6 @@ from control_backend.schemas.ri_message import RIEndpoint
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
"""Mock the ZMQ context."""
mock_context = mocker.patch( mock_context = mocker.patch(
"control_backend.agents.actuation.robot_gesture_agent.azmq.Context.instance" "control_backend.agents.actuation.robot_gesture_agent.azmq.Context.instance"
) )
@@ -21,61 +20,54 @@ def zmq_context(mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_bind(zmq_context, mocker): async def test_setup_bind(zmq_context, mocker):
"""Setup binds and subscribes to internal commands."""
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=True) agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=True)
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
settings.zmq_settings.internal_gesture_rep_adress = "tcp://internal:5557"
agent.add_behavior = MagicMock() agent.add_behavior = MagicMock()
await agent.setup() await agent.setup()
# Check PUB socket binding
fake_socket.bind.assert_any_call("tcp://localhost:5556") fake_socket.bind.assert_any_call("tcp://localhost:5556")
# Check SUB socket connection and subscriptions
fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures")
assert agent.add_behavior.call_count == 2
# Check behavior was added
agent.add_behavior.assert_called() # Twice, even.
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_connect(zmq_context, mocker): async def test_setup_connect(zmq_context, mocker):
"""Setup connects when bind=False."""
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=False) agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=False)
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
settings.zmq_settings.internal_gesture_rep_adress = "tcp://internal:5557"
agent.add_behavior = MagicMock() agent.add_behavior = MagicMock()
await agent.setup() await agent.setup()
# Check PUB socket connection (not binding)
fake_socket.connect.assert_any_call("tcp://localhost:5556") fake_socket.connect.assert_any_call("tcp://localhost:5556")
fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.bind.assert_called()
# Check behavior was added assert agent.add_behavior.call_count == 2
agent.add_behavior.assert_called() # Twice, actually.
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_sends_valid_gesture_command(): async def test_handle_message_valid_gesture_tag():
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent(
"robot_gesture",
address="",
gesture_tags=["hello"],
)
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
payload = { payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"}
"endpoint": RIEndpoint.GESTURE_TAG,
"data": "hello", # "hello" is in availableTags
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg) await agent.handle_message(msg)
@@ -84,13 +76,16 @@ async def test_handle_message_sends_valid_gesture_command():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_sends_non_gesture_command(): async def test_handle_message_invalid_gesture_tag():
"""Internal message with non-gesture endpoint is not handled by this agent."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent(
"robot_gesture",
address="",
gesture_tags=["hello"],
)
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"} payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "nope"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg) await agent.handle_message(msg)
@@ -99,50 +94,38 @@ async def test_handle_message_sends_non_gesture_command():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_rejects_invalid_gesture_tag(): async def test_handle_message_invalid_payload_logged():
"""Internal message with invalid gesture tag is not forwarded."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent("robot_gesture", address="")
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
agent.logger = MagicMock()
# Use a tag that's not in availableTags msg = InternalMessage(to="robot", sender="tester", body="not json")
payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg) await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited() pubsocket.send_json.assert_not_awaited()
agent.logger.exception.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_message_invalid_payload(): async def test_zmq_command_loop_valid_gesture():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_gesture_payload():
"""UI command with valid gesture tag is read from SUB and published."""
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"}
fake_socket = AsyncMock() fake_socket = AsyncMock()
async def recv_once(): async def recv_once():
# stop after first iteration
agent._running = False agent._running = False
return (b"command", json.dumps(command).encode("utf-8")) return b"command", json.dumps(
{"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"}
).encode()
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent(
"robot_gesture",
address="",
gesture_tags=["hello"],
)
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -153,64 +136,23 @@ async def test_zmq_command_loop_valid_gesture_payload():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_zmq_command_loop_valid_non_gesture_payload(): async def test_zmq_command_loop_invalid_tag():
"""UI command with non-gesture endpoint is not handled by this agent."""
command = {"endpoint": "some_other_endpoint", "data": "anything"}
fake_socket = AsyncMock() fake_socket = AsyncMock()
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"command", json.dumps(command).encode("utf-8")) return b"command", json.dumps(
{"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid"}
).encode()
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent(
agent.subsocket = fake_socket "robot_gesture",
agent.pubsocket = fake_socket address="",
agent._running = True gesture_tags=["hello"],
)
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_invalid_gesture_tag():
"""UI command with invalid gesture tag is not forwarded."""
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_invalid_json():
"""Invalid JSON is ignored without sending."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", b"{not_json}")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -222,17 +164,16 @@ async def test_zmq_command_loop_invalid_json():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_zmq_command_loop_ignores_send_gestures_topic(): async def test_zmq_command_loop_ignores_send_gestures_topic():
"""send_gestures topic is ignored in command loop."""
fake_socket = AsyncMock() fake_socket = AsyncMock()
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"send_gestures", b"{}") return b"send_gestures", b"{}"
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent("robot_gesture", address="")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -243,150 +184,68 @@ async def test_zmq_command_loop_ignores_send_gestures_topic():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fetch_gestures_loop_without_amount(): async def test_fetch_gestures_tags():
"""Fetch gestures request without amount returns all tags.""" fake_repsocket = AsyncMock()
fake_socket = AsyncMock()
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"send_gestures", b"{}") return {"type": "tags"}
fake_socket.recv_multipart = recv_once fake_repsocket.recv_json = recv_once
fake_socket.send_multipart = AsyncMock() fake_repsocket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent(
agent.subsocket = fake_socket "robot_gesture",
agent.pubsocket = fake_socket address="",
gesture_tags=["hello", "yes", "no"],
)
agent.repsocket = fake_repsocket
agent._running = True agent._running = True
await agent._fetch_gestures_loop() await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_awaited_once() fake_repsocket.send_json.assert_awaited_once_with({"tags": ["hello", "yes", "no"]})
# Check the response contains all tags
args, kwargs = fake_socket.send_multipart.call_args
assert args[0][0] == b"get_gestures"
response = json.loads(args[0][1])
assert "tags" in response
assert len(response["tags"]) > 0
# Check it includes some expected tags
assert "hello" in response["tags"]
assert "yes" in response["tags"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fetch_gestures_loop_with_amount(): async def test_fetch_gestures_basic():
"""Fetch gestures request with amount returns limited tags.""" fake_repsocket = AsyncMock()
fake_socket = AsyncMock()
amount = 5
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"send_gestures", json.dumps(amount).encode()) return {"type": "basic"}
fake_socket.recv_multipart = recv_once fake_repsocket.recv_json = recv_once
fake_socket.send_multipart = AsyncMock() fake_repsocket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent(
agent.subsocket = fake_socket "robot_gesture",
agent.pubsocket = fake_socket address="",
gesture_basic=["wave", "point"],
)
agent.repsocket = fake_repsocket
agent._running = True agent._running = True
await agent._fetch_gestures_loop() await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_awaited_once() fake_repsocket.send_json.assert_awaited_once_with({"basic_gestures": ["wave", "point"]})
args, kwargs = fake_socket.send_multipart.call_args
assert args[0][0] == b"get_gestures"
response = json.loads(args[0][1])
assert "tags" in response
assert len(response["tags"]) == amount
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fetch_gestures_loop_ignores_command_topic(): async def test_fetch_gestures_unknown_type():
"""Command topic is ignored in fetch gestures loop.""" fake_repsocket = AsyncMock()
fake_socket = AsyncMock()
async def recv_once(): async def recv_once():
agent._running = False agent._running = False
return (b"command", b"{}") return {"type": "unknown"}
fake_socket.recv_multipart = recv_once fake_repsocket.recv_json = recv_once
fake_socket.send_multipart = AsyncMock() fake_repsocket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture") agent = RobotGestureAgent("robot_gesture", address="")
agent.subsocket = fake_socket agent.repsocket = fake_repsocket
agent.pubsocket = fake_socket
agent._running = True agent._running = True
await agent._fetch_gestures_loop() await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_not_awaited() fake_repsocket.send_json.assert_awaited_once_with({})
@pytest.mark.asyncio
async def test_fetch_gestures_loop_invalid_request():
"""Invalid request body is handled gracefully."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
# Send a non-integer, non-JSON body
return (b"send_gestures", b"not_json")
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._fetch_gestures_loop()
# Should still send a response (all tags)
fake_socket.send_multipart.assert_awaited_once()
def test_available_tags():
"""Test that availableTags returns the expected list."""
agent = RobotGestureAgent("robot_gesture")
tags = agent.availableTags()
assert isinstance(tags, list)
assert len(tags) > 0
# Check some expected tags are present
assert "hello" in tags
assert "yes" in tags
assert "no" in tags
# Check a non-existent tag is not present
assert "invalid_tag_not_in_list" not in tags
@pytest.mark.asyncio
async def test_stop_closes_sockets():
"""Stop method closes both sockets."""
pubsocket = MagicMock()
subsocket = MagicMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
agent.subsocket = subsocket
await agent.stop()
pubsocket.close.assert_called_once()
subsocket.close.assert_called_once()
@pytest.mark.asyncio
async def test_initialization_with_custom_gesture_data():
"""Agent can be initialized with custom gesture data."""
custom_gestures = ["custom1", "custom2", "custom3"]
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures)
# Note: The current implementation doesn't use the gesture_data parameter
# in availableTags(). This test documents that behavior.
# If you update the agent to use gesture_data, update this test accordingly.
assert agent.gesture_data == custom_gestures

View File

@@ -8,6 +8,11 @@ from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
def mock_speech_agent():
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
return agent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch( mock_context = mocker.patch(
@@ -56,10 +61,10 @@ async def test_setup_connect(zmq_context, mocker):
async def test_handle_message_sends_command(): async def test_handle_message_sends_command():
"""Internal message is forwarded to robot pub socket as JSON.""" """Internal message is forwarded to robot pub socket as JSON."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech") agent = mock_speech_agent()
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
payload = {"endpoint": "actuate/speech", "data": "hello"} payload = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg) await agent.handle_message(msg)
@@ -70,7 +75,7 @@ async def test_handle_message_sends_command():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_zmq_command_loop_valid_payload(zmq_context): async def test_zmq_command_loop_valid_payload(zmq_context):
"""UI command is read from SUB and published.""" """UI command is read from SUB and published."""
command = {"endpoint": "actuate/speech", "data": "hello"} command = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False}
fake_socket = AsyncMock() fake_socket = AsyncMock()
async def recv_once(): async def recv_once():
@@ -80,7 +85,7 @@ async def test_zmq_command_loop_valid_payload(zmq_context):
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech") agent = mock_speech_agent()
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -101,7 +106,7 @@ async def test_zmq_command_loop_invalid_json():
fake_socket.recv_multipart = recv_once fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech") agent = mock_speech_agent()
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
agent._running = True agent._running = True
@@ -115,7 +120,7 @@ async def test_zmq_command_loop_invalid_json():
async def test_handle_message_invalid_payload(): async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send.""" """Invalid payload is caught and does not send."""
pubsocket = AsyncMock() pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech") agent = mock_speech_agent()
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"})) msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
@@ -129,7 +134,7 @@ async def test_handle_message_invalid_payload():
async def test_stop_closes_sockets(): async def test_stop_closes_sockets():
pubsocket = MagicMock() pubsocket = MagicMock()
subsocket = MagicMock() subsocket = MagicMock()
agent = RobotSpeechAgent("robot_speech") agent = mock_speech_agent()
agent.pubsocket = pubsocket agent.pubsocket = pubsocket
agent.subsocket = subsocket agent.subsocket = subsocket

View File

@@ -1,4 +1,6 @@
import asyncio
import json import json
import time
from unittest.mock import AsyncMock, MagicMock, mock_open, patch from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import agentspeak import agentspeak
@@ -77,11 +79,6 @@ async def test_incorrect_belief_collector_message(agent, mock_settings):
agent.bdi_agent.call.assert_not_called() # did not set belief agent.bdi_agent.call.assert_not_called() # did not set belief
@pytest.mark.asyncio
async def test():
pass
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_llm_response(agent): async def test_handle_llm_response(agent):
"""Test that LLM responses are forwarded to the Robot Speech Agent""" """Test that LLM responses are forwarded to the Robot Speech Agent"""
@@ -124,3 +121,148 @@ async def test_custom_actions(agent):
next(gen) # Execute next(gen) # Execute
agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal") agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal")
def test_add_belief_sets_event(agent):
"""Test that a belief triggers wake event and call()"""
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="test_belief", arguments=["a", "b"])
agent._apply_beliefs([belief])
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
def test_apply_beliefs_empty_returns(agent):
"""Line: if not beliefs: return"""
agent._wake_bdi_loop = MagicMock()
agent._apply_beliefs([])
agent.bdi_agent.call.assert_not_called()
agent._wake_bdi_loop.set.assert_not_called()
def test_remove_belief_success_wakes_loop(agent):
"""Line: if result: wake set"""
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = True
agent._remove_belief("remove_me", ["x"])
assert agent.bdi_agent.call.called
trigger, goaltype, literal, *_ = agent.bdi_agent.call.call_args.args
assert trigger == agentspeak.Trigger.removal
assert goaltype == agentspeak.GoalType.belief
assert literal.functor == "remove_me"
assert literal.args[0].functor == "x"
agent._wake_bdi_loop.set.assert_called()
def test_remove_belief_failure_does_not_wake(agent):
"""Line: else result is False"""
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = False
agent._remove_belief("not_there", ["y"])
assert agent.bdi_agent.call.called # removal was attempted
agent._wake_bdi_loop.set.assert_not_called()
def test_remove_all_with_name_wakes_loop(agent):
"""Cover _remove_all_with_name() removed counter + wake"""
agent._wake_bdi_loop = MagicMock()
fake_literal = agentspeak.Literal("delete_me", (agentspeak.Literal("arg1"),))
fake_key = ("delete_me", 1)
agent.bdi_agent.beliefs = {fake_key: {fake_literal}}
agent._remove_all_with_name("delete_me")
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
@pytest.mark.asyncio
async def test_bdi_step_true_branch_hits_line_67(agent):
"""Force step() to return True once so line 67 is actually executed"""
# counter that isn't tied to MagicMock.call_count ordering
counter = {"i": 0}
def fake_step():
counter["i"] += 1
return counter["i"] == 1 # True only first time
# Important: wrap fake_step into another mock so `.called` still exists
agent.bdi_agent.step = MagicMock(side_effect=fake_step)
agent.bdi_agent.shortest_deadline = MagicMock(return_value=None)
agent._running = True
agent._wake_bdi_loop = asyncio.Event()
agent._wake_bdi_loop.set()
task = asyncio.create_task(agent._bdi_loop())
await asyncio.sleep(0.01)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert agent.bdi_agent.step.called
assert counter["i"] >= 1 # proves True branch ran
def test_replace_belief_calls_remove_all(agent):
"""Cover: if belief.replace: self._remove_all_with_name()"""
agent._remove_all_with_name = MagicMock()
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="user_said", arguments=["Hello"], replace=True)
agent._apply_beliefs([belief])
agent._remove_all_with_name.assert_called_with("user_said")
@pytest.mark.asyncio
async def test_send_to_llm_creates_prompt_and_sends(agent):
"""Cover entire _send_to_llm() including message send and logger.info"""
agent.bdi_agent = MagicMock() # ensure mocked BDI does not interfere
agent._wake_bdi_loop = MagicMock()
await agent._send_to_llm("hello world", "n1\nn2", "g1")
# send() was called
assert agent.send.called
sent_msg: InternalMessage = agent.send.call_args.args[0]
# Message routing values correct
assert sent_msg.to == settings.agent_settings.llm_name
assert "hello world" in sent_msg.body
# JSON contains split norms/goals
body = json.loads(sent_msg.body)
assert body["norms"] == ["n1", "n2"]
assert body["goals"] == ["g1"]
@pytest.mark.asyncio
async def test_deadline_sleep_branch(agent):
"""Specifically assert the if deadline: sleep → maybe_more_work=True branch"""
future_deadline = time.time() + 0.005
agent.bdi_agent.step.return_value = False
agent.bdi_agent.shortest_deadline.return_value = future_deadline
start_time = time.time()
agent._running = True
agent._wake_bdi_loop = asyncio.Event()
agent._wake_bdi_loop.set()
task = asyncio.create_task(agent._bdi_loop())
await asyncio.sleep(0.01)
task.cancel()
duration = time.time() - start_time
assert duration >= 0.004 # loop slept until deadline

View File

@@ -0,0 +1,99 @@
import asyncio
import json
import sys
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.program import Program
# Fix Windows Proactor loop for zmq
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def make_valid_program_json(norm="N1", goal="G1"):
return json.dumps(
{
"phases": [
{
"id": "phase1",
"label": "Phase 1",
"triggers": [],
"norms": [{"id": "n1", "label": "Norm 1", "norm": norm}],
"goals": [
{"id": "g1", "label": "Goal 1", "description": goal, "achieved": False}
],
}
]
}
)
@pytest.mark.asyncio
async def test_send_to_bdi():
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
program = Program.model_validate_json(make_valid_program_json())
await manager._send_to_bdi(program)
assert manager.send.await_count == 1
msg: InternalMessage = manager.send.await_args[0][0]
assert msg.thread == "beliefs"
beliefs = BeliefMessage.model_validate_json(msg.body)
names = {b.name: b.arguments for b in beliefs.beliefs}
assert "norms" in names and names["norms"] == ["N1"]
assert "goals" in names and names["goals"] == ["G1"]
@pytest.mark.asyncio
async def test_receive_programs_valid_and_invalid():
sub = AsyncMock()
sub.recv_multipart.side_effect = [
(b"program", b"{bad json"),
(b"program", make_valid_program_json().encode()),
]
manager = BDIProgramManager(name="program_manager_test")
manager.sub_socket = sub
manager._send_to_bdi = AsyncMock()
manager._send_clear_llm_history = AsyncMock()
try:
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
await manager._receive_programs()
except StopAsyncIteration:
pass
# Only valid Program should have triggered _send_to_bdi
assert manager._send_to_bdi.await_count == 1
forwarded: Program = manager._send_to_bdi.await_args[0][0]
assert forwarded.phases[0].norms[0].norm == "N1"
assert forwarded.phases[0].goals[0].description == "G1"
# Verify history clear was triggered
assert manager._send_clear_llm_history.await_count == 1
@pytest.mark.asyncio
async def test_send_clear_llm_history(mock_settings):
# Ensure the mock returns a string for the agent name (just like in your LLM tests)
mock_settings.agent_settings.llm_agent_name = "llm_agent"
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
await manager._send_clear_llm_history()
assert manager.send.await_count == 1
msg: InternalMessage = manager.send.await_args[0][0]
# Verify the content and recipient
assert msg.body == "clear_history"
assert msg.to == "llm_agent"

View File

@@ -87,3 +87,49 @@ async def test_send_beliefs_to_bdi(agent):
assert sent.to == settings.agent_settings.bdi_core_name assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs" assert sent.thread == "beliefs"
assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs] assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]
@pytest.mark.asyncio
async def test_setup_executes(agent):
"""Covers setup and asserts the agent has a name."""
await agent.setup()
assert agent.name == "belief_collector_agent" # simple property assertion
@pytest.mark.asyncio
async def test_handle_message_unrecognized_type_executes(agent):
"""Covers the else branch for unrecognized message type."""
payload = {"type": "unknown_type"}
msg = make_msg(payload, sender="tester")
# Wrap send to ensure nothing is sent
agent.send = AsyncMock()
await agent.handle_message(msg)
# Assert no messages were sent
agent.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_emo_text_executes(agent):
"""Covers the _handle_emo_text method."""
# The method does nothing, but we can assert it returns None
result = await agent._handle_emo_text({}, "origin")
assert result is None
@pytest.mark.asyncio
async def test_send_beliefs_to_bdi_empty_executes(agent):
"""Covers early return when beliefs are empty."""
agent.send = AsyncMock()
await agent._send_beliefs_to_bdi({})
# Assert that nothing was sent
agent.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_belief_text_invalid_returns_none(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "invalid-argument"}}
result = await agent._handle_belief_text(payload, "origin")
# The method itself returns None
assert result is None

View File

@@ -56,3 +56,10 @@ async def test_process_transcription_demo(agent, mock_settings):
assert sent.thread == "beliefs" assert sent.thread == "beliefs"
parsed = json.loads(sent.body) parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription] assert parsed["beliefs"]["user_said"] == [transcription]
@pytest.mark.asyncio
async def test_setup_initializes_beliefs(agent):
"""Covers the setup method and ensures beliefs are initialized."""
await agent.setup()
assert agent.beliefs == {"mood": ["X"], "car": ["Y"]}

View File

@@ -61,15 +61,18 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
MockSpeech.return_value.start.assert_awaited_once() MockSpeech.return_value.start.assert_awaited_once()
MockGesture.return_value.start.assert_awaited_once() MockGesture.return_value.start.assert_awaited_once()
MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False) MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False)
MockGesture.assert_called_once_with( MockGesture.assert_called_once_with(
ANY, ANY,
address="tcp://localhost:5556", address="tcp://localhost:5556",
bind=False, bind=False,
gesture_data=[], gesture_tags=[],
gesture_basic=[],
gesture_single=[],
) )
agent.add_behavior.assert_called_once()
agent.add_behavior.assert_called_once()
assert agent.connected is True assert agent.connected is True
@@ -354,3 +357,13 @@ async def test_listen_loop_ping_sends_internal(zmq_context):
await agent._listen_loop() await agent._listen_loop()
pub_socket.send_multipart.assert_awaited() pub_socket.send_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_negotiate_req_socket_none_causes_retry(zmq_context):
agent = RICommunicationAgent("ri_comm")
agent._req_socket = None
result = await agent._negotiate_connection(max_retries=1)
assert result is False

View File

@@ -49,6 +49,9 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent") agent = LLMAgent("llm_agent")
agent.send = AsyncMock() # Mock the send method to verify replies agent.send = AsyncMock() # Mock the send method to verify replies
mock_logger = MagicMock()
agent.logger = mock_logger
# Simulate receiving a message from BDI # Simulate receiving a message from BDI
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage( msg = InternalMessage(
@@ -134,3 +137,151 @@ def test_llm_instructions():
text_def = instr_def.build_developer_instruction() text_def = instr_def.build_developer_instruction()
assert "Norms to follow" in text_def assert "Norms to follow" in text_def
assert "Goals to reach" in text_def assert "Goals to reach" in text_def
@pytest.mark.asyncio
async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, mock_settings):
"""
Covers the ValidationError branch:
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
Assert: no message is sent.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Invalid JSON that triggers ValidationError in LLMPromptMessage
invalid_json = '{"text": "Hi", "wrong_field": 123}' # field not in schema
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=invalid_json,
)
await agent.handle_message(msg)
# Should not send any reply
agent.send.assert_not_called()
@pytest.mark.asyncio
async def test_handle_message_ignored_sender_branch_no_send(mock_httpx_client, mock_settings):
"""
Covers the else branch for messages not from BDI core:
else:
self.logger.debug("Message ignored (not from BDI core.")
Assert: no message is sent.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
msg = InternalMessage(
to="llm_agent",
sender="some_other_agent", # Not BDI core
body='{"text": "Hi"}',
)
await agent.handle_message(msg)
# Should not send any reply
agent.send.assert_not_called()
@pytest.mark.asyncio
async def test_query_llm_yields_final_tail_chunk(mock_settings):
"""
Covers the branch: if current_chunk: yield current_chunk
Ensure that the last partial chunk is emitted.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
agent.logger = MagicMock()
agent.logger.llm = MagicMock()
# Patch _stream_query_llm to yield tokens that do NOT end with punctuation
async def fake_stream(messages):
yield "Hello"
yield " world" # No punctuation to trigger the normal chunking
agent._stream_query_llm = fake_stream
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
# Collect chunks yielded
chunks = []
async for chunk in agent._query_llm(prompt.text, prompt.norms, prompt.goals):
chunks.append(chunk)
# The final chunk should be yielded
assert chunks[-1] == "Hello world"
assert any("Hello" in c for c in chunks)
@pytest.mark.asyncio
async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_settings):
"""
Covers: if not line or not line.startswith("data: "): continue
Feed lines that are empty or do not start with 'data:' and check they are skipped.
"""
# Mock response
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
lines = [
"", # empty line
"not data", # invalid prefix
'data: {"choices": [{"delta": {"content": "Hi"}}]}',
"data: [DONE]",
]
async def aiter_lines_gen():
for line in lines:
yield line
mock_response.aiter_lines.side_effect = aiter_lines_gen
# Proper async context manager for stream
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
# Make stream return the async context manager
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Patch settings for local LLM URL
with patch("control_backend.agents.llm.llm_agent.settings") as mock_sett:
mock_sett.llm_settings.local_llm_url = "http://localhost"
mock_sett.llm_settings.local_llm_model = "test-model"
# Collect tokens
tokens = []
async for token in agent._stream_query_llm([]):
tokens.append(token)
# Only the valid 'data:' line should yield content
assert tokens == ["Hi"]
@pytest.mark.asyncio
async def test_clear_history_command(mock_settings):
"""Test that the 'clear_history' message clears the agent's memory."""
# setup LLM to have some history
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
agent = LLMAgent("llm_agent")
agent.history = [
{"role": "user", "content": "Old conversation context"},
{"role": "assistant", "content": "Old response"},
]
assert len(agent.history) == 2
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_program_manager_name,
body="clear_history",
)
await agent.handle_message(msg)
assert len(agent.history) == 0

View File

@@ -120,3 +120,83 @@ def test_mlx_recognizer():
mlx_mock.transcribe.return_value = {"text": "Hi"} mlx_mock.transcribe.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10)) res = rec.recognize_speech(np.zeros(10))
assert res == "Hi" assert res == "Hi"
@pytest.mark.asyncio
async def test_transcription_loop_continues_after_error(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [
fake_audio, # first iteration → recognizer fails
asyncio.CancelledError(), # second iteration → stop loop
]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.side_effect = RuntimeError("fail")
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent._running = True # ← REQUIRED to enter the loop
agent.send = AsyncMock() # should never be called
agent.add_behavior = AsyncMock() # match other tests
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# recognizer failed, so we should never send anything
agent.send.assert_not_called()
# recv must have been called twice (audio then CancelledError)
assert mock_sub.recv.call_count == 2
@pytest.mark.asyncio
async def test_transcription_continue_branch_when_empty(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# First recv → audio chunk
# Second recv → Cancel loop → stop iteration
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = "" # <— triggers the continue branch
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
# Make loop runnable
agent._running = True
agent.send = AsyncMock()
agent.add_behavior = AsyncMock()
await agent.setup()
# Execute loop manually
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# → Because of "continue", NO sending should occur
agent.send.assert_not_called()
# → Continue was hit, so we must have read exactly 2 times:
# - first audio
# - second CancelledError
assert mock_sub.recv.call_count == 2
# → recognizer was called once (first iteration)
assert mock_recognizer.recognize_speech.call_count == 1

View File

@@ -1,11 +1,21 @@
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent from control_backend.agents.perception.vad_agent import VADAgent
# We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets
# aren't closed properly.
@pytest.fixture(autouse=True)
def mock_zmq():
with patch("zmq.asyncio.Context") as mock:
mock.instance.return_value = MagicMock()
yield mock
@pytest.fixture @pytest.fixture
def audio_out_socket(): def audio_out_socket():
return AsyncMock() return AsyncMock()
@@ -123,3 +133,42 @@ async def test_no_data(audio_out_socket, vad_agent):
audio_out_socket.send.assert_not_called() audio_out_socket.send.assert_not_called()
assert len(vad_agent.audio_buffer) == 0 assert len(vad_agent.audio_buffer) == 0
@pytest.mark.asyncio
async def test_vad_model_load_failure_stops_agent(vad_agent):
"""
Test that if loading the VAD model raises an Exception, it is caught,
the agent logs an exception, stops itself, and setup returns.
"""
# Patch torch.hub.load to raise an exception
with patch(
"control_backend.agents.perception.vad_agent.torch.hub.load",
side_effect=Exception("model fail"),
):
# Patch stop to an AsyncMock so we can check it was awaited
vad_agent.stop = AsyncMock()
await vad_agent.setup()
# Assert stop was called
vad_agent.stop.assert_awaited_once()
@pytest.mark.asyncio
async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog):
"""
Test that if binding the output socket raises ZMQBindError,
audio_out_socket is set to None, None is returned, and an error is logged.
"""
mock_socket = MagicMock()
mock_socket.bind.side_effect = zmq.ZMQBindError()
with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx:
mock_ctx.return_value.socket.return_value = mock_socket
with caplog.at_level("ERROR"):
port = vad_agent._connect_audio_out_socket()
assert port is None
assert vad_agent.audio_out_socket is None
assert caplog.text is not None

View File

@@ -0,0 +1,146 @@
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import RIEndpoint
@pytest.fixture
def agent():
agent = UserInterruptAgent(name="user_interrupt_agent")
agent.send = AsyncMock()
agent.logger = MagicMock()
agent.sub_socket = AsyncMock()
return agent
@pytest.mark.asyncio
async def test_send_to_speech_agent(agent):
"""Verify speech command format."""
await agent._send_to_speech_agent("Hello World")
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.robot_speech_name
body = json.loads(sent_msg.body)
assert body["data"] == "Hello World"
assert body["is_priority"] is True
@pytest.mark.asyncio
async def test_send_to_gesture_agent(agent):
"""Verify gesture command format."""
await agent._send_to_gesture_agent("wave_hand")
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.robot_gesture_name
body = json.loads(sent_msg.body)
assert body["data"] == "wave_hand"
assert body["is_priority"] is True
assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value
@pytest.mark.asyncio
async def test_send_to_program_manager(agent):
"""Verify belief update format."""
context_str = "2"
await agent._send_to_program_manager(context_str)
agent.send.assert_awaited_once()
sent_msg: InternalMessage = agent.send.call_args.args[0]
assert sent_msg.to == settings.agent_settings.bdi_program_manager_name
assert sent_msg.thread == "belief_override_id"
body = json.loads(sent_msg.body)
assert body["belief"] == context_str
@pytest.mark.asyncio
async def test_receive_loop_routing_success(agent):
"""
Test that the loop correctly:
1. Receives 'button_pressed' topic from ZMQ
2. Parses the JSON payload to find 'type' and 'context'
3. Calls the correct handler method based on 'type'
"""
# Prepare JSON payloads as bytes
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
(b"button_pressed", payload_speech),
(b"button_pressed", payload_gesture),
(b"button_pressed", payload_override),
asyncio.CancelledError, # Stop the infinite loop
]
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_program_manager = AsyncMock()
try:
await agent._receive_button_event()
except asyncio.CancelledError:
pass
await asyncio.sleep(0)
# Speech
agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech")
# Gesture
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
# Override
agent._send_to_program_manager.assert_awaited_once_with("Hello Override")
assert agent._send_to_speech_agent.await_count == 1
assert agent._send_to_gesture_agent.await_count == 1
assert agent._send_to_program_manager.await_count == 1
@pytest.mark.asyncio
async def test_receive_loop_unknown_type(agent):
"""Test that unknown 'type' values in the JSON log a warning and do not crash."""
# Prepare a payload with an unknown type
payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode()
agent.sub_socket.recv_multipart.side_effect = [
(b"button_pressed", payload_unknown),
asyncio.CancelledError,
]
agent._send_to_speech_agent = AsyncMock()
agent._send_to_gesture_agent = AsyncMock()
agent._send_to_belief_collector = AsyncMock()
try:
await agent._receive_button_event()
except asyncio.CancelledError:
pass
await asyncio.sleep(0)
# Ensure no handlers were called
agent._send_to_speech_agent.assert_not_called()
agent._send_to_gesture_agent.assert_not_called()
agent._send_to_belief_collector.assert_not_called()
agent.logger.warning.assert_called_with(
"Received button press with unknown type '%s' (context: '%s').",
"unknown_thing",
"some_data",
)

View File

@@ -0,0 +1,63 @@
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.responses import StreamingResponse
from control_backend.api.v1.endpoints import logs
@pytest.fixture
def client():
"""TestClient with logs router included."""
app = FastAPI()
app.include_router(logs.router)
return TestClient(app)
@pytest.mark.asyncio
async def test_log_stream_endpoint_lines(client):
"""Call /logs/stream with a mocked ZMQ socket to cover all lines."""
# Dummy socket to mock ZMQ behavior
class DummySocket:
def __init__(self):
self.subscribed = []
self.connected = False
self.recv_count = 0
def subscribe(self, topic):
self.subscribed.append(topic)
def connect(self, addr):
self.connected = True
async def recv_multipart(self):
# Return one message, then stop generator
if self.recv_count == 0:
self.recv_count += 1
return (b"INFO", b"test message")
else:
raise StopAsyncIteration
dummy_socket = DummySocket()
# Patch Context.instance().socket() to return dummy socket
with patch("control_backend.api.v1.endpoints.logs.Context.instance") as mock_context:
mock_context.return_value.socket.return_value = dummy_socket
# Call the endpoint directly
response = await logs.log_stream()
assert isinstance(response, StreamingResponse)
# Fetch one chunk from the generator
gen = response.body_iterator
chunk = await gen.__anext__()
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
assert "data:" in chunk
# Optional: assert subscribe/connect were called
assert dummy_socket.subscribed # at least some log levels subscribed
assert dummy_socket.connected # connect was called

View File

@@ -0,0 +1,45 @@
import json
import pytest
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import message
@pytest.fixture
def client():
"""FastAPI TestClient for the message router."""
from fastapi import FastAPI
app = FastAPI()
app.include_router(message.router)
return TestClient(app)
def test_receive_message_post(client, monkeypatch):
"""Test POST /message endpoint sends message to pub socket."""
# Dummy pub socket to capture sent messages
class DummyPubSocket:
def __init__(self):
self.sent = []
async def send_multipart(self, msg):
self.sent.append(msg)
dummy_socket = DummyPubSocket()
# Patch app.state.endpoints_pub_socket
client.app.state.endpoints_pub_socket = dummy_socket
data = {"message": "Hello world"}
response = client.post("/message", json=data)
assert response.status_code == 202
assert response.json() == {"status": "Message received"}
# Ensure the message was sent via pub_socket
assert len(dummy_socket.sent) == 1
topic, body = dummy_socket.sent[0]
parsed = json.loads(body.decode("utf-8"))
assert parsed["message"] == "Hello world"

View File

@@ -1,3 +1,4 @@
# tests/test_robot_endpoints.py
import json import json
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -29,7 +30,7 @@ def client(app):
@pytest.fixture @pytest.fixture
def mock_zmq_context(): def mock_zmq_context():
"""Mock the ZMQ context.""" """Mock the ZMQ context used by the endpoint module."""
with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context: with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context:
context_instance = MagicMock() context_instance = MagicMock()
mock_context.return_value = context_instance mock_context.return_value = context_instance
@@ -38,13 +39,13 @@ def mock_zmq_context():
@pytest.fixture @pytest.fixture
def mock_sockets(mock_zmq_context): def mock_sockets(mock_zmq_context):
"""Mock ZMQ sockets.""" """Optional helper if you want both a sub and req/push socket available."""
mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket) mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_pub_socket = AsyncMock(spec=zmq.asyncio.Socket) mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_zmq_context.socket.return_value = mock_sub_socket mock_zmq_context.socket.return_value = mock_sub_socket
return {"sub": mock_sub_socket, "pub": mock_pub_socket} return {"sub": mock_sub_socket, "req": mock_req_socket}
def test_receive_speech_command_success(client): def test_receive_speech_command_success(client):
@@ -61,11 +62,11 @@ def test_receive_speech_command_success(client):
speech_command = SpeechCommand(**command_data) speech_command = SpeechCommand(**command_data)
# Act # Act
response = client.post("/command", json=command_data) response = client.post("/command/speech", json=command_data)
# Assert # Assert
assert response.status_code == 202 assert response.status_code == 202
assert response.json() == {"status": "Command received"} assert response.json() == {"status": "Speech command received"}
# Verify that the ZMQ socket was used correctly # Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with( mock_pub_socket.send_multipart.assert_awaited_once_with(
@@ -75,9 +76,8 @@ def test_receive_speech_command_success(client):
def test_receive_gesture_command_success(client): def test_receive_gesture_command_success(client):
""" """
Test for successful reception of a command. Ensures the status code is 202 and the response body Test for successful reception of a command that is a gesture command.
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the Ensures the status code is 202 and the response body is correct.
expected data.
""" """
# Arrange # Arrange
mock_pub_socket = AsyncMock() mock_pub_socket = AsyncMock()
@@ -87,11 +87,11 @@ def test_receive_gesture_command_success(client):
gesture_command = GestureCommand(**command_data) gesture_command = GestureCommand(**command_data)
# Act # Act
response = client.post("/command", json=command_data) response = client.post("/command/gesture", json=command_data)
# Assert # Assert
assert response.status_code == 202 assert response.status_code == 202
assert response.json() == {"status": "Command received"} assert response.json() == {"status": "Gesture command received"}
# Verify that the ZMQ socket was used correctly # Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with( mock_pub_socket.send_multipart.assert_awaited_once_with(
@@ -99,13 +99,23 @@ def test_receive_gesture_command_success(client):
) )
def test_receive_command_invalid_payload(client): def test_receive_speech_command_invalid_payload(client):
""" """
Test invalid data handling (schema validation). Test invalid data handling (schema validation).
""" """
# Missing required field(s) # Missing required field(s)
bad_payload = {"invalid": "data"} bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload) response = client.post("/command/speech", json=bad_payload)
assert response.status_code == 422 # validation error
def test_receive_gesture_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command/gesture", json=bad_payload)
assert response.status_code == 422 # validation error assert response.status_code == 422 # validation error
@@ -116,7 +126,9 @@ def test_ping_check_returns_none(client):
assert response.json() is None assert response.json() is None
# TODO: Convert these mock sockets to the fixture. # ----------------------------
# ping_stream tests (unchanged behavior)
# ----------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ping_stream_yields_ping_event(monkeypatch): async def test_ping_stream_yields_ping_event(monkeypatch):
"""Test that ping_stream yields a proper SSE message when a ping is received.""" """Test that ping_stream yields a proper SSE message when a ping is received."""
@@ -129,6 +141,11 @@ async def test_ping_stream_yields_ping_event(monkeypatch):
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
# patch settings address used by ping_stream
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock() mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True]) mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
@@ -142,7 +159,7 @@ async def test_ping_stream_yields_ping_event(monkeypatch):
with pytest.raises(StopAsyncIteration): with pytest.raises(StopAsyncIteration):
await anext(generator) await anext(generator)
mock_sub_socket.connect.assert_called_once() mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited() mock_sub_socket.recv_multipart.assert_awaited()
@@ -159,6 +176,10 @@ async def test_ping_stream_handles_timeout(monkeypatch):
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock() mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(return_value=True) mock_request.is_disconnected = AsyncMock(return_value=True)
@@ -168,7 +189,7 @@ async def test_ping_stream_handles_timeout(monkeypatch):
with pytest.raises(StopAsyncIteration): with pytest.raises(StopAsyncIteration):
await anext(generator) await anext(generator)
mock_sub_socket.connect.assert_called_once() mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited() mock_sub_socket.recv_multipart.assert_awaited()
@@ -187,6 +208,10 @@ async def test_ping_stream_yields_json_values(monkeypatch):
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock() mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True]) mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
@@ -199,254 +224,169 @@ async def test_ping_stream_yields_json_values(monkeypatch):
assert "connected" in event_text assert "connected" in event_text
assert "true" in event_text assert "true" in event_text
mock_sub_socket.connect.assert_called_once() mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping") mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited() mock_sub_socket.recv_multipart.assert_awaited()
# New tests for get_available_gesture_tags endpoint
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_available_gesture_tags_success(client, monkeypatch): async def test_get_available_single_gestures_success(client, monkeypatch):
""" """
Test successful retrieval of available gesture tags. Test successful retrieval of single gestures.
""" """
# Arrange mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_sub_socket = AsyncMock() mock_req_socket.connect = MagicMock()
mock_sub_socket.connect = MagicMock() mock_req_socket.send_json = AsyncMock()
mock_sub_socket.setsockopt = MagicMock() response_data = {"single_gestures": ["wave", "point"]}
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
# Simulate a response with gesture tags
response_data = {"tags": ["wave", "nod", "point", "dance"]}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_context = MagicMock() mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock() monkeypatch.setattr(robot.logger, "debug", MagicMock())
client.app.state.endpoints_pub_socket = mock_pub_socket monkeypatch.setattr(robot.logger, "error", MagicMock())
# Mock settings response = client.get("/commands/gesture/single")
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger to avoid actual logging
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Act
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]} assert response.json() == {"available_gestures": ["wave", "point"]}
# Verify ZeroMQ interactions mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555") mock_req_socket.send_json.assert_awaited_once_with({"type": "single", "count": None})
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures") mock_req_socket.recv.assert_awaited_once()
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_sub_socket.recv_multipart.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_available_gesture_tags_with_amount(client, monkeypatch): async def test_get_available_single_gestures_timeout(client, monkeypatch):
""" mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
Test retrieval of gesture tags with a specific amount parameter. mock_req_socket.connect = MagicMock()
This tests the TODO in the endpoint about getting a certain amount from the UI. mock_req_socket.send_json = AsyncMock()
""" mock_req_socket.recv = AsyncMock(side_effect=TimeoutError)
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with gesture tags
response_data = {"tags": ["wave", "nod"]}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_context = MagicMock() mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock() response = client.get("/commands/gesture/single")
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Act - Note: The endpoint currently doesn't support query parameters for amount,
# but we're testing what happens if the UI sends an amount (the TODO in the code)
# For now, we test the current behavior
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod"]} assert response.json() == {"available_gestures": []}
# The endpoint currently doesn't use the amount parameter, so it should send empty bytes
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_available_gesture_tags_timeout(client, monkeypatch): async def test_get_available_single_gestures_missing_key(client, monkeypatch):
""" """
Test timeout scenario when fetching gesture tags. Test response missing 'single_gestures' key.
""" """
# Arrange mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_sub_socket = AsyncMock() mock_req_socket.connect = MagicMock()
mock_sub_socket.connect = MagicMock() mock_req_socket.send_json = AsyncMock()
mock_sub_socket.setsockopt = MagicMock() mock_req_socket.recv = AsyncMock(return_value=json.dumps({"unexpected": "value"}).encode())
# Simulate a timeout
mock_sub_socket.recv_multipart = AsyncMock(side_effect=TimeoutError)
mock_context = MagicMock() mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock() monkeypatch.setattr(robot.logger, "debug", MagicMock())
client.app.state.endpoints_pub_socket = mock_pub_socket monkeypatch.setattr(robot.logger, "error", MagicMock())
# Mock settings response = client.get("/commands/gesture/single")
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger to verify debug message is logged
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Act
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200 assert response.status_code == 200
# On timeout, body becomes b"" and json.loads(b"") raises JSONDecodeError assert response.json() == {"available_gestures": []}
# But looking at the endpoint code, it will try to parse empty bytes which will fail
# Let's check what actually happens
assert response.json() == {"available_gesture_tags": []}
# Verify the timeout was logged
mock_logger.assert_called_once_with("got timeout error fetching gestures")
# Verify ZeroMQ interactions
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_sub_socket.recv_multipart.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_available_gesture_tags_empty_response(client, monkeypatch): async def test_get_available_single_gestures_invalid_json(client, monkeypatch):
""" """
Test scenario when response contains no tags. Test invalid JSON response for single gestures.
""" """
# Arrange mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_sub_socket = AsyncMock() mock_req_socket.connect = MagicMock()
mock_sub_socket.connect = MagicMock() mock_req_socket.send_json = AsyncMock()
mock_sub_socket.setsockopt = MagicMock() mock_req_socket.recv = AsyncMock(return_value=b"not-json")
# Simulate a response with empty tags
response_data = {"tags": []}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_context = MagicMock() mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock() mock_error = MagicMock()
client.app.state.endpoints_pub_socket = mock_pub_socket monkeypatch.setattr(robot.logger, "error", mock_error)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
# Mock settings response = client.get("/commands/gesture/single")
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Act
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []} assert response.json() == {"available_gestures": []}
assert mock_error.call_count == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch): async def test_get_available_basic_gestures_success(client, monkeypatch):
""" """
Test scenario when response JSON doesn't contain 'tags' key. Test successful retrieval of basic gestures.
""" """
# Arrange mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_sub_socket = AsyncMock() mock_req_socket.connect = MagicMock()
mock_sub_socket.connect = MagicMock() mock_req_socket.send_json = AsyncMock()
mock_sub_socket.setsockopt = MagicMock() response_data = {"basic_gestures": ["nod", "shake"]}
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
# Simulate a response without 'tags' key
response_data = {"some_other_key": "value"}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_context = MagicMock() mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock() monkeypatch.setattr(robot.logger, "debug", MagicMock())
client.app.state.endpoints_pub_socket = mock_pub_socket monkeypatch.setattr(robot.logger, "error", MagicMock())
# Mock settings response = client.get("/commands/gesture/basic")
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Act
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200 assert response.status_code == 200
# .get("tags", []) should return empty list if 'tags' key is missing assert response.json() == {"available_gestures": ["nod", "shake"]}
assert response.json() == {"available_gesture_tags": []}
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send_json.assert_awaited_once_with({"type": "basic", "count": None})
mock_req_socket.recv.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_available_gesture_tags_invalid_json(client, monkeypatch): async def test_get_available_basic_gestures_timeout(client, monkeypatch):
""" mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
Test scenario when response contains invalid JSON. mock_req_socket.connect = MagicMock()
""" mock_req_socket.send_json = AsyncMock()
# Arrange mock_req_socket.recv = AsyncMock(side_effect=TimeoutError)
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with invalid JSON
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"get_gestures", b"invalid json"])
mock_context = MagicMock() mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context) monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock() response = client.get("/commands/gesture/basic")
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Act
response = client.get("/get_available_gesture_tags")
# Assert - invalid JSON should raise an exception
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []} assert response.json() == {"available_gestures": []}
@pytest.mark.asyncio
async def test_get_available_basic_gestures_invalid_json(client, monkeypatch):
"""
Test invalid JSON response for basic gestures.
"""
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send_json = AsyncMock()
mock_req_socket.recv = AsyncMock(return_value=b"{invalid json")
mock_context = MagicMock()
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_error = MagicMock()
monkeypatch.setattr(robot.logger, "error", mock_error)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
response = client.get("/commands/gesture/basic")
assert response.status_code == 200
assert response.json() == {"available_gestures": []}
assert mock_error.call_count == 1

View File

@@ -0,0 +1,16 @@
from fastapi.routing import APIRoute
from control_backend.api.v1.router import api_router # <--- corrected import
def test_router_includes_expected_paths():
"""Ensure api_router includes main router prefixes."""
routes = [r for r in api_router.routes if isinstance(r, APIRoute)]
paths = [r.path for r in routes]
# Ensure at least one route under each prefix exists
assert any(p.startswith("/robot") for p in paths)
assert any(p.startswith("/message") for p in paths)
assert any(p.startswith("/sse") for p in paths)
assert any(p.startswith("/logs") for p in paths)
assert any(p.startswith("/program") for p in paths)

View File

@@ -0,0 +1,24 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import sse
@pytest.fixture
def app():
app = FastAPI()
app.include_router(sse.router)
return app
@pytest.fixture
def client(app):
return TestClient(app)
def test_sse_route_exists(client):
"""Minimal smoke test to ensure /sse route exists and responds."""
response = client.get("/sse")
# Since implementation is not done, we only assert it doesn't crash
assert response.status_code == 200

View File

@@ -2,7 +2,7 @@
import asyncio import asyncio
import logging import logging
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@@ -70,3 +70,142 @@ async def test_get_agent():
agent = ConcreteTestAgent("registrant") agent = ConcreteTestAgent("registrant")
assert AgentDirectory.get("registrant") == agent assert AgentDirectory.get("registrant") == agent
assert AgentDirectory.get("non_existent") is None assert AgentDirectory.get("non_existent") is None
class DummyAgent(BaseAgent):
async def setup(self):
pass # we will test this separately
async def handle_message(self, msg: InternalMessage):
self.last_handled = msg
@pytest.mark.asyncio
async def test_base_agent_setup_is_noop():
agent = DummyAgent("dummy")
# Should simply return without error
assert await agent.setup() is None
@pytest.mark.asyncio
async def test_send_to_local_agent(monkeypatch):
sender = DummyAgent("sender")
target = DummyAgent("receiver")
# Fake logger
sender.logger = MagicMock()
# Patch inbox.put
target.inbox.put = AsyncMock()
message = InternalMessage(to="receiver", sender="sender", body="hello")
await sender.send(message)
target.inbox.put.assert_awaited_once_with(message)
sender.logger.debug.assert_called_once()
@pytest.mark.asyncio
async def test_process_inbox_calls_handle_message(monkeypatch):
agent = DummyAgent("dummy")
agent.logger = MagicMock()
# Make agent running so loop triggers
agent._running = True
# Prepare inbox to give one message then stop
msg = InternalMessage(to="dummy", sender="x", body="test")
async def get_once():
agent._running = False # stop after first iteration
return msg
agent.inbox.get = AsyncMock(side_effect=get_once)
agent.handle_message = AsyncMock()
await agent._process_inbox()
agent.handle_message.assert_awaited_once_with(msg)
@pytest.mark.asyncio
async def test_receive_internal_zmq_loop_success(monkeypatch):
agent = DummyAgent("dummy")
agent.logger = MagicMock()
agent._running = True
mock_socket = MagicMock()
mock_socket.recv_multipart = AsyncMock(
side_effect=[
(
b"topic",
InternalMessage(to="dummy", sender="x", body="hi").model_dump_json().encode(),
),
asyncio.CancelledError(), # stop loop
]
)
agent._internal_sub_socket = mock_socket
agent.inbox.put = AsyncMock()
await agent._receive_internal_zmq_loop()
agent.inbox.put.assert_awaited() # message forwarded
@pytest.mark.asyncio
async def test_receive_internal_zmq_loop_exception_logs_error():
agent = DummyAgent("dummy")
agent.logger = MagicMock()
agent._running = True
mock_socket = MagicMock()
mock_socket.recv_multipart = AsyncMock(
side_effect=[Exception("boom"), asyncio.CancelledError()]
)
agent._internal_sub_socket = mock_socket
agent.inbox.put = AsyncMock()
await agent._receive_internal_zmq_loop()
agent.logger.exception.assert_called_once()
assert "Could not process ZMQ message." in agent.logger.exception.call_args[0][0]
@pytest.mark.asyncio
async def test_base_agent_handle_message_not_implemented():
class RawAgent(BaseAgent):
async def setup(self):
pass
agent = RawAgent("raw")
msg = InternalMessage(to="raw", sender="x", body="hi")
with pytest.raises(NotImplementedError):
await BaseAgent.handle_message(agent, msg)
@pytest.mark.asyncio
async def test_base_agent_setup_abstract_method_body_executes():
"""
Covers the 'pass' inside BaseAgent.setup().
Since BaseAgent is abstract, we do NOT instantiate it.
We call the coroutine function directly on BaseAgent and pass a dummy self.
"""
class Dummy:
"""Minimal stub to act as 'self'."""
pass
stub = Dummy()
# Call BaseAgent.setup() as an unbound coroutine, passing stub as 'self'
result = await BaseAgent.setup(stub)
# The method contains only 'pass', so it returns None
assert result is None

View File

@@ -86,3 +86,34 @@ def test_setup_logging_zmq_handler(mock_zmq_context):
args = mock_dict_config.call_args[0][0] args = mock_dict_config.call_args[0][0]
assert "interface_or_socket" in args["handlers"]["ui"] assert "interface_or_socket" in args["handlers"]["ui"]
def test_add_logging_level_method_name_exists_in_logging():
# method_name explicitly set to an existing logging method → triggers first hasattr branch
with pytest.raises(AttributeError) as exc:
add_logging_level("NEWDUPLEVEL", 37, method_name="info")
assert "info already defined in logging module" in str(exc.value)
def test_add_logging_level_method_name_exists_in_logger_class():
# 'makeRecord' exists on Logger class but not on the logging module
with pytest.raises(AttributeError) as exc:
add_logging_level("ANOTHERLEVEL", 38, method_name="makeRecord")
assert "makeRecord already defined in logger class" in str(exc.value)
def test_add_logging_level_log_to_root_path_executes_without_error():
# Verify log_to_root is installed and callable — without asserting logging output
level_name = "ROOTTEST"
level_num = 36
add_logging_level(level_name, level_num)
# Simply call the injected root logger method
# The line is executed even if we don't validate output
root_logging_method = getattr(logging, level_name.lower(), None)
assert callable(root_logging_method)
# Execute the method to hit log_to_root in coverage.
# No need to verify log output.
root_logging_method("some message")

View File

@@ -0,0 +1,12 @@
from control_backend.schemas.message import Message
def base_message() -> Message:
return Message(message="Example")
def test_valid_message():
mess = base_message()
validated = Message.model_validate(mess)
assert isinstance(validated, Message)
assert validated.message == "Example"

73
test/unit/test_main.py Normal file
View File

@@ -0,0 +1,73 @@
import asyncio
import sys
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from control_backend.api.v1.router import api_router
from control_backend.main import app, lifespan
# Fix event loop on Windows
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@pytest.fixture
def client():
# Patch setup_logging so it does nothing
with patch("control_backend.main.setup_logging"):
with TestClient(app) as c:
yield c
def test_root_fast():
# Patch heavy startup code so it doesnt slow down
with patch("control_backend.main.setup_logging"), patch("control_backend.main.lifespan"):
client = TestClient(app)
resp = client.get("/")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
def test_cors_middleware_added():
"""Test that CORSMiddleware is correctly added to the app."""
from starlette.middleware.cors import CORSMiddleware
middleware_classes = [m.cls for m in app.user_middleware]
assert CORSMiddleware in middleware_classes
def test_api_router_included():
"""Test that the API router is included in the FastAPI app."""
route_paths = [r.path for r in app.routes]
for route in api_router.routes:
assert route.path in route_paths
@pytest.mark.asyncio
async def test_lifespan_agent_start_exception():
"""
Trigger an exception during agent startup to cover the error logging branch.
Ensures exceptions are logged properly and re-raised.
"""
with (
patch(
"control_backend.main.RICommunicationAgent.start", new_callable=AsyncMock
) as ri_start,
patch("control_backend.main.setup_logging"),
patch("control_backend.main.threading.Thread"),
):
# Force RICommunicationAgent.start to raise an exception
ri_start.side_effect = Exception("Test exception")
with patch("control_backend.main.logger") as mock_logger:
with pytest.raises(Exception, match="Test exception"):
async with lifespan(app):
pass
# Verify the error was logged correctly
assert mock_logger.error.called
args, _ = mock_logger.error.call_args
assert isinstance(args[2], Exception)