312 lines
10 KiB
Python
312 lines
10 KiB
Python
import asyncio
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.core.config import settings
|
|
from control_backend.schemas.program import (
|
|
ConditionalNorm,
|
|
Goal,
|
|
KeywordBelief,
|
|
Phase,
|
|
Plan,
|
|
Program,
|
|
Trigger,
|
|
)
|
|
from control_backend.schemas.ri_message import RIEndpoint
|
|
|
|
|
|
@pytest.fixture
|
|
def agent():
|
|
agent = UserInterruptAgent(name="user_interrupt_agent")
|
|
agent.send = AsyncMock()
|
|
agent.logger = MagicMock()
|
|
agent.sub_socket = AsyncMock()
|
|
agent.pub_socket = AsyncMock()
|
|
return agent
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_speech_agent(agent):
|
|
"""Verify speech command format."""
|
|
await agent._send_to_speech_agent("Hello World")
|
|
|
|
agent.send.assert_awaited_once()
|
|
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.robot_speech_name
|
|
body = json.loads(sent_msg.body)
|
|
assert body["data"] == "Hello World"
|
|
assert body["is_priority"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_gesture_agent(agent):
|
|
"""Verify gesture command format."""
|
|
await agent._send_to_gesture_agent("wave_hand")
|
|
|
|
agent.send.assert_awaited_once()
|
|
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.robot_gesture_name
|
|
body = json.loads(sent_msg.body)
|
|
assert body["data"] == "wave_hand"
|
|
assert body["is_priority"] is True
|
|
assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_bdi_belief(agent):
|
|
"""Verify belief update format."""
|
|
context_str = "some_goal"
|
|
|
|
await agent._send_to_bdi_belief(context_str, "goal")
|
|
|
|
assert agent.send.await_count == 1
|
|
sent_msg = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.bdi_core_name
|
|
assert sent_msg.thread == "beliefs"
|
|
assert "achieved_some_goal" in sent_msg.body
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_routing_success(agent):
|
|
"""
|
|
Test that the loop correctly:
|
|
1. Receives 'button_pressed' topic from ZMQ
|
|
2. Parses the JSON payload to find 'type' and 'context'
|
|
3. Calls the correct handler method based on 'type'
|
|
"""
|
|
# Prepare JSON payloads as bytes
|
|
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
|
|
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
|
|
# override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal).
|
|
|
|
# To test routing, we need to populate the maps
|
|
agent._goal_map["Hello Override"] = "some_goal_slug"
|
|
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"button_pressed", payload_speech),
|
|
(b"button_pressed", payload_gesture),
|
|
(b"button_pressed", payload_override),
|
|
asyncio.CancelledError, # Stop the infinite loop
|
|
]
|
|
|
|
agent._send_to_speech_agent = AsyncMock()
|
|
agent._send_to_gesture_agent = AsyncMock()
|
|
agent._send_to_bdi_belief = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
# Speech
|
|
agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech")
|
|
|
|
# Gesture
|
|
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
|
|
|
|
# Override (since we mapped it to a goal)
|
|
agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug", "goal")
|
|
|
|
assert agent._send_to_speech_agent.await_count == 1
|
|
assert agent._send_to_gesture_agent.await_count == 1
|
|
assert agent._send_to_bdi_belief.await_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_unknown_type(agent):
|
|
"""Test that unknown 'type' values in the JSON log a warning and do not crash."""
|
|
|
|
# Prepare a payload with an unknown type
|
|
payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"button_pressed", payload_unknown),
|
|
asyncio.CancelledError,
|
|
]
|
|
|
|
agent._send_to_speech_agent = AsyncMock()
|
|
agent._send_to_gesture_agent = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
# Ensure no handlers were called
|
|
agent._send_to_speech_agent.assert_not_called()
|
|
agent._send_to_gesture_agent.assert_not_called()
|
|
|
|
agent.logger.warning.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_mapping(agent):
|
|
# Create a program with a trigger, goal, and conditional norm
|
|
import uuid
|
|
|
|
trigger_id = uuid.uuid4()
|
|
goal_id = uuid.uuid4()
|
|
norm_id = uuid.uuid4()
|
|
|
|
cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key")
|
|
plan = Plan(id=uuid.uuid4(), name="p1", steps=[])
|
|
|
|
trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan)
|
|
goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan)
|
|
|
|
cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond)
|
|
|
|
phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger])
|
|
prog = Program(phases=[phase])
|
|
|
|
# Call create_mapping via handle_message
|
|
msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json())
|
|
await agent.handle_message(msg)
|
|
|
|
# Check maps
|
|
assert str(trigger_id) in agent._trigger_map
|
|
assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger"
|
|
|
|
assert str(goal_id) in agent._goal_map
|
|
assert agent._goal_map[str(goal_id)] == "my_goal"
|
|
|
|
assert str(norm_id) in agent._cond_norm_map
|
|
assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_mapping_invalid_json(agent):
|
|
# Pass invalid json to handle_message thread "new_program"
|
|
msg = InternalMessage(to="me", thread="new_program", body="invalid json")
|
|
await agent.handle_message(msg)
|
|
|
|
# Should log error and maps should remain empty or cleared
|
|
agent.logger.error.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_trigger_start(agent):
|
|
# Setup reverse map manually
|
|
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
|
|
|
|
msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
args = agent.pub_socket.send_multipart.call_args[0][0]
|
|
assert args[0] == b"experiment"
|
|
payload = json.loads(args[1])
|
|
assert payload["type"] == "trigger_update"
|
|
assert payload["id"] == "ui_id_123"
|
|
assert payload["achieved"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_trigger_end(agent):
|
|
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
|
|
|
|
msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
|
assert payload["type"] == "trigger_update"
|
|
assert payload["achieved"] is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_transition_phase(agent):
|
|
msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
|
assert payload["type"] == "phase_update"
|
|
assert payload["id"] == "phase_id_123"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_goal_start(agent):
|
|
agent._goal_reverse_map["goal_slug"] = "goal_id_123"
|
|
|
|
msg = InternalMessage(to="me", thread="goal_start", body="goal_slug")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
|
assert payload["type"] == "goal_update"
|
|
assert payload["id"] == "goal_id_123"
|
|
assert payload["active"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_active_norms_update(agent):
|
|
agent._cond_norm_reverse_map["norm_active"] = "id_1"
|
|
agent._cond_norm_reverse_map["norm_inactive"] = "id_2"
|
|
|
|
# Body is like: "('norm_active', 'other')"
|
|
# The split logic handles quotes etc.
|
|
msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'")
|
|
await agent.handle_message(msg)
|
|
|
|
agent.pub_socket.send_multipart.assert_awaited_once()
|
|
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
|
assert payload["type"] == "cond_norms_state_update"
|
|
norms = {n["id"]: n["active"] for n in payload["norms"]}
|
|
assert norms["id_1"] is True
|
|
assert norms["id_2"] is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_experiment_control(agent):
|
|
# Test next_phase
|
|
await agent._send_experiment_control_to_bdi_core("next_phase")
|
|
agent.send.assert_awaited()
|
|
msg = agent.send.call_args[0][0]
|
|
assert msg.thread == "force_next_phase"
|
|
|
|
# Test reset_phase
|
|
await agent._send_experiment_control_to_bdi_core("reset_phase")
|
|
msg = agent.send.call_args[0][0]
|
|
assert msg.thread == "reset_current_phase"
|
|
|
|
# Test reset_experiment
|
|
await agent._send_experiment_control_to_bdi_core("reset_experiment")
|
|
msg = agent.send.call_args[0][0]
|
|
assert msg.thread == "reset_experiment"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_pause_command(agent):
|
|
await agent._send_pause_command("true")
|
|
# Sends to RI and VAD
|
|
assert agent.send.await_count == 2
|
|
msgs = [call.args[0] for call in agent.send.call_args_list]
|
|
|
|
ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name)
|
|
assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint
|
|
assert json.loads(ri_msg.body)["data"] is True
|
|
|
|
vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name)
|
|
assert vad_msg.body == "PAUSE"
|
|
|
|
agent.send.reset_mock()
|
|
await agent._send_pause_command("false")
|
|
assert agent.send.await_count == 2
|
|
vad_msg = next(
|
|
m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name
|
|
).args[0]
|
|
assert vad_msg.body == "RESUME"
|