Compare commits

..

10 Commits

Author SHA1 Message Date
173326d4ad build: add Dockerfile
ref: N25B-280
2025-11-14 14:06:39 +01:00
9c538d927f build: integrate Docker functionality
Add environment variables throughout the code base to support Docker
compose integration.

ref: N25B-280
2025-11-14 13:59:13 +01:00
Twirre Meulenbelt
1518b14867 fix: set empty default norms and goals
ref: N25B-200
2025-11-12 16:33:53 +01:00
Twirre Meulenbelt
858a554c78 feat: update endpoint to support the new UI request type
The UI now sends a Program as defined in our schemas.

ref: N25B-200
2025-11-12 16:30:42 +01:00
Twirre Meulenbelt
5376b3bb4c feat: create the program manager agent to interpret programs from the UI
Extracts norms and goals and sends these to the BDI agent.

ref: N25B-200
2025-11-12 16:28:52 +01:00
8cd8988fe0 feat: (hopefully) optional norms and goals
ref: N25B-200
2025-11-12 14:00:50 +01:00
Twirre Meulenbelt
919604493e Merge remote-tracking branch 'origin/feat/recieve-programs-ui' into demo 2025-11-12 13:39:53 +01:00
273f621b1b feat: norms and goals in BDI
ref: N25B-200
2025-11-12 13:35:15 +01:00
Twirre Meulenbelt
e39139cac9 fix: VAD agent requires reset
Otherwise, it won't start up correctly.

ref: N25B-266
2025-11-12 12:06:17 +01:00
Twirre Meulenbelt
b785493b97 fix: messages are None when no message is received
ref: N25B-265
2025-11-12 11:47:59 +01:00
93 changed files with 3911 additions and 5793 deletions

14
.dockerignore Normal file
View File

@@ -0,0 +1,14 @@
.git
.venv
__pycache__/
*.pyc
.dockerignore
Dockerfile
README.md
.gitlab-ci.yml
.gitignore
.pre-commit-config.yaml
.githooks/
test/
.pytest_cache/
.ruff_cache/

4
.gitignore vendored
View File

@@ -218,9 +218,7 @@ __marimo__/
# MacOS
.DS_Store
# Docs
docs/*
!docs/conf.py

View File

@@ -22,4 +22,6 @@ test:
tags:
- test
script:
- uv run --only-group test pytest test
# - uv run --group integration-test pytest test/integration
- uv run --only-group test pytest test/unit

View File

@@ -8,7 +8,7 @@ formatters:
# Console output
colored:
(): "colorlog.ColoredFormatter"
format: "{log_color}{asctime}.{msecs:03.0f} | {levelname:11} | {name:70} | {message}"
format: "{log_color}{asctime} | {levelname:11} | {name:70} | {message}"
style: "{"
datefmt: "%H:%M:%S"

21
Dockerfile Normal file
View File

@@ -0,0 +1,21 @@
# Debian based image
FROM ghcr.io/astral-sh/uv:0.9.8-trixie-slim
WORKDIR /app
ENV VIRTUAL_ENV=/app/.venv
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
RUN apt-get update && apt-get install -y gcc=4:14.2.0-1 portaudio19-dev && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
COPY pyproject.toml uv.lock .python-version ./
RUN uv sync
COPY . .
EXPOSE 8000
ENV PYTHONPATH=src
CMD [ "fastapi", "run", "src/control_backend/main.py" ]

View File

@@ -63,30 +63,3 @@ git config --local --unset core.hooksPath
```
Then run the pre-commit install commands again.
## Documentation
Generate documentation web pages using:
### Linux & macOS
```bash
PYTHONPATH=src sphinx-apidoc -F -o docs src/control_backend
```
### Windows
```bash
$env:PYTHONPATH="src"; sphinx-apidoc -F -o docs src/control_backend
```
Optionally, in the `conf.py` file in the `docs` folder, change preferences.
In the `docs` folder:
### Linux & macOS
```bash
make html
```
### Windows
```bash
.\make.bat html
```

View File

@@ -1,40 +0,0 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
import os
import sys
sys.path.insert(0, os.path.abspath("../src"))
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "control_backend"
copyright = "2025, Author"
author = "Author"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.viewcode",
"sphinx.ext.todo",
]
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
language = "en"
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]
# -- Options for todo extension ----------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/extensions/todo.html#configuration
todo_include_todos = True

View File

@@ -5,51 +5,43 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"agentspeak>=0.2.2",
"colorlog>=6.10.1",
"fastapi[all]>=0.115.6",
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
"numpy>=2.3.3",
"openai-whisper>=20250625",
"pyaudio>=0.2.14",
"pydantic>=2.12.0",
"pydantic-settings>=2.11.0",
"python-json-logger>=4.0.0",
"pyyaml>=6.0.3",
"pyzmq>=27.1.0",
"silero-vad>=6.0.0",
"sphinx>=7.3.7",
"sphinx-rtd-theme>=3.0.2",
"torch>=2.8.0",
"uvicorn>=0.37.0",
"colorlog>=6.10.1",
"fastapi[all]>=0.115.6",
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
"numpy>=2.3.3",
"openai-whisper>=20250625",
"pyaudio>=0.2.14",
"pydantic>=2.12.0",
"pydantic-settings>=2.11.0",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
"python-json-logger>=4.0.0",
"pyyaml>=6.0.3",
"pyzmq>=27.1.0",
"silero-vad>=6.0.0",
"spade>=4.1.0",
"spade-bdi>=0.3.2",
"torch>=2.8.0",
"uvicorn>=0.37.0",
]
[dependency-groups]
dev = [
"pre-commit>=4.3.0",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
"soundfile>=0.13.1",
"ruff>=0.14.2",
"ruff-format>=0.3.0",
"pre-commit>=4.3.0",
"ruff>=0.14.2",
"ruff-format>=0.3.0",
]
integration-test = [
"soundfile>=0.13.1",
]
test = [
"agentspeak>=0.2.2",
"fastapi>=0.115.6",
"httpx>=0.28.1",
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
"openai-whisper>=20250625",
"pydantic>=2.12.0",
"pydantic-settings>=2.11.0",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
"pyyaml>=6.0.3",
"pyzmq>=27.1.0",
"soundfile>=0.13.1",
"numpy>=2.3.3",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
]
[tool.pytest.ini_options]
@@ -60,15 +52,15 @@ line-length = 100
[tool.ruff.lint]
extend-select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort (import sorting)
"UP", # pyupgrade (modernize code)
"B", # flake8-bugbear (common bugs)
"C4", # flake8-comprehensions (unnecessary comprehensions)
"E", # pycodestyle
"F", # pyflakes
"I", # isort (import sorting)
"UP", # pyupgrade (modernize code)
"B", # flake8-bugbear (common bugs)
"C4", # flake8-comprehensions (unnecessary comprehensions)
]
ignore = [
"E226", # spaces around operators
"E701", # multiple statements on a single line
"E226", # spaces around operators
"E701", # multiple statements on a single line
]

View File

@@ -1 +1,7 @@
from .base import BaseAgent as BaseAgent
from .belief_collector.belief_collector import BeliefCollectorAgent as BeliefCollectorAgent
from .llm.llm import LLMAgent as LLMAgent
from .ri_command_agent import RICommandAgent as RICommandAgent
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent
from .transcription.transcription_agent import TranscriptionAgent as TranscriptionAgent
from .vad_agent import VADAgent as VADAgent

View File

@@ -1,2 +0,0 @@
from .robot_gesture_agent import RobotGestureAgent as RobotGestureAgent
from .robot_speech_agent import RobotSpeechAgent as RobotSpeechAgent

View File

@@ -1,295 +0,0 @@
import json
import zmq
import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
class RobotGestureAgent(BaseAgent):
"""
This agent acts as a bridge between the control backend and the Robot Interface (RI).
It receives speech commands from other agents or from the UI,
and forwards them to the robot via a ZMQ PUB socket.
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
:ivar pubsocket: ZMQ PUB socket for sending commands to the Robot Interface.
:ivar address: Address to bind/connect the PUB socket.
:ivar bind: Whether to bind or connect the PUB socket.
:ivar gesture_data: A list of strings for available gestures
"""
subsocket: azmq.Socket
pubsocket: azmq.Socket
address = ""
bind = False
gesture_data = []
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
bind=False,
gesture_data=None,
):
if gesture_data is None:
self.gesture_data = []
else:
self.gesture_data = gesture_data
super().__init__(name)
self.address = address
self.bind = bind
async def setup(self):
"""
Initialize the agent.
1. Sets up the PUB socket to talk to the robot.
2. Sets up the SUB socket to listen for "command" topics (from UI/External).
3. Starts the loop for handling ZMQ commands.
"""
self.logger.info("Setting up %s", self.name)
context = azmq.Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind: # TODO: Should this ever be the case?
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures")
self.add_behavior(self._zmq_command_loop())
self.add_behavior(self._fetch_gestures_loop())
self.logger.info("Finished setting up %s", self.name)
async def stop(self):
if self.subsocket:
self.subsocket.close()
if self.pubsocket:
self.pubsocket.close()
await super().stop()
async def handle_message(self, msg: InternalMessage):
"""
Handle commands received from other internal Python agents.
Validates the message as a :class:`GestureCommand` and forwards it to the robot.
:param msg: The internal message containing the command.
"""
try:
gesture_command = GestureCommand.model_validate_json(msg.body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.availableTags():
self.logger.warning(
"Received gesture tag '%s' which is not in available tags. Early returning",
gesture_command.data,
)
return
await self.pubsocket.send_json(gesture_command.model_dump())
except Exception:
self.logger.exception("Error processing internal message.")
async def _zmq_command_loop(self):
"""
Loop to handle commands received via ZMQ (e.g., from the UI).
Listens on the 'command' topic, validates the JSON and forwards it to the robot.
"""
while self._running:
try:
topic, body = await self.subsocket.recv_multipart()
# Don't process send_gestures here
if topic != b"command":
continue
body = json.loads(body)
gesture_command = GestureCommand.model_validate(body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.availableTags():
self.logger.warning(
"Received gesture tag '%s' which is not in available tags.\
Early returning",
gesture_command.data,
)
continue
await self.pubsocket.send_json(gesture_command.model_dump())
except Exception:
self.logger.exception("Error processing ZMQ message.")
async def _fetch_gestures_loop(self):
"""
Loop to handle fetching gestures received via ZMQ (e.g., from the UI).
Listens on the 'send_gestures' topic, and returns a list on the get_gestures topic.
"""
while self._running:
try:
topic, body = await self.subsocket.recv_multipart()
# Don't process commands here
if topic != b"send_gestures":
continue
try:
body = json.loads(body)
except json.JSONDecodeError:
body = None
# We could have the body be the nummer of gestures you want to fetch or something.
amount = None
if isinstance(body, int):
amount = body
tags = self.availableTags()[:amount] if amount else self.availableTags()
response = json.dumps({"tags": tags}).encode()
await self.pubsocket.send_multipart(
[
b"get_gestures",
response,
]
)
except Exception:
self.logger.exception("Error fetching gesture tags.")
def availableTags(self):
"""
Returns the available gesture tags.
:return: List of available gesture tags.
"""
return [
"above",
"affirmative",
"afford",
"agitated",
"all",
"allright",
"alright",
"any",
"assuage",
"assuage",
"attemper",
"back",
"bashful",
"beg",
"beseech",
"blank",
"body language",
"bored",
"bow",
"but",
"call",
"calm",
"choose",
"choice",
"cloud",
"cogitate",
"cool",
"crazy",
"disappointed",
"down",
"earth",
"empty",
"embarrassed",
"enthusiastic",
"entire",
"estimate",
"except",
"exalted",
"excited",
"explain",
"far",
"field",
"floor",
"forlorn",
"friendly",
"front",
"frustrated",
"gentle",
"gift",
"give",
"ground",
"happy",
"hello",
"her",
"here",
"hey",
"hi",
"him",
"hopeless",
"hysterical",
"I",
"implore",
"indicate",
"joyful",
"me",
"meditate",
"modest",
"negative",
"nervous",
"no",
"not know",
"nothing",
"offer",
"ok",
"once upon a time",
"oppose",
"or",
"pacify",
"pick",
"placate",
"please",
"present",
"proffer",
"quiet",
"reason",
"refute",
"reject",
"rousing",
"sad",
"select",
"shamefaced",
"show",
"show sky",
"sky",
"soothe",
"sun",
"supplicate",
"tablet",
"tall",
"them",
"there",
"think",
"timid",
"top",
"unless",
"up",
"upstairs",
"void",
"warm",
"winner",
"yeah",
"yes",
"yoo-hoo",
"you",
"your",
"zero",
"zestful",
]

View File

@@ -1,103 +0,0 @@
import json
import zmq
import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
class RobotSpeechAgent(BaseAgent):
"""
This agent acts as a bridge between the control backend and the Robot Interface (RI).
It receives speech commands from other agents or from the UI,
and forwards them to the robot via a ZMQ PUB socket.
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
:ivar pubsocket: ZMQ PUB socket for sending commands to the Robot Interface.
:ivar address: Address to bind/connect the PUB socket.
:ivar bind: Whether to bind or connect the PUB socket.
"""
subsocket: azmq.Socket
pubsocket: azmq.Socket
address = ""
bind = False
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
bind=False,
):
super().__init__(name)
self.address = address
self.bind = bind
async def setup(self):
"""
Initialize the agent.
1. Sets up the PUB socket to talk to the robot.
2. Sets up the SUB socket to listen for "command" topics (from UI/External).
3. Starts the loop for handling ZMQ commands.
"""
self.logger.info("Setting up %s", self.name)
context = azmq.Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind: # TODO: Should this ever be the case?
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
self.add_behavior(self._zmq_command_loop())
self.logger.info("Finished setting up %s", self.name)
async def stop(self):
if self.subsocket:
self.subsocket.close()
if self.pubsocket:
self.pubsocket.close()
await super().stop()
async def handle_message(self, msg: InternalMessage):
"""
Handle commands received from other internal Python agents.
Validates the message as a :class:`SpeechCommand` and forwards it to the robot.
:param msg: The internal message containing the command.
"""
try:
speech_command = SpeechCommand.model_validate_json(msg.body)
await self.pubsocket.send_json(speech_command.model_dump())
except Exception:
self.logger.exception("Error processing internal message.")
async def _zmq_command_loop(self):
"""
Loop to handle commands received via ZMQ (e.g., from the UI).
Listens on the 'command' topic, validates the JSON, and forwards it to the robot.
"""
while self._running:
try:
_, body = await self.subsocket.recv_multipart()
body = json.loads(body)
message = SpeechCommand.model_validate(body)
await self.pubsocket.send_json(message.model_dump())
except Exception:
self.logger.exception("Error processing ZMQ message.")

View File

@@ -1,26 +1,18 @@
import logging
from control_backend.core.agent_system import BaseAgent as CoreBaseAgent
from spade.agent import Agent
class BaseAgent(CoreBaseAgent):
class BaseAgent(Agent):
"""
The primary base class for all implementation agents.
Inherits from :class:`control_backend.core.agent_system.BaseAgent`.
This class ensures that every agent instance is automatically equipped with a
properly configured ``logger``.
:ivar logger: A logger instance named after the agent's package and class.
Base agent class for our agents to inherit from.
This ensures that all agents have a logger.
"""
logger: logging.Logger
# Whenever a subclass is initiated, give it the correct logger
def __init_subclass__(cls, **kwargs) -> None:
"""
Whenever a subclass is initiated, give it the correct logger.
:param kwargs: Keyword arguments for the subclass.
"""
super().__init_subclass__(**kwargs)
cls.logger = logging.getLogger(__package__).getChild(cls.__name__)

View File

@@ -1,8 +1,2 @@
from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent
from .belief_collector_agent import (
BDIBeliefCollectorAgent as BDIBeliefCollectorAgent,
)
from .text_belief_extractor_agent import (
TextBeliefExtractorAgent as TextBeliefExtractorAgent,
)
from .bdi_core import BDICoreAgent as BDICoreAgent
from .text_extractor import TBeliefExtractorAgent as TBeliefExtractorAgent

View File

@@ -0,0 +1,106 @@
import json
import logging
import agentspeak
from spade.behaviour import OneShotBehaviour
from spade.message import Message
from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings
from .behaviours.belief_setter import BeliefSetterBehaviour
from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour
class BDICoreAgent(BDIAgent):
"""
This is the Brain agent that does the belief inference with AgentSpeak.
This is a continous process that happens automatically in the background.
This class contains all the actions that can be called from AgentSpeak plans.
It has the BeliefSetter behaviour and can aks and recieve requests from the LLM agent.
"""
logger = logging.getLogger(__package__).getChild(__name__)
async def setup(self) -> None:
"""
Initializes belief behaviors and message routing.
"""
self.logger.info("BDICoreAgent setup started.")
self.add_behaviour(BeliefSetterBehaviour())
self.add_behaviour(ReceiveLLMResponseBehaviour())
self.logger.info("BDICoreAgent setup complete.")
def add_custom_actions(self, actions) -> None:
"""
Registers custom AgentSpeak actions callable from plans.
"""
@actions.add(".reply", 3)
def _reply(agent: "BDICoreAgent", term, intention):
"""
Sends text to the LLM (AgentSpeak action).
Example: .reply("Hello LLM!")
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
goals = agentspeak.grounded(term.args[2], intention.scope)
self.logger.debug("Norms: %s", norms)
self.logger.debug("Goals: %s", goals)
self.logger.debug("User text: %s", message_text)
self._send_to_llm(str(message_text), str(norms), str(goals))
yield
@actions.add(".reply_no_norms", 2)
def _reply_no_norms(agent: "BDICoreAgent", term, intention):
message_text = agentspeak.grounded(term.args[0], intention.scope)
goals = agentspeak.grounded(term.args[1], intention.scope)
self.logger.debug("Goals: %s", goals)
self.logger.debug("User text: %s", message_text)
self._send_to_llm(str(message_text), goals=str(goals))
@actions.add(".reply_no_goals", 2)
def _reply_no_goals(agent: "BDICoreAgent", term, intention):
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
self.logger.debug("Norms: %s", norms)
self.logger.debug("User text: %s", message_text)
self._send_to_llm(str(message_text), norms=str(norms))
@actions.add(".reply_no_goals_no_norms", 1)
def _reply_no_goals_no_norms(agent: "BDICoreAgent", term, intention):
message_text = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("User text: %s", message_text)
self._send_to_llm(message_text)
def _send_to_llm(self, text: str, norms: str = None, goals: str = None):
"""
Sends a text query to the LLM Agent asynchronously.
"""
class SendBehaviour(OneShotBehaviour):
async def run(self) -> None:
message_dict = {
"text": text,
"norms": norms if norms else "",
"goals": goals if goals else "",
}
msg = Message(
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
body=json.dumps(message_dict),
)
await self.send(msg)
self.agent.logger.info("Message sent to LLM agent: %s", text)
self.add_behaviour(SendBehaviour())

View File

@@ -1,277 +0,0 @@
import asyncio
import copy
import time
from collections.abc import Iterable
import agentspeak
import agentspeak.runtime
import agentspeak.stdlib
from pydantic import ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
from control_backend.schemas.ri_message import SpeechCommand
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
class BDICoreAgent(BaseAgent):
"""
BDI Core Agent.
This is the central reasoning agent of the system, powered by the **AgentSpeak(L)** language.
It maintains a belief base (representing the state of the world) and a set of plans (rules).
It runs an internal BDI (Belief-Desire-Intention) cycle using the ``agentspeak`` library.
When beliefs change (e.g., via :meth:`_apply_beliefs`), the agent evaluates its plans to
determine the best course of action.
**Custom Actions:**
It defines custom actions (like ``.reply``) that allow the AgentSpeak code to interact with
external Python agents (e.g., querying the LLM).
:ivar bdi_agent: The internal AgentSpeak agent instance.
:ivar asl_file: Path to the AgentSpeak source file (.asl).
:ivar env: The AgentSpeak environment.
:ivar actions: A registry of custom actions available to the AgentSpeak code.
:ivar _wake_bdi_loop: Event used to wake up the reasoning loop when new beliefs arrive.
"""
bdi_agent: agentspeak.runtime.Agent
def __init__(self, name: str, asl: str):
super().__init__(name)
self.asl_file = asl
self.env = agentspeak.runtime.Environment()
# Deep copy because we don't actually want to modify the standard actions globally
self.actions = copy.deepcopy(agentspeak.stdlib.actions)
self._wake_bdi_loop = asyncio.Event()
async def setup(self) -> None:
"""
Initialize the BDI agent.
1. Registers custom actions (like ``.reply``).
2. Loads the .asl source file.
3. Starts the reasoning loop (:meth:`_bdi_loop`) in the background.
"""
self.logger.debug("Setup started.")
self._add_custom_actions()
await self._load_asl()
# Start the BDI cycle loop
self.add_behavior(self._bdi_loop())
self._wake_bdi_loop.set()
self.logger.debug("Setup complete.")
async def _load_asl(self):
"""
Load and parse the AgentSpeak source file.
"""
try:
with open(self.asl_file) as source:
self.bdi_agent = self.env.build_agent(source, self.actions)
except FileNotFoundError:
self.logger.warning(f"Could not find the specified ASL file at {self.asl_file}.")
self.bdi_agent = agentspeak.runtime.Agent(self.env, self.name)
async def _bdi_loop(self):
"""
The main BDI reasoning loop.
It waits for the ``_wake_bdi_loop`` event (set when beliefs change or actions complete).
When awake, it steps through the AgentSpeak interpreter. It also handles sleeping if
the agent has deferred intentions (deadlines).
"""
while self._running:
await (
self._wake_bdi_loop.wait()
) # gets set whenever there's an update to the belief base
# Agent knows when it's expected to have to do its next thing
maybe_more_work = True
while maybe_more_work:
maybe_more_work = False
self.logger.debug("Stepping BDI.")
if self.bdi_agent.step():
maybe_more_work = True
if not maybe_more_work:
deadline = self.bdi_agent.shortest_deadline()
if deadline:
self.logger.debug("Sleeping until %s", deadline)
await asyncio.sleep(deadline - time.time())
maybe_more_work = True
else:
self._wake_bdi_loop.clear()
self.logger.debug("No more deadlines. Halting BDI loop.")
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages.
- **Beliefs**: Updates the internal belief base.
- **LLM Responses**: Forwards the generated text to the Robot Speech Agent (actuation).
:param msg: The received internal message.
"""
self.logger.debug("Processing message from %s.", msg.sender)
if msg.thread == "beliefs":
try:
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
self._apply_beliefs(beliefs)
except ValidationError:
self.logger.exception("Error processing belief.")
return
# The message was not a belief, handle special cases based on sender
match msg.sender:
case settings.agent_settings.llm_name:
content = msg.body
self.logger.info("Received LLM response: %s", content)
# Forward to Robot Speech Agent
cmd = SpeechCommand(data=content)
out_msg = InternalMessage(
to=settings.agent_settings.robot_speech_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
def _apply_beliefs(self, beliefs: list[Belief]):
"""
Update the belief base with a list of new beliefs.
If ``replace=True`` is set on a belief, it removes all existing beliefs with that name
before adding the new one.
"""
if not beliefs:
return
for belief in beliefs:
if belief.replace:
self._remove_all_with_name(belief.name)
self._add_belief(belief.name, belief.arguments)
def _add_belief(self, name: str, args: Iterable[str] = []):
"""
Add a single belief to the BDI agent.
:param name: The functor/name of the belief (e.g., "user_said").
:param args: Arguments for the belief.
"""
# new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args)
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.belief,
term,
agentspeak.runtime.Intention(),
)
self._wake_bdi_loop.set()
self.logger.debug(f"Added belief {self.format_belief_string(name, args)}")
def _remove_belief(self, name: str, args: Iterable[str]):
"""
Removes a specific belief (with arguments), if it exists.
"""
new_args = (agentspeak.Literal(arg) for arg in args)
term = agentspeak.Literal(name, new_args)
result = self.bdi_agent.call(
agentspeak.Trigger.removal,
agentspeak.GoalType.belief,
term,
agentspeak.runtime.Intention(),
)
if result:
self.logger.debug(f"Removed belief {self.format_belief_string(name, args)}")
self._wake_bdi_loop.set()
else:
self.logger.debug("Failed to remove belief (it was not in the belief base).")
def _remove_all_with_name(self, name: str):
"""
Removes all beliefs that match the given `name`.
"""
relevant_groups = []
for key in self.bdi_agent.beliefs:
if key[0] == name:
relevant_groups.append(key)
removed_count = 0
for group in relevant_groups:
beliefs_to_remove = list(self.bdi_agent.beliefs[group])
for belief in beliefs_to_remove:
self.bdi_agent.call(
agentspeak.Trigger.removal,
agentspeak.GoalType.belief,
belief,
agentspeak.runtime.Intention(),
)
removed_count += 1
self._wake_bdi_loop.set()
self.logger.debug(f"Removed {removed_count} beliefs.")
def _add_custom_actions(self) -> None:
"""
Add any custom actions here. Inside `@self.actions.add()`, the first argument is
the name of the function in the ASL file, and the second the amount of arguments
the function expects (which will be located in `term.args`).
"""
@self.actions.add(".reply", 3)
def _reply(agent: "BDICoreAgent", term, intention):
"""
Sends text to the LLM (AgentSpeak action).
Example: .reply("Hello LLM!", "Some norm", "Some goal")
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
goals = agentspeak.grounded(term.args[2], intention.scope)
self.logger.debug("Norms: %s", norms)
self.logger.debug("Goals: %s", goals)
self.logger.debug("User text: %s", message_text)
asyncio.create_task(self._send_to_llm(str(message_text), str(norms), str(goals)))
yield
async def _send_to_llm(self, text: str, norms: str = None, goals: str = None):
"""
Sends a text query to the LLM agent asynchronously.
"""
prompt = LLMPromptMessage(
text=text,
norms=norms.split("\n") if norms else [],
goals=goals.split("\n") if norms else [],
)
msg = InternalMessage(
to=settings.agent_settings.llm_name,
sender=self.name,
body=prompt.model_dump_json(),
)
await self.send(msg)
self.logger.info("Message sent to LLM agent: %s", text)
@staticmethod
def format_belief_string(name: str, args: Iterable[str] = []):
"""
Given a belief's name and its args, return a string of the form "name(*args)"
"""
return f"{name}{'(' if args else ''}{','.join(args)}{')' if args else ''}"

View File

@@ -1,94 +0,0 @@
import zmq
from pydantic import ValidationError
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.program import Program
class BDIProgramManager(BaseAgent):
"""
BDI Program Manager Agent.
This agent is responsible for receiving high-level programs (sequences of instructions/goals)
from the external HTTP API (via ZMQ) and translating them into core beliefs (norms and goals)
for the BDI Core Agent. In the future, it will be responsible for determining when goals are
met, and passing on new norms and goals accordingly.
:ivar sub_socket: The ZMQ SUB socket used to receive program updates.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
async def _send_to_bdi(self, program: Program):
"""
Convert a received program into BDI beliefs and send them to the BDI Core Agent.
Currently, it takes the **first phase** of the program and extracts:
- **Norms**: Constraints or rules the agent must follow.
- **Goals**: Objectives the agent must achieve.
These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will
overwrite any existing norms/goals of the same name in the BDI agent.
:param program: The program object received from the API.
"""
first_phase = program.phases[0]
norms_belief = Belief(
name="norms",
arguments=[norm.norm for norm in first_phase.norms],
replace=True,
)
goals_belief = Belief(
name="goals",
arguments=[goal.description for goal in first_phase.goals],
replace=True,
)
program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief])
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=program_beliefs.model_dump_json(),
thread="beliefs",
)
await self.send(message)
self.logger.debug("Sent new norms and goals to the BDI agent.")
async def _receive_programs(self):
"""
Continuous loop that receives program updates from the HTTP endpoint.
It listens to the ``program`` topic on the internal ZMQ SUB socket.
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
try:
program = Program.model_validate_json(body)
except ValidationError:
self.logger.exception("Received an invalid program.")
continue
await self._send_to_bdi(program)
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'program' topic.
Starts the background behavior to receive programs.
"""
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("program")
self.add_behavior(self._receive_programs())

View File

@@ -0,0 +1,27 @@
import zmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .receive_programs_behavior import ReceiveProgramsBehavior
class BDIProgramManager(BaseAgent):
"""
Will interpret programs received from the HTTP endpoint. Extracts norms, goals, triggers and
forwards them to the BDI as beliefs.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
async def setup(self):
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("program")
self.add_behaviour(ReceiveProgramsBehavior())

View File

@@ -0,0 +1,59 @@
import json
from pydantic import ValidationError
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings
from control_backend.schemas.program import Program
class ReceiveProgramsBehavior(CyclicBehaviour):
async def _receive(self) -> Program | None:
topic, body = await self.agent.sub_socket.recv_multipart()
try:
return Program.model_validate_json(body)
except ValidationError as e:
self.agent.logger.error("Received an invalid program.", exc_info=e)
return None
def _extract_norms(self, program: Program) -> str:
"""First phase only for now, as a single newline delimited string."""
if not program.phases:
return ""
if not program.phases[0].phaseData.norms:
return ""
norm_values = [norm.value for norm in program.phases[0].phaseData.norms]
return "\n".join(norm_values)
def _extract_goals(self, program: Program) -> str:
"""First phase only for now, as a single newline delimited string."""
if not program.phases:
return ""
if not program.phases[0].phaseData.goals:
return ""
goal_descriptions = [goal.description for goal in program.phases[0].phaseData.goals]
return "\n".join(goal_descriptions)
async def _send_to_bdi(self, program: Program):
temp_allowed_parts = {
"norms": [self._extract_norms(program)],
"goals": [self._extract_goals(program)],
}
message = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
sender=self.agent.jid,
body=json.dumps(temp_allowed_parts),
thread="beliefs",
)
await self.send(message)
self.agent.logger.debug("Sent new norms and goals to the BDI agent.")
async def run(self):
program = await self._receive()
if not program:
return
await self._send_to_bdi(program)

View File

@@ -0,0 +1,92 @@
import json
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings
class BeliefSetterBehaviour(CyclicBehaviour):
"""
This is the behaviour that the BDI agent runs. This behaviour waits for incoming
message and updates the agent's beliefs accordingly.
"""
agent: BDIAgent
async def run(self):
"""Polls for messages and processes them."""
msg = await self.receive(timeout=1)
if not msg:
return
self.agent.logger.debug(
"Received message from %s with thread '%s' and body: %s",
msg.sender,
msg.thread,
msg.body,
)
self._process_message(msg)
def _process_message(self, message: Message):
"""Routes the message to the correct processing function based on the sender."""
sender = message.sender.node # removes host from jid and converts to str
self.agent.logger.debug("Processing message from sender: %s", sender)
match sender:
case settings.agent_settings.belief_collector_agent_name:
self.agent.logger.debug(
"Message is from the belief collector agent. Processing as belief message."
)
self._process_belief_message(message)
case settings.agent_settings.program_manager_agent_name:
self.agent.logger.debug(
"Processing message from the program manager. Processing as belief message."
)
self._process_belief_message(message)
case _:
self.agent.logger.debug("Not from expected agents, discarding message")
pass
def _process_belief_message(self, message: Message):
if not message.body:
self.agent.logger.debug("Ignoring message with empty body from %s", message.sender.node)
return
match message.thread:
case "beliefs":
try:
beliefs: dict[str, list[str]] = json.loads(message.body)
self._set_beliefs(beliefs)
except json.JSONDecodeError:
self.agent.logger.error(
"Could not decode beliefs from JSON. Message body: '%s'",
message.body,
exc_info=True,
)
case _:
pass
def _set_beliefs(self, beliefs: dict[str, list[str]]):
"""Removes previous values for beliefs and updates them with the provided values."""
if self.agent.bdi is None:
self.agent.logger.warning("Cannot set beliefs; agent's BDI is not yet initialized.")
return
if not beliefs:
self.agent.logger.debug("Received an empty set of beliefs. No beliefs were updated.")
return
# Set new beliefs (outdated beliefs are automatically removed)
for belief, arguments in beliefs.items():
self.agent.logger.debug("Setting belief %s with arguments %s", belief, arguments)
self.agent.bdi.set_belief(belief, *arguments)
# Special case: if there's a new user message, flag that we haven't responded yet
if belief == "user_said":
self.agent.bdi.set_belief("new_message")
self.agent.logger.debug(
"Detected 'user_said' belief, also setting 'new_message' belief."
)
self.agent.logger.info("Successfully updated %d beliefs.", len(beliefs))

View File

@@ -0,0 +1,39 @@
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
class ReceiveLLMResponseBehaviour(CyclicBehaviour):
"""
Adds behavior to receive responses from the LLM Agent.
"""
async def run(self):
msg = await self.receive(timeout=1)
if not msg:
return
sender = msg.sender.node
match sender:
case settings.agent_settings.llm_agent_name:
content = msg.body
self.agent.logger.info("Received LLM response: %s", content)
speech_command = SpeechCommand(data=content)
message = Message(
to=settings.agent_settings.ri_command_agent_name
+ "@"
+ settings.agent_settings.host,
sender=self.agent.jid,
body=speech_command.model_dump_json(),
)
self.agent.logger.debug("Sending message: %s", message)
await self.send(message)
case _:
self.agent.logger.debug("Discarding message from %s", sender)
pass

View File

@@ -0,0 +1,104 @@
import json
import logging
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings
class BeliefFromText(CyclicBehaviour):
logger = logging.getLogger(__name__)
# TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent. Your task is to extract values \
from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
and "text" (user's transcript).
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
A single piece of text might contain multiple instances that match a belief.
Respond ONLY with a single JSON object.
The JSON object's keys should be the belief functors (e.g., "weather").
The value for each key must be a list of lists.
Each inner list must contain the extracted arguments (as strings) for one instance \
of that belief.
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
in your response.
"""
# on_start agent receives message containing the beliefs to look out for and
# sets up the LLM with instruction prompt
# async def on_start(self):
# msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message
# send instruction prompt to LLM
beliefs: dict[str, list[str]]
beliefs = {"mood": ["X"], "car": ["Y"]}
async def run(self):
msg = await self.receive(timeout=1)
if not msg:
return
sender = msg.sender.node
match sender:
case settings.agent_settings.transcription_agent_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body)
case _:
self.logger.info("Discarding message from %s", sender)
pass
async def _process_transcription(self, text: str):
text_prompt = f"Text: {text}"
beliefs_prompt = "These are the beliefs to be bound:\n"
for belief, values in self.beliefs.items():
beliefs_prompt += f"{belief}({', '.join(values)})\n"
prompt = text_prompt + beliefs_prompt
self.logger.info(prompt)
# prompt_msg = Message(to="LLMAgent@whatever")
# response = self.send(prompt_msg)
# Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]]
response = '{"mood": [["happy"]]}'
# Verify by trying to parse
try:
json.loads(response)
belief_message = Message()
belief_message.to = (
settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host
)
belief_message.body = response
belief_message.thread = "beliefs"
await self.send(belief_message)
self.agent.logger.info("Sent beliefs to BDI.")
except json.JSONDecodeError:
# Parsing failed, so the response is in the wrong format, log warning
self.agent.logger.warning("Received LLM response in incorrect format.")
async def _process_transcription_demo(self, txt: str):
"""
Demo version to process the transcription input to beliefs. For the demo only the belief
'user_said' is relevant, so this function simply makes a dict with key: "user_said",
value: txt and passes this to the Belief Collector agent.
"""
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = Message()
belief_msg.to = (
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host
)
belief_msg.body = payload
belief_msg.thread = "beliefs"
await self.send(belief_msg)
self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"]))

View File

@@ -1,152 +0,0 @@
import json
from pydantic import ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
class BDIBeliefCollectorAgent(BaseAgent):
"""
BDI Belief Collector Agent.
This agent acts as a central aggregator for beliefs derived from various sources (e.g., text,
emotion, vision). It receives raw extracted data from other agents,
normalizes them into valid :class:`Belief` objects, and forwards them as a unified packet to the
BDI Core Agent.
It serves as a funnel to ensure the BDI agent receives a consistent stream of beliefs.
"""
async def setup(self):
"""
Initialize the agent.
"""
self.logger.info("Setting up %s", self.name)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages from other extractor agents.
Routes the message to specific handlers based on the 'type' field in the JSON body.
Supported types:
- ``belief_extraction_text``: Handled by :meth:`_handle_belief_text`
- ``emotion_extraction_text``: Handled by :meth:`_handle_emo_text`
:param msg: The received internal message.
"""
sender_node = msg.sender
# Parse JSON payload
try:
payload = json.loads(msg.body)
except Exception as e:
self.logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node,
msg.body,
e,
)
return
msg_type = payload.get("type")
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text":
self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node)
await self._handle_belief_text(payload, sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text":
self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
await self._handle_emo_text(payload, sender_node)
else:
self.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type
)
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Process text-based belief extraction payloads.
Expected payload format::
{
"type": "belief_extraction_text",
"beliefs": {
"user_said": ["Can you help me?"],
"intention": ["ask_help"]
}
}
Validates and converts the dictionary items into :class:`Belief` objects.
:param payload: The dictionary payload containing belief data.
:param origin: The name of the sender agent.
"""
beliefs = payload.get("beliefs", {})
if not beliefs:
self.logger.debug("Received empty beliefs set.")
return
def try_create_belief(name, arguments) -> Belief | None:
"""
Create a belief object from name and arguments, or return None silently if the input is
not correct.
:param name: The name of the belief.
:param arguments: The arguments of the belief.
:return: A Belief object if the input is valid or None.
"""
try:
return Belief(name=name, arguments=arguments)
except ValidationError:
return None
beliefs = [
belief
for name, arguments in beliefs.items()
if (belief := try_create_belief(name, arguments)) is not None
]
self.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief in beliefs:
for argument in belief.arguments:
self.logger.debug(" - %s %s", belief.name, argument)
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""
Process emotion extraction payloads.
**TODO**: Implement this method once emotion recognition is integrated.
:param payload: The dictionary payload containing emotion data.
:param origin: The name of the sender agent.
"""
pass
async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None):
"""
Send a list of aggregated beliefs to the BDI Core Agent.
Wraps the beliefs in a :class:`BeliefMessage` and sends it via the 'beliefs' thread.
:param beliefs: The list of Belief objects to send.
:param origin: (Optional) The original source of the beliefs (unused currently).
"""
if not beliefs:
return
msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
thread="beliefs",
)
await self.send(msg)
self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))

View File

@@ -1,6 +1,18 @@
norms("").
goals("").
+user_said(Message) : norms(Norms) & goals(Goals) <-
-user_said(Message);
+new_message : user_said(Message) & norms(Norms) & goals(Goals) <-
-new_message;
.reply(Message, Norms, Goals).
// +new_message : user_said(Message) & norms(Norms) <-
// -new_message;
// .reply_no_goals(Message, Norms).
//
// +new_message : user_said(Message) & goals(Goals) <-
// -new_message;
// .reply_no_norms(Message, Goals).
//
// +new_message : user_said(Message) <-
// -new_message;
// .reply_no_goals_no_norms(Message).

View File

@@ -1,65 +0,0 @@
import json
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
class TextBeliefExtractorAgent(BaseAgent):
"""
Text Belief Extractor Agent.
This agent is responsible for processing raw text (e.g., from speech transcription) and
extracting semantic beliefs from it.
In the current demonstration version, it performs a simple wrapping of the user's input
into a ``user_said`` belief. In a full implementation, this agent would likely interact
with an LLM or NLU engine to extract intent, entities, and other structured information.
"""
async def setup(self):
"""
Initialize the agent and its resources.
"""
self.logger.info("Settting up %s.", self.name)
# Setup LLM belief context if needed (currently demo is just passthrough)
self.beliefs = {"mood": ["X"], "car": ["Y"]}
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages, primarily from the Transcription Agent.
:param msg: The received message containing transcribed text.
"""
sender = msg.sender
if sender == settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body)
else:
self.logger.info("Discarding message from %s", sender)
async def _process_transcription_demo(self, txt: str):
"""
Process the transcribed text and generate beliefs.
**Demo Implementation:**
Currently, this method takes the raw text ``txt`` and wraps it into a belief structure:
``user_said("txt")``.
This belief is then sent to the :class:`BDIBeliefCollectorAgent`.
:param txt: The raw transcribed text string.
"""
# For demo, just wrapping user text as user_said belief
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = InternalMessage(
to=settings.agent_settings.bdi_belief_collector_name,
sender=self.name,
body=payload,
thread="beliefs",
)
await self.send(belief_msg)
self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"]))

View File

@@ -0,0 +1,8 @@
from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor import BeliefFromText
class TBeliefExtractorAgent(BaseAgent):
async def setup(self):
self.add_behaviour(BeliefFromText())

View File

@@ -0,0 +1,94 @@
import json
from json import JSONDecodeError
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
class ContinuousBeliefCollector(CyclicBehaviour):
"""
Continuously collects beliefs/emotions from extractor agents:
Then we send a unified belief packet to the BDI agent.
"""
async def run(self):
msg = await self.receive(timeout=1)
if not msg:
return
await self._process_message(msg)
async def _process_message(self, msg: Message):
sender_node = msg.sender.node
# Parse JSON payload
try:
payload = json.loads(msg.body)
except JSONDecodeError as e:
self.agent.logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node,
msg.body,
e,
)
return
msg_type = payload.get("type")
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock":
self.agent.logger.debug(
"Message routed to _handle_belief_text (sender=%s)", sender_node
)
await self._handle_belief_text(payload, sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
self.agent.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
await self._handle_emo_text(payload, sender_node)
else:
self.agent.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type
)
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Expected payload:
{
"type": "belief_extraction_text",
"beliefs": {"user_said": ["Can you help me?"]}
}
"""
beliefs = payload.get("beliefs", {})
if not beliefs:
self.agent.logger.debug("Received empty beliefs set.")
return
self.agent.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief_name, belief_list in beliefs.items():
for belief in belief_list:
self.agent.logger.debug(" - %s %s", belief_name, str(belief))
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""TODO: implement (after we have emotional recogntion)"""
pass
async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None):
"""
Sends a unified belief packet to the BDI agent.
"""
if not beliefs:
return
to_jid = f"{settings.agent_settings.bdi_core_agent_name}@{settings.agent_settings.host}"
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs)
await self.send(msg)
self.agent.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))

View File

@@ -0,0 +1,11 @@
from control_backend.agents.base import BaseAgent
from .behaviours.continuous_collect import ContinuousBeliefCollector
class BeliefCollectorAgent(BaseAgent):
async def setup(self):
self.logger.info("BeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(ContinuousBeliefCollector())
self.logger.info("BeliefCollectorAgent ready.")

View File

@@ -1 +0,0 @@
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent

View File

@@ -1,294 +0,0 @@
import asyncio
import json
import zmq
import zmq.asyncio as azmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.config import settings
from ..actuation.robot_speech_agent import RobotSpeechAgent
class RICommunicationAgent(BaseAgent):
"""
Robot Interface (RI) Communication Agent.
This agent manages the high-level connection negotiation and health checking (heartbeat)
between the Control Backend and the Robot Interface (or UI).
It acts as a service discovery mechanism:
1. It initiates a handshake (negotiation) to discover where other services (like the robot
command listener) are listening.
2. It spawns specific agents
(like :class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`)
once the connection details are established.
3. It maintains a "ping" loop to ensure the connection remains active.
:ivar _address: The ZMQ address to attempt the initial connection negotiation.
:ivar _bind: Whether to bind or connect the negotiation socket.
:ivar _req_socket: ZMQ REQ socket for negotiation and pings.
:ivar pub_socket: ZMQ PUB socket for internal notifications (e.g., ping status).
:ivar connected: Boolean flag indicating active connection status.
"""
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
bind=False,
):
super().__init__(name)
self._address = address
self._bind = bind
self._req_socket: azmq.Socket | None = None
self.pub_socket: azmq.Socket | None = None
self.connected = False
async def setup(self):
"""
Initialize the agent and attempt connection.
Tries to negotiate connection up to ``behaviour_settings.comm_setup_max_retries`` times.
If successful, starts the :meth:`_listen_loop`.
"""
self.logger.info("Setting up %s", self.name)
# Bind request socket
await self._setup_sockets()
if await self._negotiate_connection():
self.connected = True
self.add_behavior(self._listen_loop())
else:
self.logger.warning("Failed to negotiate connection during setup.")
self.logger.info("Finished setting up %s", self.name)
async def _setup_sockets(self, force=False):
"""
Initialize ZMQ sockets (REQ for negotiation, PUB for internal updates).
"""
# Bind request socket
if self._req_socket is None or force:
self._req_socket = Context.instance().socket(zmq.REQ)
if self._bind:
self._req_socket.bind(self._address)
else:
self._req_socket.connect(self._address)
if self.pub_socket is None or force:
self.pub_socket = Context.instance().socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
async def _negotiate_connection(
self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries
):
"""
Perform the handshake protocol with the Robot Interface.
Sends a ``negotiate/ports`` request and expects a configuration response containing
port assignments for various services (e.g., actuation).
:param max_retries: Number of attempts before giving up.
:return: True if negotiation succeeded, False otherwise.
"""
retries = 0
while retries < max_retries:
if self._req_socket is None:
retries += 1
continue
# Send our message and receive one back
message = {"endpoint": "negotiate/ports", "data": {}}
await self._req_socket.send_json(message)
retry_frequency = 1.0
try:
received_message = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=retry_frequency
)
except TimeoutError:
self.logger.warning(
"No connection established in %d seconds (attempt %d/%d)",
retries * retry_frequency,
retries + 1,
max_retries,
)
retries += 1
continue
except Exception as e:
self.logger.warning("Unexpected error during negotiation: %s", e)
retries += 1
continue
# Validate endpoint
endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports":
self.logger.warning(
"Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
max_retries,
)
retries += 1
await asyncio.sleep(1)
continue
# At this point, we have a valid response
try:
await self._handle_negotiation_response(received_message)
# Let UI know that we're connected
topic = b"ping"
data = json.dumps(True).encode()
if self.pub_socket:
await self.pub_socket.send_multipart([topic, data])
return True
except Exception as e:
self.logger.warning("Error unpacking negotiation data: %s", e)
retries += 1
await asyncio.sleep(settings.behaviour_settings.sleep_s)
continue
return False
async def _handle_negotiation_response(self, received_message):
"""
Parse the negotiation response and initialize services.
Based on the response, it might re-connect the main socket or spawn new agents
(e.g., for robot actuation).
"""
for port_data in received_message["data"]:
id = port_data["id"]
port = port_data["port"]
bind = port_data["bind"]
if not bind:
addr = f"tcp://localhost:{port}"
else:
addr = f"tcp://*:{port}"
match id:
case "main":
if addr != self._address:
assert self._req_socket is not None
if not bind:
self._req_socket.connect(addr)
else:
self._req_socket.bind(addr)
case "actuation":
gesture_data = port_data.get("gestures", [])
robot_speech_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name,
address=addr,
bind=bind,
)
robot_gesture_agent = RobotGestureAgent(
settings.agent_settings.robot_gesture_name,
address=addr,
bind=bind,
gesture_data=gesture_data,
)
await robot_speech_agent.start()
await asyncio.sleep(0.1) # Small delay
await robot_gesture_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)
async def stop(self):
"""
Closes all sockets.
:return:
"""
if self._req_socket:
self._req_socket.close()
if self.pub_socket:
self.pub_socket.close()
await super().stop()
async def _listen_loop(self):
"""
Maintain the connection via a heartbeat (ping) loop.
Sends a ``ping`` request periodically and waits for a reply.
If pings fail repeatedly, it triggers a disconnection handler to restart negotiation.
"""
while self._running:
if not self.connected:
await asyncio.sleep(settings.behaviour_settings.sleep_s)
continue
# We need to listen and send pings.
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
seconds_to_wait_total = settings.behaviour_settings.sleep_s
try:
assert self._req_socket is not None
await asyncio.wait_for(
self._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
)
except TimeoutError:
self.logger.debug(
"Waited too long to send message - "
"we probably dont have any receivers... but let's check!"
)
# Wait up to {seconds_to_wait_total/2} seconds for a reply
try:
assert self._req_socket is not None
message = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
)
self.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" not in message:
self.logger.warning("No received endpoint in message, expected ping endpoint.")
continue
# See what endpoint we received
match message["endpoint"]:
case "ping":
topic = b"ping"
data = json.dumps(True).encode()
if self.pub_socket is not None:
await self.pub_socket.send_multipart([topic, data])
await asyncio.sleep(settings.behaviour_settings.sleep_s)
case _:
self.logger.debug(
"Received message with topic different than ping, while ping expected."
)
# We didnt get a reply
except TimeoutError:
self.logger.info(
f"No ping retrieved in {seconds_to_wait_total} seconds, "
"sending UI disconnection event and attempting to restart."
)
await self._handle_disconnection()
continue
except Exception:
self.logger.error("Error while waiting for ping message.", exc_info=True)
raise
async def _handle_disconnection(self):
"""
Handle connection loss.
Notifies the UI of disconnection (via internal PUB) and attempts to restart negotiation.
"""
self.connected = False
# Tell UI we're disconnected.
topic = b"ping"
data = json.dumps(False).encode()
if self.pub_socket:
try:
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
except TimeoutError:
self.logger.warning("Connection ping for router timed out.")
# Try to reboot/renegotiate
self.logger.debug("Restarting communication negotiation.")
if await self._negotiate_connection(max_retries=1):
self.connected = True

View File

@@ -1 +0,0 @@
from .llm_agent import LLMAgent as LLMAgent

View File

@@ -0,0 +1,159 @@
import json
import re
from collections.abc import AsyncGenerator
import httpx
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .llm_instructions import LLMInstructions
class LLMAgent(BaseAgent):
"""
Agent responsible for processing user text input and querying a locally
hosted LLM for text generation. Receives messages from the BDI Core Agent
and responds with processed LLM output.
"""
class ReceiveMessageBehaviour(CyclicBehaviour):
"""
Cyclic behaviour to continuously listen for incoming messages from
the BDI Core Agent and handle them.
"""
async def run(self):
"""
Receives SPADE messages and processes only those originating from the
configured BDI agent.
"""
msg = await self.receive(timeout=1)
if not msg:
return
sender = msg.sender.node
self.agent.logger.debug(
"Received message: %s from %s",
msg.body,
sender,
)
if sender == settings.agent_settings.bdi_core_agent_name:
self.agent.logger.debug("Processing message from BDI Core Agent")
await self._process_bdi_message(msg)
else:
self.agent.logger.debug("Message ignored (not from BDI Core Agent)")
async def _process_bdi_message(self, message: Message):
"""
Forwards user text from the BDI to the LLM and replies with the generated text in chunks
separated by punctuation.
"""
try:
message = json.loads(message.body)
except json.JSONDecodeError:
self.agent.logger.error("Could not process BDI message.", exc_info=True)
# Consume the streaming generator and send a reply for every chunk
async for chunk in self._query_llm(message["text"], message["norms"], message["goals"]):
await self._reply(chunk)
self.agent.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI Core Agent."
)
async def _reply(self, msg: str):
"""
Sends a response message back to the BDI Core Agent.
"""
reply = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=msg,
)
await self.send(reply)
async def _query_llm(self, prompt: str, norms: str, goals: str) -> AsyncGenerator[str]:
"""
Sends a chat completion request to the local LLM service and streams the response by
yielding fragments separated by punctuation like.
:param prompt: Input text prompt to pass to the LLM.
:yield: Fragments of the LLM-generated content.
"""
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
messages = [
{
"role": "developer",
"content": instructions.build_developer_instruction(),
},
{
"role": "user",
"content": prompt,
},
]
try:
current_chunk = ""
async for token in self._stream_query_llm(messages):
current_chunk += token
# Stream the message in chunks separated by punctuation.
# We include the delimiter in the emitted chunk for natural flow.
pattern = re.compile(r".*?(?:,|;|:|—||\.{3}|…|\.|\?|!)\s*", re.DOTALL)
for m in pattern.finditer(current_chunk):
chunk = m.group(0)
if chunk:
yield current_chunk
current_chunk = ""
# Yield any remaining tail
if current_chunk:
yield current_chunk
except httpx.HTTPError as err:
self.agent.logger.error("HTTP error.", exc_info=err)
yield "LLM service unavailable."
except Exception as err:
self.agent.logger.error("Unexpected error.", exc_info=err)
yield "Error processing the request."
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
"""Raises httpx.HTTPError when the API gives an error."""
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": messages,
"temperature": 0.3,
"stream": True,
},
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data = line[len("data: ") :]
if data.strip() == "[DONE]":
break
try:
event = json.loads(data)
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
if delta:
yield delta
except json.JSONDecodeError:
self.agent.logger.error("Failed to parse LLM response: %s", data)
async def setup(self):
"""
Sets up the SPADE behaviour to filter and process messages from the
BDI Core Agent.
"""
behaviour = self.ReceiveMessageBehaviour()
self.add_behaviour(behaviour)
self.logger.info("LLMAgent setup complete")

View File

@@ -1,195 +0,0 @@
import json
import re
import uuid
from collections.abc import AsyncGenerator
import httpx
from pydantic import ValidationError
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from ...schemas.llm_prompt_message import LLMPromptMessage
from .llm_instructions import LLMInstructions
class LLMAgent(BaseAgent):
"""
LLM Agent.
This agent is responsible for processing user text input and querying a locally
hosted LLM for text generation. It acts as the conversational brain of the system.
It receives :class:`~control_backend.schemas.llm_prompt_message.LLMPromptMessage`
payloads from the BDI Core Agent, constructs a conversation history, queries the
LLM via HTTP, and streams the response back to the BDI agent in natural chunks
(e.g., sentence by sentence).
:ivar history: A list of dictionaries representing the conversation history (Role/Content).
"""
def __init__(self, name: str):
super().__init__(name)
self.history = []
async def setup(self):
self.logger.info("Setting up %s.", self.name)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages.
Expects messages from :attr:`settings.agent_settings.bdi_core_name` containing
an :class:`LLMPromptMessage` in the body.
:param msg: The received internal message.
"""
if msg.sender == settings.agent_settings.bdi_core_name:
self.logger.debug("Processing message from BDI core.")
try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message)
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
else:
self.logger.debug("Message ignored (not from BDI core.")
async def _process_bdi_message(self, message: LLMPromptMessage):
"""
Orchestrate the LLM query and response streaming.
Iterates over the chunks yielded by :meth:`_query_llm` and forwards them
individually to the BDI agent via :meth:`_send_reply`.
:param message: The parsed prompt message containing text, norms, and goals.
"""
async for chunk in self._query_llm(message.text, message.norms, message.goals):
await self._send_reply(chunk)
self.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI core."
)
async def _send_reply(self, msg: str):
"""
Sends a response message (chunk) back to the BDI Core Agent.
:param msg: The text content of the chunk.
"""
reply = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=msg,
)
await self.send(reply)
async def _query_llm(
self, prompt: str, norms: list[str], goals: list[str]
) -> AsyncGenerator[str]:
"""
Send a chat completion request to the local LLM service and stream the response.
It constructs the full prompt using
:class:`~control_backend.agents.llm.llm_instructions.LLMInstructions`.
It streams the response from the LLM and buffers tokens until a natural break (punctuation)
is reached, then yields the chunk. This ensures that the robot speaks in complete phrases
rather than individual tokens.
:param prompt: Input text prompt to pass to the LLM.
:param norms: Norms the LLM should hold itself to.
:param goals: Goals the LLM should achieve.
:yield: Fragments of the LLM-generated content (e.g., sentences/phrases).
"""
self.history.append(
{
"role": "user",
"content": prompt,
}
)
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
messages = [
{
"role": "developer",
"content": instructions.build_developer_instruction(),
},
*self.history,
]
message_id = str(uuid.uuid4()) # noqa
try:
full_message = ""
current_chunk = ""
async for token in self._stream_query_llm(messages):
full_message += token
current_chunk += token
self.logger.info(
"Received token: %s",
full_message,
extra={"reference": message_id}, # Used in the UI to update old logs
)
# Stream the message in chunks separated by punctuation.
# We include the delimiter in the emitted chunk for natural flow.
pattern = re.compile(r".*?(?:,|;|:|—||\.{3}|…|\.|\?|!)\s*", re.DOTALL)
for m in pattern.finditer(current_chunk):
chunk = m.group(0)
if chunk:
yield current_chunk
current_chunk = ""
# Yield any remaining tail
if current_chunk:
yield current_chunk
self.history.append(
{
"role": "assistant",
"content": full_message,
}
)
except httpx.HTTPError as err:
self.logger.error("HTTP error.", exc_info=err)
yield "LLM service unavailable."
except Exception as err:
self.logger.error("Unexpected error.", exc_info=err)
yield "Error processing the request."
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
"""
Perform the raw HTTP streaming request to the LLM API.
:param messages: The list of message dictionaries (role/content).
:yield: Raw text tokens (deltas) from the SSE stream.
:raises httpx.HTTPError: If the API returns a non-200 status.
"""
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": messages,
"temperature": 0.3,
"stream": True,
},
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data = line[len("data: ") :]
if data.strip() == "[DONE]":
break
try:
event = json.loads(data)
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
if delta:
yield delta
except json.JSONDecodeError:
self.logger.error("Failed to parse LLM response: %s", data)

View File

@@ -1,45 +1,27 @@
class LLMInstructions:
"""
Helper class to construct the system instructions for the LLM.
It combines the base persona (Pepper robot) with dynamic norms and goals
provided by the BDI system.
If no norms/goals are given it assumes empty lists.
:ivar norms: A list of behavioral norms.
:ivar goals: A list of specific conversational goals.
Defines structured instructions that are sent along with each request
to the LLM to guide its behavior (norms, goals, etc.).
"""
@staticmethod
def default_norms() -> list[str]:
return [
"Be friendly and respectful.",
"Make the conversation feel natural and engaging.",
]
def default_norms() -> str:
return "Be friendly and respectful.\nMake the conversation feel natural and engaging."
@staticmethod
def default_goals() -> list[str]:
return [
"Try to learn the user's name during conversation.",
]
def default_goals() -> str:
return """
Try to learn the user's name during conversation.
""".strip()
def __init__(self, norms: list[str] | None = None, goals: list[str] | None = None):
self.norms = norms or self.default_norms()
self.goals = goals or self.default_goals()
def __init__(self, norms: str | None = None, goals: str | None = None):
self.norms = norms if norms is not None else self.default_norms()
self.goals = goals if goals is not None else self.default_goals()
def build_developer_instruction(self) -> str:
"""
Builds the final system prompt string.
The prompt includes:
1. Persona definition.
2. Constraint on response length.
3. Instructions on how to handle goals (reach them in order, but prioritize natural flow).
4. The specific list of norms.
5. The specific list of goals.
:return: The formatted system prompt string.
Builds a multi-line formatted instruction string for the LLM.
Includes only non-empty structured fields.
"""
sections = [
"You are a Pepper robot engaging in natural human conversation.",
@@ -50,14 +32,12 @@ class LLMInstructions:
if self.norms:
sections.append("Norms to follow:")
for norm in self.norms:
sections.append("- " + norm)
sections.append(self.norms)
sections.append("")
if self.goals:
sections.append("Goals to reach:")
for goal in self.goals:
sections.append("- " + goal)
sections.append(self.goals)
sections.append("")
return "\n".join(sections).strip()

View File

@@ -0,0 +1,44 @@
import json
from spade.agent import Agent
from spade.behaviour import OneShotBehaviour
from spade.message import Message
from control_backend.core.config import settings
class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self):
to_jid = (
settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host
)
# Send multiple beliefs in one JSON payload
payload = {
"type": "belief_extraction_text",
"beliefs": {
"user_said": [
"hello test",
"Can you help me?",
"stop talking to me",
"No",
"Pepper do a dance",
]
},
}
msg = Message(to=to_jid)
msg.body = json.dumps(payload)
await self.send(msg)
print(f"Beliefs sent to {to_jid}!")
self.exit_code = "Job Finished!"
await self.agent.stop()
async def setup(self):
print("BeliefTextAgent started")
self.b = self.SendOnceBehaviourBlfText()
self.add_behaviour(self.b)

View File

@@ -1,4 +0,0 @@
from .transcription_agent.transcription_agent import (
TranscriptionAgent as TranscriptionAgent,
)
from .vad_agent import VADAgent as VADAgent

View File

@@ -1,138 +0,0 @@
import asyncio
import numpy as np
import zmq
import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from .speech_recognizer import SpeechRecognizer
class TranscriptionAgent(BaseAgent):
"""
Transcription Agent.
This agent listens to audio fragments (containing speech) on a ZMQ SUB socket,
transcribes them using the configured :class:`SpeechRecognizer`, and sends the
resulting text to other agents (e.g., the Text Belief Extractor).
It uses an internal semaphore to limit the number of concurrent transcription tasks.
:ivar audio_in_address: The ZMQ address to receive audio from (usually from VAD Agent).
:ivar audio_in_socket: The ZMQ SUB socket instance.
:ivar speech_recognizer: The speech recognition engine instance.
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
"""
def __init__(self, audio_in_address: str):
"""
Initialize the Transcription Agent.
:param audio_in_address: The ZMQ address of the audio source (e.g., VAD output).
"""
super().__init__(settings.agent_settings.transcription_name)
self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None
self.speech_recognizer = None
self._concurrency = None
async def setup(self):
"""
Initialize the agent resources.
1. Connects to the audio input ZMQ socket.
2. Initializes the :class:`SpeechRecognizer` (choosing the best available backend).
3. Starts the background transcription loop.
"""
self.logger.info("Setting up %s", self.name)
self._connect_audio_in_socket()
# Initialize recognizer and semaphore
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
self.speech_recognizer = SpeechRecognizer.best_type()
self.speech_recognizer.load_model() # Warmup
# Start background loop
self.add_behavior(self._transcribing_loop())
self.logger.info("Finished setting up %s", self.name)
async def stop(self):
"""
Stop the agent and close sockets.
"""
assert self.audio_in_socket is not None
self.audio_in_socket.close()
self.audio_in_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
"""
Helper to connect the ZMQ SUB socket for audio input.
"""
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)
async def _transcribe(self, audio: np.ndarray) -> str:
"""
Run the speech recognition on the audio data.
This runs in a separate thread (via `asyncio.to_thread`) to avoid blocking the event loop,
constrained by the concurrency semaphore.
:param audio: The audio data as a numpy array.
:return: The transcribed text string.
"""
assert self._concurrency is not None and self.speech_recognizer is not None
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
async def _share_transcription(self, transcription: str):
"""
Share a transcription to the other agents that depend on it.
Currently sends to:
- :attr:`settings.agent_settings.text_belief_extractor_name`
:param transcription: The transcribed text.
"""
receiver_names = [
settings.agent_settings.text_belief_extractor_name,
]
for receiver_name in receiver_names:
message = InternalMessage(
to=receiver_name,
sender=self.name,
body=transcription,
)
await self.send(message)
async def _transcribing_loop(self) -> None:
"""
The main loop for receiving audio and triggering transcription.
Receives audio chunks from ZMQ, decodes them to float32, and calls :meth:`_transcribe`.
If speech is found, it calls :meth:`_share_transcription`.
"""
while self._running:
try:
assert self.audio_in_socket is not None
audio_data = await self.audio_in_socket.recv()
audio = np.frombuffer(audio_data, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
self.logger.info("Nothing transcribed.")
continue
self.logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)
except Exception as e:
self.logger.error(f"Error in transcription loop: {e}")

View File

@@ -1,231 +0,0 @@
import asyncio
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .transcription_agent.transcription_agent import TranscriptionAgent
class SocketPoller[T]:
"""
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
:param T: The type of data returned by the socket.
"""
def __init__(
self,
socket: azmq.Socket,
timeout_ms: int = settings.behaviour_settings.socket_poller_timeout_ms,
):
"""
:param socket: The socket to poll and get data from.
:param timeout_ms: A timeout in milliseconds to wait for data.
"""
self.socket = socket
self.poller = azmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms
async def poll(self, timeout_ms: int | None = None) -> T | None:
"""
Get data from the socket, or None if the timeout is reached.
:param timeout_ms: If given, the timeout. Otherwise, ``self.timeout_ms`` is used.
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
socks = dict(await self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv()
return None
class VADAgent(BaseAgent):
"""
Voice Activity Detection (VAD) Agent.
This agent:
1. Receives an audio stream (via ZMQ).
2. Processes the audio using the Silero VAD model to detect speech.
3. Buffers potential speech segments.
4. Publishes valid speech fragments (containing speech plus small buffer) to a ZMQ PUB socket.
5. Instantiates and starts agents (like :class:`TranscriptionAgent`) that use this output.
:ivar audio_in_address: Address of the input audio stream.
:ivar audio_in_bind: Whether to bind or connect to the input address.
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
"""
Initialize the VAD Agent.
:param audio_in_address: ZMQ address for input audio.
:param audio_in_bind: True if this agent should bind to the input address, False to connect.
"""
super().__init__(settings.agent_settings.vad_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
self.audio_in_poller: SocketPoller | None = None
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = asyncio.Event()
self.model = None
async def setup(self):
"""
Initialize resources.
1. Connects audio input socket.
2. Binds audio output socket (random port).
3. Loads VAD model from Torch Hub.
4. Starts the streaming loop.
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
"""
self.logger.info("Setting up %s", self.name)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
self.logger.error("Could not bind output socket, stopping.")
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Initialize VAD model
try:
self.model, _ = torch.hub.load(
repo_or_dir=settings.vad_settings.repo_or_dir,
model=settings.vad_settings.model_name,
force_reload=False,
)
except Exception:
self.logger.exception("Failed to load VAD model.")
await self.stop()
return
# Warmup/reset
await self.reset_stream()
self.add_behavior(self._streaming_loop())
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
self.logger.info("Finished setting up %s", self.name)
async def stop(self):
"""
Stop listening to audio, stop publishing audio, close sockets.
"""
if self.audio_in_socket is not None:
self.audio_in_socket.close()
self.audio_in_socket = None
if self.audio_out_socket is not None:
self.audio_out_socket.close()
self.audio_out_socket = None
await super().stop()
def _connect_audio_in_socket(self):
"""
Connects (or binds) the socket for listening to audio from RI.
:return:
"""
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address)
else:
self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
"""
Returns the port bound, or None if binding failed.
"""
try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def reset_stream(self):
"""
Clears the ZeroMQ queue and sets ready state.
"""
discarded = 0
assert self.audio_in_poller is not None
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
self.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready.set()
async def _streaming_loop(self):
"""
Main loop for processing audio stream.
1. Polls for new audio chunks.
2. Passes chunk to VAD model.
3. Manages `i_since_speech` counter to determine start/end of speech.
4. Buffers speech + context.
5. Sends complete speech segment to output socket when silence is detected.
"""
await self._ready.wait()
while self._running:
assert self.audio_in_poller is not None
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
self.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
continue
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
assert self.model is not None
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
prob_threshold = settings.behaviour_settings.vad_prob_threshold
if prob > prob_threshold:
if self.i_since_speech > non_speech_patience:
self.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
continue
self.i_since_speech += 1
# prob < threshold, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= non_speech_patience:
self.audio_buffer = np.append(self.audio_buffer, chunk)
continue
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
self.logger.debug("Speech ended.")
assert self.audio_out_socket is not None
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk

View File

@@ -0,0 +1,92 @@
import json
import spade.agent
import zmq
from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
class RICommandAgent(BaseAgent):
subsocket: zmq.Socket
pubsocket: zmq.Socket
address = ""
bind = False
def __init__(
self,
jid: str,
password: str,
port: int = 5222,
verify_security: bool = False,
address="tcp://localhost:0000",
bind=False,
):
super().__init__(jid, password, port, verify_security)
self.address = address
self.bind = bind
class SendCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from the UI."""
async def run(self):
"""
Run the command publishing loop indefinetely.
"""
assert self.agent is not None
# Get a message internally (with topic command)
topic, body = await self.agent.subsocket.recv_multipart()
# Try to get body
try:
body = json.loads(body)
message = SpeechCommand.model_validate(body)
# Send to the robot.
await self.agent.pubsocket.send_json(message.model_dump())
except Exception as e:
self.agent.logger.error("Error processing message: %s", e)
class SendPythonCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from other Python agents."""
async def run(self):
message: spade.agent.Message = await self.receive(timeout=1)
if not message:
return
if message and message.to == self.agent.jid:
try:
speech_command = SpeechCommand.model_validate_json(message.body)
await self.agent.pubsocket.send_json(speech_command.model_dump())
except Exception as e:
self.agent.logger.error("Error processing message: %s", e)
async def setup(self):
"""
Setup the command agent
"""
self.logger.info("Setting up %s", self.jid)
context = Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind:
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent
commands_behaviour = self.SendCommandsBehaviour()
self.add_behaviour(commands_behaviour)
self.add_behaviour(self.SendPythonCommandsBehaviour())
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,162 @@
import asyncio
import zmq
from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .ri_command_agent import RICommandAgent
class RICommunicationAgent(BaseAgent):
req_socket: zmq.Socket
_address = ""
_bind = True
def __init__(
self,
jid: str,
password: str,
port: int = 5222,
verify_security: bool = False,
address=None,
bind=True,
):
super().__init__(jid, password, port, verify_security)
if not address:
self.logger.critical("No address set for negotiations.")
raise Exception # TODO: improve
self._address = address
self._bind = bind
class ListenBehaviour(CyclicBehaviour):
async def run(self):
"""
Run the listening (ping) loop indefinetely.
"""
assert self.agent is not None
# We need to listen and sent pings.
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
await self.agent.req_socket.send_json(message)
# Wait up to three seconds for a reply:)
try:
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :(
except TimeoutError:
self.agent.logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
self.agent.logger.debug('Received message "%s"', message)
if "endpoint" not in message:
self.agent.logger.error("No received endpoint in message, excepted ping endpoint.")
return
# See what endpoint we received
match message["endpoint"]:
case "ping":
await asyncio.sleep(1)
case _:
self.agent.logger.info(
"Received message with topic different than ping, while ping expected."
)
async def setup(self, max_retries: int = 5):
"""
Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
"""
self.logger.info("Setting up %s", self.jid)
retries = 0
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Bind request socket
self.req_socket = Context.instance().socket(zmq.REQ)
if self._bind:
self.req_socket.bind(self._address)
else:
self.req_socket.connect(self._address)
# Send our message and receive one back:)
message = {"endpoint": "negotiate/ports", "data": None}
await self.req_socket.send_json(message)
try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except TimeoutError:
self.logger.warning(
"No connection established in 20 seconds (attempt %d/%d)",
retries + 1,
max_retries,
)
retries += 1
continue
except Exception as e:
self.logger.error("Unexpected error during negotiation: %s", e)
retries += 1
continue
# Validate endpoint
endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports":
# TODO: Should this send a message back?
self.logger.error(
"Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
max_retries,
)
retries += 1
continue
# At this point, we have a valid response
try:
for port_data in received_message["data"]:
id = port_data["id"]
port = port_data["port"]
bind = port_data["bind"]
addr = f"tcp://{settings.zmq_settings.external_host}:{port}"
match id:
case "main":
if addr != self._address:
if not bind:
self.req_socket.connect(addr)
else:
self.req_socket.bind(addr)
case "actuation":
ri_commands_agent = RICommandAgent(
settings.agent_settings.ri_command_agent_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.ri_command_agent_name,
address=addr,
bind=bind,
)
await ri_commands_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)
except Exception as e:
self.logger.error("Error unpacking negotiation data: %s", e)
retries += 1
continue
# setup succeeded
break
else:
self.logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries)
return
# Set up ping behaviour
listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour)
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -10,32 +10,17 @@ import numpy as np
import torch
import whisper
from control_backend.core.config import settings
class SpeechRecognizer(abc.ABC):
"""
Abstract base class for speech recognition backends.
Provides a common interface for loading models and transcribing audio,
as well as heuristics for estimating token counts to optimize decoding.
:ivar limit_output_length: If True, limits the generated text length based on audio duration.
"""
def __init__(self, limit_output_length=True):
"""
:param limit_output_length: When ``True``, the length of the generated speech will be
limited by the length of the input audio and some heuristics.
:param limit_output_length: When `True`, the length of the generated speech will be limited
by the length of the input audio and some heuristics.
"""
self.limit_output_length = limit_output_length
@abc.abstractmethod
def load_model(self):
"""
Load the speech recognition model into memory.
"""
...
def load_model(self): ...
@abc.abstractmethod
def recognize_speech(self, audio: np.ndarray) -> str:
@@ -43,33 +28,29 @@ class SpeechRecognizer(abc.ABC):
Recognize speech from the given audio sample.
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
range [-1.0, 1.0].
:return: The recognized speech text.
range [-1.0, 1.0].
:return: Recognized speech.
"""
@staticmethod
def _estimate_max_tokens(audio: np.ndarray) -> int:
"""
Estimate the maximum length of a given audio sample in tokens.
Assumes a maximum speaking rate of 450 words per minute (3x average), and assumes that
3 words is approx. 4 tokens.
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
:param audio: The audio sample (16 kHz) to use for length estimation.
:return: The estimated length of the transcribed audio in tokens.
"""
length_seconds = len(audio) / settings.vad_settings.sample_rate_hz
length_seconds = len(audio) / 16_000
length_minutes = length_seconds / 60
word_count = length_minutes * settings.behaviour_settings.transcription_words_per_minute
token_count = word_count / settings.behaviour_settings.transcription_words_per_token
return int(token_count) + settings.behaviour_settings.transcription_token_buffer
word_count = length_minutes * 450
token_count = word_count / 3 * 4
return int(token_count) + 10
def _get_decode_options(self, audio: np.ndarray) -> dict:
"""
Construct decoding options for the Whisper model.
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
:return: A dict that can be used to construct ``whisper.DecodingOptions`` (or equivalent).
:return: A dict that can be used to construct `whisper.DecodingOptions`.
"""
options = {}
if self.limit_output_length:
@@ -78,12 +59,7 @@ class SpeechRecognizer(abc.ABC):
@staticmethod
def best_type():
"""
Factory method to get the best available `SpeechRecognizer`.
:return: An instance of :class:`MLXWhisperSpeechRecognizer` if on macOS with Apple Silicon,
otherwise :class:`OpenAIWhisperSpeechRecognizer`.
"""
"""Get the best type of SpeechRecognizer based on system capabilities."""
if torch.mps.is_available():
print("Choosing MLX Whisper model.")
return MLXWhisperSpeechRecognizer()
@@ -93,20 +69,12 @@ class SpeechRecognizer(abc.ABC):
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
"""
Speech recognizer using the MLX framework (optimized for Apple Silicon).
"""
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.was_loaded = False
self.model_name = settings.speech_model_settings.mlx_model_name
self.model_name = "mlx-community/whisper-small.en-mlx"
def load_model(self):
"""
Ensures the model is downloaded and cached. MLX loads dynamically, so this
pre-fetches the model.
"""
if self.was_loaded:
return
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
@@ -124,24 +92,15 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
"""
Speech recognizer using the standard OpenAI Whisper library (PyTorch).
"""
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.model = None
def load_model(self):
"""
Loads the OpenAI Whisper model onto the available device (CUDA or CPU).
"""
if self.model is not None:
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = whisper.load_model(
settings.speech_model_settings.openai_model_name, device=device
)
self.model = whisper.load_model("small.en", device=device)
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()

View File

@@ -0,0 +1,86 @@
import asyncio
import numpy as np
import zmq
import zmq.asyncio as azmq
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .speech_recognizer import SpeechRecognizer
class TranscriptionAgent(BaseAgent):
"""
An agent which listens to audio fragments with voice, transcribes them, and sends the
transcription to other agents.
"""
def __init__(self, audio_in_address: str):
jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.transcription_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None
class Transcribing(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket):
super().__init__()
self.audio_in_socket = audio_in_socket
self.speech_recognizer = SpeechRecognizer.best_type()
self._concurrency = asyncio.Semaphore(3)
def warmup(self):
"""Load the transcription model into memory to speed up the first transcription."""
self.speech_recognizer.load_model()
async def _transcribe(self, audio: np.ndarray) -> str:
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it."""
receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name
+ "@"
+ settings.agent_settings.host,
] # Set message receivers here
for receiver_jid in receiver_jids:
message = Message(to=receiver_jid, body=transcription)
await self.send(message)
async def run(self) -> None:
audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
self.agent.logger.info("Nothing transcribed.")
return
self.agent.logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)
async def stop(self):
self.audio_in_socket.close()
self.audio_in_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)
async def setup(self):
self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
transcribing = self.Transcribing(self.audio_in_socket)
transcribing.warmup()
self.add_behaviour(transcribing)
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,172 @@
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .transcription.transcription_agent import TranscriptionAgent
class SocketPoller[T]:
"""
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
"""
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
"""
:param socket: The socket to poll and get data from.
:param timeout_ms: A timeout in milliseconds to wait for data.
"""
self.socket = socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms
async def poll(self, timeout_ms: int | None = None) -> T | None:
"""
Get data from the socket, or None if the timeout is reached.
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
socks = dict(self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv()
return None
class Streaming(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100 # Used to allow small pauses in speech
self._ready = False
async def reset(self):
"""Clears the ZeroMQ queue and tells this behavior to start."""
discarded = 0
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
self.agent.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready = True
async def run(self) -> None:
if not self._ready:
return
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
self.agent.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100
return
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3:
self.agent.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
self.agent.logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(BaseAgent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
self.streaming_behaviour: Streaming | None = None
async def stop(self):
"""
Stop listening to audio, stop publishing audio, close sockets.
"""
if self.audio_in_socket is not None:
self.audio_in_socket.close()
self.audio_in_socket = None
if self.audio_out_socket is not None:
self.audio_out_socket.close()
self.audio_out_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address)
else:
self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed."""
try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def setup(self):
self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(self.streaming_behaviour)
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,20 @@
import logging
from fastapi import APIRouter, Request
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}

View File

@@ -15,14 +15,6 @@ router = APIRouter()
# DO NOT LOG INSIDE THIS FUNCTION
@router.get("/logs/stream")
async def log_stream():
"""
Server-Sent Events (SSE) endpoint for real-time log streaming.
Subscribes to the internal ZMQ logging topic and forwards log records to the client.
Allows the frontend to display live logs from the backend.
:return: A StreamingResponse yielding SSE data.
"""
context = Context.instance()
socket = context.socket(zmq.SUB)

View File

@@ -11,14 +11,6 @@ router = APIRouter()
@router.post("/message", status_code=202)
async def receive_message(message: Message, request: Request):
"""
Generic endpoint to receive text messages.
Publishes the message to the internal 'message' topic via ZMQ.
:param message: The message payload.
:param request: The FastAPI request object (used to access app state).
"""
logger.info("Received message: %s", message.message)
topic = b"message"

View File

@@ -1,7 +1,9 @@
import logging
from fastapi import APIRouter, Request
from fastapi import APIRouter, HTTPException, Request
from pydantic import ValidationError
from control_backend.schemas.message import Message
from control_backend.schemas.program import Program
logger = logging.getLogger(__name__)
@@ -9,18 +11,20 @@ router = APIRouter()
@router.post("/program", status_code=202)
async def receive_message(program: Program, request: Request):
async def receive_message(program: Message, request: Request):
"""
Endpoint to upload a new Behavior Program.
Validates the program structure (phases, norms, goals) and publishes it to the internal
'program' topic. The :class:`~control_backend.agents.bdi.bdi_program_manager.BDIProgramManager`
will pick this up and update the BDI agent.
:param program: The parsed Program object.
:param request: The FastAPI request object.
Receives a BehaviorProgram as a stringified JSON list inside `message`.
Converts it into real Phase objects.
"""
logger.debug("Received raw program: %s", program)
raw_str = program.message # This is the JSON string
# Validate program
try:
program = Program.model_validate_json(raw_str)
except ValidationError as e:
logger.error("Failed to validate program JSON: %s", e)
raise HTTPException(status_code=400, detail="Not a valid program") from None
# send away
topic = b"program"

View File

@@ -1,143 +0,0 @@
import asyncio
import json
import logging
import zmq.asyncio
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import ValidationError
from zmq.asyncio import Context, Socket
from control_backend.core.config import settings
from control_backend.schemas.ri_message import GestureCommand, SpeechCommand
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
"""
Send a direct speech command to the robot.
Publishes the command to the internal 'command' topic. The
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`
or
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotGestureAgent`
will forward this to the robot.
:param command: The speech command payload.
:param request: The FastAPI request object.
"""
# Validate and retrieve data.
validated = None
valid_commands = (GestureCommand, SpeechCommand)
for command_model in valid_commands:
try:
validated = command_model.model_validate(command)
except ValidationError:
continue
if validated is None:
raise HTTPException(status_code=422, detail="Payload is not valid for command models")
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, validated.model_dump_json().encode()])
return {"status": "Command received"}
@router.get("/ping_check")
async def ping(request: Request):
"""
Simple HTTP ping endpoint to check if the backend is reachable.
"""
pass
@router.get("/get_available_gesture_tags")
async def get_available_gesture_tags(request: Request):
"""
Endpoint to retrieve the available gesture tags for the robot.
:param request: The FastAPI request object.
:return: A list of available gesture tags.
"""
sub_socket = Context.instance().socket(zmq.SUB)
sub_socket.connect(settings.zmq_settings.internal_sub_address)
sub_socket.setsockopt(zmq.SUBSCRIBE, b"get_gestures")
pub_socket: Socket = request.app.state.endpoints_pub_socket
topic = b"send_gestures"
# TODO: Implement a way to get a certain ammount from the UI, rather than everything.
amount = None
timeout = 5 # seconds
await pub_socket.send_multipart([topic, amount.to_bytes(4, "big") if amount else b""])
try:
_, body = await asyncio.wait_for(sub_socket.recv_multipart(), timeout=timeout)
except TimeoutError:
body = b"tags: []"
logger.debug("got timeout error fetching gestures")
# Handle empty response and JSON decode errors
available_tags = []
if body:
try:
available_tags = json.loads(body).get("tags", [])
except json.JSONDecodeError as e:
logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}")
# Return empty list on JSON error
available_tags = []
return {"available_gesture_tags": available_tags}
@router.get("/ping_stream")
async def ping_stream(request: Request):
"""
SSE endpoint for monitoring the Robot Interface connection status.
Subscribes to the internal 'ping' topic (published by the RI Communication Agent)
and yields status updates to the client.
:return: A StreamingResponse of connection status events.
"""
async def event_stream():
# Set up internal socket to receive ping updates
sub_socket = Context.instance().socket(zmq.SUB)
sub_socket.connect(settings.zmq_settings.internal_sub_address)
sub_socket.setsockopt(zmq.SUBSCRIBE, b"ping")
connected = False
ping_frequency = 2
# Even though its most likely the updates should alternate
# (So, True - False - True - False for connectivity),
# let's still check.
while True:
try:
topic, body = await asyncio.wait_for(
sub_socket.recv_multipart(), timeout=ping_frequency
)
connected = json.loads(body)
except TimeoutError:
logger.debug("got timeout error in ping loop in ping router")
connected = False
# Stop if client disconnected
if await request.is_disconnected():
logger.info("Client disconnected from SSE")
break
logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}")
connectedJson = json.dumps(connected)
yield (f"data: {connectedJson}\n\n")
return StreamingResponse(event_stream(), media_type="text/event-stream")

View File

@@ -6,7 +6,4 @@ router = APIRouter()
# TODO: implement
@router.get("/sse")
async def sse(request: Request):
"""
Placeholder for future Server-Sent Events endpoint.
"""
pass

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import logs, message, program, robot, sse
from control_backend.api.v1.endpoints import command, logs, message, program, sse
api_router = APIRouter()
@@ -8,7 +8,7 @@ api_router.include_router(message.router, tags=["Messages"])
api_router.include_router(sse.router, tags=["SSE"])
api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"])
api_router.include_router(command.router, tags=["Commands"])
api_router.include_router(logs.router, tags=["Logs"])

View File

@@ -1,198 +0,0 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from asyncio import Task
from collections.abc import Coroutine
import zmq
import zmq.asyncio as azmq
from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
# Central directory to resolve agent names to instances
_agent_directory: dict[str, "BaseAgent"] = {}
class AgentDirectory:
"""
Helper class to keep track of which agents are registered.
Used for handling message routing.
"""
@staticmethod
def register(name: str, agent: "BaseAgent"):
_agent_directory[name] = agent
@staticmethod
def get(name: str) -> "BaseAgent | None":
return _agent_directory.get(name)
class BaseAgent(ABC):
"""
Abstract base class for all agents in the system.
This class provides the foundational infrastructure for agent lifecycle management, messaging
(both intra-process and inter-process via ZMQ), and asynchronous behavior execution.
.. warning::
Do not inherit from this class directly for creating new agents. Instead, inherit from
:class:`control_backend.agents.base.BaseAgent`, which ensures proper logger configuration.
:ivar name: The unique name of the agent.
:ivar inbox: The queue for receiving internal messages.
:ivar _tasks: A set of currently running asynchronous tasks/behaviors.
:ivar _running: A boolean flag indicating if the agent is currently running.
:ivar logger: The logger instance for the agent.
"""
logger: logging.Logger
def __init__(self, name: str):
"""
Initialize the BaseAgent.
:param name: The unique identifier for this agent.
"""
self.name = name
self.inbox: asyncio.Queue[InternalMessage] = asyncio.Queue()
self._tasks: set[asyncio.Task] = set()
self._running = False
# Register immediately
AgentDirectory.register(name, self)
@abstractmethod
async def setup(self):
"""
Initialize agent-specific resources.
This method must be overridden by subclasses. It is called after the agent has started
and the ZMQ sockets have been initialized. Use this method to:
* Initialize connections (databases, APIs, etc.)
* Add initial behaviors using :meth:`add_behavior`
"""
pass
async def start(self):
"""
Start the agent and its internal loops.
This method:
1. Sets the running state to True.
2. Initializes ZeroMQ PUB/SUB sockets for inter-process communication.
3. Calls the user-defined :meth:`setup` method.
4. Starts the inbox processing loop and the ZMQ receiver loop.
"""
self.logger.info(f"Starting agent {self.name}")
self._running = True
context = azmq.Context.instance()
# Setup the internal publishing socket
self._internal_pub_socket = context.socket(zmq.PUB)
self._internal_pub_socket.connect(settings.zmq_settings.internal_pub_address)
# Setup the internal receiving socket
self._internal_sub_socket = context.socket(zmq.SUB)
self._internal_sub_socket.connect(settings.zmq_settings.internal_sub_address)
self._internal_sub_socket.subscribe(f"internal/{self.name}")
await self.setup()
# Start processing inbox and ZMQ messages
self.add_behavior(self._process_inbox())
self.add_behavior(self._receive_internal_zmq_loop())
async def stop(self):
"""
Stop the agent.
Sets the running state to False and cancels all running background tasks.
"""
self._running = False
for task in self._tasks:
task.cancel()
self.logger.info(f"Agent {self.name} stopped")
async def send(self, message: InternalMessage):
"""
Send a message to another agent.
This method intelligently routes the message:
* If the target agent is in the same process (found in :class:`AgentDirectory`),
the message is put directly into its inbox.
* If the target agent is not found locally, the message is serialized and sent
via ZeroMQ to the internal publication address.
:param message: The message to send.
"""
target = AgentDirectory.get(message.to)
if target:
await target.inbox.put(message)
self.logger.debug(f"Sent message {message.body} to {message.to} via regular inbox.")
else:
# Apparently target agent is on a different process, send via ZMQ
topic = f"internal/{message.to}".encode()
body = message.model_dump_json().encode()
await self._internal_pub_socket.send_multipart([topic, body])
self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.")
async def _process_inbox(self):
"""
Internal loop that processes messages from the inbox.
Reads messages from ``self.inbox`` and passes them to :meth:`handle_message`.
"""
while self._running:
msg = await self.inbox.get()
self.logger.debug(f"Received message from {msg.sender}.")
await self.handle_message(msg)
async def _receive_internal_zmq_loop(self):
"""
Internal loop that listens for ZMQ messages.
Subscribes to ``internal/<agent_name>`` topics. When a message is received,
it is deserialized into an :class:`InternalMessage` and put into the local inbox.
This bridges the gap between inter-process ZMQ communication and the intra-process inbox.
"""
while self._running:
try:
_, body = await self._internal_sub_socket.recv_multipart()
msg = InternalMessage.model_validate_json(body)
await self.inbox.put(msg)
except asyncio.CancelledError:
break
except Exception:
self.logger.exception("Could not process ZMQ message.")
async def handle_message(self, msg: InternalMessage):
"""
Handle an incoming message.
This method must be overridden by subclasses to define how the agent reacts to messages.
:param msg: The received message.
:raises NotImplementedError: If not overridden by the subclass.
"""
raise NotImplementedError
def add_behavior(self, coro: Coroutine) -> Task:
"""
Add a background behavior (task) to the agent.
This is the preferred way to run continuous loops or long-running tasks within an agent.
The task is tracked and will be automatically cancelled when :meth:`stop` is called.
:param coro: The coroutine to execute as a task.
"""
task = asyncio.create_task(coro)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
return task

View File

@@ -1,140 +1,37 @@
import os
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel):
"""
Configuration for ZeroMQ (ZMQ) addresses used for inter-process communication.
:ivar internal_pub_address: Address for the internal PUB socket.
:ivar internal_sub_address: Address for the internal SUB socket.
:ivar ri_command_address: Address for sending commands to the Robot Interface.
:ivar ri_communication_address: Address for receiving communication from the Robot Interface.
:ivar vad_agent_address: Address for the Voice Activity Detection (VAD) agent.
"""
internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561"
ri_command_address: str = "tcp://localhost:0000"
ri_communication_address: str = "tcp://*:5555"
vad_agent_address: str = "tcp://localhost:5558"
external_host: str = "0.0.0.0"
class AgentSettings(BaseModel):
"""
Names of the various agents in the system. These names are used for routing messages.
host: str = os.environ.get("XMPP_HOST", "localhost")
bdi_core_agent_name: str = "bdi_core"
belief_collector_agent_name: str = "belief_collector"
text_belief_extractor_agent_name: str = "text_belief_extractor"
vad_agent_name: str = "vad_agent"
llm_agent_name: str = "llm_agent"
test_agent_name: str = "test_agent"
transcription_agent_name: str = "transcription_agent"
program_manager_agent_name: str = "program_manager"
:ivar bdi_core_name: Name of the BDI Core Agent.
:ivar bdi_belief_collector_name: Name of the Belief Collector Agent.
:ivar bdi_program_manager_name: Name of the BDI Program Manager Agent.
:ivar text_belief_extractor_name: Name of the Text Belief Extractor Agent.
:ivar vad_name: Name of the Voice Activity Detection (VAD) Agent.
:ivar llm_name: Name of the Large Language Model (LLM) Agent.
:ivar test_name: Name of the Test Agent.
:ivar transcription_name: Name of the Transcription Agent.
:ivar ri_communication_name: Name of the RI Communication Agent.
:ivar robot_speech_name: Name of the Robot Speech Agent.
"""
# agent names
bdi_core_name: str = "bdi_core_agent"
bdi_belief_collector_name: str = "belief_collector_agent"
bdi_program_manager_name: str = "bdi_program_manager_agent"
text_belief_extractor_name: str = "text_belief_extractor_agent"
vad_name: str = "vad_agent"
llm_name: str = "llm_agent"
test_name: str = "test_agent"
transcription_name: str = "transcription_agent"
ri_communication_name: str = "ri_communication_agent"
robot_speech_name: str = "robot_speech_agent"
robot_gesture_name: str = "robot_gesture_agent"
class BehaviourSettings(BaseModel):
"""
Configuration for agent behaviors and parameters.
:ivar sleep_s: Default sleep time in seconds for loops.
:ivar comm_setup_max_retries: Maximum number of retries for setting up communication.
:ivar socket_poller_timeout_ms: Timeout in milliseconds for socket polling.
:ivar vad_prob_threshold: Probability threshold for Voice Activity Detection.
:ivar vad_initial_since_speech: Initial value for 'since speech' counter in VAD.
:ivar vad_non_speech_patience_chunks: Number of non-speech chunks to wait before speech ended.
:ivar transcription_max_concurrent_tasks: Maximum number of concurrent transcription tasks.
:ivar transcription_words_per_minute: Estimated words per minute for transcription timing.
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
:ivar transcription_token_buffer: Buffer for transcription tokens.
"""
sleep_s: float = 1.0
comm_setup_max_retries: int = 5
socket_poller_timeout_ms: int = 100
# VAD settings
vad_prob_threshold: float = 0.5
vad_initial_since_speech: int = 100
vad_non_speech_patience_chunks: int = 3
# transcription behaviour
transcription_max_concurrent_tasks: int = 3
transcription_words_per_minute: int = 300
transcription_words_per_token: float = 0.75 # (3 words = 4 tokens)
transcription_token_buffer: int = 10
ri_communication_agent_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent"
class LLMSettings(BaseModel):
"""
Configuration for the Large Language Model (LLM).
:ivar local_llm_url: URL for the local LLM API.
:ivar local_llm_model: Name of the local LLM model to use.
"""
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "gpt-oss"
class VADSettings(BaseModel):
"""
Configuration for Voice Activity Detection (VAD) model.
:ivar repo_or_dir: Repository or directory for the VAD model.
:ivar model_name: Name of the VAD model.
:ivar sample_rate_hz: Sample rate in Hz for the VAD model.
"""
repo_or_dir: str = "snakers4/silero-vad"
model_name: str = "silero_vad"
sample_rate_hz: int = 16000
class SpeechModelSettings(BaseModel):
"""
Configuration for speech recognition models.
:ivar mlx_model_name: Model name for MLX-based speech recognition.
:ivar openai_model_name: Model name for OpenAI-based speech recognition.
"""
# model identifiers for speech recognition
mlx_model_name: str = "mlx-community/whisper-small.en-mlx"
openai_model_name: str = "small.en"
local_llm_url: str = os.environ.get("LLM_URL", "http://localhost:1234/v1/") + "chat/completions"
local_llm_model: str = os.environ.get("LLM_MODEL", "openai/gpt-oss-20b")
class Settings(BaseSettings):
"""
Global application settings.
:ivar app_title: Title of the application.
:ivar ui_url: URL of the frontend UI.
:ivar zmq_settings: ZMQ configuration.
:ivar agent_settings: Agent name configuration.
:ivar behaviour_settings: Behavior configuration.
:ivar vad_settings: VAD model configuration.
:ivar speech_model_settings: Speech model configuration.
:ivar llm_settings: LLM configuration.
"""
app_title: str = "PepperPlus"
ui_url: str = "http://localhost:5173"
@@ -143,15 +40,9 @@ class Settings(BaseSettings):
agent_settings: AgentSettings = AgentSettings()
behaviour_settings: BehaviourSettings = BehaviourSettings()
vad_settings: VADSettings = VADSettings()
speech_model_settings: SpeechModelSettings = SpeechModelSettings()
llm_settings: LLMSettings = LLMSettings()
model_config = SettingsConfigDict(env_file=".env", env_nested_delimiter="__")
model_config = SettingsConfigDict(env_file=".env")
settings = Settings()

View File

@@ -37,12 +37,6 @@ def add_logging_level(level_name: str, level_num: int, method_name: str | None =
def setup_logging(path: str = ".logging_config.yaml") -> None:
"""
Setup logging configuration of the CB. Tries to load the logging configuration from a file,
in which we specify custom loggers, formatters, handlers, etc.
:param path:
:return:
"""
if os.path.exists(path):
with open(path) as f:
try:

View File

@@ -1,22 +1,6 @@
"""
Control Backend Main Application.
This module defines the FastAPI application that serves as the entry point for the
Control Backend. It manages the lifecycle of the entire system, including:
1. **Socket Initialization**: Setting up the internal ZeroMQ PUB/SUB proxy for agent communication.
2. **Agent Management**: Instantiating and starting all agents.
3. **API Routing**: Exposing REST endpoints for external interaction.
Lifecycle Manager
-----------------
The :func:`lifespan` context manager handles the startup and shutdown sequences:
- **Startup**: Configures logging, starts the ZMQ proxy, connects sockets, and launches agents.
- **Shutdown**: Handles graceful cleanup (though currently minimal).
"""
import contextlib
import logging
import os
import threading
import zmq
@@ -24,25 +8,14 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context
# BDI agents
from control_backend.agents.bdi import (
BDIBeliefCollectorAgent,
BDICoreAgent,
TextBeliefExtractorAgent,
from control_backend.agents import (
BeliefCollectorAgent,
LLMAgent,
RICommunicationAgent,
VADAgent,
)
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
# Communication agents
from control_backend.agents.communication import RICommunicationAgent
# Emotional Agents
# LLM Agents
from control_backend.agents.llm import LLMAgent
# Perceive agents
from control_backend.agents.perception import VADAgent
# Other backend imports
from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent
from control_backend.agents.bdi.bdi_program_manager.bdi_program_manager import BDIProgramManager
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.logging import setup_logging
@@ -51,12 +24,6 @@ logger = logging.getLogger(__name__)
def setup_sockets():
"""
Initialize and run the internal ZeroMQ Proxy (XPUB/XSUB).
This proxy acts as the central message bus, forwarding messages published on the
internal PUB address to all subscribers on the internal SUB address.
"""
context = Context.instance()
internal_pub_socket = context.socket(zmq.XPUB)
@@ -83,6 +50,9 @@ async def lifespan(app: FastAPI):
# --- APPLICATION STARTUP ---
setup_logging()
logger.info("%s is starting up.", app.title)
logger.info(
"LLM_URL: %s, LLM_MODEL: %s", os.environ.get("LLM_URL"), os.environ.get("LLM_MODEL")
)
# Initiate sockets
proxy_thread = threading.Thread(target=setup_sockets)
@@ -97,72 +67,88 @@ async def lifespan(app: FastAPI):
# --- Initialize Agents ---
logger.info("Initializing and starting agents.")
agents_to_start = {
"RICommunicationAgent": (
RICommunicationAgent,
{
"name": settings.agent_settings.ri_communication_name,
"address": settings.zmq_settings.ri_communication_address,
"name": settings.agent_settings.ri_communication_agent_name,
"jid": f"{settings.agent_settings.ri_communication_agent_name}"
f"@{settings.agent_settings.host}",
"password": settings.agent_settings.ri_communication_agent_name,
"address": f"tcp://{settings.zmq_settings.external_host}:5555",
"bind": True,
},
),
"LLMAgent": (
LLMAgent,
{
"name": settings.agent_settings.llm_name,
"name": settings.agent_settings.llm_agent_name,
"jid": f"{settings.agent_settings.llm_agent_name}@{settings.agent_settings.host}",
"password": settings.agent_settings.llm_agent_name,
},
),
"BDICoreAgent": (
BDICoreAgent,
{
"name": settings.agent_settings.bdi_core_name,
"name": settings.agent_settings.bdi_core_agent_name,
"jid": f"{settings.agent_settings.bdi_core_agent_name}@"
f"{settings.agent_settings.host}",
"password": settings.agent_settings.bdi_core_agent_name,
"asl": "src/control_backend/agents/bdi/rules.asl",
},
),
"BeliefCollectorAgent": (
BDIBeliefCollectorAgent,
BeliefCollectorAgent,
{
"name": settings.agent_settings.bdi_belief_collector_name,
"name": settings.agent_settings.belief_collector_agent_name,
"jid": f"{settings.agent_settings.belief_collector_agent_name}@"
f"{settings.agent_settings.host}",
"password": settings.agent_settings.belief_collector_agent_name,
},
),
"TextBeliefExtractorAgent": (
TextBeliefExtractorAgent,
"TBeliefExtractor": (
TBeliefExtractorAgent,
{
"name": settings.agent_settings.text_belief_extractor_name,
"name": settings.agent_settings.text_belief_extractor_agent_name,
"jid": f"{settings.agent_settings.text_belief_extractor_agent_name}@"
f"{settings.agent_settings.host}",
"password": settings.agent_settings.text_belief_extractor_agent_name,
},
),
"VADAgent": (
VADAgent,
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
{
"audio_in_address": f"tcp://{settings.zmq_settings.external_host}:5558",
"audio_in_bind": True,
},
),
"ProgramManagerAgent": (
"ProgramManager": (
BDIProgramManager,
{
"name": settings.agent_settings.bdi_program_manager_name,
"name": settings.agent_settings.program_manager_agent_name,
"jid": f"{settings.agent_settings.program_manager_agent_name}@"
f"{settings.agent_settings.host}",
"password": settings.agent_settings.program_manager_agent_name,
},
),
}
agents = []
vad_agent = None
vad_agent_instance = None
for name, (agent_class, kwargs) in agents_to_start.items():
try:
logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**kwargs)
agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"})
await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance
agents.append(agent_instance)
vad_agent_instance = agent_instance
logger.info("Agent '%s' started successfully.", name)
except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
# Consider if the application should continue if an agent fails to start.
raise
assert vad_agent is not None
await vad_agent.reset_stream()
await vad_agent_instance.streaming_behaviour.reset()
logger.info("Application startup complete.")

View File

@@ -1,23 +0,0 @@
from pydantic import BaseModel
class Belief(BaseModel):
"""
Represents a single belief in the BDI system.
:ivar name: The functor or name of the belief (e.g., 'user_said').
:ivar arguments: A list of string arguments for the belief.
:ivar replace: If True, existing beliefs with this name should be replaced by this one.
"""
name: str
arguments: list[str]
replace: bool = False
class BeliefMessage(BaseModel):
"""
A container for transporting a list of beliefs between agents.
"""
beliefs: list[Belief]

View File

@@ -1,17 +0,0 @@
from pydantic import BaseModel
class InternalMessage(BaseModel):
"""
Standard message envelope for communication between agents within the Control Backend.
:ivar to: The name of the destination agent.
:ivar sender: The name of the sending agent.
:ivar body: The string payload (often a JSON-serialized model).
:ivar thread: An optional thread identifier/topic to categorize the message (e.g., 'beliefs').
"""
to: str
sender: str
body: str
thread: str | None = None

View File

@@ -1,18 +0,0 @@
from pydantic import BaseModel
class LLMPromptMessage(BaseModel):
"""
Payload sent from the BDI agent to the LLM agent.
Contains the user's text input along with the dynamic context (norms and goals)
that the LLM should use to generate a response.
:ivar text: The user's input text.
:ivar norms: A list of active behavioral norms.
:ivar goals: A list of active goals to pursue.
"""
text: str
norms: list[str]
goals: list[str]

View File

@@ -2,8 +2,4 @@ from pydantic import BaseModel
class Message(BaseModel):
"""
A simple generic message wrapper, typically used for simple API responses.
"""
message: str

View File

@@ -2,70 +2,37 @@ from pydantic import BaseModel
class Norm(BaseModel):
"""
Represents a behavioral norm.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar norm: The actual norm text describing the behavior.
"""
id: str
label: str
norm: str
name: str
value: str
class Goal(BaseModel):
"""
Represents an objective to be achieved.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar description: Detailed description of the goal.
:ivar achieved: Status flag indicating if the goal has been met.
"""
id: str
label: str
name: str
description: str
achieved: bool
class TriggerKeyword(BaseModel):
id: str
keyword: str
class KeywordTrigger(BaseModel):
class Trigger(BaseModel):
id: str
label: str
type: str
keywords: list[TriggerKeyword]
value: list[str]
class PhaseData(BaseModel):
norms: list[Norm]
goals: list[Goal]
triggers: list[Trigger]
class Phase(BaseModel):
"""
A distinct phase within a program, containing norms, goals, and triggers.
:ivar id: Unique identifier.
:ivar label: Human-readable label.
:ivar norms: List of norms active in this phase.
:ivar goals: List of goals to pursue in this phase.
:ivar triggers: List of triggers that define transitions out of this phase.
"""
id: str
label: str
norms: list[Norm]
goals: list[Goal]
triggers: list[KeywordTrigger]
name: str
nextPhaseId: str
phaseData: PhaseData
class Program(BaseModel):
"""
Represents a complete interaction program, consisting of a sequence or set of phases.
:ivar phases: The list of phases that make up the program.
"""
phases: list[Phase]

View File

@@ -1,64 +1,20 @@
from enum import Enum
from typing import Any, Literal
from typing import Any
from pydantic import BaseModel, model_validator
from pydantic import BaseModel
class RIEndpoint(str, Enum):
"""
Enumeration of valid endpoints for the Robot Interface (RI).
"""
SPEECH = "actuate/speech"
GESTURE_SINGLE = "actuate/gesture/single"
GESTURE_TAG = "actuate/gesture/tag"
PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports"
class RIMessage(BaseModel):
"""
Base schema for messages sent to the Robot Interface.
:ivar endpoint: The target endpoint/action on the RI.
:ivar data: The payload associated with the action.
"""
endpoint: RIEndpoint
data: Any
class SpeechCommand(RIMessage):
"""
A specific command to make the robot speak.
:ivar endpoint: Fixed to ``RIEndpoint.SPEECH``.
:ivar data: The text string to be spoken.
"""
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str
class GestureCommand(RIMessage):
"""
A specific command to make the robot do a gesture.
:ivar endpoint: Should be ``RIEndpoint.GESTURE_SINGLE`` or ``RIEndpoint.GESTURE_TAG``.
:ivar data: The id of the gesture to be executed.
"""
endpoint: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] - We validate this stricter rule ourselves
RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG
]
data: str
@model_validator(mode="after")
def check_endpoint(self):
allowed = {
RIEndpoint.GESTURE_SINGLE,
RIEndpoint.GESTURE_TAG,
}
if self.endpoint not in allowed:
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
return self

View File

@@ -1,144 +0,0 @@
import random
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.fixture
def per_transcription_agent(mocker):
return mocker.patch(
"control_backend.agents.perception.vad_agent.TranscriptionAgent", autospec=True
)
@pytest.fixture(autouse=True)
def torch_load(mocker):
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
model = MagicMock()
mock_torch.hub.load.return_value = (model, None)
mock_torch.from_numpy.side_effect = lambda arr: arr
return mock_torch
@pytest.mark.asyncio
async def test_normal_setup(per_transcription_agent):
"""
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets, and starts the TranscriptionAgent without loading real models.
"""
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
per_vad_agent.reset_stream = AsyncMock()
await per_vad_agent.setup()
per_transcription_agent.assert_called_once()
per_transcription_agent.return_value.start.assert_called_once()
per_vad_agent._streaming_loop.assert_called_once()
per_vad_agent.reset_stream.assert_called_once()
assert per_vad_agent.audio_in_socket is not None
assert per_vad_agent.audio_out_socket is not None
@pytest.mark.parametrize("do_bind", [True, False])
def test_in_socket_creation(zmq_context, do_bind: bool):
"""
Test that the VAD agent creates an audio input socket, differentiating between binding and
connecting.
"""
per_vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
per_vad_agent._connect_audio_in_socket()
assert per_vad_agent.audio_in_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
zmq.SUBSCRIBE,
"",
)
if do_bind:
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else:
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
"tcp://localhost:12345"
)
def test_out_socket_creation(zmq_context):
"""
Test that the VAD agent creates an audio output socket correctly.
"""
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent._connect_audio_out_socket()
assert per_vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio
async def test_out_socket_creation_failure(zmq_context):
"""
Test setup failure when the audio output socket cannot be created.
"""
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.stop = AsyncMock()
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
async def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
await per_vad_agent.setup()
assert per_vad_agent.audio_out_socket is None
per_vad_agent.stop.assert_called_once()
@pytest.mark.asyncio
async def test_stop(zmq_context, per_transcription_agent):
"""
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000,
10000,
)
await per_vad_agent.setup()
await per_vad_agent.stop()
assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert per_vad_agent.audio_in_socket is None
assert per_vad_agent.audio_out_socket is None

View File

@@ -1,99 +0,0 @@
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
import soundfile as sf
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture(autouse=True)
def patch_settings():
from control_backend.agents.perception import vad_agent
vad_agent.settings.behaviour_settings.vad_prob_threshold = 0.5
vad_agent.settings.behaviour_settings.vad_non_speech_patience_chunks = 3
vad_agent.settings.behaviour_settings.vad_initial_since_speech = 0
vad_agent.settings.vad_settings.sample_rate_hz = 16_000
@pytest.fixture(autouse=True)
def mock_torch(mocker):
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
mock_torch.from_numpy.side_effect = lambda arr: arr
return mock_torch
def get_audio_chunks() -> list[bytes]:
curr_file = os.path.realpath(__file__)
curr_dir = os.path.dirname(curr_file)
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
chunk_size = 512
chunks = []
with sf.SoundFile(file, "r") as f:
assert f.samplerate == 16000
assert f.channels == 1
assert f.subtype == "FLOAT"
while True:
data = f.read(chunk_size, dtype="float32")
if len(data) != chunk_size:
break
chunks.append(data.tobytes())
return chunks
@pytest.mark.asyncio
async def test_real_audio(mocker):
"""
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
input. Ensure that it outputs some fragments with audio.
"""
audio_chunks = get_audio_chunks()
audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll = AsyncMock(return_value=[(audio_in_socket, zmq.POLLIN)])
audio_out_socket = AsyncMock()
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent.audio_out_socket = audio_out_socket
# Use a fake model that marks most chunks as speech and ends with a few silences
silence_padding = 5
probabilities = [1.0] * len(audio_chunks) + [0.0] * silence_padding
chunk_bytes = audio_chunks + [b"\x00" * len(audio_chunks[0])] * silence_padding
model_item = MagicMock()
model_item.item.side_effect = probabilities
vad_agent.model = MagicMock(return_value=model_item)
class DummyPoller:
def __init__(self, data, agent):
self.data = data
self.agent = agent
async def poll(self, timeout_ms=None):
if self.data:
return self.data.pop(0)
self.agent._running = False
return None
vad_agent.audio_in_poller = DummyPoller(chunk_bytes, vad_agent)
vad_agent._ready = AsyncMock()
vad_agent._running = True
vad_agent.i_since_speech = 0
await vad_agent._streaming_loop()
audio_out_socket.send.assert_called()
for args in audio_out_socket.send.call_args_list:
assert isinstance(args[0][0], bytes)
assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio

View File

@@ -0,0 +1,99 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq
from control_backend.agents.ri_command_agent import RICommandAgent
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_bind(zmq_context, mocker):
"""Test setup with bind=True"""
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup()
# Ensure PUB socket bound
fake_socket.bind.assert_any_call("tcp://localhost:5555")
# Ensure SUB socket connected to internal address and subscribed
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
# Ensure behaviour attached
assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False"""
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup()
# Ensure PUB socket connected
fake_socket.connect.assert_any_call("tcp://localhost:5555")
@pytest.mark.asyncio
async def test_send_commands_behaviour_valid_message():
"""Test behaviour with valid JSON message"""
fake_socket = AsyncMock()
message_dict = {"message": "hello"}
fake_socket.recv_multipart = AsyncMock(
return_value=(b"command", json.dumps(message_dict).encode("utf-8"))
)
fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent
with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand:
mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message
await behaviour.run()
fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_awaited_with(mock_message.model_dump())
@pytest.mark.asyncio
async def test_send_commands_behaviour_invalid_message(caplog):
"""Test behaviour with invalid JSON message triggers error logging"""
fake_socket = AsyncMock()
fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}"))
fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent
with caplog.at_level("ERROR"):
await behaviour.run()
fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_not_awaited()
assert "Error processing message" in caplog.text

View File

@@ -0,0 +1,551 @@
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.ri_communication_agent import RICommunicationAgent
def fake_json_correct_negototiate_1():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True},
],
}
)
def fake_json_correct_negototiate_2():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_3():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_4():
# Different port, do bind
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 4555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_5():
# Different port, dont bind.
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 4555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_wrong_negototiate_1():
return AsyncMock(return_value={"endpoint": "ping", "data": ""})
def fake_json_invalid_id_negototiate():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "banana", "port": 4555, "bind": False},
{"id": "tomato", "port": 5557, "bind": True},
],
}
)
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1()
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5556", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2()
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1()
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("ERROR"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.recv_json.assert_awaited()
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "Failed to set up RICommunicationAgent" in caplog.text
# Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(zmq_context):
"""
Test the setup of the communication agent with different bind value
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3()
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=True
)
await agent.setup()
# --- Assert ---
fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4()
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5()
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
"""
Test the functionality of setup with incorrect id
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate()
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.recv_json.assert_awaited()
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "Unhandled negotiation id:" in caplog.text
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "No connection established in 20 seconds" in caplog.text
# Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_listen_behaviour_ping_correct(caplog):
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
# TODO: Integration test between actual server and password needed for spade agents
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("DEBUG"):
await behaviour.run()
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
assert "Received message" in caplog.text
@pytest.mark.asyncio
async def test_listen_behaviour_ping_wrong_endpoint(caplog):
"""
Test if our listen behaviour can work with wrong messages (wrong endpoint)
"""
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# This is a message for the wrong endpoint >:(
fake_socket.recv_json = AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True},
],
}
)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("INFO"):
await behaviour.run()
assert "Received message with topic different than ping, while ping expected." in caplog.text
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_listen_behaviour_timeout(zmq_context, caplog):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
with caplog.at_level("INFO"):
await behaviour.run()
assert "No ping retrieved in 3 seconds" in caplog.text
@pytest.mark.asyncio
async def test_listen_behaviour_ping_no_endpoint(caplog):
"""
Test if our listen behaviour can work with wrong messages (wrong endpoint)
"""
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# This is a message without endpoint >:(
fake_socket.recv_json = AsyncMock(
return_value={
"data": "I dont have an endpoint >:)",
}
)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("ERROR"):
await behaviour.run()
assert "No received endpoint in message, excepted ping endpoint." in caplog.text
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_setup_unexpected_exception(zmq_context, caplog):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
with caplog.at_level("ERROR"):
await agent.setup(max_retries=1)
# Ensure that the error was logged
assert "Unexpected error during negotiation: boom!" in caplog.text
@pytest.mark.asyncio
async def test_setup_unpacking_exception(zmq_context, caplog):
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
# Make recv_json return malformed negotiation data to trigger unpacking exception
malformed_data = {
"endpoint": "negotiate/ports",
"data": [{"id": "main"}],
} # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch RICommandAgent so it won't actually start
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
# --- Act & Assert ---
with caplog.at_level("ERROR"):
await agent.setup(max_retries=1)
# Ensure the unpacking exception was logged
assert "Error unpacking negotiation data" in caplog.text
# Ensure no command agent was started
fake_agent_instance.start.assert_not_awaited()
# Ensure no behaviour was attached
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)

View File

@@ -0,0 +1,120 @@
import random
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq
from spade.agent import Agent
from control_backend.agents.vad_agent import VADAgent
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.fixture
def streaming(mocker):
return mocker.patch("control_backend.agents.vad_agent.Streaming")
@pytest.fixture
def transcription_agent(mocker):
return mocker.patch("control_backend.agents.vad_agent.TranscriptionAgent", autospec=True)
@pytest.mark.asyncio
async def test_normal_setup(streaming, transcription_agent):
"""
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets, and starts the TranscriptionAgent without loading real models.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent.add_behaviour = MagicMock()
await vad_agent.setup()
streaming.assert_called_once()
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
transcription_agent.assert_called_once()
transcription_agent.return_value.start.assert_called_once()
assert vad_agent.audio_in_socket is not None
assert vad_agent.audio_out_socket is not None
@pytest.mark.parametrize("do_bind", [True, False])
def test_in_socket_creation(zmq_context, do_bind: bool):
"""
Test that the VAD agent creates an audio input socket, differentiating between binding and
connecting.
"""
vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
vad_agent._connect_audio_in_socket()
assert vad_agent.audio_in_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
zmq.SUBSCRIBE,
"",
)
if do_bind:
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else:
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
"tcp://localhost:12345"
)
def test_out_socket_creation(zmq_context):
"""
Test that the VAD agent creates an audio output socket correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._connect_audio_out_socket()
assert vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio
async def test_out_socket_creation_failure(zmq_context):
"""
Test setup failure when the audio output socket cannot be created.
"""
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError
)
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
assert vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once()
@pytest.mark.asyncio
async def test_stop(zmq_context, transcription_agent):
"""
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000,
10000,
)
await vad_agent.setup()
await vad_agent.stop()
assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None

View File

@@ -0,0 +1,59 @@
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
import soundfile as sf
import zmq
from control_backend.agents.vad_agent import Streaming
def get_audio_chunks() -> list[bytes]:
curr_file = os.path.realpath(__file__)
curr_dir = os.path.dirname(curr_file)
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
chunk_size = 512
chunks = []
with sf.SoundFile(file, "r") as f:
assert f.samplerate == 16000
assert f.channels == 1
assert f.subtype == "FLOAT"
while True:
data = f.read(chunk_size, dtype="float32")
if len(data) != chunk_size:
break
chunks.append(data.tobytes())
return chunks
@pytest.mark.asyncio
async def test_real_audio(mocker):
"""
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
input. Ensure that it outputs some fragments with audio.
"""
audio_chunks = get_audio_chunks()
audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
audio_out_socket = AsyncMock()
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
vad_streamer._ready = True
vad_streamer.agent = MagicMock()
for _ in audio_chunks:
await vad_streamer.run()
audio_out_socket.send.assert_called()
for args in audio_out_socket.send.call_args_list:
assert isinstance(args[0][0], bytes)
assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio

View File

@@ -0,0 +1,61 @@
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import command
from control_backend.schemas.ri_message import SpeechCommand
@pytest.fixture
def app():
"""
Creates a FastAPI test app and attaches the router under test.
Also sets up a mock internal_comm_socket.
"""
app = FastAPI()
app.include_router(command.router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
def test_receive_command_success(client):
"""
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
"""
# Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
speech_command = SpeechCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
)
def test_receive_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
assert response.status_code == 422 # validation error

View File

@@ -0,0 +1,26 @@
import pytest
from pydantic import ValidationError
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
def valid_command_1():
return SpeechCommand(data="Hallo?")
def invalid_command_1():
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
def test_valid_speech_command_1():
command = valid_command_1()
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
def test_invalid_speech_command_1():
command = invalid_command_1()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
SpeechCommand.model_validate(command)

View File

@@ -1,392 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.ri_message import RIEndpoint
@pytest.fixture
def zmq_context(mocker):
"""Mock the ZMQ context."""
mock_context = mocker.patch(
"control_backend.agents.actuation.robot_gesture_agent.azmq.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_bind(zmq_context, mocker):
"""Setup binds and subscribes to internal commands."""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=True)
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
await agent.setup()
# Check PUB socket binding
fake_socket.bind.assert_any_call("tcp://localhost:5556")
# Check SUB socket connection and subscriptions
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures")
# Check behavior was added
agent.add_behavior.assert_called() # Twice, even.
@pytest.mark.asyncio
async def test_setup_connect(zmq_context, mocker):
"""Setup connects when bind=False."""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=False)
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
await agent.setup()
# Check PUB socket connection (not binding)
fake_socket.connect.assert_any_call("tcp://localhost:5556")
fake_socket.connect.assert_any_call("tcp://internal:1234")
# Check behavior was added
agent.add_behavior.assert_called() # Twice, actually.
@pytest.mark.asyncio
async def test_handle_message_sends_valid_gesture_command():
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
payload = {
"endpoint": RIEndpoint.GESTURE_TAG,
"data": "hello", # "hello" is in availableTags
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_handle_message_sends_non_gesture_command():
"""Internal message with non-gesture endpoint is not handled by this agent."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_message_rejects_invalid_gesture_tag():
"""Internal message with invalid gesture tag is not forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
# Use a tag that's not in availableTags
payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_gesture_payload():
"""UI command with valid gesture tag is read from SUB and published."""
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"}
fake_socket = AsyncMock()
async def recv_once():
# stop after first iteration
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_non_gesture_payload():
"""UI command with non-gesture endpoint is not handled by this agent."""
command = {"endpoint": "some_other_endpoint", "data": "anything"}
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_invalid_gesture_tag():
"""UI command with invalid gesture tag is not forwarded."""
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_invalid_json():
"""Invalid JSON is ignored without sending."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", b"{not_json}")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_ignores_send_gestures_topic():
"""send_gestures topic is ignored in command loop."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"send_gestures", b"{}")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_fetch_gestures_loop_without_amount():
"""Fetch gestures request without amount returns all tags."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"send_gestures", b"{}")
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_awaited_once()
# Check the response contains all tags
args, kwargs = fake_socket.send_multipart.call_args
assert args[0][0] == b"get_gestures"
response = json.loads(args[0][1])
assert "tags" in response
assert len(response["tags"]) > 0
# Check it includes some expected tags
assert "hello" in response["tags"]
assert "yes" in response["tags"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_amount():
"""Fetch gestures request with amount returns limited tags."""
fake_socket = AsyncMock()
amount = 5
async def recv_once():
agent._running = False
return (b"send_gestures", json.dumps(amount).encode())
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_awaited_once()
args, kwargs = fake_socket.send_multipart.call_args
assert args[0][0] == b"get_gestures"
response = json.loads(args[0][1])
assert "tags" in response
assert len(response["tags"]) == amount
@pytest.mark.asyncio
async def test_fetch_gestures_loop_ignores_command_topic():
"""Command topic is ignored in fetch gestures loop."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", b"{}")
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._fetch_gestures_loop()
fake_socket.send_multipart.assert_not_awaited()
@pytest.mark.asyncio
async def test_fetch_gestures_loop_invalid_request():
"""Invalid request body is handled gracefully."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
# Send a non-integer, non-JSON body
return (b"send_gestures", b"not_json")
fake_socket.recv_multipart = recv_once
fake_socket.send_multipart = AsyncMock()
agent = RobotGestureAgent("robot_gesture")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._fetch_gestures_loop()
# Should still send a response (all tags)
fake_socket.send_multipart.assert_awaited_once()
def test_available_tags():
"""Test that availableTags returns the expected list."""
agent = RobotGestureAgent("robot_gesture")
tags = agent.availableTags()
assert isinstance(tags, list)
assert len(tags) > 0
# Check some expected tags are present
assert "hello" in tags
assert "yes" in tags
assert "no" in tags
# Check a non-existent tag is not present
assert "invalid_tag_not_in_list" not in tags
@pytest.mark.asyncio
async def test_stop_closes_sockets():
"""Stop method closes both sockets."""
pubsocket = MagicMock()
subsocket = MagicMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
agent.subsocket = subsocket
await agent.stop()
pubsocket.close.assert_called_once()
subsocket.close.assert_called_once()
@pytest.mark.asyncio
async def test_initialization_with_custom_gesture_data():
"""Agent can be initialized with custom gesture data."""
custom_gestures = ["custom1", "custom2", "custom3"]
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures)
# Note: The current implementation doesn't use the gesture_data parameter
# in availableTags(). This test documents that behavior.
# If you update the agent to use gesture_data, update this test accordingly.
assert agent.gesture_data == custom_gestures

View File

@@ -1,139 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
"control_backend.agents.actuation.robot_speech_agent.azmq.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_bind(zmq_context, mocker):
"""Setup binds and subscribes to internal commands."""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=True)
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
agent.add_behavior.assert_called_once()
@pytest.mark.asyncio
async def test_setup_connect(zmq_context, mocker):
"""Setup connects when bind=False."""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
await agent.setup()
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.connect.assert_any_call("tcp://internal:1234")
agent.add_behavior.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_sends_command():
"""Internal message is forwarded to robot pub socket as JSON."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent.pubsocket = pubsocket
payload = {"endpoint": "actuate/speech", "data": "hello"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_awaited_once_with(payload)
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_payload(zmq_context):
"""UI command is read from SUB and published."""
command = {"endpoint": "actuate/speech", "data": "hello"}
fake_socket = AsyncMock()
async def recv_once():
# stop after first iteration
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_awaited_once_with(command)
@pytest.mark.asyncio
async def test_zmq_command_loop_invalid_json():
"""Invalid JSON is ignored without sending."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", b"{not_json}")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_stop_closes_sockets():
pubsocket = MagicMock()
subsocket = MagicMock()
agent = RobotSpeechAgent("robot_speech")
agent.pubsocket = pubsocket
agent.subsocket = subsocket
await agent.stop()
pubsocket.close.assert_called_once()
subsocket.close.assert_called_once()

View File

@@ -0,0 +1,209 @@
import json
import logging
from unittest.mock import AsyncMock, MagicMock, call
import pytest
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetterBehaviour
# Define a constant for the collector agent name to use in tests
COLLECTOR_AGENT_NAME = "belief_collector"
COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test"
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock BDIAgent."""
agent = MagicMock()
agent.bdi = MagicMock()
agent.jid = "bdi_agent@test"
return agent
@pytest.fixture
def belief_setter(mock_agent, mocker):
"""Fixture to create an instance of BeliefSetterBehaviour with a mocked agent."""
# Patch the settings to use a predictable agent name
mocker.patch(
"control_backend.agents.bdi.behaviours.belief_setter.settings.agent_settings.belief_collector_agent_name",
COLLECTOR_AGENT_NAME,
)
setter = BeliefSetterBehaviour()
setter.agent = mock_agent
# Mock the receive method, we will control its return value in each test
setter.receive = AsyncMock()
return setter
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
msg.thread = thread
return msg
@pytest.mark.asyncio
async def test_run_message_received(belief_setter, mocker):
"""
Test that when a message is received, _process_message is called.
"""
# Arrange
msg = MagicMock()
belief_setter.receive.return_value = msg
mocker.patch.object(belief_setter, "_process_message")
# Act
await belief_setter.run()
# Assert
belief_setter._process_message.assert_called_once_with(msg)
def test_process_message_from_belief_collector(belief_setter, mocker):
"""
Test processing a message from the correct belief collector agent.
"""
# Arrange
msg = create_mock_message(sender_node=COLLECTOR_AGENT_NAME, body="", thread="")
mock_process_belief = mocker.patch.object(belief_setter, "_process_belief_message")
# Act
belief_setter._process_message(msg)
# Assert
mock_process_belief.assert_called_once_with(msg)
def test_process_message_from_other_agent(belief_setter, mocker):
"""
Test that messages from other agents are ignored.
"""
# Arrange
msg = create_mock_message(sender_node="other_agent", body="", thread="")
mock_process_belief = mocker.patch.object(belief_setter, "_process_belief_message")
# Act
belief_setter._process_message(msg)
# Assert
mock_process_belief.assert_not_called()
def test_process_belief_message_valid_json(belief_setter, mocker):
"""
Test processing a valid belief message with correct thread and JSON body.
"""
# Arrange
beliefs_payload = {"is_hot": ["kitchen"], "is_clean": ["kitchen", "bathroom"]}
msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs"
)
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs")
# Act
belief_setter._process_belief_message(msg)
# Assert
mock_set_beliefs.assert_called_once_with(beliefs_payload)
def test_process_belief_message_invalid_json(belief_setter, mocker, caplog):
"""
Test that a message with invalid JSON is handled gracefully and an error is logged.
"""
# Arrange
msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body="this is not a json string", thread="beliefs"
)
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs")
# Act
belief_setter._process_belief_message(msg)
# Assert
mock_set_beliefs.assert_not_called()
def test_process_belief_message_wrong_thread(belief_setter, mocker):
"""
Test that a message with an incorrect thread is ignored.
"""
# Arrange
msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body='{"some": "data"}', thread="not_beliefs"
)
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs")
# Act
belief_setter._process_belief_message(msg)
# Assert
mock_set_beliefs.assert_not_called()
def test_process_belief_message_empty_body(belief_setter, mocker):
"""
Test that a message with an empty body is ignored.
"""
# Arrange
msg = create_mock_message(sender_node=COLLECTOR_AGENT_JID, body="", thread="beliefs")
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs")
# Act
belief_setter._process_belief_message(msg)
# Assert
mock_set_beliefs.assert_not_called()
def test_set_beliefs_success(belief_setter, mock_agent, caplog):
"""
Test that beliefs are correctly set on the agent's BDI.
"""
# Arrange
beliefs_to_set = {
"is_hot": ["kitchen"],
"door_opened": ["front_door", "back_door"],
}
# Act
with caplog.at_level(logging.INFO):
belief_setter._set_beliefs(beliefs_to_set)
# Assert
expected_calls = [
call("is_hot", "kitchen"),
call("door_opened", "front_door", "back_door"),
]
mock_agent.bdi.set_belief.assert_has_calls(expected_calls, any_order=True)
assert mock_agent.bdi.set_belief.call_count == 2
# def test_responded_unset(belief_setter, mock_agent):
# # Arrange
# new_beliefs = {"user_said": ["message"]}
#
# # Act
# belief_setter._set_beliefs(new_beliefs)
#
# # Assert
# mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")])
# mock_agent.bdi.remove_belief.assert_has_calls([call("responded")])
# def test_set_beliefs_bdi_not_initialized(belief_setter, mock_agent, caplog):
# """
# Test that a warning is logged if the agent's BDI is not initialized.
# """
# # Arrange
# mock_agent.bdi = None # Simulate BDI not being ready
# beliefs_to_set = {"is_hot": ["kitchen"]}
#
# # Act
# with caplog.at_level(logging.WARNING):
# belief_setter._set_beliefs(beliefs_to_set)
#
# # Assert
# assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text

View File

@@ -1,126 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import agentspeak
import pytest
from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
@pytest.fixture
def mock_agentspeak_env():
with patch("agentspeak.runtime.Environment") as mock_env:
yield mock_env
@pytest.fixture
def agent():
agent = BDICoreAgent("bdi_agent", "dummy.asl")
agent.send = AsyncMock()
agent.bdi_agent = MagicMock()
return agent
@pytest.mark.asyncio
async def test_setup_loads_asl(mock_agentspeak_env, agent):
# Mock file opening
with patch("builtins.open", mock_open(read_data="+initial_goal.")):
await agent.setup()
# Check if environment tried to build agent
mock_agentspeak_env.return_value.build_agent.assert_called()
@pytest.mark.asyncio
async def test_setup_no_asl(mock_agentspeak_env, agent):
with patch("builtins.open", side_effect=FileNotFoundError):
await agent.setup()
mock_agentspeak_env.return_value.build_agent.assert_not_called()
@pytest.mark.asyncio
async def test_handle_belief_collector_message(agent, mock_settings):
"""Test that incoming beliefs are added to the BDI agent"""
beliefs = [Belief(name="user_said", arguments=["Hello"])]
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
# Expect bdi_agent.call to be triggered to add belief
args = agent.bdi_agent.call.call_args.args
assert args[0] == agentspeak.Trigger.addition
assert args[1] == agentspeak.GoalType.belief
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
@pytest.mark.asyncio
async def test_incorrect_belief_collector_message(agent, mock_settings):
"""Test that incorrect message format triggers an exception."""
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=json.dumps({"bad_format": "bad_format"}),
thread="beliefs",
)
await agent.handle_message(msg)
agent.bdi_agent.call.assert_not_called() # did not set belief
@pytest.mark.asyncio
async def test():
pass
@pytest.mark.asyncio
async def test_handle_llm_response(agent):
"""Test that LLM responses are forwarded to the Robot Speech Agent"""
msg = InternalMessage(
to="bdi_agent", sender=settings.agent_settings.llm_name, body="This is the LLM reply"
)
await agent.handle_message(msg)
# Verify forward
assert agent.send.called
sent_msg = agent.send.call_args[0][0]
assert sent_msg.to == settings.agent_settings.robot_speech_name
assert "This is the LLM reply" in sent_msg.body
@pytest.mark.asyncio
async def test_custom_actions(agent):
agent._send_to_llm = MagicMock(side_effect=agent.send) # Mock specific method
# Initialize actions manually since we didn't call setup with real file
agent._add_custom_actions()
# Find the action
action_fn = None
for (functor, _), fn in agent.actions.actions.items():
if functor == ".reply":
action_fn = fn
break
assert action_fn is not None
# Invoke action
mock_term = MagicMock()
mock_term.args = ["Hello", "Norm", "Goal"]
mock_intention = MagicMock()
# Run generator
gen = action_fn(agent, mock_term, mock_intention)
next(gen) # Execute
agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal")

View File

@@ -1,89 +0,0 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
BDIBeliefCollectorAgent,
)
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief
@pytest.fixture
def agent():
agent = BDIBeliefCollectorAgent("belief_collector_agent")
return agent
def make_msg(body: dict, sender: str = "sender"):
return InternalMessage(to="collector", sender=sender, body=json.dumps(body))
@pytest.mark.asyncio
async def test_handle_message_routes_belief_text(agent, mocker):
"""
Test that when a message is received, _handle_belief_text is called with that message.
"""
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}
spy = mocker.patch.object(agent, "_handle_belief_text", new_callable=AsyncMock)
await agent.handle_message(make_msg(payload))
spy.assert_awaited_once_with(payload, "sender")
@pytest.mark.asyncio
async def test_handle_message_routes_emotion(agent, mocker):
payload = {"type": "emotion_extraction_text"}
spy = mocker.patch.object(agent, "_handle_emo_text", new_callable=AsyncMock)
await agent.handle_message(make_msg(payload))
spy.assert_awaited_once_with(payload, "sender")
@pytest.mark.asyncio
async def test_handle_message_bad_json(agent, mocker):
agent._handle_belief_text = AsyncMock()
bad_msg = InternalMessage(to="collector", sender="sender", body="not json")
await agent.handle_message(bad_msg)
agent._handle_belief_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
expected = [Belief(name="user_said", arguments=["hello"])]
await agent._handle_belief_text(payload, "origin")
spy.assert_awaited_once_with(expected, origin="origin")
@pytest.mark.asyncio
async def test_handle_belief_text_no_send_when_empty(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
await agent._handle_belief_text(payload, "origin")
spy.assert_not_awaited()
@pytest.mark.asyncio
async def test_send_beliefs_to_bdi(agent):
agent.send = AsyncMock()
beliefs = [Belief(name="user_said", arguments=["hello", "world"])]
await agent._send_beliefs_to_bdi(beliefs, origin="origin")
agent.send.assert_awaited_once()
sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]

View File

@@ -1,58 +0,0 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
TextBeliefExtractorAgent,
)
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
return agent
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_transcription_demo(agent, mock_settings):
transcription = "this is a test"
await agent._process_transcription_demo(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]

View File

@@ -0,0 +1,101 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.belief_collector.behaviours.continuous_collect import (
ContinuousBeliefCollector,
)
def create_mock_message(sender_node: str, body: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
return msg
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock Agent."""
agent = MagicMock()
agent.jid = "belief_collector_agent@test"
return agent
@pytest.fixture
def continuous_collector(mock_agent, mocker):
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
# Patch asyncio.sleep to prevent tests from actually waiting
mocker.patch("asyncio.sleep", return_value=None)
collector = ContinuousBeliefCollector()
collector.agent = mock_agent
# Mock the receive method, we will control its return value in each test
collector.receive = AsyncMock()
return collector
@pytest.mark.asyncio
async def test_run_message_received(continuous_collector, mocker):
"""
Test that when a message is received, _process_message is called with that message.
"""
# Arrange
mock_msg = MagicMock()
continuous_collector.receive.return_value = mock_msg
mocker.patch.object(continuous_collector, "_process_message")
# Act
await continuous_collector.run()
# Assert
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
msg = create_mock_message(
"anyone",
json.dumps({"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}),
)
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
msg = create_mock_message(
"belief_text_agent_mock", json.dumps({"beliefs": {"user_said": [["hi"]]}})
)
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
msg = create_mock_message("anyone", json.dumps({"type": "emotion_extraction_text"}))
spy = mocker.patch.object(continuous_collector, "_handle_emo_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_happy_path_sends(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
continuous_collector.send = AsyncMock()
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
# make sure we attempted a send
continuous_collector.send.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
continuous_collector.send = AsyncMock()
await continuous_collector._handle_belief_text(payload, "origin")
continuous_collector.send.assert_awaited_once()

View File

@@ -0,0 +1,187 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from spade.message import Message
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText
@pytest.fixture
def mock_settings():
"""
Mocks the settings object that the behaviour imports.
We patch it at the source where it's imported by the module under test.
"""
# Create a mock object that mimics the nested structure
settings_mock = MagicMock()
settings_mock.agent_settings.transcription_agent_name = "transcriber"
settings_mock.agent_settings.belief_collector_agent_name = "collector"
settings_mock.agent_settings.host = "fake.host"
# Use patch to replace the settings object during the test
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where
# your behaviour file imports it from.
with patch(
"control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock
):
yield settings_mock
@pytest.fixture
def behavior(mock_settings):
"""
Creates an instance of the BeliefFromText behaviour and mocks its
agent, logger, send, and receive methods.
"""
b = BeliefFromText()
b.agent = MagicMock()
b.send = AsyncMock()
b.receive = AsyncMock()
return b
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
msg.thread = thread
return msg
@pytest.mark.asyncio
async def test_run_no_message(behavior):
"""
Tests the run() method when no message is received.
"""
# Arrange: Configure receive to return None
behavior.receive.return_value = None
# Act: Run the behavior
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that no message was sent
behavior.send.assert_not_called()
@pytest.mark.asyncio
async def test_run_message_from_other_agent(behavior):
"""
Tests the run() method when a message is received from an
unknown agent (not the transcriber).
"""
# Arrange: Create a mock message from an unknown sender
mock_msg = create_mock_message("unknown", "some data", None)
behavior.receive.return_value = mock_msg
behavior._process_transcription_demo = MagicMock()
# Act
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that _process_transcription_demo was not sent
behavior._process_transcription_demo.assert_not_called()
@pytest.mark.asyncio
async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkeypatch):
"""
Tests the main success path: receiving a message from the
transcription agent, which triggers _process_transcription_demo.
"""
# Arrange: Create a mock message from the transcriber
transcription_text = "hello world"
mock_msg = create_mock_message(
mock_settings.agent_settings.transcription_agent_name, transcription_text, None
)
behavior.receive.return_value = mock_msg
# Act
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that send was called *once*
behavior.send.assert_called_once()
# 3. Deeply inspect the message that was sent
sent_msg: Message = behavior.send.call_args[0][0]
assert (
sent_msg.to
== mock_settings.agent_settings.belief_collector_agent_name
+ "@"
+ mock_settings.agent_settings.host
)
# Check thread
assert sent_msg.thread == "beliefs"
# Parse the received JSON string back into a dict
expected_dict = {
"beliefs": {"user_said": [transcription_text]},
"type": "belief_extraction_text",
}
sent_dict = json.loads(sent_msg.body)
# Assert that the dictionaries are equal
assert sent_dict == expected_dict
@pytest.mark.asyncio
async def test_process_transcription_success(behavior, mock_settings):
"""
Tests the (currently unused) _process_transcription method's
success path, using its hardcoded mock response.
"""
# Arrange
test_text = "I am feeling happy"
# This is the hardcoded response inside the method
expected_response_body = '{"mood": [["happy"]]}'
# Act
await behavior._process_transcription(test_text)
# Assert
# 1. Check that a message was sent
behavior.send.assert_called_once()
# 2. Inspect the sent message
sent_msg: Message = behavior.send.call_args[0][0]
expected_to = (
mock_settings.agent_settings.belief_collector_agent_name
+ "@"
+ mock_settings.agent_settings.host
)
assert str(sent_msg.to) == expected_to
assert sent_msg.thread == "beliefs"
assert sent_msg.body == expected_response_body
@pytest.mark.asyncio
async def test_process_transcription_json_decode_error(behavior, mock_settings):
"""
Tests the _process_transcription method's error handling
when the (mocked) response is invalid JSON.
We do this by patching json.loads to raise an error.
"""
# Arrange
test_text = "I am feeling happy"
# Patch json.loads to raise an error when called
with patch("json.loads", side_effect=json.JSONDecodeError("Mock error", "", 0)):
# Act
await behavior._process_transcription(test_text)
# Assert
# 1. Check that NO message was sent
behavior.send.assert_not_called()

View File

@@ -1,356 +0,0 @@
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
def speech_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
def gesture_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotGestureAgent"
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
"control_backend.agents.communication.ri_communication_agent.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
def negotiation_message(
actuation_port: int = 5556,
bind_main: bool = False,
bind_actuation: bool = False,
main_port: int = 5555,
):
return {
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": main_port, "bind": bind_main},
{"id": "actuation", "port": actuation_port, "bind": bind_actuation},
],
}
@pytest.mark.asyncio
async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
fake_socket.send_multipart = AsyncMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent.add_behavior = MagicMock()
await agent.setup()
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
MockSpeech.return_value.start.assert_awaited_once()
MockGesture.return_value.start.assert_awaited_once()
MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False)
MockGesture.assert_called_once_with(
ANY,
address="tcp://localhost:5556",
bind=False,
gesture_data=[],
)
agent.add_behavior.assert_called_once()
assert agent.connected is True
@pytest.mark.asyncio
async def test_setup_binds_when_requested(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message(bind_main=True))
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
agent.add_behavior = MagicMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555")
agent.add_behavior.assert_called_once()
@pytest.mark.asyncio
async def test_negotiate_invalid_endpoint_retries(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
success = await agent._negotiate_connection(max_retries=1)
assert success is False
@pytest.mark.asyncio
async def test_negotiate_timeout(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
success = await agent._negotiate_connection(max_retries=1)
assert success is False
@pytest.mark.asyncio
async def test_handle_negotiation_response_updates_req_socket(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
await agent._handle_negotiation_response(
negotiation_message(
main_port=6000,
actuation_port=6001,
bind_main=False,
bind_actuation=False,
)
)
fake_socket.connect.assert_any_call("tcp://localhost:6000")
@pytest.mark.asyncio
async def test_handle_disconnection_publishes_and_reconnects():
pub_socket = AsyncMock()
agent = RICommunicationAgent("ri_comm")
agent.pub_socket = pub_socket
agent.connected = True
agent._negotiate_connection = AsyncMock(return_value=True)
await agent._handle_disconnection()
pub_socket.send_multipart.assert_awaited()
assert agent.connected is True
@pytest.mark.asyncio
async def test_listen_loop_handles_non_ping(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
async def recv_once():
agent._running = False
return {"endpoint": "negotiate/ports", "data": {}}
fake_socket.recv_json = recv_once
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
await agent._listen_loop()
fake_socket.send_json.assert_called()
@pytest.mark.asyncio
async def test_negotiate_unexpected_error(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom"))
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
assert await agent._negotiate_connection(max_retries=1) is False
@pytest.mark.asyncio
async def test_negotiate_handle_response_error(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent._handle_negotiation_response = AsyncMock(side_effect=Exception("bad response"))
assert await agent._negotiate_connection(max_retries=1) is False
@pytest.mark.asyncio
async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock()
agent = RICommunicationAgent("ri_comm")
def swallow(coro):
coro.close()
agent.add_behavior = swallow
agent._negotiate_connection = AsyncMock(return_value=False)
await agent.setup()
assert agent.connected is False
@pytest.mark.asyncio
async def test_handle_negotiation_response_unhandled_id():
agent = RICommunicationAgent("ri_comm")
await agent._handle_negotiation_response(
{"data": [{"id": "other", "port": 5000, "bind": False}]}
)
@pytest.mark.asyncio
async def test_stop_closes_sockets():
req = MagicMock()
pub = MagicMock()
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = pub
await agent.stop()
req.close.assert_called_once()
pub.close.assert_called_once()
@pytest.mark.asyncio
async def test_listen_loop_not_connected(monkeypatch):
agent = RICommunicationAgent("ri_comm")
agent._running = True
agent.connected = False
agent._req_socket = AsyncMock()
async def fake_sleep(duration):
agent._running = False
monkeypatch.setattr("asyncio.sleep", fake_sleep)
await agent._listen_loop()
@pytest.mark.asyncio
async def test_listen_loop_send_and_recv_timeout():
req = AsyncMock()
req.send_json = AsyncMock(side_effect=TimeoutError)
req.recv_json = AsyncMock(side_effect=TimeoutError)
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
async def stop_run():
agent._running = False
agent._handle_disconnection = AsyncMock(side_effect=stop_run)
await agent._listen_loop()
agent._handle_disconnection.assert_awaited()
@pytest.mark.asyncio
async def test_listen_loop_missing_endpoint(monkeypatch):
req = AsyncMock()
req.send_json = AsyncMock()
async def recv_once():
agent._running = False
return {"data": {}}
req.recv_json = recv_once
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
await agent._listen_loop()
@pytest.mark.asyncio
async def test_listen_loop_generic_exception():
req = AsyncMock()
req.send_json = AsyncMock()
req.recv_json = AsyncMock(side_effect=ValueError("boom"))
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
with pytest.raises(ValueError):
await agent._listen_loop()
@pytest.mark.asyncio
async def test_handle_disconnection_timeout(monkeypatch):
pub = AsyncMock()
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
agent = RICommunicationAgent("ri_comm")
agent.pub_socket = pub
agent._negotiate_connection = AsyncMock(return_value=False)
await agent._handle_disconnection()
pub.send_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_listen_loop_ping_sends_internal(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
pub_socket = AsyncMock()
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent.pub_socket = pub_socket
agent.connected = True
agent._running = True
async def recv_once():
agent._running = False
return {"endpoint": "ping", "data": {}}
fake_socket.recv_json = recv_once
await agent._listen_loop()
pub_socket.send_multipart.assert_awaited()

View File

@@ -1,136 +0,0 @@
"""Mocks `httpx` and tests chunking logic."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
@pytest.fixture
def mock_httpx_client():
with patch("httpx.AsyncClient") as mock_cls:
mock_client = AsyncMock()
mock_cls.return_value.__aenter__.return_value = mock_client
yield mock_client
@pytest.mark.asyncio
async def test_llm_processing_success(mock_httpx_client, mock_settings):
# Setup the mock response for the stream
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
# Simulate stream lines
lines = [
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
b'data: {"choices": [{"delta": {"content": " world"}}]}',
b'data: {"choices": [{"delta": {"content": "."}}]}',
b"data: [DONE]",
]
async def aiter_lines_gen():
for line in lines:
yield line.decode()
mock_response.aiter_lines.side_effect = aiter_lines_gen
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
# Configure the client
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
# Setup Agent
agent = LLMAgent("llm_agent")
agent.send = AsyncMock() # Mock the send method to verify replies
# Simulate receiving a message from BDI
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
await agent.handle_message(msg)
# Verification
# "Hello world." constitutes one sentence/chunk based on punctuation split
# The agent should call send once with the full sentence
assert agent.send.called
args = agent.send.call_args[0][0]
assert args.to == mock_settings.agent_settings.bdi_core_name
assert "Hello world." in args.body
@pytest.mark.asyncio
async def test_llm_processing_errors(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
# HTTP Error
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
await agent.handle_message(msg)
assert "LLM service unavailable." in agent.send.call_args[0][0].body
# General Exception
agent.send.reset_mock()
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
await agent.handle_message(msg)
assert "Error processing the request." in agent.send.call_args[0][0].body
@pytest.mark.asyncio
async def test_llm_json_error(mock_httpx_client, mock_settings):
# Test malformed JSON in stream
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
async def aiter_lines_gen():
yield "data: {bad_json"
yield "data: [DONE]"
mock_response.aiter_lines.side_effect = aiter_lines_gen
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
with patch.object(agent.logger, "error") as log:
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
await agent.handle_message(msg)
log.assert_called() # Should log JSONDecodeError
def test_llm_instructions():
# Full custom
instr = LLMInstructions(norms=["N1", "N2"], goals=["G1", "G2"])
text = instr.build_developer_instruction()
assert "Norms to follow:\n- N1\n- N2" in text
assert "Goals to reach:\n- G1\n- G2" in text
# Defaults
instr_def = LLMInstructions()
text_def = instr_def.build_developer_instruction()
assert "Norms to follow" in text_def
assert "Goals to reach" in text_def

View File

@@ -1,122 +0,0 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
MLXWhisperSpeechRecognizer,
OpenAIWhisperSpeechRecognizer,
SpeechRecognizer,
)
from control_backend.agents.perception.transcription_agent.transcription_agent import (
TranscriptionAgent,
)
@pytest.mark.asyncio
async def test_transcription_agent_flow(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
# Setup context to return this specific mock socket
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# Data: [Audio Bytes, Cancel Loop]
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
# Mock Recognizer
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = "Hello"
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent.send = AsyncMock()
agent._running = True
agent.add_behavior = AsyncMock()
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# Check transcription happened
assert mock_recognizer.recognize_speech.called
# Check sending
assert agent.send.called
assert agent.send.call_args[0][0].body == "Hello"
await agent.stop()
@pytest.mark.asyncio
async def test_transcription_empty(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# Return valid audio, but recognizer returns empty string
fake_audio = np.zeros(10, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = ""
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent.send = AsyncMock()
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# Should NOT send message
agent.send.assert_not_called()
def test_speech_recognizer_factory():
# Test Factory Logic
with patch("torch.mps.is_available", return_value=True):
assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer)
with patch("torch.mps.is_available", return_value=False):
assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer)
def test_openai_recognizer():
with patch("whisper.load_model") as load_mock:
with patch("whisper.transcribe") as trans_mock:
rec = OpenAIWhisperSpeechRecognizer()
rec.load_model()
load_mock.assert_called()
trans_mock.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"
def test_mlx_recognizer():
# Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir().
# We must use create=True to inject it into the module namespace during the test.
module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer"
with patch("sys.platform", "darwin"):
with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock:
with patch(f"{module_path}.ModelHolder", create=True) as holder_mock:
# We also need to mock mlx.core if it's used for types/constants
with patch(f"{module_path}.mx", create=True):
rec = MLXWhisperSpeechRecognizer()
rec.load_model()
holder_mock.get_model.assert_called()
mlx_mock.transcribe.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"

View File

@@ -1,125 +0,0 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def vad_agent(audio_out_socket):
return VADAgent("tcp://localhost:5555", False)
@pytest.fixture(autouse=True)
def patch_settings(monkeypatch):
# Patch the settings that vad_agent.run() reads
from control_backend.agents.perception import vad_agent
monkeypatch.setattr(
vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False
)
monkeypatch.setattr(
vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False
)
monkeypatch.setattr(
vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False
)
monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
"""
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
:param streaming: The streaming component to be tested.
:param probabilities: A list of probabilities representing the outputs of the VAD model.
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock(return_value=model_item)
# Prepare deterministic audio chunks and a poller that stops the loop when exhausted
chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes()
chunks = [chunk_bytes for _ in probabilities]
class DummyPoller:
def __init__(self, data, agent):
self.data = data
self.agent = agent
async def poll(self, timeout_ms=None):
if self.data:
return self.data.pop(0)
# Stop the loop cleanly once we've consumed all chunks
self.agent._running = False
return None
streaming.audio_in_poller = DummyPoller(chunks, streaming)
streaming._ready = AsyncMock()
streaming._running = True
await streaming._streaming_loop()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_out_socket, vad_agent):
"""
Test a scenario where there is voice activity detected between silences.
"""
speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
vad_agent.audio_out_socket = audio_out_socket
await simulate_streaming_with_probabilities(vad_agent, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
assert len(data) == 512 * 4 * (speech_chunk_count + 1)
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
"""
speech_chunk_count = 5
probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
)
vad_agent.audio_out_socket = audio_out_socket
await simulate_streaming_with_probabilities(vad_agent, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# Expecting 13 chunks (2*5 with speech, 1 pause between, 1 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 1)
@pytest.mark.asyncio
async def test_no_data(audio_out_socket, vad_agent):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
class DummyPoller:
async def poll(self, timeout_ms=None):
vad_agent._running = False
return None
vad_agent.audio_out_socket = audio_out_socket
vad_agent.audio_in_poller = DummyPoller()
vad_agent._ready = AsyncMock()
vad_agent._running = True
await vad_agent._streaming_loop()
audio_out_socket.send.assert_not_called()
assert len(vad_agent.audio_buffer) == 0

View File

@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.perception.vad_agent import SocketPoller
from control_backend.agents.vad_agent import SocketPoller
@pytest.fixture
@@ -16,8 +16,8 @@ async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test"
socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll = AsyncMock(return_value=[(socket, zmq.POLLIN)])
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused
@@ -35,8 +35,8 @@ async def test_socket_poller_with_data(socket, mocker):
@pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll = AsyncMock(return_value=[])
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket)
data = await poller.poll()

View File

@@ -0,0 +1,106 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.vad_agent import Streaming
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock BDIAgent."""
agent = MagicMock()
agent.jid = "vad_agent@test"
return agent
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket, mock_agent):
import torch
torch.hub.load.return_value = (..., ...) # Mock
streaming = Streaming(audio_in_socket, audio_out_socket)
streaming._ready = True
streaming.agent = mock_agent
return streaming
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
"""
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
:param streaming: The streaming component to be tested.
:param probabilities: A list of probabilities representing the outputs of the VAD model.
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
for _ in probabilities:
await streaming.run()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is voice activity detected between silences.
:return:
"""
speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count + 2)
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
"""
speech_chunk_count = 5
probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
)
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2)
@pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
await streaming.run()
audio_out_socket.send.assert_not_called()
assert len(streaming.audio_buffer) == 0

View File

@@ -1,30 +1,11 @@
import numpy as np
import pytest
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
from control_backend.agents.transcription.speech_recognizer import (
OpenAIWhisperSpeechRecognizer,
SpeechRecognizer,
)
@pytest.fixture(autouse=True)
def patch_sr_settings(monkeypatch):
# Patch the *module-local* settings that SpeechRecognizer imported
from control_backend.agents.perception.transcription_agent import speech_recognizer as sr
# Provide real numbers for everything _estimate_max_tokens() reads
monkeypatch.setattr(sr.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
monkeypatch.setattr(
sr.settings.behaviour_settings, "transcription_words_per_minute", 450, raising=False
)
monkeypatch.setattr(
sr.settings.behaviour_settings, "transcription_words_per_token", 0.75, raising=False
)
monkeypatch.setattr(
sr.settings.behaviour_settings, "transcription_token_buffer", 10, raising=False
)
def test_estimate_max_tokens():
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
expecting 610 tokens."""

View File

@@ -1,125 +0,0 @@
import json
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import program
from control_backend.schemas.program import Program
@pytest.fixture
def app():
"""Create a FastAPI app with the /program route and mock socket."""
app = FastAPI()
app.include_router(program.router)
return app
@pytest.fixture
def client(app):
"""Create a TestClient."""
return TestClient(app)
def make_valid_program_dict():
"""Helper to create a valid Program JSON structure."""
return {
"phases": [
{
"id": "phase1",
"label": "basephase",
"norms": [{"id": "n1", "label": "norm", "norm": "be nice"}],
"goals": [
{"id": "g1", "label": "goal", "description": "test goal", "achieved": False}
],
"triggers": [
{
"id": "t1",
"label": "trigger",
"type": "keywords",
"keywords": [
{"id": "kw1", "keyword": "keyword1"},
{"id": "kw2", "keyword": "keyword2"},
],
},
],
}
]
}
def test_receive_program_success(client):
"""Valid Program JSON should be parsed and sent through the socket."""
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
program_dict = make_valid_program_dict()
response = client.post("/program", json=program_dict)
assert response.status_code == 202
assert response.json() == {"status": "Program parsed"}
# Verify socket call
mock_pub_socket.send_multipart.assert_awaited_once()
args, kwargs = mock_pub_socket.send_multipart.await_args
assert args[0][0] == b"program"
sent_bytes = args[0][1]
sent_obj = json.loads(sent_bytes.decode())
expected_obj = Program.model_validate(program_dict).model_dump()
assert sent_obj == expected_obj
def test_receive_program_invalid_json(client):
"""
Invalid JSON (malformed) -> FastAPI never calls endpoint.
It returns a 422 Unprocessable Entity.
"""
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# FastAPI only accepts valid JSON bodies, so send raw string
response = client.post("/program", content="{invalid json}")
assert response.status_code == 422
mock_pub_socket.send_multipart.assert_not_called()
def test_receive_program_invalid_deep_structure(client):
"""
Valid JSON but schema invalid -> Pydantic throws validation error -> 422.
"""
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Missing "value" in norms element
bad_program = {
"phases": [
{
"id": "phase1",
"name": "deepfail",
"nextPhaseId": "phase2",
"phaseData": {
"norms": [
{"id": "n1", "name": "norm"} # INVALID: missing "value"
],
"goals": [
{"id": "g1", "name": "goal", "description": "desc", "achieved": False}
],
"triggers": [
{"id": "t1", "label": "trigger", "type": "keyword", "value": ["start"]}
],
},
}
]
}
response = client.post("/program", json=bad_program)
assert response.status_code == 422
mock_pub_socket.send_multipart.assert_not_called()

View File

@@ -1,452 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq.asyncio
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import robot
from control_backend.schemas.ri_message import GestureCommand, SpeechCommand
@pytest.fixture
def app():
"""
Creates a FastAPI test app and attaches the router under test.
Also sets up a mock internal_comm_socket.
"""
app = FastAPI()
app.include_router(robot.router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
@pytest.fixture
def mock_zmq_context():
"""Mock the ZMQ context."""
with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context:
context_instance = MagicMock()
mock_context.return_value = context_instance
yield context_instance
@pytest.fixture
def mock_sockets(mock_zmq_context):
"""Mock ZMQ sockets."""
mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_pub_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_zmq_context.socket.return_value = mock_sub_socket
return {"sub": mock_sub_socket, "pub": mock_pub_socket}
def test_receive_speech_command_success(client):
"""
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
"""
# Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
speech_command = SpeechCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
)
def test_receive_gesture_command_success(client):
"""
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
"""
# Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
command_data = {"endpoint": "actuate/gesture/tag", "data": "happy"}
gesture_command = GestureCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", gesture_command.model_dump_json().encode()]
)
def test_receive_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
assert response.status_code == 422 # validation error
def test_ping_check_returns_none(client):
"""Ensure /ping_check returns 200 and None (currently unimplemented)."""
response = client.get("/ping_check")
assert response.status_code == 200
assert response.json() is None
# TODO: Convert these mock sockets to the fixture.
@pytest.mark.asyncio
async def test_ping_stream_yields_ping_event(monkeypatch):
"""Test that ping_stream yields a proper SSE message when a ping is received."""
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"ping", b"true"])
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
response = await robot.ping_stream(mock_request)
generator = aiter(response.body_iterator)
event = await anext(generator)
event_text = event.decode() if isinstance(event, bytes) else str(event)
assert event_text.strip() == "data: true"
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_ping_stream_handles_timeout(monkeypatch):
"""Test that ping_stream continues looping on TimeoutError."""
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
mock_sub_socket.recv_multipart.side_effect = TimeoutError()
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(return_value=True)
response = await robot.ping_stream(mock_request)
generator = aiter(response.body_iterator)
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_ping_stream_yields_json_values(monkeypatch):
"""Ensure ping_stream correctly parses and yields JSON body values."""
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"ping", json.dumps({"connected": True}).encode()]
)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
response = await robot.ping_stream(mock_request)
generator = aiter(response.body_iterator)
event = await anext(generator)
event_text = event.decode() if isinstance(event, bytes) else str(event)
assert "connected" in event_text
assert "true" in event_text
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
# New tests for get_available_gesture_tags endpoint
@pytest.mark.asyncio
async def test_get_available_gesture_tags_success(client, monkeypatch):
"""
Test successful retrieval of available gesture tags.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with gesture tags
response_data = {"tags": ["wave", "nod", "point", "dance"]}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger to avoid actual logging
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Act
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]}
# Verify ZeroMQ interactions
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_sub_socket.recv_multipart.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_available_gesture_tags_with_amount(client, monkeypatch):
"""
Test retrieval of gesture tags with a specific amount parameter.
This tests the TODO in the endpoint about getting a certain amount from the UI.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with gesture tags
response_data = {"tags": ["wave", "nod"]}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Act - Note: The endpoint currently doesn't support query parameters for amount,
# but we're testing what happens if the UI sends an amount (the TODO in the code)
# For now, we test the current behavior
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod"]}
# The endpoint currently doesn't use the amount parameter, so it should send empty bytes
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
@pytest.mark.asyncio
async def test_get_available_gesture_tags_timeout(client, monkeypatch):
"""
Test timeout scenario when fetching gesture tags.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a timeout
mock_sub_socket.recv_multipart = AsyncMock(side_effect=TimeoutError)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Mock logger to verify debug message is logged
mock_logger = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_logger)
# Act
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200
# On timeout, body becomes b"" and json.loads(b"") raises JSONDecodeError
# But looking at the endpoint code, it will try to parse empty bytes which will fail
# Let's check what actually happens
assert response.json() == {"available_gesture_tags": []}
# Verify the timeout was logged
mock_logger.assert_called_once_with("got timeout error fetching gestures")
# Verify ZeroMQ interactions
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
mock_sub_socket.recv_multipart.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_available_gesture_tags_empty_response(client, monkeypatch):
"""
Test scenario when response contains no tags.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with empty tags
response_data = {"tags": []}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Act
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []}
@pytest.mark.asyncio
async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch):
"""
Test scenario when response JSON doesn't contain 'tags' key.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response without 'tags' key
response_data = {"some_other_key": "value"}
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"get_gestures", json.dumps(response_data).encode()]
)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Act
response = client.get("/get_available_gesture_tags")
# Assert
assert response.status_code == 200
# .get("tags", []) should return empty list if 'tags' key is missing
assert response.json() == {"available_gesture_tags": []}
@pytest.mark.asyncio
async def test_get_available_gesture_tags_invalid_json(client, monkeypatch):
"""
Test scenario when response contains invalid JSON.
"""
# Arrange
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
# Simulate a response with invalid JSON
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"get_gestures", b"invalid json"])
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Mock settings
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
# Act
response = client.get("/get_available_gesture_tags")
# Assert - invalid JSON should raise an exception
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []}

View File

@@ -1,43 +1,66 @@
from unittest.mock import MagicMock, patch
import pytest
from control_backend.core.agent_system import _agent_directory
import sys
from unittest.mock import MagicMock
@pytest.fixture(autouse=True)
def reset_agent_directory():
def pytest_configure(config):
"""
Automatically clears the global agent directory before and after each test
to prevent state leakage between tests.
This hook runs at the start of the pytest session, before any tests are
collected. It mocks heavy or unavailable modules to prevent ImportErrors.
"""
_agent_directory.clear()
yield
_agent_directory.clear()
# --- Mock spade and spade-bdi ---
mock_agentspeak = MagicMock()
mock_httpx = MagicMock()
mock_pydantic = MagicMock()
mock_spade = MagicMock()
mock_spade.agent = MagicMock()
mock_spade.behaviour = MagicMock()
mock_spade.message = MagicMock()
mock_spade_bdi = MagicMock()
mock_spade_bdi.bdi = MagicMock()
mock_spade.agent.Message = MagicMock()
mock_spade.behaviour.CyclicBehaviour = type("CyclicBehaviour", (object,), {})
mock_spade_bdi.bdi.BDIAgent = type("BDIAgent", (object,), {})
@pytest.fixture
def mock_settings():
with patch("control_backend.core.config.settings") as mock:
# Set default values that match the pydantic model defaults
# to avoid AttributeErrors during tests
mock.zmq_settings.internal_pub_address = "tcp://localhost:5560"
mock.zmq_settings.internal_sub_address = "tcp://localhost:5561"
mock.zmq_settings.ri_command_address = "tcp://localhost:0000"
mock.agent_settings.bdi_core_name = "bdi_core_agent"
mock.agent_settings.bdi_belief_collector_name = "belief_collector_agent"
mock.agent_settings.llm_name = "llm_agent"
mock.agent_settings.robot_speech_name = "robot_speech_agent"
mock.agent_settings.transcription_name = "transcription_agent"
mock.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent"
mock.agent_settings.vad_name = "vad_agent"
mock.behaviour_settings.sleep_s = 0.01 # Speed up tests
mock.behaviour_settings.comm_setup_max_retries = 1
yield mock
sys.modules["agentspeak"] = mock_agentspeak
sys.modules["httpx"] = mock_httpx
sys.modules["pydantic"] = mock_pydantic
sys.modules["spade"] = mock_spade
sys.modules["spade.agent"] = mock_spade.agent
sys.modules["spade.behaviour"] = mock_spade.behaviour
sys.modules["spade.message"] = mock_spade.message
sys.modules["spade_bdi"] = mock_spade_bdi
sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi
# --- Mock the config module to prevent Pydantic ImportError ---
mock_config_module = MagicMock()
@pytest.fixture
def mock_zmq_context():
with patch("zmq.asyncio.Context") as mock:
mock.instance.return_value = MagicMock()
yield mock
# The code under test does `from ... import settings`, so our mock module
# must have a `settings` attribute. We'll make it a MagicMock so we can
# configure it later in our tests using mocker.patch.
mock_config_module.settings = MagicMock()
sys.modules["control_backend.core.config"] = mock_config_module
# --- Mock torch and zmq for VAD ---
mock_torch = MagicMock()
mock_zmq = MagicMock()
mock_zmq.asyncio = mock_zmq
# In individual tests, these can be imported and the return values changed
sys.modules["torch"] = mock_torch
sys.modules["zmq"] = mock_zmq
sys.modules["zmq.asyncio"] = mock_zmq.asyncio
# --- Mock whisper ---
mock_whisper = MagicMock()
mock_mlx = MagicMock()
mock_mlx.core = MagicMock()
mock_mlx_whisper = MagicMock()
mock_mlx_whisper.transcribe = MagicMock()
sys.modules["whisper"] = mock_whisper
sys.modules["mlx"] = mock_mlx
sys.modules["mlx.core"] = mock_mlx
sys.modules["mlx_whisper"] = mock_mlx_whisper
sys.modules["mlx_whisper.transcribe"] = mock_mlx_whisper.transcribe

View File

@@ -1,72 +0,0 @@
"""Test the base class logic, message passing and background task handling."""
import asyncio
import logging
from unittest.mock import AsyncMock
import pytest
from control_backend.core.agent_system import AgentDirectory, BaseAgent, InternalMessage
class ConcreteTestAgent(BaseAgent):
logger = logging.getLogger("test")
def __init__(self, name: str):
super().__init__(name)
self.received = []
async def setup(self):
pass
async def handle_message(self, msg: InternalMessage):
self.received.append(msg)
if msg.body == "stop":
await self.stop()
@pytest.mark.asyncio
async def test_agent_lifecycle():
agent = ConcreteTestAgent("lifecycle_agent")
await agent.start()
assert agent._running is True
# Test background task
async def dummy_task():
pass
task = agent.add_behavior(dummy_task())
assert task in agent._tasks
await task
# Wait for task to finish
assert task not in agent._tasks
assert len(agent._tasks) == 2 # message handling tasks are still running
await agent.stop()
assert agent._running is False
await asyncio.sleep(0.01)
# Tasks should be cancelled
assert len(agent._tasks) == 0
@pytest.mark.asyncio
async def test_send_unknown_agent():
agent = ConcreteTestAgent("sender")
msg = InternalMessage(to="unknown_receiver", sender="sender", body="boo")
agent._internal_pub_socket = AsyncMock()
await agent.send(msg)
agent._internal_pub_socket.send_multipart.assert_called()
@pytest.mark.asyncio
async def test_get_agent():
agent = ConcreteTestAgent("registrant")
assert AgentDirectory.get("registrant") == agent
assert AgentDirectory.get("non_existent") is None

View File

@@ -1,14 +0,0 @@
"""Test if settings load correctly and environment variables override defaults."""
from control_backend.core.config import Settings
def test_default_settings():
settings = Settings()
assert settings.app_title == "PepperPlus"
def test_env_override(monkeypatch):
monkeypatch.setenv("APP_TITLE", "TestPepper")
settings = Settings()
assert settings.app_title == "TestPepper"

View File

@@ -1,88 +0,0 @@
import logging
from unittest.mock import mock_open, patch
import pytest
from control_backend.logging.setup_logging import add_logging_level, setup_logging
def test_add_logging_level():
# Add a unique level to avoid conflicts with other tests/libraries
level_name = "TESTLEVEL"
level_num = 35
add_logging_level(level_name, level_num)
assert logging.getLevelName(level_num) == level_name
assert hasattr(logging, level_name)
assert hasattr(logging.getLoggerClass(), level_name.lower())
# Test functionality
logger = logging.getLogger("test_custom_level")
with patch.object(logger, "_log") as mock_log:
getattr(logger, level_name.lower())("message")
mock_log.assert_called_with(level_num, "message", ())
# Test duplicates
with pytest.raises(AttributeError):
add_logging_level(level_name, level_num)
with pytest.raises(AttributeError):
add_logging_level("INFO", 20) # Existing level
def test_setup_logging_no_file(caplog):
with patch("os.path.exists", return_value=False):
setup_logging("dummy.yaml")
assert "Logging config file not found" in caplog.text
def test_setup_logging_yaml_error(caplog):
with patch("os.path.exists", return_value=True):
with patch("builtins.open", mock_open(read_data="invalid: [yaml")):
with patch("logging.config.dictConfig") as mock_dict_config:
setup_logging("config.yaml")
# Verify we logged the warning
assert "Could not load logging configuration" in caplog.text
# Verify dictConfig was called with empty dict (which would crash real dictConfig)
mock_dict_config.assert_called_with({})
assert "Could not load logging configuration" in caplog.text
def test_setup_logging_success():
config_data = """
version: 1
handlers:
console:
class: logging.StreamHandler
root:
handlers: [console]
level: INFO
custom_levels:
MYLEVEL: 15
"""
with patch("os.path.exists", return_value=True):
with patch("builtins.open", mock_open(read_data=config_data)):
with patch("logging.config.dictConfig") as mock_dict_config:
setup_logging("config.yaml")
mock_dict_config.assert_called()
assert hasattr(logging, "MYLEVEL")
def test_setup_logging_zmq_handler(mock_zmq_context):
config_data = """
version: 1
handlers:
ui:
class: logging.NullHandler
# In real config this would be a zmq handler, but for unit test logic
# we just want to see if the socket injection happens
"""
with patch("os.path.exists", return_value=True):
with patch("builtins.open", mock_open(read_data=config_data)):
with patch("logging.config.dictConfig") as mock_dict_config:
setup_logging("config.yaml")
args = mock_dict_config.call_args[0][0]
assert "interface_or_socket" in args["handlers"]["ui"]

View File

@@ -1,88 +0,0 @@
import pytest
from pydantic import ValidationError
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, RIMessage, SpeechCommand
def valid_command_1():
return SpeechCommand(data="Hallo?")
def valid_command_2():
return GestureCommand(endpoint=RIEndpoint.GESTURE_TAG, data="happy")
def valid_command_3():
return GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="happy_1")
def invalid_command_1():
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
def invalid_command_2():
return RIMessage(endpoint=RIEndpoint.PING, data="Hey!")
def invalid_command_3():
return RIMessage(endpoint=RIEndpoint.GESTURE_SINGLE, data={1, 2, 3})
def invalid_command_4():
test: RIMessage = GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="asdsad")
def change_endpoint(msg: RIMessage):
msg.endpoint = RIEndpoint.PING
change_endpoint(test)
return test
def test_valid_speech_command_1():
command = valid_command_1()
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
def test_valid_gesture_command_1():
command = valid_command_2()
RIMessage.model_validate(command)
GestureCommand.model_validate(command)
def test_valid_gesture_command_2():
command = valid_command_3()
RIMessage.model_validate(command)
GestureCommand.model_validate(command)
def test_invalid_speech_command_1():
command = invalid_command_1()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
SpeechCommand.model_validate(command)
def test_invalid_gesture_command_1():
command = invalid_command_2()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)
def test_invalid_gesture_command_2():
command = invalid_command_3()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)
def test_invalid_gesture_command_3():
command = invalid_command_4()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)

View File

@@ -1,87 +0,0 @@
import pytest
from pydantic import ValidationError
from control_backend.schemas.program import (
Goal,
KeywordTrigger,
Norm,
Phase,
Program,
TriggerKeyword,
)
def base_norm() -> Norm:
return Norm(
id="norm1",
label="testNorm",
norm="testNormNorm",
)
def base_goal() -> Goal:
return Goal(
id="goal1",
label="testGoal",
description="testGoalDescription",
achieved=False,
)
def base_trigger() -> KeywordTrigger:
return KeywordTrigger(
id="trigger1",
label="testTrigger",
type="keywords",
keywords=[
TriggerKeyword(id="keyword1", keyword="testKeyword1"),
TriggerKeyword(id="keyword1", keyword="testKeyword2"),
],
)
def base_phase() -> Phase:
return Phase(
id="phase1",
label="basephase",
norms=[base_norm()],
goals=[base_goal()],
triggers=[base_trigger()],
)
def base_program() -> Program:
return Program(phases=[base_phase()])
def invalid_program() -> dict:
# wrong types inside phases list (not Phase objects)
return {
"phases": [
{"id": "phase1"}, # incomplete
{"not_a_phase": True},
]
}
def test_valid_program():
program = base_program()
validated = Program.model_validate(program)
assert isinstance(validated, Program)
assert validated.phases[0].norms[0].norm == "testNormNorm"
def test_valid_deepprogram():
program = base_program()
validated = Program.model_validate(program)
# validate nested components directly
phase = validated.phases[0]
assert isinstance(phase.goals[0], Goal)
assert isinstance(phase.triggers[0], KeywordTrigger)
assert isinstance(phase.norms[0], Norm)
def test_invalid_program():
bad = invalid_program()
with pytest.raises(ValidationError):
Program.model_validate(bad)

1033
uv.lock generated

File diff suppressed because it is too large Load Diff