feat: create tests for all currect functionality and add get available tags router
ref: N25B-334
This commit is contained in:
@@ -6,7 +6,7 @@ import zmq.asyncio as azmq
|
|||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.ri_message import GestureCommand
|
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
|
||||||
|
|
||||||
|
|
||||||
class RobotGestureAgent(BaseAgent):
|
class RobotGestureAgent(BaseAgent):
|
||||||
@@ -36,7 +36,9 @@ class RobotGestureAgent(BaseAgent):
|
|||||||
gesture_data=None,
|
gesture_data=None,
|
||||||
):
|
):
|
||||||
if gesture_data is None:
|
if gesture_data is None:
|
||||||
gesture_data = []
|
self.gesture_data = []
|
||||||
|
else:
|
||||||
|
self.gesture_data = gesture_data
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.address = address
|
self.address = address
|
||||||
self.bind = bind
|
self.bind = bind
|
||||||
@@ -65,8 +67,10 @@ class RobotGestureAgent(BaseAgent):
|
|||||||
self.subsocket = context.socket(zmq.SUB)
|
self.subsocket = context.socket(zmq.SUB)
|
||||||
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
||||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
||||||
# This one
|
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures")
|
||||||
|
|
||||||
self.add_behavior(self._zmq_command_loop())
|
self.add_behavior(self._zmq_command_loop())
|
||||||
|
self.add_behavior(self._fetch_gestures_loop())
|
||||||
|
|
||||||
self.logger.info("Finished setting up %s", self.name)
|
self.logger.info("Finished setting up %s", self.name)
|
||||||
|
|
||||||
@@ -87,6 +91,14 @@ class RobotGestureAgent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
gesture_command = GestureCommand.model_validate_json(msg.body)
|
gesture_command = GestureCommand.model_validate_json(msg.body)
|
||||||
|
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
|
||||||
|
if gesture_command.data not in self.availableTags():
|
||||||
|
self.logger.warning(
|
||||||
|
"Received gesture tag '%s' which is not in available tags. Early returning",
|
||||||
|
gesture_command.data,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
await self.pubsocket.send_json(gesture_command.model_dump())
|
await self.pubsocket.send_json(gesture_command.model_dump())
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception("Error processing internal message.")
|
self.logger.exception("Error processing internal message.")
|
||||||
@@ -99,15 +111,63 @@ class RobotGestureAgent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
_, body = await self.subsocket.recv_multipart()
|
topic, body = await self.subsocket.recv_multipart()
|
||||||
|
|
||||||
|
# Don't process send_gestures here
|
||||||
|
if topic != b"command":
|
||||||
|
continue
|
||||||
|
|
||||||
body = json.loads(body)
|
body = json.loads(body)
|
||||||
message = GestureCommand.model_validate(body)
|
gesture_command = GestureCommand.model_validate(body)
|
||||||
|
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
|
||||||
await self.pubsocket.send_json(message.model_dump())
|
if gesture_command.data not in self.availableTags():
|
||||||
|
self.logger.warning(
|
||||||
|
"Received gesture tag '%s' which is not in available tags.\
|
||||||
|
Early returning",
|
||||||
|
gesture_command.data,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
await self.pubsocket.send_json(gesture_command.model_dump())
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception("Error processing ZMQ message.")
|
self.logger.exception("Error processing ZMQ message.")
|
||||||
|
|
||||||
|
async def _fetch_gestures_loop(self):
|
||||||
|
"""
|
||||||
|
Loop to handle fetching gestures received via ZMQ (e.g., from the UI).
|
||||||
|
|
||||||
|
Listens on the 'send_gestures' topic, and returns a list on the get_gestures topic.
|
||||||
|
"""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
topic, body = await self.subsocket.recv_multipart()
|
||||||
|
|
||||||
|
# Don't process commands here
|
||||||
|
if topic != b"send_gestures":
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
body = json.loads(body)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
body = None
|
||||||
|
|
||||||
|
# We could have the body be the nummer of gestures you want to fetch or something.
|
||||||
|
amount = None
|
||||||
|
if isinstance(body, int):
|
||||||
|
amount = body
|
||||||
|
|
||||||
|
tags = self.availableTags()[:amount] if amount else self.availableTags()
|
||||||
|
response = json.dumps({"tags": tags}).encode()
|
||||||
|
|
||||||
|
await self.pubsocket.send_multipart(
|
||||||
|
[
|
||||||
|
b"get_gestures",
|
||||||
|
response,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception("Error fetching gesture tags.")
|
||||||
|
|
||||||
def availableTags(self):
|
def availableTags(self):
|
||||||
"""
|
"""
|
||||||
Returns the available gesture tags.
|
Returns the available gesture tags.
|
||||||
|
|||||||
@@ -58,6 +58,45 @@ async def ping(request: Request):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/get_available_gesture_tags")
|
||||||
|
async def get_available_gesture_tags(request: Request):
|
||||||
|
"""
|
||||||
|
Endpoint to retrieve the available gesture tags for the robot.
|
||||||
|
|
||||||
|
:param request: The FastAPI request object.
|
||||||
|
:return: A list of available gesture tags.
|
||||||
|
"""
|
||||||
|
sub_socket = Context.instance().socket(zmq.SUB)
|
||||||
|
sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||||
|
sub_socket.setsockopt(zmq.SUBSCRIBE, b"get_gestures")
|
||||||
|
|
||||||
|
pub_socket: Socket = request.app.state.endpoints_pub_socket
|
||||||
|
topic = b"send_gestures"
|
||||||
|
|
||||||
|
# TODO: Implement a way to get a certain ammount from the UI, rather than everything.
|
||||||
|
amount = None
|
||||||
|
timeout = 5 # seconds
|
||||||
|
|
||||||
|
await pub_socket.send_multipart([topic, amount.to_bytes(4, "big") if amount else b""])
|
||||||
|
try:
|
||||||
|
_, body = await asyncio.wait_for(sub_socket.recv_multipart(), timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
body = b"tags: []"
|
||||||
|
logger.debug("got timeout error fetching gestures")
|
||||||
|
|
||||||
|
# Handle empty response and JSON decode errors
|
||||||
|
available_tags = []
|
||||||
|
if body:
|
||||||
|
try:
|
||||||
|
available_tags = json.loads(body).get("tags", [])
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}")
|
||||||
|
# Return empty list on JSON error
|
||||||
|
available_tags = []
|
||||||
|
|
||||||
|
return {"available_gesture_tags": available_tags}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ping_stream")
|
@router.get("/ping_stream")
|
||||||
async def ping_stream(request: Request):
|
async def ping_stream(request: Request):
|
||||||
"""
|
"""
|
||||||
|
|||||||
392
test/unit/agents/actuation/test_robot_gesture_agent.py
Normal file
392
test/unit/agents/actuation/test_robot_gesture_agent.py
Normal file
@@ -0,0 +1,392 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.schemas.ri_message import RIEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def zmq_context(mocker):
|
||||||
|
"""Mock the ZMQ context."""
|
||||||
|
mock_context = mocker.patch(
|
||||||
|
"control_backend.agents.actuation.robot_gesture_agent.azmq.Context.instance"
|
||||||
|
)
|
||||||
|
mock_context.return_value = MagicMock()
|
||||||
|
return mock_context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_bind(zmq_context, mocker):
|
||||||
|
"""Setup binds and subscribes to internal commands."""
|
||||||
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
|
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=True)
|
||||||
|
|
||||||
|
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
|
||||||
|
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||||
|
|
||||||
|
agent.add_behavior = MagicMock()
|
||||||
|
|
||||||
|
await agent.setup()
|
||||||
|
|
||||||
|
# Check PUB socket binding
|
||||||
|
fake_socket.bind.assert_any_call("tcp://localhost:5556")
|
||||||
|
|
||||||
|
# Check SUB socket connection and subscriptions
|
||||||
|
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||||
|
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
|
||||||
|
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures")
|
||||||
|
|
||||||
|
# Check behavior was added
|
||||||
|
agent.add_behavior.assert_called() # Twice, even.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_connect(zmq_context, mocker):
|
||||||
|
"""Setup connects when bind=False."""
|
||||||
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
|
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=False)
|
||||||
|
|
||||||
|
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
|
||||||
|
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||||
|
|
||||||
|
agent.add_behavior = MagicMock()
|
||||||
|
|
||||||
|
await agent.setup()
|
||||||
|
|
||||||
|
# Check PUB socket connection (not binding)
|
||||||
|
fake_socket.connect.assert_any_call("tcp://localhost:5556")
|
||||||
|
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||||
|
|
||||||
|
# Check behavior was added
|
||||||
|
agent.add_behavior.assert_called() # Twice, actually.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_sends_valid_gesture_command():
|
||||||
|
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
|
||||||
|
pubsocket = AsyncMock()
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.pubsocket = pubsocket
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"endpoint": RIEndpoint.GESTURE_TAG,
|
||||||
|
"data": "hello", # "hello" is in availableTags
|
||||||
|
}
|
||||||
|
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||||
|
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
pubsocket.send_json.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_sends_non_gesture_command():
|
||||||
|
"""Internal message with non-gesture endpoint is not handled by this agent."""
|
||||||
|
pubsocket = AsyncMock()
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.pubsocket = pubsocket
|
||||||
|
|
||||||
|
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
|
||||||
|
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||||
|
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
pubsocket.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_rejects_invalid_gesture_tag():
|
||||||
|
"""Internal message with invalid gesture tag is not forwarded."""
|
||||||
|
pubsocket = AsyncMock()
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.pubsocket = pubsocket
|
||||||
|
|
||||||
|
# Use a tag that's not in availableTags
|
||||||
|
payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
|
||||||
|
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||||
|
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
pubsocket.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_invalid_payload():
|
||||||
|
"""Invalid payload is caught and does not send."""
|
||||||
|
pubsocket = AsyncMock()
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.pubsocket = pubsocket
|
||||||
|
|
||||||
|
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
|
||||||
|
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
pubsocket.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zmq_command_loop_valid_gesture_payload():
|
||||||
|
"""UI command with valid gesture tag is read from SUB and published."""
|
||||||
|
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"}
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
# stop after first iteration
|
||||||
|
agent._running = False
|
||||||
|
return (b"command", json.dumps(command).encode("utf-8"))
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_json = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._zmq_command_loop()
|
||||||
|
|
||||||
|
fake_socket.send_json.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zmq_command_loop_valid_non_gesture_payload():
|
||||||
|
"""UI command with non-gesture endpoint is not handled by this agent."""
|
||||||
|
command = {"endpoint": "some_other_endpoint", "data": "anything"}
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
agent._running = False
|
||||||
|
return (b"command", json.dumps(command).encode("utf-8"))
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_json = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._zmq_command_loop()
|
||||||
|
|
||||||
|
fake_socket.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zmq_command_loop_invalid_gesture_tag():
|
||||||
|
"""UI command with invalid gesture tag is not forwarded."""
|
||||||
|
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
agent._running = False
|
||||||
|
return (b"command", json.dumps(command).encode("utf-8"))
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_json = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._zmq_command_loop()
|
||||||
|
|
||||||
|
fake_socket.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zmq_command_loop_invalid_json():
|
||||||
|
"""Invalid JSON is ignored without sending."""
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
agent._running = False
|
||||||
|
return (b"command", b"{not_json}")
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_json = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._zmq_command_loop()
|
||||||
|
|
||||||
|
fake_socket.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zmq_command_loop_ignores_send_gestures_topic():
|
||||||
|
"""send_gestures topic is ignored in command loop."""
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
agent._running = False
|
||||||
|
return (b"send_gestures", b"{}")
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_json = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._zmq_command_loop()
|
||||||
|
|
||||||
|
fake_socket.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_gestures_loop_without_amount():
|
||||||
|
"""Fetch gestures request without amount returns all tags."""
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
agent._running = False
|
||||||
|
return (b"send_gestures", b"{}")
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._fetch_gestures_loop()
|
||||||
|
|
||||||
|
fake_socket.send_multipart.assert_awaited_once()
|
||||||
|
|
||||||
|
# Check the response contains all tags
|
||||||
|
args, kwargs = fake_socket.send_multipart.call_args
|
||||||
|
assert args[0][0] == b"get_gestures"
|
||||||
|
response = json.loads(args[0][1])
|
||||||
|
assert "tags" in response
|
||||||
|
assert len(response["tags"]) > 0
|
||||||
|
# Check it includes some expected tags
|
||||||
|
assert "hello" in response["tags"]
|
||||||
|
assert "yes" in response["tags"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_gestures_loop_with_amount():
|
||||||
|
"""Fetch gestures request with amount returns limited tags."""
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
amount = 5
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
agent._running = False
|
||||||
|
return (b"send_gestures", json.dumps(amount).encode())
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._fetch_gestures_loop()
|
||||||
|
|
||||||
|
fake_socket.send_multipart.assert_awaited_once()
|
||||||
|
|
||||||
|
args, kwargs = fake_socket.send_multipart.call_args
|
||||||
|
assert args[0][0] == b"get_gestures"
|
||||||
|
response = json.loads(args[0][1])
|
||||||
|
assert "tags" in response
|
||||||
|
assert len(response["tags"]) == amount
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_gestures_loop_ignores_command_topic():
|
||||||
|
"""Command topic is ignored in fetch gestures loop."""
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
agent._running = False
|
||||||
|
return (b"command", b"{}")
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._fetch_gestures_loop()
|
||||||
|
|
||||||
|
fake_socket.send_multipart.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_gestures_loop_invalid_request():
|
||||||
|
"""Invalid request body is handled gracefully."""
|
||||||
|
fake_socket = AsyncMock()
|
||||||
|
|
||||||
|
async def recv_once():
|
||||||
|
agent._running = False
|
||||||
|
# Send a non-integer, non-JSON body
|
||||||
|
return (b"send_gestures", b"not_json")
|
||||||
|
|
||||||
|
fake_socket.recv_multipart = recv_once
|
||||||
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.subsocket = fake_socket
|
||||||
|
agent.pubsocket = fake_socket
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
await agent._fetch_gestures_loop()
|
||||||
|
|
||||||
|
# Should still send a response (all tags)
|
||||||
|
fake_socket.send_multipart.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_available_tags():
|
||||||
|
"""Test that availableTags returns the expected list."""
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
|
||||||
|
tags = agent.availableTags()
|
||||||
|
|
||||||
|
assert isinstance(tags, list)
|
||||||
|
assert len(tags) > 0
|
||||||
|
# Check some expected tags are present
|
||||||
|
assert "hello" in tags
|
||||||
|
assert "yes" in tags
|
||||||
|
assert "no" in tags
|
||||||
|
# Check a non-existent tag is not present
|
||||||
|
assert "invalid_tag_not_in_list" not in tags
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_closes_sockets():
|
||||||
|
"""Stop method closes both sockets."""
|
||||||
|
pubsocket = MagicMock()
|
||||||
|
subsocket = MagicMock()
|
||||||
|
agent = RobotGestureAgent("robot_gesture")
|
||||||
|
agent.pubsocket = pubsocket
|
||||||
|
agent.subsocket = subsocket
|
||||||
|
|
||||||
|
await agent.stop()
|
||||||
|
|
||||||
|
pubsocket.close.assert_called_once()
|
||||||
|
subsocket.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialization_with_custom_gesture_data():
|
||||||
|
"""Agent can be initialized with custom gesture data."""
|
||||||
|
custom_gestures = ["custom1", "custom2", "custom3"]
|
||||||
|
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures)
|
||||||
|
|
||||||
|
# Note: The current implementation doesn't use the gesture_data parameter
|
||||||
|
# in availableTags(). This test documents that behavior.
|
||||||
|
# If you update the agent to use gesture_data, update this test accordingly.
|
||||||
|
assert agent.gesture_data == custom_gestures
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import zmq.asyncio
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
@@ -26,6 +27,26 @@ def client(app):
|
|||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_zmq_context():
|
||||||
|
"""Mock the ZMQ context."""
|
||||||
|
with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context:
|
||||||
|
context_instance = MagicMock()
|
||||||
|
mock_context.return_value = context_instance
|
||||||
|
yield context_instance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_sockets(mock_zmq_context):
|
||||||
|
"""Mock ZMQ sockets."""
|
||||||
|
mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||||
|
mock_pub_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||||
|
|
||||||
|
mock_zmq_context.socket.return_value = mock_sub_socket
|
||||||
|
|
||||||
|
return {"sub": mock_sub_socket, "pub": mock_pub_socket}
|
||||||
|
|
||||||
|
|
||||||
def test_receive_command_success(client):
|
def test_receive_command_success(client):
|
||||||
"""
|
"""
|
||||||
Test for successful reception of a command. Ensures the status code is 202 and the response body
|
Test for successful reception of a command. Ensures the status code is 202 and the response body
|
||||||
@@ -69,6 +90,7 @@ def test_ping_check_returns_none(client):
|
|||||||
assert response.json() is None
|
assert response.json() is None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Convert these mock sockets to the fixture.
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ping_stream_yields_ping_event(monkeypatch):
|
async def test_ping_stream_yields_ping_event(monkeypatch):
|
||||||
"""Test that ping_stream yields a proper SSE message when a ping is received."""
|
"""Test that ping_stream yields a proper SSE message when a ping is received."""
|
||||||
@@ -154,3 +176,251 @@ async def test_ping_stream_yields_json_values(monkeypatch):
|
|||||||
mock_sub_socket.connect.assert_called_once()
|
mock_sub_socket.connect.assert_called_once()
|
||||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||||
mock_sub_socket.recv_multipart.assert_awaited()
|
mock_sub_socket.recv_multipart.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
# New tests for get_available_gesture_tags endpoint
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_gesture_tags_success(client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test successful retrieval of available gesture tags.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_sub_socket = AsyncMock()
|
||||||
|
mock_sub_socket.connect = MagicMock()
|
||||||
|
mock_sub_socket.setsockopt = MagicMock()
|
||||||
|
|
||||||
|
# Simulate a response with gesture tags
|
||||||
|
response_data = {"tags": ["wave", "nod", "point", "dance"]}
|
||||||
|
mock_sub_socket.recv_multipart = AsyncMock(
|
||||||
|
return_value=[b"get_gestures", json.dumps(response_data).encode()]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.socket.return_value = mock_sub_socket
|
||||||
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||||
|
|
||||||
|
mock_pub_socket = AsyncMock()
|
||||||
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||||
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||||
|
|
||||||
|
# Mock logger to avoid actual logging
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
monkeypatch.setattr(robot.logger, "debug", mock_logger)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get("/get_available_gesture_tags")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]}
|
||||||
|
|
||||||
|
# Verify ZeroMQ interactions
|
||||||
|
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
|
||||||
|
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
|
||||||
|
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
|
||||||
|
mock_sub_socket.recv_multipart.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_gesture_tags_with_amount(client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test retrieval of gesture tags with a specific amount parameter.
|
||||||
|
This tests the TODO in the endpoint about getting a certain amount from the UI.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_sub_socket = AsyncMock()
|
||||||
|
mock_sub_socket.connect = MagicMock()
|
||||||
|
mock_sub_socket.setsockopt = MagicMock()
|
||||||
|
|
||||||
|
# Simulate a response with gesture tags
|
||||||
|
response_data = {"tags": ["wave", "nod"]}
|
||||||
|
mock_sub_socket.recv_multipart = AsyncMock(
|
||||||
|
return_value=[b"get_gestures", json.dumps(response_data).encode()]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.socket.return_value = mock_sub_socket
|
||||||
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||||
|
|
||||||
|
mock_pub_socket = AsyncMock()
|
||||||
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||||
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||||
|
|
||||||
|
# Mock logger
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
monkeypatch.setattr(robot.logger, "debug", mock_logger)
|
||||||
|
|
||||||
|
# Act - Note: The endpoint currently doesn't support query parameters for amount,
|
||||||
|
# but we're testing what happens if the UI sends an amount (the TODO in the code)
|
||||||
|
# For now, we test the current behavior
|
||||||
|
response = client.get("/get_available_gesture_tags")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"available_gesture_tags": ["wave", "nod"]}
|
||||||
|
|
||||||
|
# The endpoint currently doesn't use the amount parameter, so it should send empty bytes
|
||||||
|
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_gesture_tags_timeout(client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test timeout scenario when fetching gesture tags.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_sub_socket = AsyncMock()
|
||||||
|
mock_sub_socket.connect = MagicMock()
|
||||||
|
mock_sub_socket.setsockopt = MagicMock()
|
||||||
|
|
||||||
|
# Simulate a timeout
|
||||||
|
mock_sub_socket.recv_multipart = AsyncMock(side_effect=TimeoutError)
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.socket.return_value = mock_sub_socket
|
||||||
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||||
|
|
||||||
|
mock_pub_socket = AsyncMock()
|
||||||
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||||
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||||
|
|
||||||
|
# Mock logger to verify debug message is logged
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
monkeypatch.setattr(robot.logger, "debug", mock_logger)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get("/get_available_gesture_tags")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
# On timeout, body becomes b"" and json.loads(b"") raises JSONDecodeError
|
||||||
|
# But looking at the endpoint code, it will try to parse empty bytes which will fail
|
||||||
|
# Let's check what actually happens
|
||||||
|
assert response.json() == {"available_gesture_tags": []}
|
||||||
|
|
||||||
|
# Verify the timeout was logged
|
||||||
|
mock_logger.assert_called_once_with("got timeout error fetching gestures")
|
||||||
|
|
||||||
|
# Verify ZeroMQ interactions
|
||||||
|
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
|
||||||
|
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"get_gestures")
|
||||||
|
mock_pub_socket.send_multipart.assert_awaited_once_with([b"send_gestures", b""])
|
||||||
|
mock_sub_socket.recv_multipart.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_gesture_tags_empty_response(client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test scenario when response contains no tags.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_sub_socket = AsyncMock()
|
||||||
|
mock_sub_socket.connect = MagicMock()
|
||||||
|
mock_sub_socket.setsockopt = MagicMock()
|
||||||
|
|
||||||
|
# Simulate a response with empty tags
|
||||||
|
response_data = {"tags": []}
|
||||||
|
mock_sub_socket.recv_multipart = AsyncMock(
|
||||||
|
return_value=[b"get_gestures", json.dumps(response_data).encode()]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.socket.return_value = mock_sub_socket
|
||||||
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||||
|
|
||||||
|
mock_pub_socket = AsyncMock()
|
||||||
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||||
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get("/get_available_gesture_tags")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"available_gesture_tags": []}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test scenario when response JSON doesn't contain 'tags' key.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_sub_socket = AsyncMock()
|
||||||
|
mock_sub_socket.connect = MagicMock()
|
||||||
|
mock_sub_socket.setsockopt = MagicMock()
|
||||||
|
|
||||||
|
# Simulate a response without 'tags' key
|
||||||
|
response_data = {"some_other_key": "value"}
|
||||||
|
mock_sub_socket.recv_multipart = AsyncMock(
|
||||||
|
return_value=[b"get_gestures", json.dumps(response_data).encode()]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.socket.return_value = mock_sub_socket
|
||||||
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||||
|
|
||||||
|
mock_pub_socket = AsyncMock()
|
||||||
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||||
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get("/get_available_gesture_tags")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
# .get("tags", []) should return empty list if 'tags' key is missing
|
||||||
|
assert response.json() == {"available_gesture_tags": []}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_gesture_tags_invalid_json(client, monkeypatch):
|
||||||
|
"""
|
||||||
|
Test scenario when response contains invalid JSON.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_sub_socket = AsyncMock()
|
||||||
|
mock_sub_socket.connect = MagicMock()
|
||||||
|
mock_sub_socket.setsockopt = MagicMock()
|
||||||
|
|
||||||
|
# Simulate a response with invalid JSON
|
||||||
|
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"get_gestures", b"invalid json"])
|
||||||
|
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.socket.return_value = mock_sub_socket
|
||||||
|
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||||
|
|
||||||
|
mock_pub_socket = AsyncMock()
|
||||||
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||||
|
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get("/get_available_gesture_tags")
|
||||||
|
|
||||||
|
# Assert - invalid JSON should raise an exception
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"available_gesture_tags": []}
|
||||||
|
|||||||
Reference in New Issue
Block a user