diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index 092a2c6..25b7364 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -42,6 +42,16 @@ class BDIProgramManager(BaseAgent): def _initialize_internal_state(self, program: Program): self._program = program self._phase = program.phases[0] # start in first phase + self._goal_mapping: dict[str, Goal] = {} + for phase in program.phases: + for goal in phase.goals: + self._populate_goal_mapping_with_goal(goal) + + def _populate_goal_mapping_with_goal(self, goal: Goal): + self._goal_mapping[str(goal.id)] = goal + for step in goal.plan.steps: + if isinstance(step, Goal): + self._populate_goal_mapping_with_goal(step) async def _create_agentspeak_and_send_to_bdi(self, program: Program): """ @@ -73,6 +83,9 @@ class BDIProgramManager(BaseAgent): phases = json.loads(msg.body) await self._transition_phase(phases["old"], phases["new"]) + case "achieve_goal": + goal_id = msg.body + await self._send_achieved_goal_to_semantic_belief_extractor(goal_id) async def _transition_phase(self, old: str, new: str): if old != str(self._phase.id): @@ -138,6 +151,19 @@ class BDIProgramManager(BaseAgent): await self.send(message) + @staticmethod + def _extract_goals_from_goal(goal: Goal) -> list[Goal]: + """ + Extract all goals from a given goal, that is: the goal itself and any subgoals. + + :return: All goals within and including the given goal. + """ + goals: list[Goal] = [goal] + for plan in goal.plan: + if isinstance(plan, Goal): + goals.extend(BDIProgramManager._extract_goals_from_goal(plan)) + return goals + def _extract_current_goals(self) -> list[Goal]: """ Extract all goals from the program, including subgoals. @@ -146,15 +172,8 @@ class BDIProgramManager(BaseAgent): """ goals: list[Goal] = [] - def extract_goals_from_goal(goal_: Goal) -> list[Goal]: - goals_: list[Goal] = [goal] - for plan in goal_.plan: - if isinstance(plan, Goal): - goals_.extend(extract_goals_from_goal(plan)) - return goals_ - for goal in self._phase.goals: - goals.extend(extract_goals_from_goal(goal)) + goals.extend(self._extract_goals_from_goal(goal)) return goals @@ -173,6 +192,25 @@ class BDIProgramManager(BaseAgent): await self.send(message) + async def _send_achieved_goal_to_semantic_belief_extractor(self, achieved_goal_id: str): + """ + Inform the semantic belief extractor when a goal is marked achieved. + + :param achieved_goal_id: The id of the achieved goal. + """ + goal = self._goal_mapping.get(achieved_goal_id) + if goal is None: + self.logger.debug(f"Goal with ID {achieved_goal_id} marked achieved but was not found.") + return + + goals = self._extract_goals_from_goal(goal) + message = InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + body=GoalList(goals=goals).model_dump_json(), + thread="achieved_goals", + ) + await self.send(message) + async def _send_clear_llm_history(self): """ Clear the LLM Agent's conversation history. diff --git a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py index 108e821..d994121 100644 --- a/src/control_backend/agents/user_interrupt/user_interrupt_agent.py +++ b/src/control_backend/agents/user_interrupt/user_interrupt_agent.py @@ -7,6 +7,7 @@ from control_backend.agents import BaseAgent from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.belief_message import Belief, BeliefMessage from control_backend.schemas.program import ConditionalNorm, Program from control_backend.schemas.ri_message import ( GestureCommand, @@ -31,7 +32,7 @@ class UserInterruptAgent(BaseAgent): Prioritized actions clear the current RI queue before inserting the new item, ensuring they are executed immediately after Pepper's current action has been fulfilled. - :ivar sub_socket: The ZMQ SUB socket used to receive user intterupts. + :ivar sub_socket: The ZMQ SUB socket used to receive user interrupts. """ def __init__(self, **kwargs): @@ -118,8 +119,23 @@ class UserInterruptAgent(BaseAgent): "Forwarded button press (override) with context '%s' to BDIProgramManager.", event_context, ) + elif asl_goal := self._goal_map.get(ui_id): + await self._send_to_bdi_belief(asl_goal) + self.logger.info( + "Forwarded button press (override) with context '%s' to BDI Core.", + event_context, + ) + + goal_achieve_msg = InternalMessage( + to=settings.agent_settings.bdi_program_manager_name, + thread="achieve_goal", + body=ui_id, + ) + + await self.send(goal_achieve_msg) else: self.logger.warning("Could not determine which element to override.") + elif event_type == "pause": self.logger.debug( "Received pause/resume button press with context '%s'.", event_context @@ -291,6 +307,19 @@ class UserInterruptAgent(BaseAgent): await self.send(msg) self.logger.info(f"Directly forced {thread} in BDI: {body}") + async def _send_to_bdi_belief(self, asl_goal: str): + """Send belief to BDI Core""" + belief_name = f"achieved_{asl_goal}" + belief = Belief(name=belief_name) + self.logger.debug(f"Sending belief to BDI Core: {belief_name}") + belief_message = BeliefMessage(create=[belief]) + msg = InternalMessage( + to=settings.agent_settings.bdi_core_name, + thread="beliefs", + body=belief_message.model_dump_json(), + ) + await self.send(msg) + async def _send_experiment_control_to_bdi_core(self, type): """ method to send experiment control buttons to bdi core. diff --git a/src/control_backend/api/v1/endpoints/robot.py b/src/control_backend/api/v1/endpoints/robot.py index afbf1ac..95a9c40 100644 --- a/src/control_backend/api/v1/endpoints/robot.py +++ b/src/control_backend/api/v1/endpoints/robot.py @@ -137,7 +137,6 @@ async def ping_stream(request: Request): logger.info("Client disconnected from SSE") break - logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}") connectedJson = json.dumps(connected) yield (f"data: {connectedJson}\n\n") diff --git a/src/control_backend/schemas/belief_message.py b/src/control_backend/schemas/belief_message.py index 51411b3..226833e 100644 --- a/src/control_backend/schemas/belief_message.py +++ b/src/control_backend/schemas/belief_message.py @@ -11,7 +11,7 @@ class Belief(BaseModel): """ name: str - arguments: list[str] | None + arguments: list[str] | None = None # To make it hashable model_config = {"frozen": True}