feat: support history, norms and goals for LLM
ref: N25B-299
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
|
||||
from ...schemas.llm_prompt_message import LLMPromptMessage
|
||||
from .llm_instructions import LLMInstructions
|
||||
|
||||
|
||||
@@ -18,19 +21,26 @@ class LLMAgent(BaseAgent):
|
||||
and responds with processed LLM output.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.history = []
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Setting up %s.", self.name)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
if msg.sender == settings.agent_settings.bdi_core_name:
|
||||
self.logger.debug("Processing message from BDI core.")
|
||||
await self._process_bdi_message(msg)
|
||||
try:
|
||||
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
||||
await self._process_bdi_message(prompt_message)
|
||||
except ValidationError:
|
||||
self.logger.debug("Prompt message from BDI core is invalid.")
|
||||
else:
|
||||
self.logger.debug("Message ignored (not from BDI core.")
|
||||
|
||||
async def _process_bdi_message(self, message: InternalMessage):
|
||||
user_text = message.body
|
||||
async for chunk in self._query_llm(user_text):
|
||||
async def _process_bdi_message(self, message: LLMPromptMessage):
|
||||
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||||
await self._send_reply(chunk)
|
||||
self.logger.debug(
|
||||
"Finished processing BDI message. Response sent in chunks to BDI core."
|
||||
@@ -47,39 +57,49 @@ class LLMAgent(BaseAgent):
|
||||
)
|
||||
await self.send(reply)
|
||||
|
||||
async def _query_llm(self, prompt: str) -> AsyncGenerator[str]:
|
||||
async def _query_llm(
|
||||
self, prompt: str, norms: list[str], goals: list[str]
|
||||
) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Sends a chat completion request to the local LLM service and streams the response by
|
||||
yielding fragments separated by punctuation like.
|
||||
|
||||
:param prompt: Input text prompt to pass to the LLM.
|
||||
:param norms: Norms the LLM should hold itself to.
|
||||
:param goals: Goals the LLM should achieve.
|
||||
:yield: Fragments of the LLM-generated content.
|
||||
"""
|
||||
instructions = LLMInstructions(
|
||||
"- Be friendly and respectful.\n"
|
||||
"- Make the conversation feel natural and engaging.\n"
|
||||
"- Speak like a pirate.\n"
|
||||
"- When the user asks what you can do, tell them.",
|
||||
"- Try to learn the user's name during conversation.\n"
|
||||
"- Suggest playing a game of asking yes or no questions where you think of a word "
|
||||
"and the user must guess it.",
|
||||
self.history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
)
|
||||
|
||||
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
|
||||
messages = [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": instructions.build_developer_instruction(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
*self.history,
|
||||
]
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
full_message = ""
|
||||
current_chunk = ""
|
||||
async for token in self._stream_query_llm(messages):
|
||||
full_message += token
|
||||
current_chunk += token
|
||||
|
||||
self.logger.info(
|
||||
"Received token: %s",
|
||||
full_message,
|
||||
extra={"reference": message_id}, # Used in the UI to update old logs
|
||||
)
|
||||
|
||||
# Stream the message in chunks separated by punctuation.
|
||||
# We include the delimiter in the emitted chunk for natural flow.
|
||||
pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL)
|
||||
@@ -92,6 +112,13 @@ class LLMAgent(BaseAgent):
|
||||
# Yield any remaining tail
|
||||
if current_chunk:
|
||||
yield current_chunk
|
||||
|
||||
self.history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_message,
|
||||
}
|
||||
)
|
||||
except httpx.HTTPError as err:
|
||||
self.logger.error("HTTP error.", exc_info=err)
|
||||
yield "LLM service unavailable."
|
||||
|
||||
Reference in New Issue
Block a user