580 lines
19 KiB
Python
580 lines
19 KiB
Python
import asyncio
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.core.config import settings
|
|
from control_backend.schemas.belief_message import BeliefMessage
|
|
from control_backend.schemas.program import (
|
|
ConditionalNorm,
|
|
Goal,
|
|
KeywordBelief,
|
|
Phase,
|
|
Plan,
|
|
Program,
|
|
Trigger,
|
|
)
|
|
from control_backend.schemas.ri_message import RIEndpoint
|
|
|
|
|
|
@pytest.fixture
|
|
def agent():
|
|
agent = UserInterruptAgent(name="user_interrupt_agent")
|
|
agent.send = AsyncMock()
|
|
agent.logger = MagicMock()
|
|
agent.sub_socket = AsyncMock()
|
|
agent.pub_socket = AsyncMock()
|
|
return agent
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_speech_agent(agent):
|
|
"""Verify speech command format."""
|
|
await agent._send_to_speech_agent("Hello World")
|
|
|
|
agent.send.assert_awaited_once()
|
|
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.robot_speech_name
|
|
body = json.loads(sent_msg.body)
|
|
assert body["data"] == "Hello World"
|
|
assert body["is_priority"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_gesture_agent(agent):
|
|
"""Verify gesture command format."""
|
|
await agent._send_to_gesture_agent("wave_hand")
|
|
|
|
agent.send.assert_awaited_once()
|
|
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.robot_gesture_name
|
|
body = json.loads(sent_msg.body)
|
|
assert body["data"] == "wave_hand"
|
|
assert body["is_priority"] is True
|
|
assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_bdi_belief(agent):
|
|
"""Verify belief update format."""
|
|
context_str = "some_goal"
|
|
|
|
await agent._send_to_bdi_belief(context_str, "goal")
|
|
|
|
assert agent.send.await_count == 1
|
|
sent_msg = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.bdi_core_name
|
|
assert sent_msg.thread == "beliefs"
|
|
assert "achieved_some_goal" in sent_msg.body
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_routing_success(agent):
|
|
"""
|
|
Test that the loop correctly:
|
|
1. Receives 'button_pressed' topic from ZMQ
|
|
2. Parses the JSON payload to find 'type' and 'context'
|
|
3. Calls the correct handler method based on 'type'
|
|
"""
|
|
# Prepare JSON payloads as bytes
|
|
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
|
|
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
|
|
# override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal).
|
|
|
|
# To test routing, we need to populate the maps
|
|
agent._goal_map["Hello Override"] = "some_goal_slug"
|
|
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"button_pressed", payload_speech),
|
|
(b"button_pressed", payload_gesture),
|
|
(b"button_pressed", payload_override),
|
|
asyncio.CancelledError, # Stop the infinite loop
|
|
]
|
|
|
|
agent._send_to_speech_agent = AsyncMock()
|
|
agent._send_to_gesture_agent = AsyncMock()
|
|
agent._send_to_bdi_belief = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
# Speech
|
|
agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech")
|
|
|
|
# Gesture
|
|
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
|
|
|
|
# Override (since we mapped it to a goal)
|
|
agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug", "goal")
|
|
|
|
assert agent._send_to_speech_agent.await_count == 1
|
|
assert agent._send_to_gesture_agent.await_count == 1
|
|
assert agent._send_to_bdi_belief.await_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_unknown_type(agent):
|
|
"""Test that unknown 'type' values in the JSON log a warning and do not crash."""
|
|
|
|
# Prepare a payload with an unknown type
|
|
payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"button_pressed", payload_unknown),
|
|
asyncio.CancelledError,
|
|
]
|
|
|
|
agent._send_to_speech_agent = AsyncMock()
|
|
agent._send_to_gesture_agent = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
# Ensure no handlers were called
|
|
agent._send_to_speech_agent.assert_not_called()
|
|
agent._send_to_gesture_agent.assert_not_called()
|
|
|
|
agent.logger.warning.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_mapping(agent):
|
|
# Create a program with a trigger, goal, and conditional norm
|
|
import uuid
|
|
|
|
trigger_id = uuid.uuid4()
|
|
goal_id = uuid.uuid4()
|
|
norm_id = uuid.uuid4()
|
|
|
|
cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key")
|
|
plan = Plan(id=uuid.uuid4(), name="p1", steps=[])
|
|
|
|
trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan)
|
|
goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan)
|
|
|
|
cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond)
|
|
|
|
phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger])
|
|
prog = Program(phases=[phase])
|
|
|
|
# Call create_mapping via handle_message
|
|
msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json())
|
|
await agent.handle_message(msg)
|
|
|
|
# Check maps
|
|
assert str(trigger_id) in agent._trigger_map
|
|
assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger"
|
|
|
|
assert str(goal_id) in agent._goal_map
|
|
assert agent._goal_map[str(goal_id)] == "my_goal"
|
|
|
|
assert str(norm_id) in agent._cond_norm_map
|
|
assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_mapping_invalid_json(agent):
|
|
# Pass invalid json to handle_message thread "new_program"
|
|
msg = InternalMessage(to="me", thread="new_program", body="invalid json")
|
|
await agent.handle_message(msg)
|
|
|
|
# Should log error and maps should remain empty or cleared
|
|
agent.logger.error.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_trigger_start(agent):
|
|
# Setup reverse map manually
|
|
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
|
|
|
|
msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
args = agent.pub_socket.send_multipart.call_args[0][0]
|
|
assert args[0] == b"experiment"
|
|
payload = json.loads(args[1])
|
|
assert payload["type"] == "trigger_update"
|
|
assert payload["id"] == "ui_id_123"
|
|
assert payload["achieved"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_trigger_end(agent):
|
|
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
|
|
|
|
msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
|
assert payload["type"] == "trigger_update"
|
|
assert payload["achieved"] is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_transition_phase(agent):
|
|
msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
|
assert payload["type"] == "phase_update"
|
|
assert payload["id"] == "phase_id_123"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_goal_start(agent):
|
|
agent._goal_reverse_map["goal_slug"] = "goal_id_123"
|
|
|
|
msg = InternalMessage(to="me", thread="goal_start", body="goal_slug")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
|
assert payload["type"] == "goal_update"
|
|
assert payload["id"] == "goal_id_123"
|
|
assert payload["active"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_active_norms_update(agent):
|
|
agent._cond_norm_reverse_map["norm_active"] = "id_1"
|
|
agent._cond_norm_reverse_map["norm_inactive"] = "id_2"
|
|
|
|
# Body is like: "('norm_active', 'other')"
|
|
# The split logic handles quotes etc.
|
|
msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
|
assert payload["type"] == "cond_norms_state_update"
|
|
norms = {n["id"]: n["active"] for n in payload["norms"]}
|
|
assert norms["id_1"] is True
|
|
assert norms["id_2"] is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_experiment_control(agent):
|
|
# Test next_phase
|
|
await agent._send_experiment_control_to_bdi_core("next_phase")
|
|
agent.send.assert_awaited()
|
|
msg = agent.send.call_args[0][0]
|
|
assert msg.thread == "force_next_phase"
|
|
|
|
# Test reset_phase
|
|
await agent._send_experiment_control_to_bdi_core("reset_phase")
|
|
msg = agent.send.call_args[0][0]
|
|
assert msg.thread == "reset_current_phase"
|
|
|
|
# Test reset_experiment
|
|
await agent._send_experiment_control_to_bdi_core("reset_experiment")
|
|
msg = agent.send.call_args[0][0]
|
|
assert msg.thread == "reset_experiment"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_pause_command(agent):
|
|
await agent._send_pause_command("true")
|
|
# Sends to RI and VAD
|
|
assert agent.send.await_count == 2
|
|
msgs = [call.args[0] for call in agent.send.call_args_list]
|
|
|
|
ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name)
|
|
assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint
|
|
assert json.loads(ri_msg.body)["data"] is True
|
|
|
|
vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name)
|
|
assert vad_msg.body == "PAUSE"
|
|
|
|
agent.send.reset_mock()
|
|
await agent._send_pause_command("false")
|
|
assert agent.send.await_count == 2
|
|
vad_msg = next(
|
|
m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name
|
|
).args[0]
|
|
assert vad_msg.body == "RESUME"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup(agent):
|
|
"""Test the setup method initializes sockets correctly."""
|
|
with patch("control_backend.agents.user_interrupt.user_interrupt_agent.Context") as MockContext:
|
|
mock_ctx_instance = MagicMock()
|
|
MockContext.instance.return_value = mock_ctx_instance
|
|
|
|
mock_sub = MagicMock()
|
|
mock_pub = MagicMock()
|
|
mock_ctx_instance.socket.side_effect = [mock_sub, mock_pub]
|
|
|
|
# MOCK add_behavior so we don't rely on internal attributes
|
|
agent.add_behavior = MagicMock()
|
|
|
|
await agent.setup()
|
|
|
|
# Check sockets
|
|
mock_sub.connect.assert_called_with(settings.zmq_settings.internal_sub_address)
|
|
mock_pub.connect.assert_called_with(settings.zmq_settings.internal_pub_address)
|
|
|
|
# Verify add_behavior was called
|
|
agent.add_behavior.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_json_error(agent):
|
|
"""Verify that malformed JSON is caught and logged without crashing the loop."""
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"topic", b"INVALID{JSON"),
|
|
asyncio.CancelledError,
|
|
]
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_override_trigger(agent):
|
|
"""Verify routing 'override' to a Trigger."""
|
|
agent._trigger_map["101"] = "trigger_slug"
|
|
payload = json.dumps({"type": "override", "context": "101"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
|
agent._send_to_bdi = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
agent._send_to_bdi.assert_awaited_once_with("force_trigger", "trigger_slug")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_override_norm(agent):
|
|
"""Verify routing 'override' to a Conditional Norm."""
|
|
agent._cond_norm_map["202"] = "norm_slug"
|
|
payload = json.dumps({"type": "override", "context": "202"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
|
agent._send_to_bdi_belief = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
agent._send_to_bdi_belief.assert_awaited_once_with("norm_slug", "cond_norm")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_override_missing(agent):
|
|
"""Verify warning log when an override ID is not found in any map."""
|
|
payload = json.dumps({"type": "override", "context": "999"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
agent.logger.warning.assert_called_with("Could not determine which element to override.")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_unachieve_logic(agent):
|
|
"""Verify success and failure paths for override_unachieve."""
|
|
agent._cond_norm_map["202"] = "norm_slug"
|
|
|
|
success_payload = json.dumps({"type": "override_unachieve", "context": "202"}).encode()
|
|
fail_payload = json.dumps({"type": "override_unachieve", "context": "999"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"topic", success_payload),
|
|
(b"topic", fail_payload),
|
|
asyncio.CancelledError,
|
|
]
|
|
agent._send_to_bdi_belief = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Assert success call (True flag for unachieve)
|
|
agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True)
|
|
# Assert failure log
|
|
agent.logger.warning.assert_called_with(
|
|
"Could not determine which conditional norm to unachieve."
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_pause_resume(agent):
|
|
"""Verify pause and resume toggle logic and logging."""
|
|
pause_payload = json.dumps({"type": "pause", "context": "true"}).encode()
|
|
resume_payload = json.dumps({"type": "pause", "context": ""}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"topic", pause_payload),
|
|
(b"topic", resume_payload),
|
|
asyncio.CancelledError,
|
|
]
|
|
agent._send_pause_command = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
agent._send_pause_command.assert_any_call("true")
|
|
agent._send_pause_command.assert_any_call("")
|
|
agent.logger.info.assert_any_call("Sent pause command.")
|
|
agent.logger.info.assert_any_call("Sent resume command.")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_phase_control(agent):
|
|
"""Verify experiment flow control (next_phase)."""
|
|
payload = json.dumps({"type": "next_phase", "context": ""}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
|
agent._send_experiment_control_to_bdi_core = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
agent._send_experiment_control_to_bdi_core.assert_awaited_once_with("next_phase")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_unknown_thread(agent):
|
|
"""Test handling of an unknown message thread (lines 213-214)."""
|
|
msg = InternalMessage(to="me", thread="unknown_thread", body="test")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.logger.debug.assert_called_with(
|
|
"Received internal message on unhandled thread: unknown_thread"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_bdi_belief_edge_cases(agent):
|
|
"""
|
|
Covers:
|
|
- Unknown asl_type warning (lines 326-328)
|
|
- unachieve=True logic (lines 334-337)
|
|
"""
|
|
# 1. Unknown Type
|
|
await agent._send_to_bdi_belief("slug", "unknown_type")
|
|
|
|
agent.logger.warning.assert_called_with("Tried to send belief with unknown type")
|
|
agent.send.assert_not_called()
|
|
|
|
# Reset mock for part 2
|
|
agent.send.reset_mock()
|
|
|
|
# 2. Unachieve = True
|
|
await agent._send_to_bdi_belief("slug", "cond_norm", unachieve=True)
|
|
|
|
agent.send.assert_awaited()
|
|
sent_msg = agent.send.call_args.args[0]
|
|
|
|
# Verify it is a delete operation
|
|
body_obj = BeliefMessage.model_validate_json(sent_msg.body)
|
|
|
|
# Verify 'delete' has content
|
|
assert body_obj.delete is not None
|
|
assert len(body_obj.delete) == 1
|
|
assert body_obj.delete[0].name == "force_slug"
|
|
|
|
# Verify 'create' is empty (handling both None and [])
|
|
assert not body_obj.create
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_experiment_control_unknown(agent):
|
|
"""Test sending an unknown experiment control type (lines 366-367)."""
|
|
await agent._send_experiment_control_to_bdi_core("invalid_command")
|
|
|
|
agent.logger.warning.assert_called_with(
|
|
"Received unknown experiment control type '%s' to send to BDI Core.", "invalid_command"
|
|
)
|
|
|
|
# Ensure it still sends an empty message (as per code logic, though thread is empty)
|
|
agent.send.assert_awaited()
|
|
msg = agent.send.call_args[0][0]
|
|
assert msg.thread == ""
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_mapping_recursive_goals(agent):
|
|
"""Verify that nested subgoals are correctly registered in the mapping."""
|
|
import uuid
|
|
|
|
# 1. Setup IDs
|
|
parent_goal_id = uuid.uuid4()
|
|
child_goal_id = uuid.uuid4()
|
|
|
|
# 2. Create the child goal
|
|
child_goal = Goal(
|
|
id=child_goal_id,
|
|
name="child_goal",
|
|
description="I am a subgoal",
|
|
plan=Plan(id=uuid.uuid4(), name="p_child", steps=[]),
|
|
)
|
|
|
|
# 3. Create the parent goal and put the child goal inside its plan steps
|
|
parent_goal = Goal(
|
|
id=parent_goal_id,
|
|
name="parent_goal",
|
|
description="I am a parent",
|
|
plan=Plan(id=uuid.uuid4(), name="p_parent", steps=[child_goal]), # Nested here
|
|
)
|
|
|
|
# 4. Build the program
|
|
phase = Phase(
|
|
id=uuid.uuid4(),
|
|
name="phase1",
|
|
norms=[],
|
|
goals=[parent_goal], # Only the parent is top-level
|
|
triggers=[],
|
|
)
|
|
prog = Program(phases=[phase])
|
|
|
|
# 5. Execute mapping
|
|
msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json())
|
|
await agent.handle_message(msg)
|
|
|
|
# 6. Assertions
|
|
# Check parent
|
|
assert str(parent_goal_id) in agent._goal_map
|
|
assert agent._goal_map[str(parent_goal_id)] == "parent_goal"
|
|
|
|
# Check child (This confirms the recursion worked)
|
|
assert str(child_goal_id) in agent._goal_map
|
|
assert agent._goal_map[str(child_goal_id)] == "child_goal"
|
|
assert agent._goal_reverse_map["child_goal"] == str(child_goal_id)
|