docs: add docs to CB
Pretty much every class and method should have documentation now. ref: N25B-295
This commit is contained in:
@@ -10,6 +10,17 @@ from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
|
||||
class RobotSpeechAgent(BaseAgent):
|
||||
"""
|
||||
This agent acts as a bridge between the control backend and the Robot Interface (RI).
|
||||
It receives speech commands from other agents or from the UI,
|
||||
and forwards them to the robot via a ZMQ PUB socket.
|
||||
|
||||
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
|
||||
:ivar pubsocket: ZMQ PUB socket for sending commands to the Robot Interface.
|
||||
:ivar address: Address to bind/connect the PUB socket.
|
||||
:ivar bind: Whether to bind or connect the PUB socket.
|
||||
"""
|
||||
|
||||
subsocket: zmq.Socket
|
||||
pubsocket: zmq.Socket
|
||||
address = ""
|
||||
@@ -27,7 +38,11 @@ class RobotSpeechAgent(BaseAgent):
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Setup the robot speech command agent
|
||||
Initialize the agent.
|
||||
|
||||
1. Sets up the PUB socket to talk to the robot.
|
||||
2. Sets up the SUB socket to listen for "command" topics (from UI/External).
|
||||
3. Starts the loop for handling ZMQ commands.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
@@ -58,7 +73,11 @@ class RobotSpeechAgent(BaseAgent):
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle commands received from other Python agents.
|
||||
Handle commands received from other internal Python agents.
|
||||
|
||||
Validates the message as a :class:`SpeechCommand` and forwards it to the robot.
|
||||
|
||||
:param msg: The internal message containing the command.
|
||||
"""
|
||||
try:
|
||||
speech_command = SpeechCommand.model_validate_json(msg.body)
|
||||
@@ -68,7 +87,9 @@ class RobotSpeechAgent(BaseAgent):
|
||||
|
||||
async def _zmq_command_loop(self):
|
||||
"""
|
||||
Handle commands from the UI.
|
||||
Loop to handle commands received via ZMQ (e.g., from the UI).
|
||||
|
||||
Listens on the 'command' topic, validates the JSON, and forwards it to the robot.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
|
||||
@@ -5,8 +5,13 @@ from control_backend.core.agent_system import BaseAgent as CoreBaseAgent
|
||||
|
||||
class BaseAgent(CoreBaseAgent):
|
||||
"""
|
||||
Base agent class for our agents to inherit from. This just ensures
|
||||
all agents have a logger.
|
||||
The primary base class for all implementation agents.
|
||||
|
||||
Inherits from :class:`control_backend.core.agent_system.BaseAgent`.
|
||||
This class ensures that every agent instance is automatically equipped with a
|
||||
properly configured ``logger``.
|
||||
|
||||
:ivar logger: A logger instance named after the agent's package and class.
|
||||
"""
|
||||
|
||||
logger: logging.Logger
|
||||
|
||||
@@ -19,6 +19,27 @@ DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
|
||||
|
||||
|
||||
class BDICoreAgent(BaseAgent):
|
||||
"""
|
||||
BDI Core Agent.
|
||||
|
||||
This is the central reasoning agent of the system, powered by the **AgentSpeak(L)** language.
|
||||
It maintains a belief base (representing the state of the world) and a set of plans (rules).
|
||||
|
||||
It runs an internal BDI (Belief-Desire-Intention) cycle using the ``agentspeak`` library.
|
||||
When beliefs change (e.g., via :meth:`_apply_beliefs`), the agent evaluates its plans to
|
||||
determine the best course of action.
|
||||
|
||||
**Custom Actions:**
|
||||
It defines custom actions (like ``.reply``) that allow the AgentSpeak code to interact with
|
||||
external Python agents (e.g., querying the LLM).
|
||||
|
||||
:ivar bdi_agent: The internal AgentSpeak agent instance.
|
||||
:ivar asl_file: Path to the AgentSpeak source file (.asl).
|
||||
:ivar env: The AgentSpeak environment.
|
||||
:ivar actions: A registry of custom actions available to the AgentSpeak code.
|
||||
:ivar _wake_bdi_loop: Event used to wake up the reasoning loop when new beliefs arrive.
|
||||
"""
|
||||
|
||||
bdi_agent: agentspeak.runtime.Agent
|
||||
|
||||
def __init__(self, name: str, asl: str):
|
||||
@@ -30,6 +51,13 @@ class BDICoreAgent(BaseAgent):
|
||||
self._wake_bdi_loop = asyncio.Event()
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""
|
||||
Initialize the BDI agent.
|
||||
|
||||
1. Registers custom actions (like ``.reply``).
|
||||
2. Loads the .asl source file.
|
||||
3. Starts the reasoning loop (:meth:`_bdi_loop`) in the background.
|
||||
"""
|
||||
self.logger.debug("Setup started.")
|
||||
|
||||
self._add_custom_actions()
|
||||
@@ -42,6 +70,9 @@ class BDICoreAgent(BaseAgent):
|
||||
self.logger.debug("Setup complete.")
|
||||
|
||||
async def _load_asl(self):
|
||||
"""
|
||||
Load and parse the AgentSpeak source file.
|
||||
"""
|
||||
try:
|
||||
with open(self.asl_file) as source:
|
||||
self.bdi_agent = self.env.build_agent(source, self.actions)
|
||||
@@ -51,7 +82,11 @@ class BDICoreAgent(BaseAgent):
|
||||
|
||||
async def _bdi_loop(self):
|
||||
"""
|
||||
Runs the AgentSpeak BDI loop. Efficiently checks for when the next expected work will be.
|
||||
The main BDI reasoning loop.
|
||||
|
||||
It waits for the ``_wake_bdi_loop`` event (set when beliefs change or actions complete).
|
||||
When awake, it steps through the AgentSpeak interpreter. It also handles sleeping if
|
||||
the agent has deferred intentions (deadlines).
|
||||
"""
|
||||
while self._running:
|
||||
await (
|
||||
@@ -78,7 +113,12 @@ class BDICoreAgent(BaseAgent):
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Route incoming messages (Beliefs or LLM responses).
|
||||
Handle incoming messages.
|
||||
|
||||
- **Beliefs**: Updates the internal belief base.
|
||||
- **LLM Responses**: Forwards the generated text to the Robot Speech Agent (actuation).
|
||||
|
||||
:param msg: The received internal message.
|
||||
"""
|
||||
self.logger.debug("Processing message from %s.", msg.sender)
|
||||
|
||||
@@ -106,6 +146,12 @@ class BDICoreAgent(BaseAgent):
|
||||
await self.send(out_msg)
|
||||
|
||||
def _apply_beliefs(self, beliefs: list[Belief]):
|
||||
"""
|
||||
Update the belief base with a list of new beliefs.
|
||||
|
||||
If ``replace=True`` is set on a belief, it removes all existing beliefs with that name
|
||||
before adding the new one.
|
||||
"""
|
||||
if not beliefs:
|
||||
return
|
||||
|
||||
@@ -115,6 +161,12 @@ class BDICoreAgent(BaseAgent):
|
||||
self._add_belief(belief.name, belief.arguments)
|
||||
|
||||
def _add_belief(self, name: str, args: Iterable[str] = []):
|
||||
"""
|
||||
Add a single belief to the BDI agent.
|
||||
|
||||
:param name: The functor/name of the belief (e.g., "user_said").
|
||||
:param args: Arguments for the belief.
|
||||
"""
|
||||
# new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple
|
||||
merged_args = DELIMITER.join(arg for arg in args)
|
||||
new_args = (agentspeak.Literal(merged_args),)
|
||||
|
||||
@@ -11,8 +11,14 @@ from control_backend.schemas.program import Program
|
||||
|
||||
class BDIProgramManager(BaseAgent):
|
||||
"""
|
||||
Will interpret programs received from the HTTP endpoint. Extracts norms, goals, triggers and
|
||||
forwards them to the BDI as beliefs.
|
||||
BDI Program Manager Agent.
|
||||
|
||||
This agent is responsible for receiving high-level programs (sequences of instructions/goals)
|
||||
from the external HTTP API (via ZMQ) and translating them into core beliefs (norms and goals)
|
||||
for the BDI Core Agent. In the future, it will be responsible for determining when goals are
|
||||
met, and passing on new norms and goals accordingly.
|
||||
|
||||
:ivar sub_socket: The ZMQ SUB socket used to receive program updates.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -20,6 +26,18 @@ class BDIProgramManager(BaseAgent):
|
||||
self.sub_socket = None
|
||||
|
||||
async def _send_to_bdi(self, program: Program):
|
||||
"""
|
||||
Convert a received program into BDI beliefs and send them to the BDI Core Agent.
|
||||
|
||||
Currently, it takes the **first phase** of the program and extracts:
|
||||
- **Norms**: Constraints or rules the agent must follow.
|
||||
- **Goals**: Objectives the agent must achieve.
|
||||
|
||||
These are sent as a ``BeliefMessage`` with ``replace=True``, meaning they will
|
||||
overwrite any existing norms/goals of the same name in the BDI agent.
|
||||
|
||||
:param program: The program object received from the API.
|
||||
"""
|
||||
first_phase = program.phases[0]
|
||||
norms_belief = Belief(
|
||||
name="norms",
|
||||
@@ -44,7 +62,10 @@ class BDIProgramManager(BaseAgent):
|
||||
|
||||
async def _receive_programs(self):
|
||||
"""
|
||||
Continuously receive programs from the HTTP endpoint, sent to us over ZMQ.
|
||||
Continuous loop that receives program updates from the HTTP endpoint.
|
||||
|
||||
It listens to the ``program`` topic on the internal ZMQ SUB socket.
|
||||
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
|
||||
"""
|
||||
while True:
|
||||
topic, body = await self.sub_socket.recv_multipart()
|
||||
@@ -58,6 +79,12 @@ class BDIProgramManager(BaseAgent):
|
||||
await self._send_to_bdi(program)
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent.
|
||||
|
||||
Connects the internal ZMQ SUB socket and subscribes to the 'program' topic.
|
||||
Starts the background behavior to receive programs.
|
||||
"""
|
||||
context = Context.instance()
|
||||
|
||||
self.sub_socket = context.socket(zmq.SUB)
|
||||
|
||||
@@ -10,14 +10,33 @@ from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||
|
||||
class BDIBeliefCollectorAgent(BaseAgent):
|
||||
"""
|
||||
Continuously collects beliefs/emotions from extractor agents and forwards a
|
||||
unified belief packet to the BDI agent.
|
||||
BDI Belief Collector Agent.
|
||||
|
||||
This agent acts as a central aggregator for beliefs derived from various sources (e.g., text,
|
||||
emotion, vision). It receives raw extracted data from other agents,
|
||||
normalizes them into valid :class:`Belief` objects, and forwards them as a unified packet to the
|
||||
BDI Core Agent.
|
||||
|
||||
It serves as a funnel to ensure the BDI agent receives a consistent stream of beliefs.
|
||||
"""
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming messages from other extractor agents.
|
||||
|
||||
Routes the message to specific handlers based on the 'type' field in the JSON body.
|
||||
Supported types:
|
||||
- ``belief_extraction_text``: Handled by :meth:`_handle_belief_text`
|
||||
- ``emotion_extraction_text``: Handled by :meth:`_handle_emo_text`
|
||||
|
||||
:param msg: The received internal message.
|
||||
"""
|
||||
sender_node = msg.sender
|
||||
|
||||
# Parse JSON payload
|
||||
@@ -49,12 +68,22 @@ class BDIBeliefCollectorAgent(BaseAgent):
|
||||
|
||||
async def _handle_belief_text(self, payload: dict, origin: str):
|
||||
"""
|
||||
Expected payload:
|
||||
{
|
||||
"type": "belief_extraction_text",
|
||||
"beliefs": {"user_said": ["Can you help me?"]}
|
||||
Process text-based belief extraction payloads.
|
||||
|
||||
}
|
||||
Expected payload format::
|
||||
|
||||
{
|
||||
"type": "belief_extraction_text",
|
||||
"beliefs": {
|
||||
"user_said": ["Can you help me?"],
|
||||
"intention": ["ask_help"]
|
||||
}
|
||||
}
|
||||
|
||||
Validates and converts the dictionary items into :class:`Belief` objects.
|
||||
|
||||
:param payload: The dictionary payload containing belief data.
|
||||
:param origin: The name of the sender agent.
|
||||
"""
|
||||
beliefs = payload.get("beliefs", {})
|
||||
|
||||
@@ -90,12 +119,24 @@ class BDIBeliefCollectorAgent(BaseAgent):
|
||||
await self._send_beliefs_to_bdi(beliefs, origin=origin)
|
||||
|
||||
async def _handle_emo_text(self, payload: dict, origin: str):
|
||||
"""TODO: implement (after we have emotional recognition)"""
|
||||
"""
|
||||
Process emotion extraction payloads.
|
||||
|
||||
**TODO**: Implement this method once emotion recognition is integrated.
|
||||
|
||||
:param payload: The dictionary payload containing emotion data.
|
||||
:param origin: The name of the sender agent.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None):
|
||||
"""
|
||||
Sends a unified belief packet to the BDI agent.
|
||||
Send a list of aggregated beliefs to the BDI Core Agent.
|
||||
|
||||
Wraps the beliefs in a :class:`BeliefMessage` and sends it via the 'beliefs' thread.
|
||||
|
||||
:param beliefs: The list of Belief objects to send.
|
||||
:param origin: (Optional) The original source of the beliefs (unused currently).
|
||||
"""
|
||||
if not beliefs:
|
||||
return
|
||||
|
||||
@@ -6,12 +6,31 @@ from control_backend.core.config import settings
|
||||
|
||||
|
||||
class TextBeliefExtractorAgent(BaseAgent):
|
||||
"""
|
||||
Text Belief Extractor Agent.
|
||||
|
||||
This agent is responsible for processing raw text (e.g., from speech transcription) and
|
||||
extracting semantic beliefs from it.
|
||||
|
||||
In the current demonstration version, it performs a simple wrapping of the user's input
|
||||
into a ``user_said`` belief. In a full implementation, this agent would likely interact
|
||||
with an LLM or NLU engine to extract intent, entities, and other structured information.
|
||||
"""
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent and its resources.
|
||||
"""
|
||||
self.logger.info("Settting up %s.", self.name)
|
||||
# Setup LLM belief context if needed (currently demo is just passthrough)
|
||||
self.beliefs = {"mood": ["X"], "car": ["Y"]}
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming messages, primarily from the Transcription Agent.
|
||||
|
||||
:param msg: The received message containing transcribed text.
|
||||
"""
|
||||
sender = msg.sender
|
||||
if sender == settings.agent_settings.transcription_name:
|
||||
self.logger.debug("Received text from transcriber: %s", msg.body)
|
||||
@@ -21,7 +40,15 @@ class TextBeliefExtractorAgent(BaseAgent):
|
||||
|
||||
async def _process_transcription_demo(self, txt: str):
|
||||
"""
|
||||
Demo version to process the transcription input to beliefs.
|
||||
Process the transcribed text and generate beliefs.
|
||||
|
||||
**Demo Implementation:**
|
||||
Currently, this method takes the raw text ``txt`` and wraps it into a belief structure:
|
||||
``user_said("txt")``.
|
||||
|
||||
This belief is then sent to the :class:`BDIBeliefCollectorAgent`.
|
||||
|
||||
:param txt: The raw transcribed text string.
|
||||
"""
|
||||
# For demo, just wrapping user text as user_said belief
|
||||
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
|
||||
|
||||
@@ -12,6 +12,27 @@ from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||
|
||||
|
||||
class RICommunicationAgent(BaseAgent):
|
||||
"""
|
||||
Robot Interface (RI) Communication Agent.
|
||||
|
||||
This agent manages the high-level connection negotiation and health checking (heartbeat)
|
||||
between the Control Backend and the Robot Interface (or UI).
|
||||
|
||||
It acts as a service discovery mechanism:
|
||||
1. It initiates a handshake (negotiation) to discover where other services (like the robot
|
||||
command listener) are listening.
|
||||
2. It spawns specific agents
|
||||
(like :class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`)
|
||||
once the connection details are established.
|
||||
3. It maintains a "ping" loop to ensure the connection remains active.
|
||||
|
||||
:ivar _address: The ZMQ address to attempt the initial connection negotiation.
|
||||
:ivar _bind: Whether to bind or connect the negotiation socket.
|
||||
:ivar _req_socket: ZMQ REQ socket for negotiation and pings.
|
||||
:ivar pub_socket: ZMQ PUB socket for internal notifications (e.g., ping status).
|
||||
:ivar connected: Boolean flag indicating active connection status.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
@@ -27,8 +48,10 @@ class RICommunicationAgent(BaseAgent):
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Try to set up the communication agent, we have `behaviour_settings.comm_setup_max_retries`
|
||||
retries in case we don't have a response yet.
|
||||
Initialize the agent and attempt connection.
|
||||
|
||||
Tries to negotiate connection up to ``behaviour_settings.comm_setup_max_retries`` times.
|
||||
If successful, starts the :meth:`_listen_loop`.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
@@ -45,7 +68,7 @@ class RICommunicationAgent(BaseAgent):
|
||||
|
||||
async def _setup_sockets(self, force=False):
|
||||
"""
|
||||
Sets up request socket for communication agent.
|
||||
Initialize ZMQ sockets (REQ for negotiation, PUB for internal updates).
|
||||
"""
|
||||
# Bind request socket
|
||||
if self._req_socket is None or force:
|
||||
@@ -62,6 +85,15 @@ class RICommunicationAgent(BaseAgent):
|
||||
async def _negotiate_connection(
|
||||
self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries
|
||||
):
|
||||
"""
|
||||
Perform the handshake protocol with the Robot Interface.
|
||||
|
||||
Sends a ``negotiate/ports`` request and expects a configuration response containing
|
||||
port assignments for various services (e.g., actuation).
|
||||
|
||||
:param max_retries: Number of attempts before giving up.
|
||||
:return: True if negotiation succeeded, False otherwise.
|
||||
"""
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
if self._req_socket is None:
|
||||
@@ -122,6 +154,12 @@ class RICommunicationAgent(BaseAgent):
|
||||
return False
|
||||
|
||||
async def _handle_negotiation_response(self, received_message):
|
||||
"""
|
||||
Parse the negotiation response and initialize services.
|
||||
|
||||
Based on the response, it might re-connect the main socket or spawn new agents
|
||||
(e.g., for robot actuation).
|
||||
"""
|
||||
for port_data in received_message["data"]:
|
||||
id = port_data["id"]
|
||||
port = port_data["port"]
|
||||
@@ -159,7 +197,10 @@ class RICommunicationAgent(BaseAgent):
|
||||
|
||||
async def _listen_loop(self):
|
||||
"""
|
||||
Run the listening (ping) loop indefinitely.
|
||||
Maintain the connection via a heartbeat (ping) loop.
|
||||
|
||||
Sends a ``ping`` request periodically and waits for a reply.
|
||||
If pings fail repeatedly, it triggers a disconnection handler to restart negotiation.
|
||||
"""
|
||||
while self._running:
|
||||
if not self.connected:
|
||||
@@ -217,6 +258,11 @@ class RICommunicationAgent(BaseAgent):
|
||||
raise
|
||||
|
||||
async def _handle_disconnection(self):
|
||||
"""
|
||||
Handle connection loss.
|
||||
|
||||
Notifies the UI of disconnection (via internal PUB) and attempts to restart negotiation.
|
||||
"""
|
||||
self.connected = False
|
||||
|
||||
# Tell UI we're disconnected.
|
||||
|
||||
@@ -16,9 +16,17 @@ from .llm_instructions import LLMInstructions
|
||||
|
||||
class LLMAgent(BaseAgent):
|
||||
"""
|
||||
Agent responsible for processing user text input and querying a locally
|
||||
hosted LLM for text generation. Receives messages from the BDI Core Agent
|
||||
and responds with processed LLM output.
|
||||
LLM Agent.
|
||||
|
||||
This agent is responsible for processing user text input and querying a locally
|
||||
hosted LLM for text generation. It acts as the conversational brain of the system.
|
||||
|
||||
It receives :class:`~control_backend.schemas.llm_prompt_message.LLMPromptMessage`
|
||||
payloads from the BDI Core Agent, constructs a conversation history, queries the
|
||||
LLM via HTTP, and streams the response back to the BDI agent in natural chunks
|
||||
(e.g., sentence by sentence).
|
||||
|
||||
:ivar history: A list of dictionaries representing the conversation history (Role/Content).
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
@@ -29,6 +37,14 @@ class LLMAgent(BaseAgent):
|
||||
self.logger.info("Setting up %s.", self.name)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming messages.
|
||||
|
||||
Expects messages from :attr:`settings.agent_settings.bdi_core_name` containing
|
||||
an :class:`LLMPromptMessage` in the body.
|
||||
|
||||
:param msg: The received internal message.
|
||||
"""
|
||||
if msg.sender == settings.agent_settings.bdi_core_name:
|
||||
self.logger.debug("Processing message from BDI core.")
|
||||
try:
|
||||
@@ -40,6 +56,14 @@ class LLMAgent(BaseAgent):
|
||||
self.logger.debug("Message ignored (not from BDI core.")
|
||||
|
||||
async def _process_bdi_message(self, message: LLMPromptMessage):
|
||||
"""
|
||||
Orchestrate the LLM query and response streaming.
|
||||
|
||||
Iterates over the chunks yielded by :meth:`_query_llm` and forwards them
|
||||
individually to the BDI agent via :meth:`_send_reply`.
|
||||
|
||||
:param message: The parsed prompt message containing text, norms, and goals.
|
||||
"""
|
||||
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||||
await self._send_reply(chunk)
|
||||
self.logger.debug(
|
||||
@@ -48,7 +72,9 @@ class LLMAgent(BaseAgent):
|
||||
|
||||
async def _send_reply(self, msg: str):
|
||||
"""
|
||||
Sends a response message back to the BDI Core Agent.
|
||||
Sends a response message (chunk) back to the BDI Core Agent.
|
||||
|
||||
:param msg: The text content of the chunk.
|
||||
"""
|
||||
reply = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
@@ -61,13 +87,18 @@ class LLMAgent(BaseAgent):
|
||||
self, prompt: str, norms: list[str], goals: list[str]
|
||||
) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Sends a chat completion request to the local LLM service and streams the response by
|
||||
yielding fragments separated by punctuation like.
|
||||
Send a chat completion request to the local LLM service and stream the response.
|
||||
|
||||
It constructs the full prompt using
|
||||
:class:`~control_backend.agents.llm.llm_instructions.LLMInstructions`.
|
||||
It streams the response from the LLM and buffers tokens until a natural break (punctuation)
|
||||
is reached, then yields the chunk. This ensures that the robot speaks in complete phrases
|
||||
rather than individual tokens.
|
||||
|
||||
:param prompt: Input text prompt to pass to the LLM.
|
||||
:param norms: Norms the LLM should hold itself to.
|
||||
:param goals: Goals the LLM should achieve.
|
||||
:yield: Fragments of the LLM-generated content.
|
||||
:yield: Fragments of the LLM-generated content (e.g., sentences/phrases).
|
||||
"""
|
||||
self.history.append(
|
||||
{
|
||||
@@ -85,7 +116,7 @@ class LLMAgent(BaseAgent):
|
||||
*self.history,
|
||||
]
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
message_id = str(uuid.uuid4()) # noqa
|
||||
|
||||
try:
|
||||
full_message = ""
|
||||
@@ -127,7 +158,13 @@ class LLMAgent(BaseAgent):
|
||||
yield "Error processing the request."
|
||||
|
||||
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
|
||||
"""Raises httpx.HTTPError when the API gives an error."""
|
||||
"""
|
||||
Perform the raw HTTP streaming request to the LLM API.
|
||||
|
||||
:param messages: The list of message dictionaries (role/content).
|
||||
:yield: Raw text tokens (deltas) from the SSE stream.
|
||||
:raises httpx.HTTPError: If the API returns a non-200 status.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=None) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
class LLMInstructions:
|
||||
"""
|
||||
Defines structured instructions that are sent along with each request
|
||||
to the LLM to guide its behavior (norms, goals, etc.).
|
||||
Helper class to construct the system instructions for the LLM.
|
||||
|
||||
It combines the base persona (Pepper robot) with dynamic norms and goals
|
||||
provided by the BDI system.
|
||||
|
||||
:ivar norms: A list of behavioral norms.
|
||||
:ivar goals: A list of specific conversational goals.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -17,14 +22,22 @@ class LLMInstructions:
|
||||
"Try to learn the user's name during conversation.",
|
||||
]
|
||||
|
||||
def __init__(self, norms: list[str] = None, goals: list[str] = None):
|
||||
def __init__(self, norms: list[str] | None = None, goals: list[str] | None = None):
|
||||
self.norms = norms or self.default_norms()
|
||||
self.goals = goals or self.default_goals()
|
||||
|
||||
def build_developer_instruction(self) -> str:
|
||||
"""
|
||||
Builds a multi-line formatted instruction string for the LLM.
|
||||
Includes only non-empty structured fields.
|
||||
Builds the final system prompt string.
|
||||
|
||||
The prompt includes:
|
||||
1. Persona definition.
|
||||
2. Constraint on response length.
|
||||
3. Instructions on how to handle goals (reach them in order, but prioritize natural flow).
|
||||
4. The specific list of norms.
|
||||
5. The specific list of goals.
|
||||
|
||||
:return: The formatted system prompt string.
|
||||
"""
|
||||
sections = [
|
||||
"You are a Pepper robot engaging in natural human conversation.",
|
||||
|
||||
@@ -14,15 +14,28 @@ from control_backend.core.config import settings
|
||||
|
||||
|
||||
class SpeechRecognizer(abc.ABC):
|
||||
"""
|
||||
Abstract base class for speech recognition backends.
|
||||
|
||||
Provides a common interface for loading models and transcribing audio,
|
||||
as well as heuristics for estimating token counts to optimize decoding.
|
||||
|
||||
:ivar limit_output_length: If True, limits the generated text length based on audio duration.
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
"""
|
||||
:param limit_output_length: When `True`, the length of the generated speech will be limited
|
||||
by the length of the input audio and some heuristics.
|
||||
:param limit_output_length: When ``True``, the length of the generated speech will be
|
||||
limited by the length of the input audio and some heuristics.
|
||||
"""
|
||||
self.limit_output_length = limit_output_length
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_model(self): ...
|
||||
def load_model(self):
|
||||
"""
|
||||
Load the speech recognition model into memory.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
@@ -30,15 +43,17 @@ class SpeechRecognizer(abc.ABC):
|
||||
Recognize speech from the given audio sample.
|
||||
|
||||
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
|
||||
range [-1.0, 1.0].
|
||||
:return: Recognized speech.
|
||||
range [-1.0, 1.0].
|
||||
:return: The recognized speech text.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||
"""
|
||||
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
||||
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
|
||||
Estimate the maximum length of a given audio sample in tokens.
|
||||
|
||||
Assumes a maximum speaking rate of 450 words per minute (3x average), and assumes that
|
||||
3 words is approx. 4 tokens.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||
:return: The estimated length of the transcribed audio in tokens.
|
||||
@@ -51,8 +66,10 @@ class SpeechRecognizer(abc.ABC):
|
||||
|
||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||
"""
|
||||
Construct decoding options for the Whisper model.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
|
||||
:return: A dict that can be used to construct `whisper.DecodingOptions`.
|
||||
:return: A dict that can be used to construct ``whisper.DecodingOptions`` (or equivalent).
|
||||
"""
|
||||
options = {}
|
||||
if self.limit_output_length:
|
||||
@@ -61,7 +78,12 @@ class SpeechRecognizer(abc.ABC):
|
||||
|
||||
@staticmethod
|
||||
def best_type():
|
||||
"""Get the best type of SpeechRecognizer based on system capabilities."""
|
||||
"""
|
||||
Factory method to get the best available `SpeechRecognizer`.
|
||||
|
||||
:return: An instance of :class:`MLXWhisperSpeechRecognizer` if on macOS with Apple Silicon,
|
||||
otherwise :class:`OpenAIWhisperSpeechRecognizer`.
|
||||
"""
|
||||
if torch.mps.is_available():
|
||||
print("Choosing MLX Whisper model.")
|
||||
return MLXWhisperSpeechRecognizer()
|
||||
@@ -71,12 +93,20 @@ class SpeechRecognizer(abc.ABC):
|
||||
|
||||
|
||||
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
"""
|
||||
Speech recognizer using the MLX framework (optimized for Apple Silicon).
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
super().__init__(limit_output_length)
|
||||
self.was_loaded = False
|
||||
self.model_name = settings.speech_model_settings.mlx_model_name
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
Ensures the model is downloaded and cached. MLX loads dynamically, so this
|
||||
pre-fetches the model.
|
||||
"""
|
||||
if self.was_loaded:
|
||||
return
|
||||
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
||||
@@ -94,11 +124,18 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
|
||||
|
||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
"""
|
||||
Speech recognizer using the standard OpenAI Whisper library (PyTorch).
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
super().__init__(limit_output_length)
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
Loads the OpenAI Whisper model onto the available device (CUDA or CPU).
|
||||
"""
|
||||
if self.model is not None:
|
||||
return
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -13,11 +13,26 @@ from .speech_recognizer import SpeechRecognizer
|
||||
|
||||
class TranscriptionAgent(BaseAgent):
|
||||
"""
|
||||
An agent which listens to audio fragments with voice, transcribes them, and sends the
|
||||
transcription to other agents.
|
||||
Transcription Agent.
|
||||
|
||||
This agent listens to audio fragments (containing speech) on a ZMQ SUB socket,
|
||||
transcribes them using the configured :class:`SpeechRecognizer`, and sends the
|
||||
resulting text to other agents (e.g., the Text Belief Extractor).
|
||||
|
||||
It uses an internal semaphore to limit the number of concurrent transcription tasks.
|
||||
|
||||
:ivar audio_in_address: The ZMQ address to receive audio from (usually from VAD Agent).
|
||||
:ivar audio_in_socket: The ZMQ SUB socket instance.
|
||||
:ivar speech_recognizer: The speech recognition engine instance.
|
||||
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str):
|
||||
"""
|
||||
Initialize the Transcription Agent.
|
||||
|
||||
:param audio_in_address: The ZMQ address of the audio source (e.g., VAD output).
|
||||
"""
|
||||
super().__init__(settings.agent_settings.transcription_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
@@ -26,6 +41,13 @@ class TranscriptionAgent(BaseAgent):
|
||||
self._concurrency = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent resources.
|
||||
|
||||
1. Connects to the audio input ZMQ socket.
|
||||
2. Initializes the :class:`SpeechRecognizer` (choosing the best available backend).
|
||||
3. Starts the background transcription loop.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
@@ -42,23 +64,45 @@ class TranscriptionAgent(BaseAgent):
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Stop the agent and close sockets.
|
||||
"""
|
||||
assert self.audio_in_socket is not None
|
||||
self.audio_in_socket.close()
|
||||
self.audio_in_socket = None
|
||||
return await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
"""
|
||||
Helper to connect the ZMQ SUB socket for audio input.
|
||||
"""
|
||||
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.audio_in_socket.connect(self.audio_in_address)
|
||||
|
||||
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||
"""
|
||||
Run the speech recognition on the audio data.
|
||||
|
||||
This runs in a separate thread (via `asyncio.to_thread`) to avoid blocking the event loop,
|
||||
constrained by the concurrency semaphore.
|
||||
|
||||
:param audio: The audio data as a numpy array.
|
||||
:return: The transcribed text string.
|
||||
"""
|
||||
assert self._concurrency is not None and self.speech_recognizer is not None
|
||||
async with self._concurrency:
|
||||
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||
|
||||
async def _share_transcription(self, transcription: str):
|
||||
"""Share a transcription to the other agents that depend on it."""
|
||||
"""
|
||||
Share a transcription to the other agents that depend on it.
|
||||
|
||||
Currently sends to:
|
||||
- :attr:`settings.agent_settings.text_belief_extractor_name`
|
||||
|
||||
:param transcription: The transcribed text.
|
||||
"""
|
||||
receiver_names = [
|
||||
settings.agent_settings.text_belief_extractor_name,
|
||||
]
|
||||
@@ -72,6 +116,12 @@ class TranscriptionAgent(BaseAgent):
|
||||
await self.send(message)
|
||||
|
||||
async def _transcribing_loop(self) -> None:
|
||||
"""
|
||||
The main loop for receiving audio and triggering transcription.
|
||||
|
||||
Receives audio chunks from ZMQ, decodes them to float32, and calls :meth:`_transcribe`.
|
||||
If speech is found, it calls :meth:`_share_transcription`.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
assert self.audio_in_socket is not None
|
||||
|
||||
@@ -15,6 +15,8 @@ class SocketPoller[T]:
|
||||
"""
|
||||
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
|
||||
multiple usages.
|
||||
|
||||
:param T: The type of data returned by the socket.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -35,7 +37,7 @@ class SocketPoller[T]:
|
||||
"""
|
||||
Get data from the socket, or None if the timeout is reached.
|
||||
|
||||
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
|
||||
:param timeout_ms: If given, the timeout. Otherwise, ``self.timeout_ms`` is used.
|
||||
:return: Data from the socket or None.
|
||||
"""
|
||||
timeout_ms = timeout_ms or self.timeout_ms
|
||||
@@ -47,11 +49,27 @@ class SocketPoller[T]:
|
||||
|
||||
class VADAgent(BaseAgent):
|
||||
"""
|
||||
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
|
||||
fragments with detected speech to other agents over ZeroMQ.
|
||||
Voice Activity Detection (VAD) Agent.
|
||||
|
||||
This agent:
|
||||
1. Receives an audio stream (via ZMQ).
|
||||
2. Processes the audio using the Silero VAD model to detect speech.
|
||||
3. Buffers potential speech segments.
|
||||
4. Publishes valid speech fragments (containing speech plus small buffer) to a ZMQ PUB socket.
|
||||
5. Instantiates and starts agents (like :class:`TranscriptionAgent`) that use this output.
|
||||
|
||||
:ivar audio_in_address: Address of the input audio stream.
|
||||
:ivar audio_in_bind: Whether to bind or connect to the input address.
|
||||
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||
"""
|
||||
Initialize the VAD Agent.
|
||||
|
||||
:param audio_in_address: ZMQ address for input audio.
|
||||
:param audio_in_bind: True if this agent should bind to the input address, False to connect.
|
||||
"""
|
||||
super().__init__(settings.agent_settings.vad_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
@@ -67,6 +85,15 @@ class VADAgent(BaseAgent):
|
||||
self.model = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize resources.
|
||||
|
||||
1. Connects audio input socket.
|
||||
2. Binds audio output socket (random port).
|
||||
3. Loads VAD model from Torch Hub.
|
||||
4. Starts the streaming loop.
|
||||
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
@@ -123,7 +150,9 @@ class VADAgent(BaseAgent):
|
||||
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
|
||||
|
||||
def _connect_audio_out_socket(self) -> int | None:
|
||||
"""Returns the port bound, or None if binding failed."""
|
||||
"""
|
||||
Returns the port bound, or None if binding failed.
|
||||
"""
|
||||
try:
|
||||
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
|
||||
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
|
||||
@@ -144,6 +173,15 @@ class VADAgent(BaseAgent):
|
||||
self._ready.set()
|
||||
|
||||
async def _streaming_loop(self):
|
||||
"""
|
||||
Main loop for processing audio stream.
|
||||
|
||||
1. Polls for new audio chunks.
|
||||
2. Passes chunk to VAD model.
|
||||
3. Manages `i_since_speech` counter to determine start/end of speech.
|
||||
4. Buffers speech + context.
|
||||
5. Sends complete speech segment to output socket when silence is detected.
|
||||
"""
|
||||
await self._ready.wait()
|
||||
while self._running:
|
||||
assert self.audio_in_poller is not None
|
||||
|
||||
Reference in New Issue
Block a user