|
|
|
|
@@ -6,10 +6,13 @@ import httpx
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
from control_backend.agents.bdi import TextBeliefExtractorAgent
|
|
|
|
|
from control_backend.agents.bdi.text_belief_extractor_agent import BeliefState
|
|
|
|
|
from control_backend.core.agent_system import InternalMessage
|
|
|
|
|
from control_backend.core.config import settings
|
|
|
|
|
from control_backend.schemas.belief_list import BeliefList
|
|
|
|
|
from control_backend.schemas.belief_message import Belief as InternalBelief
|
|
|
|
|
from control_backend.schemas.belief_message import BeliefMessage
|
|
|
|
|
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
|
|
|
|
|
from control_backend.schemas.program import (
|
|
|
|
|
ConditionalNorm,
|
|
|
|
|
KeywordBelief,
|
|
|
|
|
@@ -23,11 +26,21 @@ from control_backend.schemas.program import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def agent():
|
|
|
|
|
agent = TextBeliefExtractorAgent("text_belief_agent")
|
|
|
|
|
agent.send = AsyncMock()
|
|
|
|
|
agent._query_llm = AsyncMock()
|
|
|
|
|
return agent
|
|
|
|
|
def llm():
|
|
|
|
|
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
|
|
|
|
|
llm._query_llm = AsyncMock()
|
|
|
|
|
return llm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def agent(llm):
|
|
|
|
|
with patch(
|
|
|
|
|
"control_backend.agents.bdi.text_belief_extractor_agent.TextBeliefExtractorAgent.LLM",
|
|
|
|
|
return_value=llm,
|
|
|
|
|
):
|
|
|
|
|
agent = TextBeliefExtractorAgent("text_belief_agent")
|
|
|
|
|
agent.send = AsyncMock()
|
|
|
|
|
return agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
@@ -102,24 +115,12 @@ async def test_handle_message_from_transcriber(agent, mock_settings):
|
|
|
|
|
|
|
|
|
|
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
|
|
|
|
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
|
|
|
|
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
|
|
|
|
assert sent.to == mock_settings.agent_settings.bdi_core_name
|
|
|
|
|
assert sent.thread == "beliefs"
|
|
|
|
|
parsed = json.loads(sent.body)
|
|
|
|
|
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_process_user_said(agent, mock_settings):
|
|
|
|
|
transcription = "this is a test"
|
|
|
|
|
|
|
|
|
|
await agent._user_said(transcription)
|
|
|
|
|
|
|
|
|
|
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
|
|
|
|
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
|
|
|
|
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
|
|
|
|
assert sent.thread == "beliefs"
|
|
|
|
|
parsed = json.loads(sent.body)
|
|
|
|
|
assert parsed["beliefs"]["user_said"] == [transcription]
|
|
|
|
|
parsed = BeliefMessage.model_validate_json(sent.body)
|
|
|
|
|
replaced_last = parsed.replace.pop()
|
|
|
|
|
assert replaced_last.name == "user_said"
|
|
|
|
|
assert replaced_last.arguments == [transcription]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
@@ -144,46 +145,46 @@ async def test_query_llm():
|
|
|
|
|
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
|
|
|
|
|
return_value=mock_async_client,
|
|
|
|
|
):
|
|
|
|
|
agent = TextBeliefExtractorAgent("text_belief_agent")
|
|
|
|
|
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
|
|
|
|
|
|
|
|
|
|
res = await agent._query_llm("hello world", {"type": "null"})
|
|
|
|
|
res = await llm._query_llm("hello world", {"type": "null"})
|
|
|
|
|
# Response content was set as "null", so should be deserialized as None
|
|
|
|
|
assert res is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_retry_query_llm_success(agent):
|
|
|
|
|
agent._query_llm.return_value = None
|
|
|
|
|
res = await agent._retry_query_llm("hello world", {"type": "null"})
|
|
|
|
|
async def test_retry_query_llm_success(llm):
|
|
|
|
|
llm._query_llm.return_value = None
|
|
|
|
|
res = await llm.query("hello world", {"type": "null"})
|
|
|
|
|
|
|
|
|
|
agent._query_llm.assert_called_once()
|
|
|
|
|
llm._query_llm.assert_called_once()
|
|
|
|
|
assert res is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_retry_query_llm_success_after_failure(agent):
|
|
|
|
|
agent._query_llm.side_effect = [KeyError(), "real value"]
|
|
|
|
|
res = await agent._retry_query_llm("hello world", {"type": "string"})
|
|
|
|
|
async def test_retry_query_llm_success_after_failure(llm):
|
|
|
|
|
llm._query_llm.side_effect = [KeyError(), "real value"]
|
|
|
|
|
res = await llm.query("hello world", {"type": "string"})
|
|
|
|
|
|
|
|
|
|
assert agent._query_llm.call_count == 2
|
|
|
|
|
assert llm._query_llm.call_count == 2
|
|
|
|
|
assert res == "real value"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_retry_query_llm_failures(agent):
|
|
|
|
|
agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
|
|
|
|
|
res = await agent._retry_query_llm("hello world", {"type": "string"})
|
|
|
|
|
async def test_retry_query_llm_failures(llm):
|
|
|
|
|
llm._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
|
|
|
|
|
res = await llm.query("hello world", {"type": "string"})
|
|
|
|
|
|
|
|
|
|
assert agent._query_llm.call_count == 3
|
|
|
|
|
assert llm._query_llm.call_count == 3
|
|
|
|
|
assert res is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_retry_query_llm_fail_immediately(agent):
|
|
|
|
|
agent._query_llm.side_effect = [KeyError(), "real value"]
|
|
|
|
|
res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1)
|
|
|
|
|
async def test_retry_query_llm_fail_immediately(llm):
|
|
|
|
|
llm._query_llm.side_effect = [KeyError(), "real value"]
|
|
|
|
|
res = await llm.query("hello world", {"type": "string"}, tries=1)
|
|
|
|
|
|
|
|
|
|
assert agent._query_llm.call_count == 1
|
|
|
|
|
assert llm._query_llm.call_count == 1
|
|
|
|
|
assert res is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -192,7 +193,7 @@ async def test_extracting_semantic_beliefs(agent):
|
|
|
|
|
"""
|
|
|
|
|
The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly.
|
|
|
|
|
"""
|
|
|
|
|
assert len(agent.available_beliefs) == 0
|
|
|
|
|
assert len(agent.belief_inferrer.available_beliefs) == 0
|
|
|
|
|
beliefs = BeliefList(
|
|
|
|
|
beliefs=[
|
|
|
|
|
KeywordBelief(
|
|
|
|
|
@@ -213,26 +214,28 @@ async def test_extracting_semantic_beliefs(agent):
|
|
|
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
|
|
|
sender=settings.agent_settings.bdi_program_manager_name,
|
|
|
|
|
body=beliefs.model_dump_json(),
|
|
|
|
|
thread="beliefs",
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
assert len(agent.available_beliefs) == 2
|
|
|
|
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_handle_invalid_program(agent, sample_program):
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
assert len(agent.available_beliefs) == 2
|
|
|
|
|
async def test_handle_invalid_beliefs(agent, sample_program):
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
|
|
|
|
|
|
|
|
|
await agent.handle_message(
|
|
|
|
|
InternalMessage(
|
|
|
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
|
|
|
sender=settings.agent_settings.bdi_program_manager_name,
|
|
|
|
|
body=json.dumps({"phases": "Invalid"}),
|
|
|
|
|
thread="beliefs",
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(agent.available_beliefs) == 2
|
|
|
|
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
@@ -254,13 +257,13 @@ async def test_handle_robot_response(agent):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_simulated_real_turn_with_beliefs(agent, sample_program):
|
|
|
|
|
async def test_simulated_real_turn_with_beliefs(agent, llm, sample_program):
|
|
|
|
|
"""Test sending user message to extract beliefs from."""
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
|
|
|
|
|
# Send a user message with the belief that there's no more booze
|
|
|
|
|
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
|
|
|
|
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
|
|
|
|
|
assert len(agent.conversation.messages) == 0
|
|
|
|
|
await agent.handle_message(
|
|
|
|
|
InternalMessage(
|
|
|
|
|
@@ -275,20 +278,20 @@ async def test_simulated_real_turn_with_beliefs(agent, sample_program):
|
|
|
|
|
assert agent.send.call_count == 2
|
|
|
|
|
|
|
|
|
|
# First should be the beliefs message
|
|
|
|
|
message: InternalMessage = agent.send.call_args_list[0].args[0]
|
|
|
|
|
message: InternalMessage = agent.send.call_args_list[1].args[0]
|
|
|
|
|
beliefs = BeliefMessage.model_validate_json(message.body)
|
|
|
|
|
assert len(beliefs.create) == 1
|
|
|
|
|
assert beliefs.create[0].name == "no_more_booze"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_simulated_real_turn_no_beliefs(agent, sample_program):
|
|
|
|
|
async def test_simulated_real_turn_no_beliefs(agent, llm, sample_program):
|
|
|
|
|
"""Test a user message to extract beliefs from, but no beliefs are formed."""
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
|
|
|
|
|
# Send a user message with no new beliefs
|
|
|
|
|
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
|
|
|
|
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
|
|
|
|
|
await agent.handle_message(
|
|
|
|
|
InternalMessage(
|
|
|
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
|
|
|
@@ -302,17 +305,17 @@ async def test_simulated_real_turn_no_beliefs(agent, sample_program):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
|
|
|
|
|
async def test_simulated_real_turn_no_new_beliefs(agent, llm, sample_program):
|
|
|
|
|
"""
|
|
|
|
|
Test a user message to extract beliefs from, but no new beliefs are formed because they already
|
|
|
|
|
existed.
|
|
|
|
|
"""
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
agent.beliefs["is_pirate"] = True
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
agent._current_beliefs = BeliefState(true={InternalBelief(name="is_pirate", arguments=None)})
|
|
|
|
|
|
|
|
|
|
# Send a user message with the belief the user is a pirate, still
|
|
|
|
|
agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
|
|
|
|
|
llm._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
|
|
|
|
|
await agent.handle_message(
|
|
|
|
|
InternalMessage(
|
|
|
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
|
|
|
@@ -326,17 +329,19 @@ async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_simulated_real_turn_remove_belief(agent, sample_program):
|
|
|
|
|
async def test_simulated_real_turn_remove_belief(agent, llm, sample_program):
|
|
|
|
|
"""
|
|
|
|
|
Test a user message to extract beliefs from, but an existing belief is determined no longer to
|
|
|
|
|
hold.
|
|
|
|
|
"""
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
agent.beliefs["no_more_booze"] = True
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
agent._current_beliefs = BeliefState(
|
|
|
|
|
true={InternalBelief(name="no_more_booze", arguments=None)},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Send a user message with the belief the user is a pirate, still
|
|
|
|
|
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
|
|
|
|
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
|
|
|
|
|
await agent.handle_message(
|
|
|
|
|
InternalMessage(
|
|
|
|
|
to=settings.agent_settings.text_belief_extractor_name,
|
|
|
|
|
@@ -349,18 +354,23 @@ async def test_simulated_real_turn_remove_belief(agent, sample_program):
|
|
|
|
|
assert agent.send.call_count == 2
|
|
|
|
|
|
|
|
|
|
# Agent's current beliefs should've changed
|
|
|
|
|
assert not agent.beliefs["no_more_booze"]
|
|
|
|
|
assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_llm_failure_handling(agent, sample_program):
|
|
|
|
|
async def test_llm_failure_handling(agent, llm, sample_program):
|
|
|
|
|
"""
|
|
|
|
|
Check that the agent handles failures gracefully without crashing.
|
|
|
|
|
"""
|
|
|
|
|
agent._query_llm.side_effect = httpx.HTTPError("")
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
llm._query_llm.side_effect = httpx.HTTPError("")
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
|
|
|
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
|
|
|
|
|
|
|
|
|
belief_changes = await agent._infer_turn()
|
|
|
|
|
belief_changes = await agent.belief_inferrer.infer_from_conversation(
|
|
|
|
|
ChatHistory(
|
|
|
|
|
messages=[ChatMessage(role="user", content="Good day!")],
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert len(belief_changes) == 0
|
|
|
|
|
assert len(belief_changes.true) == 0
|
|
|
|
|
assert len(belief_changes.false) == 0
|
|
|
|
|
|