feat: (maybe) stop response when new user message
If we get a new message before the LLM is done responding, interrupt it. ref: N25B-452
This commit is contained in:
@@ -338,7 +338,7 @@ class BDICoreAgent(BaseAgent):
|
||||
yield
|
||||
|
||||
@self.actions.add(".reply_with_goal", 3)
|
||||
def _reply_with_goal(agent: "BDICoreAgent", term, intention):
|
||||
def _reply_with_goal(agent, term, intention):
|
||||
"""
|
||||
Let the LLM generate a response to a user's utterance with the current norms and a
|
||||
specific goal.
|
||||
|
||||
@@ -318,6 +318,9 @@ class TextBeliefExtractorAgent(BaseAgent):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.llm_settings.local_llm_url,
|
||||
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
|
||||
if settings.llm_settings.api_key
|
||||
else {},
|
||||
json={
|
||||
"model": settings.llm_settings.local_llm_model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
@@ -32,6 +33,9 @@ class LLMAgent(BaseAgent):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.history = []
|
||||
self._querying = False
|
||||
self._interrupted = False
|
||||
self._go_ahead = asyncio.Event()
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Setting up %s.", self.name)
|
||||
@@ -50,7 +54,7 @@ class LLMAgent(BaseAgent):
|
||||
case "prompt_message":
|
||||
try:
|
||||
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
||||
await self._process_bdi_message(prompt_message)
|
||||
self.add_behavior(self._process_bdi_message(prompt_message)) # no block
|
||||
except ValidationError:
|
||||
self.logger.debug("Prompt message from BDI core is invalid.")
|
||||
case "assistant_message":
|
||||
@@ -73,12 +77,35 @@ class LLMAgent(BaseAgent):
|
||||
|
||||
:param message: The parsed prompt message containing text, norms, and goals.
|
||||
"""
|
||||
if self._querying:
|
||||
self.logger.debug("Received another BDI prompt while processing previous message.")
|
||||
self._interrupted = True # interrupt the previous processing
|
||||
await self._go_ahead.wait() # wait until we get the go-ahead
|
||||
|
||||
self._go_ahead.clear()
|
||||
self._querying = True
|
||||
full_message = ""
|
||||
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||||
if self._interrupted:
|
||||
self.logger.debug("Interrupted processing of previous message.")
|
||||
break
|
||||
await self._send_reply(chunk)
|
||||
full_message += chunk
|
||||
self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.")
|
||||
await self._send_full_reply(full_message)
|
||||
else:
|
||||
self._querying = False
|
||||
|
||||
self.history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_message,
|
||||
}
|
||||
)
|
||||
self.logger.debug(
|
||||
"Finished processing BDI message. Response sent in chunks to BDI core."
|
||||
)
|
||||
await self._send_full_reply(full_message)
|
||||
|
||||
self._interrupted = False
|
||||
|
||||
async def _send_reply(self, msg: str):
|
||||
"""
|
||||
@@ -141,7 +168,7 @@ class LLMAgent(BaseAgent):
|
||||
full_message += token
|
||||
current_chunk += token
|
||||
|
||||
self.logger.llm(
|
||||
self.logger.debug(
|
||||
"Received token: %s",
|
||||
full_message,
|
||||
extra={"reference": message_id}, # Used in the UI to update old logs
|
||||
@@ -159,13 +186,6 @@ class LLMAgent(BaseAgent):
|
||||
# Yield any remaining tail
|
||||
if current_chunk:
|
||||
yield current_chunk
|
||||
|
||||
self.history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_message,
|
||||
}
|
||||
)
|
||||
except httpx.HTTPError as err:
|
||||
self.logger.error("HTTP error.", exc_info=err)
|
||||
yield "LLM service unavailable."
|
||||
@@ -185,6 +205,9 @@ class LLMAgent(BaseAgent):
|
||||
async with client.stream(
|
||||
"POST",
|
||||
settings.llm_settings.local_llm_url,
|
||||
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
|
||||
if settings.llm_settings.api_key
|
||||
else {},
|
||||
json={
|
||||
"model": settings.llm_settings.local_llm_model,
|
||||
"messages": messages,
|
||||
|
||||
@@ -117,6 +117,7 @@ class LLMSettings(BaseModel):
|
||||
|
||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||
local_llm_model: str = "gpt-oss"
|
||||
api_key: str = ""
|
||||
chat_temperature: float = 1.0
|
||||
code_temperature: float = 0.3
|
||||
n_parallel: int = 4
|
||||
|
||||
Reference in New Issue
Block a user