Merge branch 'dev' into refactor/logging
This commit is contained in:
@@ -2,10 +2,10 @@ import json
|
||||
|
||||
import zmq
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ class RICommandAgent(BaseAgent):
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.jid)
|
||||
|
||||
context = Context.instance()
|
||||
|
||||
# To the robot
|
||||
self.pubsocket = context.socket(zmq.PUB)
|
||||
if self.bind:
|
||||
@@ -62,7 +64,7 @@ class RICommandAgent(BaseAgent):
|
||||
|
||||
# Receive internal topics regarding commands
|
||||
self.subsocket = context.socket(zmq.SUB)
|
||||
self.subsocket.connect(settings.zmq_settings.internal_comm_address)
|
||||
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
||||
|
||||
# Add behaviour to our agent
|
||||
|
||||
@@ -2,10 +2,10 @@ import asyncio
|
||||
|
||||
import zmq
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context
|
||||
|
||||
from .ri_command_agent import RICommandAgent
|
||||
|
||||
@@ -72,7 +72,7 @@ class RICommunicationAgent(BaseAgent):
|
||||
# Let's try a certain amount of times before failing connection
|
||||
while retries < max_retries:
|
||||
# Bind request socket
|
||||
self.req_socket = context.socket(zmq.REQ)
|
||||
self.req_socket = Context.instance().socket(zmq.REQ)
|
||||
if self._bind:
|
||||
self.req_socket.bind(self._address)
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,6 @@ from spade.message import Message
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context as zmq_context
|
||||
|
||||
from .speech_recognizer import SpeechRecognizer
|
||||
|
||||
@@ -67,7 +66,7 @@ class TranscriptionAgent(BaseAgent):
|
||||
return await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
self.audio_in_socket = zmq_context.socket(zmq.SUB)
|
||||
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.audio_in_socket.connect(self.audio_in_address)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from spade.behaviour import CyclicBehaviour
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context as zmq_context
|
||||
|
||||
from .transcription.transcription_agent import TranscriptionAgent
|
||||
|
||||
@@ -120,7 +119,7 @@ class VADAgent(BaseAgent):
|
||||
return await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
self.audio_in_socket = zmq_context.socket(zmq.SUB)
|
||||
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
if self.audio_in_bind:
|
||||
self.audio_in_socket.bind(self.audio_in_address)
|
||||
@@ -131,7 +130,7 @@ class VADAgent(BaseAgent):
|
||||
def _connect_audio_out_socket(self) -> int | None:
|
||||
"""Returns the port bound, or None if binding failed."""
|
||||
try:
|
||||
self.audio_out_socket = zmq_context.socket(zmq.PUB)
|
||||
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
|
||||
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
|
||||
except zmq.ZMQBindError:
|
||||
self.logger.error("Failed to bind an audio output socket after 100 tries.")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from zmq import Socket
|
||||
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
@@ -15,7 +14,7 @@ async def receive_command(command: SpeechCommand, request: Request):
|
||||
# Validate and retrieve data.
|
||||
SpeechCommand.model_validate(command)
|
||||
topic = b"command"
|
||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
||||
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||
pub_socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||
|
||||
return {"status": "Command received"}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from zmq import Socket
|
||||
|
||||
from control_backend.schemas.message import Message
|
||||
|
||||
@@ -17,8 +16,7 @@ async def receive_message(message: Message, request: Request):
|
||||
topic = b"message"
|
||||
body = message.model_dump_json().encode("utf-8")
|
||||
|
||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
||||
|
||||
pub_socket.send_multipart([topic, body])
|
||||
pub_socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, body])
|
||||
|
||||
return {"status": "Message received"}
|
||||
|
||||
@@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class ZMQSettings(BaseModel):
|
||||
internal_comm_address: str = "tcp://localhost:5560"
|
||||
internal_pub_address: str = "tcp://localhost:5560"
|
||||
internal_sub_address: str = "tcp://localhost:5561"
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from zmq.asyncio import Context
|
||||
|
||||
context = Context()
|
||||
@@ -1,9 +1,11 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
|
||||
import zmq
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.agents import (
|
||||
BeliefCollectorAgent,
|
||||
@@ -14,12 +16,30 @@ from control_backend.agents import (
|
||||
from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent
|
||||
from control_backend.api.v1.router import api_router
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context
|
||||
from control_backend.logging import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_sockets():
|
||||
context = Context.instance()
|
||||
|
||||
internal_pub_socket = context.socket(zmq.XPUB)
|
||||
internal_pub_socket.bind(settings.zmq_settings.internal_sub_address)
|
||||
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
|
||||
|
||||
internal_sub_socket = context.socket(zmq.XSUB)
|
||||
internal_sub_socket.bind(settings.zmq_settings.internal_pub_address)
|
||||
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
|
||||
try:
|
||||
zmq.proxy(internal_sub_socket, internal_pub_socket)
|
||||
except zmq.ZMQError:
|
||||
logger.warning("Error while handling PUB/SUB proxy. Closing sockets.")
|
||||
finally:
|
||||
internal_pub_socket.close()
|
||||
internal_sub_socket.close()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
@@ -29,18 +49,16 @@ async def lifespan(app: FastAPI):
|
||||
setup_logging()
|
||||
logger.info("%s is starting up.", app.title)
|
||||
|
||||
# --- Initialize Sockets ---
|
||||
logger.info("Initializing ZeroMQ sockets.")
|
||||
try:
|
||||
internal_comm_socket = context.socket(zmq.PUB)
|
||||
internal_comm_address = settings.zmq_settings.internal_comm_address
|
||||
logger.debug("Binding internal PUB socket to address: %s", internal_comm_address)
|
||||
internal_comm_socket.bind(internal_comm_address)
|
||||
app.state.internal_comm_socket = internal_comm_socket
|
||||
logger.info("Internal communication socket bound successfully.")
|
||||
except Exception as e:
|
||||
logger.error("Failed to bind internal communication socket: %s", e, exc_info=True)
|
||||
raise
|
||||
# Initiate sockets
|
||||
proxy_thread = threading.Thread(target=setup_sockets)
|
||||
proxy_thread.daemon = True
|
||||
proxy_thread.start()
|
||||
|
||||
context = Context.instance()
|
||||
|
||||
endpoints_pub_socket = context.socket(zmq.PUB)
|
||||
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||
app.state.endpoints_pub_socket = endpoints_pub_socket
|
||||
|
||||
# --- Initialize Agents ---
|
||||
logger.info("Initializing and starting agents.")
|
||||
|
||||
Reference in New Issue
Block a user