diff --git a/.gitignore b/.gitignore index f6ad342..b6322a4 100644 --- a/.gitignore +++ b/.gitignore @@ -199,7 +199,7 @@ cython_debug/ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore # and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder -# .vscode/ +.vscode/ # Ruff stuff: .ruff_cache/ diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index f4e1883..f2d3c52 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,5 +22,6 @@ test: tags: - test script: - - uv run --only-group test pytest + # - uv run --group integration-test pytest test/integration + - uv run --only-group test pytest test/unit diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b2b8866 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "test" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 641c302..8299d0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,10 @@ dependencies = [ "pyaudio>=0.2.14", "pydantic>=2.12.0", "pydantic-settings>=2.11.0", + "pytest>=8.4.2", + "pytest-asyncio>=1.2.0", + "pytest-cov>=7.0.0", + "pytest-mock>=3.15.1", "pyzmq>=27.1.0", "silero-vad>=6.0.0", "spade>=4.1.0", diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py new file mode 100644 index 0000000..01fc824 --- /dev/null +++ b/src/control_backend/agents/ri_command_agent.py @@ -0,0 +1,74 @@ +import json +import logging +from spade.agent import Agent +from spade.behaviour import CyclicBehaviour +import zmq + +from control_backend.core.config import settings +from control_backend.core.zmq_context import context +from control_backend.schemas.ri_message import SpeechCommand + +logger = logging.getLogger(__name__) + + +class RICommandAgent(Agent): + subsocket: zmq.Socket + pubsocket: zmq.Socket + address = "" + bind = False + + def __init__( + self, + jid: str, + password: str, + port: int = 5222, + verify_security: bool = False, + address="tcp://localhost:0000", + bind=False, + ): + super().__init__(jid, password, port, verify_security) + self.address = address + self.bind = bind + + class SendCommandsBehaviour(CyclicBehaviour): + async def run(self): + """ + Run the command publishing loop indefinetely. + """ + assert self.agent is not None + # Get a message internally (with topic command) + topic, body = await self.agent.subsocket.recv_multipart() + + # Try to get body + try: + body = json.loads(body) + message = SpeechCommand.model_validate(body) + + # Send to the robot. + await self.agent.pubsocket.send_json(message.model_dump()) + except Exception as e: + logger.error("Error processing message: %s", e) + + async def setup(self): + """ + Setup the command agent + """ + logger.info("Setting up %s", self.jid) + + # To the robot + self.pubsocket = context.socket(zmq.PUB) + if self.bind: + self.pubsocket.bind(self.address) + else: + self.pubsocket.connect(self.address) + + # Receive internal topics regarding commands + self.subsocket = context.socket(zmq.SUB) + self.subsocket.connect(settings.zmq_settings.internal_comm_address) + self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") + + # Add behaviour to our agent + commands_behaviour = self.SendCommandsBehaviour() + self.add_behaviour(commands_behaviour) + + logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py new file mode 100644 index 0000000..504c707 --- /dev/null +++ b/src/control_backend/agents/ri_communication_agent.py @@ -0,0 +1,165 @@ +import asyncio +import json +import logging +from spade.agent import Agent +from spade.behaviour import CyclicBehaviour +import zmq + +from control_backend.core.config import settings +from control_backend.core.zmq_context import context +from control_backend.schemas.message import Message +from control_backend.agents.ri_command_agent import RICommandAgent + +logger = logging.getLogger(__name__) + + +class RICommunicationAgent(Agent): + req_socket: zmq.Socket + _address = "" + _bind = True + + def __init__( + self, + jid: str, + password: str, + port: int = 5222, + verify_security: bool = False, + address="tcp://localhost:0000", + bind=False, + ): + super().__init__(jid, password, port, verify_security) + self._address = address + self._bind = bind + + class ListenBehaviour(CyclicBehaviour): + async def run(self): + """ + Run the listening (ping) loop indefinetely. + """ + assert self.agent is not None + + # We need to listen and sent pings. + message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}} + await self.agent.req_socket.send_json(message) + + # Wait up to three seconds for a reply:) + try: + message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0) + + # We didnt get a reply :( + except asyncio.TimeoutError as e: + logger.info("No ping retrieved in 3 seconds, killing myself.") + self.kill() + + logger.debug('Received message "%s"', message) + if "endpoint" not in message: + logger.error("No received endpoint in message, excepted ping endpoint.") + return + + # See what endpoint we received + match message["endpoint"]: + case "ping": + await asyncio.sleep(1) + case _: + logger.info( + "Received message with topic different than ping, while ping expected." + ) + + async def setup(self, max_retries: int = 5): + """ + Try to setup the communication agent, we have 5 retries in case we dont have a response yet. + """ + logger.info("Setting up %s", self.jid) + retries = 0 + + # Let's try a certain amount of times before failing connection + while retries < max_retries: + # Bind request socket + self.req_socket = context.socket(zmq.REQ) + if self._bind: + self.req_socket.bind(self._address) + else: + self.req_socket.connect(self._address) + + # Send our message and receive one back:) + message = {"endpoint": "negotiate/ports", "data": None} + await self.req_socket.send_json(message) + + try: + received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) + + except asyncio.TimeoutError: + logger.warning( + "No connection established in 20 seconds (attempt %d/%d)", + retries + 1, + max_retries, + ) + retries += 1 + continue + + except Exception as e: + logger.error("Unexpected error during negotiation: %s", e) + retries += 1 + continue + + # Validate endpoint + endpoint = received_message.get("endpoint") + if endpoint != "negotiate/ports": + # TODO: Should this send a message back? + logger.error( + "Invalid endpoint '%s' received (attempt %d/%d)", + endpoint, + retries + 1, + max_retries, + ) + retries += 1 + continue + + # At this point, we have a valid response + try: + for port_data in received_message["data"]: + id = port_data["id"] + port = port_data["port"] + bind = port_data["bind"] + + if not bind: + addr = f"tcp://localhost:{port}" + else: + addr = f"tcp://*:{port}" + + match id: + case "main": + if addr != self._address: + if not bind: + self.req_socket.connect(addr) + else: + self.req_socket.bind(addr) + case "actuation": + ri_commands_agent = RICommandAgent( + settings.agent_settings.ri_command_agent_name + + "@" + + settings.agent_settings.host, + settings.agent_settings.ri_command_agent_name, + address=addr, + bind=bind, + ) + await ri_commands_agent.start() + case _: + logger.warning("Unhandled negotiation id: %s", id) + + except Exception as e: + logger.error("Error unpacking negotiation data: %s", e) + retries += 1 + continue + + # setup succeeded + break + + else: + logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries) + return + + # Set up ping behaviour + listen_behaviour = self.ListenBehaviour() + self.add_behaviour(listen_behaviour) + logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py new file mode 100644 index 0000000..badaf90 --- /dev/null +++ b/src/control_backend/api/v1/endpoints/command.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter, Request +import logging + +from zmq import Socket + +from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post("/command", status_code=202) +async def receive_command(command: SpeechCommand, request: Request): + # Validate and retrieve data. + SpeechCommand.model_validate(command) + topic = b"command" + pub_socket: Socket = request.app.state.internal_comm_socket + pub_socket.send_multipart([topic, command.model_dump_json().encode()]) + + + return {"status": "Command received"} diff --git a/src/control_backend/api/v1/router.py b/src/control_backend/api/v1/router.py index f2fb39a..dc7aea9 100644 --- a/src/control_backend/api/v1/router.py +++ b/src/control_backend/api/v1/router.py @@ -1,9 +1,11 @@ from fastapi.routing import APIRouter -from control_backend.api.v1.endpoints import message, sse +from control_backend.api.v1.endpoints import message, sse, command api_router = APIRouter() api_router.include_router(message.router, tags=["Messages"]) api_router.include_router(sse.router, tags=["SSE"]) + +api_router.include_router(command.router, tags=["Commands"]) diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index acc3d76..4758618 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -12,6 +12,9 @@ class AgentSettings(BaseModel): belief_collector_agent_name: str = "belief_collector" vad_agent_name: str = "vad_agent" + ri_communication_agent_name: str = "ri_communication_agent" + ri_command_agent_name: str = "ri_command_agent" + class Settings(BaseSettings): app_title: str = "PepperPlus" diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 82038d2..fc5aa6c 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -9,6 +9,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware # Internal imports +from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.bdi.bdi_core import BDICoreAgent from control_backend.agents.vad_agent import VADAgent from control_backend.api.v1.router import api_router @@ -31,6 +32,14 @@ async def lifespan(app: FastAPI): logger.info("Internal publishing socket bound to %s", internal_comm_socket) # Initiate agents + ri_communication_agent = RICommunicationAgent( + settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host, + settings.agent_settings.ri_communication_agent_name, + address="tcp://*:5555", + bind=True, + ) + await ri_communication_agent.start() + bdi_core = BDICoreAgent( settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name, diff --git a/src/control_backend/schemas/ri_message.py b/src/control_backend/schemas/ri_message.py new file mode 100644 index 0000000..97b7930 --- /dev/null +++ b/src/control_backend/schemas/ri_message.py @@ -0,0 +1,20 @@ +from enum import Enum +from typing import Any, Literal + +from pydantic import BaseModel, Field, ValidationError + + +class RIEndpoint(str, Enum): + SPEECH = "actuate/speech" + PING = "ping" + NEGOTIATE_PORTS = "negotiate/ports" + + +class RIMessage(BaseModel): + endpoint: RIEndpoint + data: Any + + +class SpeechCommand(RIMessage): + endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH) + data: str diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py new file mode 100644 index 0000000..219d682 --- /dev/null +++ b/test/integration/agents/test_ri_commands_agent.py @@ -0,0 +1,102 @@ +import asyncio +import zmq +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from control_backend.agents.ri_command_agent import RICommandAgent +from control_backend.schemas.ri_message import SpeechCommand + + +@pytest.mark.asyncio +async def test_setup_bind(monkeypatch): + """Test setup with bind=True""" + fake_socket = MagicMock() + monkeypatch.setattr( + "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket + ) + + agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) + monkeypatch.setattr( + "control_backend.agents.ri_command_agent.settings", + MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), + ) + + await agent.setup() + + # Ensure PUB socket bound + fake_socket.bind.assert_any_call("tcp://localhost:5555") + # Ensure SUB socket connected to internal address and subscribed + fake_socket.connect.assert_any_call("tcp://internal:1234") + fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") + + # Ensure behaviour attached + assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours) + + +@pytest.mark.asyncio +async def test_setup_connect(monkeypatch): + """Test setup with bind=False""" + fake_socket = MagicMock() + monkeypatch.setattr( + "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket + ) + + agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) + monkeypatch.setattr( + "control_backend.agents.ri_command_agent.settings", + MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), + ) + + await agent.setup() + + # Ensure PUB socket connected + fake_socket.connect.assert_any_call("tcp://localhost:5555") + + +@pytest.mark.asyncio +async def test_send_commands_behaviour_valid_message(): + """Test behaviour with valid JSON message""" + fake_socket = AsyncMock() + message_dict = {"message": "hello"} + fake_socket.recv_multipart = AsyncMock( + return_value=(b"command", json.dumps(message_dict).encode("utf-8")) + ) + fake_socket.send_json = AsyncMock() + + agent = RICommandAgent("test@server", "password") + agent.subsocket = fake_socket + agent.pubsocket = fake_socket + + behaviour = agent.SendCommandsBehaviour() + behaviour.agent = agent + + with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand: + mock_message = MagicMock() + MockSpeechCommand.model_validate.return_value = mock_message + + await behaviour.run() + + fake_socket.recv_multipart.assert_awaited() + fake_socket.send_json.assert_awaited_with(mock_message.model_dump()) + + +@pytest.mark.asyncio +async def test_send_commands_behaviour_invalid_message(caplog): + """Test behaviour with invalid JSON message triggers error logging""" + fake_socket = AsyncMock() + fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}")) + fake_socket.send_json = AsyncMock() + + agent = RICommandAgent("test@server", "password") + agent.subsocket = fake_socket + agent.pubsocket = fake_socket + + behaviour = agent.SendCommandsBehaviour() + behaviour.agent = agent + + with caplog.at_level("ERROR"): + await behaviour.run() + + fake_socket.recv_multipart.assert_awaited() + fake_socket.send_json.assert_not_awaited() + assert "Error processing message" in caplog.text diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py new file mode 100644 index 0000000..3e4a056 --- /dev/null +++ b/test/integration/agents/test_ri_communication_agent.py @@ -0,0 +1,591 @@ +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch, ANY +from control_backend.agents.ri_communication_agent import RICommunicationAgent + + +def fake_json_correct_negototiate_1(): + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 5555, "bind": False}, + {"id": "actuation", "port": 5556, "bind": True}, + ], + } + ) + + +def fake_json_correct_negototiate_2(): + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 5555, "bind": False}, + {"id": "actuation", "port": 5557, "bind": True}, + ], + } + ) + + +def fake_json_correct_negototiate_3(): + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 5555, "bind": True}, + {"id": "actuation", "port": 5557, "bind": True}, + ], + } + ) + + +def fake_json_correct_negototiate_4(): + # Different port, do bind + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 4555, "bind": True}, + {"id": "actuation", "port": 5557, "bind": True}, + ], + } + ) + + +def fake_json_correct_negototiate_5(): + # Different port, dont bind. + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 4555, "bind": False}, + {"id": "actuation", "port": 5557, "bind": True}, + ], + } + ) + + +def fake_json_wrong_negototiate_1(): + return AsyncMock(return_value={"endpoint": "ping", "data": ""}) + + +def fake_json_invalid_id_negototiate(): + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "banana", "port": 4555, "bind": False}, + {"id": "tomato", "port": 5557, "bind": True}, + ], + } + ) + + +@pytest.mark.asyncio +async def test_setup_creates_socket_and_negotiate_1(monkeypatch): + """ + Test the setup of the communication agent + """ + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = fake_json_correct_negototiate_1() + + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + # Mock RICommandAgent agent startup + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + # --- Act --- + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + await agent.setup() + + # --- Assert --- + fake_socket.connect.assert_any_call("tcp://localhost:5555") + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.recv_json.assert_awaited() + fake_agent_instance.start.assert_awaited() + MockCommandAgent.assert_called_once_with( + ANY, # Server Name + ANY, # Server Password + address="tcp://*:5556", # derived from the 'port' value in negotiation + bind=True, + ) + # Ensure the agent attached a ListenBehaviour + assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + + +@pytest.mark.asyncio +async def test_setup_creates_socket_and_negotiate_2(monkeypatch): + """ + Test the setup of the communication agent + """ + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = fake_json_correct_negototiate_2() + + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + # Mock RICommandAgent agent startup + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + # --- Act --- + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + await agent.setup() + + # --- Assert --- + fake_socket.connect.assert_any_call("tcp://localhost:5555") + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.recv_json.assert_awaited() + fake_agent_instance.start.assert_awaited() + MockCommandAgent.assert_called_once_with( + ANY, # Server Name + ANY, # Server Password + address="tcp://*:5557", # derived from the 'port' value in negotiation + bind=True, + ) + # Ensure the agent attached a ListenBehaviour + assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + + +@pytest.mark.asyncio +async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): + """ + Test the functionality of setup with incorrect negotiation message + """ + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = fake_json_wrong_negototiate_1() + + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + # Mock RICommandAgent agent startup + + # We are sending wrong negotiation info to the communication agent, so we should retry and expect a + # better response, within a limited time. + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + # --- Act --- + with caplog.at_level("ERROR"): + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + await agent.setup(max_retries=1) + + # --- Assert --- + fake_socket.connect.assert_any_call("tcp://localhost:5555") + fake_socket.recv_json.assert_awaited() + + # Since it failed, there should not be any command agent. + fake_agent_instance.start.assert_not_awaited() + assert "Failed to set up RICommunicationAgent" in caplog.text + + # Ensure the agent did not attach a ListenBehaviour + assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + + +@pytest.mark.asyncio +async def test_setup_creates_socket_and_negotiate_4(monkeypatch): + """ + Test the setup of the communication agent with different bind value + """ + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = fake_json_correct_negototiate_3() + + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + # Mock RICommandAgent agent startup + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + # --- Act --- + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=True + ) + await agent.setup() + + # --- Assert --- + fake_socket.bind.assert_any_call("tcp://localhost:5555") + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.recv_json.assert_awaited() + fake_agent_instance.start.assert_awaited() + MockCommandAgent.assert_called_once_with( + ANY, # Server Name + ANY, # Server Password + address="tcp://*:5557", # derived from the 'port' value in negotiation + bind=True, + ) + # Ensure the agent attached a ListenBehaviour + assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + + +@pytest.mark.asyncio +async def test_setup_creates_socket_and_negotiate_5(monkeypatch): + """ + Test the setup of the communication agent + """ + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = fake_json_correct_negototiate_4() + + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + # Mock RICommandAgent agent startup + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + # --- Act --- + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + await agent.setup() + + # --- Assert --- + fake_socket.connect.assert_any_call("tcp://localhost:5555") + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.recv_json.assert_awaited() + fake_agent_instance.start.assert_awaited() + MockCommandAgent.assert_called_once_with( + ANY, # Server Name + ANY, # Server Password + address="tcp://*:5557", # derived from the 'port' value in negotiation + bind=True, + ) + # Ensure the agent attached a ListenBehaviour + assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + + +@pytest.mark.asyncio +async def test_setup_creates_socket_and_negotiate_6(monkeypatch): + """ + Test the setup of the communication agent + """ + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = fake_json_correct_negototiate_5() + + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + # Mock RICommandAgent agent startup + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + # --- Act --- + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + await agent.setup() + + # --- Assert --- + fake_socket.connect.assert_any_call("tcp://localhost:5555") + fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) + fake_socket.recv_json.assert_awaited() + fake_agent_instance.start.assert_awaited() + MockCommandAgent.assert_called_once_with( + ANY, # Server Name + ANY, # Server Password + address="tcp://*:5557", # derived from the 'port' value in negotiation + bind=True, + ) + # Ensure the agent attached a ListenBehaviour + assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + + +@pytest.mark.asyncio +async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): + """ + Test the functionality of setup with incorrect id + """ + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = fake_json_invalid_id_negototiate() + + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + # Mock RICommandAgent agent startup + + # We are sending wrong negotiation info to the communication agent, so we should retry and expect a + # better response, within a limited time. + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + # --- Act --- + with caplog.at_level("WARNING"): + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + await agent.setup(max_retries=1) + + # --- Assert --- + fake_socket.connect.assert_any_call("tcp://localhost:5555") + fake_socket.recv_json.assert_awaited() + + # Since it failed, there should not be any command agent. + fake_agent_instance.start.assert_not_awaited() + assert "Unhandled negotiation id:" in caplog.text + + +@pytest.mark.asyncio +async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): + """ + Test the functionality of setup with incorrect negotiation message + """ + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) + + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + # --- Act --- + with caplog.at_level("WARNING"): + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + await agent.setup(max_retries=1) + + # --- Assert --- + fake_socket.connect.assert_any_call("tcp://localhost:5555") + + # Since it failed, there should not be any command agent. + fake_agent_instance.start.assert_not_awaited() + assert "No connection established in 20 seconds" in caplog.text + + # Ensure the agent did not attach a ListenBehaviour + assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + + +@pytest.mark.asyncio +async def test_listen_behaviour_ping_correct(caplog): + fake_socket = AsyncMock() + fake_socket.send_json = AsyncMock() + fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}}) + + # TODO: Integration test between actual server and password needed for spade agents + agent = RICommunicationAgent("test@server", "password") + agent.req_socket = fake_socket + + behaviour = agent.ListenBehaviour() + agent.add_behaviour(behaviour) + + # Run once (CyclicBehaviour normally loops) + with caplog.at_level("DEBUG"): + await behaviour.run() + + fake_socket.send_json.assert_awaited() + fake_socket.recv_json.assert_awaited() + assert "Received message" in caplog.text + + +@pytest.mark.asyncio +async def test_listen_behaviour_ping_wrong_endpoint(caplog): + """ + Test if our listen behaviour can work with wrong messages (wrong endpoint) + """ + fake_socket = AsyncMock() + fake_socket.send_json = AsyncMock() + + # This is a message for the wrong endpoint >:( + fake_socket.recv_json = AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 5555, "bind": False}, + {"id": "actuation", "port": 5556, "bind": True}, + ], + } + ) + + agent = RICommunicationAgent("test@server", "password") + agent.req_socket = fake_socket + + behaviour = agent.ListenBehaviour() + agent.add_behaviour(behaviour) + + # Run once (CyclicBehaviour normally loops) + with caplog.at_level("INFO"): + await behaviour.run() + + assert "Received message with topic different than ping, while ping expected." in caplog.text + fake_socket.send_json.assert_awaited() + fake_socket.recv_json.assert_awaited() + + +@pytest.mark.asyncio +async def test_listen_behaviour_timeout(caplog): + fake_socket = AsyncMock() + fake_socket.send_json = AsyncMock() + # recv_json will never resolve, simulate timeout + fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) + + agent = RICommunicationAgent("test@server", "password") + agent.req_socket = fake_socket + + behaviour = agent.ListenBehaviour() + agent.add_behaviour(behaviour) + + with caplog.at_level("INFO"): + await behaviour.run() + + assert "No ping retrieved in 3 seconds" in caplog.text + + +@pytest.mark.asyncio +async def test_listen_behaviour_ping_no_endpoint(caplog): + """ + Test if our listen behaviour can work with wrong messages (wrong endpoint) + """ + fake_socket = AsyncMock() + fake_socket.send_json = AsyncMock() + + # This is a message without endpoint >:( + fake_socket.recv_json = AsyncMock( + return_value={ + "data": "I dont have an endpoint >:)", + } + ) + + agent = RICommunicationAgent("test@server", "password") + agent.req_socket = fake_socket + + behaviour = agent.ListenBehaviour() + agent.add_behaviour(behaviour) + + # Run once (CyclicBehaviour normally loops) + with caplog.at_level("ERROR"): + await behaviour.run() + + assert "No received endpoint in message, excepted ping endpoint." in caplog.text + fake_socket.send_json.assert_awaited() + fake_socket.recv_json.assert_awaited() + + +@pytest.mark.asyncio +async def test_setup_unexpected_exception(monkeypatch, caplog): + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + # Simulate unexpected exception during recv_json() + fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) + + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + + with caplog.at_level("ERROR"): + await agent.setup(max_retries=1) + + # Ensure that the error was logged + assert "Unexpected error during negotiation: boom!" in caplog.text + + +@pytest.mark.asyncio +async def test_setup_unpacking_exception(monkeypatch, caplog): + # --- Arrange --- + fake_socket = MagicMock() + fake_socket.send_json = AsyncMock() + + # Make recv_json return malformed negotiation data to trigger unpacking exception + malformed_data = { + "endpoint": "negotiate/ports", + "data": [{"id": "main"}], + } # missing 'port' and 'bind' + fake_socket.recv_json = AsyncMock(return_value=malformed_data) + + # Patch context.socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + # Patch RICommandAgent so it won't actually start + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: + fake_agent_instance = MockCommandAgent.return_value + fake_agent_instance.start = AsyncMock() + + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) + + # --- Act & Assert --- + with caplog.at_level("ERROR"): + await agent.setup(max_retries=1) + + # Ensure the unpacking exception was logged + assert "Error unpacking negotiation data" in caplog.text + + # Ensure no command agent was started + fake_agent_instance.start.assert_not_awaited() + + # Ensure no behaviour was attached + assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py new file mode 100644 index 0000000..07bd866 --- /dev/null +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -0,0 +1,63 @@ +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from unittest.mock import MagicMock + +from control_backend.api.v1.endpoints import command +from control_backend.schemas.ri_message import SpeechCommand + + +@pytest.fixture +def app(): + """ + Creates a FastAPI test app and attaches the router under test. + Also sets up a mock internal_comm_socket. + """ + app = FastAPI() + app.include_router(command.router) + app.state.internal_comm_socket = MagicMock() # mock ZMQ socket + return app + + +@pytest.fixture +def client(app): + """Create a test client for the app.""" + return TestClient(app) + + +def test_receive_command_endpoint(client, app): + """ + Test that a POST to /command sends the right multipart message + and returns a 202 with the expected JSON body. + """ + mock_socket = app.state.internal_comm_socket + + # Prepare test payload that matches SpeechCommand + payload = {"endpoint": "actuate/speech", "data": "yooo"} + + # Send POST request + response = client.post("/command", json=payload) + + # Check response + assert response.status_code == 202 + assert response.json() == {"status": "Command received"} + + # Verify that the socket was called with the correct data + assert mock_socket.send_multipart.called, "Socket should be used to send data" + + args, kwargs = mock_socket.send_multipart.call_args + sent_data = args[0] + + assert sent_data[0] == b"command" + # Check JSON encoding roughly matches + assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand) + + +def test_receive_command_invalid_payload(client): + """ + Test invalid data handling (schema validation). + """ + # Missing required field(s) + bad_payload = {"invalid": "data"} + response = client.post("/command", json=bad_payload) + assert response.status_code == 422 # validation error diff --git a/test/integration/schemas/test_ri_message.py b/test/integration/schemas/test_ri_message.py new file mode 100644 index 0000000..aef9ae6 --- /dev/null +++ b/test/integration/schemas/test_ri_message.py @@ -0,0 +1,36 @@ +import pytest +from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand +from pydantic import ValidationError + + +def valid_command_1(): + return SpeechCommand(data="Hallo?") + + +def invalid_command_1(): + return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.") + + +def test_valid_speech_command_1(): + command = valid_command_1() + try: + RIMessage.model_validate(command) + SpeechCommand.model_validate(command) + assert True + except ValidationError: + assert False + + +def test_invalid_speech_command_1(): + command = invalid_command_1() + passed_ri_message_validation = False + try: + # Should succeed, still. + RIMessage.model_validate(command) + passed_ri_message_validation = True + + # Should fail. + SpeechCommand.model_validate(command) + assert False + except ValidationError: + assert passed_ri_message_validation diff --git a/uv.lock b/uv.lock index bbbb640..07ec3c1 100644 --- a/uv.lock +++ b/uv.lock @@ -1336,6 +1336,10 @@ dependencies = [ { name = "pyaudio" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-cov" }, + { name = "pytest-mock" }, { name = "pyzmq" }, { name = "silero-vad" }, { name = "spade" }, @@ -1368,6 +1372,10 @@ requires-dist = [ { name = "pyaudio", specifier = ">=0.2.14" }, { name = "pydantic", specifier = ">=2.12.0" }, { name = "pydantic-settings", specifier = ">=2.11.0" }, + { name = "pytest", specifier = ">=8.4.2" }, + { name = "pytest-asyncio", specifier = ">=1.2.0" }, + { name = "pytest-cov", specifier = ">=7.0.0" }, + { name = "pytest-mock", specifier = ">=3.15.1" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "silero-vad", specifier = ">=6.0.0" }, { name = "spade", specifier = ">=4.1.0" },