fix: ruff checks is now in order:)

ref: N25B-205
This commit is contained in:
Björn Otgaar
2025-10-30 16:41:35 +01:00
parent af3e4ae56a
commit 30453be4b2
10 changed files with 117 additions and 84 deletions

View File

@@ -1,8 +1,9 @@
import json
import logging
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from control_backend.core.config import settings
from control_backend.core.zmq_context import context

View File

@@ -1,23 +1,21 @@
import asyncio
import json
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
import zmq.asyncio
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import RIMessage
from control_backend.agents.ri_command_agent import RICommandAgent
logger = logging.getLogger(__name__)
class RICommunicationAgent(Agent):
_pub_socket: zmq.asyncio.Socket
_pub_socket: zmq.asyncio.Socket
req_socket: zmq.asyncio.Socket | None
_address = ""
_bind = True
@@ -32,7 +30,6 @@ class RICommunicationAgent(Agent):
verify_security: bool = False,
address="tcp://localhost:0000",
bind=False,
):
super().__init__(jid, password, port, verify_security)
self._address = address
@@ -54,9 +51,10 @@ class RICommunicationAgent(Agent):
await asyncio.wait_for(
self.agent.req_socket.send_json(message), timeout=seconds_to_wait_total / 2
)
except TimeoutError as e:
except TimeoutError:
logger.debug(
f"Waited too long to send message - we probably dont have any receivers... but let's check!"
"Waited too long to send message - "
"we probably dont have any receivers... but let's check!"
)
# Wait up to three seconds for a reply:)
@@ -67,8 +65,11 @@ class RICommunicationAgent(Agent):
)
# We didnt get a reply :(
except TimeoutError as e:
logger.info(f"No ping back retrieved in {seconds_to_wait_total/2} seconds totalling {seconds_to_wait_total} of time, killing myself (or maybe just laying low).")
except TimeoutError:
logger.info(
f"No ping back retrieved in {seconds_to_wait_total / 2} seconds totalling"
f"{seconds_to_wait_total} of time, killing myself (or maybe just laying low)."
)
# TODO: Send event to UI letting know that we've lost connection
topic = b"ping"
data = json.dumps(False).encode()
@@ -95,8 +96,7 @@ class RICommunicationAgent(Agent):
"Received message with topic different than ping, while ping expected."
)
async def setup_req_socket(self, force = False):
async def setup_req_socket(self, force=False):
"""
Sets up request socket for communication agent.
"""
@@ -107,7 +107,6 @@ class RICommunicationAgent(Agent):
else:
self.req_socket.connect(self._address)
async def setup(self, max_retries: int = 5):
"""
Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
@@ -116,15 +115,14 @@ class RICommunicationAgent(Agent):
# Bind request socket
await self.setup_req_socket()
retries = 0
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Make sure the socket is properly setup.
if self.req_socket is None:
continue
# Send our message and receive one back:)
message = {"endpoint": "negotiate/ports", "data": {}}
await self.req_socket.send_json(message)
@@ -132,7 +130,7 @@ class RICommunicationAgent(Agent):
try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"No connection established in 20 seconds (attempt %d/%d)",
retries + 1,

View File

@@ -1,18 +1,15 @@
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse, StreamingResponse
import logging
import asyncio
import zmq.asyncio
import json
import datetime
import logging
import zmq.asyncio
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from zmq.asyncio import Socket
from control_backend.core.zmq_context import context
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
@@ -24,7 +21,7 @@ async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket : Socket = request.app.state.internal_comm_socket
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}
@@ -38,6 +35,7 @@ async def ping(request: Request):
@router.get("/ping_stream")
async def ping_stream(request: Request):
"""Stream live updates whenever the device state changes."""
async def event_stream():
# Set up internal socket to receive ping updates
logger.debug("Ping stream router event stream entered.")
@@ -47,7 +45,7 @@ async def ping_stream(request: Request):
sub_socket.setsockopt(zmq.SUBSCRIBE, b"ping")
connected = False
ping_frequency = 1 # How many seconds between ping attempts
ping_frequency = 1 # How many seconds between ping attempts
# Even though its most likely the updates should alternate
# So, True - False - True - False for connectivity.
@@ -55,21 +53,21 @@ async def ping_stream(request: Request):
while True:
logger.debug("Ping stream entered listening ")
try:
topic, body = await asyncio.wait_for(sub_socket.recv_multipart(), timeout=ping_frequency)
topic, body = await asyncio.wait_for(
sub_socket.recv_multipart(), timeout=ping_frequency
)
logger.debug("got ping change in ping_stream router")
connected = json.loads(body)
except TimeoutError as e:
except TimeoutError:
await asyncio.sleep(0.1)
# Stop if client disconnected
if await request.is_disconnected():
print("Client disconnected from SSE")
break
logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}")
falseJson = json.dumps(connected)
yield (f"data: {falseJson}\n\n")
return StreamingResponse(event_stream(), media_type="text/event-stream")
return StreamingResponse(event_stream(), media_type="text/event-stream")

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import message, sse, robot
from control_backend.api.v1.endpoints import message, robot, sse
api_router = APIRouter()
@@ -8,4 +8,4 @@ api_router.include_router(message.router, tags=["Messages"])
api_router.include_router(sse.router, tags=["SSE"])
api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"])
api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"])

View File

@@ -8,9 +8,10 @@ import zmq
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from control_backend.agents.bdi.bdi_core import BDICoreAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
@@ -32,12 +33,14 @@ async def lifespan(app: FastAPI):
# Initiate agents
ri_communication_agent = RICommunicationAgent(
jid=settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host,
jid=settings.agent_settings.ri_communication_agent_name
+ "@"
+ settings.agent_settings.host,
password=settings.agent_settings.ri_communication_agent_name,
pub_socket=internal_comm_socket,
address="tcp://*:5555",
bind=True,
)
)
await ri_communication_agent.start()
bdi_core = BDICoreAgent(

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Literal
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel
class RIEndpoint(str, Enum):

View File

@@ -1,10 +1,10 @@
import asyncio
import zmq
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.schemas.ri_message import SpeechCommand
@pytest.mark.asyncio

View File

@@ -1,6 +1,8 @@
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY
from control_backend.agents.ri_communication_agent import RICommunicationAgent
@@ -109,7 +111,11 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
await agent.setup()
@@ -153,7 +159,11 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
await agent.setup()
@@ -189,8 +199,8 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time.
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
@@ -200,7 +210,11 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
# --- Act ---
with caplog.at_level("ERROR"):
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
@@ -240,7 +254,11 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
fake_pub_socket = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=True
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=True,
)
await agent.setup()
@@ -283,7 +301,11 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
fake_pub_socket = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
await agent.setup()
@@ -326,7 +348,11 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
fake_pub_socket = AsyncMock()
# --- Act ---
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
await agent.setup()
@@ -362,8 +388,8 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time.
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a etter response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
@@ -374,7 +400,11 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
# --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
@@ -408,11 +438,15 @@ async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock()
fake_pub_socket = AsyncMock()
# --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
@@ -544,13 +578,16 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
# Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
with caplog.at_level("ERROR"):
@@ -587,7 +624,11 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
fake_pub_socket = AsyncMock()
agent = RICommunicationAgent(
"test@server", "password", pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False
"test@server",
"password",
pub_socket=fake_pub_socket,
address="tcp://localhost:5555",
bind=False,
)
# --- Act & Assert ---

View File

@@ -1,7 +1,8 @@
from unittest.mock import MagicMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from control_backend.api.v1.endpoints import robot
from control_backend.schemas.ri_message import SpeechCommand

View File

@@ -1,7 +1,8 @@
import pytest
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
from pydantic import ValidationError
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
def valid_command_1():
return SpeechCommand(data="Hallo?")
@@ -13,24 +14,14 @@ def invalid_command_1():
def test_valid_speech_command_1():
command = valid_command_1()
try:
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
assert True
except ValidationError:
assert False
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
assert True
def test_invalid_speech_command_1():
command = invalid_command_1()
passed_ri_message_validation = False
try:
# Should succeed, still.
RIMessage.model_validate(command)
passed_ri_message_validation = True
# Should fail.
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
SpeechCommand.model_validate(command)
assert False
except ValidationError:
assert passed_ri_message_validation
assert True