200 lines
7.3 KiB
Python
200 lines
7.3 KiB
Python
import json
|
||
import re
|
||
import uuid
|
||
from collections.abc import AsyncGenerator
|
||
|
||
import httpx
|
||
from pydantic import ValidationError
|
||
|
||
from control_backend.agents import BaseAgent
|
||
from control_backend.core.agent_system import InternalMessage
|
||
from control_backend.core.config import settings
|
||
|
||
from ...schemas.llm_prompt_message import LLMPromptMessage
|
||
from .llm_instructions import LLMInstructions
|
||
|
||
|
||
class LLMAgent(BaseAgent):
|
||
"""
|
||
LLM Agent.
|
||
|
||
This agent is responsible for processing user text input and querying a locally
|
||
hosted LLM for text generation. It acts as the conversational brain of the system.
|
||
|
||
It receives :class:`~control_backend.schemas.llm_prompt_message.LLMPromptMessage`
|
||
payloads from the BDI Core Agent, constructs a conversation history, queries the
|
||
LLM via HTTP, and streams the response back to the BDI agent in natural chunks
|
||
(e.g., sentence by sentence).
|
||
|
||
:ivar history: A list of dictionaries representing the conversation history (Role/Content).
|
||
"""
|
||
|
||
def __init__(self, name: str):
|
||
super().__init__(name)
|
||
self.history = []
|
||
|
||
async def setup(self):
|
||
self.logger.info("Setting up %s.", self.name)
|
||
|
||
async def handle_message(self, msg: InternalMessage):
|
||
"""
|
||
Handle incoming messages.
|
||
|
||
Expects messages from :attr:`settings.agent_settings.bdi_core_name` containing
|
||
an :class:`LLMPromptMessage` in the body.
|
||
|
||
:param msg: The received internal message.
|
||
"""
|
||
if msg.sender == settings.agent_settings.bdi_core_name:
|
||
self.logger.debug("Processing message from BDI core.")
|
||
try:
|
||
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
||
await self._process_bdi_message(prompt_message)
|
||
except ValidationError:
|
||
self.logger.debug("Prompt message from BDI core is invalid.")
|
||
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
|
||
if msg.body == "clear_history":
|
||
self.logger.debug("Clearing conversation history.")
|
||
self.history.clear()
|
||
else:
|
||
self.logger.debug("Message ignored (not from BDI core.")
|
||
|
||
async def _process_bdi_message(self, message: LLMPromptMessage):
|
||
"""
|
||
Orchestrate the LLM query and response streaming.
|
||
|
||
Iterates over the chunks yielded by :meth:`_query_llm` and forwards them
|
||
individually to the BDI agent via :meth:`_send_reply`.
|
||
|
||
:param message: The parsed prompt message containing text, norms, and goals.
|
||
"""
|
||
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||
await self._send_reply(chunk)
|
||
self.logger.debug(
|
||
"Finished processing BDI message. Response sent in chunks to BDI core."
|
||
)
|
||
|
||
async def _send_reply(self, msg: str):
|
||
"""
|
||
Sends a response message (chunk) back to the BDI Core Agent.
|
||
|
||
:param msg: The text content of the chunk.
|
||
"""
|
||
reply = InternalMessage(
|
||
to=settings.agent_settings.bdi_core_name,
|
||
sender=self.name,
|
||
body=msg,
|
||
)
|
||
await self.send(reply)
|
||
|
||
async def _query_llm(
|
||
self, prompt: str, norms: list[str], goals: list[str]
|
||
) -> AsyncGenerator[str]:
|
||
"""
|
||
Send a chat completion request to the local LLM service and stream the response.
|
||
|
||
It constructs the full prompt using
|
||
:class:`~control_backend.agents.llm.llm_instructions.LLMInstructions`.
|
||
It streams the response from the LLM and buffers tokens until a natural break (punctuation)
|
||
is reached, then yields the chunk. This ensures that the robot speaks in complete phrases
|
||
rather than individual tokens.
|
||
|
||
:param prompt: Input text prompt to pass to the LLM.
|
||
:param norms: Norms the LLM should hold itself to.
|
||
:param goals: Goals the LLM should achieve.
|
||
:yield: Fragments of the LLM-generated content (e.g., sentences/phrases).
|
||
"""
|
||
self.history.append(
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
}
|
||
)
|
||
|
||
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
|
||
messages = [
|
||
{
|
||
"role": "developer",
|
||
"content": instructions.build_developer_instruction(),
|
||
},
|
||
*self.history,
|
||
]
|
||
|
||
message_id = str(uuid.uuid4()) # noqa
|
||
|
||
try:
|
||
full_message = ""
|
||
current_chunk = ""
|
||
async for token in self._stream_query_llm(messages):
|
||
full_message += token
|
||
current_chunk += token
|
||
|
||
self.logger.llm(
|
||
"Received token: %s",
|
||
full_message,
|
||
extra={"reference": message_id}, # Used in the UI to update old logs
|
||
)
|
||
|
||
# Stream the message in chunks separated by punctuation.
|
||
# We include the delimiter in the emitted chunk for natural flow.
|
||
pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL)
|
||
for m in pattern.finditer(current_chunk):
|
||
chunk = m.group(0)
|
||
if chunk:
|
||
yield current_chunk
|
||
current_chunk = ""
|
||
|
||
# Yield any remaining tail
|
||
if current_chunk:
|
||
yield current_chunk
|
||
|
||
self.history.append(
|
||
{
|
||
"role": "assistant",
|
||
"content": full_message,
|
||
}
|
||
)
|
||
except httpx.HTTPError as err:
|
||
self.logger.error("HTTP error.", exc_info=err)
|
||
yield "LLM service unavailable."
|
||
except Exception as err:
|
||
self.logger.error("Unexpected error.", exc_info=err)
|
||
yield "Error processing the request."
|
||
|
||
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
|
||
"""
|
||
Perform the raw HTTP streaming request to the LLM API.
|
||
|
||
:param messages: The list of message dictionaries (role/content).
|
||
:yield: Raw text tokens (deltas) from the SSE stream.
|
||
:raises httpx.HTTPError: If the API returns a non-200 status.
|
||
"""
|
||
async with httpx.AsyncClient() as client:
|
||
async with client.stream(
|
||
"POST",
|
||
settings.llm_settings.local_llm_url,
|
||
json={
|
||
"model": settings.llm_settings.local_llm_model,
|
||
"messages": messages,
|
||
"temperature": 0.3,
|
||
"stream": True,
|
||
},
|
||
) as response:
|
||
response.raise_for_status()
|
||
|
||
async for line in response.aiter_lines():
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
|
||
data = line[len("data: ") :]
|
||
if data.strip() == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
event = json.loads(data)
|
||
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
|
||
if delta:
|
||
yield delta
|
||
except json.JSONDecodeError:
|
||
self.logger.error("Failed to parse LLM response: %s", data)
|