Merge remote-tracking branch 'origin/dev' into demo

This commit is contained in:
Twirre Meulenbelt
2025-11-05 12:38:08 +01:00
15 changed files with 187 additions and 163 deletions

View File

@@ -5,9 +5,9 @@ import spade.agent
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
@@ -71,6 +71,8 @@ class RICommandAgent(Agent):
"""
logger.info("Setting up %s", self.jid)
context = Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind:
@@ -80,7 +82,7 @@ class RICommandAgent(Agent):
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_comm_address)
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent

View File

@@ -4,10 +4,10 @@ import logging
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__)
@@ -74,7 +74,7 @@ class RICommunicationAgent(Agent):
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Bind request socket
self.req_socket = context.socket(zmq.REQ)
self.req_socket = Context.instance().socket(zmq.REQ)
if self._bind:
self.req_socket.bind(self._address)
else:

View File

@@ -10,7 +10,6 @@ from spade.message import Message
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
@@ -73,7 +72,7 @@ class TranscriptionAgent(Agent):
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB)
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)

View File

@@ -9,7 +9,6 @@ from spade.behaviour import CyclicBehaviour
from control_backend.agents.transcription import TranscriptionAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
@@ -66,7 +65,8 @@ class Streaming(CyclicBehaviour):
self._ready = True
async def run(self) -> None:
if not self._ready: return
if not self._ready:
return
data = await self.audio_in_poller.poll()
if data is None:
@@ -134,7 +134,7 @@ class VADAgent(Agent):
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB)
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address)
@@ -145,7 +145,7 @@ class VADAgent(Agent):
def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed."""
try:
self.audio_out_socket = zmq_context.socket(zmq.PUB)
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.")

View File

@@ -1,7 +1,6 @@
import logging
from fastapi import APIRouter, Request
from zmq import Socket
from control_backend.schemas.ri_message import SpeechCommand
@@ -15,7 +14,7 @@ async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}

View File

@@ -1,7 +1,6 @@
import logging
from fastapi import APIRouter, Request
from zmq import Socket
from control_backend.schemas.message import Message
@@ -17,8 +16,7 @@ async def receive_message(message: Message, request: Request):
topic = b"message"
body = message.model_dump_json().encode("utf-8")
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, body])
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Message received"}

View File

@@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560"
internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561"
class AgentSettings(BaseModel):

View File

@@ -1,3 +0,0 @@
from zmq.asyncio import Context
context = Context()

View File

@@ -3,37 +3,59 @@
# External imports
import contextlib
import logging
import threading
import zmq
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
def setup_sockets():
context = Context.instance()
internal_pub_socket = context.socket(zmq.XPUB)
internal_pub_socket.bind(settings.zmq_settings.internal_sub_address)
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
internal_sub_socket = context.socket(zmq.XSUB)
internal_sub_socket.bind(settings.zmq_settings.internal_pub_address)
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
try:
zmq.proxy(internal_sub_socket, internal_pub_socket)
except zmq.ZMQError:
logger.warning("Error while handling PUB/SUB proxy. Closing sockets.")
finally:
internal_pub_socket.close()
internal_sub_socket.close()
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("%s starting up.", app.title)
# Initiate sockets
internal_comm_socket = context.socket(zmq.PUB)
internal_comm_address = settings.zmq_settings.internal_comm_address
internal_comm_socket.bind(internal_comm_address)
app.state.internal_comm_socket = internal_comm_socket
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
proxy_thread = threading.Thread(target=setup_sockets)
proxy_thread.daemon = True
proxy_thread.start()
context = Context.instance()
endpoints_pub_socket = context.socket(zmq.PUB)
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
app.state.endpoints_pub_socket = endpoints_pub_socket
# Initiate agents
ri_communication_agent = RICommunicationAgent(