Files
pepperplus-cb/src/control_backend/agents/llm/llm_agent.py

200 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import re
import uuid
from collections.abc import AsyncGenerator
import httpx
from pydantic import ValidationError
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from ...schemas.llm_prompt_message import LLMPromptMessage
from .llm_instructions import LLMInstructions
class LLMAgent(BaseAgent):
"""
LLM Agent.
This agent is responsible for processing user text input and querying a locally
hosted LLM for text generation. It acts as the conversational brain of the system.
It receives :class:`~control_backend.schemas.llm_prompt_message.LLMPromptMessage`
payloads from the BDI Core Agent, constructs a conversation history, queries the
LLM via HTTP, and streams the response back to the BDI agent in natural chunks
(e.g., sentence by sentence).
:ivar history: A list of dictionaries representing the conversation history (Role/Content).
"""
def __init__(self, name: str):
super().__init__(name)
self.history = []
async def setup(self):
self.logger.info("Setting up %s.", self.name)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages.
Expects messages from :attr:`settings.agent_settings.bdi_core_name` containing
an :class:`LLMPromptMessage` in the body.
:param msg: The received internal message.
"""
if msg.sender == settings.agent_settings.bdi_core_name:
self.logger.debug("Processing message from BDI core.")
try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message)
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
if msg.body == "clear_history":
self.logger.debug("Clearing conversation history.")
self.history.clear()
else:
self.logger.debug("Message ignored (not from BDI core.")
async def _process_bdi_message(self, message: LLMPromptMessage):
"""
Orchestrate the LLM query and response streaming.
Iterates over the chunks yielded by :meth:`_query_llm` and forwards them
individually to the BDI agent via :meth:`_send_reply`.
:param message: The parsed prompt message containing text, norms, and goals.
"""
async for chunk in self._query_llm(message.text, message.norms, message.goals):
await self._send_reply(chunk)
self.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI core."
)
async def _send_reply(self, msg: str):
"""
Sends a response message (chunk) back to the BDI Core Agent.
:param msg: The text content of the chunk.
"""
reply = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=msg,
)
await self.send(reply)
async def _query_llm(
self, prompt: str, norms: list[str], goals: list[str]
) -> AsyncGenerator[str]:
"""
Send a chat completion request to the local LLM service and stream the response.
It constructs the full prompt using
:class:`~control_backend.agents.llm.llm_instructions.LLMInstructions`.
It streams the response from the LLM and buffers tokens until a natural break (punctuation)
is reached, then yields the chunk. This ensures that the robot speaks in complete phrases
rather than individual tokens.
:param prompt: Input text prompt to pass to the LLM.
:param norms: Norms the LLM should hold itself to.
:param goals: Goals the LLM should achieve.
:yield: Fragments of the LLM-generated content (e.g., sentences/phrases).
"""
self.history.append(
{
"role": "user",
"content": prompt,
}
)
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
messages = [
{
"role": "developer",
"content": instructions.build_developer_instruction(),
},
*self.history,
]
message_id = str(uuid.uuid4()) # noqa
try:
full_message = ""
current_chunk = ""
async for token in self._stream_query_llm(messages):
full_message += token
current_chunk += token
self.logger.llm(
"Received token: %s",
full_message,
extra={"reference": message_id}, # Used in the UI to update old logs
)
# Stream the message in chunks separated by punctuation.
# We include the delimiter in the emitted chunk for natural flow.
pattern = re.compile(r".*?(?:,|;|:|—||\.{3}|…|\.|\?|!)\s*", re.DOTALL)
for m in pattern.finditer(current_chunk):
chunk = m.group(0)
if chunk:
yield current_chunk
current_chunk = ""
# Yield any remaining tail
if current_chunk:
yield current_chunk
self.history.append(
{
"role": "assistant",
"content": full_message,
}
)
except httpx.HTTPError as err:
self.logger.error("HTTP error.", exc_info=err)
yield "LLM service unavailable."
except Exception as err:
self.logger.error("Unexpected error.", exc_info=err)
yield "Error processing the request."
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
"""
Perform the raw HTTP streaming request to the LLM API.
:param messages: The list of message dictionaries (role/content).
:yield: Raw text tokens (deltas) from the SSE stream.
:raises httpx.HTTPError: If the API returns a non-200 status.
"""
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": messages,
"temperature": 0.3,
"stream": True,
},
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data = line[len("data: ") :]
if data.strip() == "[DONE]":
break
try:
event = json.loads(data)
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
if delta:
yield delta
except json.JSONDecodeError:
self.logger.error("Failed to parse LLM response: %s", data)