test: create tests for belief extractor agent
Includes changes in schemas. Change type of `norms` in `Program` imperceptibly, big changes in schema of `BeliefMessage` to support deleting beliefs. ref: N25B-380
This commit is contained in:
@@ -11,7 +11,7 @@ from pydantic import ValidationError
|
|||||||
from control_backend.agents.base import BaseAgent
|
from control_backend.agents.base import BaseAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
|
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
|
||||||
from control_backend.schemas.ri_message import SpeechCommand
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
@@ -124,8 +124,8 @@ class BDICoreAgent(BaseAgent):
|
|||||||
|
|
||||||
if msg.thread == "beliefs":
|
if msg.thread == "beliefs":
|
||||||
try:
|
try:
|
||||||
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
|
belief_changes = BeliefMessage.model_validate_json(msg.body)
|
||||||
self._apply_beliefs(beliefs)
|
self._apply_belief_changes(belief_changes)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
self.logger.exception("Error processing belief.")
|
self.logger.exception("Error processing belief.")
|
||||||
return
|
return
|
||||||
@@ -145,21 +145,28 @@ class BDICoreAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
await self.send(out_msg)
|
await self.send(out_msg)
|
||||||
|
|
||||||
def _apply_beliefs(self, beliefs: list[Belief]):
|
def _apply_belief_changes(self, belief_changes: BeliefMessage):
|
||||||
"""
|
"""
|
||||||
Update the belief base with a list of new beliefs.
|
Update the belief base with a list of new beliefs.
|
||||||
|
|
||||||
If ``replace=True`` is set on a belief, it removes all existing beliefs with that name
|
For beliefs in ``belief_changes.replace``, it removes all existing beliefs with that name
|
||||||
before adding the new one.
|
before adding one new one.
|
||||||
|
|
||||||
|
:param belief_changes: The changes in beliefs to apply.
|
||||||
"""
|
"""
|
||||||
if not beliefs:
|
if not belief_changes.create and not belief_changes.replace and not belief_changes.delete:
|
||||||
return
|
return
|
||||||
|
|
||||||
for belief in beliefs:
|
for belief in belief_changes.create:
|
||||||
if belief.replace:
|
self._add_belief(belief.name, belief.arguments)
|
||||||
|
|
||||||
|
for belief in belief_changes.replace:
|
||||||
self._remove_all_with_name(belief.name)
|
self._remove_all_with_name(belief.name)
|
||||||
self._add_belief(belief.name, belief.arguments)
|
self._add_belief(belief.name, belief.arguments)
|
||||||
|
|
||||||
|
for belief in belief_changes.delete:
|
||||||
|
self._remove_belief(belief.name, belief.arguments)
|
||||||
|
|
||||||
def _add_belief(self, name: str, args: list[str] = None):
|
def _add_belief(self, name: str, args: list[str] = None):
|
||||||
"""
|
"""
|
||||||
Add a single belief to the BDI agent.
|
Add a single belief to the BDI agent.
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class BDIBeliefCollectorAgent(BaseAgent):
|
|||||||
msg = InternalMessage(
|
msg = InternalMessage(
|
||||||
to=settings.agent_settings.bdi_core_name,
|
to=settings.agent_settings.bdi_core_name,
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
|
body=BeliefMessage(create=beliefs).model_dump_json(),
|
||||||
thread="beliefs",
|
thread="beliefs",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -34,8 +34,8 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.beliefs = {}
|
self.beliefs: dict[str, bool] = {}
|
||||||
self.available_beliefs = []
|
self.available_beliefs: list[SemanticBelief] = []
|
||||||
self.conversation = ChatHistory(messages=[])
|
self.conversation = ChatHistory(messages=[])
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
@@ -151,23 +151,30 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
return
|
return
|
||||||
|
|
||||||
candidate_beliefs = await self._infer_turn()
|
candidate_beliefs = await self._infer_turn()
|
||||||
new_beliefs: list[InternalBelief] = []
|
belief_changes = BeliefMessage()
|
||||||
for belief_key, belief_value in candidate_beliefs.items():
|
for belief_key, belief_value in candidate_beliefs.items():
|
||||||
if belief_value is None:
|
if belief_value is None:
|
||||||
continue
|
continue
|
||||||
old_belief_value = self.beliefs.get(belief_key)
|
old_belief_value = self.beliefs.get(belief_key)
|
||||||
# TODO: Do we need this check? Can we send the same beliefs multiple times?
|
|
||||||
if belief_value == old_belief_value:
|
if belief_value == old_belief_value:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.beliefs[belief_key] = belief_value
|
self.beliefs[belief_key] = belief_value
|
||||||
new_beliefs.append(
|
|
||||||
InternalBelief(name=belief_key, arguments=[belief_value], replace=True),
|
belief = InternalBelief(name=belief_key, arguments=None)
|
||||||
)
|
if belief_value:
|
||||||
|
belief_changes.create.append(belief)
|
||||||
|
else:
|
||||||
|
belief_changes.delete.append(belief)
|
||||||
|
|
||||||
|
# Return if there were no changes in beliefs
|
||||||
|
if not belief_changes.has_values():
|
||||||
|
return
|
||||||
|
|
||||||
beliefs_message = InternalMessage(
|
beliefs_message = InternalMessage(
|
||||||
to=settings.agent_settings.bdi_core_name,
|
to=settings.agent_settings.bdi_core_name,
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body=BeliefMessage(beliefs=new_beliefs).model_dump_json(),
|
body=belief_changes.model_dump_json(),
|
||||||
thread="beliefs",
|
thread="beliefs",
|
||||||
)
|
)
|
||||||
await self.send(beliefs_message)
|
await self.send(beliefs_message)
|
||||||
@@ -184,7 +191,7 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
|
|
||||||
:return: A dict mapping belief names to a value ``True``, ``False`` or ``None``.
|
:return: A dict mapping belief names to a value ``True``, ``False`` or ``None``.
|
||||||
"""
|
"""
|
||||||
n_parallel = min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs))
|
n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)))
|
||||||
all_beliefs = await asyncio.gather(
|
all_beliefs = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self._infer_beliefs(self.conversation, beliefs)
|
self._infer_beliefs(self.conversation, beliefs)
|
||||||
@@ -286,7 +293,7 @@ Respond with a JSON similar to the following, but with the property names as giv
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return await self._query_llm(prompt, schema)
|
return await self._query_llm(prompt, schema)
|
||||||
except (httpx.HTTPStatusError, json.JSONDecodeError, KeyError) as e:
|
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
|
||||||
if try_count < tries:
|
if try_count < tries:
|
||||||
continue
|
continue
|
||||||
self.logger.exception(
|
self.logger.exception(
|
||||||
|
|||||||
@@ -6,18 +6,27 @@ class Belief(BaseModel):
|
|||||||
Represents a single belief in the BDI system.
|
Represents a single belief in the BDI system.
|
||||||
|
|
||||||
:ivar name: The functor or name of the belief (e.g., 'user_said').
|
:ivar name: The functor or name of the belief (e.g., 'user_said').
|
||||||
:ivar arguments: A list of string arguments for the belief.
|
:ivar arguments: A list of string arguments for the belief, or None if the belief has no
|
||||||
:ivar replace: If True, existing beliefs with this name should be replaced by this one.
|
arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
arguments: list[str]
|
arguments: list[str] | None
|
||||||
replace: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class BeliefMessage(BaseModel):
|
class BeliefMessage(BaseModel):
|
||||||
"""
|
"""
|
||||||
A container for transporting a list of beliefs between agents.
|
A container for communicating beliefs between agents.
|
||||||
|
|
||||||
|
:ivar create: Beliefs to create.
|
||||||
|
:ivar delete: Beliefs to delete.
|
||||||
|
:ivar replace: Beliefs to replace. Deletes all beliefs with the same name, replacing them with
|
||||||
|
one new belief.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beliefs: list[Belief]
|
create: list[Belief] = []
|
||||||
|
delete: list[Belief] = []
|
||||||
|
replace: list[Belief] = []
|
||||||
|
|
||||||
|
def has_values(self) -> bool:
|
||||||
|
return len(self.create) > 0 or len(self.delete) > 0 or len(self.replace) > 0
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class Phase(ProgramElement):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = ""
|
name: str = ""
|
||||||
norms: list[Norm]
|
norms: list[BasicNorm | ConditionalNorm]
|
||||||
goals: list[Goal]
|
goals: list[Goal]
|
||||||
triggers: list[Trigger]
|
triggers: list[Trigger]
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ async def test_handle_belief_collector_message(agent, mock_settings):
|
|||||||
msg = InternalMessage(
|
msg = InternalMessage(
|
||||||
to="bdi_agent",
|
to="bdi_agent",
|
||||||
sender=mock_settings.agent_settings.bdi_belief_collector_name,
|
sender=mock_settings.agent_settings.bdi_belief_collector_name,
|
||||||
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
|
body=BeliefMessage(create=beliefs).model_dump_json(),
|
||||||
thread="beliefs",
|
thread="beliefs",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,6 +64,26 @@ async def test_handle_belief_collector_message(agent, mock_settings):
|
|||||||
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
|
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_delete_belief_message(agent, mock_settings):
|
||||||
|
"""Test that incoming beliefs to be deleted are removed from the BDI agent"""
|
||||||
|
beliefs = [Belief(name="user_said", arguments=["Hello"])]
|
||||||
|
|
||||||
|
msg = InternalMessage(
|
||||||
|
to="bdi_agent",
|
||||||
|
sender=mock_settings.agent_settings.bdi_belief_collector_name,
|
||||||
|
body=BeliefMessage(delete=beliefs).model_dump_json(),
|
||||||
|
thread="beliefs",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
# Expect bdi_agent.call to be triggered to remove belief
|
||||||
|
args = agent.bdi_agent.call.call_args.args
|
||||||
|
assert args[0] == agentspeak.Trigger.removal
|
||||||
|
assert args[1] == agentspeak.GoalType.belief
|
||||||
|
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_incorrect_belief_collector_message(agent, mock_settings):
|
async def test_incorrect_belief_collector_message(agent, mock_settings):
|
||||||
"""Test that incorrect message format triggers an exception."""
|
"""Test that incorrect message format triggers an exception."""
|
||||||
@@ -128,7 +148,8 @@ def test_add_belief_sets_event(agent):
|
|||||||
agent._wake_bdi_loop = MagicMock()
|
agent._wake_bdi_loop = MagicMock()
|
||||||
|
|
||||||
belief = Belief(name="test_belief", arguments=["a", "b"])
|
belief = Belief(name="test_belief", arguments=["a", "b"])
|
||||||
agent._apply_beliefs([belief])
|
belief_changes = BeliefMessage(replace=[belief])
|
||||||
|
agent._apply_belief_changes(belief_changes)
|
||||||
|
|
||||||
assert agent.bdi_agent.call.called
|
assert agent.bdi_agent.call.called
|
||||||
agent._wake_bdi_loop.set.assert_called()
|
agent._wake_bdi_loop.set.assert_called()
|
||||||
@@ -137,7 +158,7 @@ def test_add_belief_sets_event(agent):
|
|||||||
def test_apply_beliefs_empty_returns(agent):
|
def test_apply_beliefs_empty_returns(agent):
|
||||||
"""Line: if not beliefs: return"""
|
"""Line: if not beliefs: return"""
|
||||||
agent._wake_bdi_loop = MagicMock()
|
agent._wake_bdi_loop = MagicMock()
|
||||||
agent._apply_beliefs([])
|
agent._apply_belief_changes(BeliefMessage())
|
||||||
agent.bdi_agent.call.assert_not_called()
|
agent.bdi_agent.call.assert_not_called()
|
||||||
agent._wake_bdi_loop.set.assert_not_called()
|
agent._wake_bdi_loop.set.assert_not_called()
|
||||||
|
|
||||||
@@ -220,8 +241,9 @@ def test_replace_belief_calls_remove_all(agent):
|
|||||||
agent._remove_all_with_name = MagicMock()
|
agent._remove_all_with_name = MagicMock()
|
||||||
agent._wake_bdi_loop = MagicMock()
|
agent._wake_bdi_loop = MagicMock()
|
||||||
|
|
||||||
belief = Belief(name="user_said", arguments=["Hello"], replace=True)
|
belief = Belief(name="user_said", arguments=["Hello"])
|
||||||
agent._apply_beliefs([belief])
|
belief_changes = BeliefMessage(replace=[belief])
|
||||||
|
agent._apply_belief_changes(belief_changes)
|
||||||
|
|
||||||
agent._remove_all_with_name.assert_called_with("user_said")
|
agent._remove_all_with_name.assert_called_with("user_said")
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ async def test_send_beliefs_to_bdi(agent):
|
|||||||
sent: InternalMessage = agent.send.call_args.args[0]
|
sent: InternalMessage = agent.send.call_args.args[0]
|
||||||
assert sent.to == settings.agent_settings.bdi_core_name
|
assert sent.to == settings.agent_settings.bdi_core_name
|
||||||
assert sent.thread == "beliefs"
|
assert sent.thread == "beliefs"
|
||||||
assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]
|
assert json.loads(sent.body)["create"] == [belief.model_dump() for belief in beliefs]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
346
test/unit/agents/bdi/test_text_belief_extractor.py
Normal file
346
test/unit/agents/bdi/test_text_belief_extractor.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from control_backend.agents.bdi import TextBeliefExtractorAgent
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
|
from control_backend.schemas.program import (
|
||||||
|
ConditionalNorm,
|
||||||
|
LLMAction,
|
||||||
|
Phase,
|
||||||
|
Plan,
|
||||||
|
Program,
|
||||||
|
SemanticBelief,
|
||||||
|
Trigger,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent():
|
||||||
|
agent = TextBeliefExtractorAgent("text_belief_agent")
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
agent._query_llm = AsyncMock()
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_program():
|
||||||
|
return Program(
|
||||||
|
phases=[
|
||||||
|
Phase(
|
||||||
|
name="Some phase",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
norms=[
|
||||||
|
ConditionalNorm(
|
||||||
|
name="Some norm",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
norm="Use nautical terms.",
|
||||||
|
critical=False,
|
||||||
|
condition=SemanticBelief(
|
||||||
|
name="is_pirate",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
description="The user is a pirate. Perhaps because they say "
|
||||||
|
"they are, or because they speak like a pirate "
|
||||||
|
'with terms like "arr".',
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
goals=[],
|
||||||
|
triggers=[
|
||||||
|
Trigger(
|
||||||
|
name="Some trigger",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
condition=SemanticBelief(
|
||||||
|
name="no_more_booze",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
description="There is no more alcohol.",
|
||||||
|
),
|
||||||
|
plan=Plan(
|
||||||
|
name="Some plan",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
steps=[
|
||||||
|
LLMAction(
|
||||||
|
name="Some action",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
goal="Suggest eating chocolate instead.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
|
||||||
|
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_ignores_other_agents(agent):
|
||||||
|
msg = make_msg("unknown", "some data", None)
|
||||||
|
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_from_transcriber(agent, mock_settings):
|
||||||
|
transcription = "hello world"
|
||||||
|
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
|
||||||
|
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
||||||
|
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
||||||
|
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
||||||
|
assert sent.thread == "beliefs"
|
||||||
|
parsed = json.loads(sent.body)
|
||||||
|
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_user_said(agent, mock_settings):
|
||||||
|
transcription = "this is a test"
|
||||||
|
|
||||||
|
await agent._user_said(transcription)
|
||||||
|
|
||||||
|
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
||||||
|
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
||||||
|
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
||||||
|
assert sent.thread == "beliefs"
|
||||||
|
parsed = json.loads(sent.body)
|
||||||
|
assert parsed["beliefs"]["user_said"] == [transcription]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_llm():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": "null",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
mock_async_client = MagicMock()
|
||||||
|
mock_async_client.__aenter__.return_value = mock_client
|
||||||
|
mock_async_client.__aexit__.return_value = None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
|
||||||
|
return_value=mock_async_client,
|
||||||
|
):
|
||||||
|
agent = TextBeliefExtractorAgent("text_belief_agent")
|
||||||
|
|
||||||
|
res = await agent._query_llm("hello world", {"type": "null"})
|
||||||
|
# Response content was set as "null", so should be deserialized as None
|
||||||
|
assert res is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_query_llm_success(agent):
|
||||||
|
agent._query_llm.return_value = None
|
||||||
|
res = await agent._retry_query_llm("hello world", {"type": "null"})
|
||||||
|
|
||||||
|
agent._query_llm.assert_called_once()
|
||||||
|
assert res is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_query_llm_success_after_failure(agent):
|
||||||
|
agent._query_llm.side_effect = [KeyError(), "real value"]
|
||||||
|
res = await agent._retry_query_llm("hello world", {"type": "string"})
|
||||||
|
|
||||||
|
assert agent._query_llm.call_count == 2
|
||||||
|
assert res == "real value"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_query_llm_failures(agent):
|
||||||
|
agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
|
||||||
|
res = await agent._retry_query_llm("hello world", {"type": "string"})
|
||||||
|
|
||||||
|
assert agent._query_llm.call_count == 3
|
||||||
|
assert res is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_query_llm_fail_immediately(agent):
|
||||||
|
agent._query_llm.side_effect = [KeyError(), "real value"]
|
||||||
|
res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1)
|
||||||
|
|
||||||
|
assert agent._query_llm.call_count == 1
|
||||||
|
assert res is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extracting_beliefs_from_program(agent, sample_program):
|
||||||
|
assert len(agent.available_beliefs) == 0
|
||||||
|
await agent.handle_message(
|
||||||
|
InternalMessage(
|
||||||
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
|
sender=settings.agent_settings.bdi_program_manager_name,
|
||||||
|
body=sample_program.model_dump_json(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert len(agent.available_beliefs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_invalid_program(agent, sample_program):
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
assert len(agent.available_beliefs) == 2
|
||||||
|
|
||||||
|
await agent.handle_message(
|
||||||
|
InternalMessage(
|
||||||
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
|
sender=settings.agent_settings.bdi_program_manager_name,
|
||||||
|
body=json.dumps({"phases": "Invalid"}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(agent.available_beliefs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_robot_response(agent):
|
||||||
|
initial_length = len(agent.conversation.messages)
|
||||||
|
response = "Hi, I'm Pepper. What's your name?"
|
||||||
|
|
||||||
|
await agent.handle_message(
|
||||||
|
InternalMessage(
|
||||||
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
|
sender=settings.agent_settings.llm_name,
|
||||||
|
body=response,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(agent.conversation.messages) == initial_length + 1
|
||||||
|
assert agent.conversation.messages[-1].role == "assistant"
|
||||||
|
assert agent.conversation.messages[-1].content == response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simulated_real_turn_with_beliefs(agent, sample_program):
|
||||||
|
"""Test sending user message to extract beliefs from."""
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
|
||||||
|
# Send a user message with the belief that there's no more booze
|
||||||
|
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
|
||||||
|
assert len(agent.conversation.messages) == 0
|
||||||
|
await agent.handle_message(
|
||||||
|
InternalMessage(
|
||||||
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
|
sender=settings.agent_settings.transcription_name,
|
||||||
|
body="We're all out of schnaps.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert len(agent.conversation.messages) == 1
|
||||||
|
|
||||||
|
# There should be a belief set and sent to the BDI core, as well as the user_said belief
|
||||||
|
assert agent.send.call_count == 2
|
||||||
|
|
||||||
|
# First should be the beliefs message
|
||||||
|
message: InternalMessage = agent.send.call_args_list[0].args[0]
|
||||||
|
beliefs = BeliefMessage.model_validate_json(message.body)
|
||||||
|
assert len(beliefs.create) == 1
|
||||||
|
assert beliefs.create[0].name == "no_more_booze"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simulated_real_turn_no_beliefs(agent, sample_program):
|
||||||
|
"""Test a user message to extract beliefs from, but no beliefs are formed."""
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
|
||||||
|
# Send a user message with no new beliefs
|
||||||
|
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
|
||||||
|
await agent.handle_message(
|
||||||
|
InternalMessage(
|
||||||
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
|
sender=settings.agent_settings.transcription_name,
|
||||||
|
body="Hello there!",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only the user_said belief should've been sent
|
||||||
|
agent.send.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
|
||||||
|
"""
|
||||||
|
Test a user message to extract beliefs from, but no new beliefs are formed because they already
|
||||||
|
existed.
|
||||||
|
"""
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
agent.beliefs["is_pirate"] = True
|
||||||
|
|
||||||
|
# Send a user message with the belief the user is a pirate, still
|
||||||
|
agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
|
||||||
|
await agent.handle_message(
|
||||||
|
InternalMessage(
|
||||||
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
|
sender=settings.agent_settings.transcription_name,
|
||||||
|
body="Arr, nice to meet you, matey.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only the user_said belief should've been sent, as no beliefs have changed
|
||||||
|
agent.send.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simulated_real_turn_remove_belief(agent, sample_program):
|
||||||
|
"""
|
||||||
|
Test a user message to extract beliefs from, but an existing belief is determined no longer to
|
||||||
|
hold.
|
||||||
|
"""
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
agent.beliefs["no_more_booze"] = True
|
||||||
|
|
||||||
|
# Send a user message with the belief the user is a pirate, still
|
||||||
|
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
|
||||||
|
await agent.handle_message(
|
||||||
|
InternalMessage(
|
||||||
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
|
sender=settings.agent_settings.transcription_name,
|
||||||
|
body="I found an untouched barrel of wine!",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both user_said and belief change should've been sent
|
||||||
|
assert agent.send.call_count == 2
|
||||||
|
|
||||||
|
# Agent's current beliefs should've changed
|
||||||
|
assert not agent.beliefs["no_more_booze"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_failure_handling(agent, sample_program):
|
||||||
|
"""
|
||||||
|
Check that the agent handles failures gracefully without crashing.
|
||||||
|
"""
|
||||||
|
agent._query_llm.side_effect = httpx.HTTPError("")
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
|
||||||
|
belief_changes = await agent._infer_turn()
|
||||||
|
|
||||||
|
assert len(belief_changes) == 0
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
import json
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from control_backend.agents.bdi import (
|
|
||||||
TextBeliefExtractorAgent,
|
|
||||||
)
|
|
||||||
from control_backend.core.agent_system import InternalMessage
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def agent():
|
|
||||||
agent = TextBeliefExtractorAgent("text_belief_agent")
|
|
||||||
agent.send = AsyncMock()
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
|
|
||||||
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_message_ignores_other_agents(agent):
|
|
||||||
msg = make_msg("unknown", "some data", None)
|
|
||||||
|
|
||||||
await agent.handle_message(msg)
|
|
||||||
|
|
||||||
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_message_from_transcriber(agent, mock_settings):
|
|
||||||
transcription = "hello world"
|
|
||||||
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
|
|
||||||
|
|
||||||
await agent.handle_message(msg)
|
|
||||||
|
|
||||||
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
|
||||||
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
|
||||||
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
|
||||||
assert sent.thread == "beliefs"
|
|
||||||
parsed = json.loads(sent.body)
|
|
||||||
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_process_user_said(agent, mock_settings):
|
|
||||||
transcription = "this is a test"
|
|
||||||
|
|
||||||
await agent._user_said(transcription)
|
|
||||||
|
|
||||||
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
|
||||||
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
|
||||||
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
|
||||||
assert sent.thread == "beliefs"
|
|
||||||
parsed = json.loads(sent.body)
|
|
||||||
assert parsed["beliefs"]["user_said"] == [transcription]
|
|
||||||
@@ -5,11 +5,15 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from control_backend.schemas.program import (
|
from control_backend.schemas.program import (
|
||||||
BasicNorm,
|
BasicNorm,
|
||||||
|
ConditionalNorm,
|
||||||
Goal,
|
Goal,
|
||||||
|
InferredBelief,
|
||||||
KeywordBelief,
|
KeywordBelief,
|
||||||
|
LogicalOperator,
|
||||||
Phase,
|
Phase,
|
||||||
Plan,
|
Plan,
|
||||||
Program,
|
Program,
|
||||||
|
SemanticBelief,
|
||||||
Trigger,
|
Trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,3 +101,104 @@ def test_invalid_program():
|
|||||||
bad = invalid_program()
|
bad = invalid_program()
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
Program.model_validate(bad)
|
Program.model_validate(bad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conditional_norm_parsing():
|
||||||
|
"""
|
||||||
|
Check that pydantic is able to preserve the type of the norm, that it doesn't lose its
|
||||||
|
"condition" field when serializing and deserializing.
|
||||||
|
"""
|
||||||
|
norm = ConditionalNorm(
|
||||||
|
name="testNormName",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
norm="testNormNorm",
|
||||||
|
critical=False,
|
||||||
|
condition=KeywordBelief(
|
||||||
|
name="testKeywordBelief",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
keyword="testKeywordBelief",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
program = Program(
|
||||||
|
phases=[
|
||||||
|
Phase(
|
||||||
|
name="Some phase",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
norms=[norm],
|
||||||
|
goals=[],
|
||||||
|
triggers=[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_program = Program.model_validate_json(program.model_dump_json())
|
||||||
|
parsed_norm = parsed_program.phases[0].norms[0]
|
||||||
|
|
||||||
|
assert hasattr(parsed_norm, "condition")
|
||||||
|
assert isinstance(parsed_norm, ConditionalNorm)
|
||||||
|
|
||||||
|
|
||||||
|
def test_belief_type_parsing():
|
||||||
|
"""
|
||||||
|
Check that pydantic is able to discern between the different types of beliefs when serializing
|
||||||
|
and deserializing.
|
||||||
|
"""
|
||||||
|
keyword_belief = KeywordBelief(
|
||||||
|
name="testKeywordBelief",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
keyword="something",
|
||||||
|
)
|
||||||
|
semantic_belief = SemanticBelief(
|
||||||
|
name="testSemanticBelief",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
description="something",
|
||||||
|
)
|
||||||
|
inferred_belief = InferredBelief(
|
||||||
|
name="testInferredBelief",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
operator=LogicalOperator.OR,
|
||||||
|
left=keyword_belief,
|
||||||
|
right=semantic_belief,
|
||||||
|
)
|
||||||
|
|
||||||
|
program = Program(
|
||||||
|
phases=[
|
||||||
|
Phase(
|
||||||
|
name="Some phase",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
norms=[],
|
||||||
|
goals=[],
|
||||||
|
triggers=[
|
||||||
|
Trigger(
|
||||||
|
name="testTriggerKeywordTrigger",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
condition=keyword_belief,
|
||||||
|
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
|
||||||
|
),
|
||||||
|
Trigger(
|
||||||
|
name="testTriggerSemanticTrigger",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
condition=semantic_belief,
|
||||||
|
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
|
||||||
|
),
|
||||||
|
Trigger(
|
||||||
|
name="testTriggerInferredTrigger",
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
condition=inferred_belief,
|
||||||
|
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_program = Program.model_validate_json(program.model_dump_json())
|
||||||
|
|
||||||
|
parsed_keyword_belief = parsed_program.phases[0].triggers[0].condition
|
||||||
|
assert isinstance(parsed_keyword_belief, KeywordBelief)
|
||||||
|
|
||||||
|
parsed_semantic_belief = parsed_program.phases[0].triggers[1].condition
|
||||||
|
assert isinstance(parsed_semantic_belief, SemanticBelief)
|
||||||
|
|
||||||
|
parsed_inferred_belief = parsed_program.phases[0].triggers[2].condition
|
||||||
|
assert isinstance(parsed_inferred_belief, InferredBelief)
|
||||||
|
|||||||
Reference in New Issue
Block a user