Merge remote-tracking branch 'origin/dev' into fix/bdi-correct-belief-management

This commit is contained in:
2025-10-29 13:25:58 +01:00
25 changed files with 1618 additions and 6 deletions

View File

@@ -0,0 +1,74 @@
import json
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
class RICommandAgent(Agent):
subsocket: zmq.Socket
pubsocket: zmq.Socket
address = ""
bind = False
def __init__(
self,
jid: str,
password: str,
port: int = 5222,
verify_security: bool = False,
address="tcp://localhost:0000",
bind=False,
):
super().__init__(jid, password, port, verify_security)
self.address = address
self.bind = bind
class SendCommandsBehaviour(CyclicBehaviour):
async def run(self):
"""
Run the command publishing loop indefinetely.
"""
assert self.agent is not None
# Get a message internally (with topic command)
topic, body = await self.agent.subsocket.recv_multipart()
# Try to get body
try:
body = json.loads(body)
message = SpeechCommand.model_validate(body)
# Send to the robot.
await self.agent.pubsocket.send_json(message.model_dump())
except Exception as e:
logger.error("Error processing message: %s", e)
async def setup(self):
"""
Setup the command agent
"""
logger.info("Setting up %s", self.jid)
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind:
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_comm_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent
commands_behaviour = self.SendCommandsBehaviour()
self.add_behaviour(commands_behaviour)
logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,165 @@
import asyncio
import json
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.message import Message
from control_backend.agents.ri_command_agent import RICommandAgent
logger = logging.getLogger(__name__)
class RICommunicationAgent(Agent):
req_socket: zmq.Socket
_address = ""
_bind = True
def __init__(
self,
jid: str,
password: str,
port: int = 5222,
verify_security: bool = False,
address="tcp://localhost:0000",
bind=False,
):
super().__init__(jid, password, port, verify_security)
self._address = address
self._bind = bind
class ListenBehaviour(CyclicBehaviour):
async def run(self):
"""
Run the listening (ping) loop indefinetely.
"""
assert self.agent is not None
# We need to listen and sent pings.
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
await self.agent.req_socket.send_json(message)
# Wait up to three seconds for a reply:)
try:
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :(
except asyncio.TimeoutError as e:
logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
logger.debug('Received message "%s"', message)
if "endpoint" not in message:
logger.error("No received endpoint in message, excepted ping endpoint.")
return
# See what endpoint we received
match message["endpoint"]:
case "ping":
await asyncio.sleep(1)
case _:
logger.info(
"Received message with topic different than ping, while ping expected."
)
async def setup(self, max_retries: int = 5):
"""
Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
"""
logger.info("Setting up %s", self.jid)
retries = 0
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Bind request socket
self.req_socket = context.socket(zmq.REQ)
if self._bind:
self.req_socket.bind(self._address)
else:
self.req_socket.connect(self._address)
# Send our message and receive one back:)
message = {"endpoint": "negotiate/ports", "data": None}
await self.req_socket.send_json(message)
try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError:
logger.warning(
"No connection established in 20 seconds (attempt %d/%d)",
retries + 1,
max_retries,
)
retries += 1
continue
except Exception as e:
logger.error("Unexpected error during negotiation: %s", e)
retries += 1
continue
# Validate endpoint
endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports":
# TODO: Should this send a message back?
logger.error(
"Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
max_retries,
)
retries += 1
continue
# At this point, we have a valid response
try:
for port_data in received_message["data"]:
id = port_data["id"]
port = port_data["port"]
bind = port_data["bind"]
if not bind:
addr = f"tcp://localhost:{port}"
else:
addr = f"tcp://*:{port}"
match id:
case "main":
if addr != self._address:
if not bind:
self.req_socket.connect(addr)
else:
self.req_socket.bind(addr)
case "actuation":
ri_commands_agent = RICommandAgent(
settings.agent_settings.ri_command_agent_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.ri_command_agent_name,
address=addr,
bind=bind,
)
await ri_commands_agent.start()
case _:
logger.warning("Unhandled negotiation id: %s", id)
except Exception as e:
logger.error("Error unpacking negotiation data: %s", e)
retries += 1
continue
# setup succeeded
break
else:
logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries)
return
# Set up ping behaviour
listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour)
logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,156 @@
import logging
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
class SocketPoller[T]:
"""
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
"""
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
"""
:param socket: The socket to poll and get data from.
:param timeout_ms: A timeout in milliseconds to wait for data.
"""
self.socket = socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms
async def poll(self, timeout_ms: int | None = None) -> T | None:
"""
Get data from the socket, or None if the timeout is reached.
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
socks = dict(self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv()
return None
class Streaming(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100 # Used to allow small pauses in speech
async def run(self) -> None:
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
logger.debug("No audio data received. Discarding buffer until new data arrives.")
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100
return
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3:
logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(Agent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
async def stop(self):
"""
Stop listening to audio, stop publishing audio, close sockets.
"""
if self.audio_in_socket is not None:
self.audio_in_socket.close()
self.audio_in_socket = None
if self.audio_out_socket is not None:
self.audio_out_socket.close()
self.audio_out_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address)
else:
self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed."""
try:
self.audio_out_socket = zmq_context.socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def setup(self):
logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
await self.stop()
return
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming)
# ... start agents dependent on the output audio fragments here
logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,22 @@
from fastapi import APIRouter, Request
import logging
from zmq import Socket
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}

View File

@@ -1,9 +1,11 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import message, sse
from control_backend.api.v1.endpoints import message, sse, command
api_router = APIRouter()
api_router.include_router(message.router, tags=["Messages"])
api_router.include_router(sse.router, tags=["SSE"])
api_router.include_router(command.router, tags=["Commands"])

View File

@@ -10,7 +10,10 @@ class AgentSettings(BaseModel):
host: str = "localhost"
bdi_core_agent_name: str = "bdi_core"
belief_collector_agent_name: str = "belief_collector"
test_agent_name: str = "test_agent"
vad_agent_name: str = "vad_agent"
ri_communication_agent_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent"
class Settings(BaseSettings):

View File

@@ -9,7 +9,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
@@ -30,6 +32,14 @@ async def lifespan(app: FastAPI):
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
# Initiate agents
ri_communication_agent = RICommunicationAgent(
settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.ri_communication_agent_name,
address="tcp://*:5555",
bind=True,
)
await ri_communication_agent.start()
bdi_core = BDICoreAgent(
settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name,
@@ -37,6 +47,9 @@ async def lifespan(app: FastAPI):
)
await bdi_core.start()
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
await _temp_vad_agent.start()
yield
logger.info("%s shutting down.", app.title)

View File

@@ -0,0 +1,20 @@
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel, Field, ValidationError
class RIEndpoint(str, Enum):
SPEECH = "actuate/speech"
PING = "ping"
NEGOTIATE_PORTS = "negotiate/ports"
class RIMessage(BaseModel):
endpoint: RIEndpoint
data: Any
class SpeechCommand(RIMessage):
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
data: str