Remove SPADE dependency #29

Merged
k.marinus merged 28 commits from refactor/remove-spade into dev 2025-11-25 10:26:07 +00:00
51 changed files with 2408 additions and 3105 deletions

View File

@@ -22,6 +22,4 @@ test:
tags: tags:
- test - test
script: script:
# - uv run --group integration-test pytest test/integration - uv run --only-group test pytest test
- uv run --only-group test pytest test/unit

View File

@@ -8,7 +8,7 @@ formatters:
# Console output # Console output
colored: colored:
(): "colorlog.ColoredFormatter" (): "colorlog.ColoredFormatter"
format: "{log_color}{asctime} | {levelname:11} | {name:70} | {message}" format: "{log_color}{asctime}.{msecs:03.0f} | {levelname:11} | {name:70} | {message}"
style: "{" style: "{"
datefmt: "%H:%M:%S" datefmt: "%H:%M:%S"

View File

@@ -4,34 +4,35 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html # https://www.sphinx-doc.org/en/master/usage/configuration.html
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath("../src")) sys.path.insert(0, os.path.abspath("../src"))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = 'control_backend' project = "control_backend"
copyright = '2025, Author' copyright = "2025, Author"
author = 'Author' author = "Author"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.todo', "sphinx.ext.todo",
] ]
templates_path = ['_templates'] templates_path = ["_templates"]
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
language = 'en' language = "en"
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
html_static_path = ['_static'] html_static_path = ["_static"]
# -- Options for todo extension ---------------------------------------------- # -- Options for todo extension ----------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/extensions/todo.html#configuration # https://www.sphinx-doc.org/en/master/usage/extensions/todo.html#configuration

View File

@@ -5,6 +5,7 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"agentspeak>=0.2.2",
"colorlog>=6.10.1", "colorlog>=6.10.1",
"fastapi[all]>=0.115.6", "fastapi[all]>=0.115.6",
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'", "mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
@@ -13,16 +14,10 @@ dependencies = [
"pyaudio>=0.2.14", "pyaudio>=0.2.14",
"pydantic>=2.12.0", "pydantic>=2.12.0",
"pydantic-settings>=2.11.0", "pydantic-settings>=2.11.0",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
"python-json-logger>=4.0.0", "python-json-logger>=4.0.0",
"pyyaml>=6.0.3", "pyyaml>=6.0.3",
"pyzmq>=27.1.0", "pyzmq>=27.1.0",
"silero-vad>=6.0.0", "silero-vad>=6.0.0",
"spade>=4.1.0",
"spade-bdi>=0.3.2",
"sphinx>=7.3.7", "sphinx>=7.3.7",
"sphinx-rtd-theme>=3.0.2", "sphinx-rtd-theme>=3.0.2",
"torch>=2.8.0", "torch>=2.8.0",
@@ -32,18 +27,29 @@ dependencies = [
[dependency-groups] [dependency-groups]
dev = [ dev = [
"pre-commit>=4.3.0", "pre-commit>=4.3.0",
"ruff>=0.14.2",
"ruff-format>=0.3.0",
]
integration-test = [
"soundfile>=0.13.1",
]
test = [
"numpy>=2.3.3",
"pytest>=8.4.2", "pytest>=8.4.2",
"pytest-asyncio>=1.2.0", "pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0", "pytest-cov>=7.0.0",
"pytest-mock>=3.15.1", "pytest-mock>=3.15.1",
"soundfile>=0.13.1",
"ruff>=0.14.2",
"ruff-format>=0.3.0",
]
test = [
"agentspeak>=0.2.2",
"fastapi>=0.115.6",
"httpx>=0.28.1",
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
"openai-whisper>=20250625",
"pydantic>=2.12.0",
"pydantic-settings>=2.11.0",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
"pyyaml>=6.0.3",
"pyzmq>=27.1.0",
"soundfile>=0.13.1",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@@ -1,11 +1,10 @@
import json import json
import spade.agent
import zmq import zmq
from spade.behaviour import CyclicBehaviour import zmq.asyncio as azmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
@@ -18,57 +17,21 @@ class RobotSpeechAgent(BaseAgent):
def __init__( def __init__(
self, self,
jid: str, name: str,
password: str,
port: int = settings.agent_settings.default_spade_port,
verify_security: bool = False,
address=settings.zmq_settings.ri_command_address, address=settings.zmq_settings.ri_command_address,
bind=False, bind=False,
): ):
super().__init__(jid, password, port, verify_security) super().__init__(name)
self.address = address self.address = address
self.bind = bind self.bind = bind
class SendZMQCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from the UI."""
async def run(self):
"""
Run the command publishing loop indefinetely.
"""
assert self.agent is not None
# Get a message internally (with topic command)
topic, body = await self.agent.subsocket.recv_multipart()
# Try to get body
try:
body = json.loads(body)
message = SpeechCommand.model_validate(body)
# Send to the robot.
await self.agent.pubsocket.send_json(message.model_dump())
except Exception as e:
self.agent.logger.error("Error processing message: %s", e)
class SendSpadeCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from other Python agents."""
async def run(self):
message: spade.agent.Message = await self.receive(timeout=0.1)
if message and message.to == self.agent.jid:
try:
speech_command = SpeechCommand.model_validate_json(message.body)
await self.agent.pubsocket.send_json(speech_command.model_dump())
except Exception as e:
self.agent.logger.error("Error processing message: %s", e)
async def setup(self): async def setup(self):
""" """
Setup the robot speech command agent Setup the robot speech command agent
""" """
self.logger.info("Setting up %s", self.jid) self.logger.info("Setting up %s", self.name)
context = Context.instance() context = azmq.Context.instance()
# To the robot # To the robot
self.pubsocket = context.socket(zmq.PUB) self.pubsocket = context.socket(zmq.PUB)
@@ -82,9 +45,38 @@ class RobotSpeechAgent(BaseAgent):
self.subsocket.connect(settings.zmq_settings.internal_sub_address) self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent self.add_behavior(self._zmq_command_loop())
commands_behaviour = self.SendZMQCommandsBehaviour()
self.add_behaviour(commands_behaviour)
self.add_behaviour(self.SendSpadeCommandsBehaviour())
self.logger.info("Finished setting up %s", self.jid) self.logger.info("Finished setting up %s", self.name)
async def stop(self):
if self.subsocket:
self.subsocket.close()
if self.pubsocket:
self.pubsocket.close()
await super().stop()
async def handle_message(self, msg: InternalMessage):
"""
Handle commands received from other Python agents.
"""
try:
speech_command = SpeechCommand.model_validate_json(msg.body)
await self.pubsocket.send_json(speech_command.model_dump())
except Exception:
self.logger.exception("Error processing internal message.")
async def _zmq_command_loop(self):
"""
Handle commands from the UI.
"""
while self._running:
try:
_, body = await self.subsocket.recv_multipart()
body = json.loads(body)
message = SpeechCommand.model_validate(body)
await self.pubsocket.send_json(message.model_dump())
except Exception:
self.logger.exception("Error processing ZMQ message.")

View File

@@ -1,12 +1,12 @@
import logging import logging
from spade.agent import Agent from control_backend.core.agent_system import BaseAgent as CoreBaseAgent
class BaseAgent(Agent): class BaseAgent(CoreBaseAgent):
""" """
Base agent class for our agents to inherit from. Base agent class for our agents to inherit from. This just ensures
This ensures that all agents have a logger. all agents have a logger.
""" """
logger: logging.Logger logger: logging.Logger

View File

@@ -1,7 +1,7 @@
from .bdi_core_agent.bdi_core_agent import BDICoreAgent as BDICoreAgent from .bdi_core_agent.bdi_core_agent import BDICoreAgent as BDICoreAgent
from .belief_collector_agent.belief_collector_agent import ( from .belief_collector_agent import (
BDIBeliefCollectorAgent as BDIBeliefCollectorAgent, BDIBeliefCollectorAgent as BDIBeliefCollectorAgent,
) )
from .text_belief_extractor_agent.text_belief_extractor_agent import ( from .text_belief_extractor_agent import (
TextBeliefExtractorAgent as TextBeliefExtractorAgent, TextBeliefExtractorAgent as TextBeliefExtractorAgent,
) )

View File

@@ -1,67 +1,201 @@
import logging import asyncio
import copy
import time
from collections.abc import Iterable
import agentspeak import agentspeak
from spade.behaviour import OneShotBehaviour import agentspeak.runtime
from spade.message import Message import agentspeak.stdlib
from spade_bdi.bdi import BDIAgent from pydantic import ValidationError
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage
from .behaviours.belief_setter_behaviour import BeliefSetterBehaviour from control_backend.schemas.ri_message import SpeechCommand
from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour
class BDICoreAgent(BDIAgent): class BDICoreAgent(BaseAgent):
""" bdi_agent: agentspeak.runtime.Agent
This is the Brain agent that does the belief inference with AgentSpeak.
This is a continous process that happens automatically in the background.
This class contains all the actions that can be called from AgentSpeak plans.
It has the BeliefSetter behaviour and can aks and recieve requests from the LLM agent.
"""
logger = logging.getLogger(__package__).getChild(__name__) def __init__(self, name: str, asl: str):
super().__init__(name)
self.asl_file = asl
self.env = agentspeak.runtime.Environment()
# Deep copy because we don't actually want to modify the standard actions globally
self.actions = copy.deepcopy(agentspeak.stdlib.actions)
self._wake_bdi_loop = asyncio.Event()
async def setup(self) -> None: async def setup(self) -> None:
""" self.logger.debug("Setup started.")
Initializes belief behaviors and message routing.
"""
self.logger.info("BDICoreAgent setup started.")
self.add_behaviour(BeliefSetterBehaviour()) self._add_custom_actions()
self.add_behaviour(ReceiveLLMResponseBehaviour())
self.logger.info("BDICoreAgent setup complete.") await self._load_asl()
def add_custom_actions(self, actions) -> None: # Start the BDI cycle loop
""" self.add_behavior(self._bdi_loop())
Registers custom AgentSpeak actions callable from plans. self._wake_bdi_loop.set()
""" self.logger.debug("Setup complete.")
@actions.add(".reply", 1) async def _load_asl(self):
def _reply(agent: "BDICoreAgent", term, intention): try:
""" with open(self.asl_file) as source:
Sends text to the LLM (AgentSpeak action). self.bdi_agent = self.env.build_agent(source, self.actions)
Example: .reply("Hello LLM!") except FileNotFoundError:
""" self.logger.warning(f"Could not find the specified ASL file at {self.asl_file}.")
message_text = agentspeak.grounded(term.args[0], intention.scope) self.bdi_agent = agentspeak.runtime.Agent(self.env, self.name)
self.logger.debug("Reply action sending: %s", message_text)
self._send_to_llm(str(message_text)) async def _bdi_loop(self):
yield
def _send_to_llm(self, text: str):
""" """
Sends a text query to the LLM Agent asynchronously. Runs the AgentSpeak BDI loop. Efficiently checks for when the next expected work will be.
""" """
while self._running:
await (
self._wake_bdi_loop.wait()
) # gets set whenever there's an update to the belief base
class SendBehaviour(OneShotBehaviour): # Agent knows when it's expected to have to do its next thing
async def run(self) -> None: maybe_more_work = True
msg = Message( while maybe_more_work:
to=settings.agent_settings.llm_name + "@" + settings.agent_settings.host, maybe_more_work = False
body=text, self.logger.debug("Stepping BDI.")
if self.bdi_agent.step():
maybe_more_work = True
if not maybe_more_work:
deadline = self.bdi_agent.shortest_deadline()
if deadline:
self.logger.debug("Sleeping until %s", deadline)
await asyncio.sleep(deadline - time.time())
maybe_more_work = True
else:
self._wake_bdi_loop.clear()
self.logger.debug("No more deadlines. Halting BDI loop.")
async def handle_message(self, msg: InternalMessage):
"""
Route incoming messages (Beliefs or LLM responses).
"""
sender = msg.sender
match sender:
case settings.agent_settings.bdi_belief_collector_name:
self.logger.debug("Processing message from belief collector.")
try:
if msg.thread == "beliefs":
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
self._add_beliefs(beliefs)
except ValidationError:
self.logger.exception("Error processing belief.")
case settings.agent_settings.llm_name:
content = msg.body
self.logger.info("Received LLM response: %s", content)
# Forward to Robot Speech Agent
cmd = SpeechCommand(data=content)
out_msg = InternalMessage(
to=settings.agent_settings.robot_speech_name,
sender=self.name,
body=cmd.model_dump_json(),
)
await self.send(out_msg)
def _add_beliefs(self, beliefs: dict[str, list[str]]):
if not beliefs:
return
for name, args in beliefs.items():
self._add_belief(name, args)
def _add_belief(self, name: str, args: Iterable[str] = []):
new_args = (agentspeak.Literal(arg) for arg in args)
term = agentspeak.Literal(name, new_args)
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.belief,
term,
agentspeak.runtime.Intention(),
) )
await self.send(msg) self._wake_bdi_loop.set()
self.agent.logger.info("Message sent to LLM agent: %s", text)
self.add_behaviour(SendBehaviour()) self.logger.debug(f"Added belief {self.format_belief_string(name, args)}")
def _remove_belief(self, name: str, args: Iterable[str]):
"""
Removes a specific belief (with arguments), if it exists.
"""
new_args = (agentspeak.Literal(arg) for arg in args)
term = agentspeak.Literal(name, new_args)
result = self.bdi_agent.call(
agentspeak.Trigger.removal,
agentspeak.GoalType.belief,
term,
agentspeak.runtime.Intention(),
)
if result:
self.logger.debug(f"Removed belief {self.format_belief_string(name, args)}")
self._wake_bdi_loop.set()
else:
self.logger.debug("Failed to remove belief (it was not in the belief base).")
# TODO: decide if this is needed
def _remove_all_with_name(self, name: str):
"""
Removes all beliefs that match the given `name`.
"""
relevant_groups = []
for key in self.bdi_agent.beliefs:
if key[0] == name:
relevant_groups.append(key)
removed_count = 0
for group in relevant_groups:
for belief in self.bdi_agent.beliefs[group]:
self.bdi_agent.call(
agentspeak.Trigger.removal,
agentspeak.GoalType.belief,
belief,
agentspeak.runtime.Intention(),
)
removed_count += 1
self._wake_bdi_loop.set()
self.logger.debug(f"Removed {removed_count} beliefs.")
def _add_custom_actions(self) -> None:
"""
Add any custom actions here. Inside `@self.actions.add()`, the first argument is
the name of the function in the ASL file, and the second the amount of arguments
the function expects (which will be located in `term.args`).
"""
@self.actions.add(".reply", 1)
def _reply(agent, term, intention):
"""
Sends text to the LLM.
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
asyncio.create_task(self._send_to_llm(str(message_text)))
yield
async def _send_to_llm(self, text: str):
"""
Sends a text query to the LLM agent asynchronously.
"""
msg = InternalMessage(to=settings.agent_settings.llm_name, sender=self.name, body=text)
await self.send(msg)
self.logger.info("Message sent to LLM agent: %s", text)
@staticmethod
def format_belief_string(name: str, args: Iterable[str] = []):
"""
Given a belief's name and its args, return a string of the form "name(*args)"
"""
return f"{name}{'(' if args else ''}{','.join(args)}{')' if args else ''}"

View File

@@ -1,85 +0,0 @@
import json
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings
class BeliefSetterBehaviour(CyclicBehaviour):
"""
This is the behaviour that the BDI agent runs. This behaviour waits for incoming
message and updates the agent's beliefs accordingly.
"""
agent: BDIAgent
async def run(self):
"""Polls for messages and processes them."""
msg = await self.receive()
self.agent.logger.debug(
"Received message from %s with thread '%s' and body: %s",
msg.sender,
msg.thread,
msg.body,
)
self._process_message(msg)
def _process_message(self, message: Message):
"""Routes the message to the correct processing function based on the sender."""
sender = message.sender.node # removes host from jid and converts to str
self.agent.logger.debug("Processing message from sender: %s", sender)
match sender:
case settings.agent_settings.bdi_belief_collector_name:
self.agent.logger.debug(
"Message is from the belief collector agent. Processing as belief message."
)
self._process_belief_message(message)
case _:
self.agent.logger.debug("Not the belief agent, discarding message")
pass
def _process_belief_message(self, message: Message):
if not message.body:
self.agent.logger.debug("Ignoring message with empty body from %s", message.sender.node)
return
match message.thread:
case "beliefs":
try:
beliefs: dict[str, list[str]] = json.loads(message.body)
self._set_beliefs(beliefs)
except json.JSONDecodeError:
self.agent.logger.error(
"Could not decode beliefs from JSON. Message body: '%s'",
message.body,
exc_info=True,
)
case _:
pass
def _set_beliefs(self, beliefs: dict[str, list[str]]):
"""Removes previous values for beliefs and updates them with the provided values."""
if self.agent.bdi is None:
self.agent.logger.warning("Cannot set beliefs; agent's BDI is not yet initialized.")
return
if not beliefs:
self.agent.logger.debug("Received an empty set of beliefs. No beliefs were updated.")
return
# Set new beliefs (outdated beliefs are automatically removed)
for belief, arguments in beliefs.items():
self.agent.logger.debug("Setting belief %s with arguments %s", belief, arguments)
self.agent.bdi.set_belief(belief, *arguments)
# Special case: if there's a new user message, flag that we haven't responded yet
if belief == "user_said":
self.agent.bdi.set_belief("new_message")
self.agent.logger.debug(
"Detected 'user_said' belief, also setting 'new_message' belief."
)
self.agent.logger.info("Successfully updated %d beliefs.", len(beliefs))

View File

@@ -1,37 +0,0 @@
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
class ReceiveLLMResponseBehaviour(CyclicBehaviour):
"""
Adds behavior to receive responses from the LLM Agent.
"""
async def run(self):
msg = await self.receive()
sender = msg.sender.node
match sender:
case settings.agent_settings.llm_name:
content = msg.body
self.agent.logger.info("Received LLM response: %s", content)
speech_command = SpeechCommand(data=content)
message = Message(
to=settings.agent_settings.robot_speech_name
+ "@"
+ settings.agent_settings.host,
sender=self.agent.jid,
body=speech_command.model_dump_json(),
)
self.agent.logger.debug("Sending message: %s", message)
await self.send(message)
case _:
self.agent.logger.debug("Discarding message from %s", sender)
pass

View File

@@ -1,3 +1,3 @@
+new_message : user_said(Message) <- +user_said(Message) <-
-new_message; -user_said(Message);
.reply(Message). .reply(Message).

View File

@@ -0,0 +1,89 @@
import json
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage
class BDIBeliefCollectorAgent(BaseAgent):
"""
Continuously collects beliefs/emotions from extractor agents and forwards a
unified belief packet to the BDI agent.
"""
async def setup(self):
self.logger.info("Setting up %s", self.name)
async def handle_message(self, msg: InternalMessage):
sender_node = msg.sender
# Parse JSON payload
try:
payload = json.loads(msg.body)
except Exception as e:
self.logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node,
msg.body,
e,
)
return
msg_type = payload.get("type")
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text":
self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node)
await self._handle_belief_text(payload, sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text":
self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
await self._handle_emo_text(payload, sender_node)
else:
self.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type
)
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Expected payload:
{
"type": "belief_extraction_text",
"beliefs": {"user_said": ["Can you help me?"]}
}
"""
beliefs = payload.get("beliefs", {})
if not beliefs:
self.logger.debug("Received empty beliefs set.")
return
self.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief_name, belief_list in beliefs.items():
for belief in belief_list:
self.logger.debug(" - %s %s", belief_name, str(belief))
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""TODO: implement (after we have emotional recognition)"""
pass
async def _send_beliefs_to_bdi(self, beliefs: dict, origin: str | None = None):
"""
Sends a unified belief packet to the BDI agent.
"""
if not beliefs:
return
msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
thread="beliefs",
)
await self.send(msg)
self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))

View File

@@ -1,92 +0,0 @@
import json
from json import JSONDecodeError
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
class BeliefCollectorBehaviour(CyclicBehaviour):
"""
Continuously collects beliefs/emotions from extractor agents:
Then we send a unified belief packet to the BDI agent.
"""
async def run(self):
msg = await self.receive()
await self._process_message(msg)
async def _process_message(self, msg: Message):
sender_node = msg.sender.node
# Parse JSON payload
try:
payload = json.loads(msg.body)
except JSONDecodeError as e:
self.agent.logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node,
msg.body,
e,
)
return
msg_type = payload.get("type")
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "bel_text_agent_mock":
self.agent.logger.debug(
"Message routed to _handle_belief_text (sender=%s)", sender_node
)
await self._handle_belief_text(payload, sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
self.agent.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
await self._handle_emo_text(payload, sender_node)
else:
self.agent.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type
)
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Expected payload:
{
"type": "belief_extraction_text",
"beliefs": {"user_said": ["Can you help me?"]}
}
"""
beliefs = payload.get("beliefs", {})
if not beliefs:
self.agent.logger.debug("Received empty beliefs set.")
return
self.agent.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief_name, belief_list in beliefs.items():
for belief in belief_list:
self.agent.logger.debug(" - %s %s", belief_name, str(belief))
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""TODO: implement (after we have emotional recogntion)"""
pass
async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None):
"""
Sends a unified belief packet to the BDI agent.
"""
if not beliefs:
return
to_jid = f"{settings.agent_settings.bdi_core_name}@{settings.agent_settings.host}"
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs)
await self.send(msg)
self.agent.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))

View File

@@ -1,11 +0,0 @@
from control_backend.agents.base import BaseAgent
from .behaviours.belief_collector_behaviour import BeliefCollectorBehaviour
class BDIBeliefCollectorAgent(BaseAgent):
async def setup(self):
self.logger.info("BDIBeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(BeliefCollectorBehaviour())
self.logger.info("BDIBeliefCollectorAgent ready.")

View File

@@ -0,0 +1,38 @@
import json
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
class TextBeliefExtractorAgent(BaseAgent):
async def setup(self):
self.logger.info("Settting up %s.", self.name)
# Setup LLM belief context if needed (currently demo is just passthrough)
self.beliefs = {"mood": ["X"], "car": ["Y"]}
async def handle_message(self, msg: InternalMessage):
sender = msg.sender
if sender == settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body)
else:
self.logger.info("Discarding message from %s", sender)
async def _process_transcription_demo(self, txt: str):
"""
Demo version to process the transcription input to beliefs.
"""
# For demo, just wrapping user text as user_said belief
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = InternalMessage(
to=settings.agent_settings.bdi_belief_collector_name,
sender=self.name,
body=payload,
thread="beliefs",
)
await self.send(belief_msg)
self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"]))

View File

@@ -1,104 +0,0 @@
import json
import logging
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings
class TextBeliefExtractorBehaviour(CyclicBehaviour):
logger = logging.getLogger(__name__)
# TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent. Your task is to extract values \
from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
and "text" (user's transcript).
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
A single piece of text might contain multiple instances that match a belief.
Respond ONLY with a single JSON object.
The JSON object's keys should be the belief functors (e.g., "weather").
The value for each key must be a list of lists.
Each inner list must contain the extracted arguments (as strings) for one instance \
of that belief.
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
in your response.
"""
# on_start agent receives message containing the beliefs to look out for and
# sets up the LLM with instruction prompt
# async def on_start(self):
# msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message
# send instruction prompt to LLM
beliefs: dict[str, list[str]]
beliefs = {"mood": ["X"], "car": ["Y"]}
async def run(self):
msg = await self.receive()
if msg is None:
return
sender = msg.sender.node
match sender:
case settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body)
case _:
self.logger.info("Discarding message from %s", sender)
pass
async def _process_transcription(self, text: str):
text_prompt = f"Text: {text}"
beliefs_prompt = "These are the beliefs to be bound:\n"
for belief, values in self.beliefs.items():
beliefs_prompt += f"{belief}({', '.join(values)})\n"
prompt = text_prompt + beliefs_prompt
self.logger.info(prompt)
# prompt_msg = Message(to="LLMAgent@whatever")
# response = self.send(prompt_msg)
# Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]]
response = '{"mood": [["happy"]]}'
# Verify by trying to parse
try:
json.loads(response)
belief_message = Message()
belief_message.to = (
settings.agent_settings.bdi_belief_collector_name
+ "@"
+ settings.agent_settings.host
)
belief_message.body = response
belief_message.thread = "beliefs"
await self.send(belief_message)
self.agent.logger.info("Sent beliefs to BDI.")
except json.JSONDecodeError:
# Parsing failed, so the response is in the wrong format, log warning
self.agent.logger.warning("Received LLM response in incorrect format.")
async def _process_transcription_demo(self, txt: str):
"""
Demo version to process the transcription input to beliefs. For the demo only the belief
'user_said' is relevant, so this function simply makes a dict with key: "user_said",
value: txt and passes this to the Belief Collector agent.
"""
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = Message()
belief_msg.to = (
settings.agent_settings.bdi_belief_collector_name + "@" + settings.agent_settings.host
)
belief_msg.body = payload
belief_msg.thread = "beliefs"
await self.send(belief_msg)
self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"]))

View File

@@ -1,8 +0,0 @@
from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor_behaviour import TextBeliefExtractorBehaviour
class TextBeliefExtractorAgent(BaseAgent):
async def setup(self):
self.add_behaviour(TextBeliefExtractorBehaviour())

View File

@@ -1,8 +1,8 @@
import asyncio import asyncio
import json import json
import zmq.asyncio import zmq
from spade.behaviour import CyclicBehaviour import zmq.asyncio as azmq
from zmq.asyncio import Context from zmq.asyncio import Context
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
@@ -12,109 +12,38 @@ from ..actuation.robot_speech_agent import RobotSpeechAgent
class RICommunicationAgent(BaseAgent): class RICommunicationAgent(BaseAgent):
req_socket: zmq.Socket
_address = ""
_bind = True
connected = False
def __init__( def __init__(
self, self,
jid: str, name: str,
password: str,
port: int = settings.agent_settings.default_spade_port,
verify_security: bool = False,
address=settings.zmq_settings.ri_command_address, address=settings.zmq_settings.ri_command_address,
bind=False, bind=False,
): ):
super().__init__(jid, password, port, verify_security) super().__init__(name)
self._address = address self._address = address
self._bind = bind self._bind = bind
self._req_socket: zmq.asyncio.Socket | None = None self._req_socket: azmq.Socket | None = None
self.pub_socket: zmq.asyncio.Socket | None = None self.pub_socket: azmq.Socket | None = None
self.connected = False
class ListenBehaviour(CyclicBehaviour): async def setup(self):
async def run(self):
""" """
Run the listening (ping) loop indefinetely. Try to set up the communication agent, we have `behaviour_settings.comm_setup_max_retries`
retries in case we don't have a response yet.
""" """
assert self.agent is not None self.logger.info("Setting up %s", self.name)
if not self.agent.connected: # Bind request socket
await asyncio.sleep(settings.behaviour_settings.sleep_s) await self._setup_sockets()
return
# We need to listen and sent pings. if await self._negotiate_connection():
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}} self.connected = True
seconds_to_wait_total = settings.behaviour_settings.sleep_s self.add_behavior(self._listen_loop())
try:
await asyncio.wait_for(
self.agent._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
)
except TimeoutError:
self.agent.logger.debug(
"Waited too long to send message - "
"we probably dont have any receivers... but let's check!"
)
# Wait up to {seconds_to_wait_total/2} seconds for a reply
try:
message = await asyncio.wait_for(
self.agent._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
)
# We didnt get a reply
except TimeoutError:
self.agent.logger.info(
f"No ping retrieved in {seconds_to_wait_total} seconds, "
"sending UI disconnection event and attempting to restart."
)
# Make sure we dont retry receiving messages untill we're setup.
self.agent.connected = False
self.agent.remove_behaviour(self)
# Tell UI we're disconnected.
topic = b"ping"
data = json.dumps(False).encode()
if self.agent.pub_socket is None:
self.agent.logger.warning(
"Communication agent pub socket not correctly initialized."
)
else: else:
try: self.logger.warning("Failed to negotiate connection during setup.")
await asyncio.wait_for(
self.agent.pub_socket.send_multipart([topic, data]), 5
)
except TimeoutError:
self.agent.logger.warning(
f"Initial connection ping for router timed out in {self.agent.name}."
)
# Try to reboot. self.logger.info("Finished setting up %s", self.name)
self.agent.logger.debug("Restarting communication agent.")
await self.agent.setup()
self.agent.logger.debug(f'Received message "{message}" from RI.') async def _setup_sockets(self, force=False):
if "endpoint" not in message:
self.agent.logger.warning(
"No received endpoint in message, expected ping endpoint."
)
return
# See what endpoint we received
match message["endpoint"]:
case "ping":
topic = b"ping"
data = json.dumps(True).encode()
if self.agent.pub_socket is not None:
await self.agent.pub_socket.send_multipart([topic, data])
await asyncio.sleep(settings.behaviour_settings.sleep_s)
case _:
self.agent.logger.debug(
"Received message with topic different than ping, while ping expected."
)
async def setup_sockets(self, force=False):
""" """
Sets up request socket for communication agent. Sets up request socket for communication agent.
""" """
@@ -130,21 +59,13 @@ class RICommunicationAgent(BaseAgent):
self.pub_socket = Context.instance().socket(zmq.PUB) self.pub_socket = Context.instance().socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address) self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
async def setup(self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries): async def _negotiate_connection(
""" self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries
Try to set up the communication agent, we have `behaviour_settings.comm_setup_max_retries` ):
retries in case we don't have a response yet.
"""
self.logger.info("Setting up %s", self.jid)
# Bind request socket
await self.setup_sockets()
retries = 0 retries = 0
# Let's try a certain amount of times before failing connection
while retries < max_retries: while retries < max_retries:
# Make sure the socket is properly setup.
if self._req_socket is None: if self._req_socket is None:
retries += 1
continue continue
# Send our message and receive one back # Send our message and receive one back
@@ -156,7 +77,6 @@ class RICommunicationAgent(BaseAgent):
received_message = await asyncio.wait_for( received_message = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=retry_frequency self._req_socket.recv_json(), timeout=retry_frequency
) )
except TimeoutError: except TimeoutError:
self.logger.warning( self.logger.warning(
"No connection established in %d seconds (attempt %d/%d)", "No connection established in %d seconds (attempt %d/%d)",
@@ -166,7 +86,6 @@ class RICommunicationAgent(BaseAgent):
) )
retries += 1 retries += 1
continue continue
except Exception as e: except Exception as e:
self.logger.warning("Unexpected error during negotiation: %s", e) self.logger.warning("Unexpected error during negotiation: %s", e)
retries += 1 retries += 1
@@ -187,6 +106,22 @@ class RICommunicationAgent(BaseAgent):
# At this point, we have a valid response # At this point, we have a valid response
try: try:
await self._handle_negotiation_response(received_message)
# Let UI know that we're connected
topic = b"ping"
data = json.dumps(True).encode()
if self.pub_socket:
await self.pub_socket.send_multipart([topic, data])
return True
except Exception as e:
self.logger.warning("Error unpacking negotiation data: %s", e)
retries += 1
await asyncio.sleep(settings.behaviour_settings.sleep_s)
continue
return False
async def _handle_negotiation_response(self, received_message):
for port_data in received_message["data"]: for port_data in received_message["data"]:
id = port_data["id"] id = port_data["id"]
port = port_data["port"] port = port_data["port"]
@@ -200,15 +135,13 @@ class RICommunicationAgent(BaseAgent):
match id: match id:
case "main": case "main":
if addr != self._address: if addr != self._address:
assert self._req_socket is not None
if not bind: if not bind:
self._req_socket.connect(addr) self._req_socket.connect(addr)
else: else:
self._req_socket.bind(addr) self._req_socket.bind(addr)
case "actuation": case "actuation":
ri_commands_agent = RobotSpeechAgent( ri_commands_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.robot_speech_name, settings.agent_settings.robot_speech_name,
address=addr, address=addr,
bind=bind, bind=bind,
@@ -217,34 +150,85 @@ class RICommunicationAgent(BaseAgent):
case _: case _:
self.logger.warning("Unhandled negotiation id: %s", id) self.logger.warning("Unhandled negotiation id: %s", id)
except Exception as e: async def stop(self):
self.logger.warning("Error unpacking negotiation data: %s", e) if self._req_socket:
retries += 1 self._req_socket.close()
await asyncio.sleep(1) if self.pub_socket:
self.pub_socket.close()
await super().stop()
async def _listen_loop(self):
"""
Run the listening (ping) loop indefinitely.
"""
while self._running:
if not self.connected:
await asyncio.sleep(settings.behaviour_settings.sleep_s)
continue continue
# setup succeeded # We need to listen and send pings.
break message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
seconds_to_wait_total = settings.behaviour_settings.sleep_s
try:
assert self._req_socket is not None
await asyncio.wait_for(
self._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
)
except TimeoutError:
self.logger.debug(
"Waited too long to send message - "
"we probably dont have any receivers... but let's check!"
)
else: # Wait up to {seconds_to_wait_total/2} seconds for a reply
self.logger.warning("Failed to set up %s after %d retries", self.name, max_retries) try:
return assert self._req_socket is not None
message = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
)
# Set up ping behaviour self.logger.debug(f'Received message "{message}" from RI.')
listen_behaviour = self.ListenBehaviour() if "endpoint" not in message:
self.add_behaviour(listen_behaviour) self.logger.warning("No received endpoint in message, expected ping endpoint.")
continue
# Let UI know that we're connected # See what endpoint we received
match message["endpoint"]:
case "ping":
topic = b"ping" topic = b"ping"
data = json.dumps(True).encode() data = json.dumps(True).encode()
if self.pub_socket is None: if self.pub_socket is not None:
self.logger.warning("Communication agent pub socket not correctly initialized.") await self.pub_socket.send_multipart([topic, data])
else: await asyncio.sleep(settings.behaviour_settings.sleep_s)
case _:
self.logger.debug(
"Received message with topic different than ping, while ping expected."
)
# We didnt get a reply
except TimeoutError:
self.logger.info(
f"No ping retrieved in {seconds_to_wait_total} seconds, "
"sending UI disconnection event and attempting to restart."
)
await self._handle_disconnection()
continue
except Exception:
self.logger.error("Error while waiting for ping message.", exc_info=True)
raise
async def _handle_disconnection(self):
self.connected = False
# Tell UI we're disconnected.
topic = b"ping"
data = json.dumps(False).encode()
if self.pub_socket:
try: try:
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5) await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
except TimeoutError: except TimeoutError:
self.logger.warning("Initial connection ping for router timed out in com_ri_agent.") self.logger.warning("Connection ping for router timed out.")
# Make sure to start listening now that we're connected. # Try to reboot/renegotiate
self.logger.debug("Restarting communication negotiation.")
if await self._negotiate_connection(max_retries=1):
self.connected = True self.connected = True
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -3,10 +3,9 @@ import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
import httpx import httpx
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from .llm_instructions import LLMInstructions from .llm_instructions import LLMInstructions
@@ -19,51 +18,31 @@ class LLMAgent(BaseAgent):
and responds with processed LLM output. and responds with processed LLM output.
""" """
class ReceiveMessageBehaviour(CyclicBehaviour): async def setup(self):
""" self.logger.info("Setting up %s.", self.name)
Cyclic behaviour to continuously listen for incoming messages from
the BDI Core Agent and handle them.
"""
async def run(self): async def handle_message(self, msg: InternalMessage):
""" if msg.sender == settings.agent_settings.bdi_core_name:
Receives SPADE messages and processes only those originating from the self.logger.debug("Processing message from BDI core.")
configured BDI agent.
"""
msg = await self.receive()
sender = msg.sender.node
self.agent.logger.debug(
"Received message: %s from %s",
msg.body,
sender,
)
if sender == settings.agent_settings.bdi_core_name:
self.agent.logger.debug("Processing message from BDI Core Agent")
await self._process_bdi_message(msg) await self._process_bdi_message(msg)
else: else:
self.agent.logger.debug("Message ignored (not from BDI Core Agent)") self.logger.debug("Message ignored (not from BDI core.")
async def _process_bdi_message(self, message: Message): async def _process_bdi_message(self, message: InternalMessage):
"""
Forwards user text from the BDI to the LLM and replies with the generated text in chunks
separated by punctuation.
"""
user_text = message.body user_text = message.body
# Consume the streaming generator and send a reply for every chunk
async for chunk in self._query_llm(user_text): async for chunk in self._query_llm(user_text):
await self._reply(chunk) await self._send_reply(chunk)
self.agent.logger.debug( self.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI Core Agent." "Finished processing BDI message. Response sent in chunks to BDI core."
) )
async def _reply(self, msg: str): async def _send_reply(self, msg: str):
""" """
Sends a response message back to the BDI Core Agent. Sends a response message back to the BDI Core Agent.
""" """
reply = Message( reply = InternalMessage(
to=settings.agent_settings.bdi_core_name + "@" + settings.agent_settings.host, to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=msg, body=msg,
) )
await self.send(reply) await self.send(reply)
@@ -114,10 +93,10 @@ class LLMAgent(BaseAgent):
if current_chunk: if current_chunk:
yield current_chunk yield current_chunk
except httpx.HTTPError as err: except httpx.HTTPError as err:
self.agent.logger.error("HTTP error.", exc_info=err) self.logger.error("HTTP error.", exc_info=err)
yield "LLM service unavailable." yield "LLM service unavailable."
except Exception as err: except Exception as err:
self.agent.logger.error("Unexpected error.", exc_info=err) self.logger.error("Unexpected error.", exc_info=err)
yield "Error processing the request." yield "Error processing the request."
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]: async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
@@ -149,13 +128,4 @@ class LLMAgent(BaseAgent):
if delta: if delta:
yield delta yield delta
except json.JSONDecodeError: except json.JSONDecodeError:
self.agent.logger.error("Failed to parse LLM response: %s", data) self.logger.error("Failed to parse LLM response: %s", data)
async def setup(self):
"""
Sets up the SPADE behaviour to filter and process messages from the
BDI Core Agent.
"""
behaviour = self.ReceiveMessageBehaviour()
self.add_behaviour(behaviour)
self.logger.info("LLMAgent setup complete")

View File

@@ -3,10 +3,9 @@ import asyncio
import numpy as np import numpy as np
import zmq import zmq
import zmq.asyncio as azmq import zmq.asyncio as azmq
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from .speech_recognizer import SpeechRecognizer from .speech_recognizer import SpeechRecognizer
@@ -19,53 +18,31 @@ class TranscriptionAgent(BaseAgent):
""" """
def __init__(self, audio_in_address: str): def __init__(self, audio_in_address: str):
jid = settings.agent_settings.transcription_name + "@" + settings.agent_settings.host super().__init__(settings.agent_settings.transcription_name)
super().__init__(jid, settings.agent_settings.transcription_name)
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
self.speech_recognizer = None
self._concurrency = None
class TranscribingBehaviour(CyclicBehaviour): async def setup(self):
def __init__(self, audio_in_socket: azmq.Socket): self.logger.info("Setting up %s", self.name)
super().__init__()
self._connect_audio_in_socket()
# Initialize recognizer and semaphore
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
self.audio_in_socket = audio_in_socket
self.speech_recognizer = SpeechRecognizer.best_type()
self._concurrency = asyncio.Semaphore(max_concurrent_tasks) self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
self.speech_recognizer = SpeechRecognizer.best_type()
self.speech_recognizer.load_model() # Warmup
def warmup(self): # Start background loop
"""Load the transcription model into memory to speed up the first transcription.""" self.add_behavior(self._transcribing_loop())
self.speech_recognizer.load_model()
async def _transcribe(self, audio: np.ndarray) -> str: self.logger.info("Finished setting up %s", self.name)
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it."""
receiver_jids = [
settings.agent_settings.text_belief_extractor_name
+ "@"
+ settings.agent_settings.host,
] # Set message receivers here
for receiver_jid in receiver_jids:
message = Message(to=receiver_jid, body=transcription)
await self.send(message)
async def run(self) -> None:
audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
self.agent.logger.info("Nothing transcribed.")
return
self.agent.logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)
async def stop(self): async def stop(self):
assert self.audio_in_socket is not None
self.audio_in_socket.close() self.audio_in_socket.close()
self.audio_in_socket = None self.audio_in_socket = None
return await super().stop() return await super().stop()
@@ -75,13 +52,37 @@ class TranscriptionAgent(BaseAgent):
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address) self.audio_in_socket.connect(self.audio_in_address)
async def setup(self): async def _transcribe(self, audio: np.ndarray) -> str:
self.logger.info("Setting up %s", self.jid) assert self._concurrency is not None and self.speech_recognizer is not None
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
self._connect_audio_in_socket() async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it."""
receiver_names = [
settings.agent_settings.text_belief_extractor_name,
]
transcribing = self.TranscribingBehaviour(self.audio_in_socket) for receiver_name in receiver_names:
transcribing.warmup() message = InternalMessage(
self.add_behaviour(transcribing) to=receiver_name,
sender=self.name,
body=transcription,
)
await self.send(message)
self.logger.info("Finished setting up %s", self.jid) async def _transcribing_loop(self) -> None:
while self._running:
try:
assert self.audio_in_socket is not None
audio_data = await self.audio_in_socket.recv()
audio = np.frombuffer(audio_data, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
self.logger.info("Nothing transcribed.")
continue
self.logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)
except Exception as e:
self.logger.error(f"Error in transcription loop: {e}")

View File

@@ -1,8 +1,9 @@
import asyncio
import numpy as np import numpy as np
import torch import torch
import zmq import zmq
import zmq.asyncio as azmq import zmq.asyncio as azmq
from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
@@ -26,7 +27,7 @@ class SocketPoller[T]:
:param timeout_ms: A timeout in milliseconds to wait for data. :param timeout_ms: A timeout in milliseconds to wait for data.
""" """
self.socket = socket self.socket = socket
self.poller = zmq.Poller() self.poller = azmq.Poller()
self.poller.register(self.socket, zmq.POLLIN) self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms self.timeout_ms = timeout_ms
@@ -38,81 +39,12 @@ class SocketPoller[T]:
:return: Data from the socket or None. :return: Data from the socket or None.
""" """
timeout_ms = timeout_ms or self.timeout_ms timeout_ms = timeout_ms or self.timeout_ms
socks = dict(self.poller.poll(timeout_ms)) socks = dict(await self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN: if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv() return await self.socket.recv()
return None return None
class StreamingBehaviour(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(
repo_or_dir=settings.vad_settings.repo_or_dir,
model=settings.vad_settings.model_name,
force_reload=False,
)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = (
settings.behaviour_settings.vad_initial_since_speech
) # Used to allow small pauses in speech
self._ready = False
async def reset(self):
"""Clears the ZeroMQ queue and tells this behavior to start."""
discarded = 0
# Poll for the shortest amount of time possible to clear the queue
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
self.agent.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready = True
async def run(self) -> None:
if not self._ready:
return
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
self.agent.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
return
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
prob_threshold = settings.behaviour_settings.vad_prob_threshold
if prob > prob_threshold:
if self.i_since_speech > non_speech_patience:
self.agent.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= non_speech_patience:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
self.agent.logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(BaseAgent): class VADAgent(BaseAgent):
""" """
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
@@ -120,16 +52,54 @@ class VADAgent(BaseAgent):
""" """
def __init__(self, audio_in_address: str, audio_in_bind: bool): def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_name + "@" + settings.agent_settings.host super().__init__(settings.agent_settings.vad_name)
super().__init__(jid, settings.agent_settings.vad_name)
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None
self.audio_in_poller: SocketPoller | None = None
self.streaming_behaviour: StreamingBehaviour | None = None self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = asyncio.Event()
self.model = None
async def setup(self):
self.logger.info("Setting up %s", self.name)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
self.logger.error("Could not bind output socket, stopping.")
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Initialize VAD model
try:
self.model, _ = torch.hub.load(
repo_or_dir=settings.vad_settings.repo_or_dir,
model=settings.vad_settings.model_name,
force_reload=False,
)
except Exception:
self.logger.exception("Failed to load VAD model.")
await self.stop()
return
# Warmup/reset
await self.reset_stream()
self.add_behavior(self._streaming_loop())
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
self.logger.info("Finished setting up %s", self.name)
async def stop(self): async def stop(self):
""" """
@@ -141,7 +111,7 @@ class VADAgent(BaseAgent):
if self.audio_out_socket is not None: if self.audio_out_socket is not None:
self.audio_out_socket.close() self.audio_out_socket.close()
self.audio_out_socket = None self.audio_out_socket = None
return await super().stop() await super().stop()
def _connect_audio_in_socket(self): def _connect_audio_in_socket(self):
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
@@ -156,28 +126,64 @@ class VADAgent(BaseAgent):
"""Returns the port bound, or None if binding failed.""" """Returns the port bound, or None if binding failed."""
try: try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
except zmq.ZMQBindError: except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.") self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None self.audio_out_socket = None
return None return None
async def setup(self): async def reset_stream(self):
self.logger.info("Setting up %s", self.jid) """
Clears the ZeroMQ queue and sets ready state.
"""
discarded = 0
assert self.audio_in_poller is not None
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
self.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready.set()
self._connect_audio_in_socket() async def _streaming_loop(self):
await self._ready.wait()
while self._running:
assert self.audio_in_poller is not None
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
self.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
continue
audio_out_port = self._connect_audio_out_socket() # copy otherwise Torch will be sad that it's immutable
if audio_out_port is None: chunk = np.frombuffer(data, dtype=np.float32).copy()
await self.stop() assert self.model is not None
return prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
audio_out_address = f"tcp://localhost:{audio_out_port}" non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
prob_threshold = settings.behaviour_settings.vad_prob_threshold
self.streaming_behaviour = StreamingBehaviour(self.audio_in_socket, self.audio_out_socket) if prob > prob_threshold:
self.add_behaviour(self.streaming_behaviour) if self.i_since_speech > non_speech_patience:
self.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
continue
# Start agents dependent on the output audio fragments here self.i_since_speech += 1
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
self.logger.info("Finished setting up %s", self.jid) # prob < threshold, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= non_speech_patience:
self.audio_buffer = np.append(self.audio_buffer, chunk)
continue
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
self.logger.debug("Speech ended.")
assert self.audio_out_socket is not None
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk

View File

@@ -0,0 +1,142 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from asyncio import Task
from collections.abc import Coroutine
import zmq
import zmq.asyncio as azmq
from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
# Central directory to resolve agent names to instances
_agent_directory: dict[str, "BaseAgent"] = {}
class AgentDirectory:
"""
Helper class to keep track of which agents are registered.
Used for handling message routing.
"""
@staticmethod
def register(name: str, agent: "BaseAgent"):
_agent_directory[name] = agent
@staticmethod
def get(name: str) -> "BaseAgent | None":
return _agent_directory.get(name)
class BaseAgent(ABC):
"""
Abstract base class for all agents. To make a new agent, inherit from
`control_backend.agents.BaseAgent`, not this class. That ensures that a
logger is present with the correct name pattern.
When subclassing, the `setup()` method needs to be overwritten. To handle
messages from other agents, overwrite the `handle_message()` method. To
send messages to other agents, use the `send()` method. To add custom
behaviors/tasks to the agent, use the `add_background_task()` method.
"""
logger: logging.Logger
def __init__(self, name: str):
self.name = name
self.inbox: asyncio.Queue[InternalMessage] = asyncio.Queue()
self._tasks: set[asyncio.Task] = set()
self._running = False
# Register immediately
AgentDirectory.register(name, self)
@abstractmethod
async def setup(self):
"""Overwrite this to initialize resources."""
pass
async def start(self):
"""Starts the agent and its loops."""
self.logger.info(f"Starting agent {self.name}")
self._running = True
context = azmq.Context.instance()
# Setup the internal publishing socket
self._internal_pub_socket = context.socket(zmq.PUB)
self._internal_pub_socket.connect(settings.zmq_settings.internal_pub_address)
# Setup the internal receiving socket
self._internal_sub_socket = context.socket(zmq.SUB)
self._internal_sub_socket.connect(settings.zmq_settings.internal_sub_address)
self._internal_sub_socket.subscribe(f"internal/{self.name}")
await self.setup()
# Start processing inbox and ZMQ messages
self.add_behavior(self._process_inbox())
self.add_behavior(self._receive_internal_zmq_loop())
async def stop(self):
"""Stops the agent."""
self._running = False
for task in self._tasks:
task.cancel()
self.logger.info(f"Agent {self.name} stopped")
async def send(self, message: InternalMessage):
"""
Sends a message to another agent.
"""
target = AgentDirectory.get(message.to)
if target:
await target.inbox.put(message)
self.logger.debug(f"Sent message {message.body} to {message.to} via regular inbox.")
else:
# Apparently target agent is on a different process, send via ZMQ
topic = f"internal/{message.to}".encode()
body = message.model_dump_json().encode()
await self._internal_pub_socket.send_multipart([topic, body])
self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.")
async def _process_inbox(self):
"""Default loop: equivalent to a CyclicBehaviour receiving messages."""
while self._running:
msg = await self.inbox.get()
self.logger.debug(f"Received message from {msg.sender}.")
await self.handle_message(msg)
async def _receive_internal_zmq_loop(self):
"""
Listens for internal messages sent from agents on another process via ZMQ
and puts them into the normal inbox.
"""
while self._running:
try:
_, body = await self._internal_sub_socket.recv_multipart()
msg = InternalMessage.model_validate_json(body)
await self.inbox.put(msg)
except asyncio.CancelledError:
break
except Exception:
self.logger.exception("Could not process ZMQ message.")
async def handle_message(self, msg: InternalMessage):
"""Override this to handle incoming messages."""
raise NotImplementedError
def add_behavior(self, coro: Coroutine) -> Task:
"""
Helper to add a behavior to the agent. To add asynchronous behavior to an agent, define
an `async` function and add it to the task list by calling :func:`add_behavior`
with it. This should happen in the :func:`setup` method of the agent. For an example, see:
:func:`~control_backend.agents.bdi.BDICoreAgent`.
"""
task = asyncio.create_task(coro)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
return task

View File

@@ -11,9 +11,6 @@ class ZMQSettings(BaseModel):
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
# connection settings
host: str = "localhost"
# agent names # agent names
bdi_core_name: str = "bdi_core_agent" bdi_core_name: str = "bdi_core_agent"
bdi_belief_collector_name: str = "belief_collector_agent" bdi_belief_collector_name: str = "belief_collector_agent"
@@ -25,9 +22,6 @@ class AgentSettings(BaseModel):
ri_communication_name: str = "ri_communication_agent" ri_communication_name: str = "ri_communication_agent"
robot_speech_name: str = "robot_speech_agent" robot_speech_name: str = "robot_speech_agent"
# default SPADE port
default_spade_port: int = 5222
class BehaviourSettings(BaseModel): class BehaviourSettings(BaseModel):
sleep_s: float = 1.0 sleep_s: float = 1.0
@@ -81,7 +75,7 @@ class Settings(BaseSettings):
llm_settings: LLMSettings = LLMSettings() llm_settings: LLMSettings = LLMSettings()
model_config = SettingsConfigDict(env_file=".env") model_config = SettingsConfigDict(env_file=".env", env_nested_delimiter="__")
settings = Settings() settings = Settings()

View File

@@ -7,7 +7,6 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context from zmq.asyncio import Context
# Act agents
# BDI agents # BDI agents
from control_backend.agents.bdi import ( from control_backend.agents.bdi import (
BDIBeliefCollectorAgent, BDIBeliefCollectorAgent,
@@ -60,7 +59,6 @@ async def lifespan(app: FastAPI):
# --- APPLICATION STARTUP --- # --- APPLICATION STARTUP ---
setup_logging() setup_logging()
logger.info("%s is starting up.", app.title) logger.info("%s is starting up.", app.title)
logger.warning("testing extra", extra={"extra1": "one", "extra2": "two"})
# Initiate sockets # Initiate sockets
proxy_thread = threading.Thread(target=setup_sockets) proxy_thread = threading.Thread(target=setup_sockets)
@@ -75,14 +73,12 @@ async def lifespan(app: FastAPI):
# --- Initialize Agents --- # --- Initialize Agents ---
logger.info("Initializing and starting agents.") logger.info("Initializing and starting agents.")
agents_to_start = { agents_to_start = {
"RICommunicationAgent": ( "RICommunicationAgent": (
RICommunicationAgent, RICommunicationAgent,
{ {
"name": settings.agent_settings.ri_communication_name, "name": settings.agent_settings.ri_communication_name,
"jid": f"{settings.agent_settings.ri_communication_name}"
f"@{settings.agent_settings.host}",
"password": settings.agent_settings.ri_communication_name,
"address": settings.zmq_settings.ri_communication_address, "address": settings.zmq_settings.ri_communication_address,
"bind": True, "bind": True,
}, },
@@ -91,16 +87,12 @@ async def lifespan(app: FastAPI):
LLMAgent, LLMAgent,
{ {
"name": settings.agent_settings.llm_name, "name": settings.agent_settings.llm_name,
"jid": f"{settings.agent_settings.llm_name}@{settings.agent_settings.host}",
"password": settings.agent_settings.llm_name,
}, },
), ),
"BDICoreAgent": ( "BDICoreAgent": (
BDICoreAgent, BDICoreAgent,
{ {
"name": settings.agent_settings.bdi_core_name, "name": settings.agent_settings.bdi_core_name,
"jid": f"{settings.agent_settings.bdi_core_name}@{settings.agent_settings.host}",
"password": settings.agent_settings.bdi_core_name,
"asl": "src/control_backend/agents/bdi/bdi_core_agent/rules.asl", "asl": "src/control_backend/agents/bdi/bdi_core_agent/rules.asl",
}, },
), ),
@@ -108,18 +100,12 @@ async def lifespan(app: FastAPI):
BDIBeliefCollectorAgent, BDIBeliefCollectorAgent,
{ {
"name": settings.agent_settings.bdi_belief_collector_name, "name": settings.agent_settings.bdi_belief_collector_name,
"jid": f"{settings.agent_settings.bdi_belief_collector_name}@"
f"{settings.agent_settings.host}",
"password": settings.agent_settings.bdi_belief_collector_name,
}, },
), ),
"TextBeliefExtractorAgent": ( "TextBeliefExtractorAgent": (
TextBeliefExtractorAgent, TextBeliefExtractorAgent,
{ {
"name": settings.agent_settings.text_belief_extractor_name, "name": settings.agent_settings.text_belief_extractor_name,
"jid": f"{settings.agent_settings.text_belief_extractor_name}@"
f"{settings.agent_settings.host}",
"password": settings.agent_settings.text_belief_extractor_name,
}, },
), ),
"VADAgent": ( "VADAgent": (
@@ -128,22 +114,25 @@ async def lifespan(app: FastAPI):
), ),
} }
agents = []
vad_agent = None vad_agent = None
for name, (agent_class, kwargs) in agents_to_start.items(): for name, (agent_class, kwargs) in agents_to_start.items():
try: try:
logger.debug("Starting agent: %s", name) logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"}) agent_instance = agent_class(**kwargs)
await agent_instance.start() await agent_instance.start()
if isinstance(agent_instance, VADAgent): if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance vad_agent = agent_instance
agents.append(agent_instance)
logger.info("Agent '%s' started successfully.", name) logger.info("Agent '%s' started successfully.", name)
except Exception as e: except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True) logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
# Consider if the application should continue if an agent fails to start.
raise raise
await vad_agent.streaming_behaviour.reset() assert vad_agent is not None
await vad_agent.reset_stream()
logger.info("Application startup complete.") logger.info("Application startup complete.")

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class BeliefMessage(BaseModel):
beliefs: dict[str, list[str]]

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
class InternalMessage(BaseModel):
"""
Represents a message to an agent.
"""
to: str
sender: str
body: str
thread: str | None = None

View File

@@ -1,98 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq
from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
"control_backend.agents.actuation.robot_speech_agent.zmq.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_bind(zmq_context, mocker):
"""Test setup with bind=True"""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
# Ensure behaviour attached
assert any(isinstance(b, agent.SendZMQCommandsBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False"""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup()
fake_socket.connect.assert_any_call("tcp://localhost:5555")
@pytest.mark.asyncio
async def test_send_commands_behaviour_valid_message():
"""Test behaviour with valid JSON message"""
fake_socket = AsyncMock()
message_dict = {"message": "hello"}
fake_socket.recv_multipart = AsyncMock(
return_value=(b"command", json.dumps(message_dict).encode("utf-8"))
)
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("test@server", "password")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
behaviour = agent.SendZMQCommandsBehaviour()
behaviour.agent = agent
with patch(
"control_backend.agents.actuation.robot_speech_agent.SpeechCommand"
) as MockSpeechCommand:
mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message
await behaviour.run()
fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_awaited_with(mock_message.model_dump())
@pytest.mark.asyncio
async def test_send_commands_behaviour_invalid_message():
"""Test behaviour with invalid JSON message triggers error logging"""
fake_socket = AsyncMock()
fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}"))
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("test@server", "password")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
behaviour = agent.SendZMQCommandsBehaviour()
behaviour.agent = agent
await behaviour.run()
fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_not_awaited()

View File

@@ -1,567 +0,0 @@
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
def speech_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
def fake_json_correct_negototiate_1():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True},
],
}
)
def fake_json_correct_negototiate_2():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_3():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_4():
# Different port, do bind
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 4555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_5():
# Different port, dont bind.
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 4555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_wrong_negototiate_1():
return AsyncMock(return_value={"endpoint": "ping", "data": ""})
def fake_json_invalid_id_negototiate():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "banana", "port": 4555, "bind": False},
{"id": "tomato", "port": 5557, "bind": True},
],
}
)
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
"control_backend.agents.communication.ri_communication_agent.zmq.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1()
fake_socket.send_multipart = AsyncMock()
# Mock ActSpeechAgent agent startup
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5556", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2()
fake_socket.send_multipart = AsyncMock()
# Mock ActSpeechAgent agent startup
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(zmq_context):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1()
fake_socket.send_multipart = AsyncMock()
# Mock ActSpeechAgent agent startup
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time.
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.recv_json.assert_awaited()
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
# Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(zmq_context):
"""
Test the setup of the communication agent with different bind value
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3()
fake_socket.send_multipart = AsyncMock()
# Mock ActSpeechAgent agent startup
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=True,
)
await agent.setup()
# --- Assert ---
fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4()
fake_socket.send_multipart = AsyncMock()
# Mock ActSpeechAgent agent startup
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5()
fake_socket.send_multipart = AsyncMock()
# Mock ActSpeechAgent agent startup
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(zmq_context):
"""
Test the functionality of setup with incorrect id
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate()
fake_socket.send_multipart = AsyncMock()
# Mock ActSpeechAgent agent startup
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time.
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.recv_json.assert_awaited()
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(zmq_context):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock()
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
# Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_listen_behaviour_ping_correct():
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("test@server", "password")
agent._req_socket = fake_socket
agent.connected = True
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
await behaviour.run()
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_listen_behaviour_ping_wrong_endpoint():
"""
Test if our listen behaviour can work with wrong messages (wrong endpoint)
"""
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# This is a message for the wrong endpoint >:(
fake_socket.recv_json = AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True},
],
}
)
fake_pub_socket = AsyncMock()
agent = RICommunicationAgent("test@server", "password", fake_pub_socket)
agent._req_socket = fake_socket
agent.connected = True
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
await behaviour.run()
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_listen_behaviour_timeout(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("test@server", "password")
agent._req_socket = fake_socket
agent.connected = True
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
await behaviour.run()
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
assert not agent.connected
@pytest.mark.asyncio
async def test_listen_behaviour_ping_no_endpoint():
"""
Test if our listen behaviour can work with wrong messages (wrong endpoint)
"""
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
fake_socket.send_multipart = AsyncMock()
# This is a message without endpoint >:(
fake_socket.recv_json = AsyncMock(
return_value={
"data": "I dont have an endpoint >:)",
}
)
agent = RICommunicationAgent("test@server", "password")
agent._req_socket = fake_socket
agent.connected = True
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
await behaviour.run()
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_setup_unexpected_exception(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
assert not agent.connected
@pytest.mark.asyncio
async def test_setup_unpacking_exception(zmq_context):
# --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.send_multipart = AsyncMock()
# Make recv_json return malformed negotiation data to trigger unpacking exception
malformed_data = {
"endpoint": "negotiate/ports",
"data": [{"id": "main"}],
} # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch ActSpeechAgent so it won't actually start
with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
agent = RICommunicationAgent(
"test@server",
"password",
address="tcp://localhost:5555",
bind=False,
)
# --- Act & Assert ---
await agent.setup(max_retries=1)
# Ensure no command agent was started
fake_agent_instance.start.assert_not_awaited()
# Ensure no behaviour was attached
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)

View File

@@ -1,9 +1,8 @@
import random import random
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
import zmq import zmq
from spade.agent import Agent
from control_backend.agents.perception.vad_agent import VADAgent from control_backend.agents.perception.vad_agent import VADAgent
@@ -15,11 +14,6 @@ def zmq_context(mocker):
return mock_context return mock_context
@pytest.fixture
def streaming(mocker):
return mocker.patch("control_backend.agents.perception.vad_agent.StreamingBehaviour")
@pytest.fixture @pytest.fixture
def per_transcription_agent(mocker): def per_transcription_agent(mocker):
return mocker.patch( return mocker.patch(
@@ -27,21 +21,36 @@ def per_transcription_agent(mocker):
) )
@pytest.fixture(autouse=True)
def torch_load(mocker):
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
model = MagicMock()
mock_torch.hub.load.return_value = (model, None)
mock_torch.from_numpy.side_effect = lambda arr: arr
return mock_torch
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_normal_setup(streaming, per_transcription_agent): async def test_normal_setup(per_transcription_agent):
""" """
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets, and starts the TranscriptionAgent without loading real models. sockets, and starts the TranscriptionAgent without loading real models.
""" """
per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.add_behaviour = MagicMock() per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
per_vad_agent.reset_stream = AsyncMock()
await per_vad_agent.setup() await per_vad_agent.setup()
streaming.assert_called_once()
per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
per_transcription_agent.assert_called_once() per_transcription_agent.assert_called_once()
per_transcription_agent.return_value.start.assert_called_once() per_transcription_agent.return_value.start.assert_called_once()
per_vad_agent._streaming_loop.assert_called_once()
per_vad_agent.reset_stream.assert_called_once()
assert per_vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
assert per_vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
@@ -91,16 +100,22 @@ async def test_out_socket_creation_failure(zmq_context):
""" """
Test setup failure when the audio output socket cannot be created. Test setup failure when the audio output socket cannot be created.
""" """
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError
)
per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.stop = AsyncMock()
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
async def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
await per_vad_agent.setup() await per_vad_agent.setup()
assert per_vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once() per_vad_agent.stop.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -109,6 +124,13 @@ async def test_stop(zmq_context, per_transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000, 1000,
10000, 10000,

View File

@@ -5,7 +5,24 @@ import pytest
import soundfile as sf import soundfile as sf
import zmq import zmq
from control_backend.agents.perception.vad_agent import StreamingBehaviour from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture(autouse=True)
def patch_settings():
from control_backend.agents.perception import vad_agent
vad_agent.settings.behaviour_settings.vad_prob_threshold = 0.5
vad_agent.settings.behaviour_settings.vad_non_speech_patience_chunks = 3
vad_agent.settings.behaviour_settings.vad_initial_since_speech = 0
vad_agent.settings.vad_settings.sample_rate_hz = 16_000
@pytest.fixture(autouse=True)
def mock_torch(mocker):
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
mock_torch.from_numpy.side_effect = lambda arr: arr
return mock_torch
def get_audio_chunks() -> list[bytes]: def get_audio_chunks() -> list[bytes]:
@@ -42,16 +59,39 @@ async def test_real_audio(mocker):
audio_in_socket = AsyncMock() audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)] mock_poller.return_value.poll = AsyncMock(return_value=[(audio_in_socket, zmq.POLLIN)])
audio_out_socket = AsyncMock() audio_out_socket = AsyncMock()
vad_streamer = StreamingBehaviour(audio_in_socket, audio_out_socket) vad_agent = VADAgent("tcp://localhost:12345", False)
vad_streamer._ready = True vad_agent.audio_out_socket = audio_out_socket
vad_streamer.agent = MagicMock()
for _ in audio_chunks: # Use a fake model that marks most chunks as speech and ends with a few silences
await vad_streamer.run() silence_padding = 5
probabilities = [1.0] * len(audio_chunks) + [0.0] * silence_padding
chunk_bytes = audio_chunks + [b"\x00" * len(audio_chunks[0])] * silence_padding
model_item = MagicMock()
model_item.item.side_effect = probabilities
vad_agent.model = MagicMock(return_value=model_item)
class DummyPoller:
def __init__(self, data, agent):
self.data = data
self.agent = agent
async def poll(self, timeout_ms=None):
if self.data:
return self.data.pop(0)
self.agent._running = False
return None
vad_agent.audio_in_poller = DummyPoller(chunk_bytes, vad_agent)
vad_agent._ready = AsyncMock()
vad_agent._running = True
vad_agent.i_since_speech = 0
await vad_agent._streaming_loop()
audio_out_socket.send.assert_called() audio_out_socket.send.assert_called()
for args in audio_out_socket.send.call_args_list: for args in audio_out_socket.send.call_args_list:

View File

@@ -0,0 +1,139 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
"control_backend.agents.actuation.robot_speech_agent.azmq.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_bind(zmq_context, mocker):
"""Setup binds and subscribes to internal commands."""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=True)
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
agent.add_behavior.assert_called_once()
@pytest.mark.asyncio
async def test_setup_connect(zmq_context, mocker):
"""Setup connects when bind=False."""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
await agent.setup()
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.connect.assert_any_call("tcp://internal:1234")
agent.add_behavior.assert_called_once()
@pytest.mark.asyncio
async def test_handle_message_sends_command():
"""Internal message is forwarded to robot pub socket as JSON."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent.pubsocket = pubsocket
payload = {"endpoint": "actuate/speech", "data": "hello"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_awaited_once_with(payload)
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_payload(zmq_context):
"""UI command is read from SUB and published."""
command = {"endpoint": "actuate/speech", "data": "hello"}
fake_socket = AsyncMock()
async def recv_once():
# stop after first iteration
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_awaited_once_with(command)
@pytest.mark.asyncio
async def test_zmq_command_loop_invalid_json():
"""Invalid JSON is ignored without sending."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", b"{not_json}")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_stop_closes_sockets():
pubsocket = MagicMock()
subsocket = MagicMock()
agent = RobotSpeechAgent("robot_speech")
agent.pubsocket = pubsocket
agent.subsocket = subsocket
await agent.stop()
pubsocket.close.assert_called_once()
subsocket.close.assert_called_once()

View File

@@ -1,212 +0,0 @@
import json
import logging
from unittest.mock import AsyncMock, MagicMock, call
import pytest
from control_backend.agents.bdi.bdi_core_agent.behaviours.belief_setter_behaviour import (
BeliefSetterBehaviour,
)
# Define a constant for the collector agent name to use in tests
COLLECTOR_AGENT_NAME = "belief_collector_agent"
COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test"
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock BDIAgent."""
agent = MagicMock()
agent.bdi = MagicMock()
agent.jid = "bdi_agent@test"
return agent
@pytest.fixture
def belief_setter_behaviour(mock_agent, mocker):
"""Fixture to create an instance of BeliefSetterBehaviour with a mocked agent."""
# Patch the settings to use a predictable agent name
mocker.patch(
"control_backend.agents.bdi.bdi_core_agent."
"behaviours.belief_setter_behaviour.settings.agent_settings.bdi_belief_collector_name",
COLLECTOR_AGENT_NAME,
)
setter = BeliefSetterBehaviour()
setter.agent = mock_agent
# Mock the receive method, we will control its return value in each test
setter.receive = AsyncMock()
return setter
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
msg.thread = thread
return msg
@pytest.mark.asyncio
async def test_run_message_received(belief_setter_behaviour, mocker):
"""
Test that when a message is received, _process_message is called.
"""
# Arrange
msg = MagicMock()
belief_setter_behaviour.receive.return_value = msg
mocker.patch.object(belief_setter_behaviour, "_process_message")
# Act
await belief_setter_behaviour.run()
# Assert
belief_setter_behaviour._process_message.assert_called_once_with(msg)
def test_process_message_from_bdi_belief_collector_agent(belief_setter_behaviour, mocker):
"""
Test processing a message from the correct belief collector agent.
"""
# Arrange
msg = create_mock_message(sender_node=COLLECTOR_AGENT_NAME, body="", thread="")
mock_process_belief = mocker.patch.object(belief_setter_behaviour, "_process_belief_message")
# Act
belief_setter_behaviour._process_message(msg)
# Assert
mock_process_belief.assert_called_once_with(msg)
def test_process_message_from_other_agent(belief_setter_behaviour, mocker):
"""
Test that messages from other agents are ignored.
"""
# Arrange
msg = create_mock_message(sender_node="other_agent", body="", thread="")
mock_process_belief = mocker.patch.object(belief_setter_behaviour, "_process_belief_message")
# Act
belief_setter_behaviour._process_message(msg)
# Assert
mock_process_belief.assert_not_called()
def test_process_belief_message_valid_json(belief_setter_behaviour, mocker):
"""
Test processing a valid belief message with correct thread and JSON body.
"""
# Arrange
beliefs_payload = {"is_hot": ["kitchen"], "is_clean": ["kitchen", "bathroom"]}
msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs"
)
mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act
belief_setter_behaviour._process_belief_message(msg)
# Assert
mock_set_beliefs.assert_called_once_with(beliefs_payload)
def test_process_belief_message_invalid_json(belief_setter_behaviour, mocker, caplog):
"""
Test that a message with invalid JSON is handled gracefully and an error is logged.
"""
# Arrange
msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body="this is not a json string", thread="beliefs"
)
mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act
belief_setter_behaviour._process_belief_message(msg)
# Assert
mock_set_beliefs.assert_not_called()
def test_process_belief_message_wrong_thread(belief_setter_behaviour, mocker):
"""
Test that a message with an incorrect thread is ignored.
"""
# Arrange
msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body='{"some": "data"}', thread="not_beliefs"
)
mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act
belief_setter_behaviour._process_belief_message(msg)
# Assert
mock_set_beliefs.assert_not_called()
def test_process_belief_message_empty_body(belief_setter_behaviour, mocker):
"""
Test that a message with an empty body is ignored.
"""
# Arrange
msg = create_mock_message(sender_node=COLLECTOR_AGENT_JID, body="", thread="beliefs")
mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act
belief_setter_behaviour._process_belief_message(msg)
# Assert
mock_set_beliefs.assert_not_called()
def test_set_beliefs_success(belief_setter_behaviour, mock_agent, caplog):
"""
Test that beliefs are correctly set on the agent's BDI.
"""
# Arrange
beliefs_to_set = {
"is_hot": ["kitchen"],
"door_opened": ["front_door", "back_door"],
}
# Act
with caplog.at_level(logging.INFO):
belief_setter_behaviour._set_beliefs(beliefs_to_set)
# Assert
expected_calls = [
call("is_hot", "kitchen"),
call("door_opened", "front_door", "back_door"),
]
mock_agent.bdi.set_belief.assert_has_calls(expected_calls, any_order=True)
assert mock_agent.bdi.set_belief.call_count == 2
# def test_responded_unset(belief_setter_behaviour, mock_agent):
# # Arrange
# new_beliefs = {"user_said": ["message"]}
#
# # Act
# belief_setter_behaviour._set_beliefs(new_beliefs)
#
# # Assert
# mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")])
# mock_agent.bdi.remove_belief.assert_has_calls([call("responded")])
# def test_set_beliefs_bdi_not_initialized(belief_setter_behaviour, mock_agent, caplog):
# """
# Test that a warning is logged if the agent's BDI is not initialized.
# """
# # Arrange
# mock_agent.bdi = None # Simulate BDI not being ready
# beliefs_to_set = {"is_hot": ["kitchen"]}
#
# # Act
# with caplog.at_level(logging.WARNING):
# belief_setter_behaviour._set_beliefs(beliefs_to_set)
#
# # Assert
# assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text

View File

@@ -1,101 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.bdi.belief_collector_agent.behaviours.belief_collector_behaviour import ( # noqa: E501
BeliefCollectorBehaviour,
)
def create_mock_message(sender_node: str, body: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
return msg
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock Agent."""
agent = MagicMock()
agent.jid = "belief_collector_agent@test"
return agent
@pytest.fixture
def bel_collector_behaviouror(mock_agent, mocker):
"""Fixture to create an instance of BelCollectorBehaviour with a mocked agent."""
# Patch asyncio.sleep to prevent tests from actually waiting
mocker.patch("asyncio.sleep", return_value=None)
collector = BeliefCollectorBehaviour()
collector.agent = mock_agent
# Mock the receive method, we will control its return value in each test
collector.receive = AsyncMock()
return collector
@pytest.mark.asyncio
async def test_run_message_received(bel_collector_behaviouror, mocker):
"""
Test that when a message is received, _process_message is called with that message.
"""
# Arrange
mock_msg = MagicMock()
bel_collector_behaviouror.receive.return_value = mock_msg
mocker.patch.object(bel_collector_behaviouror, "_process_message")
# Act
await bel_collector_behaviouror.run()
# Assert
bel_collector_behaviouror._process_message.assert_awaited_once_with(mock_msg)
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_type(bel_collector_behaviouror, mocker):
msg = create_mock_message(
"anyone",
json.dumps({"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}),
)
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_belief_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(bel_collector_behaviouror, mocker):
msg = create_mock_message(
"bel_text_agent_mock", json.dumps({"beliefs": {"user_said": [["hi"]]}})
)
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_belief_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_emo_text(bel_collector_behaviouror, mocker):
msg = create_mock_message("anyone", json.dumps({"type": "emotion_extraction_text"}))
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_emo_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_happy_path_sends(bel_collector_behaviouror, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
bel_collector_behaviouror.send = AsyncMock()
await bel_collector_behaviouror._handle_belief_text(payload, "bel_text_agent_mock")
# make sure we attempted a send
bel_collector_behaviouror.send.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_coerces_non_strings(bel_collector_behaviouror, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
bel_collector_behaviouror.send = AsyncMock()
await bel_collector_behaviouror._handle_belief_text(payload, "origin")
bel_collector_behaviouror.send.assert_awaited_once()

View File

@@ -0,0 +1,126 @@
import json
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import agentspeak
import pytest
from control_backend.agents.bdi.bdi_core_agent.bdi_core_agent import BDICoreAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage
@pytest.fixture
def mock_agentspeak_env():
with patch("agentspeak.runtime.Environment") as mock_env:
yield mock_env
@pytest.fixture
def agent():
agent = BDICoreAgent("bdi_agent", "dummy.asl")
agent.send = AsyncMock()
agent.bdi_agent = MagicMock()
return agent
@pytest.mark.asyncio
async def test_setup_loads_asl(mock_agentspeak_env, agent):
# Mock file opening
with patch("builtins.open", mock_open(read_data="+initial_goal.")):
await agent.setup()
# Check if environment tried to build agent
mock_agentspeak_env.return_value.build_agent.assert_called()
@pytest.mark.asyncio
async def test_setup_no_asl(mock_agentspeak_env, agent):
with patch("builtins.open", side_effect=FileNotFoundError):
await agent.setup()
mock_agentspeak_env.return_value.build_agent.assert_not_called()
@pytest.mark.asyncio
async def test_handle_belief_collector_message(agent, mock_settings):
"""Test that incoming beliefs are added to the BDI agent"""
beliefs = {"user_said": ["Hello"]}
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
# Expect bdi_agent.call to be triggered to add belief
args = agent.bdi_agent.call.call_args.args
assert args[0] == agentspeak.Trigger.addition
assert args[1] == agentspeak.GoalType.belief
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
@pytest.mark.asyncio
async def test_incorrect_belief_collector_message(agent, mock_settings):
"""Test that incorrect message format triggers an exception."""
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=json.dumps({"bad_format": "bad_format"}),
thread="beliefs",
)
await agent.handle_message(msg)
agent.bdi_agent.call.assert_not_called() # did not set belief
@pytest.mark.asyncio
async def test():
pass
@pytest.mark.asyncio
async def test_handle_llm_response(agent):
"""Test that LLM responses are forwarded to the Robot Speech Agent"""
msg = InternalMessage(
to="bdi_agent", sender=settings.agent_settings.llm_name, body="This is the LLM reply"
)
await agent.handle_message(msg)
# Verify forward
assert agent.send.called
sent_msg = agent.send.call_args[0][0]
assert sent_msg.to == settings.agent_settings.robot_speech_name
assert "This is the LLM reply" in sent_msg.body
@pytest.mark.asyncio
async def test_custom_actions(agent):
agent._send_to_llm = MagicMock(side_effect=agent.send) # Mock specific method
# Initialize actions manually since we didn't call setup with real file
agent._add_custom_actions()
# Find the action
action_fn = None
for (functor, _), fn in agent.actions.actions.items():
if functor == ".reply":
action_fn = fn
break
assert action_fn is not None
# Invoke action
mock_term = MagicMock()
mock_term.args = ["Hello"]
mock_intention = MagicMock()
# Run generator
gen = action_fn(agent, mock_term, mock_intention)
next(gen) # Execute
agent._send_to_llm.assert_called_with("Hello")

View File

@@ -0,0 +1,87 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
BDIBeliefCollectorAgent,
)
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
@pytest.fixture
def agent():
agent = BDIBeliefCollectorAgent("belief_collector_agent")
return agent
def make_msg(body: dict, sender: str = "sender"):
return InternalMessage(to="collector", sender=sender, body=json.dumps(body))
@pytest.mark.asyncio
async def test_handle_message_routes_belief_text(agent, mocker):
"""
Test that when a message is received, _handle_belief_text is called with that message.
"""
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}
spy = mocker.patch.object(agent, "_handle_belief_text", new_callable=AsyncMock)
await agent.handle_message(make_msg(payload))
spy.assert_awaited_once_with(payload, "sender")
@pytest.mark.asyncio
async def test_handle_message_routes_emotion(agent, mocker):
payload = {"type": "emotion_extraction_text"}
spy = mocker.patch.object(agent, "_handle_emo_text", new_callable=AsyncMock)
await agent.handle_message(make_msg(payload))
spy.assert_awaited_once_with(payload, "sender")
@pytest.mark.asyncio
async def test_handle_message_bad_json(agent, mocker):
agent._handle_belief_text = AsyncMock()
bad_msg = InternalMessage(to="collector", sender="sender", body="not json")
await agent.handle_message(bad_msg)
agent._handle_belief_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
await agent._handle_belief_text(payload, "origin")
spy.assert_awaited_once_with(payload["beliefs"], origin="origin")
@pytest.mark.asyncio
async def test_handle_belief_text_no_send_when_empty(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
await agent._handle_belief_text(payload, "origin")
spy.assert_not_awaited()
@pytest.mark.asyncio
async def test_send_beliefs_to_bdi(agent):
agent.send = AsyncMock()
beliefs = {"user_said": ["hello", "world"]}
await agent._send_beliefs_to_bdi(beliefs, origin="origin")
agent.send.assert_awaited_once()
sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
assert json.loads(sent.body)["beliefs"] == beliefs

View File

@@ -0,0 +1,58 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
TextBeliefExtractorAgent,
)
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
return agent
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_transcription_demo(agent, mock_settings):
transcription = "this is a test"
await agent._process_transcription_demo(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]

View File

@@ -1,191 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from spade.message import Message
from control_backend.agents.bdi.text_belief_extractor_agent.behaviours.text_belief_extractor_behaviour import ( # noqa: E501, We can't shorten this import.
TextBeliefExtractorBehaviour,
)
@pytest.fixture
def mock_settings():
"""
Mocks the settings object that the behaviour imports.
We patch it at the source where it's imported by the module under test.
"""
# Create a mock object that mimics the nested structure
settings_mock = MagicMock()
settings_mock.agent_settings.transcription_name = "transcriber"
settings_mock.agent_settings.bdi_belief_collector_name = "collector"
settings_mock.agent_settings.host = "fake.host"
# Use patch to replace the settings object during the test
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where
# your behaviour file imports it from.
with patch(
"control_backend.agents.bdi.text_belief_extractor_agent.behaviours"
".text_belief_extractor_behaviour.settings",
settings_mock,
):
yield settings_mock
@pytest.fixture
def behavior(mock_settings):
"""
Creates an instance of the BDITextBeliefBehaviour behaviour and mocks its
agent, logger, send, and receive methods.
"""
b = TextBeliefExtractorBehaviour()
b.agent = MagicMock()
b.send = AsyncMock()
b.receive = AsyncMock()
return b
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
msg.thread = thread
return msg
@pytest.mark.asyncio
async def test_run_no_message(behavior):
"""
Tests the run() method when no message is received.
"""
# Arrange: Configure receive to return None
behavior.receive.return_value = None
# Act: Run the behavior
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that no message was sent
behavior.send.assert_not_called()
@pytest.mark.asyncio
async def test_run_message_from_other_agent(behavior):
"""
Tests the run() method when a message is received from an
unknown agent (not the transcriber).
"""
# Arrange: Create a mock message from an unknown sender
mock_msg = create_mock_message("unknown", "some data", None)
behavior.receive.return_value = mock_msg
behavior._process_transcription_demo = MagicMock()
# Act
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that _process_transcription_demo was not sent
behavior._process_transcription_demo.assert_not_called()
@pytest.mark.asyncio
async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkeypatch):
"""
Tests the main success path: receiving a message from the
transcription agent, which triggers _process_transcription_demo.
"""
# Arrange: Create a mock message from the transcriber
transcription_text = "hello world"
mock_msg = create_mock_message(
mock_settings.agent_settings.transcription_name, transcription_text, None
)
behavior.receive.return_value = mock_msg
# Act
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that send was called *once*
behavior.send.assert_called_once()
# 3. Deeply inspect the message that was sent
sent_msg: Message = behavior.send.call_args[0][0]
assert (
sent_msg.to
== mock_settings.agent_settings.bdi_belief_collector_name
+ "@"
+ mock_settings.agent_settings.host
)
# Check thread
assert sent_msg.thread == "beliefs"
# Parse the received JSON string back into a dict
expected_dict = {
"beliefs": {"user_said": [transcription_text]},
"type": "belief_extraction_text",
}
sent_dict = json.loads(sent_msg.body)
# Assert that the dictionaries are equal
assert sent_dict == expected_dict
@pytest.mark.asyncio
async def test_process_transcription_success(behavior, mock_settings):
"""
Tests the (currently unused) _process_transcription method's
success path, using its hardcoded mock response.
"""
# Arrange
test_text = "I am feeling happy"
# This is the hardcoded response inside the method
expected_response_body = '{"mood": [["happy"]]}'
# Act
await behavior._process_transcription(test_text)
# Assert
# 1. Check that a message was sent
behavior.send.assert_called_once()
# 2. Inspect the sent message
sent_msg: Message = behavior.send.call_args[0][0]
expected_to = (
mock_settings.agent_settings.bdi_belief_collector_name
+ "@"
+ mock_settings.agent_settings.host
)
assert str(sent_msg.to) == expected_to
assert sent_msg.thread == "beliefs"
assert sent_msg.body == expected_response_body
@pytest.mark.asyncio
async def test_process_transcription_json_decode_error(behavior, mock_settings):
"""
Tests the _process_transcription method's error handling
when the (mocked) response is invalid JSON.
We do this by patching json.loads to raise an error.
"""
# Arrange
test_text = "I am feeling happy"
# Patch json.loads to raise an error when called
with patch("json.loads", side_effect=json.JSONDecodeError("Mock error", "", 0)):
# Act
await behavior._process_transcription(test_text)
# Assert
# 1. Check that NO message was sent
behavior.send.assert_not_called()

View File

@@ -0,0 +1,336 @@
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
def speech_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
"control_backend.agents.communication.ri_communication_agent.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
def negotiation_message(
actuation_port: int = 5556,
bind_main: bool = False,
bind_actuation: bool = True,
main_port: int = 5555,
):
return {
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": main_port, "bind": bind_main},
{"id": "actuation", "port": actuation_port, "bind": bind_actuation},
],
}
@pytest.mark.asyncio
async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
fake_socket.send_multipart = AsyncMock()
with patch(speech_agent_path(), autospec=True) as MockRobot:
robot_instance = MockRobot.return_value
robot_instance.start = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent.add_behavior = MagicMock()
await agent.setup()
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
robot_instance.start.assert_awaited_once()
MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True)
agent.add_behavior.assert_called_once()
assert agent.connected is True
@pytest.mark.asyncio
async def test_setup_binds_when_requested(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message(bind_main=True))
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
agent.add_behavior = MagicMock()
with patch(speech_agent_path(), autospec=True) as MockRobot:
MockRobot.return_value.start = AsyncMock()
await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555")
agent.add_behavior.assert_called_once()
@pytest.mark.asyncio
async def test_negotiate_invalid_endpoint_retries(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
success = await agent._negotiate_connection(max_retries=1)
assert success is False
@pytest.mark.asyncio
async def test_negotiate_timeout(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
success = await agent._negotiate_connection(max_retries=1)
assert success is False
@pytest.mark.asyncio
async def test_handle_negotiation_response_updates_req_socket(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
with patch(speech_agent_path(), autospec=True) as MockRobot:
MockRobot.return_value.start = AsyncMock()
await agent._handle_negotiation_response(
negotiation_message(
main_port=6000,
actuation_port=6001,
bind_main=False,
bind_actuation=False,
)
)
fake_socket.connect.assert_any_call("tcp://localhost:6000")
@pytest.mark.asyncio
async def test_handle_disconnection_publishes_and_reconnects():
pub_socket = AsyncMock()
agent = RICommunicationAgent("ri_comm")
agent.pub_socket = pub_socket
agent.connected = True
agent._negotiate_connection = AsyncMock(return_value=True)
await agent._handle_disconnection()
pub_socket.send_multipart.assert_awaited()
assert agent.connected is True
@pytest.mark.asyncio
async def test_listen_loop_handles_non_ping(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
async def recv_once():
agent._running = False
return {"endpoint": "negotiate/ports", "data": {}}
fake_socket.recv_json = recv_once
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
await agent._listen_loop()
fake_socket.send_json.assert_called()
@pytest.mark.asyncio
async def test_negotiate_unexpected_error(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom"))
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
assert await agent._negotiate_connection(max_retries=1) is False
@pytest.mark.asyncio
async def test_negotiate_handle_response_error(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent._handle_negotiation_response = AsyncMock(side_effect=Exception("bad response"))
assert await agent._negotiate_connection(max_retries=1) is False
@pytest.mark.asyncio
async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock()
agent = RICommunicationAgent("ri_comm")
async def swallow(coro):
coro.close()
agent.add_behavior = swallow
agent._negotiate_connection = AsyncMock(return_value=False)
await agent.setup()
assert agent.connected is False
@pytest.mark.asyncio
async def test_handle_negotiation_response_unhandled_id():
agent = RICommunicationAgent("ri_comm")
await agent._handle_negotiation_response(
{"data": [{"id": "other", "port": 5000, "bind": False}]}
)
@pytest.mark.asyncio
async def test_stop_closes_sockets():
req = MagicMock()
pub = MagicMock()
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = pub
await agent.stop()
req.close.assert_called_once()
pub.close.assert_called_once()
@pytest.mark.asyncio
async def test_listen_loop_not_connected(monkeypatch):
agent = RICommunicationAgent("ri_comm")
agent._running = True
agent.connected = False
agent._req_socket = AsyncMock()
async def fake_sleep(duration):
agent._running = False
monkeypatch.setattr("asyncio.sleep", fake_sleep)
await agent._listen_loop()
@pytest.mark.asyncio
async def test_listen_loop_send_and_recv_timeout():
req = AsyncMock()
req.send_json = AsyncMock(side_effect=TimeoutError)
req.recv_json = AsyncMock(side_effect=TimeoutError)
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
async def stop_run():
agent._running = False
agent._handle_disconnection = AsyncMock(side_effect=stop_run)
await agent._listen_loop()
agent._handle_disconnection.assert_awaited()
@pytest.mark.asyncio
async def test_listen_loop_missing_endpoint(monkeypatch):
req = AsyncMock()
req.send_json = AsyncMock()
async def recv_once():
agent._running = False
return {"data": {}}
req.recv_json = recv_once
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
await agent._listen_loop()
@pytest.mark.asyncio
async def test_listen_loop_generic_exception():
req = AsyncMock()
req.send_json = AsyncMock()
req.recv_json = AsyncMock(side_effect=ValueError("boom"))
agent = RICommunicationAgent("ri_comm")
agent._req_socket = req
agent.pub_socket = AsyncMock()
agent.connected = True
agent._running = True
with pytest.raises(ValueError):
await agent._listen_loop()
@pytest.mark.asyncio
async def test_handle_disconnection_timeout(monkeypatch):
pub = AsyncMock()
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
agent = RICommunicationAgent("ri_comm")
agent.pub_socket = pub
agent._negotiate_connection = AsyncMock(return_value=False)
await agent._handle_disconnection()
pub.send_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_listen_loop_ping_sends_internal(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
pub_socket = AsyncMock()
agent = RICommunicationAgent("ri_comm")
agent._req_socket = fake_socket
agent.pub_socket = pub_socket
agent.connected = True
agent._running = True
async def recv_once():
agent._running = False
return {"endpoint": "ping", "data": {}}
fake_socket.recv_json = recv_once
await agent._listen_loop()
pub_socket.send_multipart.assert_awaited()

View File

@@ -0,0 +1,124 @@
"""Mocks `httpx` and tests chunking logic."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def mock_httpx_client():
with patch("httpx.AsyncClient") as mock_cls:
mock_client = AsyncMock()
mock_cls.return_value.__aenter__.return_value = mock_client
yield mock_client
@pytest.mark.asyncio
async def test_llm_processing_success(mock_httpx_client, mock_settings):
# Setup the mock response for the stream
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
# Simulate stream lines
lines = [
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
b'data: {"choices": [{"delta": {"content": " world"}}]}',
b'data: {"choices": [{"delta": {"content": "."}}]}',
b"data: [DONE]",
]
async def aiter_lines_gen():
for line in lines:
yield line.decode()
mock_response.aiter_lines.side_effect = aiter_lines_gen
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
# Configure the client
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
# Setup Agent
agent = LLMAgent("llm_agent")
agent.send = AsyncMock() # Mock the send method to verify replies
# Simulate receiving a message from BDI
msg = InternalMessage(
to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body="Hi"
)
await agent.handle_message(msg)
# Verification
# "Hello world." constitutes one sentence/chunk based on punctuation split
# The agent should call send once with the full sentence
assert agent.send.called
args = agent.send.call_args[0][0]
assert args.to == mock_settings.agent_settings.bdi_core_name
assert "Hello world." in args.body
@pytest.mark.asyncio
async def test_llm_processing_errors(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
msg = InternalMessage(to="llm", sender=mock_settings.agent_settings.bdi_core_name, body="Hi")
# HTTP Error
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
await agent.handle_message(msg)
assert "LLM service unavailable." in agent.send.call_args[0][0].body
# General Exception
agent.send.reset_mock()
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
await agent.handle_message(msg)
assert "Error processing the request." in agent.send.call_args[0][0].body
@pytest.mark.asyncio
async def test_llm_json_error(mock_httpx_client, mock_settings):
# Test malformed JSON in stream
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
async def aiter_lines_gen():
yield "data: {bad_json"
yield "data: [DONE]"
mock_response.aiter_lines.side_effect = aiter_lines_gen
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
with patch.object(agent.logger, "error") as log:
msg = InternalMessage(
to="llm", sender=mock_settings.agent_settings.bdi_core_name, body="Hi"
)
await agent.handle_message(msg)
log.assert_called() # Should log JSONDecodeError
def test_llm_instructions():
# Full custom
instr = LLMInstructions(norms="N", goals="G")
text = instr.build_developer_instruction()
assert "Norms to follow:\nN" in text
assert "Goals to reach:\nG" in text
# Defaults
instr_def = LLMInstructions()
text_def = instr_def.build_developer_instruction()
assert "Norms to follow" in text_def
assert "Goals to reach" in text_def

View File

@@ -0,0 +1,122 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
MLXWhisperSpeechRecognizer,
OpenAIWhisperSpeechRecognizer,
SpeechRecognizer,
)
from control_backend.agents.perception.transcription_agent.transcription_agent import (
TranscriptionAgent,
)
@pytest.mark.asyncio
async def test_transcription_agent_flow(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
# Setup context to return this specific mock socket
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# Data: [Audio Bytes, Cancel Loop]
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
# Mock Recognizer
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = "Hello"
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent.send = AsyncMock()
agent._running = True
agent.add_behavior = AsyncMock()
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# Check transcription happened
assert mock_recognizer.recognize_speech.called
# Check sending
assert agent.send.called
assert agent.send.call_args[0][0].body == "Hello"
await agent.stop()
@pytest.mark.asyncio
async def test_transcription_empty(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# Return valid audio, but recognizer returns empty string
fake_audio = np.zeros(10, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = ""
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent.send = AsyncMock()
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# Should NOT send message
agent.send.assert_not_called()
def test_speech_recognizer_factory():
# Test Factory Logic
with patch("torch.mps.is_available", return_value=True):
assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer)
with patch("torch.mps.is_available", return_value=False):
assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer)
def test_openai_recognizer():
with patch("whisper.load_model") as load_mock:
with patch("whisper.transcribe") as trans_mock:
rec = OpenAIWhisperSpeechRecognizer()
rec.load_model()
load_mock.assert_called()
trans_mock.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"
def test_mlx_recognizer():
# Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir().
# We must use create=True to inject it into the module namespace during the test.
module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer"
with patch("sys.platform", "darwin"):
with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock:
with patch(f"{module_path}.ModelHolder", create=True) as holder_mock:
# We also need to mock mlx.core if it's used for types/constants
with patch(f"{module_path}.mx", create=True):
rec = MLXWhisperSpeechRecognizer()
rec.load_model()
holder_mock.get_model.assert_called()
mlx_mock.transcribe.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"

View File

@@ -16,8 +16,8 @@ async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test" socket_data = b"test"
socket.recv.return_value = socket_data socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)] mock_poller.return_value.poll = AsyncMock(return_value=[(socket, zmq.POLLIN)])
poller = SocketPoller(socket) poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused # Calling `poll` twice to be able to check that the poller is reused
@@ -35,8 +35,8 @@ async def test_socket_poller_with_data(socket, mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker): async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
mock_poller.return_value.poll.return_value = [] mock_poller.return_value.poll = AsyncMock(return_value=[])
poller = SocketPoller(socket) poller = SocketPoller(socket)
data = await poller.poll() data = await poller.poll()

View File

@@ -3,12 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np import numpy as np
import pytest import pytest
from control_backend.agents.perception.vad_agent import StreamingBehaviour from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture @pytest.fixture
@@ -17,22 +12,8 @@ def audio_out_socket():
@pytest.fixture @pytest.fixture
def mock_agent(mocker): def vad_agent(audio_out_socket):
"""Fixture to create a mock BDIAgent.""" return VADAgent("tcp://localhost:5555", False)
agent = MagicMock()
agent.jid = "vad_agent@test"
return agent
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket, mock_agent):
import torch
torch.hub.load.return_value = (..., ...) # Mock
streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
streaming._ready = True
streaming.agent = mock_agent
return streaming
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -61,25 +42,40 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f
""" """
model_item = MagicMock() model_item = MagicMock()
model_item.item.side_effect = probabilities model_item.item.side_effect = probabilities
streaming.model = MagicMock() streaming.model = MagicMock(return_value=model_item)
streaming.model.return_value = model_item
audio_in_poller = AsyncMock() # Prepare deterministic audio chunks and a poller that stops the loop when exhausted
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32) chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes()
streaming.audio_in_poller = audio_in_poller chunks = [chunk_bytes for _ in probabilities]
for _ in probabilities: class DummyPoller:
await streaming.run() def __init__(self, data, agent):
self.data = data
self.agent = agent
async def poll(self, timeout_ms=None):
if self.data:
return self.data.pop(0)
# Stop the loop cleanly once we've consumed all chunks
self.agent._running = False
return None
streaming.audio_in_poller = DummyPoller(chunks, streaming)
streaming._ready = AsyncMock()
streaming._running = True
await streaming._streaming_loop()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming): async def test_voice_activity_detected(audio_out_socket, vad_agent):
""" """
Test a scenario where there is voice activity detected between silences. Test a scenario where there is voice activity detected between silences.
""" """
speech_chunk_count = 5 speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5 probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
await simulate_streaming_with_probabilities(streaming, probabilities) vad_agent.audio_out_socket = audio_out_socket
await simulate_streaming_with_probabilities(vad_agent, probabilities)
audio_out_socket.send.assert_called_once() audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0] data = audio_out_socket.send.call_args[0][0]
@@ -88,7 +84,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming): async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
""" """
Test a scenario where there is a short pause between speech, checking whether it ignores the Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause. short pause.
@@ -97,7 +93,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
probabilities = ( probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5 [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
) )
await simulate_streaming_with_probabilities(streaming, probabilities) vad_agent.audio_out_socket = audio_out_socket
await simulate_streaming_with_probabilities(vad_agent, probabilities)
audio_out_socket.send.assert_called_once() audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0] data = audio_out_socket.send.call_args[0][0]
@@ -107,15 +104,22 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming): async def test_no_data(audio_out_socket, vad_agent):
""" """
Test a scenario where there is no data received. This should not cause errors. Test a scenario where there is no data received. This should not cause errors.
""" """
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
await streaming.run() class DummyPoller:
async def poll(self, timeout_ms=None):
vad_agent._running = False
return None
vad_agent.audio_out_socket = audio_out_socket
vad_agent.audio_in_poller = DummyPoller()
vad_agent._ready = AsyncMock()
vad_agent._running = True
await vad_agent._streaming_loop()
audio_out_socket.send.assert_not_called() audio_out_socket.send.assert_not_called()
assert len(streaming.audio_buffer) == 0 assert len(vad_agent.audio_buffer) == 0

View File

@@ -1,66 +1,43 @@
import sys from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from control_backend.core.agent_system import _agent_directory
def pytest_configure(config): @pytest.fixture(autouse=True)
def reset_agent_directory():
""" """
This hook runs at the start of the pytest session, before any tests are Automatically clears the global agent directory before and after each test
collected. It mocks heavy or unavailable modules to prevent ImportErrors. to prevent state leakage between tests.
""" """
# --- Mock spade and spade-bdi --- _agent_directory.clear()
mock_agentspeak = MagicMock() yield
mock_httpx = MagicMock() _agent_directory.clear()
mock_pydantic = MagicMock()
mock_spade = MagicMock()
mock_spade.agent = MagicMock()
mock_spade.behaviour = MagicMock()
mock_spade.message = MagicMock()
mock_spade_bdi = MagicMock()
mock_spade_bdi.bdi = MagicMock()
mock_spade.agent.Message = MagicMock()
mock_spade.behaviour.CyclicBehaviour = type("CyclicBehaviour", (object,), {})
mock_spade_bdi.bdi.BDIAgent = type("BDIAgent", (object,), {})
sys.modules["agentspeak"] = mock_agentspeak @pytest.fixture
sys.modules["httpx"] = mock_httpx def mock_settings():
sys.modules["pydantic"] = mock_pydantic with patch("control_backend.core.config.settings") as mock:
sys.modules["spade"] = mock_spade # Set default values that match the pydantic model defaults
sys.modules["spade.agent"] = mock_spade.agent # to avoid AttributeErrors during tests
sys.modules["spade.behaviour"] = mock_spade.behaviour mock.zmq_settings.internal_pub_address = "tcp://localhost:5560"
sys.modules["spade.message"] = mock_spade.message mock.zmq_settings.internal_sub_address = "tcp://localhost:5561"
sys.modules["spade_bdi"] = mock_spade_bdi mock.zmq_settings.ri_command_address = "tcp://localhost:0000"
sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi mock.agent_settings.bdi_core_name = "bdi_core_agent"
mock.agent_settings.bdi_belief_collector_name = "belief_collector_agent"
mock.agent_settings.llm_name = "llm_agent"
mock.agent_settings.robot_speech_name = "robot_speech_agent"
mock.agent_settings.transcription_name = "transcription_agent"
mock.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent"
mock.agent_settings.vad_name = "vad_agent"
mock.behaviour_settings.sleep_s = 0.01 # Speed up tests
mock.behaviour_settings.comm_setup_max_retries = 1
yield mock
# --- Mock the config module to prevent Pydantic ImportError ---
mock_config_module = MagicMock()
# The code under test does `from ... import settings`, so our mock module @pytest.fixture
# must have a `settings` attribute. We'll make it a MagicMock so we can def mock_zmq_context():
# configure it later in our tests using mocker.patch. with patch("zmq.asyncio.Context") as mock:
mock_config_module.settings = MagicMock() mock.instance.return_value = MagicMock()
yield mock
sys.modules["control_backend.core.config"] = mock_config_module
# --- Mock torch and zmq for VAD ---
mock_torch = MagicMock()
mock_zmq = MagicMock()
mock_zmq.asyncio = mock_zmq
# In individual tests, these can be imported and the return values changed
sys.modules["torch"] = mock_torch
sys.modules["zmq"] = mock_zmq
sys.modules["zmq.asyncio"] = mock_zmq.asyncio
# --- Mock whisper ---
mock_whisper = MagicMock()
mock_mlx = MagicMock()
mock_mlx.core = MagicMock()
mock_mlx_whisper = MagicMock()
mock_mlx_whisper.transcribe = MagicMock()
sys.modules["whisper"] = mock_whisper
sys.modules["mlx"] = mock_mlx
sys.modules["mlx.core"] = mock_mlx
sys.modules["mlx_whisper"] = mock_mlx_whisper
sys.modules["mlx_whisper.transcribe"] = mock_mlx_whisper.transcribe

View File

@@ -0,0 +1,72 @@
"""Test the base class logic, message passing and background task handling."""
import asyncio
import logging
from unittest.mock import AsyncMock
import pytest
from control_backend.core.agent_system import AgentDirectory, BaseAgent, InternalMessage
class ConcreteTestAgent(BaseAgent):
logger = logging.getLogger("test")
def __init__(self, name: str):
super().__init__(name)
self.received = []
async def setup(self):
pass
async def handle_message(self, msg: InternalMessage):
self.received.append(msg)
if msg.body == "stop":
await self.stop()
@pytest.mark.asyncio
async def test_agent_lifecycle():
agent = ConcreteTestAgent("lifecycle_agent")
await agent.start()
assert agent._running is True
# Test background task
async def dummy_task():
pass
task = agent.add_behavior(dummy_task())
assert task in agent._tasks
await task
# Wait for task to finish
assert task not in agent._tasks
assert len(agent._tasks) == 2 # message handling tasks are still running
await agent.stop()
assert agent._running is False
await asyncio.sleep(0.01)
# Tasks should be cancelled
assert len(agent._tasks) == 0
@pytest.mark.asyncio
async def test_send_unknown_agent():
agent = ConcreteTestAgent("sender")
msg = InternalMessage(to="unknown_receiver", sender="sender", body="boo")
agent._internal_pub_socket = AsyncMock()
await agent.send(msg)
agent._internal_pub_socket.send_multipart.assert_called()
@pytest.mark.asyncio
async def test_get_agent():
agent = ConcreteTestAgent("registrant")
assert AgentDirectory.get("registrant") == agent
assert AgentDirectory.get("non_existent") is None

View File

@@ -0,0 +1,14 @@
"""Test if settings load correctly and environment variables override defaults."""
from control_backend.core.config import Settings
def test_default_settings():
settings = Settings()
assert settings.app_title == "PepperPlus"
def test_env_override(monkeypatch):
monkeypatch.setenv("APP_TITLE", "TestPepper")
settings = Settings()
assert settings.app_title == "TestPepper"

View File

@@ -0,0 +1,88 @@
import logging
from unittest.mock import mock_open, patch
import pytest
from control_backend.logging.setup_logging import add_logging_level, setup_logging
def test_add_logging_level():
# Add a unique level to avoid conflicts with other tests/libraries
level_name = "TESTLEVEL"
level_num = 35
add_logging_level(level_name, level_num)
assert logging.getLevelName(level_num) == level_name
assert hasattr(logging, level_name)
assert hasattr(logging.getLoggerClass(), level_name.lower())
# Test functionality
logger = logging.getLogger("test_custom_level")
with patch.object(logger, "_log") as mock_log:
getattr(logger, level_name.lower())("message")
mock_log.assert_called_with(level_num, "message", ())
# Test duplicates
with pytest.raises(AttributeError):
add_logging_level(level_name, level_num)
with pytest.raises(AttributeError):
add_logging_level("INFO", 20) # Existing level
def test_setup_logging_no_file(caplog):
with patch("os.path.exists", return_value=False):
setup_logging("dummy.yaml")
assert "Logging config file not found" in caplog.text
def test_setup_logging_yaml_error(caplog):
with patch("os.path.exists", return_value=True):
with patch("builtins.open", mock_open(read_data="invalid: [yaml")):
with patch("logging.config.dictConfig") as mock_dict_config:
setup_logging("config.yaml")
# Verify we logged the warning
assert "Could not load logging configuration" in caplog.text
# Verify dictConfig was called with empty dict (which would crash real dictConfig)
mock_dict_config.assert_called_with({})
assert "Could not load logging configuration" in caplog.text
def test_setup_logging_success():
config_data = """
version: 1
handlers:
console:
class: logging.StreamHandler
root:
handlers: [console]
level: INFO
custom_levels:
MYLEVEL: 15
"""
with patch("os.path.exists", return_value=True):
with patch("builtins.open", mock_open(read_data=config_data)):
with patch("logging.config.dictConfig") as mock_dict_config:
setup_logging("config.yaml")
mock_dict_config.assert_called()
assert hasattr(logging, "MYLEVEL")
def test_setup_logging_zmq_handler(mock_zmq_context):
config_data = """
version: 1
handlers:
ui:
class: logging.NullHandler
# In real config this would be a zmq handler, but for unit test logic
# we just want to see if the socket injection happens
"""
with patch("os.path.exists", return_value=True):
with patch("builtins.open", mock_open(read_data=config_data)):
with patch("logging.config.dictConfig") as mock_dict_config:
setup_logging("config.yaml")
args = mock_dict_config.call_args[0][0]
assert "interface_or_socket" in args["handlers"]["ui"]

959
uv.lock generated

File diff suppressed because it is too large Load Diff