refactor: ZMQ context and proxy

Use ZMQ's global context instance and setup an XPUB/XSUB proxy intermediary to allow for easier multi-pubs.

close: N25B-217
This commit is contained in:
2025-10-30 11:40:14 +01:00
parent 657c300bc7
commit b92471ff1c
10 changed files with 92 additions and 49 deletions

View File

@@ -1,11 +1,12 @@
import json import json
import logging import logging
import zmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
import zmq from zmq.asyncio import Context
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -55,6 +56,8 @@ class RICommandAgent(Agent):
""" """
logger.info("Setting up %s", self.jid) logger.info("Setting up %s", self.jid)
context = Context.instance()
# To the robot # To the robot
self.pubsocket = context.socket(zmq.PUB) self.pubsocket = context.socket(zmq.PUB)
if self.bind: if self.bind:
@@ -64,7 +67,7 @@ class RICommandAgent(Agent):
# Receive internal topics regarding commands # Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB) self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_comm_address) self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent # Add behaviour to our agent

View File

@@ -1,14 +1,13 @@
import asyncio import asyncio
import json
import logging import logging
import zmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
import zmq from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.message import Message
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,7 +46,7 @@ class RICommunicationAgent(Agent):
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0) message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :( # We didnt get a reply :(
except asyncio.TimeoutError as e: except TimeoutError:
logger.info("No ping retrieved in 3 seconds, killing myself.") logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill() self.kill()
@@ -75,7 +74,7 @@ class RICommunicationAgent(Agent):
# Let's try a certain amount of times before failing connection # Let's try a certain amount of times before failing connection
while retries < max_retries: while retries < max_retries:
# Bind request socket # Bind request socket
self.req_socket = context.socket(zmq.REQ) self.req_socket = Context.instance().socket(zmq.REQ)
if self._bind: if self._bind:
self.req_socket.bind(self._address) self.req_socket.bind(self._address)
else: else:
@@ -88,7 +87,7 @@ class RICommunicationAgent(Agent):
try: try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError: except TimeoutError:
logger.warning( logger.warning(
"No connection established in 20 seconds (attempt %d/%d)", "No connection established in 20 seconds (attempt %d/%d)",
retries + 1, retries + 1,

View File

@@ -10,7 +10,6 @@ from spade.message import Message
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,7 +46,8 @@ class TranscriptionAgent(Agent):
"""Share a transcription to the other agents that depend on it.""" """Share a transcription to the other agents that depend on it."""
receiver_jids = [ receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name settings.agent_settings.text_belief_extractor_agent_name
+ '@' + settings.agent_settings.host, + "@"
+ settings.agent_settings.host,
] # Set message receivers here ] # Set message receivers here
for receiver_jid in receiver_jids: for receiver_jid in receiver_jids:
@@ -68,7 +68,7 @@ class TranscriptionAgent(Agent):
return await super().stop() return await super().stop()
def _connect_audio_in_socket(self): def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB) self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address) self.audio_in_socket.connect(self.audio_in_address)

View File

@@ -9,7 +9,6 @@ from spade.behaviour import CyclicBehaviour
from control_backend.agents.transcription import TranscriptionAgent from control_backend.agents.transcription import TranscriptionAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -121,7 +120,7 @@ class VADAgent(Agent):
return await super().stop() return await super().stop()
def _connect_audio_in_socket(self): def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB) self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind: if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address) self.audio_in_socket.bind(self.audio_in_address)
@@ -132,7 +131,7 @@ class VADAgent(Agent):
def _connect_audio_out_socket(self) -> int | None: def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed.""" """Returns the port bound, or None if binding failed."""
try: try:
self.audio_out_socket = zmq_context.socket(zmq.PUB) self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError: except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.") logger.error("Failed to bind an audio output socket after 100 tries.")

View File

@@ -1,9 +1,11 @@
from fastapi import APIRouter, Request
import logging import logging
from zmq import Socket import zmq
from fastapi import APIRouter, Request
from zmq.asyncio import Context
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,8 +17,8 @@ async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data. # Validate and retrieve data.
SpeechCommand.model_validate(command) SpeechCommand.model_validate(command)
topic = b"command" topic = b"command"
pub_socket: Socket = request.app.state.internal_comm_socket pub_socket = Context.instance().socket(zmq.PUB)
pub_socket.send_multipart([topic, command.model_dump_json().encode()]) pub_socket.connect(settings.zmq_settings.internal_pub_address)
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"} return {"status": "Command received"}

View File

@@ -1,8 +1,10 @@
import logging import logging
import zmq
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from zmq import Socket from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.schemas.message import Message from control_backend.schemas.message import Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,8 +19,8 @@ async def receive_message(message: Message, request: Request):
topic = b"message" topic = b"message"
body = message.model_dump_json().encode("utf-8") body = message.model_dump_json().encode("utf-8")
pub_socket: Socket = request.app.state.internal_comm_socket pub_socket = Context.instance().socket(zmq.PUB)
pub_socket.bind(settings.zmq_settings.internal_pub_address)
pub_socket.send_multipart([topic, body]) await pub_socket.send_multipart([topic, body])
return {"status": "Message received"} return {"status": "Message received"}

View File

@@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel): class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560" internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561"
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
@@ -24,6 +25,7 @@ class LLMSettings(BaseModel):
local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "openai/gpt-oss-20b" local_llm_model: str = "openai/gpt-oss-20b"
class Settings(BaseSettings): class Settings(BaseSettings):
app_title: str = "PepperPlus" app_title: str = "PepperPlus"
@@ -37,4 +39,5 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env") model_config = SettingsConfigDict(env_file=".env")
settings = Settings() settings = Settings()

View File

@@ -1,3 +0,0 @@
from zmq.asyncio import Context
context = Context()

View File

@@ -7,17 +7,18 @@ import logging
import zmq import zmq
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent
# Internal imports # Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.vad_agent import VADAgent from control_backend.agents.vad_agent import VADAgent
from control_backend.agents.llm.llm import LLMAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@@ -28,12 +29,17 @@ async def lifespan(app: FastAPI):
logger.info("%s starting up.", app.title) logger.info("%s starting up.", app.title)
# Initiate sockets # Initiate sockets
internal_comm_socket = context.socket(zmq.PUB) context = Context.instance()
internal_comm_address = settings.zmq_settings.internal_comm_address
internal_comm_socket.bind(internal_comm_address)
app.state.internal_comm_socket = internal_comm_socket
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
internal_pub_socket = context.socket(zmq.XPUB)
internal_pub_socket.bind(settings.zmq_settings.internal_pub_address)
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
internal_sub_socket = context.socket(zmq.XSUB)
internal_sub_socket.bind(settings.zmq_settings.internal_sub_address)
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
zmq.proxy(internal_pub_socket, internal_sub_socket)
# Initiate agents # Initiate agents
ri_communication_agent = RICommunicationAgent( ri_communication_agent = RICommunicationAgent(
@@ -45,26 +51,28 @@ async def lifespan(app: FastAPI):
await ri_communication_agent.start() await ri_communication_agent.start()
llm_agent = LLMAgent( llm_agent = LLMAgent(
settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.llm_agent_name, settings.agent_settings.llm_agent_name,
) )
await llm_agent.start() await llm_agent.start()
bdi_core = BDICoreAgent( bdi_core = BDICoreAgent(
settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name, settings.agent_settings.bdi_core_agent_name,
"src/control_backend/agents/bdi/rules.asl", "src/control_backend/agents/bdi/rules.asl",
) )
await bdi_core.start() await bdi_core.start()
belief_collector = BeliefCollectorAgent( belief_collector = BeliefCollectorAgent(
settings.agent_settings.belief_collector_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.belief_collector_agent_name, settings.agent_settings.belief_collector_agent_name,
) )
await belief_collector.start() await belief_collector.start()
text_belief_extractor = TBeliefExtractor( text_belief_extractor = TBeliefExtractor(
settings.agent_settings.text_belief_extractor_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.text_belief_extractor_agent_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.text_belief_extractor_agent_name, settings.agent_settings.text_belief_extractor_agent_name,
) )
await text_belief_extractor.start() await text_belief_extractor.start()

View File

@@ -1,7 +1,8 @@
from unittest.mock import AsyncMock, patch
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from control_backend.api.v1.endpoints import command from control_backend.api.v1.endpoints import command
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
@@ -15,7 +16,6 @@ def app():
""" """
app = FastAPI() app = FastAPI()
app.include_router(command.router) app.include_router(command.router)
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
return app return app
@@ -25,12 +25,42 @@ def client(app):
return TestClient(app) return TestClient(app)
def test_receive_command_endpoint(client, app): @pytest.mark.asyncio
@patch("control_backend.api.endpoints.command.Context.instance")
async def test_receive_command_success(mock_context_instance, async_client):
"""
Test for successful reception of a command.
Ensures the status code is 202 and the response body is correct.
It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data.
"""
# Arrange
mock_pub_socket = AsyncMock()
mock_context_instance.return_value.socket.return_value = mock_pub_socket
command_data = {"command": "test_command", "text": "This is a test"}
speech_command = SpeechCommand(**command_data)
# Act
response = await async_client.post("/command", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the ZMQ socket was used correctly
mock_context_instance.return_value.socket.assert_called_once_with(1) # zmq.PUB is 1
mock_pub_socket.connect.assert_called_once()
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
)
def test_receive_command_endpoint(client, app, mocker):
""" """
Test that a POST to /command sends the right multipart message Test that a POST to /command sends the right multipart message
and returns a 202 with the expected JSON body. and returns a 202 with the expected JSON body.
""" """
mock_socket = app.state.internal_comm_socket mock_socket = mocker.patch.object()
# Prepare test payload that matches SpeechCommand # Prepare test payload that matches SpeechCommand
payload = {"endpoint": "actuate/speech", "data": "yooo"} payload = {"endpoint": "actuate/speech", "data": "yooo"}