feat: support force completed goals in semantic belief agent

ref: N25B-427
This commit is contained in:
Twirre Meulenbelt
2026-01-13 17:04:44 +01:00
parent 8f52f8bf0c
commit f7669c021b
5 changed files with 52 additions and 17 deletions

View File

@@ -18,6 +18,7 @@ from control_backend.agents.bdi.agentspeak_ast import (
TriggerType,
)
from control_backend.schemas.program import (
BaseGoal,
BasicNorm,
ConditionalNorm,
GestureAction,
@@ -436,7 +437,7 @@ class AgentSpeakGenerator:
@slugify.register
@staticmethod
def _(g: Goal) -> str:
def _(g: BaseGoal) -> str:
return AgentSpeakGenerator._slugify_str(g.name)
@slugify.register

View File

@@ -12,7 +12,7 @@ from control_backend.schemas.belief_list import BeliefList, GoalList
from control_backend.schemas.belief_message import Belief as InternalBelief
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
from control_backend.schemas.program import Goal, SemanticBelief
from control_backend.schemas.program import BaseGoal, SemanticBelief
type JSONLike = None | bool | int | float | str | list["JSONLike"] | dict[str, "JSONLike"]
@@ -62,6 +62,7 @@ class TextBeliefExtractorAgent(BaseAgent):
self.goal_inferrer = GoalAchievementInferrer(self._llm)
self._current_beliefs = BeliefState()
self._current_goal_completions: dict[str, bool] = {}
self._force_completed_goals: set[BaseGoal] = set()
self.conversation = ChatHistory(messages=[])
async def setup(self):
@@ -118,13 +119,15 @@ class TextBeliefExtractorAgent(BaseAgent):
case "goals":
self._handle_goals_message(msg)
await self._infer_goal_completions()
case "achieved_goals":
self._handle_goal_achieved_message(msg)
case "conversation_history":
if msg.body == "reset":
self._reset()
self._reset_phase()
case _:
self.logger.warning("Received unexpected message from %s", msg.sender)
def _reset(self):
def _reset_phase(self):
self.conversation = ChatHistory(messages=[])
self.belief_inferrer.available_beliefs.clear()
self._current_beliefs = BeliefState()
@@ -158,7 +161,8 @@ class TextBeliefExtractorAgent(BaseAgent):
return
# Use only goals that can fail, as the others are always assumed to be completed
available_goals = [g for g in goals_list.goals if g.can_fail]
available_goals = {g for g in goals_list.goals if g.can_fail}
available_goals -= self._force_completed_goals
self.goal_inferrer.goals = available_goals
self.logger.debug(
"Received %d failable goals from the program manager: %s",
@@ -166,6 +170,23 @@ class TextBeliefExtractorAgent(BaseAgent):
", ".join(g.name for g in available_goals),
)
def _handle_goal_achieved_message(self, msg: InternalMessage):
# NOTE: When goals can be marked unachieved, remember to re-add them to the goal_inferrer
try:
goals_list = GoalList.model_validate_json(msg.body)
except ValidationError:
self.logger.warning(
"Received goal achieved message from the program manager, "
"but it is not a valid list of goals."
)
return
for goal in goals_list.goals:
self._force_completed_goals.add(goal)
self._current_goal_completions[f"achieved_{AgentSpeakGenerator.slugify(goal)}"] = True
self.goal_inferrer.goals -= self._force_completed_goals
async def _user_said(self, text: str):
"""
Create a belief for the user's full speech.
@@ -445,7 +466,7 @@ Respond with a JSON similar to the following, but with the property names as giv
class GoalAchievementInferrer(SemanticBeliefInferrer):
def __init__(self, llm: TextBeliefExtractorAgent.LLM):
super().__init__(llm)
self.goals = []
self.goals: set[BaseGoal] = set()
async def infer_from_conversation(self, conversation: ChatHistory) -> dict[str, bool]:
"""
@@ -465,7 +486,7 @@ class GoalAchievementInferrer(SemanticBeliefInferrer):
for goal, achieved in zip(self.goals, goals_achieved, strict=True)
}
async def _infer_goal(self, conversation: ChatHistory, goal: Goal) -> bool:
async def _infer_goal(self, conversation: ChatHistory, goal: BaseGoal) -> bool:
prompt = f"""{self._format_conversation(conversation)}
Given the above conversation, what has the following goal been achieved?