fix: merge refactor/zmq-internal-socket-behaviour into feat/cb2ui-robot-connections. (And fixed all ruff/ test issues to commit)

ref: None
This commit is contained in:
Björn Otgaar
2025-10-31 14:16:11 +01:00
38 changed files with 1761 additions and 167 deletions

View File

@@ -1,9 +1,15 @@
import logging
import agentspeak
from spade.behaviour import OneShotBehaviour
from spade.message import Message
from spade_bdi.bdi import BDIAgent
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetter
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetterBehaviour
from control_backend.agents.bdi.behaviours.receive_llm_resp_behaviour import (
ReceiveLLMResponseBehaviour,
)
from control_backend.core.config import settings
class BDICoreAgent(BDIAgent):
@@ -11,25 +17,52 @@ class BDICoreAgent(BDIAgent):
This is the Brain agent that does the belief inference with AgentSpeak.
This is a continous process that happens automatically in the background.
This class contains all the actions that can be called from AgentSpeak plans.
It has the BeliefSetter behaviour.
It has the BeliefSetter behaviour and can aks and recieve requests from the LLM agent.
"""
logger = logging.getLogger("BDI Core")
logger = logging.getLogger("bdi_core_agent")
async def setup(self):
belief_setter = BeliefSetter()
self.add_behaviour(belief_setter)
async def setup(self) -> None:
"""
Initializes belief behaviors and message routing.
"""
self.logger.info("BDICoreAgent setup started")
self.add_behaviour(BeliefSetterBehaviour())
self.add_behaviour(ReceiveLLMResponseBehaviour())
self.logger.info("BDICoreAgent setup complete")
def add_custom_actions(self, actions) -> None:
"""
Registers custom AgentSpeak actions callable from plans.
"""
def add_custom_actions(self, actions):
@actions.add(".reply", 1)
def _reply(agent, term, intention):
message = agentspeak.grounded(term.args[0], intention.scope)
self.logger.info(f"Replying to message: {message}")
reply = self._send_to_llm(message)
self.logger.info(f"Received reply: {reply}")
def _reply(agent: "BDICoreAgent", term, intention):
"""
Sends text to the LLM (AgentSpeak action).
Example: .reply("Hello LLM!")
"""
message_text = agentspeak.grounded(term.args[0], intention.scope)
self.logger.info("Reply action sending: %s", message_text)
self._send_to_llm(str(message_text))
yield
def _send_to_llm(self, message) -> str:
"""TODO: implement"""
return f"This is a reply to {message}"
def _send_to_llm(self, text: str):
"""
Sends a text query to the LLM Agent asynchronously.
"""
class SendBehaviour(OneShotBehaviour):
async def run(self) -> None:
msg = Message(
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
body=text,
)
await self.send(msg)
self.agent.logger.info("Message sent to LLM: %s", text)
self.add_behaviour(SendBehaviour())

View File

@@ -1,4 +1,3 @@
import asyncio
import json
import logging
@@ -9,11 +8,10 @@ from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings
class BeliefSetter(CyclicBehaviour):
class BeliefSetterBehaviour(CyclicBehaviour):
"""
This is the behaviour that the BDI agent runs. This behaviour waits for incoming
message and processes it based on sender. Currently, it only waits for messages
containing beliefs from BeliefCollector and adds these to its KB.
message and processes it based on sender.
"""
agent: BDIAgent
@@ -24,7 +22,6 @@ class BeliefSetter(CyclicBehaviour):
if msg:
self.logger.info(f"Received message {msg.body}")
self._process_message(msg)
await asyncio.sleep(1)
def _process_message(self, message: Message):
sender = message.sender.node # removes host from jid and converts to str
@@ -35,6 +32,7 @@ class BeliefSetter(CyclicBehaviour):
self.logger.debug("Processing message from belief collector.")
self._process_belief_message(message)
case _:
self.logger.debug("Not the belief agent, discarding message")
pass
def _process_belief_message(self, message: Message):
@@ -44,19 +42,25 @@ class BeliefSetter(CyclicBehaviour):
match message.thread:
case "beliefs":
try:
beliefs: dict[str, list[list[str]]] = json.loads(message.body)
beliefs: dict[str, list[str]] = json.loads(message.body)
self._set_beliefs(beliefs)
except json.JSONDecodeError as e:
self.logger.error("Could not decode beliefs into JSON format: %s", e)
case _:
pass
def _set_beliefs(self, beliefs: dict[str, list[list[str]]]):
def _set_beliefs(self, beliefs: dict[str, list[str]]):
"""Remove previous values for beliefs and update them with the provided values."""
if self.agent.bdi is None:
self.logger.warning("Cannot set beliefs, since agent's BDI is not yet initialized.")
return
for belief, arguments_list in beliefs.items():
for arguments in arguments_list:
self.agent.bdi.set_belief(belief, *arguments)
self.logger.info("Set belief %s with arguments %s", belief, arguments)
# Set new beliefs (outdated beliefs are automatically removed)
for belief, arguments in beliefs.items():
self.agent.bdi.set_belief(belief, *arguments)
# Special case: if there's a new user message, flag that we haven't responded yet
if belief == "user_said":
self.agent.bdi.set_belief("new_message")
self.logger.info("Set belief %s with arguments %s", belief, arguments)

View File

@@ -0,0 +1,28 @@
import logging
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
class ReceiveLLMResponseBehaviour(CyclicBehaviour):
"""
Adds behavior to receive responses from the LLM Agent.
"""
logger = logging.getLogger("BDI/LLM Reciever")
async def run(self):
msg = await self.receive(timeout=2)
if not msg:
return
sender = msg.sender.node
match sender:
case settings.agent_settings.llm_agent_name:
content = msg.body
self.logger.info("Received LLM response: %s", content)
# Here the BDI can pass the message back as a response
case _:
self.logger.debug("Not from the llm, discarding message")
pass

View File

@@ -0,0 +1,100 @@
import asyncio
import json
import logging
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings
class BeliefFromText(CyclicBehaviour):
logger = logging.getLogger("Belief From Text")
# TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent.
Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs"
(a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript).
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
A single piece of text might contain multiple instances that match a belief.
Respond ONLY with a single JSON object.
The JSON object's keys should be the belief functors (e.g., "weather").
The value for each key must be a list of lists.
Each inner list must contain the extracted arguments
(as strings) for one instance of that belief.
CRITICAL: If no information in the text matches a belief,
DO NOT include that key in your response.
"""
# on_start agent receives message containing the beliefs to look out
# for and sets up the LLM with instruction prompt
# async def on_start(self):
# msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message
# send instruction prompt to LLM
beliefs: dict[str, list[str]]
beliefs = {"mood": ["X"], "car": ["Y"]}
async def run(self):
msg = await self.receive(timeout=0.1)
if msg:
sender = msg.sender.node
match sender:
case settings.agent_settings.transcription_agent_name:
self.logger.info("Received text from transcriber.")
await self._process_transcription_demo(msg.body)
case _:
self.logger.info("Received message from other agent.")
pass
await asyncio.sleep(1)
async def _process_transcription(self, text: str):
text_prompt = f"Text: {text}"
beliefs_prompt = "These are the beliefs to be bound:\n"
for belief, values in self.beliefs.items():
beliefs_prompt += f"{belief}({', '.join(values)})\n"
prompt = text_prompt + beliefs_prompt
self.logger.info(prompt)
# prompt_msg = Message(to="LLMAgent@whatever")
# response = self.send(prompt_msg)
# Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]]
response = '{"mood": [["happy"]]}'
# Verify by trying to parse
try:
json.loads(response)
belief_message = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=response,
)
belief_message.thread = "beliefs"
await self.send(belief_message)
self.logger.info("Sent beliefs to BDI.")
except json.JSONDecodeError:
# Parsing failed, so the response is in the wrong format, log warning
self.logger.warning("Received LLM response in incorrect format.")
async def _process_transcription_demo(self, txt: str):
"""
Demo version to process the transcription input to beliefs. For the demo only the belief
'user_said' is relevant, so this function simply makes a dict with key: "user_said",
value: txt and passes this to the Belief Collector agent.
"""
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = Message(
to=settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host,
body=payload,
)
belief_msg.thread = "beliefs"
await self.send(belief_msg)
self.logger.info("Sent beliefs to Belief Collector.")

View File

@@ -1,3 +1,3 @@
+user_said(Message) : not responded <-
+responded;
+new_message : user_said(Message) <-
-new_message;
.reply(Message).

View File

@@ -0,0 +1,9 @@
from spade.agent import Agent
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText
class TBeliefExtractor(Agent):
async def setup(self):
self.b = BeliefFromText()
self.add_behaviour(self.b)

View File

@@ -0,0 +1,117 @@
import json
import logging
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
logger = logging.getLogger(__name__)
class ContinuousBeliefCollector(CyclicBehaviour):
"""
Continuously collects beliefs/emotions from extractor agents:
Then we send a unified belief packet to the BDI agent.
"""
async def run(self):
msg = await self.receive(timeout=0.1) # Wait for 0.1s
if msg:
await self._process_message(msg)
async def _process_message(self, msg: Message):
sender_node = self._sender_node(msg)
# Parse JSON payload
try:
payload = json.loads(msg.body)
except Exception as e:
logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node,
msg.body,
e,
)
return
msg_type = payload.get("type")
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock":
logger.info(
"BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node
)
await self._handle_belief_text(payload, sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
logger.info(
"BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node
)
await self._handle_emo_text(payload, sender_node)
else:
logger.info(
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
sender_node,
msg_type,
)
@staticmethod
def _sender_node(msg: Message) -> str:
"""
Extracts the 'node' (localpart) of the sender JID.
E.g., 'agent@host/resource' -> 'agent'
"""
s = str(msg.sender) if msg.sender is not None else "no_sender"
return s.split("@", 1)[0] if "@" in s else s
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Expected payload:
{
"type": "belief_extraction_text",
"beliefs": {"user_said": ["hello"","Can you help me?",
"stop talking to me","No","Pepper do a dance"]}
}
"""
beliefs = payload.get("beliefs", {})
if not beliefs:
logger.info("BeliefCollector: no beliefs to process.")
return
if not isinstance(beliefs, dict):
logger.warning("BeliefCollector: 'beliefs' is not a dict: %r", beliefs)
return
if not all(isinstance(v, list) for v in beliefs.values()):
logger.warning("BeliefCollector: 'beliefs' values are not all lists: %r", beliefs)
return
logger.info("BeliefCollector: forwarding %d beliefs.", len(beliefs))
for belief_name, belief_list in beliefs.items():
for belief in belief_list:
logger.info(" - %s %s", belief_name, str(belief))
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""TODO: implement (after we have emotional recogntion)"""
pass
async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None):
"""
Sends a unified belief packet to the BDI agent.
"""
if not beliefs:
return
to_jid = f"{settings.agent_settings.bdi_core_agent_name}@{settings.agent_settings.host}"
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs)
await self.send(msg)
logger.info("BeliefCollector: sent %d belief(s) to BDI at %s", len(beliefs), to_jid)

View File

@@ -0,0 +1,15 @@
import logging
from spade.agent import Agent
from .behaviours.continuous_collect import ContinuousBeliefCollector
logger = logging.getLogger(__name__)
class BeliefCollectorAgent(Agent):
async def setup(self):
logger.info("BeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(ContinuousBeliefCollector())
logger.info("BeliefCollectorAgent ready.")

View File

@@ -0,0 +1,123 @@
"""
LLM Agent module for routing text queries from the BDI Core Agent to a local LLM
service and returning its responses back to the BDI Core Agent.
"""
import logging
from typing import Any
import httpx
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.agents.llm.llm_instructions import LLMInstructions
from control_backend.core.config import settings
class LLMAgent(Agent):
"""
Agent responsible for processing user text input and querying a locally
hosted LLM for text generation. Receives messages from the BDI Core Agent
and responds with processed LLM output.
"""
logger = logging.getLogger("llm_agent")
class ReceiveMessageBehaviour(CyclicBehaviour):
"""
Cyclic behaviour to continuously listen for incoming messages from
the BDI Core Agent and handle them.
"""
async def run(self):
"""
Receives SPADE messages and processes only those originating from the
configured BDI agent.
"""
msg = await self.receive(timeout=1)
if not msg:
return
sender = msg.sender.node
self.agent.logger.info(
"Received message: %s from %s",
msg.body,
sender,
)
if sender == settings.agent_settings.bdi_core_agent_name:
self.agent.logger.debug("Processing message from BDI Core Agent")
await self._process_bdi_message(msg)
else:
self.agent.logger.debug("Message ignored (not from BDI Core Agent)")
async def _process_bdi_message(self, message: Message):
"""
Forwards user text to the LLM and replies with the generated text.
"""
user_text = message.body
llm_response = await self._query_llm(user_text)
await self._reply(llm_response)
async def _reply(self, msg: str):
"""
Sends a response message back to the BDI Core Agent.
"""
reply = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=msg,
)
await self.send(reply)
self.agent.logger.info("Reply sent to BDI Core Agent")
async def _query_llm(self, prompt: str) -> str:
"""
Sends a chat completion request to the local LLM service.
:param prompt: Input text prompt to pass to the LLM.
:return: LLM-generated content or fallback message.
"""
async with httpx.AsyncClient(timeout=120.0) as client:
# Example dynamic content for future (optional)
instructions = LLMInstructions()
developer_instruction = instructions.build_developer_instruction()
response = await client.post(
settings.llm_settings.local_llm_url,
headers={"Content-Type": "application/json"},
json={
"model": settings.llm_settings.local_llm_model,
"messages": [
{"role": "developer", "content": developer_instruction},
{"role": "user", "content": prompt},
],
"temperature": 0.3,
},
)
try:
response.raise_for_status()
data: dict[str, Any] = response.json()
return (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "No response")
)
except httpx.HTTPError as err:
self.agent.logger.error("HTTP error: %s", err)
return "LLM service unavailable."
except Exception as err:
self.agent.logger.error("Unexpected error: %s", err)
return "Error processing the request."
async def setup(self):
"""
Sets up the SPADE behaviour to filter and process messages from the
BDI Core Agent.
"""
self.logger.info("LLMAgent setup complete")
behaviour = self.ReceiveMessageBehaviour()
self.add_behaviour(behaviour)

View File

@@ -0,0 +1,44 @@
class LLMInstructions:
"""
Defines structured instructions that are sent along with each request
to the LLM to guide its behavior (norms, goals, etc.).
"""
@staticmethod
def default_norms() -> str:
return """
Be friendly and respectful.
Make the conversation feel natural and engaging.
""".strip()
@staticmethod
def default_goals() -> str:
return """
Try to learn the user's name during conversation.
""".strip()
def __init__(self, norms: str | None = None, goals: str | None = None):
self.norms = norms if norms is not None else self.default_norms()
self.goals = goals if goals is not None else self.default_goals()
def build_developer_instruction(self) -> str:
"""
Builds a multi-line formatted instruction string for the LLM.
Includes only non-empty structured fields.
"""
sections = [
"You are a Pepper robot engaging in natural human conversation.",
"Keep responses between 15 sentences, unless instructed otherwise.\n",
]
if self.norms:
sections.append("Norms to follow:")
sections.append(self.norms)
sections.append("")
if self.goals:
sections.append("Goals to reach:")
sections.append(self.goals)
sections.append("")
return "\n".join(sections).strip()

View File

@@ -0,0 +1,43 @@
import json
from spade.agent import Agent
from spade.behaviour import OneShotBehaviour
from spade.message import Message
from control_backend.core.config import settings
class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self):
to_jid = (
f"{settings.agent_settings.belief_collector_agent_name}"
f"@{settings.agent_settings.host}"
)
# Send multiple beliefs in one JSON payload
payload = {
"type": "belief_extraction_text",
"beliefs": {
"user_said": [
"hello test",
"Can you help me?",
"stop talking to me",
"No",
"Pepper do a dance",
]
},
}
msg = Message(to=to_jid)
msg.body = json.dumps(payload)
await self.send(msg)
print(f"Beliefs sent to {to_jid}!")
self.exit_code = "Job Finished!"
await self.agent.stop()
async def setup(self):
print("BeliefTextAgent started")
self.b = self.SendOnceBehaviourBlfText()
self.add_behaviour(self.b)

View File

@@ -4,9 +4,9 @@ import logging
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
@@ -56,16 +56,18 @@ class RICommandAgent(Agent):
"""
logger.info("Setting up %s", self.jid)
context = Context.instance()
# To the robot
self.pubsocket = context.socket(zmq.PUB)
if self.bind:
if self.bind: # TODO: Should this ever be the case?
self.pubsocket.bind(self.address)
else:
self.pubsocket.connect(self.address)
# Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_comm_address)
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent

View File

@@ -2,21 +2,18 @@ import asyncio
import json
import logging
import zmq
import zmq.asyncio
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__)
class RICommunicationAgent(Agent):
_pub_socket: zmq.asyncio.Socket
req_socket: zmq.asyncio.Socket | None
_address = ""
_bind = True
connected = False
@@ -25,7 +22,6 @@ class RICommunicationAgent(Agent):
self,
jid: str,
password: str,
pub_socket: zmq.asyncio.Socket,
port: int = 5222,
verify_security: bool = False,
address="tcp://localhost:0000",
@@ -34,8 +30,8 @@ class RICommunicationAgent(Agent):
super().__init__(jid, password, port, verify_security)
self._address = address
self._bind = bind
self.req_socket = None
self._pub_socket = pub_socket
self._req_socket: zmq.asyncio.Socket | None = None
self.pub_socket: zmq.asyncio.Socket | None = None
class ListenBehaviour(CyclicBehaviour):
async def run(self):
@@ -49,7 +45,7 @@ class RICommunicationAgent(Agent):
seconds_to_wait_total = 1.0
try:
await asyncio.wait_for(
self.agent.req_socket.send_json(message), timeout=seconds_to_wait_total / 2
self.agent._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
)
except TimeoutError:
logger.debug(
@@ -61,23 +57,13 @@ class RICommunicationAgent(Agent):
try:
logger.debug(f"waiting for message for {seconds_to_wait_total / 2} seconds.")
message = await asyncio.wait_for(
self.agent.req_socket.recv_json(), timeout=seconds_to_wait_total / 2
self.agent._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
)
# We didnt get a reply :(
except TimeoutError:
logger.info(
f"No ping back retrieved in {seconds_to_wait_total / 2} seconds totalling"
f"{seconds_to_wait_total} of time, killing myself (or maybe just laying low)."
)
# TODO: Send event to UI letting know that we've lost connection
topic = b"ping"
data = json.dumps(False).encode()
self.agent._pub_socket.send_multipart([topic, data])
await self.agent.setup()
except Exception as e:
logger.debug(f"Differennt exception: {e}")
logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
logger.debug('Received message "%s"', message)
if "endpoint" not in message:
@@ -89,46 +75,53 @@ class RICommunicationAgent(Agent):
case "ping":
topic = b"ping"
data = json.dumps(True).encode()
await self.agent._pub_socket.send_multipart([topic, data])
if self.agent.pub_socket is not None:
await self.agent.pub_socket.send_multipart([topic, data])
await asyncio.sleep(1)
case _:
logger.info(
"Received message with topic different than ping, while ping expected."
)
async def setup_req_socket(self, force=False):
async def setup_sockets(self, force=False):
"""
Sets up request socket for communication agent.
"""
if self.req_socket is None or force:
self.req_socket = context.socket(zmq.REQ)
if self._bind:
self.req_socket.bind(self._address)
# Bind request socket
if self._req_socket is None or force:
self._req_socket = Context.instance().socket(zmq.REQ)
if self._bind: # TODO: Should this ever be the case with new architecture?
self._req_socket.bind(self._address)
else:
self.req_socket.connect(self._address)
self._req_socket.connect(self._address)
async def setup(self, max_retries: int = 5):
# TODO: Check with Kasper
if self.pub_socket is None or force:
self.pub_socket = Context.instance().socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
async def setup(self, max_retries: int = 100):
"""
Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
"""
logger.info("Setting up %s", self.jid)
# Bind request socket
await self.setup_req_socket()
await self.setup_sockets()
retries = 0
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Make sure the socket is properly setup.
if self.req_socket is None:
if self._req_socket is None:
continue
# Send our message and receive one back:)
message = {"endpoint": "negotiate/ports", "data": {}}
await self.req_socket.send_json(message)
await self._req_socket.send_json(message)
try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
received_message = await asyncio.wait_for(self._req_socket.recv_json(), timeout=1.0)
except TimeoutError:
logger.warning(
@@ -173,9 +166,9 @@ class RICommunicationAgent(Agent):
case "main":
if addr != self._address:
if not bind:
self.req_socket.connect(addr)
else:
self.req_socket.bind(addr)
self._req_socket.connect(addr)
else: # TODO: Should this ever be the case?
self._req_socket.bind(addr)
case "actuation":
ri_commands_agent = RICommandAgent(
settings.agent_settings.ri_command_agent_name
@@ -205,9 +198,17 @@ class RICommunicationAgent(Agent):
listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour)
# TODO: Let UI know that we're connected >:)
# Let UI know that we're connected >:)
topic = b"ping"
data = json.dumps(True).encode()
await self._pub_socket.send_multipart([topic, data])
if self.pub_socket is None:
logger.error("communication agent pub socket not correctly initialized.")
else:
try:
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
except TimeoutError:
logger.error(
"Initial connection ping for router timed out in ri_communication_agent."
)
self.connected = True
logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,2 @@
from .speech_recognizer import SpeechRecognizer as SpeechRecognizer
from .transcription_agent import TranscriptionAgent as TranscriptionAgent

View File

@@ -0,0 +1,108 @@
import abc
import sys
if sys.platform == "darwin":
import mlx.core as mx
import mlx_whisper
from mlx_whisper.transcribe import ModelHolder
import numpy as np
import torch
import whisper
class SpeechRecognizer(abc.ABC):
def __init__(self, limit_output_length=True):
"""
:param limit_output_length: When `True`, the length of the generated speech will be limited
by the length of the input audio and some heuristics.
"""
self.limit_output_length = limit_output_length
@abc.abstractmethod
def load_model(self): ...
@abc.abstractmethod
def recognize_speech(self, audio: np.ndarray) -> str:
"""
Recognize speech from the given audio sample.
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
range [-1.0, 1.0].
:return: Recognized speech.
"""
@staticmethod
def _estimate_max_tokens(audio: np.ndarray) -> int:
"""
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens.
:param audio: The audio sample (16 kHz) to use for length estimation.
:return: The estimated length of the transcribed audio in tokens.
"""
length_seconds = len(audio) / 16_000
length_minutes = length_seconds / 60
word_count = length_minutes * 300
token_count = word_count / 3 * 4
return int(token_count)
def _get_decode_options(self, audio: np.ndarray) -> dict:
"""
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
:return: A dict that can be used to construct `whisper.DecodingOptions`.
"""
options = {}
if self.limit_output_length:
options["sample_len"] = self._estimate_max_tokens(audio)
return options
@staticmethod
def best_type():
"""Get the best type of SpeechRecognizer based on system capabilities."""
if torch.mps.is_available():
print("Choosing MLX Whisper model.")
return MLXWhisperSpeechRecognizer()
else:
print("Choosing reference Whisper model.")
return OpenAIWhisperSpeechRecognizer()
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.was_loaded = False
self.model_name = "mlx-community/whisper-small.en-mlx"
def load_model(self):
if self.was_loaded:
return
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
# store it in memory for later usage
ModelHolder.get_model(self.model_name, mx.float16)
self.was_loaded = True
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return mlx_whisper.transcribe(
audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio)
)["text"]
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
def __init__(self, limit_output_length=True):
super().__init__(limit_output_length)
self.model = None
def load_model(self):
if self.model is not None:
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = whisper.load_model("small.en", device=device)
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return whisper.transcribe(
self.model, audio, decode_options=self._get_decode_options(audio)
)["text"]

View File

@@ -0,0 +1,84 @@
import asyncio
import logging
import numpy as np
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
from control_backend.core.config import settings
logger = logging.getLogger(__name__)
class TranscriptionAgent(Agent):
"""
An agent which listens to audio fragments with voice, transcribes them, and sends the
transcription to other agents.
"""
def __init__(self, audio_in_address: str):
jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.transcription_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None
class Transcribing(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket):
super().__init__()
self.audio_in_socket = audio_in_socket
self.speech_recognizer = SpeechRecognizer.best_type()
self._concurrency = asyncio.Semaphore(3)
def warmup(self):
"""Load the transcription model into memory to speed up the first transcription."""
self.speech_recognizer.load_model()
async def _transcribe(self, audio: np.ndarray) -> str:
async with self._concurrency:
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it."""
receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name
+ "@"
+ settings.agent_settings.host,
] # Set message receivers here
for receiver_jid in receiver_jids:
message = Message(to=receiver_jid, body=transcription)
await self.send(message)
async def run(self) -> None:
audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio)
logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)
async def stop(self):
self.audio_in_socket.close()
self.audio_in_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address)
async def setup(self):
logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
transcribing = self.Transcribing(self.audio_in_socket)
transcribing.warmup()
self.add_behaviour(transcribing)
logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1,159 @@
import logging
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from control_backend.agents.transcription import TranscriptionAgent
from control_backend.core.config import settings
logger = logging.getLogger(__name__)
class SocketPoller[T]:
"""
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
"""
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
"""
:param socket: The socket to poll and get data from.
:param timeout_ms: A timeout in milliseconds to wait for data.
"""
self.socket = socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms
async def poll(self, timeout_ms: int | None = None) -> T | None:
"""
Get data from the socket, or None if the timeout is reached.
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
socks = dict(self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv()
return None
class Streaming(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100 # Used to allow small pauses in speech
async def run(self) -> None:
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
logger.debug("No audio data received. Discarding buffer until new data arrives.")
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100
return
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3:
logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(Agent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
async def stop(self):
"""
Stop listening to audio, stop publishing audio, close sockets.
"""
if self.audio_in_socket is not None:
self.audio_in_socket.close()
self.audio_in_socket = None
if self.audio_out_socket is not None:
self.audio_out_socket.close()
self.audio_out_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address)
else:
self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed."""
try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def setup(self):
logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming)
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
logger.info("Finished setting up %s", self.jid)

View File

@@ -1,7 +1,6 @@
import logging
from fastapi import APIRouter, Request
from zmq import Socket
from control_backend.schemas.message import Message
@@ -17,8 +16,7 @@ async def receive_message(message: Message, request: Request):
topic = b"message"
body = message.model_dump_json().encode("utf-8")
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, body])
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Message received"}

View File

@@ -5,10 +5,9 @@ import logging
import zmq.asyncio
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from zmq.asyncio import Socket
from zmq.asyncio import Context, Socket
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
@@ -21,6 +20,8 @@ async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
# TODO: Check with Kasper
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
@@ -40,8 +41,9 @@ async def ping_stream(request: Request):
# Set up internal socket to receive ping updates
logger.debug("Ping stream router event stream entered.")
sub_socket = context.socket(zmq.SUB)
sub_socket.connect(settings.zmq_settings.internal_comm_address)
# TODO: Check with Kasper
sub_socket = Context.instance().socket(zmq.SUB)
sub_socket.connect(settings.zmq_settings.internal_sub_address)
sub_socket.setsockopt(zmq.SUBSCRIBE, b"ping")
connected = False

View File

@@ -3,19 +3,29 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560"
internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561"
class AgentSettings(BaseModel):
host: str = "localhost"
bdi_core_agent_name: str = "bdi_core"
belief_collector_agent_name: str = "belief_collector"
text_belief_extractor_agent_name: str = "text_belief_extractor"
vad_agent_name: str = "vad_agent"
llm_agent_name: str = "llm_agent"
test_agent_name: str = "test_agent"
transcription_agent_name: str = "transcription_agent"
ri_communication_agent_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent"
class LLMSettings(BaseModel):
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "openai/gpt-oss-20b"
class Settings(BaseSettings):
app_title: str = "PepperPlus"
@@ -25,6 +35,8 @@ class Settings(BaseSettings):
agent_settings: AgentSettings = AgentSettings()
llm_settings: LLMSettings = LLMSettings()
model_config = SettingsConfigDict(env_file=".env")

View File

@@ -1,3 +0,0 @@
from zmq.asyncio import Context
context = Context()

View File

@@ -3,33 +3,61 @@
# External imports
import contextlib
import logging
import threading
import zmq
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context
from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
def setup_sockets():
context = Context.instance()
internal_pub_socket = context.socket(zmq.XPUB)
internal_pub_socket.bind(settings.zmq_settings.internal_sub_address)
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
internal_sub_socket = context.socket(zmq.XSUB)
internal_sub_socket.bind(settings.zmq_settings.internal_pub_address)
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
try:
zmq.proxy(internal_sub_socket, internal_pub_socket)
except zmq.ZMQError:
logger.warning("Error while handling PUB/SUB proxy. Closing sockets.")
finally:
internal_pub_socket.close()
internal_sub_socket.close()
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("%s starting up.", app.title)
# Initiate sockets
internal_comm_socket = context.socket(zmq.PUB)
internal_comm_address = settings.zmq_settings.internal_comm_address
internal_comm_socket.bind(internal_comm_address)
app.state.internal_comm_socket = internal_comm_socket
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
proxy_thread = threading.Thread(target=setup_sockets)
proxy_thread.daemon = True
proxy_thread.start()
context = Context.instance()
endpoints_pub_socket = context.socket(zmq.PUB)
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
app.state.endpoints_pub_socket = endpoints_pub_socket
# Initiate agents
ri_communication_agent = RICommunicationAgent(
@@ -37,12 +65,17 @@ async def lifespan(app: FastAPI):
+ "@"
+ settings.agent_settings.host,
password=settings.agent_settings.ri_communication_agent_name,
pub_socket=internal_comm_socket,
address="tcp://*:5555",
bind=True,
)
await ri_communication_agent.start()
llm_agent = LLMAgent(
settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.llm_agent_name,
)
await llm_agent.start()
bdi_core = BDICoreAgent(
settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name,
@@ -50,6 +83,23 @@ async def lifespan(app: FastAPI):
)
await bdi_core.start()
belief_collector = BeliefCollectorAgent(
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.belief_collector_agent_name,
)
await belief_collector.start()
text_belief_extractor = TBeliefExtractor(
settings.agent_settings.text_belief_extractor_agent_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.text_belief_extractor_agent_name,
)
await text_belief_extractor.start()
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
await _temp_vad_agent.start()
yield
logger.info("%s shutting down.", app.title)