refactor: ZMQ context and proxy

Use ZMQ's global context instance and setup an XPUB/XSUB proxy intermediary to allow for easier multi-pubs.

close: N25B-217
This commit is contained in:
2025-10-30 11:40:14 +01:00
parent 657c300bc7
commit b92471ff1c
10 changed files with 92 additions and 49 deletions

View File

@@ -1,11 +1,12 @@
import json
import logging
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
@@ -55,6 +56,8 @@ class RICommandAgent(Agent):
"""
logger.info("Setting up %s", self.jid)
context = Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind:
@@ -64,7 +67,7 @@ class RICommandAgent(Agent):
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_comm_address)
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent

View File

@@ -1,14 +1,13 @@
import asyncio
import json
import logging
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.message import Message
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.core.config import settings
logger = logging.getLogger(__name__)
@@ -47,7 +46,7 @@ class RICommunicationAgent(Agent):
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :(
except asyncio.TimeoutError as e:
except TimeoutError:
logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
@@ -75,7 +74,7 @@ class RICommunicationAgent(Agent):
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Bind request socket
self.req_socket = context.socket(zmq.REQ)
self.req_socket = Context.instance().socket(zmq.REQ)
if self._bind:
self.req_socket.bind(self._address)
else:
@@ -88,7 +87,7 @@ class RICommunicationAgent(Agent):
try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"No connection established in 20 seconds (attempt %d/%d)",
retries + 1,

View File

@@ -10,7 +10,6 @@ from spade.message import Message
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
@@ -47,7 +46,8 @@ class TranscriptionAgent(Agent):
"""Share a transcription to the other agents that depend on it."""
receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name
+ '@' + settings.agent_settings.host,
+ "@"
+ settings.agent_settings.host,
] # Set message receivers here
for receiver_jid in receiver_jids:
@@ -68,7 +68,7 @@ class TranscriptionAgent(Agent):
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB)
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)

View File

@@ -9,7 +9,6 @@ from spade.behaviour import CyclicBehaviour
from control_backend.agents.transcription import TranscriptionAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
@@ -121,7 +120,7 @@ class VADAgent(Agent):
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB)
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address)
@@ -132,7 +131,7 @@ class VADAgent(Agent):
def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed."""
try:
self.audio_out_socket = zmq_context.socket(zmq.PUB)
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.")

View File

@@ -1,9 +1,11 @@
from fastapi import APIRouter, Request
import logging
from zmq import Socket
import zmq
from fastapi import APIRouter, Request
from zmq.asyncio import Context
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
@@ -15,8 +17,8 @@ async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
pub_socket = Context.instance().socket(zmq.PUB)
pub_socket.connect(settings.zmq_settings.internal_pub_address)
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}

View File

@@ -1,8 +1,10 @@
import logging
import zmq
from fastapi import APIRouter, Request
from zmq import Socket
from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.schemas.message import Message
logger = logging.getLogger(__name__)
@@ -17,8 +19,8 @@ async def receive_message(message: Message, request: Request):
topic = b"message"
body = message.model_dump_json().encode("utf-8")
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, body])
pub_socket = Context.instance().socket(zmq.PUB)
pub_socket.bind(settings.zmq_settings.internal_pub_address)
await pub_socket.send_multipart([topic, body])
return {"status": "Message received"}

View File

@@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560"
internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561"
class AgentSettings(BaseModel):
@@ -24,6 +25,7 @@ class LLMSettings(BaseModel):
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "openai/gpt-oss-20b"
class Settings(BaseSettings):
app_title: str = "PepperPlus"
@@ -37,4 +39,5 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env")
settings = Settings()

View File

@@ -1,3 +0,0 @@
from zmq.asyncio import Context
context = Context()

View File

@@ -7,17 +7,18 @@ import logging
import zmq
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.agents.llm.llm import LLMAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
@@ -28,12 +29,17 @@ async def lifespan(app: FastAPI):
logger.info("%s starting up.", app.title)
# Initiate sockets
internal_comm_socket = context.socket(zmq.PUB)
internal_comm_address = settings.zmq_settings.internal_comm_address
internal_comm_socket.bind(internal_comm_address)
app.state.internal_comm_socket = internal_comm_socket
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
context = Context.instance()
internal_pub_socket = context.socket(zmq.XPUB)
internal_pub_socket.bind(settings.zmq_settings.internal_pub_address)
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
internal_sub_socket = context.socket(zmq.XSUB)
internal_sub_socket.bind(settings.zmq_settings.internal_sub_address)
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
zmq.proxy(internal_pub_socket, internal_sub_socket)
# Initiate agents
ri_communication_agent = RICommunicationAgent(
@@ -45,26 +51,28 @@ async def lifespan(app: FastAPI):
await ri_communication_agent.start()
llm_agent = LLMAgent(
settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host,
settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.llm_agent_name,
)
await llm_agent.start()
bdi_core = BDICoreAgent(
settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name,
"src/control_backend/agents/bdi/rules.asl",
)
await bdi_core.start()
belief_collector = BeliefCollectorAgent(
settings.agent_settings.belief_collector_agent_name + '@' + settings.agent_settings.host,
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.belief_collector_agent_name,
)
await belief_collector.start()
text_belief_extractor = TBeliefExtractor(
settings.agent_settings.text_belief_extractor_agent_name + '@' + settings.agent_settings.host,
settings.agent_settings.text_belief_extractor_agent_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.text_belief_extractor_agent_name,
)
await text_belief_extractor.start()