147 lines
4.5 KiB
Python
147 lines
4.5 KiB
Python
import asyncio
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.core.config import settings
|
|
from control_backend.schemas.ri_message import RIEndpoint
|
|
|
|
|
|
@pytest.fixture
|
|
def agent():
|
|
agent = UserInterruptAgent(name="user_interrupt_agent")
|
|
agent.send = AsyncMock()
|
|
agent.logger = MagicMock()
|
|
agent.sub_socket = AsyncMock()
|
|
return agent
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_speech_agent(agent):
|
|
"""Verify speech command format."""
|
|
await agent._send_to_speech_agent("Hello World")
|
|
|
|
agent.send.assert_awaited_once()
|
|
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.robot_speech_name
|
|
body = json.loads(sent_msg.body)
|
|
assert body["data"] == "Hello World"
|
|
assert body["is_priority"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_gesture_agent(agent):
|
|
"""Verify gesture command format."""
|
|
await agent._send_to_gesture_agent("wave_hand")
|
|
|
|
agent.send.assert_awaited_once()
|
|
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.robot_gesture_name
|
|
body = json.loads(sent_msg.body)
|
|
assert body["data"] == "wave_hand"
|
|
assert body["is_priority"] is True
|
|
assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_program_manager(agent):
|
|
"""Verify belief update format."""
|
|
context_str = "2"
|
|
|
|
await agent._send_to_program_manager(context_str)
|
|
|
|
agent.send.assert_awaited_once()
|
|
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
|
|
|
assert sent_msg.to == settings.agent_settings.bdi_program_manager_name
|
|
assert sent_msg.thread == "belief_override_id"
|
|
|
|
body = json.loads(sent_msg.body)
|
|
|
|
assert body["belief"] == context_str
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_routing_success(agent):
|
|
"""
|
|
Test that the loop correctly:
|
|
1. Receives 'button_pressed' topic from ZMQ
|
|
2. Parses the JSON payload to find 'type' and 'context'
|
|
3. Calls the correct handler method based on 'type'
|
|
"""
|
|
# Prepare JSON payloads as bytes
|
|
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
|
|
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
|
|
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"button_pressed", payload_speech),
|
|
(b"button_pressed", payload_gesture),
|
|
(b"button_pressed", payload_override),
|
|
asyncio.CancelledError, # Stop the infinite loop
|
|
]
|
|
|
|
agent._send_to_speech_agent = AsyncMock()
|
|
agent._send_to_gesture_agent = AsyncMock()
|
|
agent._send_to_program_manager = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
# Speech
|
|
agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech")
|
|
|
|
# Gesture
|
|
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
|
|
|
|
# Override
|
|
agent._send_to_program_manager.assert_awaited_once_with("Hello Override")
|
|
|
|
assert agent._send_to_speech_agent.await_count == 1
|
|
assert agent._send_to_gesture_agent.await_count == 1
|
|
assert agent._send_to_program_manager.await_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_loop_unknown_type(agent):
|
|
"""Test that unknown 'type' values in the JSON log a warning and do not crash."""
|
|
|
|
# Prepare a payload with an unknown type
|
|
payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode()
|
|
|
|
agent.sub_socket.recv_multipart.side_effect = [
|
|
(b"button_pressed", payload_unknown),
|
|
asyncio.CancelledError,
|
|
]
|
|
|
|
agent._send_to_speech_agent = AsyncMock()
|
|
agent._send_to_gesture_agent = AsyncMock()
|
|
agent._send_to_belief_collector = AsyncMock()
|
|
|
|
try:
|
|
await agent._receive_button_event()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
# Ensure no handlers were called
|
|
agent._send_to_speech_agent.assert_not_called()
|
|
agent._send_to_gesture_agent.assert_not_called()
|
|
agent._send_to_belief_collector.assert_not_called()
|
|
|
|
agent.logger.warning.assert_called_with(
|
|
"Received button press with unknown type '%s' (context: '%s').",
|
|
"unknown_thing",
|
|
"some_data",
|
|
)
|