298 lines
10 KiB
Python
298 lines
10 KiB
Python
import asyncio
|
|
import json
|
|
import sys
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
|
|
|
import pytest
|
|
|
|
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program
|
|
|
|
# Fix Windows Proactor loop for zmq
|
|
if sys.platform.startswith("win"):
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
|
|
def make_valid_program_json(norm="N1", goal="G1") -> str:
|
|
return Program(
|
|
phases=[
|
|
Phase(
|
|
id=uuid.uuid4(),
|
|
name="Basic Phase",
|
|
norms=[
|
|
BasicNorm(
|
|
id=uuid.uuid4(),
|
|
name=norm,
|
|
norm=norm,
|
|
),
|
|
],
|
|
goals=[
|
|
Goal(
|
|
id=uuid.uuid4(),
|
|
name=goal,
|
|
description="This description can be used to determine whether the goal "
|
|
"has been achieved.",
|
|
plan=Plan(
|
|
id=uuid.uuid4(),
|
|
name="Goal Plan",
|
|
steps=[],
|
|
),
|
|
can_fail=False,
|
|
),
|
|
],
|
|
triggers=[],
|
|
),
|
|
],
|
|
).model_dump_json()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_agentspeak_and_send_to_bdi(mock_settings):
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager.send = AsyncMock()
|
|
|
|
program = Program.model_validate_json(make_valid_program_json())
|
|
|
|
with patch("builtins.open", mock_open()) as mock_file:
|
|
await manager._create_agentspeak_and_send_to_bdi(program)
|
|
|
|
# Check file writing
|
|
mock_file.assert_called_with(mock_settings.behaviour_settings.agentspeak_file, "w")
|
|
handle = mock_file()
|
|
handle.write.assert_called()
|
|
|
|
assert manager.send.await_count == 1
|
|
msg: InternalMessage = manager.send.await_args[0][0]
|
|
assert msg.thread == "new_program"
|
|
assert msg.to == mock_settings.agent_settings.bdi_core_name
|
|
assert msg.body == mock_settings.behaviour_settings.agentspeak_file
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_programs_valid_and_invalid():
|
|
sub = AsyncMock()
|
|
sub.recv_multipart.side_effect = [
|
|
(b"program", b"{bad json"),
|
|
(b"program", make_valid_program_json().encode()),
|
|
]
|
|
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager._internal_pub_socket = AsyncMock()
|
|
manager.sub_socket = sub
|
|
manager._create_agentspeak_and_send_to_bdi = AsyncMock()
|
|
manager._send_clear_llm_history = AsyncMock()
|
|
manager._send_program_to_user_interrupt = AsyncMock()
|
|
manager._send_beliefs_to_semantic_belief_extractor = AsyncMock()
|
|
manager._send_goals_to_semantic_belief_extractor = AsyncMock()
|
|
|
|
try:
|
|
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
|
|
await manager._receive_programs()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
# Only valid Program should have triggered _send_to_bdi
|
|
assert manager._create_agentspeak_and_send_to_bdi.await_count == 1
|
|
forwarded: Program = manager._create_agentspeak_and_send_to_bdi.await_args[0][0]
|
|
assert forwarded.phases[0].norms[0].name == "N1"
|
|
assert forwarded.phases[0].goals[0].name == "G1"
|
|
|
|
# Verify history clear was triggered exactly once (for the valid program)
|
|
# The invalid program loop `continue`s before calling _send_clear_llm_history
|
|
assert manager._send_clear_llm_history.await_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_clear_llm_history(mock_settings):
|
|
# Ensure the mock returns a string for the agent name (just like in your LLM tests)
|
|
mock_settings.agent_settings.llm_agent_name = "llm_agent"
|
|
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager.send = AsyncMock()
|
|
|
|
await manager._send_clear_llm_history()
|
|
|
|
assert manager.send.await_count == 2
|
|
msg: InternalMessage = manager.send.await_args_list[0][0][0]
|
|
|
|
# Verify the content and recipient
|
|
assert msg.body == "clear_history"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_transition_phase(mock_settings):
|
|
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager.send = AsyncMock()
|
|
|
|
# Setup state
|
|
prog = Program.model_validate_json(make_valid_program_json(norm="N1", goal="G1"))
|
|
manager._initialize_internal_state(prog)
|
|
|
|
# Test valid transition (to same phase for simplicity, or we need 2 phases)
|
|
# Let's create a program with 2 phases
|
|
phase2_id = uuid.uuid4()
|
|
phase2 = Phase(id=phase2_id, name="Phase 2", norms=[], goals=[], triggers=[])
|
|
prog.phases.append(phase2)
|
|
manager._initialize_internal_state(prog)
|
|
|
|
current_phase_id = str(prog.phases[0].id)
|
|
next_phase_id = str(phase2_id)
|
|
|
|
payload = json.dumps({"old": current_phase_id, "new": next_phase_id})
|
|
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
|
|
|
|
await manager.handle_message(msg)
|
|
|
|
assert str(manager._phase.id) == next_phase_id
|
|
|
|
# Allow background tasks to run (add_behavior)
|
|
await asyncio.sleep(0)
|
|
|
|
# Check notifications sent
|
|
# 1. beliefs to extractor
|
|
# 2. goals to extractor
|
|
# 3. notification to user interrupt
|
|
|
|
assert manager.send.await_count >= 3
|
|
|
|
# Verify user interrupt notification
|
|
calls = manager.send.await_args_list
|
|
ui_msgs = [
|
|
c[0][0] for c in calls if c[0][0].to == mock_settings.agent_settings.user_interrupt_name
|
|
]
|
|
assert len(ui_msgs) > 0
|
|
assert ui_msgs[-1].body == next_phase_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_transition_phase_desync():
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager.logger = MagicMock()
|
|
|
|
prog = Program.model_validate_json(make_valid_program_json())
|
|
manager._initialize_internal_state(prog)
|
|
|
|
current_phase_id = str(prog.phases[0].id)
|
|
|
|
# Request transition from WRONG old phase
|
|
payload = json.dumps({"old": "wrong_id", "new": "some_new_id"})
|
|
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
|
|
|
|
await manager.handle_message(msg)
|
|
|
|
# Should warn and do nothing
|
|
manager.logger.warning.assert_called_once()
|
|
assert "Phase transition desync detected" in manager.logger.warning.call_args[0][0]
|
|
assert str(manager._phase.id) == current_phase_id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_transition_phase_end(mock_settings):
|
|
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager.send = AsyncMock()
|
|
|
|
prog = Program.model_validate_json(make_valid_program_json())
|
|
manager._initialize_internal_state(prog)
|
|
current_phase_id = str(prog.phases[0].id)
|
|
|
|
payload = json.dumps({"old": current_phase_id, "new": "end"})
|
|
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
|
|
|
|
await manager.handle_message(msg)
|
|
|
|
assert manager._phase is None
|
|
|
|
# Allow background tasks to run (add_behavior)
|
|
await asyncio.sleep(0)
|
|
|
|
# Verify notification to user interrupt
|
|
assert manager.send.await_count == 1
|
|
msg_sent = manager.send.await_args[0][0]
|
|
assert msg_sent.to == mock_settings.agent_settings.user_interrupt_name
|
|
assert msg_sent.body == "end"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_achieve_goal(mock_settings):
|
|
mock_settings.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent"
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager.send = AsyncMock()
|
|
|
|
prog = Program.model_validate_json(make_valid_program_json(goal="TargetGoal"))
|
|
manager._initialize_internal_state(prog)
|
|
|
|
goal_id = str(prog.phases[0].goals[0].id)
|
|
|
|
msg = InternalMessage(to="me", sender="ui", body=goal_id, thread="achieve_goal")
|
|
|
|
await manager.handle_message(msg)
|
|
|
|
# Should send achieved goals to text extractor
|
|
assert manager.send.await_count == 1
|
|
msg_sent = manager.send.await_args[0][0]
|
|
assert msg_sent.to == mock_settings.agent_settings.text_belief_extractor_name
|
|
assert msg_sent.thread == "achieved_goals"
|
|
|
|
# Verify body
|
|
from control_backend.schemas.belief_list import GoalList
|
|
|
|
gl = GoalList.model_validate_json(msg_sent.body)
|
|
assert len(gl.goals) == 1
|
|
assert gl.goals[0].name == "TargetGoal"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_achieve_goal_not_found():
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager.send = AsyncMock()
|
|
manager.logger = MagicMock()
|
|
|
|
prog = Program.model_validate_json(make_valid_program_json())
|
|
manager._initialize_internal_state(prog)
|
|
|
|
msg = InternalMessage(to="me", sender="ui", body="non_existent_id", thread="achieve_goal")
|
|
|
|
await manager.handle_message(msg)
|
|
|
|
manager.send.assert_not_called()
|
|
manager.logger.debug.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup(mock_settings):
|
|
manager = BDIProgramManager(name="program_manager_test")
|
|
manager.send = AsyncMock()
|
|
|
|
def close_coro(coro):
|
|
coro.close()
|
|
return MagicMock()
|
|
|
|
manager.add_behavior = MagicMock(side_effect=close_coro)
|
|
|
|
mock_context = MagicMock()
|
|
mock_sub = MagicMock()
|
|
mock_context.socket.return_value = mock_sub
|
|
|
|
with patch(
|
|
"control_backend.agents.bdi.bdi_program_manager.Context.instance", return_value=mock_context
|
|
):
|
|
# We also need to mock file writing in _create_agentspeak_and_send_to_bdi
|
|
with patch("builtins.open", new_callable=MagicMock):
|
|
await manager.setup()
|
|
|
|
# Check logic
|
|
# 1. Sends default empty program to BDI
|
|
assert manager.send.await_count == 1
|
|
assert manager.send.await_args[0][0].to == mock_settings.agent_settings.bdi_core_name
|
|
|
|
# 2. Connects SUB socket
|
|
mock_sub.connect.assert_called_with(mock_settings.zmq_settings.internal_sub_address)
|
|
mock_sub.subscribe.assert_called_with("program")
|
|
|
|
# 3. Adds behavior
|
|
manager.add_behavior.assert_called()
|