From aa5b386f658415e52f22b4f0390e67974f9770d1 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:08:23 +0100 Subject: [PATCH] feat: semantically determine goal completion ref: N25B-432 --- .../agents/bdi/bdi_core_agent.py | 13 +- .../agents/bdi/bdi_program_manager.py | 64 ++- .../agents/bdi/belief_collector_agent.py | 2 +- .../agents/bdi/text_belief_extractor_agent.py | 443 ++++++++++++------ src/control_backend/core/agent_system.py | 11 +- src/control_backend/schemas/belief_list.py | 5 + src/control_backend/schemas/belief_message.py | 3 + src/control_backend/schemas/program.py | 2 +- 8 files changed, 380 insertions(+), 163 deletions(-) diff --git a/src/control_backend/agents/bdi/bdi_core_agent.py b/src/control_backend/agents/bdi/bdi_core_agent.py index 58ece29..3baa493 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent.py +++ b/src/control_backend/agents/bdi/bdi_core_agent.py @@ -205,12 +205,15 @@ class BDICoreAgent(BaseAgent): self.logger.debug(f"Added belief {self.format_belief_string(name, args)}") - def _remove_belief(self, name: str, args: Iterable[str]): + def _remove_belief(self, name: str, args: Iterable[str] | None): """ Removes a specific belief (with arguments), if it exists. """ - new_args = (agentspeak.Literal(arg) for arg in args) - term = agentspeak.Literal(name, new_args) + if args is None: + term = agentspeak.Literal(name) + else: + new_args = (agentspeak.Literal(arg) for arg in args) + term = agentspeak.Literal(name, new_args) result = self.bdi_agent.call( agentspeak.Trigger.removal, @@ -346,8 +349,8 @@ class BDICoreAgent(BaseAgent): self.logger.info("Message sent to LLM agent: %s", text) @staticmethod - def format_belief_string(name: str, args: Iterable[str] = []): + def format_belief_string(name: str, args: Iterable[str] | None = []): """ Given a belief's name and its args, return a string of the form "name(*args)" """ - return f"{name}{'(' if args else ''}{','.join(args)}{')' if args else ''}" + return f"{name}{'(' if args else ''}{','.join(args or [])}{')' if args else ''}" diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index 54e7196..96d924d 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -7,9 +7,9 @@ from zmq.asyncio import Context from control_backend.agents import BaseAgent from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator from control_backend.core.config import settings -from control_backend.schemas.belief_list import BeliefList +from control_backend.schemas.belief_list import BeliefList, GoalList from control_backend.schemas.internal_message import InternalMessage -from control_backend.schemas.program import Belief, ConditionalNorm, InferredBelief, Program +from control_backend.schemas.program import Belief, ConditionalNorm, Goal, InferredBelief, Program class BDIProgramManager(BaseAgent): @@ -63,24 +63,23 @@ class BDIProgramManager(BaseAgent): def _extract_beliefs_from_program(program: Program) -> list[Belief]: beliefs: list[Belief] = [] + def extract_beliefs_from_belief(belief: Belief) -> list[Belief]: + if isinstance(belief, InferredBelief): + return extract_beliefs_from_belief(belief.left) + extract_beliefs_from_belief( + belief.right + ) + return [belief] + for phase in program.phases: for norm in phase.norms: if isinstance(norm, ConditionalNorm): - beliefs += BDIProgramManager._extract_beliefs_from_belief(norm.condition) + beliefs += extract_beliefs_from_belief(norm.condition) for trigger in phase.triggers: - beliefs += BDIProgramManager._extract_beliefs_from_belief(trigger.condition) + beliefs += extract_beliefs_from_belief(trigger.condition) return beliefs - @staticmethod - def _extract_beliefs_from_belief(belief: Belief) -> list[Belief]: - if isinstance(belief, InferredBelief): - return BDIProgramManager._extract_beliefs_from_belief( - belief.left - ) + BDIProgramManager._extract_beliefs_from_belief(belief.right) - return [belief] - async def _send_beliefs_to_semantic_belief_extractor(self, program: Program): """ Extract beliefs from the program and send them to the Semantic Belief Extractor Agent. @@ -98,6 +97,46 @@ class BDIProgramManager(BaseAgent): await self.send(message) + @staticmethod + def _extract_goals_from_program(program: Program) -> list[Goal]: + """ + Extract all goals from the program, including subgoals. + + :param program: The program received from the API. + :return: A list of Goal objects. + """ + goals: list[Goal] = [] + + def extract_goals_from_goal(goal_: Goal) -> list[Goal]: + goals_: list[Goal] = [goal] + for plan in goal_.plan: + if isinstance(plan, Goal): + goals_.extend(extract_goals_from_goal(plan)) + return goals_ + + for phase in program.phases: + for goal in phase.goals: + goals.extend(extract_goals_from_goal(goal)) + + return goals + + async def _send_goals_to_semantic_belief_extractor(self, program: Program): + """ + Extract goals from the program and send them to the Semantic Belief Extractor Agent. + + :param program: The program received from the API. + """ + goals = GoalList(goals=self._extract_goals_from_program(program)) + + message = InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=self.name, + body=goals.model_dump_json(), + thread="goals", + ) + + await self.send(message) + async def _receive_programs(self): """ Continuous loop that receives program updates from the HTTP endpoint. @@ -117,6 +156,7 @@ class BDIProgramManager(BaseAgent): await asyncio.gather( self._create_agentspeak_and_send_to_bdi(program), self._send_beliefs_to_semantic_belief_extractor(program), + self._send_goals_to_semantic_belief_extractor(program), ) async def setup(self): diff --git a/src/control_backend/agents/bdi/belief_collector_agent.py b/src/control_backend/agents/bdi/belief_collector_agent.py index 6f89d2a..ac0e2e5 100644 --- a/src/control_backend/agents/bdi/belief_collector_agent.py +++ b/src/control_backend/agents/bdi/belief_collector_agent.py @@ -101,7 +101,7 @@ class BDIBeliefCollectorAgent(BaseAgent): :return: A Belief object if the input is valid or None. """ try: - return Belief(name=name, arguments=arguments, replace=name == "user_said") + return Belief(name=name, arguments=arguments) except ValidationError: return None diff --git a/src/control_backend/agents/bdi/text_belief_extractor_agent.py b/src/control_backend/agents/bdi/text_belief_extractor_agent.py index 800d5e4..7e3570f 100644 --- a/src/control_backend/agents/bdi/text_belief_extractor_agent.py +++ b/src/control_backend/agents/bdi/text_belief_extractor_agent.py @@ -2,17 +2,45 @@ import asyncio import json import httpx -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from control_backend.agents.base import BaseAgent from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings -from control_backend.schemas.belief_list import BeliefList +from control_backend.schemas.belief_list import BeliefList, GoalList from control_backend.schemas.belief_message import Belief as InternalBelief from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.chat_history import ChatHistory, ChatMessage -from control_backend.schemas.program import SemanticBelief +from control_backend.schemas.program import Goal, SemanticBelief + +type JSONLike = None | bool | int | float | str | list["JSONLike"] | dict[str, "JSONLike"] + + +class BeliefState(BaseModel): + true: set[InternalBelief] = set() + false: set[InternalBelief] = set() + + def difference(self, other: "BeliefState") -> "BeliefState": + return BeliefState( + true=self.true - other.true, + false=self.false - other.false, + ) + + def union(self, other: "BeliefState") -> "BeliefState": + return BeliefState( + true=self.true | other.true, + false=self.false | other.false, + ) + + def __sub__(self, other): + return self.difference(other) + + def __or__(self, other): + return self.union(other) + + def __bool__(self): + return bool(self.true) or bool(self.false) class TextBeliefExtractorAgent(BaseAgent): @@ -27,12 +55,14 @@ class TextBeliefExtractorAgent(BaseAgent): the message itself. """ - def __init__(self, name: str, temperature: float = settings.llm_settings.code_temperature): + def __init__(self, name: str): super().__init__(name) - self.beliefs: dict[str, bool] = {} - self.available_beliefs: list[SemanticBelief] = [] + self._llm = self.LLM(self, settings.llm_settings.n_parallel) + self.belief_inferrer = SemanticBeliefInferrer(self._llm) + self.goal_inferrer = GoalAchievementInferrer(self._llm) + self._current_beliefs = BeliefState() + self._current_goal_completions: dict[str, bool] = {} self.conversation = ChatHistory(messages=[]) - self.temperature = temperature async def setup(self): """ @@ -53,8 +83,9 @@ class TextBeliefExtractorAgent(BaseAgent): case settings.agent_settings.transcription_name: self.logger.debug("Received text from transcriber: %s", msg.body) self._apply_conversation_message(ChatMessage(role="user", content=msg.body)) - await self._infer_new_beliefs() await self._user_said(msg.body) + await self._infer_new_beliefs() + await self._infer_goal_completions() case settings.agent_settings.llm_name: self.logger.debug("Received text from LLM: %s", msg.body) self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body)) @@ -76,10 +107,19 @@ class TextBeliefExtractorAgent(BaseAgent): def _handle_program_manager_message(self, msg: InternalMessage): """ - Handle a message from the program manager: extract available beliefs from it. + Handle a message from the program manager: extract available beliefs and goals from it. :param msg: The received message from the program manager. """ + match msg.thread: + case "beliefs": + self._handle_beliefs_message(msg) + case "goals": + self._handle_goals_message(msg) + case _: + self.logger.warning("Received unexpected message from %s", msg.sender) + + def _handle_beliefs_message(self, msg: InternalMessage): try: belief_list = BeliefList.model_validate_json(msg.body) except ValidationError: @@ -88,10 +128,28 @@ class TextBeliefExtractorAgent(BaseAgent): ) return - self.available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)] + available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)] + self.belief_inferrer.available_beliefs = available_beliefs self.logger.debug( - "Received %d beliefs from the program manager.", - len(self.available_beliefs), + "Received %d semantic beliefs from the program manager.", + len(available_beliefs), + ) + + def _handle_goals_message(self, msg: InternalMessage): + try: + goals_list = GoalList.model_validate_json(msg.body) + except ValidationError: + self.logger.warning( + "Received message from program manager but it is not a valid list of goals." + ) + return + + # Use only goals that can fail, as the others are always assumed to be completed + available_goals = [g for g in goals_list.goals if g.can_fail] + self.goal_inferrer.goals = available_goals + self.logger.debug( + "Received %d failable goals from the program manager.", + len(available_goals), ) async def _user_said(self, text: str): @@ -111,109 +169,199 @@ class TextBeliefExtractorAgent(BaseAgent): await self.send(belief_msg) async def _infer_new_beliefs(self): - """ - Process conversation history to extract beliefs, semantically. Any changed beliefs are sent - to the BDI core. - """ - # Return instantly if there are no beliefs to infer - if not self.available_beliefs: + conversation_beliefs = await self.belief_inferrer.infer_from_conversation(self.conversation) + + new_beliefs = conversation_beliefs - self._current_beliefs + if not new_beliefs: return - candidate_beliefs = await self._infer_turn() - belief_changes = BeliefMessage() - for belief_key, belief_value in candidate_beliefs.items(): - if belief_value is None: - continue - old_belief_value = self.beliefs.get(belief_key) - if belief_value == old_belief_value: - continue + self._current_beliefs |= new_beliefs - self.beliefs[belief_key] = belief_value + belief_changes = BeliefMessage( + create=list(new_beliefs.true), + delete=list(new_beliefs.false), + ) - belief = InternalBelief(name=belief_key, arguments=None) - if belief_value: - belief_changes.create.append(belief) - else: - belief_changes.delete.append(belief) - - # Return if there were no changes in beliefs - if not belief_changes.has_values(): - return - - beliefs_message = InternalMessage( + message = InternalMessage( to=settings.agent_settings.bdi_core_name, sender=self.name, body=belief_changes.model_dump_json(), thread="beliefs", ) - await self.send(beliefs_message) + await self.send(message) - @staticmethod - def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]: - k, m = divmod(len(items), n) - return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)] + async def _infer_goal_completions(self): + goal_completions = await self.goal_inferrer.infer_from_conversation(self.conversation) - async def _infer_turn(self) -> dict: + new_achieved = [ + InternalBelief(name=goal, arguments=None) + for goal, achieved in goal_completions.items() + if achieved and self._current_goal_completions.get(goal) != achieved + ] + new_not_achieved = [ + InternalBelief(name=goal, arguments=None) + for goal, achieved in goal_completions.items() + if not achieved and self._current_goal_completions.get(goal) != achieved + ] + for goal, achieved in goal_completions.items(): + self._current_goal_completions[goal] = achieved + + if not new_achieved and not new_not_achieved: + return + + belief_changes = BeliefMessage( + create=new_achieved, + delete=new_not_achieved, + ) + message = InternalMessage( + to=settings.agent_settings.bdi_core_name, + sender=self.name, + body=belief_changes.model_dump_json(), + thread="beliefs", + ) + await self.send(message) + + class LLM: """ - Process the stored conversation history to extract semantic beliefs. Returns a list of - beliefs that have been set to ``True``, ``False`` or ``None``. - - :return: A dict mapping belief names to a value ``True``, ``False`` or ``None``. + Class that handles sending structured generation requests to an LLM. """ + + def __init__(self, agent: "TextBeliefExtractorAgent", n_parallel: int): + self._agent = agent + self._semaphore = asyncio.Semaphore(n_parallel) + + async def query(self, prompt: str, schema: dict, tries: int = 3) -> JSONLike | None: + """ + Query the LLM with the given prompt and schema, return an instance of a dict conforming + to this schema. Try ``tries`` times, or return None. + + :param prompt: Prompt to be queried. + :param schema: Schema to be queried. + :param tries: Number of times to try to query the LLM. + :return: An instance of a dict conforming to this schema, or None if failed. + """ + try_count = 0 + while try_count < tries: + try_count += 1 + + try: + return await self._query_llm(prompt, schema) + except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e: + if try_count < tries: + continue + self._agent.logger.exception( + "Failed to get LLM response after %d tries.", + try_count, + exc_info=e, + ) + + return None + + async def _query_llm(self, prompt: str, schema: dict) -> JSONLike: + """ + Query an LLM with the given prompt and schema, return an instance of a dict conforming + to that schema. + + :param prompt: The prompt to be queried. + :param schema: Schema to use during response. + :return: A dict conforming to this schema. + :raises httpx.HTTPStatusError: If the LLM server responded with an error. + :raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the + response was cut off early due to length limitations. + :raises KeyError: If the LLM server responded with no error, but the response was + invalid. + """ + async with self._semaphore: + async with httpx.AsyncClient() as client: + response = await client.post( + settings.llm_settings.local_llm_url, + json={ + "model": settings.llm_settings.local_llm_model, + "messages": [{"role": "user", "content": prompt}], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "Beliefs", + "strict": True, + "schema": schema, + }, + }, + "reasoning_effort": "low", + "temperature": settings.llm_settings.code_temperature, + "stream": False, + }, + timeout=30.0, + ) + response.raise_for_status() + + response_json = response.json() + json_message = response_json["choices"][0]["message"]["content"] + return json.loads(json_message) + + +class SemanticBeliefInferrer: + """ + Class that handles only prompting an LLM for semantic beliefs. + """ + + def __init__( + self, + llm: "TextBeliefExtractorAgent.LLM", + available_beliefs: list[SemanticBelief] | None = None, + ): + self._llm = llm + self.available_beliefs: list[SemanticBelief] = available_beliefs or [] + + async def infer_from_conversation(self, conversation: ChatHistory) -> BeliefState: + """ + Process conversation history to extract beliefs, semantically. The result is an object that + describes all beliefs that hold or don't hold based on the full conversation. + + :param conversation: The conversation history to be processed. + :return: An object that describes beliefs. + """ + # Return instantly if there are no beliefs to infer + if not self.available_beliefs: + return BeliefState() + n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs))) - all_beliefs = await asyncio.gather( + all_beliefs: list[dict[str, bool | None] | None] = await asyncio.gather( *[ - self._infer_beliefs(self.conversation, beliefs) + self._infer_beliefs(conversation, beliefs) for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel) ] ) - retval = {} + retval = BeliefState() for beliefs in all_beliefs: if beliefs is None: continue - retval.update(beliefs) + for belief_name, belief_holds in beliefs.items(): + if belief_holds is None: + continue + belief = InternalBelief(name=belief_name, arguments=None) + if belief_holds: + retval.true.add(belief) + else: + retval.false.add(belief) return retval @staticmethod - def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]: - return AgentSpeakGenerator.slugify(belief), { - "type": ["boolean", "null"], - "description": belief.description, - } + def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]: + """ + Split a list into ``n`` chunks, making each chunk approximately ``len(items) / n`` long. - @staticmethod - def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict: - belief_schemas = [ - TextBeliefExtractorAgent._create_belief_schema(belief) for belief in beliefs - ] - - return { - "type": "object", - "properties": dict(belief_schemas), - "required": [name for name, _ in belief_schemas], - } - - @staticmethod - def _format_message(message: ChatMessage): - return f"{message.role.upper()}:\n{message.content}" - - @staticmethod - def _format_conversation(conversation: ChatHistory): - return "\n\n".join( - [TextBeliefExtractorAgent._format_message(message) for message in conversation.messages] - ) - - @staticmethod - def _format_beliefs(beliefs: list[SemanticBelief]): - return "\n".join( - [f"- {AgentSpeakGenerator.slugify(belief)}: {belief.description}" for belief in beliefs] - ) + :param items: The list of items to split. + :param n: The number of desired chunks. + :return: A list of chunks each approximately ``len(items) / n`` long. + """ + k, m = divmod(len(items), n) + return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)] async def _infer_beliefs( self, conversation: ChatHistory, beliefs: list[SemanticBelief], - ) -> dict | None: + ) -> dict[str, bool | None] | None: """ Infer given beliefs based on the given conversation. :param conversation: The conversation to infer beliefs from. @@ -240,70 +388,79 @@ Respond with a JSON similar to the following, but with the property names as giv schema = self._create_beliefs_schema(beliefs) - return await self._retry_query_llm(prompt, schema) + return await self._llm.query(prompt, schema) - async def _retry_query_llm(self, prompt: str, schema: dict, tries: int = 3) -> dict | None: + @staticmethod + def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]: + return AgentSpeakGenerator.slugify(belief), { + "type": ["boolean", "null"], + "description": belief.description, + } + + @staticmethod + def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict: + belief_schemas = [ + SemanticBeliefInferrer._create_belief_schema(belief) for belief in beliefs + ] + + return { + "type": "object", + "properties": dict(belief_schemas), + "required": [name for name, _ in belief_schemas], + } + + @staticmethod + def _format_message(message: ChatMessage): + return f"{message.role.upper()}:\n{message.content}" + + @staticmethod + def _format_conversation(conversation: ChatHistory): + return "\n\n".join( + [SemanticBeliefInferrer._format_message(message) for message in conversation.messages] + ) + + @staticmethod + def _format_beliefs(beliefs: list[SemanticBelief]): + return "\n".join( + [f"- {AgentSpeakGenerator.slugify(belief)}: {belief.description}" for belief in beliefs] + ) + + +class GoalAchievementInferrer(SemanticBeliefInferrer): + def __init__(self, llm: TextBeliefExtractorAgent.LLM): + super().__init__(llm) + self.goals = [] + + async def infer_from_conversation(self, conversation: ChatHistory) -> dict[str, bool]: """ - Query the LLM with the given prompt and schema, return an instance of a dict conforming - to this schema. Try ``tries`` times, or return None. + Determine which goals have been achieved based on the given conversation. - :param prompt: Prompt to be queried. - :param schema: Schema to be queried. - :return: An instance of a dict conforming to this schema, or None if failed. + :param conversation: The conversation to infer goal completion from. + :return: A mapping of goals and a boolean whether they have been achieved. """ - try_count = 0 - while try_count < tries: - try_count += 1 + if not self.goals: + return {} - try: - return await self._query_llm(prompt, schema) - except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e: - if try_count < tries: - continue - self.logger.exception( - "Failed to get LLM response after %d tries.", - try_count, - exc_info=e, - ) + goals_achieved = await asyncio.gather( + *[self._infer_goal(conversation, g) for g in self.goals] + ) + return { + f"achieved_{AgentSpeakGenerator.slugify(goal)}": achieved + for goal, achieved in zip(self.goals, goals_achieved, strict=True) + } - return None + async def _infer_goal(self, conversation: ChatHistory, goal: Goal) -> bool: + prompt = f"""{self._format_conversation(conversation)} - async def _query_llm(self, prompt: str, schema: dict) -> dict: - """ - Query an LLM with the given prompt and schema, return an instance of a dict conforming to - that schema. +Given the above conversation, what has the following goal been achieved? - :param prompt: The prompt to be queried. - :param schema: Schema to use during response. - :return: A dict conforming to this schema. - :raises httpx.HTTPStatusError: If the LLM server responded with an error. - :raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the - response was cut off early due to length limitations. - :raises KeyError: If the LLM server responded with no error, but the response was invalid. - """ - async with httpx.AsyncClient() as client: - response = await client.post( - settings.llm_settings.local_llm_url, - json={ - "model": settings.llm_settings.local_llm_model, - "messages": [{"role": "user", "content": prompt}], - "response_format": { - "type": "json_schema", - "json_schema": { - "name": "Beliefs", - "strict": True, - "schema": schema, - }, - }, - "reasoning_effort": "low", - "temperature": self.temperature, - "stream": False, - }, - timeout=None, - ) - response.raise_for_status() +The name of the goal: {goal.name} +Description of the goal: {goal.description} - response_json = response.json() - json_message = response_json["choices"][0]["message"]["content"] - beliefs = json.loads(json_message) - return beliefs +Answer with literally only `true` or `false` (without backticks).""" + + schema = { + "type": "boolean", + } + + return await self._llm.query(prompt, schema) diff --git a/src/control_backend/core/agent_system.py b/src/control_backend/core/agent_system.py index 9d7a47f..e12a6b2 100644 --- a/src/control_backend/core/agent_system.py +++ b/src/control_backend/core/agent_system.py @@ -192,7 +192,16 @@ class BaseAgent(ABC): :param coro: The coroutine to execute as a task. """ - task = asyncio.create_task(coro) + + async def try_coro(coro_: Coroutine): + try: + await coro_ + except asyncio.CancelledError: + self.logger.debug("A behavior was canceled successfully: %s", coro_) + except Exception: + self.logger.warning("An exception occurred in a behavior.", exc_info=True) + + task = asyncio.create_task(try_coro(coro)) self._tasks.add(task) task.add_done_callback(self._tasks.discard) return task diff --git a/src/control_backend/schemas/belief_list.py b/src/control_backend/schemas/belief_list.py index ec6a7a1..b79247d 100644 --- a/src/control_backend/schemas/belief_list.py +++ b/src/control_backend/schemas/belief_list.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from control_backend.schemas.program import Belief as ProgramBelief +from control_backend.schemas.program import Goal class BeliefList(BaseModel): @@ -12,3 +13,7 @@ class BeliefList(BaseModel): """ beliefs: list[ProgramBelief] + + +class GoalList(BaseModel): + goals: list[Goal] diff --git a/src/control_backend/schemas/belief_message.py b/src/control_backend/schemas/belief_message.py index 56a8a4a..51411b3 100644 --- a/src/control_backend/schemas/belief_message.py +++ b/src/control_backend/schemas/belief_message.py @@ -13,6 +13,9 @@ class Belief(BaseModel): name: str arguments: list[str] | None + # To make it hashable + model_config = {"frozen": True} + class BeliefMessage(BaseModel): """ diff --git a/src/control_backend/schemas/program.py b/src/control_backend/schemas/program.py index df20954..82c017e 100644 --- a/src/control_backend/schemas/program.py +++ b/src/control_backend/schemas/program.py @@ -117,7 +117,7 @@ class Goal(ProgramElement): :ivar can_fail: Whether we can fail to achieve the goal after executing the plan. """ - description: str + description: str = "" plan: Plan can_fail: bool = True