377 lines
13 KiB
Python
377 lines
13 KiB
Python
import json
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from control_backend.agents.bdi import TextBeliefExtractorAgent
|
|
from control_backend.agents.bdi.text_belief_extractor_agent import BeliefState
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.core.config import settings
|
|
from control_backend.schemas.belief_list import BeliefList
|
|
from control_backend.schemas.belief_message import Belief as InternalBelief
|
|
from control_backend.schemas.belief_message import BeliefMessage
|
|
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
|
|
from control_backend.schemas.program import (
|
|
ConditionalNorm,
|
|
KeywordBelief,
|
|
LLMAction,
|
|
Phase,
|
|
Plan,
|
|
Program,
|
|
SemanticBelief,
|
|
Trigger,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def llm():
|
|
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
|
|
llm._query_llm = AsyncMock()
|
|
return llm
|
|
|
|
|
|
@pytest.fixture
|
|
def agent(llm):
|
|
with patch(
|
|
"control_backend.agents.bdi.text_belief_extractor_agent.TextBeliefExtractorAgent.LLM",
|
|
return_value=llm,
|
|
):
|
|
agent = TextBeliefExtractorAgent("text_belief_agent")
|
|
agent.send = AsyncMock()
|
|
return agent
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_program():
|
|
return Program(
|
|
phases=[
|
|
Phase(
|
|
name="Some phase",
|
|
id=uuid.uuid4(),
|
|
norms=[
|
|
ConditionalNorm(
|
|
name="Some norm",
|
|
id=uuid.uuid4(),
|
|
norm="Use nautical terms.",
|
|
critical=False,
|
|
condition=SemanticBelief(
|
|
name="is_pirate",
|
|
id=uuid.uuid4(),
|
|
description="The user is a pirate. Perhaps because they say "
|
|
"they are, or because they speak like a pirate "
|
|
'with terms like "arr".',
|
|
),
|
|
),
|
|
],
|
|
goals=[],
|
|
triggers=[
|
|
Trigger(
|
|
name="Some trigger",
|
|
id=uuid.uuid4(),
|
|
condition=SemanticBelief(
|
|
name="no_more_booze",
|
|
id=uuid.uuid4(),
|
|
description="There is no more alcohol.",
|
|
),
|
|
plan=Plan(
|
|
name="Some plan",
|
|
id=uuid.uuid4(),
|
|
steps=[
|
|
LLMAction(
|
|
name="Some action",
|
|
id=uuid.uuid4(),
|
|
goal="Suggest eating chocolate instead.",
|
|
),
|
|
],
|
|
),
|
|
),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
|
|
|
|
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
|
|
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_ignores_other_agents(agent):
|
|
msg = make_msg("unknown", "some data", None)
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_from_transcriber(agent, mock_settings):
|
|
transcription = "hello world"
|
|
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
|
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
|
assert sent.to == mock_settings.agent_settings.bdi_core_name
|
|
assert sent.thread == "beliefs"
|
|
parsed = BeliefMessage.model_validate_json(sent.body)
|
|
replaced_last = parsed.replace.pop()
|
|
assert replaced_last.name == "user_said"
|
|
assert replaced_last.arguments == [transcription]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_llm():
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "null",
|
|
}
|
|
}
|
|
]
|
|
}
|
|
mock_client = AsyncMock()
|
|
mock_client.post.return_value = mock_response
|
|
mock_async_client = MagicMock()
|
|
mock_async_client.__aenter__.return_value = mock_client
|
|
mock_async_client.__aexit__.return_value = None
|
|
|
|
with patch(
|
|
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
|
|
return_value=mock_async_client,
|
|
):
|
|
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
|
|
|
|
res = await llm._query_llm("hello world", {"type": "null"})
|
|
# Response content was set as "null", so should be deserialized as None
|
|
assert res is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_query_llm_success(llm):
|
|
llm._query_llm.return_value = None
|
|
res = await llm.query("hello world", {"type": "null"})
|
|
|
|
llm._query_llm.assert_called_once()
|
|
assert res is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_query_llm_success_after_failure(llm):
|
|
llm._query_llm.side_effect = [KeyError(), "real value"]
|
|
res = await llm.query("hello world", {"type": "string"})
|
|
|
|
assert llm._query_llm.call_count == 2
|
|
assert res == "real value"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_query_llm_failures(llm):
|
|
llm._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
|
|
res = await llm.query("hello world", {"type": "string"})
|
|
|
|
assert llm._query_llm.call_count == 3
|
|
assert res is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_query_llm_fail_immediately(llm):
|
|
llm._query_llm.side_effect = [KeyError(), "real value"]
|
|
res = await llm.query("hello world", {"type": "string"}, tries=1)
|
|
|
|
assert llm._query_llm.call_count == 1
|
|
assert res is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extracting_semantic_beliefs(agent):
|
|
"""
|
|
The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly.
|
|
"""
|
|
assert len(agent.belief_inferrer.available_beliefs) == 0
|
|
beliefs = BeliefList(
|
|
beliefs=[
|
|
KeywordBelief(
|
|
id=uuid.uuid4(),
|
|
name="keyword_hello",
|
|
keyword="hello",
|
|
),
|
|
SemanticBelief(
|
|
id=uuid.uuid4(), name="semantic_hello_1", description="Some semantic belief 1"
|
|
),
|
|
SemanticBelief(
|
|
id=uuid.uuid4(), name="semantic_hello_2", description="Some semantic belief 2"
|
|
),
|
|
]
|
|
)
|
|
await agent.handle_message(
|
|
InternalMessage(
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
sender=settings.agent_settings.bdi_program_manager_name,
|
|
body=beliefs.model_dump_json(),
|
|
thread="beliefs",
|
|
),
|
|
)
|
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_invalid_beliefs(agent, sample_program):
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
|
|
|
await agent.handle_message(
|
|
InternalMessage(
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
sender=settings.agent_settings.bdi_program_manager_name,
|
|
body=json.dumps({"phases": "Invalid"}),
|
|
thread="beliefs",
|
|
),
|
|
)
|
|
|
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_robot_response(agent):
|
|
initial_length = len(agent.conversation.messages)
|
|
response = "Hi, I'm Pepper. What's your name?"
|
|
|
|
await agent.handle_message(
|
|
InternalMessage(
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
sender=settings.agent_settings.llm_name,
|
|
body=response,
|
|
),
|
|
)
|
|
|
|
assert len(agent.conversation.messages) == initial_length + 1
|
|
assert agent.conversation.messages[-1].role == "assistant"
|
|
assert agent.conversation.messages[-1].content == response
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simulated_real_turn_with_beliefs(agent, llm, sample_program):
|
|
"""Test sending user message to extract beliefs from."""
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
# Send a user message with the belief that there's no more booze
|
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
|
|
assert len(agent.conversation.messages) == 0
|
|
await agent.handle_message(
|
|
InternalMessage(
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
sender=settings.agent_settings.transcription_name,
|
|
body="We're all out of schnaps.",
|
|
),
|
|
)
|
|
assert len(agent.conversation.messages) == 1
|
|
|
|
# There should be a belief set and sent to the BDI core, as well as the user_said belief
|
|
assert agent.send.call_count == 2
|
|
|
|
# First should be the beliefs message
|
|
message: InternalMessage = agent.send.call_args_list[1].args[0]
|
|
beliefs = BeliefMessage.model_validate_json(message.body)
|
|
assert len(beliefs.create) == 1
|
|
assert beliefs.create[0].name == "no_more_booze"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simulated_real_turn_no_beliefs(agent, llm, sample_program):
|
|
"""Test a user message to extract beliefs from, but no beliefs are formed."""
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
# Send a user message with no new beliefs
|
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
|
|
await agent.handle_message(
|
|
InternalMessage(
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
sender=settings.agent_settings.transcription_name,
|
|
body="Hello there!",
|
|
),
|
|
)
|
|
|
|
# Only the user_said belief should've been sent
|
|
agent.send.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simulated_real_turn_no_new_beliefs(agent, llm, sample_program):
|
|
"""
|
|
Test a user message to extract beliefs from, but no new beliefs are formed because they already
|
|
existed.
|
|
"""
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
agent._current_beliefs = BeliefState(true={InternalBelief(name="is_pirate", arguments=None)})
|
|
|
|
# Send a user message with the belief the user is a pirate, still
|
|
llm._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
|
|
await agent.handle_message(
|
|
InternalMessage(
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
sender=settings.agent_settings.transcription_name,
|
|
body="Arr, nice to meet you, matey.",
|
|
),
|
|
)
|
|
|
|
# Only the user_said belief should've been sent, as no beliefs have changed
|
|
agent.send.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_simulated_real_turn_remove_belief(agent, llm, sample_program):
|
|
"""
|
|
Test a user message to extract beliefs from, but an existing belief is determined no longer to
|
|
hold.
|
|
"""
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
agent._current_beliefs = BeliefState(
|
|
true={InternalBelief(name="no_more_booze", arguments=None)},
|
|
)
|
|
|
|
# Send a user message with the belief the user is a pirate, still
|
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
|
|
await agent.handle_message(
|
|
InternalMessage(
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
sender=settings.agent_settings.transcription_name,
|
|
body="I found an untouched barrel of wine!",
|
|
),
|
|
)
|
|
|
|
# Both user_said and belief change should've been sent
|
|
assert agent.send.call_count == 2
|
|
|
|
# Agent's current beliefs should've changed
|
|
assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_failure_handling(agent, llm, sample_program):
|
|
"""
|
|
Check that the agent handles failures gracefully without crashing.
|
|
"""
|
|
llm._query_llm.side_effect = httpx.HTTPError("")
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
belief_changes = await agent.belief_inferrer.infer_from_conversation(
|
|
ChatHistory(
|
|
messages=[ChatMessage(role="user", content="Good day!")],
|
|
),
|
|
)
|
|
|
|
assert len(belief_changes.true) == 0
|
|
assert len(belief_changes.false) == 0
|