feat: support history, norms and goals for LLM
ref: N25B-299
This commit is contained in:
@@ -12,8 +12,11 @@ from control_backend.agents.base import BaseAgent
|
|||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||||
|
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
|
||||||
from control_backend.schemas.ri_message import SpeechCommand
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
|
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
|
||||||
|
|
||||||
|
|
||||||
class BDICoreAgent(BaseAgent):
|
class BDICoreAgent(BaseAgent):
|
||||||
bdi_agent: agentspeak.runtime.Agent
|
bdi_agent: agentspeak.runtime.Agent
|
||||||
@@ -112,7 +115,9 @@ class BDICoreAgent(BaseAgent):
|
|||||||
self._add_belief(belief.name, belief.arguments)
|
self._add_belief(belief.name, belief.arguments)
|
||||||
|
|
||||||
def _add_belief(self, name: str, args: Iterable[str] = []):
|
def _add_belief(self, name: str, args: Iterable[str] = []):
|
||||||
new_args = (agentspeak.Literal(arg) for arg in args)
|
# new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple
|
||||||
|
merged_args = DELIMITER.join(arg for arg in args)
|
||||||
|
new_args = (agentspeak.Literal(merged_args),)
|
||||||
term = agentspeak.Literal(name, new_args)
|
term = agentspeak.Literal(name, new_args)
|
||||||
|
|
||||||
self.bdi_agent.call(
|
self.bdi_agent.call(
|
||||||
@@ -178,21 +183,37 @@ class BDICoreAgent(BaseAgent):
|
|||||||
the function expects (which will be located in `term.args`).
|
the function expects (which will be located in `term.args`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@self.actions.add(".reply", 1)
|
@self.actions.add(".reply", 3)
|
||||||
def _reply(agent, term, intention):
|
def _reply(agent: "BDICoreAgent", term, intention):
|
||||||
"""
|
"""
|
||||||
Sends text to the LLM.
|
Sends text to the LLM (AgentSpeak action).
|
||||||
|
Example: .reply("Hello LLM!", "Some norm", "Some goal")
|
||||||
"""
|
"""
|
||||||
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
||||||
|
norms = agentspeak.grounded(term.args[1], intention.scope)
|
||||||
|
goals = agentspeak.grounded(term.args[2], intention.scope)
|
||||||
|
|
||||||
asyncio.create_task(self._send_to_llm(str(message_text)))
|
self.logger.debug("Norms: %s", norms)
|
||||||
|
self.logger.debug("Goals: %s", goals)
|
||||||
|
self.logger.debug("User text: %s", message_text)
|
||||||
|
|
||||||
|
asyncio.create_task(self._send_to_llm(str(message_text), str(norms), str(goals)))
|
||||||
yield
|
yield
|
||||||
|
|
||||||
async def _send_to_llm(self, text: str):
|
async def _send_to_llm(self, text: str, norms: str = None, goals: str = None):
|
||||||
"""
|
"""
|
||||||
Sends a text query to the LLM agent asynchronously.
|
Sends a text query to the LLM agent asynchronously.
|
||||||
"""
|
"""
|
||||||
msg = InternalMessage(to=settings.agent_settings.llm_name, sender=self.name, body=text)
|
prompt = LLMPromptMessage(
|
||||||
|
text=text,
|
||||||
|
norms=norms.split("\n") if norms else [],
|
||||||
|
goals=goals.split("\n") if norms else [],
|
||||||
|
)
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=settings.agent_settings.llm_name,
|
||||||
|
sender=self.name,
|
||||||
|
body=prompt.model_dump_json(),
|
||||||
|
)
|
||||||
await self.send(msg)
|
await self.send(msg)
|
||||||
self.logger.info("Message sent to LLM agent: %s", text)
|
self.logger.info("Message sent to LLM agent: %s", text)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
+user_said(Message) <-
|
norms("").
|
||||||
|
goals("").
|
||||||
|
|
||||||
|
+user_said(Message) : norms(Norms) & goals(Goals) <-
|
||||||
-user_said(Message);
|
-user_said(Message);
|
||||||
.reply(Message).
|
.reply(Message, Norms, Goals).
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
|
from ...schemas.llm_prompt_message import LLMPromptMessage
|
||||||
from .llm_instructions import LLMInstructions
|
from .llm_instructions import LLMInstructions
|
||||||
|
|
||||||
|
|
||||||
@@ -18,19 +21,26 @@ class LLMAgent(BaseAgent):
|
|||||||
and responds with processed LLM output.
|
and responds with processed LLM output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
self.history = []
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
self.logger.info("Setting up %s.", self.name)
|
self.logger.info("Setting up %s.", self.name)
|
||||||
|
|
||||||
async def handle_message(self, msg: InternalMessage):
|
async def handle_message(self, msg: InternalMessage):
|
||||||
if msg.sender == settings.agent_settings.bdi_core_name:
|
if msg.sender == settings.agent_settings.bdi_core_name:
|
||||||
self.logger.debug("Processing message from BDI core.")
|
self.logger.debug("Processing message from BDI core.")
|
||||||
await self._process_bdi_message(msg)
|
try:
|
||||||
|
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
||||||
|
await self._process_bdi_message(prompt_message)
|
||||||
|
except ValidationError:
|
||||||
|
self.logger.debug("Prompt message from BDI core is invalid.")
|
||||||
else:
|
else:
|
||||||
self.logger.debug("Message ignored (not from BDI core.")
|
self.logger.debug("Message ignored (not from BDI core.")
|
||||||
|
|
||||||
async def _process_bdi_message(self, message: InternalMessage):
|
async def _process_bdi_message(self, message: LLMPromptMessage):
|
||||||
user_text = message.body
|
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||||||
async for chunk in self._query_llm(user_text):
|
|
||||||
await self._send_reply(chunk)
|
await self._send_reply(chunk)
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Finished processing BDI message. Response sent in chunks to BDI core."
|
"Finished processing BDI message. Response sent in chunks to BDI core."
|
||||||
@@ -47,39 +57,49 @@ class LLMAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
await self.send(reply)
|
await self.send(reply)
|
||||||
|
|
||||||
async def _query_llm(self, prompt: str) -> AsyncGenerator[str]:
|
async def _query_llm(
|
||||||
|
self, prompt: str, norms: list[str], goals: list[str]
|
||||||
|
) -> AsyncGenerator[str]:
|
||||||
"""
|
"""
|
||||||
Sends a chat completion request to the local LLM service and streams the response by
|
Sends a chat completion request to the local LLM service and streams the response by
|
||||||
yielding fragments separated by punctuation like.
|
yielding fragments separated by punctuation like.
|
||||||
|
|
||||||
:param prompt: Input text prompt to pass to the LLM.
|
:param prompt: Input text prompt to pass to the LLM.
|
||||||
|
:param norms: Norms the LLM should hold itself to.
|
||||||
|
:param goals: Goals the LLM should achieve.
|
||||||
:yield: Fragments of the LLM-generated content.
|
:yield: Fragments of the LLM-generated content.
|
||||||
"""
|
"""
|
||||||
instructions = LLMInstructions(
|
self.history.append(
|
||||||
"- Be friendly and respectful.\n"
|
{
|
||||||
"- Make the conversation feel natural and engaging.\n"
|
"role": "user",
|
||||||
"- Speak like a pirate.\n"
|
"content": prompt,
|
||||||
"- When the user asks what you can do, tell them.",
|
}
|
||||||
"- Try to learn the user's name during conversation.\n"
|
|
||||||
"- Suggest playing a game of asking yes or no questions where you think of a word "
|
|
||||||
"and the user must guess it.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "developer",
|
"role": "developer",
|
||||||
"content": instructions.build_developer_instruction(),
|
"content": instructions.build_developer_instruction(),
|
||||||
},
|
},
|
||||||
{
|
*self.history,
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
message_id = str(uuid.uuid4())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
full_message = ""
|
||||||
current_chunk = ""
|
current_chunk = ""
|
||||||
async for token in self._stream_query_llm(messages):
|
async for token in self._stream_query_llm(messages):
|
||||||
|
full_message += token
|
||||||
current_chunk += token
|
current_chunk += token
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Received token: %s",
|
||||||
|
full_message,
|
||||||
|
extra={"reference": message_id}, # Used in the UI to update old logs
|
||||||
|
)
|
||||||
|
|
||||||
# Stream the message in chunks separated by punctuation.
|
# Stream the message in chunks separated by punctuation.
|
||||||
# We include the delimiter in the emitted chunk for natural flow.
|
# We include the delimiter in the emitted chunk for natural flow.
|
||||||
pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL)
|
pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL)
|
||||||
@@ -92,6 +112,13 @@ class LLMAgent(BaseAgent):
|
|||||||
# Yield any remaining tail
|
# Yield any remaining tail
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
yield current_chunk
|
yield current_chunk
|
||||||
|
|
||||||
|
self.history.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_message,
|
||||||
|
}
|
||||||
|
)
|
||||||
except httpx.HTTPError as err:
|
except httpx.HTTPError as err:
|
||||||
self.logger.error("HTTP error.", exc_info=err)
|
self.logger.error("HTTP error.", exc_info=err)
|
||||||
yield "LLM service unavailable."
|
yield "LLM service unavailable."
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ class LLMInstructions:
|
|||||||
Try to learn the user's name during conversation.
|
Try to learn the user's name during conversation.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
def __init__(self, norms: str | None = None, goals: str | None = None):
|
def __init__(self, norms: list[str] = None, goals: list[str] = None):
|
||||||
self.norms = norms if norms is not None else self.default_norms()
|
self.norms = norms or self.default_norms()
|
||||||
self.goals = goals if goals is not None else self.default_goals()
|
self.goals = goals or self.default_goals()
|
||||||
|
|
||||||
def build_developer_instruction(self) -> str:
|
def build_developer_instruction(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -35,12 +35,14 @@ class LLMInstructions:
|
|||||||
|
|
||||||
if self.norms:
|
if self.norms:
|
||||||
sections.append("Norms to follow:")
|
sections.append("Norms to follow:")
|
||||||
sections.append(self.norms)
|
for norm in self.norms:
|
||||||
|
sections.append("- " + norm)
|
||||||
sections.append("")
|
sections.append("")
|
||||||
|
|
||||||
if self.goals:
|
if self.goals:
|
||||||
sections.append("Goals to reach:")
|
sections.append("Goals to reach:")
|
||||||
sections.append(self.goals)
|
for goal in self.goals:
|
||||||
|
sections.append("- " + goal)
|
||||||
sections.append("")
|
sections.append("")
|
||||||
|
|
||||||
return "\n".join(sections).strip()
|
return "\n".join(sections).strip()
|
||||||
|
|||||||
7
src/control_backend/schemas/llm_prompt_message.py
Normal file
7
src/control_backend/schemas/llm_prompt_message.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class LLMPromptMessage(BaseModel):
|
||||||
|
text: str
|
||||||
|
norms: list[str]
|
||||||
|
goals: list[str]
|
||||||
Reference in New Issue
Block a user