393 lines
12 KiB
Python
393 lines
12 KiB
Python
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
import zmq
|
|
|
|
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.schemas.ri_message import RIEndpoint
|
|
|
|
|
|
@pytest.fixture
|
|
def zmq_context(mocker):
|
|
"""Mock the ZMQ context."""
|
|
mock_context = mocker.patch(
|
|
"control_backend.agents.actuation.robot_gesture_agent.azmq.Context.instance"
|
|
)
|
|
mock_context.return_value = MagicMock()
|
|
return mock_context
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_bind(zmq_context, mocker):
|
|
"""Setup binds and subscribes to internal commands."""
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=True)
|
|
|
|
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
|
|
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
|
|
|
agent.add_behavior = MagicMock()
|
|
|
|
await agent.setup()
|
|
|
|
# Check PUB socket binding
|
|
fake_socket.bind.assert_any_call("tcp://localhost:5556")
|
|
|
|
# Check SUB socket connection and subscriptions
|
|
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
|
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
|
|
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures")
|
|
|
|
# Check behavior was added
|
|
agent.add_behavior.assert_called() # Twice, even.
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_connect(zmq_context, mocker):
|
|
"""Setup connects when bind=False."""
|
|
fake_socket = zmq_context.return_value.socket.return_value
|
|
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=False)
|
|
|
|
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
|
|
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
|
|
|
agent.add_behavior = MagicMock()
|
|
|
|
await agent.setup()
|
|
|
|
# Check PUB socket connection (not binding)
|
|
fake_socket.connect.assert_any_call("tcp://localhost:5556")
|
|
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
|
|
|
# Check behavior was added
|
|
agent.add_behavior.assert_called() # Twice, actually.
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_sends_valid_gesture_command():
|
|
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
|
|
pubsocket = AsyncMock()
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.pubsocket = pubsocket
|
|
|
|
payload = {
|
|
"endpoint": RIEndpoint.GESTURE_TAG,
|
|
"data": "hello", # "hello" is in availableTags
|
|
}
|
|
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
pubsocket.send_json.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_sends_non_gesture_command():
|
|
"""Internal message with non-gesture endpoint is not handled by this agent."""
|
|
pubsocket = AsyncMock()
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.pubsocket = pubsocket
|
|
|
|
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
|
|
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
pubsocket.send_json.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_rejects_invalid_gesture_tag():
|
|
"""Internal message with invalid gesture tag is not forwarded."""
|
|
pubsocket = AsyncMock()
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.pubsocket = pubsocket
|
|
|
|
# Use a tag that's not in availableTags
|
|
payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
|
|
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
pubsocket.send_json.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_invalid_payload():
|
|
"""Invalid payload is caught and does not send."""
|
|
pubsocket = AsyncMock()
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.pubsocket = pubsocket
|
|
|
|
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
pubsocket.send_json.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_zmq_command_loop_valid_gesture_payload():
|
|
"""UI command with valid gesture tag is read from SUB and published."""
|
|
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"}
|
|
fake_socket = AsyncMock()
|
|
|
|
async def recv_once():
|
|
# stop after first iteration
|
|
agent._running = False
|
|
return (b"command", json.dumps(command).encode("utf-8"))
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_json = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._zmq_command_loop()
|
|
|
|
fake_socket.send_json.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_zmq_command_loop_valid_non_gesture_payload():
|
|
"""UI command with non-gesture endpoint is not handled by this agent."""
|
|
command = {"endpoint": "some_other_endpoint", "data": "anything"}
|
|
fake_socket = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return (b"command", json.dumps(command).encode("utf-8"))
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_json = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._zmq_command_loop()
|
|
|
|
fake_socket.send_json.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_zmq_command_loop_invalid_gesture_tag():
|
|
"""UI command with invalid gesture tag is not forwarded."""
|
|
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
|
|
fake_socket = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return (b"command", json.dumps(command).encode("utf-8"))
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_json = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._zmq_command_loop()
|
|
|
|
fake_socket.send_json.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_zmq_command_loop_invalid_json():
|
|
"""Invalid JSON is ignored without sending."""
|
|
fake_socket = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return (b"command", b"{not_json}")
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_json = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._zmq_command_loop()
|
|
|
|
fake_socket.send_json.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_zmq_command_loop_ignores_send_gestures_topic():
|
|
"""send_gestures topic is ignored in command loop."""
|
|
fake_socket = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return (b"send_gestures", b"{}")
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_json = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._zmq_command_loop()
|
|
|
|
fake_socket.send_json.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_gestures_loop_without_amount():
|
|
"""Fetch gestures request without amount returns all tags."""
|
|
fake_socket = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return (b"send_gestures", b"{}")
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_multipart = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._fetch_gestures_loop()
|
|
|
|
fake_socket.send_multipart.assert_awaited_once()
|
|
|
|
# Check the response contains all tags
|
|
args, kwargs = fake_socket.send_multipart.call_args
|
|
assert args[0][0] == b"get_gestures"
|
|
response = json.loads(args[0][1])
|
|
assert "tags" in response
|
|
assert len(response["tags"]) > 0
|
|
# Check it includes some expected tags
|
|
assert "hello" in response["tags"]
|
|
assert "yes" in response["tags"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_gestures_loop_with_amount():
|
|
"""Fetch gestures request with amount returns limited tags."""
|
|
fake_socket = AsyncMock()
|
|
amount = 5
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return (b"send_gestures", json.dumps(amount).encode())
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_multipart = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._fetch_gestures_loop()
|
|
|
|
fake_socket.send_multipart.assert_awaited_once()
|
|
|
|
args, kwargs = fake_socket.send_multipart.call_args
|
|
assert args[0][0] == b"get_gestures"
|
|
response = json.loads(args[0][1])
|
|
assert "tags" in response
|
|
assert len(response["tags"]) == amount
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_gestures_loop_ignores_command_topic():
|
|
"""Command topic is ignored in fetch gestures loop."""
|
|
fake_socket = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
return (b"command", b"{}")
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_multipart = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._fetch_gestures_loop()
|
|
|
|
fake_socket.send_multipart.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_gestures_loop_invalid_request():
|
|
"""Invalid request body is handled gracefully."""
|
|
fake_socket = AsyncMock()
|
|
|
|
async def recv_once():
|
|
agent._running = False
|
|
# Send a non-integer, non-JSON body
|
|
return (b"send_gestures", b"not_json")
|
|
|
|
fake_socket.recv_multipart = recv_once
|
|
fake_socket.send_multipart = AsyncMock()
|
|
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.subsocket = fake_socket
|
|
agent.pubsocket = fake_socket
|
|
agent._running = True
|
|
|
|
await agent._fetch_gestures_loop()
|
|
|
|
# Should still send a response (all tags)
|
|
fake_socket.send_multipart.assert_awaited_once()
|
|
|
|
|
|
def test_available_tags():
|
|
"""Test that availableTags returns the expected list."""
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
|
|
tags = agent.availableTags()
|
|
|
|
assert isinstance(tags, list)
|
|
assert len(tags) > 0
|
|
# Check some expected tags are present
|
|
assert "hello" in tags
|
|
assert "yes" in tags
|
|
assert "no" in tags
|
|
# Check a non-existent tag is not present
|
|
assert "invalid_tag_not_in_list" not in tags
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_closes_sockets():
|
|
"""Stop method closes both sockets."""
|
|
pubsocket = MagicMock()
|
|
subsocket = MagicMock()
|
|
agent = RobotGestureAgent("robot_gesture")
|
|
agent.pubsocket = pubsocket
|
|
agent.subsocket = subsocket
|
|
|
|
await agent.stop()
|
|
|
|
pubsocket.close.assert_called_once()
|
|
subsocket.close.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_custom_gesture_data():
|
|
"""Agent can be initialized with custom gesture data."""
|
|
custom_gestures = ["custom1", "custom2", "custom3"]
|
|
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures)
|
|
|
|
# Note: The current implementation doesn't use the gesture_data parameter
|
|
# in availableTags(). This test documents that behavior.
|
|
# If you update the agent to use gesture_data, update this test accordingly.
|
|
assert agent.gesture_data == custom_gestures
|