Merge branch 'feat/cb2ui-robot-connections' into 'dev'

Merge new implementation of robot ri api with unit tests into dev.

See merge request ics/sp/2025/n25b/pepperplus-cb!6
This commit was merged in pull request #6.
This commit is contained in:
Twirre
2025-10-28 13:49:40 +00:00
17 changed files with 1136 additions and 3 deletions

2
.gitignore vendored
View File

@@ -199,7 +199,7 @@ cython_debug/
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
.vscode/
# Ruff stuff:
.ruff_cache/

View File

@@ -22,5 +22,6 @@ test:
tags:
- test
script:
- uv run --only-group test pytest
# - uv run --group integration-test pytest test/integration
- uv run --only-group test pytest test/unit

7
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"test"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@@ -11,6 +11,10 @@ dependencies = [
"pyaudio>=0.2.14",
"pydantic>=2.12.0",
"pydantic-settings>=2.11.0",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
"pyzmq>=27.1.0",
"silero-vad>=6.0.0",
"spade>=4.1.0",
@@ -25,6 +29,9 @@ dev = [
"ruff>=0.14.2",
"ruff-format>=0.3.0",
]
integration-test = [
"soundfile>=0.13.1",
]
test = [
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",

View File

@@ -0,0 +1,74 @@
import json
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
class RICommandAgent(Agent):
subsocket: zmq.Socket
pubsocket: zmq.Socket
address = ""
bind = False
def __init__(
self,
jid: str,
password: str,
port: int = 5222,
verify_security: bool = False,
address="tcp://localhost:0000",
bind=False,
):
super().__init__(jid, password, port, verify_security)
self.address = address
self.bind = bind
class SendCommandsBehaviour(CyclicBehaviour):
async def run(self):
"""
Run the command publishing loop indefinetely.
"""
assert self.agent is not None
# Get a message internally (with topic command)
topic, body = await self.agent.subsocket.recv_multipart()
# Try to get body
try:
body = json.loads(body)
message = SpeechCommand.model_validate(body)
# Send to the robot.
await self.agent.pubsocket.send_json(message.model_dump())
except Exception as e:
logger.error("Error processing message: %s", e)
async def setup(self):
"""
Setup the command agent
"""
logger.info("Setting up %s", self.jid)
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind:
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_comm_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent
commands_behaviour = self.SendCommandsBehaviour()
self.add_behaviour(commands_behaviour)
logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,165 @@
import asyncio
import json
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.message import Message
from control_backend.agents.ri_command_agent import RICommandAgent
logger = logging.getLogger(__name__)
class RICommunicationAgent(Agent):
req_socket: zmq.Socket
_address = ""
_bind = True
def __init__(
self,
jid: str,
password: str,
port: int = 5222,
verify_security: bool = False,
address="tcp://localhost:0000",
bind=False,
):
super().__init__(jid, password, port, verify_security)
self._address = address
self._bind = bind
class ListenBehaviour(CyclicBehaviour):
async def run(self):
"""
Run the listening (ping) loop indefinetely.
"""
assert self.agent is not None
# We need to listen and sent pings.
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
await self.agent.req_socket.send_json(message)
# Wait up to three seconds for a reply:)
try:
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :(
except asyncio.TimeoutError as e:
logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
logger.debug('Received message "%s"', message)
if "endpoint" not in message:
logger.error("No received endpoint in message, excepted ping endpoint.")
return
# See what endpoint we received
match message["endpoint"]:
case "ping":
await asyncio.sleep(1)
case _:
logger.info(
"Received message with topic different than ping, while ping expected."
)
async def setup(self, max_retries: int = 5):
"""
Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
"""
logger.info("Setting up %s", self.jid)
retries = 0
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Bind request socket
self.req_socket = context.socket(zmq.REQ)
if self._bind:
self.req_socket.bind(self._address)
else:
self.req_socket.connect(self._address)
# Send our message and receive one back:)
message = {"endpoint": "negotiate/ports", "data": None}
await self.req_socket.send_json(message)
try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError:
logger.warning(
"No connection established in 20 seconds (attempt %d/%d)",
retries + 1,
max_retries,
)
retries += 1
continue
except Exception as e:
logger.error("Unexpected error during negotiation: %s", e)
retries += 1
continue
# Validate endpoint
endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports":
# TODO: Should this send a message back?
logger.error(
"Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
max_retries,
)
retries += 1
continue
# At this point, we have a valid response
try:
for port_data in received_message["data"]:
id = port_data["id"]
port = port_data["port"]
bind = port_data["bind"]
if not bind:
addr = f"tcp://localhost:{port}"
else:
addr = f"tcp://*:{port}"
match id:
case "main":
if addr != self._address:
if not bind:
self.req_socket.connect(addr)
else:
self.req_socket.bind(addr)
case "actuation":
ri_commands_agent = RICommandAgent(
settings.agent_settings.ri_command_agent_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.ri_command_agent_name,
address=addr,
bind=bind,
)
await ri_commands_agent.start()
case _:
logger.warning("Unhandled negotiation id: %s", id)
except Exception as e:
logger.error("Error unpacking negotiation data: %s", e)
retries += 1
continue
# setup succeeded
break
else:
logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries)
return
# Set up ping behaviour
listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour)
logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,22 @@
from fastapi import APIRouter, Request
import logging
from zmq import Socket
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}

View File

@@ -1,9 +1,11 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import message, sse
from control_backend.api.v1.endpoints import message, sse, command
api_router = APIRouter()
api_router.include_router(message.router, tags=["Messages"])
api_router.include_router(sse.router, tags=["SSE"])
api_router.include_router(command.router, tags=["Commands"])

View File

@@ -12,6 +12,9 @@ class AgentSettings(BaseModel):
belief_collector_agent_name: str = "belief_collector"
test_agent_name: str = "test_agent"
ri_communication_agent_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent"
class Settings(BaseSettings):
app_title: str = "PepperPlus"

View File

@@ -9,6 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
@@ -30,6 +31,14 @@ async def lifespan(app: FastAPI):
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
# Initiate agents
ri_communication_agent = RICommunicationAgent(
settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.ri_communication_agent_name,
address="tcp://*:5555",
bind=True,
)
await ri_communication_agent.start()
bdi_core = BDICoreAgent(
settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name,

View File

@@ -0,0 +1,20 @@
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel, Field, ValidationError
class RIEndpoint(str, Enum):
SPEECH = "actuate/speech"
PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports"
class RIMessage(BaseModel):
endpoint: RIEndpoint
data: Any
class SpeechCommand(RIMessage):
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str

View File

@@ -0,0 +1,102 @@
import asyncio
import zmq
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.schemas.ri_message import SpeechCommand
@pytest.mark.asyncio
async def test_setup_bind(monkeypatch):
"""Test setup with bind=True"""
fake_socket = MagicMock()
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup()
# Ensure PUB socket bound
fake_socket.bind.assert_any_call("tcp://localhost:5555")
# Ensure SUB socket connected to internal address and subscribed
fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
# Ensure behaviour attached
assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_connect(monkeypatch):
"""Test setup with bind=False"""
fake_socket = MagicMock()
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup()
# Ensure PUB socket connected
fake_socket.connect.assert_any_call("tcp://localhost:5555")
@pytest.mark.asyncio
async def test_send_commands_behaviour_valid_message():
"""Test behaviour with valid JSON message"""
fake_socket = AsyncMock()
message_dict = {"message": "hello"}
fake_socket.recv_multipart = AsyncMock(
return_value=(b"command", json.dumps(message_dict).encode("utf-8"))
)
fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent
with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand:
mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message
await behaviour.run()
fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_awaited_with(mock_message.model_dump())
@pytest.mark.asyncio
async def test_send_commands_behaviour_invalid_message(caplog):
"""Test behaviour with invalid JSON message triggers error logging"""
fake_socket = AsyncMock()
fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}"))
fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password")
agent.subsocket = fake_socket
agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent
with caplog.at_level("ERROR"):
await behaviour.run()
fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_not_awaited()
assert "Error processing message" in caplog.text

View File

@@ -0,0 +1,591 @@
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY
from control_backend.agents.ri_communication_agent import RICommunicationAgent
def fake_json_correct_negototiate_1():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True},
],
}
)
def fake_json_correct_negototiate_2():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_3():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_4():
# Different port, do bind
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 4555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_correct_negototiate_5():
# Different port, dont bind.
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 4555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True},
],
}
)
def fake_json_wrong_negototiate_1():
return AsyncMock(return_value={"endpoint": "ping", "data": ""})
def fake_json_invalid_id_negototiate():
return AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "banana", "port": 4555, "bind": False},
{"id": "tomato", "port": 5557, "bind": True},
],
}
)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5556", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("ERROR"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.recv_json.assert_awaited()
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "Failed to set up RICommunicationAgent" in caplog.text
# Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
"""
Test the setup of the communication agent with different bind value
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=True
)
await agent.setup()
# --- Assert ---
fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup()
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None})
fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with(
ANY, # Server Name
ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True,
)
# Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
"""
Test the functionality of setup with incorrect id
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.recv_json.assert_awaited()
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "Unhandled negotiation id:" in caplog.text
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
# --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1)
# --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555")
# Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited()
assert "No connection established in 20 seconds" in caplog.text
# Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio
async def test_listen_behaviour_ping_correct(caplog):
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
# TODO: Integration test between actual server and password needed for spade agents
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("DEBUG"):
await behaviour.run()
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
assert "Received message" in caplog.text
@pytest.mark.asyncio
async def test_listen_behaviour_ping_wrong_endpoint(caplog):
"""
Test if our listen behaviour can work with wrong messages (wrong endpoint)
"""
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# This is a message for the wrong endpoint >:(
fake_socket.recv_json = AsyncMock(
return_value={
"endpoint": "negotiate/ports",
"data": [
{"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True},
],
}
)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("INFO"):
await behaviour.run()
assert "Received message with topic different than ping, while ping expected." in caplog.text
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_listen_behaviour_timeout(caplog):
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
with caplog.at_level("INFO"):
await behaviour.run()
assert "No ping retrieved in 3 seconds" in caplog.text
@pytest.mark.asyncio
async def test_listen_behaviour_ping_no_endpoint(caplog):
"""
Test if our listen behaviour can work with wrong messages (wrong endpoint)
"""
fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock()
# This is a message without endpoint >:(
fake_socket.recv_json = AsyncMock(
return_value={
"data": "I dont have an endpoint >:)",
}
)
agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops)
with caplog.at_level("ERROR"):
await behaviour.run()
assert "No received endpoint in message, excepted ping endpoint." in caplog.text
fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio
async def test_setup_unexpected_exception(monkeypatch, caplog):
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
with caplog.at_level("ERROR"):
await agent.setup(max_retries=1)
# Ensure that the error was logged
assert "Unexpected error during negotiation: boom!" in caplog.text
@pytest.mark.asyncio
async def test_setup_unpacking_exception(monkeypatch, caplog):
# --- Arrange ---
fake_socket = MagicMock()
fake_socket.send_json = AsyncMock()
# Make recv_json return malformed negotiation data to trigger unpacking exception
malformed_data = {
"endpoint": "negotiate/ports",
"data": [{"id": "main"}],
} # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch context.socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Patch RICommandAgent so it won't actually start
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False
)
# --- Act & Assert ---
with caplog.at_level("ERROR"):
await agent.setup(max_retries=1)
# Ensure the unpacking exception was logged
assert "Error unpacking negotiation data" in caplog.text
# Ensure no command agent was started
fake_agent_instance.start.assert_not_awaited()
# Ensure no behaviour was attached
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)

View File

@@ -0,0 +1,63 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from control_backend.api.v1.endpoints import command
from control_backend.schemas.ri_message import SpeechCommand
@pytest.fixture
def app():
"""
Creates a FastAPI test app and attaches the router under test.
Also sets up a mock internal_comm_socket.
"""
app = FastAPI()
app.include_router(command.router)
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
def test_receive_command_endpoint(client, app):
"""
Test that a POST to /command sends the right multipart message
and returns a 202 with the expected JSON body.
"""
mock_socket = app.state.internal_comm_socket
# Prepare test payload that matches SpeechCommand
payload = {"endpoint": "actuate/speech", "data": "yooo"}
# Send POST request
response = client.post("/command", json=payload)
# Check response
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the socket was called with the correct data
assert mock_socket.send_multipart.called, "Socket should be used to send data"
args, kwargs = mock_socket.send_multipart.call_args
sent_data = args[0]
assert sent_data[0] == b"command"
# Check JSON encoding roughly matches
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
def test_receive_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
assert response.status_code == 422 # validation error

View File

@@ -0,0 +1,36 @@
import pytest
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
from pydantic import ValidationError
def valid_command_1():
return SpeechCommand(data="Hallo?")
def invalid_command_1():
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
def test_valid_speech_command_1():
command = valid_command_1()
try:
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
assert True
except ValidationError:
assert False
def test_invalid_speech_command_1():
command = invalid_command_1()
passed_ri_message_validation = False
try:
# Should succeed, still.
RIMessage.model_validate(command)
passed_ri_message_validation = True
# Should fail.
SpeechCommand.model_validate(command)
assert False
except ValidationError:
assert passed_ri_message_validation

31
uv.lock generated
View File

@@ -1336,6 +1336,10 @@ dependencies = [
{ name = "pyaudio" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
{ name = "pytest-mock" },
{ name = "pyzmq" },
{ name = "silero-vad" },
{ name = "spade" },
@@ -1350,6 +1354,9 @@ dev = [
{ name = "ruff" },
{ name = "ruff-format" },
]
integration-test = [
{ name = "soundfile" },
]
test = [
{ name = "pytest" },
{ name = "pytest-asyncio" },
@@ -1365,6 +1372,10 @@ requires-dist = [
{ name = "pyaudio", specifier = ">=0.2.14" },
{ name = "pydantic", specifier = ">=2.12.0" },
{ name = "pydantic-settings", specifier = ">=2.11.0" },
{ name = "pytest", specifier = ">=8.4.2" },
{ name = "pytest-asyncio", specifier = ">=1.2.0" },
{ name = "pytest-cov", specifier = ">=7.0.0" },
{ name = "pytest-mock", specifier = ">=3.15.1" },
{ name = "pyzmq", specifier = ">=27.1.0" },
{ name = "silero-vad", specifier = ">=6.0.0" },
{ name = "spade", specifier = ">=4.1.0" },
@@ -1379,6 +1390,7 @@ dev = [
{ name = "ruff", specifier = ">=0.14.2" },
{ name = "ruff-format", specifier = ">=0.3.0" },
]
integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }]
test = [
{ name = "pytest", specifier = ">=8.4.2" },
{ name = "pytest-asyncio", specifier = ">=1.2.0" },
@@ -2208,6 +2220,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
]
[[package]]
name = "soundfile"
version = "0.13.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cffi" },
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e1/41/9b873a8c055582859b239be17902a85339bec6a30ad162f98c9b0288a2cc/soundfile-0.13.1.tar.gz", hash = "sha256:b2c68dab1e30297317080a5b43df57e302584c49e2942defdde0acccc53f0e5b", size = 46156, upload-time = "2025-01-25T09:17:04.831Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/64/28/e2a36573ccbcf3d57c00626a21fe51989380636e821b341d36ccca0c1c3a/soundfile-0.13.1-py2.py3-none-any.whl", hash = "sha256:a23c717560da2cf4c7b5ae1142514e0fd82d6bbd9dfc93a50423447142f2c445", size = 25751, upload-time = "2025-01-25T09:16:44.235Z" },
{ url = "https://files.pythonhosted.org/packages/ea/ab/73e97a5b3cc46bba7ff8650a1504348fa1863a6f9d57d7001c6b67c5f20e/soundfile-0.13.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:82dc664d19831933fe59adad199bf3945ad06d84bc111a5b4c0d3089a5b9ec33", size = 1142250, upload-time = "2025-01-25T09:16:47.583Z" },
{ url = "https://files.pythonhosted.org/packages/a0/e5/58fd1a8d7b26fc113af244f966ee3aecf03cb9293cb935daaddc1e455e18/soundfile-0.13.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:743f12c12c4054921e15736c6be09ac26b3b3d603aef6fd69f9dde68748f2593", size = 1101406, upload-time = "2025-01-25T09:16:49.662Z" },
{ url = "https://files.pythonhosted.org/packages/58/ae/c0e4a53d77cf6e9a04179535766b3321b0b9ced5f70522e4caf9329f0046/soundfile-0.13.1-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9c9e855f5a4d06ce4213f31918653ab7de0c5a8d8107cd2427e44b42df547deb", size = 1235729, upload-time = "2025-01-25T09:16:53.018Z" },
{ url = "https://files.pythonhosted.org/packages/57/5e/70bdd9579b35003a489fc850b5047beeda26328053ebadc1fb60f320f7db/soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:03267c4e493315294834a0870f31dbb3b28a95561b80b134f0bd3cf2d5f0e618", size = 1313646, upload-time = "2025-01-25T09:16:54.872Z" },
{ url = "https://files.pythonhosted.org/packages/fe/df/8c11dc4dfceda14e3003bb81a0d0edcaaf0796dd7b4f826ea3e532146bba/soundfile-0.13.1-py2.py3-none-win32.whl", hash = "sha256:c734564fab7c5ddf8e9be5bf70bab68042cd17e9c214c06e365e20d64f9a69d5", size = 899881, upload-time = "2025-01-25T09:16:56.663Z" },
{ url = "https://files.pythonhosted.org/packages/14/e9/6b761de83277f2f02ded7e7ea6f07828ec78e4b229b80e4ca55dd205b9dc/soundfile-0.13.1-py2.py3-none-win_amd64.whl", hash = "sha256:1e70a05a0626524a69e9f0f4dd2ec174b4e9567f4d8b6c11d38b5c289be36ee9", size = 1019162, upload-time = "2025-01-25T09:16:59.573Z" },
]
[[package]]
name = "spade"
version = "4.1.0"