fix: unit tests fixes and ruff formating

N25B-205
This commit is contained in:
Björn Otgaar
2025-10-28 11:31:05 +01:00
parent 52faa59184
commit 47a87d0b4a
11 changed files with 307 additions and 209 deletions

View File

@@ -10,13 +10,22 @@ from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RICommandAgent(Agent): class RICommandAgent(Agent):
subsocket: zmq.Socket subsocket: zmq.Socket
pubsocket: zmq.Socket pubsocket: zmq.Socket
address = "" address = ""
bind = False bind = False
def __init__(self, jid: str, password: str, port: int = 5222, verify_security: bool = False, address = "tcp://localhost:0000", bind = False): def __init__(
self,
jid: str,
password: str,
port: int = 5222,
verify_security: bool = False,
address="tcp://localhost:0000",
bind=False,
):
super().__init__(jid, password, port, verify_security) super().__init__(jid, password, port, verify_security)
self.address = address self.address = address
self.bind = bind self.bind = bind
@@ -50,7 +59,7 @@ class RICommandAgent(Agent):
self.pubsocket = context.socket(zmq.PUB) self.pubsocket = context.socket(zmq.PUB)
if self.bind: if self.bind:
self.pubsocket.bind(self.address) self.pubsocket.bind(self.address)
else : else:
self.pubsocket.connect(self.address) self.pubsocket.connect(self.address)
# Receive internal topics regarding commands # Receive internal topics regarding commands

View File

@@ -12,14 +12,21 @@ from control_backend.agents.ri_command_agent import RICommandAgent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RICommunicationAgent(Agent): class RICommunicationAgent(Agent):
req_socket: zmq.Socket req_socket: zmq.Socket
_address = "" _address = ""
_bind = True _bind = True
def __init__(self, jid: str, password: str, port: int = 5222, def __init__(
verify_security: bool = False, address = "tcp://localhost:0000", self,
bind = False): jid: str,
password: str,
port: int = 5222,
verify_security: bool = False,
address="tcp://localhost:0000",
bind=False,
):
super().__init__(jid, password, port, verify_security) super().__init__(jid, password, port, verify_security)
self._address = address self._address = address
self._bind = bind self._bind = bind
@@ -37,16 +44,14 @@ class RICommunicationAgent(Agent):
# Wait up to three seconds for a reply:) # Wait up to three seconds for a reply:)
try: try:
message = await asyncio.wait_for( message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
self.agent.req_socket.recv_json(),
timeout=3.0)
# We didnt get a reply :( # We didnt get a reply :(
except asyncio.TimeoutError as e: except asyncio.TimeoutError as e:
logger.info("No ping retrieved in 3 seconds, killing myself.") logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill() self.kill()
logger.debug("Received message \"%s\"", message) logger.debug('Received message "%s"', message)
if "endpoint" not in message: if "endpoint" not in message:
logger.error("No received endpoint in message, excepted ping endpoint.") logger.error("No received endpoint in message, excepted ping endpoint.")
return return
@@ -56,9 +61,9 @@ class RICommunicationAgent(Agent):
case "ping": case "ping":
await asyncio.sleep(1) await asyncio.sleep(1)
case _: case _:
logger.info("Received message with topic different than ping," \ logger.info(
" while ping expected.") "Received message with topic different than ping, while ping expected."
)
async def setup(self, max_retries: int = 5): async def setup(self, max_retries: int = 5):
""" """
@@ -67,7 +72,6 @@ class RICommunicationAgent(Agent):
logger.info("Setting up %s", self.jid) logger.info("Setting up %s", self.jid)
retries = 0 retries = 0
# Let's try a certain amount of times before failing connection # Let's try a certain amount of times before failing connection
while retries < max_retries: while retries < max_retries:
# Bind request socket # Bind request socket
@@ -85,8 +89,11 @@ class RICommunicationAgent(Agent):
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning("No connection established in 20 seconds (attempt %d/%d)", logger.warning(
retries + 1, max_retries) "No connection established in 20 seconds (attempt %d/%d)",
retries + 1,
max_retries,
)
retries += 1 retries += 1
continue continue
@@ -99,8 +106,12 @@ class RICommunicationAgent(Agent):
endpoint = received_message.get("endpoint") endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports": if endpoint != "negotiate/ports":
# TODO: Should this send a message back? # TODO: Should this send a message back?
logger.error("Invalid endpoint '%s' received (attempt %d/%d)", logger.error(
endpoint, retries + 1, max_retries) "Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
max_retries,
)
retries += 1 retries += 1
continue continue
@@ -125,11 +136,13 @@ class RICommunicationAgent(Agent):
self.req_socket.bind(addr) self.req_socket.bind(addr)
case "actuation": case "actuation":
ri_commands_agent = RICommandAgent( ri_commands_agent = RICommandAgent(
settings.agent_settings.ri_command_agent_name + settings.agent_settings.ri_command_agent_name
'@' + settings.agent_settings.host, + "@"
+ settings.agent_settings.host,
settings.agent_settings.ri_command_agent_name, settings.agent_settings.ri_command_agent_name,
address=addr, address=addr,
bind=bind ) bind=bind,
)
await ri_commands_agent.start() await ri_commands_agent.start()
case _: case _:
logger.warning("Unhandled negotiation id: %s", id) logger.warning("Unhandled negotiation id: %s", id)
@@ -150,5 +163,3 @@ class RICommunicationAgent(Agent):
listen_behaviour = self.ListenBehaviour() listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour) self.add_behaviour(listen_behaviour)
logger.info("Finished setting up %s", self.jid) logger.info("Finished setting up %s", self.jid)

View File

@@ -9,6 +9,7 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@router.post("/command", status_code=202) @router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request): async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data. # Validate and retrieve data.
@@ -17,4 +18,5 @@ async def receive_command(command: SpeechCommand, request: Request):
pub_socket: Socket = request.app.state.internal_comm_socket pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()]) pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"} return {"status": "Command received"}

View File

@@ -6,12 +6,6 @@ api_router = APIRouter()
api_router.include_router(message.router, tags=["Messages"]) api_router.include_router(message.router, tags=["Messages"])
api_router.include_router( api_router.include_router(sse.router, tags=["SSE"])
sse.router,
tags=["SSE"]
)
api_router.include_router( api_router.include_router(command.router, tags=["Commands"])
command.router,
tags=["Commands"]
)

View File

@@ -15,6 +15,7 @@ class AgentSettings(BaseModel):
ri_communication_agent_name: str = "ri_communication_agent" ri_communication_agent_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent" ri_command_agent_name: str = "ri_command_agent"
class Settings(BaseSettings): class Settings(BaseSettings):
app_title: str = "PepperPlus" app_title: str = "PepperPlus"

View File

@@ -31,16 +31,19 @@ async def lifespan(app: FastAPI):
logger.info("Internal publishing socket bound to %s", internal_comm_socket) logger.info("Internal publishing socket bound to %s", internal_comm_socket)
# Initiate agents # Initiate agents
ri_communication_agent = RICommunicationAgent(settings.agent_settings.ri_communication_agent_name + ri_communication_agent = RICommunicationAgent(
'@' + settings.agent_settings.host, settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.ri_communication_agent_name, settings.agent_settings.ri_communication_agent_name,
address="tcp://*:5555", bind=True) address="tcp://*:5555",
bind=True,
)
await ri_communication_agent.start() await ri_communication_agent.start()
bdi_core = BDICoreAgent(settings.agent_settings.bdi_core_agent_name + bdi_core = BDICoreAgent(
'@' + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name, settings.agent_settings.bdi_core_agent_name,
"src/control_backend/agents/bdi/rules.asl") "src/control_backend/agents/bdi/rules.asl",
)
await bdi_core.start() await bdi_core.start()
yield yield

View File

@@ -6,16 +6,20 @@ from unittest.mock import AsyncMock, MagicMock, patch
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_bind(monkeypatch): async def test_setup_bind(monkeypatch):
"""Test setup with bind=True""" """Test setup with bind=True"""
fake_socket = MagicMock() fake_socket = MagicMock()
monkeypatch.setattr("control_backend.agents.ri_command_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
monkeypatch.setattr("control_backend.agents.ri_command_agent.settings", monkeypatch.setattr(
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234"))) "control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup() await agent.setup()
@@ -28,29 +32,35 @@ async def test_setup_bind(monkeypatch):
# Ensure behaviour attached # Ensure behaviour attached
assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours) assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_connect(monkeypatch): async def test_setup_connect(monkeypatch):
"""Test setup with bind=False""" """Test setup with bind=False"""
fake_socket = MagicMock() fake_socket = MagicMock()
monkeypatch.setattr("control_backend.agents.ri_command_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
monkeypatch.setattr("control_backend.agents.ri_command_agent.settings", monkeypatch.setattr(
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234"))) "control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup() await agent.setup()
# Ensure PUB socket connected # Ensure PUB socket connected
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_commands_behaviour_valid_message(): async def test_send_commands_behaviour_valid_message():
"""Test behaviour with valid JSON message""" """Test behaviour with valid JSON message"""
fake_socket = AsyncMock() fake_socket = AsyncMock()
message_dict = {"message": "hello"} message_dict = {"message": "hello"}
fake_socket.recv_multipart = AsyncMock(return_value=(b"command", fake_socket.recv_multipart = AsyncMock(
json.dumps(message_dict).encode("utf-8"))) return_value=(b"command", json.dumps(message_dict).encode("utf-8"))
)
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password") agent = RICommandAgent("test@server", "password")
@@ -60,14 +70,15 @@ async def test_send_commands_behaviour_valid_message():
behaviour = agent.SendCommandsBehaviour() behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent behaviour.agent = agent
with patch('control_backend.agents.ri_command_agent.SpeechCommand') as MockSpeechCommand: with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand:
mock_message = MagicMock() mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message MockSpeechCommand.model_validate.return_value = mock_message
await behaviour.run() await behaviour.run()
fake_socket.recv_multipart.assert_awaited() fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_awaited_with(mock_message) fake_socket.send_json.assert_awaited_with(mock_message.model_dump())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_commands_behaviour_invalid_message(caplog): async def test_send_commands_behaviour_invalid_message(caplog):

View File

@@ -3,60 +3,84 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY from unittest.mock import AsyncMock, MagicMock, patch, ANY
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.ri_communication_agent import RICommunicationAgent
def fake_json_correct_negototiate_1(): def fake_json_correct_negototiate_1():
return AsyncMock(return_value={ return AsyncMock(
return_value={
"endpoint": "negotiate/ports", "endpoint": "negotiate/ports",
"data": [ "data": [
{"id": "main", "port": 5555, "bind": False}, {"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True}, {"id": "actuation", "port": 5556, "bind": True},
]}) ],
}
)
def fake_json_correct_negototiate_2(): def fake_json_correct_negototiate_2():
return AsyncMock(return_value={ return AsyncMock(
return_value={
"endpoint": "negotiate/ports", "endpoint": "negotiate/ports",
"data": [ "data": [
{"id": "main", "port": 5555, "bind": False}, {"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True}, {"id": "actuation", "port": 5557, "bind": True},
]}) ],
}
)
def fake_json_correct_negototiate_3(): def fake_json_correct_negototiate_3():
return AsyncMock(return_value={ return AsyncMock(
return_value={
"endpoint": "negotiate/ports", "endpoint": "negotiate/ports",
"data": [ "data": [
{"id": "main", "port": 5555, "bind": True}, {"id": "main", "port": 5555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True}, {"id": "actuation", "port": 5557, "bind": True},
]}) ],
}
)
def fake_json_correct_negototiate_4(): def fake_json_correct_negototiate_4():
# Different port, do bind # Different port, do bind
return AsyncMock(return_value={ return AsyncMock(
return_value={
"endpoint": "negotiate/ports", "endpoint": "negotiate/ports",
"data": [ "data": [
{"id": "main", "port": 4555, "bind": True}, {"id": "main", "port": 4555, "bind": True},
{"id": "actuation", "port": 5557, "bind": True}, {"id": "actuation", "port": 5557, "bind": True},
]}) ],
}
)
def fake_json_correct_negototiate_5(): def fake_json_correct_negototiate_5():
# Different port, dont bind. # Different port, dont bind.
return AsyncMock(return_value={ return AsyncMock(
return_value={
"endpoint": "negotiate/ports", "endpoint": "negotiate/ports",
"data": [ "data": [
{"id": "main", "port": 4555, "bind": False}, {"id": "main", "port": 4555, "bind": False},
{"id": "actuation", "port": 5557, "bind": True}, {"id": "actuation", "port": 5557, "bind": True},
]}) ],
}
)
def fake_json_wrong_negototiate_1(): def fake_json_wrong_negototiate_1():
return AsyncMock(return_value={ return AsyncMock(return_value={"endpoint": "ping", "data": ""})
"endpoint": "ping",
"data": ""})
def fake_json_invalid_id_negototiate(): def fake_json_invalid_id_negototiate():
return AsyncMock(return_value={ return AsyncMock(
return_value={
"endpoint": "negotiate/ports", "endpoint": "negotiate/ports",
"data": [ "data": [
{"id": "banana", "port": 4555, "bind": False}, {"id": "banana", "port": 4555, "bind": False},
{"id": "tomato", "port": 5557, "bind": True}, {"id": "tomato", "port": 5557, "bind": True},
]}) ],
}
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(monkeypatch): async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
@@ -69,18 +93,21 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
fake_socket.recv_json = fake_json_correct_negototiate_1() fake_socket.recv_json = fake_json_correct_negototiate_1()
# Mock context.socket to return our fake socket # Mock context.socket to return our fake socket
monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -92,11 +119,12 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
ANY, # Server Name ANY, # Server Name
ANY, # Server Password ANY, # Server Password
address="tcp://*:5556", # derived from the 'port' value in negotiation address="tcp://*:5556", # derived from the 'port' value in negotiation
bind=True bind=True,
) )
# Ensure the agent attached a ListenBehaviour # Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(monkeypatch): async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
""" """
@@ -108,18 +136,21 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
fake_socket.recv_json = fake_json_correct_negototiate_2() fake_socket.recv_json = fake_json_correct_negototiate_2()
# Mock context.socket to return our fake socket # Mock context.socket to return our fake socket
monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -131,11 +162,12 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
ANY, # Server Name ANY, # Server Name
ANY, # Server Password ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True bind=True,
) )
# Ensure the agent attached a ListenBehaviour # Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
""" """
@@ -147,23 +179,25 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
fake_socket.recv_json = fake_json_wrong_negototiate_1() fake_socket.recv_json = fake_json_wrong_negototiate_1()
# Mock context.socket to return our fake socket # Mock context.socket to return our fake socket
monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a # We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time. # better response, within a limited time.
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("ERROR"): with caplog.at_level("ERROR"):
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
@@ -177,6 +211,7 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
# Ensure the agent did not attach a ListenBehaviour # Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(monkeypatch): async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
""" """
@@ -188,18 +223,21 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
fake_socket.recv_json = fake_json_correct_negototiate_3() fake_socket.recv_json = fake_json_correct_negototiate_3()
# Mock context.socket to return our fake socket # Mock context.socket to return our fake socket
monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=True) "test@server", "password", address="tcp://localhost:5555", bind=True
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -211,11 +249,12 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
ANY, # Server Name ANY, # Server Name
ANY, # Server Password ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True bind=True,
) )
# Ensure the agent attached a ListenBehaviour # Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(monkeypatch): async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
""" """
@@ -227,18 +266,21 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
fake_socket.recv_json = fake_json_correct_negototiate_4() fake_socket.recv_json = fake_json_correct_negototiate_4()
# Mock context.socket to return our fake socket # Mock context.socket to return our fake socket
monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -250,11 +292,12 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
ANY, # Server Name ANY, # Server Name
ANY, # Server Password ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True bind=True,
) )
# Ensure the agent attached a ListenBehaviour # Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(monkeypatch): async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
""" """
@@ -266,18 +309,21 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
fake_socket.recv_json = fake_json_correct_negototiate_5() fake_socket.recv_json = fake_json_correct_negototiate_5()
# Mock context.socket to return our fake socket # Mock context.socket to return our fake socket
monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -289,11 +335,12 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
ANY, # Server Name ANY, # Server Name
ANY, # Server Password ANY, # Server Password
address="tcp://*:5557", # derived from the 'port' value in negotiation address="tcp://*:5557", # derived from the 'port' value in negotiation
bind=True bind=True,
) )
# Ensure the agent attached a ListenBehaviour # Ensure the agent attached a ListenBehaviour
assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
""" """
@@ -305,23 +352,25 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
fake_socket.recv_json = fake_json_invalid_id_negototiate() fake_socket.recv_json = fake_json_invalid_id_negototiate()
# Mock context.socket to return our fake socket # Mock context.socket to return our fake socket
monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a # We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time. # better response, within a limited time.
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("WARNING"): with caplog.at_level("WARNING"):
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
@@ -344,18 +393,21 @@ async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
# Mock context.socket to return our fake socket # Mock context.socket to return our fake socket
monkeypatch.setattr("control_backend.agents.ri_communication_agent.context.socket", monkeypatch.setattr(
lambda _: fake_socket) "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("WARNING"): with caplog.at_level("WARNING"):
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
@@ -368,6 +420,7 @@ async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
# Ensure the agent did not attach a ListenBehaviour # Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_ping_correct(caplog): async def test_listen_behaviour_ping_correct(caplog):
fake_socket = AsyncMock() fake_socket = AsyncMock()
@@ -389,6 +442,7 @@ async def test_listen_behaviour_ping_correct(caplog):
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
assert "Received message" in caplog.text assert "Received message" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_ping_wrong_endpoint(caplog): async def test_listen_behaviour_ping_wrong_endpoint(caplog):
""" """
@@ -398,12 +452,15 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# This is a message for the wrong endpoint >:( # This is a message for the wrong endpoint >:(
fake_socket.recv_json = AsyncMock(return_value={ fake_socket.recv_json = AsyncMock(
return_value={
"endpoint": "negotiate/ports", "endpoint": "negotiate/ports",
"data": [ "data": [
{"id": "main", "port": 5555, "bind": False}, {"id": "main", "port": 5555, "bind": False},
{"id": "actuation", "port": 5556, "bind": True}, {"id": "actuation", "port": 5556, "bind": True},
]}) ],
}
)
agent = RICommunicationAgent("test@server", "password") agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket agent.req_socket = fake_socket
@@ -415,11 +472,11 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
with caplog.at_level("INFO"): with caplog.at_level("INFO"):
await behaviour.run() await behaviour.run()
assert "Received message with topic different than ping, while ping expected." in caplog.text assert "Received message with topic different than ping, while ping expected." in caplog.text
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_timeout(caplog): async def test_listen_behaviour_timeout(caplog):
fake_socket = AsyncMock() fake_socket = AsyncMock()
@@ -438,6 +495,7 @@ async def test_listen_behaviour_timeout(caplog):
assert "No ping retrieved in 3 seconds" in caplog.text assert "No ping retrieved in 3 seconds" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_ping_no_endpoint(caplog): async def test_listen_behaviour_ping_no_endpoint(caplog):
""" """
@@ -447,9 +505,11 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# This is a message without endpoint >:( # This is a message without endpoint >:(
fake_socket.recv_json = AsyncMock(return_value={ fake_socket.recv_json = AsyncMock(
return_value={
"data": "I dont have an endpoint >:)", "data": "I dont have an endpoint >:)",
}) }
)
agent = RICommunicationAgent("test@server", "password") agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket agent.req_socket = fake_socket
@@ -465,6 +525,7 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unexpected_exception(monkeypatch, caplog): async def test_setup_unexpected_exception(monkeypatch, caplog):
fake_socket = MagicMock() fake_socket = MagicMock()
@@ -473,12 +534,12 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
monkeypatch.setattr( monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
lambda _: fake_socket
) )
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
with caplog.at_level("ERROR"): with caplog.at_level("ERROR"):
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
@@ -486,6 +547,7 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
# Ensure that the error was logged # Ensure that the error was logged
assert "Unexpected error during negotiation: boom!" in caplog.text assert "Unexpected error during negotiation: boom!" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unpacking_exception(monkeypatch, caplog): async def test_setup_unpacking_exception(monkeypatch, caplog):
# --- Arrange --- # --- Arrange ---
@@ -493,24 +555,27 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# Make recv_json return malformed negotiation data to trigger unpacking exception # Make recv_json return malformed negotiation data to trigger unpacking exception
malformed_data = {"endpoint": "negotiate/ports", malformed_data = {
"data": [ {"id": "main"} ]} # missing 'port' and 'bind' "endpoint": "negotiate/ports",
"data": [{"id": "main"}],
} # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data) fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch context.socket # Patch context.socket
monkeypatch.setattr( monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
lambda _: fake_socket
) )
# Patch RICommandAgent so it won't actually start # Patch RICommandAgent so it won't actually start
with patch("control_backend.agents.ri_communication_agent.RICommandAgent", with patch(
autospec=True) as MockCommandAgent: "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
agent = RICommunicationAgent("test@server", "password", agent = RICommunicationAgent(
address="tcp://localhost:5555", bind=False) "test@server", "password", address="tcp://localhost:5555", bind=False
)
# --- Act & Assert --- # --- Act & Assert ---
with caplog.at_level("ERROR"): with caplog.at_level("ERROR"):

View File

@@ -6,6 +6,7 @@ from unittest.mock import MagicMock
from control_backend.api.v1.endpoints import command from control_backend.api.v1.endpoints import command
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
@pytest.fixture @pytest.fixture
def app(): def app():
""" """
@@ -49,7 +50,7 @@ def test_receive_command_endpoint(client, app):
assert sent_data[0] == b"command" assert sent_data[0] == b"command"
# Check JSON encoding roughly matches # Check JSON encoding roughly matches
assert isinstance(sent_data[1], SpeechCommand) assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
def test_receive_command_invalid_payload(client): def test_receive_command_invalid_payload(client):

View File

@@ -2,12 +2,15 @@ import pytest
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
from pydantic import ValidationError from pydantic import ValidationError
def valid_command_1(): def valid_command_1():
return SpeechCommand(data="Hallo?") return SpeechCommand(data="Hallo?")
def invalid_command_1(): def invalid_command_1():
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.") return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
def test_valid_speech_command_1(): def test_valid_speech_command_1():
command = valid_command_1() command = valid_command_1()
try: try:
@@ -31,5 +34,3 @@ def test_invalid_speech_command_1():
assert False assert False
except ValidationError: except ValidationError:
assert passed_ri_message_validation assert passed_ri_message_validation