diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py index 88c859b..1ec76d5 100644 --- a/src/control_backend/api/v1/endpoints/command.py +++ b/src/control_backend/api/v1/endpoints/command.py @@ -17,8 +17,7 @@ async def receive_command(command: SpeechCommand, request: Request): # Validate and retrieve data. SpeechCommand.model_validate(command) topic = b"command" - pub_socket = Context.instance().socket(zmq.PUB) - pub_socket.connect(settings.zmq_settings.internal_pub_address) + pub_socket = request.app.state.endpoints_pub_socket await pub_socket.send_multipart([topic, command.model_dump_json().encode()]) return {"status": "Command received"} diff --git a/src/control_backend/api/v1/endpoints/message.py b/src/control_backend/api/v1/endpoints/message.py index 1a58377..bd88a0b 100644 --- a/src/control_backend/api/v1/endpoints/message.py +++ b/src/control_backend/api/v1/endpoints/message.py @@ -1,10 +1,7 @@ import logging -import zmq from fastapi import APIRouter, Request -from zmq.asyncio import Context -from control_backend.core.config import settings from control_backend.schemas.message import Message logger = logging.getLogger(__name__) @@ -19,8 +16,7 @@ async def receive_message(message: Message, request: Request): topic = b"message" body = message.model_dump_json().encode("utf-8") - pub_socket = Context.instance().socket(zmq.PUB) - pub_socket.bind(settings.zmq_settings.internal_pub_address) + pub_socket = request.app.state.endpoints_pub_socket await pub_socket.send_multipart([topic, body]) return {"status": "Message received"} diff --git a/src/control_backend/main.py b/src/control_backend/main.py index ff63e1f..9d0f664 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -28,14 +28,14 @@ def setup_sockets(): context = Context.instance() internal_pub_socket = context.socket(zmq.XPUB) - internal_pub_socket.bind(settings.zmq_settings.internal_pub_address) + internal_pub_socket.bind(settings.zmq_settings.internal_sub_address) logger.debug("Internal publishing socket bound to %s", internal_pub_socket) internal_sub_socket = context.socket(zmq.XSUB) - internal_sub_socket.bind(settings.zmq_settings.internal_sub_address) + internal_sub_socket.bind(settings.zmq_settings.internal_pub_address) logger.debug("Internal subscribing socket bound to %s", internal_sub_socket) try: - zmq.proxy(internal_pub_socket, internal_sub_socket) + zmq.proxy(internal_sub_socket, internal_pub_socket) except zmq.ZMQError: logger.warning("Error while handling PUB/SUB proxy. Closing sockets.") finally: @@ -51,6 +51,12 @@ async def lifespan(app: FastAPI): proxy_thread.daemon = True proxy_thread.start() + context = Context.instance() + + endpoints_pub_socket = context.socket(zmq.PUB) + endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address) + app.state.endpoints_pub_socket = endpoints_pub_socket + # Initiate agents ri_communication_agent = RICommunicationAgent( settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host,