feat: (maybe) stop response when new user message
If we get a new message before the LLM is done responding, interrupt it. ref: N25B-452
This commit is contained in:
@@ -338,7 +338,7 @@ class BDICoreAgent(BaseAgent):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
@self.actions.add(".reply_with_goal", 3)
|
@self.actions.add(".reply_with_goal", 3)
|
||||||
def _reply_with_goal(agent: "BDICoreAgent", term, intention):
|
def _reply_with_goal(agent, term, intention):
|
||||||
"""
|
"""
|
||||||
Let the LLM generate a response to a user's utterance with the current norms and a
|
Let the LLM generate a response to a user's utterance with the current norms and a
|
||||||
specific goal.
|
specific goal.
|
||||||
|
|||||||
@@ -318,6 +318,9 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
settings.llm_settings.local_llm_url,
|
settings.llm_settings.local_llm_url,
|
||||||
|
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
|
||||||
|
if settings.llm_settings.api_key
|
||||||
|
else {},
|
||||||
json={
|
json={
|
||||||
"model": settings.llm_settings.local_llm_model,
|
"model": settings.llm_settings.local_llm_model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@@ -32,6 +33,9 @@ class LLMAgent(BaseAgent):
|
|||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.history = []
|
self.history = []
|
||||||
|
self._querying = False
|
||||||
|
self._interrupted = False
|
||||||
|
self._go_ahead = asyncio.Event()
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
self.logger.info("Setting up %s.", self.name)
|
self.logger.info("Setting up %s.", self.name)
|
||||||
@@ -50,7 +54,7 @@ class LLMAgent(BaseAgent):
|
|||||||
case "prompt_message":
|
case "prompt_message":
|
||||||
try:
|
try:
|
||||||
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
||||||
await self._process_bdi_message(prompt_message)
|
self.add_behavior(self._process_bdi_message(prompt_message)) # no block
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
self.logger.debug("Prompt message from BDI core is invalid.")
|
self.logger.debug("Prompt message from BDI core is invalid.")
|
||||||
case "assistant_message":
|
case "assistant_message":
|
||||||
@@ -73,12 +77,35 @@ class LLMAgent(BaseAgent):
|
|||||||
|
|
||||||
:param message: The parsed prompt message containing text, norms, and goals.
|
:param message: The parsed prompt message containing text, norms, and goals.
|
||||||
"""
|
"""
|
||||||
|
if self._querying:
|
||||||
|
self.logger.debug("Received another BDI prompt while processing previous message.")
|
||||||
|
self._interrupted = True # interrupt the previous processing
|
||||||
|
await self._go_ahead.wait() # wait until we get the go-ahead
|
||||||
|
|
||||||
|
self._go_ahead.clear()
|
||||||
|
self._querying = True
|
||||||
full_message = ""
|
full_message = ""
|
||||||
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||||||
|
if self._interrupted:
|
||||||
|
self.logger.debug("Interrupted processing of previous message.")
|
||||||
|
break
|
||||||
await self._send_reply(chunk)
|
await self._send_reply(chunk)
|
||||||
full_message += chunk
|
full_message += chunk
|
||||||
self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.")
|
else:
|
||||||
await self._send_full_reply(full_message)
|
self._querying = False
|
||||||
|
|
||||||
|
self.history.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.logger.debug(
|
||||||
|
"Finished processing BDI message. Response sent in chunks to BDI core."
|
||||||
|
)
|
||||||
|
await self._send_full_reply(full_message)
|
||||||
|
|
||||||
|
self._interrupted = False
|
||||||
|
|
||||||
async def _send_reply(self, msg: str):
|
async def _send_reply(self, msg: str):
|
||||||
"""
|
"""
|
||||||
@@ -141,7 +168,7 @@ class LLMAgent(BaseAgent):
|
|||||||
full_message += token
|
full_message += token
|
||||||
current_chunk += token
|
current_chunk += token
|
||||||
|
|
||||||
self.logger.llm(
|
self.logger.debug(
|
||||||
"Received token: %s",
|
"Received token: %s",
|
||||||
full_message,
|
full_message,
|
||||||
extra={"reference": message_id}, # Used in the UI to update old logs
|
extra={"reference": message_id}, # Used in the UI to update old logs
|
||||||
@@ -159,13 +186,6 @@ class LLMAgent(BaseAgent):
|
|||||||
# Yield any remaining tail
|
# Yield any remaining tail
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
yield current_chunk
|
yield current_chunk
|
||||||
|
|
||||||
self.history.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": full_message,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except httpx.HTTPError as err:
|
except httpx.HTTPError as err:
|
||||||
self.logger.error("HTTP error.", exc_info=err)
|
self.logger.error("HTTP error.", exc_info=err)
|
||||||
yield "LLM service unavailable."
|
yield "LLM service unavailable."
|
||||||
@@ -185,6 +205,9 @@ class LLMAgent(BaseAgent):
|
|||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
settings.llm_settings.local_llm_url,
|
settings.llm_settings.local_llm_url,
|
||||||
|
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
|
||||||
|
if settings.llm_settings.api_key
|
||||||
|
else {},
|
||||||
json={
|
json={
|
||||||
"model": settings.llm_settings.local_llm_model,
|
"model": settings.llm_settings.local_llm_model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ class LLMSettings(BaseModel):
|
|||||||
|
|
||||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||||
local_llm_model: str = "gpt-oss"
|
local_llm_model: str = "gpt-oss"
|
||||||
|
api_key: str = ""
|
||||||
chat_temperature: float = 1.0
|
chat_temperature: float = 1.0
|
||||||
code_temperature: float = 0.3
|
code_temperature: float = 0.3
|
||||||
n_parallel: int = 4
|
n_parallel: int = 4
|
||||||
|
|||||||
Reference in New Issue
Block a user