159 lines
5.6 KiB
Python
159 lines
5.6 KiB
Python
import json
|
||
import re
|
||
import uuid
|
||
from collections.abc import AsyncGenerator
|
||
|
||
import httpx
|
||
from pydantic import ValidationError
|
||
|
||
from control_backend.agents import BaseAgent
|
||
from control_backend.core.agent_system import InternalMessage
|
||
from control_backend.core.config import settings
|
||
|
||
from ...schemas.llm_prompt_message import LLMPromptMessage
|
||
from .llm_instructions import LLMInstructions
|
||
|
||
|
||
class LLMAgent(BaseAgent):
|
||
"""
|
||
Agent responsible for processing user text input and querying a locally
|
||
hosted LLM for text generation. Receives messages from the BDI Core Agent
|
||
and responds with processed LLM output.
|
||
"""
|
||
|
||
def __init__(self, name: str):
|
||
super().__init__(name)
|
||
self.history = []
|
||
|
||
async def setup(self):
|
||
self.logger.info("Setting up %s.", self.name)
|
||
|
||
async def handle_message(self, msg: InternalMessage):
|
||
if msg.sender == settings.agent_settings.bdi_core_name:
|
||
self.logger.debug("Processing message from BDI core.")
|
||
try:
|
||
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
||
await self._process_bdi_message(prompt_message)
|
||
except ValidationError:
|
||
self.logger.debug("Prompt message from BDI core is invalid.")
|
||
else:
|
||
self.logger.debug("Message ignored (not from BDI core.")
|
||
|
||
async def _process_bdi_message(self, message: LLMPromptMessage):
|
||
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||
await self._send_reply(chunk)
|
||
self.logger.debug(
|
||
"Finished processing BDI message. Response sent in chunks to BDI core."
|
||
)
|
||
|
||
async def _send_reply(self, msg: str):
|
||
"""
|
||
Sends a response message back to the BDI Core Agent.
|
||
"""
|
||
reply = InternalMessage(
|
||
to=settings.agent_settings.bdi_core_name,
|
||
sender=self.name,
|
||
body=msg,
|
||
)
|
||
await self.send(reply)
|
||
|
||
async def _query_llm(
|
||
self, prompt: str, norms: list[str], goals: list[str]
|
||
) -> AsyncGenerator[str]:
|
||
"""
|
||
Sends a chat completion request to the local LLM service and streams the response by
|
||
yielding fragments separated by punctuation like.
|
||
|
||
:param prompt: Input text prompt to pass to the LLM.
|
||
:param norms: Norms the LLM should hold itself to.
|
||
:param goals: Goals the LLM should achieve.
|
||
:yield: Fragments of the LLM-generated content.
|
||
"""
|
||
self.history.append(
|
||
{
|
||
"role": "user",
|
||
"content": prompt,
|
||
}
|
||
)
|
||
|
||
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
|
||
messages = [
|
||
{
|
||
"role": "developer",
|
||
"content": instructions.build_developer_instruction(),
|
||
},
|
||
*self.history,
|
||
]
|
||
|
||
message_id = str(uuid.uuid4())
|
||
|
||
try:
|
||
full_message = ""
|
||
current_chunk = ""
|
||
async for token in self._stream_query_llm(messages):
|
||
full_message += token
|
||
current_chunk += token
|
||
|
||
self.logger.info(
|
||
"Received token: %s",
|
||
full_message,
|
||
extra={"reference": message_id}, # Used in the UI to update old logs
|
||
)
|
||
|
||
# Stream the message in chunks separated by punctuation.
|
||
# We include the delimiter in the emitted chunk for natural flow.
|
||
pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL)
|
||
for m in pattern.finditer(current_chunk):
|
||
chunk = m.group(0)
|
||
if chunk:
|
||
yield current_chunk
|
||
current_chunk = ""
|
||
|
||
# Yield any remaining tail
|
||
if current_chunk:
|
||
yield current_chunk
|
||
|
||
self.history.append(
|
||
{
|
||
"role": "assistant",
|
||
"content": full_message,
|
||
}
|
||
)
|
||
except httpx.HTTPError as err:
|
||
self.logger.error("HTTP error.", exc_info=err)
|
||
yield "LLM service unavailable."
|
||
except Exception as err:
|
||
self.logger.error("Unexpected error.", exc_info=err)
|
||
yield "Error processing the request."
|
||
|
||
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
|
||
"""Raises httpx.HTTPError when the API gives an error."""
|
||
async with httpx.AsyncClient(timeout=None) as client:
|
||
async with client.stream(
|
||
"POST",
|
||
settings.llm_settings.local_llm_url,
|
||
json={
|
||
"model": settings.llm_settings.local_llm_model,
|
||
"messages": messages,
|
||
"temperature": 0.3,
|
||
"stream": True,
|
||
},
|
||
) as response:
|
||
response.raise_for_status()
|
||
|
||
async for line in response.aiter_lines():
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
|
||
data = line[len("data: ") :]
|
||
if data.strip() == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
event = json.loads(data)
|
||
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
|
||
if delta:
|
||
yield delta
|
||
except json.JSONDecodeError:
|
||
self.logger.error("Failed to parse LLM response: %s", data)
|