diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 9e3ee5b..01fc824 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -10,13 +10,22 @@ from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) + class RICommandAgent(Agent): subsocket: zmq.Socket pubsocket: zmq.Socket address = "" bind = False - def __init__(self, jid: str, password: str, port: int = 5222, verify_security: bool = False, address = "tcp://localhost:0000", bind = False): + def __init__( + self, + jid: str, + password: str, + port: int = 5222, + verify_security: bool = False, + address="tcp://localhost:0000", + bind=False, + ): super().__init__(jid, password, port, verify_security) self.address = address self.bind = bind @@ -29,12 +38,12 @@ class RICommandAgent(Agent): assert self.agent is not None # Get a message internally (with topic command) topic, body = await self.agent.subsocket.recv_multipart() - + # Try to get body try: body = json.loads(body) message = SpeechCommand.model_validate(body) - + # Send to the robot. await self.agent.pubsocket.send_json(message.model_dump()) except Exception as e: @@ -48,11 +57,11 @@ class RICommandAgent(Agent): # To the robot self.pubsocket = context.socket(zmq.PUB) - if self.bind: + if self.bind: self.pubsocket.bind(self.address) - else : + else: self.pubsocket.connect(self.address) - + # Receive internal topics regarding commands self.subsocket = context.socket(zmq.SUB) self.subsocket.connect(settings.zmq_settings.internal_comm_address) diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 8889d7c..504c707 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -12,14 +12,21 @@ from control_backend.agents.ri_command_agent import RICommandAgent logger = logging.getLogger(__name__) + class RICommunicationAgent(Agent): req_socket: zmq.Socket _address = "" _bind = True - def __init__(self, jid: str, password: str, port: int = 5222, - verify_security: bool = False, address = "tcp://localhost:0000", - bind = False): + def __init__( + self, + jid: str, + password: str, + port: int = 5222, + verify_security: bool = False, + address="tcp://localhost:0000", + bind=False, + ): super().__init__(jid, password, port, verify_security) self._address = address self._bind = bind @@ -37,28 +44,26 @@ class RICommunicationAgent(Agent): # Wait up to three seconds for a reply:) try: - message = await asyncio.wait_for( - self.agent.req_socket.recv_json(), - timeout=3.0) + message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0) # We didnt get a reply :( - except asyncio.TimeoutError as e: + except asyncio.TimeoutError as e: logger.info("No ping retrieved in 3 seconds, killing myself.") self.kill() - logger.debug("Received message \"%s\"", message) + logger.debug('Received message "%s"', message) if "endpoint" not in message: logger.error("No received endpoint in message, excepted ping endpoint.") return - + # See what endpoint we received match message["endpoint"]: - case "ping": + case "ping": await asyncio.sleep(1) case _: - logger.info("Received message with topic different than ping," \ - " while ping expected.") - + logger.info( + "Received message with topic different than ping, while ping expected." + ) async def setup(self, max_retries: int = 5): """ @@ -67,14 +72,13 @@ class RICommunicationAgent(Agent): logger.info("Setting up %s", self.jid) retries = 0 - # Let's try a certain amount of times before failing connection while retries < max_retries: # Bind request socket self.req_socket = context.socket(zmq.REQ) - if self._bind: + if self._bind: self.req_socket.bind(self._address) - else: + else: self.req_socket.connect(self._address) # Send our message and receive one back:) @@ -85,10 +89,13 @@ class RICommunicationAgent(Agent): received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) except asyncio.TimeoutError: - logger.warning("No connection established in 20 seconds (attempt %d/%d)", - retries + 1, max_retries) + logger.warning( + "No connection established in 20 seconds (attempt %d/%d)", + retries + 1, + max_retries, + ) retries += 1 - continue + continue except Exception as e: logger.error("Unexpected error during negotiation: %s", e) @@ -99,10 +106,14 @@ class RICommunicationAgent(Agent): endpoint = received_message.get("endpoint") if endpoint != "negotiate/ports": # TODO: Should this send a message back? - logger.error("Invalid endpoint '%s' received (attempt %d/%d)", - endpoint, retries + 1, max_retries) + logger.error( + "Invalid endpoint '%s' received (attempt %d/%d)", + endpoint, + retries + 1, + max_retries, + ) retries += 1 - continue + continue # At this point, we have a valid response try: @@ -113,7 +124,7 @@ class RICommunicationAgent(Agent): if not bind: addr = f"tcp://localhost:{port}" - else: + else: addr = f"tcp://*:{port}" match id: @@ -125,11 +136,13 @@ class RICommunicationAgent(Agent): self.req_socket.bind(addr) case "actuation": ri_commands_agent = RICommandAgent( - settings.agent_settings.ri_command_agent_name + - '@' + settings.agent_settings.host, - settings.agent_settings.ri_command_agent_name, - address=addr, - bind=bind ) + settings.agent_settings.ri_command_agent_name + + "@" + + settings.agent_settings.host, + settings.agent_settings.ri_command_agent_name, + address=addr, + bind=bind, + ) await ri_commands_agent.start() case _: logger.warning("Unhandled negotiation id: %s", id) @@ -150,5 +163,3 @@ class RICommunicationAgent(Agent): listen_behaviour = self.ListenBehaviour() self.add_behaviour(listen_behaviour) logger.info("Finished setting up %s", self.jid) - - diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py index 14e6fae..badaf90 100644 --- a/src/control_backend/api/v1/endpoints/command.py +++ b/src/control_backend/api/v1/endpoints/command.py @@ -9,6 +9,7 @@ logger = logging.getLogger(__name__) router = APIRouter() + @router.post("/command", status_code=202) async def receive_command(command: SpeechCommand, request: Request): # Validate and retrieve data. @@ -16,5 +17,6 @@ async def receive_command(command: SpeechCommand, request: Request): topic = b"command" pub_socket: Socket = request.app.state.internal_comm_socket pub_socket.send_multipart([topic, command.model_dump_json().encode()]) + return {"status": "Command received"} diff --git a/src/control_backend/api/v1/router.py b/src/control_backend/api/v1/router.py index 396921b..dc7aea9 100644 --- a/src/control_backend/api/v1/router.py +++ b/src/control_backend/api/v1/router.py @@ -6,12 +6,6 @@ api_router = APIRouter() api_router.include_router(message.router, tags=["Messages"]) -api_router.include_router( - sse.router, - tags=["SSE"] -) +api_router.include_router(sse.router, tags=["SSE"]) -api_router.include_router( - command.router, - tags=["Commands"] -) +api_router.include_router(command.router, tags=["Commands"]) diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 069b7e9..f48d54f 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -15,6 +15,7 @@ class AgentSettings(BaseModel): ri_communication_agent_name: str = "ri_communication_agent" ri_command_agent_name: str = "ri_command_agent" + class Settings(BaseSettings): app_title: str = "PepperPlus" diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 7878d5e..e398552 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -31,16 +31,19 @@ async def lifespan(app: FastAPI): logger.info("Internal publishing socket bound to %s", internal_comm_socket) # Initiate agents - ri_communication_agent = RICommunicationAgent(settings.agent_settings.ri_communication_agent_name + - '@' + settings.agent_settings.host, - settings.agent_settings.ri_communication_agent_name, - address="tcp://*:5555", bind=True) + ri_communication_agent = RICommunicationAgent( + settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host, + settings.agent_settings.ri_communication_agent_name, + address="tcp://*:5555", + bind=True, + ) await ri_communication_agent.start() - bdi_core = BDICoreAgent(settings.agent_settings.bdi_core_agent_name + - '@' + settings.agent_settings.host, - settings.agent_settings.bdi_core_agent_name, - "src/control_backend/agents/bdi/rules.asl") + bdi_core = BDICoreAgent( + settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, + settings.agent_settings.bdi_core_agent_name, + "src/control_backend/agents/bdi/rules.asl", + ) await bdi_core.start() yield diff --git a/src/control_backend/schemas/ri_message.py b/src/control_backend/schemas/ri_message.py index b369703..97b7930 100644 --- a/src/control_backend/schemas/ri_message.py +++ b/src/control_backend/schemas/ri_message.py @@ -12,7 +12,7 @@ class RIEndpoint(str, Enum): class RIMessage(BaseModel): endpoint: RIEndpoint - data: Any + data: Any class SpeechCommand(RIMessage): diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index a21af3c..219d682 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -6,16 +6,20 @@ from unittest.mock import AsyncMock, MagicMock, patch from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.schemas.ri_message import SpeechCommand + @pytest.mark.asyncio async def test_setup_bind(monkeypatch): """Test setup with bind=True""" fake_socket = MagicMock() - monkeypatch.setattr("control_backend.agents.ri_command_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket + ) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) - monkeypatch.setattr("control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234"))) + monkeypatch.setattr( + "control_backend.agents.ri_command_agent.settings", + MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), + ) await agent.setup() @@ -28,29 +32,35 @@ async def test_setup_bind(monkeypatch): # Ensure behaviour attached assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours) + @pytest.mark.asyncio async def test_setup_connect(monkeypatch): """Test setup with bind=False""" fake_socket = MagicMock() - monkeypatch.setattr("control_backend.agents.ri_command_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket + ) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) - monkeypatch.setattr("control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234"))) + monkeypatch.setattr( + "control_backend.agents.ri_command_agent.settings", + MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), + ) await agent.setup() # Ensure PUB socket connected fake_socket.connect.assert_any_call("tcp://localhost:5555") + @pytest.mark.asyncio async def test_send_commands_behaviour_valid_message(): """Test behaviour with valid JSON message""" fake_socket = AsyncMock() message_dict = {"message": "hello"} - fake_socket.recv_multipart = AsyncMock(return_value=(b"command", - json.dumps(message_dict).encode("utf-8"))) + fake_socket.recv_multipart = AsyncMock( + return_value=(b"command", json.dumps(message_dict).encode("utf-8")) + ) fake_socket.send_json = AsyncMock() agent = RICommandAgent("test@server", "password") @@ -60,14 +70,15 @@ async def test_send_commands_behaviour_valid_message(): behaviour = agent.SendCommandsBehaviour() behaviour.agent = agent - with patch('control_backend.agents.ri_command_agent.SpeechCommand') as MockSpeechCommand: + with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand: mock_message = MagicMock() MockSpeechCommand.model_validate.return_value = mock_message await behaviour.run() fake_socket.recv_multipart.assert_awaited() - fake_socket.send_json.assert_awaited_with(mock_message) + fake_socket.send_json.assert_awaited_with(mock_message.model_dump()) + @pytest.mark.asyncio async def test_send_commands_behaviour_invalid_message(caplog): diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index d778640..3e4a056 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -3,60 +3,84 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch, ANY from control_backend.agents.ri_communication_agent import RICommunicationAgent + def fake_json_correct_negototiate_1(): - return AsyncMock(return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 5555, "bind": False}, - {"id": "actuation", "port": 5556, "bind": True}, - ]}) + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 5555, "bind": False}, + {"id": "actuation", "port": 5556, "bind": True}, + ], + } + ) + def fake_json_correct_negototiate_2(): - return AsyncMock(return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 5555, "bind": False}, - {"id": "actuation", "port": 5557, "bind": True}, - ]}) + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 5555, "bind": False}, + {"id": "actuation", "port": 5557, "bind": True}, + ], + } + ) + def fake_json_correct_negototiate_3(): - return AsyncMock(return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 5555, "bind": True}, - {"id": "actuation", "port": 5557, "bind": True}, - ]}) + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 5555, "bind": True}, + {"id": "actuation", "port": 5557, "bind": True}, + ], + } + ) + def fake_json_correct_negototiate_4(): # Different port, do bind - return AsyncMock(return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 4555, "bind": True}, - {"id": "actuation", "port": 5557, "bind": True}, - ]}) + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 4555, "bind": True}, + {"id": "actuation", "port": 5557, "bind": True}, + ], + } + ) + def fake_json_correct_negototiate_5(): # Different port, dont bind. - return AsyncMock(return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 4555, "bind": False}, - {"id": "actuation", "port": 5557, "bind": True}, - ]}) + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 4555, "bind": False}, + {"id": "actuation", "port": 5557, "bind": True}, + ], + } + ) + def fake_json_wrong_negototiate_1(): - return AsyncMock(return_value={ - "endpoint": "ping", - "data": ""}) + return AsyncMock(return_value={"endpoint": "ping", "data": ""}) + def fake_json_invalid_id_negototiate(): - return AsyncMock(return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "banana", "port": 4555, "bind": False}, - {"id": "tomato", "port": 5557, "bind": True}, - ]}) + return AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "banana", "port": 4555, "bind": False}, + {"id": "tomato", "port": 5557, "bind": True}, + ], + } + ) + @pytest.mark.asyncio async def test_setup_creates_socket_and_negotiate_1(monkeypatch): @@ -67,20 +91,23 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_1() - + # Mock context.socket to return our fake socket - monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) # Mock RICommandAgent agent startup - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() # --- Act --- - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) await agent.setup() # --- Assert --- @@ -89,14 +116,15 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch): fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password + ANY, # Server Name + ANY, # Server Password address="tcp://*:5556", # derived from the 'port' value in negotiation - bind=True + bind=True, ) # Ensure the agent attached a ListenBehaviour assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + @pytest.mark.asyncio async def test_setup_creates_socket_and_negotiate_2(monkeypatch): """ @@ -106,20 +134,23 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_2() - + # Mock context.socket to return our fake socket - monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) # Mock RICommandAgent agent startup - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() # --- Act --- - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) await agent.setup() # --- Assert --- @@ -128,14 +159,15 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch): fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password + ANY, # Server Name + ANY, # Server Password address="tcp://*:5557", # derived from the 'port' value in negotiation - bind=True + bind=True, ) # Ensure the agent attached a ListenBehaviour assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + @pytest.mark.asyncio async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): """ @@ -145,25 +177,27 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_wrong_negototiate_1() - + # Mock context.socket to return our fake socket - monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) # Mock RICommandAgent agent startup - # We are sending wrong negotiation info to the communication agent, so we should retry and expect a # better response, within a limited time. - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() # --- Act --- with caplog.at_level("ERROR"): - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) await agent.setup(max_retries=1) # --- Assert --- @@ -173,10 +207,11 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): # Since it failed, there should not be any command agent. fake_agent_instance.start.assert_not_awaited() assert "Failed to set up RICommunicationAgent" in caplog.text - + # Ensure the agent did not attach a ListenBehaviour assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + @pytest.mark.asyncio async def test_setup_creates_socket_and_negotiate_4(monkeypatch): """ @@ -186,20 +221,23 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_3() - + # Mock context.socket to return our fake socket - monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) # Mock RICommandAgent agent startup - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() # --- Act --- - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=True) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=True + ) await agent.setup() # --- Assert --- @@ -208,14 +246,15 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch): fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password + ANY, # Server Name + ANY, # Server Password address="tcp://*:5557", # derived from the 'port' value in negotiation - bind=True + bind=True, ) # Ensure the agent attached a ListenBehaviour assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + @pytest.mark.asyncio async def test_setup_creates_socket_and_negotiate_5(monkeypatch): """ @@ -225,20 +264,23 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_4() - + # Mock context.socket to return our fake socket - monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) # Mock RICommandAgent agent startup - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() # --- Act --- - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) await agent.setup() # --- Assert --- @@ -247,14 +289,15 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch): fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password + ANY, # Server Name + ANY, # Server Password address="tcp://*:5557", # derived from the 'port' value in negotiation - bind=True + bind=True, ) # Ensure the agent attached a ListenBehaviour assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + @pytest.mark.asyncio async def test_setup_creates_socket_and_negotiate_6(monkeypatch): """ @@ -264,20 +307,23 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_5() - + # Mock context.socket to return our fake socket - monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) # Mock RICommandAgent agent startup - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() # --- Act --- - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) await agent.setup() # --- Assert --- @@ -286,14 +332,15 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch): fake_socket.recv_json.assert_awaited() fake_agent_instance.start.assert_awaited() MockCommandAgent.assert_called_once_with( - ANY, # Server Name - ANY, # Server Password + ANY, # Server Name + ANY, # Server Password address="tcp://*:5557", # derived from the 'port' value in negotiation - bind=True + bind=True, ) # Ensure the agent attached a ListenBehaviour assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + @pytest.mark.asyncio async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): """ @@ -303,25 +350,27 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_invalid_id_negototiate() - + # Mock context.socket to return our fake socket - monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket) + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) # Mock RICommandAgent agent startup - # We are sending wrong negotiation info to the communication agent, so we should retry and expect a # better response, within a limited time. - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() # --- Act --- with caplog.at_level("WARNING"): - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) await agent.setup(max_retries=1) # --- Assert --- @@ -342,32 +391,36 @@ async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) - - # Mock context.socket to return our fake socket - monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket) - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + # Mock context.socket to return our fake socket + monkeypatch.setattr( + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + ) + + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() # --- Act --- with caplog.at_level("WARNING"): - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) await agent.setup(max_retries=1) # --- Assert --- fake_socket.connect.assert_any_call("tcp://localhost:5555") - + # Since it failed, there should not be any command agent. fake_agent_instance.start.assert_not_awaited() assert "No connection established in 20 seconds" in caplog.text - + # Ensure the agent did not attach a ListenBehaviour assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) + @pytest.mark.asyncio async def test_listen_behaviour_ping_correct(caplog): fake_socket = AsyncMock() @@ -389,6 +442,7 @@ async def test_listen_behaviour_ping_correct(caplog): fake_socket.recv_json.assert_awaited() assert "Received message" in caplog.text + @pytest.mark.asyncio async def test_listen_behaviour_ping_wrong_endpoint(caplog): """ @@ -398,12 +452,15 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog): fake_socket.send_json = AsyncMock() # This is a message for the wrong endpoint >:( - fake_socket.recv_json = AsyncMock(return_value={ - "endpoint": "negotiate/ports", - "data": [ - {"id": "main", "port": 5555, "bind": False}, - {"id": "actuation", "port": 5556, "bind": True}, - ]}) + fake_socket.recv_json = AsyncMock( + return_value={ + "endpoint": "negotiate/ports", + "data": [ + {"id": "main", "port": 5555, "bind": False}, + {"id": "actuation", "port": 5556, "bind": True}, + ], + } + ) agent = RICommunicationAgent("test@server", "password") agent.req_socket = fake_socket @@ -415,11 +472,11 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog): with caplog.at_level("INFO"): await behaviour.run() - assert "Received message with topic different than ping, while ping expected." in caplog.text fake_socket.send_json.assert_awaited() fake_socket.recv_json.assert_awaited() + @pytest.mark.asyncio async def test_listen_behaviour_timeout(caplog): fake_socket = AsyncMock() @@ -438,6 +495,7 @@ async def test_listen_behaviour_timeout(caplog): assert "No ping retrieved in 3 seconds" in caplog.text + @pytest.mark.asyncio async def test_listen_behaviour_ping_no_endpoint(caplog): """ @@ -447,9 +505,11 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): fake_socket.send_json = AsyncMock() # This is a message without endpoint >:( - fake_socket.recv_json = AsyncMock(return_value={ - "data": "I dont have an endpoint >:)", - }) + fake_socket.recv_json = AsyncMock( + return_value={ + "data": "I dont have an endpoint >:)", + } + ) agent = RICommunicationAgent("test@server", "password") agent.req_socket = fake_socket @@ -465,6 +525,7 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): fake_socket.send_json.assert_awaited() fake_socket.recv_json.assert_awaited() + @pytest.mark.asyncio async def test_setup_unexpected_exception(monkeypatch, caplog): fake_socket = MagicMock() @@ -473,12 +534,12 @@ async def test_setup_unexpected_exception(monkeypatch, caplog): fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket ) - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) with caplog.at_level("ERROR"): await agent.setup(max_retries=1) @@ -486,6 +547,7 @@ async def test_setup_unexpected_exception(monkeypatch, caplog): # Ensure that the error was logged assert "Unexpected error during negotiation: boom!" in caplog.text + @pytest.mark.asyncio async def test_setup_unpacking_exception(monkeypatch, caplog): # --- Arrange --- @@ -493,24 +555,27 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): fake_socket.send_json = AsyncMock() # Make recv_json return malformed negotiation data to trigger unpacking exception - malformed_data = {"endpoint": "negotiate/ports", - "data": [ {"id": "main"} ]} # missing 'port' and 'bind' + malformed_data = { + "endpoint": "negotiate/ports", + "data": [{"id": "main"}], + } # missing 'port' and 'bind' fake_socket.recv_json = AsyncMock(return_value=malformed_data) # Patch context.socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", - lambda _: fake_socket + "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket ) # Patch RICommandAgent so it won't actually start - with patch("control_backend.agents.ri_communication_agent.RICommandAgent", - autospec=True) as MockCommandAgent: + with patch( + "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True + ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - agent = RICommunicationAgent("test@server", "password", - address="tcp://localhost:5555", bind=False) + agent = RICommunicationAgent( + "test@server", "password", address="tcp://localhost:5555", bind=False + ) # --- Act & Assert --- with caplog.at_level("ERROR"): @@ -523,4 +588,4 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): fake_agent_instance.start.assert_not_awaited() # Ensure no behaviour was attached - assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) \ No newline at end of file + assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index 3ab1be3..07bd866 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock from control_backend.api.v1.endpoints import command from control_backend.schemas.ri_message import SpeechCommand + @pytest.fixture def app(): """ @@ -49,7 +50,7 @@ def test_receive_command_endpoint(client, app): assert sent_data[0] == b"command" # Check JSON encoding roughly matches - assert isinstance(sent_data[1], SpeechCommand) + assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand) def test_receive_command_invalid_payload(client): @@ -59,4 +60,4 @@ def test_receive_command_invalid_payload(client): # Missing required field(s) bad_payload = {"invalid": "data"} response = client.post("/command", json=bad_payload) - assert response.status_code == 422 # validation error \ No newline at end of file + assert response.status_code == 422 # validation error diff --git a/test/integration/schemas/test_ri_message.py b/test/integration/schemas/test_ri_message.py index b840f97..aef9ae6 100644 --- a/test/integration/schemas/test_ri_message.py +++ b/test/integration/schemas/test_ri_message.py @@ -2,12 +2,15 @@ import pytest from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand from pydantic import ValidationError + def valid_command_1(): return SpeechCommand(data="Hallo?") + def invalid_command_1(): return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.") + def test_valid_speech_command_1(): command = valid_command_1() try: @@ -16,7 +19,7 @@ def test_valid_speech_command_1(): assert True except ValidationError: assert False - + def test_invalid_speech_command_1(): command = invalid_command_1() @@ -31,5 +34,3 @@ def test_invalid_speech_command_1(): assert False except ValidationError: assert passed_ri_message_validation - - \ No newline at end of file