The Big One #43

Merged
k.marinus merged 93 commits from feat/reset-experiment-and-phase into dev 2026-01-26 19:20:45 +00:00
34 changed files with 2172 additions and 81 deletions
Showing only changes of commit 57fe3ae3f6 - Show all commits

View File

@@ -3,6 +3,7 @@ version: 1
custom_levels:
OBSERVATION: 25
ACTION: 26
LLM: 9
formatters:
# Console output
@@ -26,7 +27,7 @@ handlers:
stream: ext://sys.stdout
ui:
class: zmq.log.handlers.PUBHandler
level: DEBUG
level: LLM
formatter: json_experiment
# Level of external libraries
@@ -36,5 +37,5 @@ root:
loggers:
control_backend:
level: DEBUG
level: LLM
handlers: [ui]

View File

@@ -1 +1,2 @@
from .robot_gesture_agent import RobotGestureAgent as RobotGestureAgent
from .robot_speech_agent import RobotSpeechAgent as RobotSpeechAgent

View File

@@ -0,0 +1,162 @@
import json
import zmq
import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
class RobotGestureAgent(BaseAgent):
"""
This agent acts as a bridge between the control backend and the Robot Interface (RI).
It receives gesture commands from other agents or from the UI,
and forwards them to the robot via a ZMQ PUB socket.
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
:ivar pubsocket: ZMQ PUB socket for sending commands to the Robot Interface.
:ivar address: Address to bind/connect the PUB socket.
:ivar bind: Whether to bind or connect the PUB socket.
:ivar gesture_data: A list of strings for available gestures
"""
subsocket: azmq.Socket
repsocket: azmq.Socket
pubsocket: azmq.Socket
address = ""
bind = False
gesture_data = []
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
bind=False,
gesture_data=None,
):
self.gesture_data = gesture_data or []
super().__init__(name)
self.address = address
self.bind = bind
async def setup(self):
"""
Initialize the agent.
1. Sets up the PUB socket to talk to the robot.
2. Sets up the SUB socket to listen for "command" topics (from UI/External).
3. Starts the loop for handling ZMQ commands.
"""
self.logger.info("Setting up %s", self.name)
context = azmq.Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind:
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures")
# REP socket for replying to gesture requests
self.repsocket = context.socket(zmq.REP)
self.repsocket.bind(settings.zmq_settings.internal_gesture_rep_adress)
self.add_behavior(self._zmq_command_loop())
self.add_behavior(self._fetch_gestures_loop())
self.logger.info("Finished setting up %s", self.name)
async def stop(self):
if self.subsocket:
self.subsocket.close()
if self.pubsocket:
self.pubsocket.close()
await super().stop()
async def handle_message(self, msg: InternalMessage):
"""
Handle commands received from other internal Python agents.
Validates the message as a :class:`GestureCommand` and forwards it to the robot.
:param msg: The internal message containing the command.
"""
try:
gesture_command = GestureCommand.model_validate_json(msg.body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.gesture_data:
self.logger.warning(
"Received gesture tag '%s' which is not in available tags. Early returning",
gesture_command.data,
)
return
await self.pubsocket.send_json(gesture_command.model_dump())
except Exception:
self.logger.exception("Error processing internal message.")
async def _zmq_command_loop(self):
"""
Loop to handle commands received via ZMQ (e.g., from the UI).
Listens on the 'command' topic, validates the JSON and forwards it to the robot.
"""
while self._running:
try:
topic, body = await self.subsocket.recv_multipart()
# Don't process send_gestures here
if topic != b"command":
continue
body = json.loads(body)
gesture_command = GestureCommand.model_validate(body)
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
if gesture_command.data not in self.gesture_data:
self.logger.warning(
"Received gesture tag '%s' which is not in available tags.\
Early returning",
gesture_command.data,
)
continue
await self.pubsocket.send_json(gesture_command.model_dump())
except Exception:
self.logger.exception("Error processing ZMQ message.")
async def _fetch_gestures_loop(self):
"""
Loop to handle fetching gestures received via ZMQ (e.g., from the UI).
Listens on the 'send_gestures' topic, and returns a list on the get_gestures topic.
"""
while self._running:
try:
# Get a request
body = await self.repsocket.recv()
# Figure out amount, if specified
try:
body = json.loads(body)
except json.JSONDecodeError:
body = None
amount = None
if isinstance(body, int):
amount = body
# Fetch tags from gesture data and respond
tags = self.gesture_data[:amount] if amount else self.gesture_data
response = json.dumps({"tags": tags}).encode()
await self.repsocket.send(response)
except Exception:
self.logger.exception("Error fetching gesture tags.")

View File

@@ -29,7 +29,7 @@ class RobotSpeechAgent(BaseAgent):
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_command_address,
address: str,
bind=False,
):
super().__init__(name)

View File

@@ -6,9 +6,11 @@ import zmq.asyncio as azmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.config import settings
from ..actuation.robot_speech_agent import RobotSpeechAgent
from ..perception import VADAgent
class RICommunicationAgent(BaseAgent):
@@ -179,12 +181,24 @@ class RICommunicationAgent(BaseAgent):
else:
self._req_socket.bind(addr)
case "actuation":
ri_commands_agent = RobotSpeechAgent(
gesture_data = port_data.get("gestures", [])
robot_speech_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name,
address=addr,
bind=bind,
)
await ri_commands_agent.start()
robot_gesture_agent = RobotGestureAgent(
settings.agent_settings.robot_gesture_name,
address=addr,
bind=bind,
gesture_data=gesture_data,
)
await robot_speech_agent.start()
await asyncio.sleep(0.1) # Small delay
await robot_gesture_agent.start()
case "audio":
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
await vad_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)

View File

@@ -125,7 +125,7 @@ class LLMAgent(BaseAgent):
full_message += token
current_chunk += token
self.logger.info(
self.logger.llm(
"Received token: %s",
full_message,
extra={"reference": message_id}, # Used in the UI to update old logs

View File

@@ -8,6 +8,7 @@ import zmq.asyncio as azmq
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
from .transcription_agent.transcription_agent import TranscriptionAgent
@@ -61,6 +62,7 @@ class VADAgent(BaseAgent):
:ivar audio_in_address: Address of the input audio stream.
:ivar audio_in_bind: Whether to bind or connect to the input address.
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
@@ -79,6 +81,8 @@ class VADAgent(BaseAgent):
self.audio_out_socket: azmq.Socket | None = None
self.audio_in_poller: SocketPoller | None = None
self.program_sub_socket: azmq.Socket | None = None
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = asyncio.Event()
@@ -90,9 +94,10 @@ class VADAgent(BaseAgent):
1. Connects audio input socket.
2. Binds audio output socket (random port).
3. Loads VAD model from Torch Hub.
4. Starts the streaming loop.
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
3. Connects to program communication socket.
4. Loads VAD model from Torch Hub.
5. Starts the streaming loop.
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
"""
self.logger.info("Setting up %s", self.name)
@@ -105,6 +110,11 @@ class VADAgent(BaseAgent):
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Connect to internal communication socket
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.program_sub_socket.subscribe(PROGRAM_STATUS)
# Initialize VAD model
try:
self.model, _ = torch.hub.load(
@@ -117,10 +127,8 @@ class VADAgent(BaseAgent):
await self.stop()
return
# Warmup/reset
await self.reset_stream()
self.add_behavior(self._streaming_loop())
self.add_behavior(self._status_loop())
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
@@ -165,7 +173,7 @@ class VADAgent(BaseAgent):
self.audio_out_socket = None
return None
async def reset_stream(self):
async def _reset_stream(self):
"""
Clears the ZeroMQ queue and sets ready state.
"""
@@ -176,6 +184,23 @@ class VADAgent(BaseAgent):
self.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready.set()
async def _status_loop(self):
"""Loop for checking program status. Only start listening if program is RUNNING."""
while self._running:
topic, body = await self.program_sub_socket.recv_multipart()
if topic != PROGRAM_STATUS:
continue
if body != ProgramStatus.RUNNING.value:
continue
# Program is now running, we can start our stream
await self._reset_stream()
# We don't care about further status updates
self.program_sub_socket.close()
break
async def _streaming_loop(self):
"""
Main loop for processing audio stream.

View File

@@ -8,15 +8,15 @@ from fastapi.responses import StreamingResponse
from zmq.asyncio import Context, Socket
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
from control_backend.schemas.ri_message import GestureCommand, SpeechCommand
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
@router.post("/command/speech", status_code=202)
async def receive_command_speech(command: SpeechCommand, request: Request):
"""
Send a direct speech command to the robot.
@@ -27,14 +27,32 @@ async def receive_command(command: SpeechCommand, request: Request):
:param command: The speech command payload.
:param request: The FastAPI request object.
"""
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}
return {"status": "Speech command received"}
@router.post("/command/gesture", status_code=202)
async def receive_command_gesture(command: GestureCommand, request: Request):
"""
Send a direct gesture command to the robot.
Publishes the command to the internal 'command' topic. The
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotGestureAgent`
will forward this to the robot.
:param command: The speech command payload.
:param request: The FastAPI request object.
"""
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Gesture command received"}
@router.get("/ping_check")
@@ -45,6 +63,41 @@ async def ping(request: Request):
pass
@router.get("/commands/gesture/tags")
async def get_available_gesture_tags(request: Request, count=0):
"""
Endpoint to retrieve the available gesture tags for the robot.
:param request: The FastAPI request object.
:return: A list of available gesture tags.
"""
req_socket = Context.instance().socket(zmq.REQ)
req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress)
# Check to see if we've got any count given in the query parameter
amount = count or None
timeout = 5 # seconds
await req_socket.send(f"{amount}".encode() if amount else b"None")
try:
body = await asyncio.wait_for(req_socket.recv(), timeout=timeout)
except TimeoutError:
body = '{"tags": []}'
logger.debug("Got timeout error fetching gestures.")
# Handle empty response and JSON decode errors
available_tags = []
if body:
try:
available_tags = json.loads(body).get("tags", [])
except json.JSONDecodeError as e:
logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}")
# Return empty list on JSON error
available_tags = []
return {"available_gesture_tags": available_tags}
@router.get("/ping_stream")
async def ping_stream(request: Request):
"""

View File

@@ -17,7 +17,7 @@ class ZMQSettings(BaseModel):
internal_sub_address: str = "tcp://localhost:5561"
ri_command_address: str = "tcp://localhost:0000"
ri_communication_address: str = "tcp://*:5555"
vad_agent_address: str = "tcp://localhost:5558"
internal_gesture_rep_adress: str = "tcp://localhost:7788"
class AgentSettings(BaseModel):
@@ -47,6 +47,7 @@ class AgentSettings(BaseModel):
transcription_name: str = "transcription_agent"
ri_communication_name: str = "ri_communication_agent"
robot_speech_name: str = "robot_speech_agent"
robot_gesture_name: str = "robot_gesture_agent"
class BehaviourSettings(BaseModel):

View File

@@ -4,6 +4,7 @@ import os
import yaml
import zmq
from zmq.log.handlers import PUBHandler
from control_backend.core.config import settings
@@ -51,15 +52,27 @@ def setup_logging(path: str = ".logging_config.yaml") -> None:
logging.warning(f"Could not load logging configuration: {e}")
config = {}
if "custom_levels" in config:
for level_name, level_num in config["custom_levels"].items():
add_logging_level(level_name, level_num)
custom_levels = config.get("custom_levels", {}) or {}
for level_name, level_num in custom_levels.items():
add_logging_level(level_name, level_num)
if config.get("handlers") is not None and config.get("handlers").get("ui"):
pub_socket = zmq.Context.instance().socket(zmq.PUB)
pub_socket.connect(settings.zmq_settings.internal_pub_address)
config["handlers"]["ui"]["interface_or_socket"] = pub_socket
logging.config.dictConfig(config)
# Patch ZMQ PUBHandler to know about custom levels
if custom_levels:
for logger_name in ("control_backend",):
logger = logging.getLogger(logger_name)
for handler in logger.handlers:
if isinstance(handler, PUBHandler):
# Use the INFO formatter as the default template
default_fmt = handler.formatters[logging.INFO]
for level_num in custom_levels.values():
handler.setFormatter(default_fmt, level=level_num)
else:
logging.warning("Logging config file not found. Using default logging configuration.")

View File

@@ -39,13 +39,11 @@ from control_backend.agents.communication import RICommunicationAgent
# LLM Agents
from control_backend.agents.llm import LLMAgent
# Perceive agents
from control_backend.agents.perception import VADAgent
# Other backend imports
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.logging import setup_logging
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
logger = logging.getLogger(__name__)
@@ -95,6 +93,8 @@ async def lifespan(app: FastAPI):
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
app.state.endpoints_pub_socket = endpoints_pub_socket
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STARTING.value])
# --- Initialize Agents ---
logger.info("Initializing and starting agents.")
@@ -132,10 +132,6 @@ async def lifespan(app: FastAPI):
"name": settings.agent_settings.text_belief_extractor_name,
},
),
"VADAgent": (
VADAgent,
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
),
"ProgramManagerAgent": (
BDIProgramManager,
{
@@ -146,32 +142,28 @@ async def lifespan(app: FastAPI):
agents = []
vad_agent = None
for name, (agent_class, kwargs) in agents_to_start.items():
try:
logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**kwargs)
await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance
agents.append(agent_instance)
logger.info("Agent '%s' started successfully.", name)
except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
raise
assert vad_agent is not None
await vad_agent.reset_stream()
logger.info("Application startup complete.")
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.RUNNING.value])
yield
# --- APPLICATION SHUTDOWN ---
logger.info("%s is shutting down.", app.title)
# Potential shutdown logic goes here
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
# Additional shutdown logic goes here
logger.info("Application shutdown complete.")

View File

@@ -0,0 +1,16 @@
from enum import Enum
PROGRAM_STATUS = b"internal/program_status"
"""A topic key for the program status."""
class ProgramStatus(Enum):
"""
Used in internal communication, to tell agents what the status of the program is.
For example, the VAD agent only starts listening when the program is RUNNING.
"""
STARTING = b"starting"
RUNNING = b"running"
STOPPING = b"stopping"

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class RIEndpoint(str, Enum):
@@ -10,6 +10,8 @@ class RIEndpoint(str, Enum):
"""
SPEECH = "actuate/speech"
GESTURE_SINGLE = "actuate/gesture/single"
GESTURE_TAG = "actuate/gesture/tag"
PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports"
@@ -36,3 +38,27 @@ class SpeechCommand(RIMessage):
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str
class GestureCommand(RIMessage):
"""
A specific command to make the robot do a gesture.
:ivar endpoint: Should be ``RIEndpoint.GESTURE_SINGLE`` or ``RIEndpoint.GESTURE_TAG``.
:ivar data: The id of the gesture to be executed.
"""
endpoint: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] - We validate this stricter rule ourselves
RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG
]
data: str
@model_validator(mode="after")
def check_endpoint(self):
allowed = {
RIEndpoint.GESTURE_SINGLE,
RIEndpoint.GESTURE_TAG,
}
if self.endpoint not in allowed:
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
return self

View File

@@ -5,6 +5,7 @@ import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
@pytest.fixture
@@ -43,14 +44,12 @@ async def test_normal_setup(per_transcription_agent):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
per_vad_agent.reset_stream = AsyncMock()
await per_vad_agent.setup()
per_transcription_agent.assert_called_once()
per_transcription_agent.return_value.start.assert_called_once()
per_vad_agent._streaming_loop.assert_called_once()
per_vad_agent.reset_stream.assert_called_once()
assert per_vad_agent.audio_in_socket is not None
assert per_vad_agent.audio_out_socket is not None
@@ -103,7 +102,7 @@ async def test_out_socket_creation_failure(zmq_context):
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.stop = AsyncMock()
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
@@ -124,7 +123,7 @@ async def test_stop(zmq_context, per_transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.reset_stream = AsyncMock()
per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
@@ -142,3 +141,66 @@ async def test_stop(zmq_context, per_transcription_agent):
assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert per_vad_agent.audio_in_socket is None
assert per_vad_agent.audio_out_socket is None
@pytest.mark.asyncio
async def test_application_startup_complete(zmq_context):
"""Check that it resets the stream when the program finishes startup."""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
]
await vad_agent._status_loop()
vad_agent._reset_stream.assert_called_once()
vad_agent.program_sub_socket.close.assert_called_once()
@pytest.mark.asyncio
async def test_application_other_status(zmq_context):
"""
Check that it does nothing when the internal communication message is a status update, but not
running.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
]
try:
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()
@pytest.mark.asyncio
async def test_application_message_other(zmq_context):
"""
Check that it does nothing when there's an internal communication message that is not a status
update.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
try:
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()

View File

@@ -0,0 +1,444 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.ri_message import RIEndpoint
@pytest.fixture
def zmq_context(mocker):
"""Mock the ZMQ context."""
mock_context = mocker.patch(
"control_backend.agents.actuation.robot_gesture_agent.azmq.Context.instance"
)
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_bind(zmq_context, mocker):
"""Setup binds and subscribes to internal commands."""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=True)
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
await agent.setup()
# Check PUB socket binding
fake_socket.bind.assert_any_call("tcp://localhost:5556")
# Check REP socket binding
fake_socket.bind.assert_called()
# Check SUB socket connection and subscriptions
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures")
# Check behavior was added (twice: once for command loop, once for fetch gestures loop)
assert agent.add_behavior.call_count == 2
@pytest.mark.asyncio
async def test_setup_connect(zmq_context, mocker):
"""Setup connects when bind=False."""
fake_socket = zmq_context.return_value.socket.return_value
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=False)
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
agent.add_behavior = MagicMock()
await agent.setup()
# Check PUB socket connection (not binding)
fake_socket.connect.assert_any_call("tcp://localhost:5556")
fake_socket.connect.assert_any_call("tcp://internal:1234")
# Check REP socket binding (always binds)
fake_socket.bind.assert_called()
# Check behavior was added (twice)
assert agent.add_behavior.call_count == 2
@pytest.mark.asyncio
async def test_handle_message_sends_valid_gesture_command():
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.pubsocket = pubsocket
payload = {
"endpoint": RIEndpoint.GESTURE_TAG,
"data": "hello", # "hello" is in gesture_data
}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_handle_message_sends_non_gesture_command():
"""Internal message with non-gesture endpoint is not forwarded by this agent."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.pubsocket = pubsocket
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
# Non-gesture endpoints should not be forwarded by this agent
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_message_rejects_invalid_gesture_tag():
"""Internal message with invalid gesture tag is not forwarded."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.pubsocket = pubsocket
# Use a tag that's not in gesture_data
payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
await agent.handle_message(msg)
pubsocket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_gesture_payload():
"""UI command with valid gesture tag is read from SUB and published."""
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"}
fake_socket = AsyncMock()
async def recv_once():
# stop after first iteration
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_zmq_command_loop_valid_non_gesture_payload():
"""UI command with non-gesture endpoint is not forwarded by this agent."""
command = {"endpoint": "some_other_endpoint", "data": "anything"}
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_invalid_gesture_tag():
"""UI command with invalid gesture tag is not forwarded."""
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", json.dumps(command).encode("utf-8"))
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_invalid_json():
"""Invalid JSON is ignored without sending."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"command", b"{not_json}")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_zmq_command_loop_ignores_send_gestures_topic():
"""send_gestures topic is ignored in command loop."""
fake_socket = AsyncMock()
async def recv_once():
agent._running = False
return (b"send_gestures", b"{}")
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
await agent._zmq_command_loop()
fake_socket.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_fetch_gestures_loop_without_amount():
"""Fetch gestures request without amount returns all tags."""
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
return b"{}" # Empty JSON request
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_repsocket.send.assert_awaited_once()
# Check the response contains all tags
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert "tags" in response
assert response["tags"] == ["hello", "yes", "no", "wave", "point"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_amount():
"""Fetch gestures request with amount returns limited tags."""
fake_repsocket = AsyncMock()
amount = 3
async def recv_once():
agent._running = False
return json.dumps(amount).encode()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert "tags" in response
assert len(response["tags"]) == amount
assert response["tags"] == ["hello", "yes", "no"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_integer_request():
"""Fetch gestures request with integer amount."""
fake_repsocket = AsyncMock()
amount = 2
async def recv_once():
agent._running = False
return json.dumps(amount).encode()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_invalid_json():
"""Invalid JSON request returns all tags."""
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
return b"not_json"
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes", "no"]
@pytest.mark.asyncio
async def test_fetch_gestures_loop_with_non_integer_json():
"""Non-integer JSON request returns all tags."""
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
return json.dumps({"not": "an_integer"}).encode()
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.repsocket = fake_repsocket
agent._running = True
await agent._fetch_gestures_loop()
fake_repsocket.send.assert_awaited_once()
args, kwargs = fake_repsocket.send.call_args
response = json.loads(args[0])
assert response["tags"] == ["hello", "yes", "no"]
def test_gesture_data_attribute():
"""Test that gesture_data returns the expected list."""
gesture_data = ["hello", "yes", "no", "wave"]
agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data)
assert agent.gesture_data == gesture_data
assert isinstance(agent.gesture_data, list)
assert len(agent.gesture_data) == 4
assert "hello" in agent.gesture_data
assert "yes" in agent.gesture_data
assert "no" in agent.gesture_data
assert "invalid_tag_not_in_list" not in agent.gesture_data
@pytest.mark.asyncio
async def test_stop_closes_sockets():
"""Stop method closes all sockets."""
pubsocket = MagicMock()
subsocket = MagicMock()
repsocket = MagicMock()
agent = RobotGestureAgent("robot_gesture")
agent.pubsocket = pubsocket
agent.subsocket = subsocket
agent.repsocket = repsocket
await agent.stop()
pubsocket.close.assert_called_once()
subsocket.close.assert_called_once()
# Note: repsocket is not closed in stop() method, but you might want to add it
# repsocket.close.assert_called_once()
@pytest.mark.asyncio
async def test_initialization_with_custom_gesture_data():
"""Agent can be initialized with custom gesture data."""
custom_gestures = ["custom1", "custom2", "custom3"]
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures)
assert agent.gesture_data == custom_gestures
@pytest.mark.asyncio
async def test_fetch_gestures_loop_handles_exception():
"""Exception in fetch gestures loop is caught and logged."""
fake_repsocket = AsyncMock()
async def recv_once():
agent._running = False
raise Exception("Test exception")
fake_repsocket.recv = recv_once
fake_repsocket.send = AsyncMock()
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"])
agent.repsocket = fake_repsocket
agent.logger = MagicMock()
agent._running = True
# Should not raise exception
await agent._fetch_gestures_loop()
# Exception should be logged
agent.logger.exception.assert_called_once()

View File

@@ -8,6 +8,11 @@ from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
from control_backend.core.agent_system import InternalMessage
def mock_speech_agent():
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
return agent
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
@@ -56,7 +61,7 @@ async def test_setup_connect(zmq_context, mocker):
async def test_handle_message_sends_command():
"""Internal message is forwarded to robot pub socket as JSON."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
payload = {"endpoint": "actuate/speech", "data": "hello"}
@@ -80,7 +85,7 @@ async def test_zmq_command_loop_valid_payload(zmq_context):
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -101,7 +106,7 @@ async def test_zmq_command_loop_invalid_json():
fake_socket.recv_multipart = recv_once
fake_socket.send_json = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
agent._running = True
@@ -115,7 +120,7 @@ async def test_zmq_command_loop_invalid_json():
async def test_handle_message_invalid_payload():
"""Invalid payload is caught and does not send."""
pubsocket = AsyncMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
@@ -129,7 +134,7 @@ async def test_handle_message_invalid_payload():
async def test_stop_closes_sockets():
pubsocket = MagicMock()
subsocket = MagicMock()
agent = RobotSpeechAgent("robot_speech")
agent = mock_speech_agent()
agent.pubsocket = pubsocket
agent.subsocket = subsocket

View File

@@ -1,4 +1,6 @@
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import agentspeak
@@ -77,11 +79,6 @@ async def test_incorrect_belief_collector_message(agent, mock_settings):
agent.bdi_agent.call.assert_not_called() # did not set belief
@pytest.mark.asyncio
async def test():
pass
@pytest.mark.asyncio
async def test_handle_llm_response(agent):
"""Test that LLM responses are forwarded to the Robot Speech Agent"""
@@ -124,3 +121,148 @@ async def test_custom_actions(agent):
next(gen) # Execute
agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal")
def test_add_belief_sets_event(agent):
"""Test that a belief triggers wake event and call()"""
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="test_belief", arguments=["a", "b"])
agent._apply_beliefs([belief])
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
def test_apply_beliefs_empty_returns(agent):
"""Line: if not beliefs: return"""
agent._wake_bdi_loop = MagicMock()
agent._apply_beliefs([])
agent.bdi_agent.call.assert_not_called()
agent._wake_bdi_loop.set.assert_not_called()
def test_remove_belief_success_wakes_loop(agent):
"""Line: if result: wake set"""
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = True
agent._remove_belief("remove_me", ["x"])
assert agent.bdi_agent.call.called
trigger, goaltype, literal, *_ = agent.bdi_agent.call.call_args.args
assert trigger == agentspeak.Trigger.removal
assert goaltype == agentspeak.GoalType.belief
assert literal.functor == "remove_me"
assert literal.args[0].functor == "x"
agent._wake_bdi_loop.set.assert_called()
def test_remove_belief_failure_does_not_wake(agent):
"""Line: else result is False"""
agent._wake_bdi_loop = MagicMock()
agent.bdi_agent.call.return_value = False
agent._remove_belief("not_there", ["y"])
assert agent.bdi_agent.call.called # removal was attempted
agent._wake_bdi_loop.set.assert_not_called()
def test_remove_all_with_name_wakes_loop(agent):
"""Cover _remove_all_with_name() removed counter + wake"""
agent._wake_bdi_loop = MagicMock()
fake_literal = agentspeak.Literal("delete_me", (agentspeak.Literal("arg1"),))
fake_key = ("delete_me", 1)
agent.bdi_agent.beliefs = {fake_key: {fake_literal}}
agent._remove_all_with_name("delete_me")
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
@pytest.mark.asyncio
async def test_bdi_step_true_branch_hits_line_67(agent):
"""Force step() to return True once so line 67 is actually executed"""
# counter that isn't tied to MagicMock.call_count ordering
counter = {"i": 0}
def fake_step():
counter["i"] += 1
return counter["i"] == 1 # True only first time
# Important: wrap fake_step into another mock so `.called` still exists
agent.bdi_agent.step = MagicMock(side_effect=fake_step)
agent.bdi_agent.shortest_deadline = MagicMock(return_value=None)
agent._running = True
agent._wake_bdi_loop = asyncio.Event()
agent._wake_bdi_loop.set()
task = asyncio.create_task(agent._bdi_loop())
await asyncio.sleep(0.01)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert agent.bdi_agent.step.called
assert counter["i"] >= 1 # proves True branch ran
def test_replace_belief_calls_remove_all(agent):
"""Cover: if belief.replace: self._remove_all_with_name()"""
agent._remove_all_with_name = MagicMock()
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="user_said", arguments=["Hello"], replace=True)
agent._apply_beliefs([belief])
agent._remove_all_with_name.assert_called_with("user_said")
@pytest.mark.asyncio
async def test_send_to_llm_creates_prompt_and_sends(agent):
"""Cover entire _send_to_llm() including message send and logger.info"""
agent.bdi_agent = MagicMock() # ensure mocked BDI does not interfere
agent._wake_bdi_loop = MagicMock()
await agent._send_to_llm("hello world", "n1\nn2", "g1")
# send() was called
assert agent.send.called
sent_msg: InternalMessage = agent.send.call_args.args[0]
# Message routing values correct
assert sent_msg.to == settings.agent_settings.llm_name
assert "hello world" in sent_msg.body
# JSON contains split norms/goals
body = json.loads(sent_msg.body)
assert body["norms"] == ["n1", "n2"]
assert body["goals"] == ["g1"]
@pytest.mark.asyncio
async def test_deadline_sleep_branch(agent):
"""Specifically assert the if deadline: sleep → maybe_more_work=True branch"""
future_deadline = time.time() + 0.005
agent.bdi_agent.step.return_value = False
agent.bdi_agent.shortest_deadline.return_value = future_deadline
start_time = time.time()
agent._running = True
agent._wake_bdi_loop = asyncio.Event()
agent._wake_bdi_loop.set()
task = asyncio.create_task(agent._bdi_loop())
await asyncio.sleep(0.01)
task.cancel()
duration = time.time() - start_time
assert duration >= 0.004 # loop slept until deadline

View File

@@ -0,0 +1,77 @@
import asyncio
import json
import sys
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.program import Program
# Fix Windows Proactor loop for zmq
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def make_valid_program_json(norm="N1", goal="G1"):
return json.dumps(
{
"phases": [
{
"id": "phase1",
"label": "Phase 1",
"triggers": [],
"norms": [{"id": "n1", "label": "Norm 1", "norm": norm}],
"goals": [
{"id": "g1", "label": "Goal 1", "description": goal, "achieved": False}
],
}
]
}
)
@pytest.mark.asyncio
async def test_send_to_bdi():
manager = BDIProgramManager(name="program_manager_test")
manager.send = AsyncMock()
program = Program.model_validate_json(make_valid_program_json())
await manager._send_to_bdi(program)
assert manager.send.await_count == 1
msg: InternalMessage = manager.send.await_args[0][0]
assert msg.thread == "beliefs"
beliefs = BeliefMessage.model_validate_json(msg.body)
names = {b.name: b.arguments for b in beliefs.beliefs}
assert "norms" in names and names["norms"] == ["N1"]
assert "goals" in names and names["goals"] == ["G1"]
@pytest.mark.asyncio
async def test_receive_programs_valid_and_invalid():
sub = AsyncMock()
sub.recv_multipart.side_effect = [
(b"program", b"{bad json"),
(b"program", make_valid_program_json().encode()),
]
manager = BDIProgramManager(name="program_manager_test")
manager.sub_socket = sub
manager._send_to_bdi = AsyncMock()
try:
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
await manager._receive_programs()
except StopAsyncIteration:
pass
# Only valid Program should have triggered _send_to_bdi
assert manager._send_to_bdi.await_count == 1
forwarded: Program = manager._send_to_bdi.await_args[0][0]
assert forwarded.phases[0].norms[0].norm == "N1"
assert forwarded.phases[0].goals[0].description == "G1"

View File

@@ -87,3 +87,49 @@ async def test_send_beliefs_to_bdi(agent):
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]
@pytest.mark.asyncio
async def test_setup_executes(agent):
"""Covers setup and asserts the agent has a name."""
await agent.setup()
assert agent.name == "belief_collector_agent" # simple property assertion
@pytest.mark.asyncio
async def test_handle_message_unrecognized_type_executes(agent):
"""Covers the else branch for unrecognized message type."""
payload = {"type": "unknown_type"}
msg = make_msg(payload, sender="tester")
# Wrap send to ensure nothing is sent
agent.send = AsyncMock()
await agent.handle_message(msg)
# Assert no messages were sent
agent.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_emo_text_executes(agent):
"""Covers the _handle_emo_text method."""
# The method does nothing, but we can assert it returns None
result = await agent._handle_emo_text({}, "origin")
assert result is None
@pytest.mark.asyncio
async def test_send_beliefs_to_bdi_empty_executes(agent):
"""Covers early return when beliefs are empty."""
agent.send = AsyncMock()
await agent._send_beliefs_to_bdi({})
# Assert that nothing was sent
agent.send.assert_not_awaited()
@pytest.mark.asyncio
async def test_handle_belief_text_invalid_returns_none(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "invalid-argument"}}
result = await agent._handle_belief_text(payload, "origin")
# The method itself returns None
assert result is None

View File

@@ -56,3 +56,10 @@ async def test_process_transcription_demo(agent, mock_settings):
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]
@pytest.mark.asyncio
async def test_setup_initializes_beliefs(agent):
"""Covers the setup method and ensures beliefs are initialized."""
await agent.setup()
assert agent.beliefs == {"mood": ["X"], "car": ["Y"]}

View File

@@ -10,6 +10,10 @@ def speech_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
def gesture_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotGestureAgent"
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
@@ -22,7 +26,7 @@ def zmq_context(mocker):
def negotiation_message(
actuation_port: int = 5556,
bind_main: bool = False,
bind_actuation: bool = True,
bind_actuation: bool = False,
main_port: int = 5555,
):
return {
@@ -41,9 +45,12 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
fake_socket.send_multipart = AsyncMock()
with patch(speech_agent_path(), autospec=True) as MockRobot:
robot_instance = MockRobot.return_value
robot_instance.start = AsyncMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent.add_behavior = MagicMock()
@@ -52,9 +59,17 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
robot_instance.start.assert_awaited_once()
MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True)
MockSpeech.return_value.start.assert_awaited_once()
MockGesture.return_value.start.assert_awaited_once()
MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False)
MockGesture.assert_called_once_with(
ANY,
address="tcp://localhost:5556",
bind=False,
gesture_data=[],
)
agent.add_behavior.assert_called_once()
assert agent.connected is True
@@ -69,10 +84,13 @@ async def test_setup_binds_when_requested(zmq_context):
agent.add_behavior = MagicMock()
with patch(speech_agent_path(), autospec=True) as MockRobot:
MockRobot.return_value.start = AsyncMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555")
agent.add_behavior.assert_called_once()
@@ -88,7 +106,6 @@ async def test_negotiate_invalid_endpoint_retries(zmq_context):
agent._req_socket = fake_socket
success = await agent._negotiate_connection(max_retries=1)
assert success is False
@@ -112,8 +129,12 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
with patch(speech_agent_path(), autospec=True) as MockRobot:
MockRobot.return_value.start = AsyncMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
await agent._handle_negotiation_response(
negotiation_message(
main_port=6000,
@@ -135,7 +156,6 @@ async def test_handle_disconnection_publishes_and_reconnects():
agent._negotiate_connection = AsyncMock(return_value=True)
await agent._handle_disconnection()
pub_socket.send_multipart.assert_awaited()
assert agent.connected is True
@@ -192,7 +212,7 @@ async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
fake_socket.recv_json = AsyncMock()
agent = RICommunicationAgent("ri_comm")
async def swallow(coro):
def swallow(coro):
coro.close()
agent.add_behavior = swallow
@@ -334,3 +354,13 @@ async def test_listen_loop_ping_sends_internal(zmq_context):
await agent._listen_loop()
pub_socket.send_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_negotiate_req_socket_none_causes_retry(zmq_context):
agent = RICommunicationAgent("ri_comm")
agent._req_socket = None
result = await agent._negotiate_connection(max_retries=1)
assert result is False

View File

@@ -49,6 +49,9 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock() # Mock the send method to verify replies
mock_logger = MagicMock()
agent.logger = mock_logger
# Simulate receiving a message from BDI
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
@@ -134,3 +137,128 @@ def test_llm_instructions():
text_def = instr_def.build_developer_instruction()
assert "Norms to follow" in text_def
assert "Goals to reach" in text_def
@pytest.mark.asyncio
async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, mock_settings):
"""
Covers the ValidationError branch:
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
Assert: no message is sent.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Invalid JSON that triggers ValidationError in LLMPromptMessage
invalid_json = '{"text": "Hi", "wrong_field": 123}' # field not in schema
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=invalid_json,
)
await agent.handle_message(msg)
# Should not send any reply
agent.send.assert_not_called()
@pytest.mark.asyncio
async def test_handle_message_ignored_sender_branch_no_send(mock_httpx_client, mock_settings):
"""
Covers the else branch for messages not from BDI core:
else:
self.logger.debug("Message ignored (not from BDI core.")
Assert: no message is sent.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
msg = InternalMessage(
to="llm_agent",
sender="some_other_agent", # Not BDI core
body='{"text": "Hi"}',
)
await agent.handle_message(msg)
# Should not send any reply
agent.send.assert_not_called()
@pytest.mark.asyncio
async def test_query_llm_yields_final_tail_chunk(mock_settings):
"""
Covers the branch: if current_chunk: yield current_chunk
Ensure that the last partial chunk is emitted.
"""
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Patch _stream_query_llm to yield tokens that do NOT end with punctuation
async def fake_stream(messages):
yield "Hello"
yield " world" # No punctuation to trigger the normal chunking
agent._stream_query_llm = fake_stream
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
# Collect chunks yielded
chunks = []
async for chunk in agent._query_llm(prompt.text, prompt.norms, prompt.goals):
chunks.append(chunk)
# The final chunk should be yielded
assert chunks[-1] == "Hello world"
assert any("Hello" in c for c in chunks)
@pytest.mark.asyncio
async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_settings):
"""
Covers: if not line or not line.startswith("data: "): continue
Feed lines that are empty or do not start with 'data:' and check they are skipped.
"""
# Mock response
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
lines = [
"", # empty line
"not data", # invalid prefix
'data: {"choices": [{"delta": {"content": "Hi"}}]}',
"data: [DONE]",
]
async def aiter_lines_gen():
for line in lines:
yield line
mock_response.aiter_lines.side_effect = aiter_lines_gen
# Proper async context manager for stream
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
# Make stream return the async context manager
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
# Patch settings for local LLM URL
with patch("control_backend.agents.llm.llm_agent.settings") as mock_sett:
mock_sett.llm_settings.local_llm_url = "http://localhost"
mock_sett.llm_settings.local_llm_model = "test-model"
# Collect tokens
tokens = []
async for token in agent._stream_query_llm([]):
tokens.append(token)
# Only the valid 'data:' line should yield content
assert tokens == ["Hi"]

View File

@@ -120,3 +120,83 @@ def test_mlx_recognizer():
mlx_mock.transcribe.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"
@pytest.mark.asyncio
async def test_transcription_loop_continues_after_error(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [
fake_audio, # first iteration → recognizer fails
asyncio.CancelledError(), # second iteration → stop loop
]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.side_effect = RuntimeError("fail")
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent._running = True # ← REQUIRED to enter the loop
agent.send = AsyncMock() # should never be called
agent.add_behavior = AsyncMock() # match other tests
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# recognizer failed, so we should never send anything
agent.send.assert_not_called()
# recv must have been called twice (audio then CancelledError)
assert mock_sub.recv.call_count == 2
@pytest.mark.asyncio
async def test_transcription_continue_branch_when_empty(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# First recv → audio chunk
# Second recv → Cancel loop → stop iteration
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = "" # <— triggers the continue branch
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
# Make loop runnable
agent._running = True
agent.send = AsyncMock()
agent.add_behavior = AsyncMock()
await agent.setup()
# Execute loop manually
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# → Because of "continue", NO sending should occur
agent.send.assert_not_called()
# → Continue was hit, so we must have read exactly 2 times:
# - first audio
# - second CancelledError
assert mock_sub.recv.call_count == 2
# → recognizer was called once (first iteration)
assert mock_recognizer.recognize_speech.call_count == 1

View File

@@ -1,7 +1,8 @@
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import zmq
from control_backend.agents.perception.vad_agent import VADAgent
@@ -123,3 +124,44 @@ async def test_no_data(audio_out_socket, vad_agent):
audio_out_socket.send.assert_not_called()
assert len(vad_agent.audio_buffer) == 0
@pytest.mark.asyncio
async def test_vad_model_load_failure_stops_agent(vad_agent):
"""
Test that if loading the VAD model raises an Exception, it is caught,
the agent logs an exception, stops itself, and setup returns.
"""
# Patch torch.hub.load to raise an exception
with patch(
"control_backend.agents.perception.vad_agent.torch.hub.load",
side_effect=Exception("model fail"),
):
# Patch stop to an AsyncMock so we can check it was awaited
vad_agent.stop = AsyncMock()
result = await vad_agent.setup()
# Assert stop was called
vad_agent.stop.assert_awaited_once()
# Assert setup returned None
assert result is None
@pytest.mark.asyncio
async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog):
"""
Test that if binding the output socket raises ZMQBindError,
audio_out_socket is set to None, None is returned, and an error is logged.
"""
mock_socket = MagicMock()
mock_socket.bind_to_random_port.side_effect = zmq.ZMQBindError()
with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx:
mock_ctx.return_value.socket.return_value = mock_socket
with caplog.at_level("ERROR"):
port = vad_agent._connect_audio_out_socket()
assert port is None
assert vad_agent.audio_out_socket is None
assert caplog.text is not None

View File

@@ -0,0 +1,63 @@
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from starlette.responses import StreamingResponse
from control_backend.api.v1.endpoints import logs
@pytest.fixture
def client():
"""TestClient with logs router included."""
app = FastAPI()
app.include_router(logs.router)
return TestClient(app)
@pytest.mark.asyncio
async def test_log_stream_endpoint_lines(client):
"""Call /logs/stream with a mocked ZMQ socket to cover all lines."""
# Dummy socket to mock ZMQ behavior
class DummySocket:
def __init__(self):
self.subscribed = []
self.connected = False
self.recv_count = 0
def subscribe(self, topic):
self.subscribed.append(topic)
def connect(self, addr):
self.connected = True
async def recv_multipart(self):
# Return one message, then stop generator
if self.recv_count == 0:
self.recv_count += 1
return (b"INFO", b"test message")
else:
raise StopAsyncIteration
dummy_socket = DummySocket()
# Patch Context.instance().socket() to return dummy socket
with patch("control_backend.api.v1.endpoints.logs.Context.instance") as mock_context:
mock_context.return_value.socket.return_value = dummy_socket
# Call the endpoint directly
response = await logs.log_stream()
assert isinstance(response, StreamingResponse)
# Fetch one chunk from the generator
gen = response.body_iterator
chunk = await gen.__anext__()
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
assert "data:" in chunk
# Optional: assert subscribe/connect were called
assert dummy_socket.subscribed # at least some log levels subscribed
assert dummy_socket.connected # connect was called

View File

@@ -0,0 +1,45 @@
import json
import pytest
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import message
@pytest.fixture
def client():
"""FastAPI TestClient for the message router."""
from fastapi import FastAPI
app = FastAPI()
app.include_router(message.router)
return TestClient(app)
def test_receive_message_post(client, monkeypatch):
"""Test POST /message endpoint sends message to pub socket."""
# Dummy pub socket to capture sent messages
class DummyPubSocket:
def __init__(self):
self.sent = []
async def send_multipart(self, msg):
self.sent.append(msg)
dummy_socket = DummyPubSocket()
# Patch app.state.endpoints_pub_socket
client.app.state.endpoints_pub_socket = dummy_socket
data = {"message": "Hello world"}
response = client.post("/message", json=data)
assert response.status_code == 202
assert response.json() == {"status": "Message received"}
# Ensure the message was sent via pub_socket
assert len(dummy_socket.sent) == 1
topic, body = dummy_socket.sent[0]
parsed = json.loads(body.decode("utf-8"))
assert parsed["message"] == "Hello world"

View File

@@ -1,12 +1,14 @@
# tests/test_robot_endpoints.py
import json
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq.asyncio
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import robot
from control_backend.schemas.ri_message import SpeechCommand
from control_backend.schemas.ri_message import GestureCommand, SpeechCommand
@pytest.fixture
@@ -26,7 +28,27 @@ def client(app):
return TestClient(app)
def test_receive_command_success(client):
@pytest.fixture
def mock_zmq_context():
"""Mock the ZMQ context used by the endpoint module."""
with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context:
context_instance = MagicMock()
mock_context.return_value = context_instance
yield context_instance
@pytest.fixture
def mock_sockets(mock_zmq_context):
"""Optional helper if you want both a sub and req/push socket available."""
mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_zmq_context.socket.return_value = mock_sub_socket
return {"sub": mock_sub_socket, "req": mock_req_socket}
def test_receive_speech_command_success(client):
"""
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
@@ -40,11 +62,11 @@ def test_receive_command_success(client):
speech_command = SpeechCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
response = client.post("/command/speech", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
assert response.json() == {"status": "Speech command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
@@ -52,13 +74,48 @@ def test_receive_command_success(client):
)
def test_receive_command_invalid_payload(client):
def test_receive_gesture_command_success(client):
"""
Test for successful reception of a command that is a gesture command.
Ensures the status code is 202 and the response body is correct.
"""
# Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
command_data = {"endpoint": "actuate/gesture/tag", "data": "happy"}
gesture_command = GestureCommand(**command_data)
# Act
response = client.post("/command/gesture", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Gesture command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", gesture_command.model_dump_json().encode()]
)
def test_receive_speech_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
response = client.post("/command/speech", json=bad_payload)
assert response.status_code == 422 # validation error
def test_receive_gesture_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command/gesture", json=bad_payload)
assert response.status_code == 422 # validation error
@@ -69,6 +126,9 @@ def test_ping_check_returns_none(client):
assert response.json() is None
# ----------------------------
# ping_stream tests (unchanged behavior)
# ----------------------------
@pytest.mark.asyncio
async def test_ping_stream_yields_ping_event(monkeypatch):
"""Test that ping_stream yields a proper SSE message when a ping is received."""
@@ -81,6 +141,11 @@ async def test_ping_stream_yields_ping_event(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
# patch settings address used by ping_stream
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
@@ -94,7 +159,7 @@ async def test_ping_stream_yields_ping_event(monkeypatch):
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@@ -111,6 +176,10 @@ async def test_ping_stream_handles_timeout(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(return_value=True)
@@ -120,7 +189,7 @@ async def test_ping_stream_handles_timeout(monkeypatch):
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@@ -139,6 +208,10 @@ async def test_ping_stream_yields_json_values(monkeypatch):
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_settings = MagicMock()
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
monkeypatch.setattr(robot, "settings", mock_settings)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
@@ -151,6 +224,192 @@ async def test_ping_stream_yields_json_values(monkeypatch):
assert "connected" in event_text
assert "true" in event_text
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
# ----------------------------
# Updated get_available_gesture_tags tests (REQ socket on tcp://localhost:7788)
# ----------------------------
@pytest.mark.asyncio
async def test_get_available_gesture_tags_success(client, monkeypatch):
"""
Test successful retrieval of available gesture tags using a REQ socket.
"""
# Arrange
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": ["wave", "nod", "point", "dance"]}
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
# Replace logger methods to avoid noisy logs in tests
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]}
# Verify ZeroMQ REQ interactions
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
mock_req_socket.recv.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_available_gesture_tags_with_amount(client, monkeypatch):
"""
The endpoint currently ignores the 'amount' TODO, so behavior is the same as 'success'.
This test asserts that the endpoint still sends b"None" and returns the tags.
"""
# Arrange
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": ["wave", "nod"]}
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": ["wave", "nod"]}
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
@pytest.mark.asyncio
async def test_get_available_gesture_tags_timeout(client, monkeypatch):
"""
Test timeout scenario when fetching gesture tags. Endpoint should handle TimeoutError
and return an empty list while logging the timeout.
"""
# Arrange
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
mock_req_socket.recv = AsyncMock(side_effect=TimeoutError)
mock_context = MagicMock()
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
# Patch logger.debug so we can assert it was called with the expected message
mock_debug = MagicMock()
monkeypatch.setattr(robot.logger, "debug", mock_debug)
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []}
# Verify the timeout was logged using the exact string from the endpoint code
mock_debug.assert_called_once_with("Got timeout error fetching gestures.")
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
mock_req_socket.send.assert_awaited_once_with(b"None")
mock_req_socket.recv.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_available_gesture_tags_empty_response(client, monkeypatch):
"""
Test scenario when response contains an empty 'tags' list.
"""
# Arrange
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"tags": []}
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []}
@pytest.mark.asyncio
async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch):
"""
Test scenario when response JSON doesn't contain 'tags' key.
"""
# Arrange
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
response_data = {"some_other_key": "value"}
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
mock_context = MagicMock()
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
monkeypatch.setattr(robot.logger, "error", MagicMock())
# Act
response = client.get("/commands/gesture/tags")
# Assert
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []}
@pytest.mark.asyncio
async def test_get_available_gesture_tags_invalid_json(client, monkeypatch):
"""
Test scenario when response contains invalid JSON. Endpoint should log the error
and return an empty list.
"""
# Arrange
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
mock_req_socket.connect = MagicMock()
mock_req_socket.send = AsyncMock()
mock_req_socket.recv = AsyncMock(return_value=b"invalid json")
mock_context = MagicMock()
mock_context.socket.return_value = mock_req_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_error = MagicMock()
monkeypatch.setattr(robot.logger, "error", mock_error)
monkeypatch.setattr(robot.logger, "debug", MagicMock())
# Act
response = client.get("/commands/gesture/tags")
# Assert - invalid JSON should lead to empty list and error log invocation
assert response.status_code == 200
assert response.json() == {"available_gesture_tags": []}
assert mock_error.call_count == 1

View File

@@ -0,0 +1,16 @@
from fastapi.routing import APIRoute
from control_backend.api.v1.router import api_router # <--- corrected import
def test_router_includes_expected_paths():
"""Ensure api_router includes main router prefixes."""
routes = [r for r in api_router.routes if isinstance(r, APIRoute)]
paths = [r.path for r in routes]
# Ensure at least one route under each prefix exists
assert any(p.startswith("/robot") for p in paths)
assert any(p.startswith("/message") for p in paths)
assert any(p.startswith("/sse") for p in paths)
assert any(p.startswith("/logs") for p in paths)
assert any(p.startswith("/program") for p in paths)

View File

@@ -0,0 +1,24 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import sse
@pytest.fixture
def app():
app = FastAPI()
app.include_router(sse.router)
return app
@pytest.fixture
def client(app):
return TestClient(app)
def test_sse_route_exists(client):
"""Minimal smoke test to ensure /sse route exists and responds."""
response = client.get("/sse")
# Since implementation is not done, we only assert it doesn't crash
assert response.status_code == 200

View File

@@ -2,7 +2,7 @@
import asyncio
import logging
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -70,3 +70,142 @@ async def test_get_agent():
agent = ConcreteTestAgent("registrant")
assert AgentDirectory.get("registrant") == agent
assert AgentDirectory.get("non_existent") is None
class DummyAgent(BaseAgent):
async def setup(self):
pass # we will test this separately
async def handle_message(self, msg: InternalMessage):
self.last_handled = msg
@pytest.mark.asyncio
async def test_base_agent_setup_is_noop():
agent = DummyAgent("dummy")
# Should simply return without error
assert await agent.setup() is None
@pytest.mark.asyncio
async def test_send_to_local_agent(monkeypatch):
sender = DummyAgent("sender")
target = DummyAgent("receiver")
# Fake logger
sender.logger = MagicMock()
# Patch inbox.put
target.inbox.put = AsyncMock()
message = InternalMessage(to="receiver", sender="sender", body="hello")
await sender.send(message)
target.inbox.put.assert_awaited_once_with(message)
sender.logger.debug.assert_called_once()
@pytest.mark.asyncio
async def test_process_inbox_calls_handle_message(monkeypatch):
agent = DummyAgent("dummy")
agent.logger = MagicMock()
# Make agent running so loop triggers
agent._running = True
# Prepare inbox to give one message then stop
msg = InternalMessage(to="dummy", sender="x", body="test")
async def get_once():
agent._running = False # stop after first iteration
return msg
agent.inbox.get = AsyncMock(side_effect=get_once)
agent.handle_message = AsyncMock()
await agent._process_inbox()
agent.handle_message.assert_awaited_once_with(msg)
@pytest.mark.asyncio
async def test_receive_internal_zmq_loop_success(monkeypatch):
agent = DummyAgent("dummy")
agent.logger = MagicMock()
agent._running = True
mock_socket = MagicMock()
mock_socket.recv_multipart = AsyncMock(
side_effect=[
(
b"topic",
InternalMessage(to="dummy", sender="x", body="hi").model_dump_json().encode(),
),
asyncio.CancelledError(), # stop loop
]
)
agent._internal_sub_socket = mock_socket
agent.inbox.put = AsyncMock()
await agent._receive_internal_zmq_loop()
agent.inbox.put.assert_awaited() # message forwarded
@pytest.mark.asyncio
async def test_receive_internal_zmq_loop_exception_logs_error():
agent = DummyAgent("dummy")
agent.logger = MagicMock()
agent._running = True
mock_socket = MagicMock()
mock_socket.recv_multipart = AsyncMock(
side_effect=[Exception("boom"), asyncio.CancelledError()]
)
agent._internal_sub_socket = mock_socket
agent.inbox.put = AsyncMock()
await agent._receive_internal_zmq_loop()
agent.logger.exception.assert_called_once()
assert "Could not process ZMQ message." in agent.logger.exception.call_args[0][0]
@pytest.mark.asyncio
async def test_base_agent_handle_message_not_implemented():
class RawAgent(BaseAgent):
async def setup(self):
pass
agent = RawAgent("raw")
msg = InternalMessage(to="raw", sender="x", body="hi")
with pytest.raises(NotImplementedError):
await BaseAgent.handle_message(agent, msg)
@pytest.mark.asyncio
async def test_base_agent_setup_abstract_method_body_executes():
"""
Covers the 'pass' inside BaseAgent.setup().
Since BaseAgent is abstract, we do NOT instantiate it.
We call the coroutine function directly on BaseAgent and pass a dummy self.
"""
class Dummy:
"""Minimal stub to act as 'self'."""
pass
stub = Dummy()
# Call BaseAgent.setup() as an unbound coroutine, passing stub as 'self'
result = await BaseAgent.setup(stub)
# The method contains only 'pass', so it returns None
assert result is None

View File

@@ -86,3 +86,34 @@ def test_setup_logging_zmq_handler(mock_zmq_context):
args = mock_dict_config.call_args[0][0]
assert "interface_or_socket" in args["handlers"]["ui"]
def test_add_logging_level_method_name_exists_in_logging():
# method_name explicitly set to an existing logging method → triggers first hasattr branch
with pytest.raises(AttributeError) as exc:
add_logging_level("NEWDUPLEVEL", 37, method_name="info")
assert "info already defined in logging module" in str(exc.value)
def test_add_logging_level_method_name_exists_in_logger_class():
# 'makeRecord' exists on Logger class but not on the logging module
with pytest.raises(AttributeError) as exc:
add_logging_level("ANOTHERLEVEL", 38, method_name="makeRecord")
assert "makeRecord already defined in logger class" in str(exc.value)
def test_add_logging_level_log_to_root_path_executes_without_error():
# Verify log_to_root is installed and callable — without asserting logging output
level_name = "ROOTTEST"
level_num = 36
add_logging_level(level_name, level_num)
# Simply call the injected root logger method
# The line is executed even if we don't validate output
root_logging_method = getattr(logging, level_name.lower(), None)
assert callable(root_logging_method)
# Execute the method to hit log_to_root in coverage.
# No need to verify log output.
root_logging_method("some message")

View File

@@ -0,0 +1,12 @@
from control_backend.schemas.message import Message
def base_message() -> Message:
return Message(message="Example")
def test_valid_message():
mess = base_message()
validated = Message.model_validate(mess)
assert isinstance(validated, Message)
assert validated.message == "Example"

View File

@@ -1,26 +1,88 @@
import pytest
from pydantic import ValidationError
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, RIMessage, SpeechCommand
def valid_command_1():
return SpeechCommand(data="Hallo?")
def valid_command_2():
return GestureCommand(endpoint=RIEndpoint.GESTURE_TAG, data="happy")
def valid_command_3():
return GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="happy_1")
def invalid_command_1():
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
def invalid_command_2():
return RIMessage(endpoint=RIEndpoint.PING, data="Hey!")
def invalid_command_3():
return RIMessage(endpoint=RIEndpoint.GESTURE_SINGLE, data={1, 2, 3})
def invalid_command_4():
test: RIMessage = GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="asdsad")
def change_endpoint(msg: RIMessage):
msg.endpoint = RIEndpoint.PING
change_endpoint(test)
return test
def test_valid_speech_command_1():
command = valid_command_1()
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
def test_valid_gesture_command_1():
command = valid_command_2()
RIMessage.model_validate(command)
GestureCommand.model_validate(command)
def test_valid_gesture_command_2():
command = valid_command_3()
RIMessage.model_validate(command)
GestureCommand.model_validate(command)
def test_invalid_speech_command_1():
command = invalid_command_1()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
SpeechCommand.model_validate(command)
def test_invalid_gesture_command_1():
command = invalid_command_2()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)
def test_invalid_gesture_command_2():
command = invalid_command_3()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)
def test_invalid_gesture_command_3():
command = invalid_command_4()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)

73
test/unit/test_main.py Normal file
View File

@@ -0,0 +1,73 @@
import asyncio
import sys
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from control_backend.api.v1.router import api_router
from control_backend.main import app, lifespan
# Fix event loop on Windows
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@pytest.fixture
def client():
# Patch setup_logging so it does nothing
with patch("control_backend.main.setup_logging"):
with TestClient(app) as c:
yield c
def test_root_fast():
# Patch heavy startup code so it doesnt slow down
with patch("control_backend.main.setup_logging"), patch("control_backend.main.lifespan"):
client = TestClient(app)
resp = client.get("/")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
def test_cors_middleware_added():
"""Test that CORSMiddleware is correctly added to the app."""
from starlette.middleware.cors import CORSMiddleware
middleware_classes = [m.cls for m in app.user_middleware]
assert CORSMiddleware in middleware_classes
def test_api_router_included():
"""Test that the API router is included in the FastAPI app."""
route_paths = [r.path for r in app.routes]
for route in api_router.routes:
assert route.path in route_paths
@pytest.mark.asyncio
async def test_lifespan_agent_start_exception():
"""
Trigger an exception during agent startup to cover the error logging branch.
Ensures exceptions are logged properly and re-raised.
"""
with (
patch(
"control_backend.main.RICommunicationAgent.start", new_callable=AsyncMock
) as ri_start,
patch("control_backend.main.setup_logging"),
patch("control_backend.main.threading.Thread"),
):
# Force RICommunicationAgent.start to raise an exception
ri_start.side_effect = Exception("Test exception")
with patch("control_backend.main.logger") as mock_logger:
with pytest.raises(Exception, match="Test exception"):
async with lifespan(app):
pass
# Verify the error was logged correctly
assert mock_logger.error.called
args, _ = mock_logger.error.call_args
assert isinstance(args[2], Exception)