feat: extract semantic beliefs from conversation

ref: N25B-380
This commit is contained in:
Twirre Meulenbelt
2025-12-23 17:09:58 +01:00
parent adbb7ffd5c
commit 33501093a1
5 changed files with 508 additions and 63 deletions

View File

@@ -1,8 +1,23 @@
import asyncio
import json
import httpx
from pydantic import ValidationError
from slugify import slugify
from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import (
Belief,
ConditionalNorm,
InferredBelief,
Program,
SemanticBelief,
)
class TextBeliefExtractorAgent(BaseAgent):
@@ -12,46 +27,110 @@ class TextBeliefExtractorAgent(BaseAgent):
This agent is responsible for processing raw text (e.g., from speech transcription) and
extracting semantic beliefs from it.
In the current demonstration version, it performs a simple wrapping of the user's input
into a ``user_said`` belief. In a full implementation, this agent would likely interact
with an LLM or NLU engine to extract intent, entities, and other structured information.
It uses the available beliefs received from the program manager to try to extract beliefs from a
user's message, sends and updated beliefs to the BDI core, and forms a ``user_said`` belief from
the message itself.
"""
def __init__(self, name: str):
super().__init__(name)
self.beliefs = {}
self.available_beliefs = []
self.conversation = ChatHistory(messages=[])
async def setup(self):
"""
Initialize the agent and its resources.
"""
self.logger.info("Settting up %s.", self.name)
# Setup LLM belief context if needed (currently demo is just passthrough)
self.beliefs = {"mood": ["X"], "car": ["Y"]}
self.logger.info("Setting up %s.", self.name)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming messages, primarily from the Transcription Agent.
Handle incoming messages. Expect messages from the Transcriber agent, LLM agent, and the
Program manager agent.
:param msg: The received message containing transcribed text.
:param msg: The received message.
"""
sender = msg.sender
if sender == settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body)
else:
self.logger.info("Discarding message from %s", sender)
async def _process_transcription_demo(self, txt: str):
match sender:
case settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="user", content=msg.body))
await self._infer_new_beliefs()
await self._user_said(msg.body)
case settings.agent_settings.llm_name:
self.logger.debug("Received text from LLM: %s", msg.body)
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
case settings.agent_settings.bdi_program_manager_name:
self._handle_program_manager_message(msg)
case _:
self.logger.info("Discarding message from %s", sender)
return
def _apply_conversation_message(self, message: ChatMessage):
"""
Process the transcribed text and generate beliefs.
Save the chat message to our conversation history, taking into account the conversation
length limit.
**Demo Implementation:**
Currently, this method takes the raw text ``txt`` and wraps it into a belief structure:
``user_said("txt")``.
This belief is then sent to the :class:`BDIBeliefCollectorAgent`.
:param txt: The raw transcribed text string.
:param message: The chat message to add to the conversation history.
"""
# For demo, just wrapping user text as user_said belief
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
length_limit = settings.behaviour_settings.conversation_history_length_limit
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:]
def _handle_program_manager_message(self, msg: InternalMessage):
"""
Handle a message from the program manager: extract available beliefs from it.
:param msg: The received message from the program manager.
"""
try:
program = Program.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received message from program manager but it is not a valid program."
)
return
self.logger.debug("Received a program from the program manager.")
self.available_beliefs = self._extract_basic_beliefs_from_program(program)
# TODO Copied from an incomplete version of the program manager. Use that one instead.
@staticmethod
def _extract_basic_beliefs_from_program(program: Program) -> list[SemanticBelief]:
beliefs = []
for phase in program.phases:
for norm in phase.norms:
if isinstance(norm, ConditionalNorm):
beliefs += TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
norm.condition
)
for trigger in phase.triggers:
beliefs += TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
trigger.condition
)
return beliefs
# TODO Copied from an incomplete version of the program manager. Use that one instead.
@staticmethod
def _extract_basic_beliefs_from_belief(belief: Belief) -> list[SemanticBelief]:
if isinstance(belief, InferredBelief):
return TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(
belief.left
) + TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(belief.right)
return [belief]
async def _user_said(self, text: str):
"""
Create a belief for the user's full speech.
:param text: User's transcribed text.
"""
belief = {"beliefs": {"user_said": [text]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = InternalMessage(
@@ -60,6 +139,200 @@ class TextBeliefExtractorAgent(BaseAgent):
body=payload,
thread="beliefs",
)
await self.send(belief_msg)
self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"]))
async def _infer_new_beliefs(self):
"""
Process conversation history to extract beliefs, semantically. Any changed beliefs are sent
to the BDI core.
"""
# Return instantly if there are no beliefs to infer
if not self.available_beliefs:
return
candidate_beliefs = await self._infer_turn()
new_beliefs: list[InternalBelief] = []
for belief_key, belief_value in candidate_beliefs.items():
if belief_value is None:
continue
old_belief_value = self.beliefs.get(belief_key)
# TODO: Do we need this check? Can we send the same beliefs multiple times?
if belief_value == old_belief_value:
continue
self.beliefs[belief_key] = belief_value
new_beliefs.append(
InternalBelief(name=belief_key, arguments=[belief_value], replace=True),
)
beliefs_message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=BeliefMessage(beliefs=new_beliefs).model_dump_json(),
thread="beliefs",
)
await self.send(beliefs_message)
@staticmethod
def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]:
k, m = divmod(len(items), n)
return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]
async def _infer_turn(self) -> dict:
"""
Process the stored conversation history to extract semantic beliefs. Returns a list of
beliefs that have been set to ``True``, ``False`` or ``None``.
:return: A dict mapping belief names to a value ``True``, ``False`` or ``None``.
"""
n_parallel = min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs))
all_beliefs = await asyncio.gather(
*[
self._infer_beliefs(self.conversation, beliefs)
for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel)
]
)
retval = {}
for beliefs in all_beliefs:
if beliefs is None:
continue
retval.update(beliefs)
return retval
@staticmethod
def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]:
# TODO: use real belief names
return belief.name or slugify(belief.description), {
"type": ["boolean", "null"],
"description": belief.description,
}
@staticmethod
def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict:
belief_schemas = [
TextBeliefExtractorAgent._create_belief_schema(belief) for belief in beliefs
]
return {
"type": "object",
"properties": dict(belief_schemas),
"required": [name for name, _ in belief_schemas],
}
@staticmethod
def _format_message(message: ChatMessage):
return f"{message.role.upper()}:\n{message.content}"
@staticmethod
def _format_conversation(conversation: ChatHistory):
return "\n\n".join(
[TextBeliefExtractorAgent._format_message(message) for message in conversation.messages]
)
@staticmethod
def _format_beliefs(beliefs: list[SemanticBelief]):
# TODO: use real belief names
return "\n".join(
[
f"- {belief.name or slugify(belief.description)}: {belief.description}"
for belief in beliefs
]
)
async def _infer_beliefs(
self,
conversation: ChatHistory,
beliefs: list[SemanticBelief],
) -> dict | None:
"""
Infer given beliefs based on the given conversation.
:param conversation: The conversation to infer beliefs from.
:param beliefs: The beliefs to infer.
:return: A dict containing belief names and a boolean whether they hold, or None if the
belief cannot be inferred based on the given conversation.
"""
example = {
"example_belief": True,
}
prompt = f"""{self._format_conversation(conversation)}
Given the above conversation, what beliefs can be inferred?
If there is no relevant information about a belief belief, give null.
In case messages conflict, prefer using the most recent messages for inference.
Choose from the following list of beliefs, formatted as (belief_name, description):
{self._format_beliefs(beliefs)}
Respond with a JSON similar to the following, but with the property names as given above:
{json.dumps(example, indent=2)}
"""
schema = self._create_beliefs_schema(beliefs)
return await self._retry_query_llm(prompt, schema)
async def _retry_query_llm(self, prompt: str, schema: dict, tries: int = 3) -> dict | None:
"""
Query the LLM with the given prompt and schema, return an instance of a dict conforming
to this schema. Try ``tries`` times, or return None.
:param prompt: Prompt to be queried.
:param schema: Schema to be queried.
:return: An instance of a dict conforming to this schema, or None if failed.
"""
try_count = 0
while try_count < tries:
try_count += 1
try:
return await self._query_llm(prompt, schema)
except (httpx.HTTPStatusError, json.JSONDecodeError, KeyError) as e:
if try_count < tries:
continue
self.logger.exception(
"Failed to get LLM response after %d tries.",
try_count,
exc_info=e,
)
return None
@staticmethod
async def _query_llm(prompt: str, schema: dict) -> dict:
"""
Query an LLM with the given prompt and schema, return an instance of a dict conforming to
that schema.
:param prompt: The prompt to be queried.
:param schema: Schema to use during response.
:return: A dict conforming to this schema.
:raises httpx.HTTPStatusError: If the LLM server responded with an error.
:raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the
response was cut off early due to length limitations.
:raises KeyError: If the LLM server responded with no error, but the response was invalid.
"""
async with httpx.AsyncClient() as client:
response = await client.post(
settings.llm_settings.local_llm_url,
json={
"model": settings.llm_settings.local_llm_model,
"messages": [{"role": "user", "content": prompt}],
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "Beliefs",
"strict": True,
"schema": schema,
},
},
"reasoning_effort": "low",
"temperature": settings.llm_settings.code_temperature,
"stream": False,
},
timeout=None,
)
response.raise_for_status()
response_json = response.json()
json_message = response_json["choices"][0]["message"]["content"]
return json.loads(json_message)