test: create tests for belief extractor agent

Includes changes in schemas. Change type of `norms` in `Program` imperceptibly, big changes in schema of `BeliefMessage` to support deleting beliefs.

ref: N25B-380
This commit is contained in:
Twirre Meulenbelt
2025-12-29 17:12:02 +01:00
parent 57b1276cb5
commit 42ee5c76d8
10 changed files with 530 additions and 92 deletions

View File

@@ -51,7 +51,7 @@ async def test_handle_belief_collector_message(agent, mock_settings):
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
body=BeliefMessage(create=beliefs).model_dump_json(),
thread="beliefs",
)
@@ -64,6 +64,26 @@ async def test_handle_belief_collector_message(agent, mock_settings):
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
@pytest.mark.asyncio
async def test_handle_delete_belief_message(agent, mock_settings):
"""Test that incoming beliefs to be deleted are removed from the BDI agent"""
beliefs = [Belief(name="user_said", arguments=["Hello"])]
msg = InternalMessage(
to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name,
body=BeliefMessage(delete=beliefs).model_dump_json(),
thread="beliefs",
)
await agent.handle_message(msg)
# Expect bdi_agent.call to be triggered to remove belief
args = agent.bdi_agent.call.call_args.args
assert args[0] == agentspeak.Trigger.removal
assert args[1] == agentspeak.GoalType.belief
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
@pytest.mark.asyncio
async def test_incorrect_belief_collector_message(agent, mock_settings):
"""Test that incorrect message format triggers an exception."""
@@ -128,7 +148,8 @@ def test_add_belief_sets_event(agent):
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="test_belief", arguments=["a", "b"])
agent._apply_beliefs([belief])
belief_changes = BeliefMessage(replace=[belief])
agent._apply_belief_changes(belief_changes)
assert agent.bdi_agent.call.called
agent._wake_bdi_loop.set.assert_called()
@@ -137,7 +158,7 @@ def test_add_belief_sets_event(agent):
def test_apply_beliefs_empty_returns(agent):
"""Line: if not beliefs: return"""
agent._wake_bdi_loop = MagicMock()
agent._apply_beliefs([])
agent._apply_belief_changes(BeliefMessage())
agent.bdi_agent.call.assert_not_called()
agent._wake_bdi_loop.set.assert_not_called()
@@ -220,8 +241,9 @@ def test_replace_belief_calls_remove_all(agent):
agent._remove_all_with_name = MagicMock()
agent._wake_bdi_loop = MagicMock()
belief = Belief(name="user_said", arguments=["Hello"], replace=True)
agent._apply_beliefs([belief])
belief = Belief(name="user_said", arguments=["Hello"])
belief_changes = BeliefMessage(replace=[belief])
agent._apply_belief_changes(belief_changes)
agent._remove_all_with_name.assert_called_with("user_said")

View File

@@ -86,7 +86,7 @@ async def test_send_beliefs_to_bdi(agent):
sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs"
assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]
assert json.loads(sent.body)["create"] == [belief.model_dump() for belief in beliefs]
@pytest.mark.asyncio

View File

@@ -0,0 +1,346 @@
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from control_backend.agents.bdi import TextBeliefExtractorAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage
from control_backend.schemas.program import (
ConditionalNorm,
LLMAction,
Phase,
Plan,
Program,
SemanticBelief,
Trigger,
)
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
agent._query_llm = AsyncMock()
return agent
@pytest.fixture
def sample_program():
return Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[
ConditionalNorm(
name="Some norm",
id=uuid.uuid4(),
norm="Use nautical terms.",
critical=False,
condition=SemanticBelief(
name="is_pirate",
id=uuid.uuid4(),
description="The user is a pirate. Perhaps because they say "
"they are, or because they speak like a pirate "
'with terms like "arr".',
),
),
],
goals=[],
triggers=[
Trigger(
name="Some trigger",
id=uuid.uuid4(),
condition=SemanticBelief(
name="no_more_booze",
id=uuid.uuid4(),
description="There is no more alcohol.",
),
plan=Plan(
name="Some plan",
id=uuid.uuid4(),
steps=[
LLMAction(
name="Some action",
id=uuid.uuid4(),
goal="Suggest eating chocolate instead.",
),
],
),
),
],
),
],
)
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_user_said(agent, mock_settings):
transcription = "this is a test"
await agent._user_said(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]
@pytest.mark.asyncio
async def test_query_llm():
mock_response = MagicMock()
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": "null",
}
}
]
}
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_async_client = MagicMock()
mock_async_client.__aenter__.return_value = mock_client
mock_async_client.__aexit__.return_value = None
with patch(
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
return_value=mock_async_client,
):
agent = TextBeliefExtractorAgent("text_belief_agent")
res = await agent._query_llm("hello world", {"type": "null"})
# Response content was set as "null", so should be deserialized as None
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success(agent):
agent._query_llm.return_value = None
res = await agent._retry_query_llm("hello world", {"type": "null"})
agent._query_llm.assert_called_once()
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_success_after_failure(agent):
agent._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"})
assert agent._query_llm.call_count == 2
assert res == "real value"
@pytest.mark.asyncio
async def test_retry_query_llm_failures(agent):
agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"})
assert agent._query_llm.call_count == 3
assert res is None
@pytest.mark.asyncio
async def test_retry_query_llm_fail_immediately(agent):
agent._query_llm.side_effect = [KeyError(), "real value"]
res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1)
assert agent._query_llm.call_count == 1
assert res is None
@pytest.mark.asyncio
async def test_extracting_beliefs_from_program(agent, sample_program):
assert len(agent.available_beliefs) == 0
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=sample_program.model_dump_json(),
),
)
assert len(agent.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_invalid_program(agent, sample_program):
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
assert len(agent.available_beliefs) == 2
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.bdi_program_manager_name,
body=json.dumps({"phases": "Invalid"}),
),
)
assert len(agent.available_beliefs) == 2
@pytest.mark.asyncio
async def test_handle_robot_response(agent):
initial_length = len(agent.conversation.messages)
response = "Hi, I'm Pepper. What's your name?"
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.llm_name,
body=response,
),
)
assert len(agent.conversation.messages) == initial_length + 1
assert agent.conversation.messages[-1].role == "assistant"
assert agent.conversation.messages[-1].content == response
@pytest.mark.asyncio
async def test_simulated_real_turn_with_beliefs(agent, sample_program):
"""Test sending user message to extract beliefs from."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with the belief that there's no more booze
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
assert len(agent.conversation.messages) == 0
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="We're all out of schnaps.",
),
)
assert len(agent.conversation.messages) == 1
# There should be a belief set and sent to the BDI core, as well as the user_said belief
assert agent.send.call_count == 2
# First should be the beliefs message
message: InternalMessage = agent.send.call_args_list[0].args[0]
beliefs = BeliefMessage.model_validate_json(message.body)
assert len(beliefs.create) == 1
assert beliefs.create[0].name == "no_more_booze"
@pytest.mark.asyncio
async def test_simulated_real_turn_no_beliefs(agent, sample_program):
"""Test a user message to extract beliefs from, but no beliefs are formed."""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
# Send a user message with no new beliefs
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="Hello there!",
),
)
# Only the user_said belief should've been sent
agent.send.assert_called_once()
@pytest.mark.asyncio
async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
"""
Test a user message to extract beliefs from, but no new beliefs are formed because they already
existed.
"""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["is_pirate"] = True
# Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="Arr, nice to meet you, matey.",
),
)
# Only the user_said belief should've been sent, as no beliefs have changed
agent.send.assert_called_once()
@pytest.mark.asyncio
async def test_simulated_real_turn_remove_belief(agent, sample_program):
"""
Test a user message to extract beliefs from, but an existing belief is determined no longer to
hold.
"""
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
agent.beliefs["no_more_booze"] = True
# Send a user message with the belief the user is a pirate, still
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
await agent.handle_message(
InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=settings.agent_settings.transcription_name,
body="I found an untouched barrel of wine!",
),
)
# Both user_said and belief change should've been sent
assert agent.send.call_count == 2
# Agent's current beliefs should've changed
assert not agent.beliefs["no_more_booze"]
@pytest.mark.asyncio
async def test_llm_failure_handling(agent, sample_program):
"""
Check that the agent handles failures gracefully without crashing.
"""
agent._query_llm.side_effect = httpx.HTTPError("")
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
belief_changes = await agent._infer_turn()
assert len(belief_changes) == 0

View File

@@ -1,58 +0,0 @@
import json
from unittest.mock import AsyncMock
import pytest
from control_backend.agents.bdi import (
TextBeliefExtractorAgent,
)
from control_backend.core.agent_system import InternalMessage
@pytest.fixture
def agent():
agent = TextBeliefExtractorAgent("text_belief_agent")
agent.send = AsyncMock()
return agent
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
@pytest.mark.asyncio
async def test_handle_message_ignores_other_agents(agent):
msg = make_msg("unknown", "some data", None)
await agent.handle_message(msg)
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
@pytest.mark.asyncio
async def test_handle_message_from_transcriber(agent, mock_settings):
transcription = "hello world"
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
await agent.handle_message(msg)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
@pytest.mark.asyncio
async def test_process_user_said(agent, mock_settings):
transcription = "this is a test"
await agent._user_said(transcription)
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
sent: InternalMessage = agent.send.call_args.args[0] # noqa
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
assert sent.thread == "beliefs"
parsed = json.loads(sent.body)
assert parsed["beliefs"]["user_said"] == [transcription]

View File

@@ -5,11 +5,15 @@ from pydantic import ValidationError
from control_backend.schemas.program import (
BasicNorm,
ConditionalNorm,
Goal,
InferredBelief,
KeywordBelief,
LogicalOperator,
Phase,
Plan,
Program,
SemanticBelief,
Trigger,
)
@@ -97,3 +101,104 @@ def test_invalid_program():
bad = invalid_program()
with pytest.raises(ValidationError):
Program.model_validate(bad)
def test_conditional_norm_parsing():
"""
Check that pydantic is able to preserve the type of the norm, that it doesn't lose its
"condition" field when serializing and deserializing.
"""
norm = ConditionalNorm(
name="testNormName",
id=uuid.uuid4(),
norm="testNormNorm",
critical=False,
condition=KeywordBelief(
name="testKeywordBelief",
id=uuid.uuid4(),
keyword="testKeywordBelief",
),
)
program = Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[norm],
goals=[],
triggers=[],
),
],
)
parsed_program = Program.model_validate_json(program.model_dump_json())
parsed_norm = parsed_program.phases[0].norms[0]
assert hasattr(parsed_norm, "condition")
assert isinstance(parsed_norm, ConditionalNorm)
def test_belief_type_parsing():
"""
Check that pydantic is able to discern between the different types of beliefs when serializing
and deserializing.
"""
keyword_belief = KeywordBelief(
name="testKeywordBelief",
id=uuid.uuid4(),
keyword="something",
)
semantic_belief = SemanticBelief(
name="testSemanticBelief",
id=uuid.uuid4(),
description="something",
)
inferred_belief = InferredBelief(
name="testInferredBelief",
id=uuid.uuid4(),
operator=LogicalOperator.OR,
left=keyword_belief,
right=semantic_belief,
)
program = Program(
phases=[
Phase(
name="Some phase",
id=uuid.uuid4(),
norms=[],
goals=[],
triggers=[
Trigger(
name="testTriggerKeywordTrigger",
id=uuid.uuid4(),
condition=keyword_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
Trigger(
name="testTriggerSemanticTrigger",
id=uuid.uuid4(),
condition=semantic_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
Trigger(
name="testTriggerInferredTrigger",
id=uuid.uuid4(),
condition=inferred_belief,
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
),
],
),
],
)
parsed_program = Program.model_validate_json(program.model_dump_json())
parsed_keyword_belief = parsed_program.phases[0].triggers[0].condition
assert isinstance(parsed_keyword_belief, KeywordBelief)
parsed_semantic_belief = parsed_program.phases[0].triggers[1].condition
assert isinstance(parsed_semantic_belief, SemanticBelief)
parsed_inferred_belief = parsed_program.phases[0].triggers[2].condition
assert isinstance(parsed_inferred_belief, InferredBelief)