Merge remote-tracking branch 'origin/dev' into feat/vad-agent

# Conflicts:
#	pyproject.toml
#	src/control_backend/main.py
#	uv.lock
This commit is contained in:
Twirre Meulenbelt
2025-10-28 10:44:03 +01:00
18 changed files with 271 additions and 95 deletions

View File

@@ -5,13 +5,15 @@ from spade_bdi.bdi import BDIAgent
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetter
class BDICoreAgent(BDIAgent):
"""
This is the Brain agent that does the belief inference with AgentSpeak.
This is the Brain agent that does the belief inference with AgentSpeak.
This is a continous process that happens automatically in the background.
This class contains all the actions that can be called from AgentSpeak plans.
It has the BeliefSetter behaviour.
"""
logger = logging.getLogger("BDI Core")
async def setup(self):
@@ -31,5 +33,3 @@ class BDICoreAgent(BDIAgent):
def _send_to_llm(self, message) -> str:
"""TODO: implement"""
return f"This is a reply to {message}"

View File

@@ -8,15 +8,17 @@ from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings
class BeliefSetter(CyclicBehaviour):
"""
This is the behaviour that the BDI agent runs.
This behaviour waits for incoming message and processes it based on sender.
Currently, t only waits for messages containing beliefs from Belief Collector and adds these to its KB.
This is the behaviour that the BDI agent runs. This behaviour waits for incoming
message and processes it based on sender. Currently, it only waits for messages
containing beliefs from BeliefCollector and adds these to its KB.
"""
agent: BDIAgent
logger = logging.getLogger("BDI/Belief Setter")
async def run(self):
msg = await self.receive(timeout=0.1)
if msg:
@@ -36,7 +38,8 @@ class BeliefSetter(CyclicBehaviour):
pass
def _process_belief_message(self, message: Message):
if not message.body: return
if not message.body:
return
match message.thread:
case "beliefs":
@@ -48,7 +51,6 @@ class BeliefSetter(CyclicBehaviour):
case _:
pass
def _set_beliefs(self, beliefs: dict[str, list[list[str]]]):
if self.agent.bdi is None:
self.logger.warning("Cannot set beliefs, since agent's BDI is not yet initialized.")

View File

@@ -18,6 +18,7 @@ class SocketPoller[T]:
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
"""
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
"""
:param socket: The socket to poll and get data from.
@@ -46,9 +47,9 @@ class Streaming(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False)
self.model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32)
@@ -59,8 +60,10 @@ class Streaming(CyclicBehaviour):
data = await self.audio_in_poller.poll()
if data is None:
if self.i_since_data % 10 == 0:
logger.debug("Failed to receive audio from socket for %d ms.",
self.audio_in_poller.timeout_ms*(self.i_since_data+1))
logger.debug(
"Failed to receive audio from socket for %d ms.",
self.audio_in_poller.timeout_ms * (self.i_since_data + 1),
)
self.i_since_data += 1
return
self.i_since_data = 0
@@ -70,7 +73,8 @@ class Streaming(CyclicBehaviour):
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3: logger.debug("Speech started.")
if self.i_since_speech > 3:
logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
@@ -82,9 +86,9 @@ class Streaming(CyclicBehaviour):
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3*len(chunk):
if len(self.audio_buffer) >= 3 * len(chunk):
logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[:-2*len(chunk)].tobytes())
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
@@ -96,8 +100,9 @@ class VADAgent(Agent):
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + '@' + settings.agent_settings.host
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
@@ -146,7 +151,6 @@ class VADAgent(Agent):
if audio_out_port is None:
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming)

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Request
import logging
from fastapi import APIRouter, Request
from zmq import Socket
from control_backend.schemas.message import Message
@@ -9,6 +9,7 @@ logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/message", status_code=202)
async def receive_message(message: Message, request: Request):
logger.info("Received message: %s", message.message)

View File

@@ -2,7 +2,8 @@ from fastapi import APIRouter, Request
router = APIRouter()
# TODO: implement
@router.get("/sse")
async def sse(request: Request):
pass
pass

View File

@@ -4,12 +4,6 @@ from control_backend.api.v1.endpoints import message, sse
api_router = APIRouter()
api_router.include_router(
message.router,
tags=["Messages"]
)
api_router.include_router(message.router, tags=["Messages"])
api_router.include_router(
sse.router,
tags=["SSE"]
)
api_router.include_router(sse.router, tags=["SSE"])

View File

@@ -1,4 +1,3 @@
from re import L
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -22,7 +21,7 @@ class Settings(BaseSettings):
zmq_settings: ZMQSettings = ZMQSettings()
agent_settings: AgentSettings = AgentSettings()
model_config = SettingsConfigDict(env_file=".env")

View File

@@ -1,26 +1,24 @@
# Standard library imports
import asyncio
import json
# External imports
import contextlib
import logging
import zmq
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import logging
from spade.agent import Agent, Message
from spade.behaviour import OneShotBehaviour
import zmq
# Internal imports
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import AgentSettings, settings
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("%s starting up.", app.title)
@@ -33,27 +31,33 @@ async def lifespan(app: FastAPI):
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
# Initiate agents
bdi_core = BDICoreAgent(settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name, "src/control_backend/agents/bdi/rules.asl")
bdi_core = BDICoreAgent(
settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name,
"src/control_backend/agents/bdi/rules.asl",
)
await bdi_core.start()
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
await _temp_vad_agent.start()
yield
logger.info("%s shutting down.", app.title)
# if __name__ == "__main__":
app = FastAPI(title=settings.app_title, lifespan=lifespan)
# This middleware allows other origins to communicate with us
app.add_middleware(
CORSMiddleware, # https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS
allow_origins=[settings.ui_url], # address of our UI application
allow_methods=["*"], # GET, POST, etc.
CORSMiddleware, # https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS
allow_origins=[settings.ui_url], # address of our UI application
allow_methods=["*"], # GET, POST, etc.
)
app.include_router(api_router, prefix="") # TODO: make prefix /api/v1
app.include_router(api_router, prefix="") # TODO: make prefix /api/v1
@app.get("/")
async def root():

View File

@@ -1,4 +1,5 @@
from pydantic import BaseModel
class Message(BaseModel):
message: str
message: str