The Big One #43
@@ -101,7 +101,6 @@ class BDICoreAgent(BaseAgent):
|
|||||||
maybe_more_work = True
|
maybe_more_work = True
|
||||||
while maybe_more_work:
|
while maybe_more_work:
|
||||||
maybe_more_work = False
|
maybe_more_work = False
|
||||||
self.logger.debug("Stepping BDI.")
|
|
||||||
if self.bdi_agent.step():
|
if self.bdi_agent.step():
|
||||||
maybe_more_work = True
|
maybe_more_work = True
|
||||||
|
|
||||||
|
|||||||
@@ -67,14 +67,14 @@ class BDIProgramManager(BaseAgent):
|
|||||||
|
|
||||||
await self.send(msg)
|
await self.send(msg)
|
||||||
|
|
||||||
def handle_message(self, msg: InternalMessage):
|
async def handle_message(self, msg: InternalMessage):
|
||||||
match msg.thread:
|
match msg.thread:
|
||||||
case "transition_phase":
|
case "transition_phase":
|
||||||
phases = json.loads(msg.body)
|
phases = json.loads(msg.body)
|
||||||
|
|
||||||
self._transition_phase(phases["old"], phases["new"])
|
await self._transition_phase(phases["old"], phases["new"])
|
||||||
|
|
||||||
def _transition_phase(self, old: str, new: str):
|
async def _transition_phase(self, old: str, new: str):
|
||||||
assert old == str(self._phase.id)
|
assert old == str(self._phase.id)
|
||||||
|
|
||||||
if new == "end":
|
if new == "end":
|
||||||
@@ -85,8 +85,8 @@ class BDIProgramManager(BaseAgent):
|
|||||||
if str(phase.id) == new:
|
if str(phase.id) == new:
|
||||||
self._phase = phase
|
self._phase = phase
|
||||||
|
|
||||||
self._send_beliefs_to_semantic_belief_extractor()
|
await self._send_beliefs_to_semantic_belief_extractor()
|
||||||
self._send_goals_to_semantic_belief_extractor()
|
await self._send_goals_to_semantic_belief_extractor()
|
||||||
|
|
||||||
# Notify user interaction agent
|
# Notify user interaction agent
|
||||||
msg = InternalMessage(
|
msg = InternalMessage(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
norms("").
|
norms("").
|
||||||
|
|
||||||
+user_said(Message) : norms(Norms) <-
|
+user_said(Message) : norms(Norms) <-
|
||||||
|
.notify_user_said(Message);
|
||||||
-user_said(Message);
|
-user_said(Message);
|
||||||
.reply(Message, Norms).
|
.reply(Message, Norms).
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
self.logger.debug("Received text from LLM: %s", msg.body)
|
self.logger.debug("Received text from LLM: %s", msg.body)
|
||||||
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
|
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
|
||||||
case settings.agent_settings.bdi_program_manager_name:
|
case settings.agent_settings.bdi_program_manager_name:
|
||||||
self._handle_program_manager_message(msg)
|
await self._handle_program_manager_message(msg)
|
||||||
case _:
|
case _:
|
||||||
self.logger.info("Discarding message from %s", sender)
|
self.logger.info("Discarding message from %s", sender)
|
||||||
return
|
return
|
||||||
@@ -105,7 +105,7 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
length_limit = settings.behaviour_settings.conversation_history_length_limit
|
length_limit = settings.behaviour_settings.conversation_history_length_limit
|
||||||
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:]
|
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:]
|
||||||
|
|
||||||
def _handle_program_manager_message(self, msg: InternalMessage):
|
async def _handle_program_manager_message(self, msg: InternalMessage):
|
||||||
"""
|
"""
|
||||||
Handle a message from the program manager: extract available beliefs and goals from it.
|
Handle a message from the program manager: extract available beliefs and goals from it.
|
||||||
|
|
||||||
@@ -114,8 +114,10 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
match msg.thread:
|
match msg.thread:
|
||||||
case "beliefs":
|
case "beliefs":
|
||||||
self._handle_beliefs_message(msg)
|
self._handle_beliefs_message(msg)
|
||||||
|
await self._infer_new_beliefs()
|
||||||
case "goals":
|
case "goals":
|
||||||
self._handle_goals_message(msg)
|
self._handle_goals_message(msg)
|
||||||
|
await self._infer_goal_completions()
|
||||||
case "conversation_history":
|
case "conversation_history":
|
||||||
if msg.body == "reset":
|
if msg.body == "reset":
|
||||||
self._reset()
|
self._reset()
|
||||||
@@ -141,8 +143,9 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
|
available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
|
||||||
self.belief_inferrer.available_beliefs = available_beliefs
|
self.belief_inferrer.available_beliefs = available_beliefs
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Received %d semantic beliefs from the program manager.",
|
"Received %d semantic beliefs from the program manager: %s",
|
||||||
len(available_beliefs),
|
len(available_beliefs),
|
||||||
|
", ".join(b.name for b in available_beliefs),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_goals_message(self, msg: InternalMessage):
|
def _handle_goals_message(self, msg: InternalMessage):
|
||||||
@@ -158,8 +161,9 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
available_goals = [g for g in goals_list.goals if g.can_fail]
|
available_goals = [g for g in goals_list.goals if g.can_fail]
|
||||||
self.goal_inferrer.goals = available_goals
|
self.goal_inferrer.goals = available_goals
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Received %d failable goals from the program manager.",
|
"Received %d failable goals from the program manager: %s",
|
||||||
len(available_goals),
|
len(available_goals),
|
||||||
|
", ".join(g.name for g in available_goals),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _user_said(self, text: str):
|
async def _user_said(self, text: str):
|
||||||
@@ -183,6 +187,7 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
|
|
||||||
new_beliefs = conversation_beliefs - self._current_beliefs
|
new_beliefs = conversation_beliefs - self._current_beliefs
|
||||||
if not new_beliefs:
|
if not new_beliefs:
|
||||||
|
self.logger.debug("No new beliefs detected.")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._current_beliefs |= new_beliefs
|
self._current_beliefs |= new_beliefs
|
||||||
@@ -217,6 +222,7 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
self._current_goal_completions[goal] = achieved
|
self._current_goal_completions[goal] = achieved
|
||||||
|
|
||||||
if not new_achieved and not new_not_achieved:
|
if not new_achieved and not new_not_achieved:
|
||||||
|
self.logger.debug("No goal achievement changes detected.")
|
||||||
return
|
return
|
||||||
|
|
||||||
belief_changes = BeliefMessage(
|
belief_changes = BeliefMessage(
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
|
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
|
||||||
)
|
)
|
||||||
|
|
||||||
if message["endpoint"] and message["endpoint"] != "ping":
|
if "endpoint" in message and message["endpoint"] != "ping":
|
||||||
self.logger.debug(f'Received message "{message}" from RI.')
|
self.logger.debug(f'Received message "{message}" from RI.')
|
||||||
if "endpoint" not in message:
|
if "endpoint" not in message:
|
||||||
self.logger.warning("No received endpoint in message, expected ping endpoint.")
|
self.logger.warning("No received endpoint in message, expected ping endpoint.")
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class BehaviourSettings(BaseModel):
|
|||||||
vad_prob_threshold: float = 0.5
|
vad_prob_threshold: float = 0.5
|
||||||
vad_initial_since_speech: int = 100
|
vad_initial_since_speech: int = 100
|
||||||
vad_non_speech_patience_chunks: int = 15
|
vad_non_speech_patience_chunks: int = 15
|
||||||
vad_begin_silence_chunks: int = 3
|
vad_begin_silence_chunks: int = 6
|
||||||
|
|
||||||
# transcription behaviour
|
# transcription behaviour
|
||||||
transcription_max_concurrent_tasks: int = 3
|
transcription_max_concurrent_tasks: int = 3
|
||||||
|
|||||||
@@ -108,8 +108,8 @@ async def test_send_clear_llm_history(mock_settings):
|
|||||||
|
|
||||||
await manager._send_clear_llm_history()
|
await manager._send_clear_llm_history()
|
||||||
|
|
||||||
assert manager.send.await_count == 1
|
assert manager.send.await_count == 2
|
||||||
msg: InternalMessage = manager.send.await_args[0][0]
|
msg: InternalMessage = manager.send.await_args_list[0][0][0]
|
||||||
|
|
||||||
# Verify the content and recipient
|
# Verify the content and recipient
|
||||||
assert msg.body == "clear_history"
|
assert msg.body == "clear_history"
|
||||||
|
|||||||
@@ -6,10 +6,13 @@ import httpx
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from control_backend.agents.bdi import TextBeliefExtractorAgent
|
from control_backend.agents.bdi import TextBeliefExtractorAgent
|
||||||
|
from control_backend.agents.bdi.text_belief_extractor_agent import BeliefState
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.belief_list import BeliefList
|
from control_backend.schemas.belief_list import BeliefList
|
||||||
|
from control_backend.schemas.belief_message import Belief as InternalBelief
|
||||||
from control_backend.schemas.belief_message import BeliefMessage
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
|
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
|
||||||
from control_backend.schemas.program import (
|
from control_backend.schemas.program import (
|
||||||
ConditionalNorm,
|
ConditionalNorm,
|
||||||
KeywordBelief,
|
KeywordBelief,
|
||||||
@@ -23,10 +26,20 @@ from control_backend.schemas.program import (
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def agent():
|
def llm():
|
||||||
|
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
|
||||||
|
llm._query_llm = AsyncMock()
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent(llm):
|
||||||
|
with patch(
|
||||||
|
"control_backend.agents.bdi.text_belief_extractor_agent.TextBeliefExtractorAgent.LLM",
|
||||||
|
return_value=llm,
|
||||||
|
):
|
||||||
agent = TextBeliefExtractorAgent("text_belief_agent")
|
agent = TextBeliefExtractorAgent("text_belief_agent")
|
||||||
agent.send = AsyncMock()
|
agent.send = AsyncMock()
|
||||||
agent._query_llm = AsyncMock()
|
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
@@ -102,24 +115,12 @@ async def test_handle_message_from_transcriber(agent, mock_settings):
|
|||||||
|
|
||||||
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
||||||
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
||||||
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
assert sent.to == mock_settings.agent_settings.bdi_core_name
|
||||||
assert sent.thread == "beliefs"
|
assert sent.thread == "beliefs"
|
||||||
parsed = json.loads(sent.body)
|
parsed = BeliefMessage.model_validate_json(sent.body)
|
||||||
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
|
replaced_last = parsed.replace.pop()
|
||||||
|
assert replaced_last.name == "user_said"
|
||||||
|
assert replaced_last.arguments == [transcription]
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_process_user_said(agent, mock_settings):
|
|
||||||
transcription = "this is a test"
|
|
||||||
|
|
||||||
await agent._user_said(transcription)
|
|
||||||
|
|
||||||
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
|
||||||
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
|
||||||
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
|
||||||
assert sent.thread == "beliefs"
|
|
||||||
parsed = json.loads(sent.body)
|
|
||||||
assert parsed["beliefs"]["user_said"] == [transcription]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -144,46 +145,46 @@ async def test_query_llm():
|
|||||||
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
|
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
|
||||||
return_value=mock_async_client,
|
return_value=mock_async_client,
|
||||||
):
|
):
|
||||||
agent = TextBeliefExtractorAgent("text_belief_agent")
|
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
|
||||||
|
|
||||||
res = await agent._query_llm("hello world", {"type": "null"})
|
res = await llm._query_llm("hello world", {"type": "null"})
|
||||||
# Response content was set as "null", so should be deserialized as None
|
# Response content was set as "null", so should be deserialized as None
|
||||||
assert res is None
|
assert res is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retry_query_llm_success(agent):
|
async def test_retry_query_llm_success(llm):
|
||||||
agent._query_llm.return_value = None
|
llm._query_llm.return_value = None
|
||||||
res = await agent._retry_query_llm("hello world", {"type": "null"})
|
res = await llm.query("hello world", {"type": "null"})
|
||||||
|
|
||||||
agent._query_llm.assert_called_once()
|
llm._query_llm.assert_called_once()
|
||||||
assert res is None
|
assert res is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retry_query_llm_success_after_failure(agent):
|
async def test_retry_query_llm_success_after_failure(llm):
|
||||||
agent._query_llm.side_effect = [KeyError(), "real value"]
|
llm._query_llm.side_effect = [KeyError(), "real value"]
|
||||||
res = await agent._retry_query_llm("hello world", {"type": "string"})
|
res = await llm.query("hello world", {"type": "string"})
|
||||||
|
|
||||||
assert agent._query_llm.call_count == 2
|
assert llm._query_llm.call_count == 2
|
||||||
assert res == "real value"
|
assert res == "real value"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retry_query_llm_failures(agent):
|
async def test_retry_query_llm_failures(llm):
|
||||||
agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
|
llm._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
|
||||||
res = await agent._retry_query_llm("hello world", {"type": "string"})
|
res = await llm.query("hello world", {"type": "string"})
|
||||||
|
|
||||||
assert agent._query_llm.call_count == 3
|
assert llm._query_llm.call_count == 3
|
||||||
assert res is None
|
assert res is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retry_query_llm_fail_immediately(agent):
|
async def test_retry_query_llm_fail_immediately(llm):
|
||||||
agent._query_llm.side_effect = [KeyError(), "real value"]
|
llm._query_llm.side_effect = [KeyError(), "real value"]
|
||||||
res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1)
|
res = await llm.query("hello world", {"type": "string"}, tries=1)
|
||||||
|
|
||||||
assert agent._query_llm.call_count == 1
|
assert llm._query_llm.call_count == 1
|
||||||
assert res is None
|
assert res is None
|
||||||
|
|
||||||
|
|
||||||
@@ -192,7 +193,7 @@ async def test_extracting_semantic_beliefs(agent):
|
|||||||
"""
|
"""
|
||||||
The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly.
|
The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly.
|
||||||
"""
|
"""
|
||||||
assert len(agent.available_beliefs) == 0
|
assert len(agent.belief_inferrer.available_beliefs) == 0
|
||||||
beliefs = BeliefList(
|
beliefs = BeliefList(
|
||||||
beliefs=[
|
beliefs=[
|
||||||
KeywordBelief(
|
KeywordBelief(
|
||||||
@@ -213,26 +214,28 @@ async def test_extracting_semantic_beliefs(agent):
|
|||||||
to=settings.agent_settings.text_belief_extractor_name,
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
sender=settings.agent_settings.bdi_program_manager_name,
|
sender=settings.agent_settings.bdi_program_manager_name,
|
||||||
body=beliefs.model_dump_json(),
|
body=beliefs.model_dump_json(),
|
||||||
|
thread="beliefs",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert len(agent.available_beliefs) == 2
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_invalid_program(agent, sample_program):
|
async def test_handle_invalid_beliefs(agent, sample_program):
|
||||||
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
assert len(agent.available_beliefs) == 2
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
||||||
|
|
||||||
await agent.handle_message(
|
await agent.handle_message(
|
||||||
InternalMessage(
|
InternalMessage(
|
||||||
to=settings.agent_settings.text_belief_extractor_name,
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
sender=settings.agent_settings.bdi_program_manager_name,
|
sender=settings.agent_settings.bdi_program_manager_name,
|
||||||
body=json.dumps({"phases": "Invalid"}),
|
body=json.dumps({"phases": "Invalid"}),
|
||||||
|
thread="beliefs",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(agent.available_beliefs) == 2
|
assert len(agent.belief_inferrer.available_beliefs) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -254,13 +257,13 @@ async def test_handle_robot_response(agent):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_simulated_real_turn_with_beliefs(agent, sample_program):
|
async def test_simulated_real_turn_with_beliefs(agent, llm, sample_program):
|
||||||
"""Test sending user message to extract beliefs from."""
|
"""Test sending user message to extract beliefs from."""
|
||||||
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
|
||||||
# Send a user message with the belief that there's no more booze
|
# Send a user message with the belief that there's no more booze
|
||||||
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
|
||||||
assert len(agent.conversation.messages) == 0
|
assert len(agent.conversation.messages) == 0
|
||||||
await agent.handle_message(
|
await agent.handle_message(
|
||||||
InternalMessage(
|
InternalMessage(
|
||||||
@@ -275,20 +278,20 @@ async def test_simulated_real_turn_with_beliefs(agent, sample_program):
|
|||||||
assert agent.send.call_count == 2
|
assert agent.send.call_count == 2
|
||||||
|
|
||||||
# First should be the beliefs message
|
# First should be the beliefs message
|
||||||
message: InternalMessage = agent.send.call_args_list[0].args[0]
|
message: InternalMessage = agent.send.call_args_list[1].args[0]
|
||||||
beliefs = BeliefMessage.model_validate_json(message.body)
|
beliefs = BeliefMessage.model_validate_json(message.body)
|
||||||
assert len(beliefs.create) == 1
|
assert len(beliefs.create) == 1
|
||||||
assert beliefs.create[0].name == "no_more_booze"
|
assert beliefs.create[0].name == "no_more_booze"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_simulated_real_turn_no_beliefs(agent, sample_program):
|
async def test_simulated_real_turn_no_beliefs(agent, llm, sample_program):
|
||||||
"""Test a user message to extract beliefs from, but no beliefs are formed."""
|
"""Test a user message to extract beliefs from, but no beliefs are formed."""
|
||||||
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
|
||||||
# Send a user message with no new beliefs
|
# Send a user message with no new beliefs
|
||||||
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
|
||||||
await agent.handle_message(
|
await agent.handle_message(
|
||||||
InternalMessage(
|
InternalMessage(
|
||||||
to=settings.agent_settings.text_belief_extractor_name,
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
@@ -302,17 +305,17 @@ async def test_simulated_real_turn_no_beliefs(agent, sample_program):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
|
async def test_simulated_real_turn_no_new_beliefs(agent, llm, sample_program):
|
||||||
"""
|
"""
|
||||||
Test a user message to extract beliefs from, but no new beliefs are formed because they already
|
Test a user message to extract beliefs from, but no new beliefs are formed because they already
|
||||||
existed.
|
existed.
|
||||||
"""
|
"""
|
||||||
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
agent.beliefs["is_pirate"] = True
|
agent._current_beliefs = BeliefState(true={InternalBelief(name="is_pirate", arguments=None)})
|
||||||
|
|
||||||
# Send a user message with the belief the user is a pirate, still
|
# Send a user message with the belief the user is a pirate, still
|
||||||
agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
|
llm._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
|
||||||
await agent.handle_message(
|
await agent.handle_message(
|
||||||
InternalMessage(
|
InternalMessage(
|
||||||
to=settings.agent_settings.text_belief_extractor_name,
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
@@ -326,17 +329,19 @@ async def test_simulated_real_turn_no_new_beliefs(agent, sample_program):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_simulated_real_turn_remove_belief(agent, sample_program):
|
async def test_simulated_real_turn_remove_belief(agent, llm, sample_program):
|
||||||
"""
|
"""
|
||||||
Test a user message to extract beliefs from, but an existing belief is determined no longer to
|
Test a user message to extract beliefs from, but an existing belief is determined no longer to
|
||||||
hold.
|
hold.
|
||||||
"""
|
"""
|
||||||
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
agent.beliefs["no_more_booze"] = True
|
agent._current_beliefs = BeliefState(
|
||||||
|
true={InternalBelief(name="no_more_booze", arguments=None)},
|
||||||
|
)
|
||||||
|
|
||||||
# Send a user message with the belief the user is a pirate, still
|
# Send a user message with the belief the user is a pirate, still
|
||||||
agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
|
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
|
||||||
await agent.handle_message(
|
await agent.handle_message(
|
||||||
InternalMessage(
|
InternalMessage(
|
||||||
to=settings.agent_settings.text_belief_extractor_name,
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
@@ -349,18 +354,23 @@ async def test_simulated_real_turn_remove_belief(agent, sample_program):
|
|||||||
assert agent.send.call_count == 2
|
assert agent.send.call_count == 2
|
||||||
|
|
||||||
# Agent's current beliefs should've changed
|
# Agent's current beliefs should've changed
|
||||||
assert not agent.beliefs["no_more_booze"]
|
assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_failure_handling(agent, sample_program):
|
async def test_llm_failure_handling(agent, llm, sample_program):
|
||||||
"""
|
"""
|
||||||
Check that the agent handles failures gracefully without crashing.
|
Check that the agent handles failures gracefully without crashing.
|
||||||
"""
|
"""
|
||||||
agent._query_llm.side_effect = httpx.HTTPError("")
|
llm._query_llm.side_effect = httpx.HTTPError("")
|
||||||
agent.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||||
agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||||
|
|
||||||
belief_changes = await agent._infer_turn()
|
belief_changes = await agent.belief_inferrer.infer_from_conversation(
|
||||||
|
ChatHistory(
|
||||||
|
messages=[ChatMessage(role="user", content="Good day!")],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
assert len(belief_changes) == 0
|
assert len(belief_changes.true) == 0
|
||||||
|
assert len(belief_changes.false) == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user