From b92471ff1c85448251f719d9a662e89718c0be81 Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Thu, 30 Oct 2025 11:40:14 +0100 Subject: [PATCH] refactor: ZMQ context and proxy Use ZMQ's global context instance and setup an XPUB/XSUB proxy intermediary to allow for easier multi-pubs. close: N25B-217 --- .../agents/ri_command_agent.py | 9 +++-- .../agents/ri_communication_agent.py | 15 ++++---- .../transcription/transcription_agent.py | 6 +-- src/control_backend/agents/vad_agent.py | 5 +-- .../api/v1/endpoints/command.py | 14 ++++--- .../api/v1/endpoints/message.py | 10 +++-- src/control_backend/core/config.py | 5 ++- src/control_backend/core/zmq_context.py | 3 -- src/control_backend/main.py | 36 +++++++++++------- .../api/endpoints/test_command_endpoint.py | 38 +++++++++++++++++-- 10 files changed, 92 insertions(+), 49 deletions(-) delete mode 100644 src/control_backend/core/zmq_context.py diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 01fc824..0dcc981 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -1,11 +1,12 @@ import json import logging + +import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour -import zmq +from zmq.asyncio import Context from control_backend.core.config import settings -from control_backend.core.zmq_context import context from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) @@ -55,6 +56,8 @@ class RICommandAgent(Agent): """ logger.info("Setting up %s", self.jid) + context = Context.instance() + # To the robot self.pubsocket = context.socket(zmq.PUB) if self.bind: @@ -64,7 +67,7 @@ class RICommandAgent(Agent): # Receive internal topics regarding commands self.subsocket = context.socket(zmq.SUB) - self.subsocket.connect(settings.zmq_settings.internal_comm_address) + self.subsocket.connect(settings.zmq_settings.internal_sub_address) self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") # Add behaviour to our agent diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 504c707..638b967 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -1,14 +1,13 @@ import asyncio -import json import logging + +import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour -import zmq +from zmq.asyncio import Context -from control_backend.core.config import settings -from control_backend.core.zmq_context import context -from control_backend.schemas.message import Message from control_backend.agents.ri_command_agent import RICommandAgent +from control_backend.core.config import settings logger = logging.getLogger(__name__) @@ -47,7 +46,7 @@ class RICommunicationAgent(Agent): message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0) # We didnt get a reply :( - except asyncio.TimeoutError as e: + except TimeoutError: logger.info("No ping retrieved in 3 seconds, killing myself.") self.kill() @@ -75,7 +74,7 @@ class RICommunicationAgent(Agent): # Let's try a certain amount of times before failing connection while retries < max_retries: # Bind request socket - self.req_socket = context.socket(zmq.REQ) + self.req_socket = Context.instance().socket(zmq.REQ) if self._bind: self.req_socket.bind(self._address) else: @@ -88,7 +87,7 @@ class RICommunicationAgent(Agent): try: received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) - except asyncio.TimeoutError: + except TimeoutError: logger.warning( "No connection established in 20 seconds (attempt %d/%d)", retries + 1, diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index a2c8e2b..530bd68 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -10,7 +10,6 @@ from spade.message import Message from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer from control_backend.core.config import settings -from control_backend.core.zmq_context import context as zmq_context logger = logging.getLogger(__name__) @@ -47,7 +46,8 @@ class TranscriptionAgent(Agent): """Share a transcription to the other agents that depend on it.""" receiver_jids = [ settings.agent_settings.text_belief_extractor_agent_name - + '@' + settings.agent_settings.host, + + "@" + + settings.agent_settings.host, ] # Set message receivers here for receiver_jid in receiver_jids: @@ -68,7 +68,7 @@ class TranscriptionAgent(Agent): return await super().stop() def _connect_audio_in_socket(self): - self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.connect(self.audio_in_address) diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index a228135..f16abf4 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -9,7 +9,6 @@ from spade.behaviour import CyclicBehaviour from control_backend.agents.transcription import TranscriptionAgent from control_backend.core.config import settings -from control_backend.core.zmq_context import context as zmq_context logger = logging.getLogger(__name__) @@ -121,7 +120,7 @@ class VADAgent(Agent): return await super().stop() def _connect_audio_in_socket(self): - self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") if self.audio_in_bind: self.audio_in_socket.bind(self.audio_in_address) @@ -132,7 +131,7 @@ class VADAgent(Agent): def _connect_audio_out_socket(self) -> int | None: """Returns the port bound, or None if binding failed.""" try: - self.audio_out_socket = zmq_context.socket(zmq.PUB) + self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) except zmq.ZMQBindError: logger.error("Failed to bind an audio output socket after 100 tries.") diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py index badaf90..88c859b 100644 --- a/src/control_backend/api/v1/endpoints/command.py +++ b/src/control_backend/api/v1/endpoints/command.py @@ -1,9 +1,11 @@ -from fastapi import APIRouter, Request import logging -from zmq import Socket +import zmq +from fastapi import APIRouter, Request +from zmq.asyncio import Context -from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint +from control_backend.core.config import settings +from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) @@ -15,8 +17,8 @@ async def receive_command(command: SpeechCommand, request: Request): # Validate and retrieve data. SpeechCommand.model_validate(command) topic = b"command" - pub_socket: Socket = request.app.state.internal_comm_socket - pub_socket.send_multipart([topic, command.model_dump_json().encode()]) - + pub_socket = Context.instance().socket(zmq.PUB) + pub_socket.connect(settings.zmq_settings.internal_pub_address) + await pub_socket.send_multipart([topic, command.model_dump_json().encode()]) return {"status": "Command received"} diff --git a/src/control_backend/api/v1/endpoints/message.py b/src/control_backend/api/v1/endpoints/message.py index 1053c3c..1a58377 100644 --- a/src/control_backend/api/v1/endpoints/message.py +++ b/src/control_backend/api/v1/endpoints/message.py @@ -1,8 +1,10 @@ import logging +import zmq from fastapi import APIRouter, Request -from zmq import Socket +from zmq.asyncio import Context +from control_backend.core.config import settings from control_backend.schemas.message import Message logger = logging.getLogger(__name__) @@ -17,8 +19,8 @@ async def receive_message(message: Message, request: Request): topic = b"message" body = message.model_dump_json().encode("utf-8") - pub_socket: Socket = request.app.state.internal_comm_socket - - pub_socket.send_multipart([topic, body]) + pub_socket = Context.instance().socket(zmq.PUB) + pub_socket.bind(settings.zmq_settings.internal_pub_address) + await pub_socket.send_multipart([topic, body]) return {"status": "Message received"} diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 5e4b764..8de2403 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class ZMQSettings(BaseModel): - internal_comm_address: str = "tcp://localhost:5560" + internal_pub_address: str = "tcp://localhost:5560" + internal_sub_address: str = "tcp://localhost:5561" class AgentSettings(BaseModel): @@ -24,6 +25,7 @@ class LLMSettings(BaseModel): local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "openai/gpt-oss-20b" + class Settings(BaseSettings): app_title: str = "PepperPlus" @@ -37,4 +39,5 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_file=".env") + settings = Settings() diff --git a/src/control_backend/core/zmq_context.py b/src/control_backend/core/zmq_context.py deleted file mode 100644 index a74544f..0000000 --- a/src/control_backend/core/zmq_context.py +++ /dev/null @@ -1,3 +0,0 @@ -from zmq.asyncio import Context - -context = Context() diff --git a/src/control_backend/main.py b/src/control_backend/main.py index d3588ea..1543882 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -7,17 +7,18 @@ import logging import zmq from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from zmq.asyncio import Context + +from control_backend.agents.bdi.bdi_core import BDICoreAgent +from control_backend.agents.bdi.text_extractor import TBeliefExtractor +from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent +from control_backend.agents.llm.llm import LLMAgent # Internal imports from control_backend.agents.ri_communication_agent import RICommunicationAgent -from control_backend.agents.bdi.bdi_core import BDICoreAgent from control_backend.agents.vad_agent import VADAgent -from control_backend.agents.llm.llm import LLMAgent -from control_backend.agents.bdi.text_extractor import TBeliefExtractor -from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent from control_backend.api.v1.router import api_router from control_backend.core.config import settings -from control_backend.core.zmq_context import context logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) @@ -28,12 +29,17 @@ async def lifespan(app: FastAPI): logger.info("%s starting up.", app.title) # Initiate sockets - internal_comm_socket = context.socket(zmq.PUB) - internal_comm_address = settings.zmq_settings.internal_comm_address - internal_comm_socket.bind(internal_comm_address) - app.state.internal_comm_socket = internal_comm_socket - logger.info("Internal publishing socket bound to %s", internal_comm_socket) + context = Context.instance() + internal_pub_socket = context.socket(zmq.XPUB) + internal_pub_socket.bind(settings.zmq_settings.internal_pub_address) + logger.debug("Internal publishing socket bound to %s", internal_pub_socket) + + internal_sub_socket = context.socket(zmq.XSUB) + internal_sub_socket.bind(settings.zmq_settings.internal_sub_address) + logger.debug("Internal subscribing socket bound to %s", internal_sub_socket) + + zmq.proxy(internal_pub_socket, internal_sub_socket) # Initiate agents ri_communication_agent = RICommunicationAgent( @@ -45,26 +51,28 @@ async def lifespan(app: FastAPI): await ri_communication_agent.start() llm_agent = LLMAgent( - settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host, + settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.llm_agent_name, ) await llm_agent.start() bdi_core = BDICoreAgent( - settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, + settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name, "src/control_backend/agents/bdi/rules.asl", ) await bdi_core.start() belief_collector = BeliefCollectorAgent( - settings.agent_settings.belief_collector_agent_name + '@' + settings.agent_settings.host, + settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.belief_collector_agent_name, ) await belief_collector.start() text_belief_extractor = TBeliefExtractor( - settings.agent_settings.text_belief_extractor_agent_name + '@' + settings.agent_settings.host, + settings.agent_settings.text_belief_extractor_agent_name + + "@" + + settings.agent_settings.host, settings.agent_settings.text_belief_extractor_agent_name, ) await text_belief_extractor.start() diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index 07bd866..7e38924 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -1,7 +1,8 @@ +from unittest.mock import AsyncMock, patch + import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from unittest.mock import MagicMock from control_backend.api.v1.endpoints import command from control_backend.schemas.ri_message import SpeechCommand @@ -15,7 +16,6 @@ def app(): """ app = FastAPI() app.include_router(command.router) - app.state.internal_comm_socket = MagicMock() # mock ZMQ socket return app @@ -25,12 +25,42 @@ def client(app): return TestClient(app) -def test_receive_command_endpoint(client, app): +@pytest.mark.asyncio +@patch("control_backend.api.endpoints.command.Context.instance") +async def test_receive_command_success(mock_context_instance, async_client): + """ + Test for successful reception of a command. + Ensures the status code is 202 and the response body is correct. + It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data. + """ + # Arrange + mock_pub_socket = AsyncMock() + mock_context_instance.return_value.socket.return_value = mock_pub_socket + + command_data = {"command": "test_command", "text": "This is a test"} + speech_command = SpeechCommand(**command_data) + + # Act + response = await async_client.post("/command", json=command_data) + + # Assert + assert response.status_code == 202 + assert response.json() == {"status": "Command received"} + + # Verify that the ZMQ socket was used correctly + mock_context_instance.return_value.socket.assert_called_once_with(1) # zmq.PUB is 1 + mock_pub_socket.connect.assert_called_once() + mock_pub_socket.send_multipart.assert_awaited_once_with( + [b"command", speech_command.model_dump_json().encode()] + ) + + +def test_receive_command_endpoint(client, app, mocker): """ Test that a POST to /command sends the right multipart message and returns a 202 with the expected JSON body. """ - mock_socket = app.state.internal_comm_socket + mock_socket = mocker.patch.object() # Prepare test payload that matches SpeechCommand payload = {"endpoint": "actuate/speech", "data": "yooo"}