diff --git a/src/control_backend/agents/bdi/text_belief_extractor_agent.py b/src/control_backend/agents/bdi/text_belief_extractor_agent.py index 0f2db01..0324573 100644 --- a/src/control_backend/agents/bdi/text_belief_extractor_agent.py +++ b/src/control_backend/agents/bdi/text_belief_extractor_agent.py @@ -1,8 +1,23 @@ +import asyncio import json +import httpx +from pydantic import ValidationError +from slugify import slugify + from control_backend.agents.base import BaseAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.belief_message import Belief as InternalBelief +from control_backend.schemas.belief_message import BeliefMessage +from control_backend.schemas.chat_history import ChatHistory, ChatMessage +from control_backend.schemas.program import ( + Belief, + ConditionalNorm, + InferredBelief, + Program, + SemanticBelief, +) class TextBeliefExtractorAgent(BaseAgent): @@ -12,46 +27,110 @@ class TextBeliefExtractorAgent(BaseAgent): This agent is responsible for processing raw text (e.g., from speech transcription) and extracting semantic beliefs from it. - In the current demonstration version, it performs a simple wrapping of the user's input - into a ``user_said`` belief. In a full implementation, this agent would likely interact - with an LLM or NLU engine to extract intent, entities, and other structured information. + It uses the available beliefs received from the program manager to try to extract beliefs from a + user's message, sends and updated beliefs to the BDI core, and forms a ``user_said`` belief from + the message itself. """ + def __init__(self, name: str): + super().__init__(name) + self.beliefs = {} + self.available_beliefs = [] + self.conversation = ChatHistory(messages=[]) + async def setup(self): """ Initialize the agent and its resources. """ - self.logger.info("Settting up %s.", self.name) - # Setup LLM belief context if needed (currently demo is just passthrough) - self.beliefs = {"mood": ["X"], "car": ["Y"]} + self.logger.info("Setting up %s.", self.name) async def handle_message(self, msg: InternalMessage): """ - Handle incoming messages, primarily from the Transcription Agent. + Handle incoming messages. Expect messages from the Transcriber agent, LLM agent, and the + Program manager agent. - :param msg: The received message containing transcribed text. + :param msg: The received message. """ sender = msg.sender - if sender == settings.agent_settings.transcription_name: - self.logger.debug("Received text from transcriber: %s", msg.body) - await self._process_transcription_demo(msg.body) - else: - self.logger.info("Discarding message from %s", sender) - async def _process_transcription_demo(self, txt: str): + match sender: + case settings.agent_settings.transcription_name: + self.logger.debug("Received text from transcriber: %s", msg.body) + self._apply_conversation_message(ChatMessage(role="user", content=msg.body)) + await self._infer_new_beliefs() + await self._user_said(msg.body) + case settings.agent_settings.llm_name: + self.logger.debug("Received text from LLM: %s", msg.body) + self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body)) + case settings.agent_settings.bdi_program_manager_name: + self._handle_program_manager_message(msg) + case _: + self.logger.info("Discarding message from %s", sender) + return + + def _apply_conversation_message(self, message: ChatMessage): """ - Process the transcribed text and generate beliefs. + Save the chat message to our conversation history, taking into account the conversation + length limit. - **Demo Implementation:** - Currently, this method takes the raw text ``txt`` and wraps it into a belief structure: - ``user_said("txt")``. - - This belief is then sent to the :class:`BDIBeliefCollectorAgent`. - - :param txt: The raw transcribed text string. + :param message: The chat message to add to the conversation history. """ - # For demo, just wrapping user text as user_said belief - belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"} + length_limit = settings.behaviour_settings.conversation_history_length_limit + self.conversation.messages = (self.conversation.messages + [message])[-length_limit:] + + def _handle_program_manager_message(self, msg: InternalMessage): + """ + Handle a message from the program manager: extract available beliefs from it. + + :param msg: The received message from the program manager. + """ + try: + program = Program.model_validate_json(msg.body) + except ValidationError: + self.logger.warning( + "Received message from program manager but it is not a valid program." + ) + return + + self.logger.debug("Received a program from the program manager.") + + self.available_beliefs = self._extract_basic_beliefs_from_program(program) + + # TODO Copied from an incomplete version of the program manager. Use that one instead. + @staticmethod + def _extract_basic_beliefs_from_program(program: Program) -> list[SemanticBelief]: + beliefs = [] + + for phase in program.phases: + for norm in phase.norms: + if isinstance(norm, ConditionalNorm): + beliefs += TextBeliefExtractorAgent._extract_basic_beliefs_from_belief( + norm.condition + ) + + for trigger in phase.triggers: + beliefs += TextBeliefExtractorAgent._extract_basic_beliefs_from_belief( + trigger.condition + ) + + return beliefs + + # TODO Copied from an incomplete version of the program manager. Use that one instead. + @staticmethod + def _extract_basic_beliefs_from_belief(belief: Belief) -> list[SemanticBelief]: + if isinstance(belief, InferredBelief): + return TextBeliefExtractorAgent._extract_basic_beliefs_from_belief( + belief.left + ) + TextBeliefExtractorAgent._extract_basic_beliefs_from_belief(belief.right) + return [belief] + + async def _user_said(self, text: str): + """ + Create a belief for the user's full speech. + + :param text: User's transcribed text. + """ + belief = {"beliefs": {"user_said": [text]}, "type": "belief_extraction_text"} payload = json.dumps(belief) belief_msg = InternalMessage( @@ -60,6 +139,200 @@ class TextBeliefExtractorAgent(BaseAgent): body=payload, thread="beliefs", ) - await self.send(belief_msg) - self.logger.info("Sent %d beliefs to the belief collector.", len(belief["beliefs"])) + + async def _infer_new_beliefs(self): + """ + Process conversation history to extract beliefs, semantically. Any changed beliefs are sent + to the BDI core. + """ + # Return instantly if there are no beliefs to infer + if not self.available_beliefs: + return + + candidate_beliefs = await self._infer_turn() + new_beliefs: list[InternalBelief] = [] + for belief_key, belief_value in candidate_beliefs.items(): + if belief_value is None: + continue + old_belief_value = self.beliefs.get(belief_key) + # TODO: Do we need this check? Can we send the same beliefs multiple times? + if belief_value == old_belief_value: + continue + self.beliefs[belief_key] = belief_value + new_beliefs.append( + InternalBelief(name=belief_key, arguments=[belief_value], replace=True), + ) + + beliefs_message = InternalMessage( + to=settings.agent_settings.bdi_core_name, + sender=self.name, + body=BeliefMessage(beliefs=new_beliefs).model_dump_json(), + thread="beliefs", + ) + await self.send(beliefs_message) + + @staticmethod + def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]: + k, m = divmod(len(items), n) + return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)] + + async def _infer_turn(self) -> dict: + """ + Process the stored conversation history to extract semantic beliefs. Returns a list of + beliefs that have been set to ``True``, ``False`` or ``None``. + + :return: A dict mapping belief names to a value ``True``, ``False`` or ``None``. + """ + n_parallel = min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)) + all_beliefs = await asyncio.gather( + *[ + self._infer_beliefs(self.conversation, beliefs) + for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel) + ] + ) + retval = {} + for beliefs in all_beliefs: + if beliefs is None: + continue + retval.update(beliefs) + return retval + + @staticmethod + def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]: + # TODO: use real belief names + return belief.name or slugify(belief.description), { + "type": ["boolean", "null"], + "description": belief.description, + } + + @staticmethod + def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict: + belief_schemas = [ + TextBeliefExtractorAgent._create_belief_schema(belief) for belief in beliefs + ] + + return { + "type": "object", + "properties": dict(belief_schemas), + "required": [name for name, _ in belief_schemas], + } + + @staticmethod + def _format_message(message: ChatMessage): + return f"{message.role.upper()}:\n{message.content}" + + @staticmethod + def _format_conversation(conversation: ChatHistory): + return "\n\n".join( + [TextBeliefExtractorAgent._format_message(message) for message in conversation.messages] + ) + + @staticmethod + def _format_beliefs(beliefs: list[SemanticBelief]): + # TODO: use real belief names + return "\n".join( + [ + f"- {belief.name or slugify(belief.description)}: {belief.description}" + for belief in beliefs + ] + ) + + async def _infer_beliefs( + self, + conversation: ChatHistory, + beliefs: list[SemanticBelief], + ) -> dict | None: + """ + Infer given beliefs based on the given conversation. + :param conversation: The conversation to infer beliefs from. + :param beliefs: The beliefs to infer. + :return: A dict containing belief names and a boolean whether they hold, or None if the + belief cannot be inferred based on the given conversation. + """ + example = { + "example_belief": True, + } + + prompt = f"""{self._format_conversation(conversation)} + +Given the above conversation, what beliefs can be inferred? +If there is no relevant information about a belief belief, give null. +In case messages conflict, prefer using the most recent messages for inference. + +Choose from the following list of beliefs, formatted as (belief_name, description): +{self._format_beliefs(beliefs)} + +Respond with a JSON similar to the following, but with the property names as given above: +{json.dumps(example, indent=2)} +""" + + schema = self._create_beliefs_schema(beliefs) + + return await self._retry_query_llm(prompt, schema) + + async def _retry_query_llm(self, prompt: str, schema: dict, tries: int = 3) -> dict | None: + """ + Query the LLM with the given prompt and schema, return an instance of a dict conforming + to this schema. Try ``tries`` times, or return None. + + :param prompt: Prompt to be queried. + :param schema: Schema to be queried. + :return: An instance of a dict conforming to this schema, or None if failed. + """ + try_count = 0 + while try_count < tries: + try_count += 1 + + try: + return await self._query_llm(prompt, schema) + except (httpx.HTTPStatusError, json.JSONDecodeError, KeyError) as e: + if try_count < tries: + continue + self.logger.exception( + "Failed to get LLM response after %d tries.", + try_count, + exc_info=e, + ) + + return None + + @staticmethod + async def _query_llm(prompt: str, schema: dict) -> dict: + """ + Query an LLM with the given prompt and schema, return an instance of a dict conforming to + that schema. + + :param prompt: The prompt to be queried. + :param schema: Schema to use during response. + :return: A dict conforming to this schema. + :raises httpx.HTTPStatusError: If the LLM server responded with an error. + :raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the + response was cut off early due to length limitations. + :raises KeyError: If the LLM server responded with no error, but the response was invalid. + """ + async with httpx.AsyncClient() as client: + response = await client.post( + settings.llm_settings.local_llm_url, + json={ + "model": settings.llm_settings.local_llm_model, + "messages": [{"role": "user", "content": prompt}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Beliefs", + "strict": True, + "schema": schema, + }, + }, + "reasoning_effort": "low", + "temperature": settings.llm_settings.code_temperature, + "stream": False, + }, + timeout=None, + ) + response.raise_for_status() + + response_json = response.json() + json_message = response_json["choices"][0]["message"]["content"] + return json.loads(json_message) diff --git a/src/control_backend/agents/llm/llm_agent.py b/src/control_backend/agents/llm/llm_agent.py index 55099e2..17edec9 100644 --- a/src/control_backend/agents/llm/llm_agent.py +++ b/src/control_backend/agents/llm/llm_agent.py @@ -64,11 +64,12 @@ class LLMAgent(BaseAgent): :param message: The parsed prompt message containing text, norms, and goals. """ + full_message = "" async for chunk in self._query_llm(message.text, message.norms, message.goals): await self._send_reply(chunk) - self.logger.debug( - "Finished processing BDI message. Response sent in chunks to BDI core." - ) + full_message += chunk + self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.") + await self._send_full_reply(full_message) async def _send_reply(self, msg: str): """ @@ -83,6 +84,19 @@ class LLMAgent(BaseAgent): ) await self.send(reply) + async def _send_full_reply(self, msg: str): + """ + Sends a response message (full) to agents that need it. + + :param msg: The text content of the message. + """ + message = InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=self.name, + body=msg, + ) + await self.send(message) + async def _query_llm( self, prompt: str, norms: list[str], goals: list[str] ) -> AsyncGenerator[str]: @@ -172,7 +186,7 @@ class LLMAgent(BaseAgent): json={ "model": settings.llm_settings.local_llm_model, "messages": messages, - "temperature": 0.3, + "temperature": settings.llm_settings.chat_temperature, "stream": True, }, ) as response: diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 927985b..1a2560a 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -65,6 +65,7 @@ class BehaviourSettings(BaseModel): :ivar transcription_words_per_minute: Estimated words per minute for transcription timing. :ivar transcription_words_per_token: Estimated words per token for transcription timing. :ivar transcription_token_buffer: Buffer for transcription tokens. + :ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from. """ sleep_s: float = 1.0 @@ -82,6 +83,9 @@ class BehaviourSettings(BaseModel): transcription_words_per_token: float = 0.75 # (3 words = 4 tokens) transcription_token_buffer: int = 10 + # Text belief extractor settings + conversation_history_length_limit = 10 + class LLMSettings(BaseModel): """ @@ -89,10 +93,17 @@ class LLMSettings(BaseModel): :ivar local_llm_url: URL for the local LLM API. :ivar local_llm_model: Name of the local LLM model to use. + :ivar chat_temperature: The temperature to use while generating chat responses. + :ivar code_temperature: The temperature to use while generating code-like responses like during + belief inference. + :ivar n_parallel: The number of parallel calls allowed to be made to the LLM. """ local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "gpt-oss" + chat_temperature = 1.0 + code_temperature = 0.3 + n_parallel: int = 4 class VADSettings(BaseModel): diff --git a/src/control_backend/schemas/chat_history.py b/src/control_backend/schemas/chat_history.py new file mode 100644 index 0000000..52fc224 --- /dev/null +++ b/src/control_backend/schemas/chat_history.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatHistory(BaseModel): + messages: list[ChatMessage] diff --git a/src/control_backend/schemas/program.py b/src/control_backend/schemas/program.py index 28969b9..529a23d 100644 --- a/src/control_backend/schemas/program.py +++ b/src/control_backend/schemas/program.py @@ -1,64 +1,201 @@ -from pydantic import BaseModel +from enum import Enum +from typing import Literal + +from pydantic import UUID4, BaseModel -class Norm(BaseModel): +class ProgramElement(BaseModel): """ - Represents a behavioral norm. + Represents a basic element of our behavior program. + :ivar name: The researcher-assigned name of the element. :ivar id: Unique identifier. - :ivar label: Human-readable label. - :ivar norm: The actual norm text describing the behavior. """ - id: str - label: str - norm: str + name: str + id: UUID4 -class Goal(BaseModel): +class LogicalOperator(Enum): + AND = "AND" + OR = "OR" + + +type Belief = KeywordBelief | SemanticBelief | InferredBelief +type BasicBelief = KeywordBelief | SemanticBelief + + +class KeywordBelief(ProgramElement): """ - Represents an objective to be achieved. + Represents a belief that is set when the user spoken text contains a certain keyword. - :ivar id: Unique identifier. - :ivar label: Human-readable label. - :ivar description: Detailed description of the goal. - :ivar achieved: Status flag indicating if the goal has been met. + :ivar keyword: The keyword on which this belief gets set. """ - id: str - label: str - description: str - achieved: bool - - -class TriggerKeyword(BaseModel): - id: str + name: str = "" keyword: str -class KeywordTrigger(BaseModel): - id: str - label: str - type: str - keywords: list[TriggerKeyword] +class SemanticBelief(ProgramElement): + """ + Represents a belief that is set by semantic LLM validation. + + :ivar description: Description of how to form the belief, used by the LLM. + """ + + name: str = "" + description: str -class Phase(BaseModel): +class InferredBelief(ProgramElement): + """ + Represents a belief that gets formed by combining two beliefs with a logical AND or OR. + + These beliefs can also be :class:`InferredBelief`, leading to arbitrarily deep nesting. + + :ivar operator: The logical operator to apply. + :ivar left: The left part of the logical expression. + :ivar right: The right part of the logical expression. + """ + + name: str = "" + operator: LogicalOperator + left: Belief + right: Belief + + +type Norm = BasicNorm | ConditionalNorm + + +class BasicNorm(ProgramElement): + """ + Represents a behavioral norm. + + :ivar norm: The actual norm text describing the behavior. + :ivar critical: When true, this norm should absolutely not be violated (checked separately). + """ + + name: str = "" + norm: str + critical: bool = False + + +class ConditionalNorm(BasicNorm): + """ + Represents a norm that is only active when a condition is met (i.e., a certain belief holds). + + :ivar condition: When to activate this norm. + """ + + condition: Belief + + +type PlanElement = Goal | Action + + +class Plan(ProgramElement): + """ + Represents a list of steps to execute. Each of these steps can be a goal (with its own plan) + or a simple action. + + :ivar steps: The actions or subgoals to execute, in order. + """ + + name: str = "" + steps: list[PlanElement] + + +class Goal(ProgramElement): + """ + Represents an objective to be achieved. To reach the goal, we should execute + the corresponding plan. If we can fail to achieve a goal after executing the plan, + for example when the achieving of the goal is dependent on the user's reply, this means + that the achieved status will be set from somewhere else in the program. + + :ivar plan: The plan to execute. + :ivar can_fail: Whether we can fail to achieve the goal after executing the plan. + """ + + plan: Plan + can_fail: bool = True + + +type Action = SpeechAction | GestureAction | LLMAction + + +class SpeechAction(ProgramElement): + """ + Represents the action of the robot speaking a literal text. + + :ivar text: The text to speak. + """ + + name: str = "" + text: str + + +class Gesture(BaseModel): + """ + Represents a gesture to be performed. Can be either a single gesture, + or a random gesture from a category (tag). + + :ivar type: The type of the gesture, "tag" or "single". + :ivar name: The name of the single gesture or tag. + """ + + type: Literal["tag", "single"] + name: str + + +class GestureAction(ProgramElement): + """ + Represents the action of the robot performing a gesture. + + :ivar gesture: The gesture to perform. + """ + + name: str = "" + gesture: Gesture + + +class LLMAction(ProgramElement): + """ + Represents the action of letting an LLM generate a reply based on its chat history + and an additional goal added in the prompt. + + :ivar goal: The extra (temporary) goal to add to the LLM. + """ + + name: str = "" + goal: str + + +class Trigger(ProgramElement): + """ + Represents a belief-based trigger. When a belief is set, the corresponding plan is executed. + + :ivar condition: When to activate the trigger. + :ivar plan: The plan to execute. + """ + + name: str = "" + condition: Belief + plan: Plan + + +class Phase(ProgramElement): """ A distinct phase within a program, containing norms, goals, and triggers. - :ivar id: Unique identifier. - :ivar label: Human-readable label. :ivar norms: List of norms active in this phase. :ivar goals: List of goals to pursue in this phase. :ivar triggers: List of triggers that define transitions out of this phase. """ - id: str - label: str + name: str = "" norms: list[Norm] goals: list[Goal] - triggers: list[KeywordTrigger] + triggers: list[Trigger] class Program(BaseModel):