test: create tests for belief extractor agent
Includes changes in schemas. Change type of `norms` in `Program` imperceptibly, big changes in schema of `BeliefMessage` to support deleting beliefs. ref: N25B-380
This commit is contained in:
@@ -11,7 +11,7 @@ from pydantic import ValidationError
|
||||
from control_backend.agents.base import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||
from control_backend.schemas.belief_message import BeliefMessage
|
||||
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
@@ -124,8 +124,8 @@ class BDICoreAgent(BaseAgent):
|
||||
|
||||
if msg.thread == "beliefs":
|
||||
try:
|
||||
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
|
||||
self._apply_beliefs(beliefs)
|
||||
belief_changes = BeliefMessage.model_validate_json(msg.body)
|
||||
self._apply_belief_changes(belief_changes)
|
||||
except ValidationError:
|
||||
self.logger.exception("Error processing belief.")
|
||||
return
|
||||
@@ -145,21 +145,28 @@ class BDICoreAgent(BaseAgent):
|
||||
)
|
||||
await self.send(out_msg)
|
||||
|
||||
def _apply_beliefs(self, beliefs: list[Belief]):
|
||||
def _apply_belief_changes(self, belief_changes: BeliefMessage):
|
||||
"""
|
||||
Update the belief base with a list of new beliefs.
|
||||
|
||||
If ``replace=True`` is set on a belief, it removes all existing beliefs with that name
|
||||
before adding the new one.
|
||||
For beliefs in ``belief_changes.replace``, it removes all existing beliefs with that name
|
||||
before adding one new one.
|
||||
|
||||
:param belief_changes: The changes in beliefs to apply.
|
||||
"""
|
||||
if not beliefs:
|
||||
if not belief_changes.create and not belief_changes.replace and not belief_changes.delete:
|
||||
return
|
||||
|
||||
for belief in beliefs:
|
||||
if belief.replace:
|
||||
self._remove_all_with_name(belief.name)
|
||||
for belief in belief_changes.create:
|
||||
self._add_belief(belief.name, belief.arguments)
|
||||
|
||||
for belief in belief_changes.replace:
|
||||
self._remove_all_with_name(belief.name)
|
||||
self._add_belief(belief.name, belief.arguments)
|
||||
|
||||
for belief in belief_changes.delete:
|
||||
self._remove_belief(belief.name, belief.arguments)
|
||||
|
||||
def _add_belief(self, name: str, args: list[str] = None):
|
||||
"""
|
||||
Add a single belief to the BDI agent.
|
||||
|
||||
@@ -144,7 +144,7 @@ class BDIBeliefCollectorAgent(BaseAgent):
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
sender=self.name,
|
||||
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
|
||||
body=BeliefMessage(create=beliefs).model_dump_json(),
|
||||
thread="beliefs",
|
||||
)
|
||||
|
||||
|
||||
@@ -34,8 +34,8 @@ class TextBeliefExtractorAgent(BaseAgent):
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.beliefs = {}
|
||||
self.available_beliefs = []
|
||||
self.beliefs: dict[str, bool] = {}
|
||||
self.available_beliefs: list[SemanticBelief] = []
|
||||
self.conversation = ChatHistory(messages=[])
|
||||
|
||||
async def setup(self):
|
||||
@@ -151,23 +151,30 @@ class TextBeliefExtractorAgent(BaseAgent):
|
||||
return
|
||||
|
||||
candidate_beliefs = await self._infer_turn()
|
||||
new_beliefs: list[InternalBelief] = []
|
||||
belief_changes = BeliefMessage()
|
||||
for belief_key, belief_value in candidate_beliefs.items():
|
||||
if belief_value is None:
|
||||
continue
|
||||
old_belief_value = self.beliefs.get(belief_key)
|
||||
# TODO: Do we need this check? Can we send the same beliefs multiple times?
|
||||
if belief_value == old_belief_value:
|
||||
continue
|
||||
|
||||
self.beliefs[belief_key] = belief_value
|
||||
new_beliefs.append(
|
||||
InternalBelief(name=belief_key, arguments=[belief_value], replace=True),
|
||||
)
|
||||
|
||||
belief = InternalBelief(name=belief_key, arguments=None)
|
||||
if belief_value:
|
||||
belief_changes.create.append(belief)
|
||||
else:
|
||||
belief_changes.delete.append(belief)
|
||||
|
||||
# Return if there were no changes in beliefs
|
||||
if not belief_changes.has_values():
|
||||
return
|
||||
|
||||
beliefs_message = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
sender=self.name,
|
||||
body=BeliefMessage(beliefs=new_beliefs).model_dump_json(),
|
||||
body=belief_changes.model_dump_json(),
|
||||
thread="beliefs",
|
||||
)
|
||||
await self.send(beliefs_message)
|
||||
@@ -184,7 +191,7 @@ class TextBeliefExtractorAgent(BaseAgent):
|
||||
|
||||
:return: A dict mapping belief names to a value ``True``, ``False`` or ``None``.
|
||||
"""
|
||||
n_parallel = min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs))
|
||||
n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)))
|
||||
all_beliefs = await asyncio.gather(
|
||||
*[
|
||||
self._infer_beliefs(self.conversation, beliefs)
|
||||
@@ -286,7 +293,7 @@ Respond with a JSON similar to the following, but with the property names as giv
|
||||
|
||||
try:
|
||||
return await self._query_llm(prompt, schema)
|
||||
except (httpx.HTTPStatusError, json.JSONDecodeError, KeyError) as e:
|
||||
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
|
||||
if try_count < tries:
|
||||
continue
|
||||
self.logger.exception(
|
||||
|
||||
@@ -6,18 +6,27 @@ class Belief(BaseModel):
|
||||
Represents a single belief in the BDI system.
|
||||
|
||||
:ivar name: The functor or name of the belief (e.g., 'user_said').
|
||||
:ivar arguments: A list of string arguments for the belief.
|
||||
:ivar replace: If True, existing beliefs with this name should be replaced by this one.
|
||||
:ivar arguments: A list of string arguments for the belief, or None if the belief has no
|
||||
arguments.
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: list[str]
|
||||
replace: bool = False
|
||||
arguments: list[str] | None
|
||||
|
||||
|
||||
class BeliefMessage(BaseModel):
|
||||
"""
|
||||
A container for transporting a list of beliefs between agents.
|
||||
A container for communicating beliefs between agents.
|
||||
|
||||
:ivar create: Beliefs to create.
|
||||
:ivar delete: Beliefs to delete.
|
||||
:ivar replace: Beliefs to replace. Deletes all beliefs with the same name, replacing them with
|
||||
one new belief.
|
||||
"""
|
||||
|
||||
beliefs: list[Belief]
|
||||
create: list[Belief] = []
|
||||
delete: list[Belief] = []
|
||||
replace: list[Belief] = []
|
||||
|
||||
def has_values(self) -> bool:
|
||||
return len(self.create) > 0 or len(self.delete) > 0 or len(self.replace) > 0
|
||||
|
||||
@@ -194,7 +194,7 @@ class Phase(ProgramElement):
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
norms: list[Norm]
|
||||
norms: list[BasicNorm | ConditionalNorm]
|
||||
goals: list[Goal]
|
||||
triggers: list[Trigger]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user