diff --git a/src/control_backend/agents/actuation/robot_gesture_agent.py b/src/control_backend/agents/actuation/robot_gesture_agent.py index e641eba..86e3cfa 100644 --- a/src/control_backend/agents/actuation/robot_gesture_agent.py +++ b/src/control_backend/agents/actuation/robot_gesture_agent.py @@ -27,16 +27,22 @@ class RobotGestureAgent(BaseAgent): pubsocket: azmq.Socket address = "" bind = False - gesture_data = [] + gesture_tags = [] + gesture_basic = [] + gesture_single = [] def __init__( self, name: str, address=settings.zmq_settings.ri_command_address, bind=False, - gesture_data=None, + gesture_tags=None, + gesture_basic=None, + gesture_single=None, ): - self.gesture_data = gesture_data or [] + self.gesture_tags = gesture_tags or [] + self.gesture_basic = gesture_basic or [] + self.gesture_single = gesture_single or [] super().__init__(name) self.address = address self.bind = bind @@ -92,13 +98,14 @@ class RobotGestureAgent(BaseAgent): """ try: gesture_command = GestureCommand.model_validate_json(msg.body) - if gesture_command.endpoint == RIEndpoint.GESTURE_TAG: - if gesture_command.data not in self.gesture_data: - self.logger.warning( - "Received gesture tag '%s' which is not in available tags. Early returning", - gesture_command.data, - ) - return + # if gesture_command.endpoint == RIEndpoint.GESTURE_TAG: + # if gesture_command.data not in self.gesture_data: + # self.logger.warning( + # "Received gesture tag '%s' which is not in available tags. + # Early returning", + # gesture_command.data, + # ) + # return await self.pubsocket.send_json(gesture_command.model_dump()) except Exception: @@ -134,29 +141,39 @@ class RobotGestureAgent(BaseAgent): async def _fetch_gestures_loop(self): """ - Loop to handle fetching gestures received via ZMQ (e.g., from the UI). - - Listens on the 'send_gestures' topic, and returns a list on the get_gestures topic. + REP socket handler for gesture queries. + Supports: + - tags + - basic_gestures + - single_gestures """ while self._running: try: - # Get a request - body = await self.repsocket.recv() + req = await self.repsocket.recv_json() - # Figure out amount, if specified - try: - body = json.loads(body) - except json.JSONDecodeError: - body = None + req_type = req.get("type") + amount = req.get("count") - amount = None - if isinstance(body, int): - amount = body + if req_type == "tags": + data = self.gesture_tags + key = "tags" - # Fetch tags from gesture data and respond - tags = self.gesture_data[:amount] if amount else self.gesture_data - response = json.dumps({"tags": tags}).encode() - await self.repsocket.send(response) + elif req_type == "basic": + data = self.gesture_basic + key = "basic_gestures" + + elif req_type == "single": + data = self.gesture_single + key = "single_gestures" + + else: + await self.repsocket.send_json({}) + continue + + if amount: + data = data[:amount] + + await self.repsocket.send_json({key: data}) except Exception: - self.logger.exception("Error fetching gesture tags.") + self.logger.exception("Error fetching gestures.") diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index a50892c..ba0c8e0 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -181,7 +181,9 @@ class RICommunicationAgent(BaseAgent): else: self._req_socket.bind(addr) case "actuation": - gesture_data = port_data.get("gestures", []) + gesture_tags = port_data.get("gestures", []) + gesture_single = port_data.get("single_gestures", []) + gesture_basic = port_data.get("basic_gestures", []) robot_speech_agent = RobotSpeechAgent( settings.agent_settings.robot_speech_name, address=addr, @@ -191,7 +193,9 @@ class RICommunicationAgent(BaseAgent): settings.agent_settings.robot_gesture_name, address=addr, bind=bind, - gesture_data=gesture_data, + gesture_tags=gesture_tags, + gesture_basic=gesture_basic, + gesture_single=gesture_single, ) await robot_speech_agent.start() await asyncio.sleep(0.1) # Small delay diff --git a/src/control_backend/api/v1/endpoints/robot.py b/src/control_backend/api/v1/endpoints/robot.py index afbf1ac..c0316ec 100644 --- a/src/control_backend/api/v1/endpoints/robot.py +++ b/src/control_backend/api/v1/endpoints/robot.py @@ -94,10 +94,77 @@ async def get_available_gesture_tags(request: Request, count=0): logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}") # Return empty list on JSON error available_tags = [] - return {"available_gesture_tags": available_tags} +@router.get("/commands/gesture/single") +async def get_available_gestures(request: Request, count=0): + """ + Endpoint to retrieve the available gestures for the robot. + + :param request: The FastAPI request object. + :return: A list of available gestures. + """ + req_socket = Context.instance().socket(zmq.REQ) + req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress) + + # Check to see if we've got any count given in the query parameter + amount = count or None + timeout = 5 # seconds + + await req_socket.send_json({"type": "single", "count": amount}) + try: + body = await asyncio.wait_for(req_socket.recv(), timeout=timeout) + except TimeoutError: + body = '{"tags": []}' + logger.debug("Got timeout error fetching gestures.") + + # Handle empty response and JSON decode errors + available_tags = [] + if body: + try: + available_tags = json.loads(body).get("single_gestures", []) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}") + # Return empty list on JSON error + available_tags = [] + return {"available_gestures": available_tags} + + +@router.get("/commands/gesture/basic") +async def get_available_basic_gestures(request: Request, count=0): + """ + Endpoint to retrieve the available gesture tags for the robot. + + :param request: The FastAPI request object. + :return: A list of 10 available gestures. + """ + req_socket = Context.instance().socket(zmq.REQ) + req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress) + + # Check to see if we've got any count given in the query parameter + amount = count or None + timeout = 5 # seconds + + await req_socket.send_json({"type": "basic", "count": amount}) + try: + body = await asyncio.wait_for(req_socket.recv(), timeout=timeout) + except TimeoutError: + body = '{"tags": []}' + logger.debug("Got timeout error fetching gestures.") + + # Handle empty response and JSON decode errors + available_tags = [] + if body: + try: + available_tags = json.loads(body).get("basic_gestures", []) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}") + # Return empty list on JSON error + available_tags = [] + return {"available_gestures": available_tags} + + @router.get("/ping_stream") async def ping_stream(request: Request): """ diff --git a/test/unit/agents/actuation/test_robot_gesture_agent.py b/test/unit/agents/actuation/test_robot_gesture_agent.py index c68f052..e1fe1e0 100644 --- a/test/unit/agents/actuation/test_robot_gesture_agent.py +++ b/test/unit/agents/actuation/test_robot_gesture_agent.py @@ -21,64 +21,52 @@ def zmq_context(mocker): @pytest.mark.asyncio async def test_setup_bind(zmq_context, mocker): - """Setup binds and subscribes to internal commands.""" fake_socket = zmq_context.return_value.socket.return_value agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=True) settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" + settings.zmq_settings.internal_gesture_rep_adress = "tcp://internal:5557" agent.add_behavior = MagicMock() await agent.setup() - # Check PUB socket binding fake_socket.bind.assert_any_call("tcp://localhost:5556") - # Check REP socket binding - fake_socket.bind.assert_called() - - # Check SUB socket connection and subscriptions fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures") - - # Check behavior was added (twice: once for command loop, once for fetch gestures loop) assert agent.add_behavior.call_count == 2 @pytest.mark.asyncio async def test_setup_connect(zmq_context, mocker): - """Setup connects when bind=False.""" fake_socket = zmq_context.return_value.socket.return_value agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=False) settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" + settings.zmq_settings.internal_gesture_rep_adress = "tcp://internal:5557" agent.add_behavior = MagicMock() await agent.setup() - # Check PUB socket connection (not binding) fake_socket.connect.assert_any_call("tcp://localhost:5556") fake_socket.connect.assert_any_call("tcp://internal:1234") - # Check REP socket binding (always binds) fake_socket.bind.assert_called() - - # Check behavior was added (twice) assert agent.add_behavior.call_count == 2 @pytest.mark.asyncio -async def test_handle_message_sends_valid_gesture_command(): - """Internal message with valid gesture tag is forwarded to robot pub socket.""" +async def test_handle_message_forwards_valid_command(): pubsocket = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) + agent = RobotGestureAgent("robot_gesture") agent.pubsocket = pubsocket payload = { "endpoint": RIEndpoint.GESTURE_TAG, - "data": "hello", # "hello" is in gesture_data + "data": "hello", } msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) @@ -87,69 +75,38 @@ async def test_handle_message_sends_valid_gesture_command(): pubsocket.send_json.assert_awaited_once() -@pytest.mark.asyncio -async def test_handle_message_sends_non_gesture_command(): - """Internal message with non-gesture endpoint is not forwarded by this agent.""" - pubsocket = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) - agent.pubsocket = pubsocket - - payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"} - msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) - - await agent.handle_message(msg) - - # Non-gesture endpoints should not be forwarded by this agent - pubsocket.send_json.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_handle_message_rejects_invalid_gesture_tag(): - """Internal message with invalid gesture tag is not forwarded.""" - pubsocket = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) - agent.pubsocket = pubsocket - - # Use a tag that's not in gesture_data - payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"} - msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) - - await agent.handle_message(msg) - - pubsocket.send_json.assert_not_awaited() - - @pytest.mark.asyncio async def test_handle_message_invalid_payload(): - """Invalid payload is caught and does not send.""" pubsocket = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) + agent = RobotGestureAgent("robot_gesture") agent.pubsocket = pubsocket + agent.logger = MagicMock() msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"})) await agent.handle_message(msg) pubsocket.send_json.assert_not_awaited() + agent.logger.exception.assert_called_once() @pytest.mark.asyncio -async def test_zmq_command_loop_valid_gesture_payload(): - """UI command with valid gesture tag is read from SUB and published.""" +async def test_zmq_command_loop_valid_payload(): + """UI command with valid payload is published.""" command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"} fake_socket = AsyncMock() async def recv_once(): - # stop after first iteration agent._running = False return (b"command", json.dumps(command).encode("utf-8")) fake_socket.recv_multipart = recv_once fake_socket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) + agent = RobotGestureAgent("robot_gesture") agent.subsocket = fake_socket agent.pubsocket = fake_socket + agent.gesture_data = ["hello", "yes", "no"] # ← REQUIRED for legacy check agent._running = True await agent._zmq_command_loop() @@ -158,76 +115,7 @@ async def test_zmq_command_loop_valid_gesture_payload(): @pytest.mark.asyncio -async def test_zmq_command_loop_valid_non_gesture_payload(): - """UI command with non-gesture endpoint is not forwarded by this agent.""" - command = {"endpoint": "some_other_endpoint", "data": "anything"} - fake_socket = AsyncMock() - - async def recv_once(): - agent._running = False - return (b"command", json.dumps(command).encode("utf-8")) - - fake_socket.recv_multipart = recv_once - fake_socket.send_json = AsyncMock() - - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) - agent.subsocket = fake_socket - agent.pubsocket = fake_socket - agent._running = True - - await agent._zmq_command_loop() - - fake_socket.send_json.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_zmq_command_loop_invalid_gesture_tag(): - """UI command with invalid gesture tag is not forwarded.""" - command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"} - fake_socket = AsyncMock() - - async def recv_once(): - agent._running = False - return (b"command", json.dumps(command).encode("utf-8")) - - fake_socket.recv_multipart = recv_once - fake_socket.send_json = AsyncMock() - - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) - agent.subsocket = fake_socket - agent.pubsocket = fake_socket - agent._running = True - - await agent._zmq_command_loop() - - fake_socket.send_json.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_zmq_command_loop_invalid_json(): - """Invalid JSON is ignored without sending.""" - fake_socket = AsyncMock() - - async def recv_once(): - agent._running = False - return (b"command", b"{not_json}") - - fake_socket.recv_multipart = recv_once - fake_socket.send_json = AsyncMock() - - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) - agent.subsocket = fake_socket - agent.pubsocket = fake_socket - agent._running = True - - await agent._zmq_command_loop() - - fake_socket.send_json.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_zmq_command_loop_ignores_send_gestures_topic(): - """send_gestures topic is ignored in command loop.""" +async def test_zmq_command_loop_ignores_send_gestures(): fake_socket = AsyncMock() async def recv_once(): @@ -237,7 +125,7 @@ async def test_zmq_command_loop_ignores_send_gestures_topic(): fake_socket.recv_multipart = recv_once fake_socket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) + agent = RobotGestureAgent("robot_gesture") agent.subsocket = fake_socket agent.pubsocket = fake_socket agent._running = True @@ -248,156 +136,122 @@ async def test_zmq_command_loop_ignores_send_gestures_topic(): @pytest.mark.asyncio -async def test_fetch_gestures_loop_without_amount(): - """Fetch gestures request without amount returns all tags.""" +async def test_fetch_gestures_tags_all(): fake_repsocket = AsyncMock() async def recv_once(): agent._running = False - return b"{}" # Empty JSON request + return {"type": "tags"} - fake_repsocket.recv = recv_once - fake_repsocket.send = AsyncMock() + fake_repsocket.recv_json = recv_once + fake_repsocket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"]) + agent = RobotGestureAgent( + "robot_gesture", + gesture_tags=["hello", "yes", "no"], + ) agent.repsocket = fake_repsocket agent._running = True await agent._fetch_gestures_loop() - fake_repsocket.send.assert_awaited_once() - - # Check the response contains all tags - args, kwargs = fake_repsocket.send.call_args - response = json.loads(args[0]) - assert "tags" in response - assert response["tags"] == ["hello", "yes", "no", "wave", "point"] + fake_repsocket.send_json.assert_awaited_once_with({"tags": ["hello", "yes", "no"]}) @pytest.mark.asyncio -async def test_fetch_gestures_loop_with_amount(): - """Fetch gestures request with amount returns limited tags.""" +async def test_fetch_gestures_tags_with_count(): fake_repsocket = AsyncMock() - amount = 3 async def recv_once(): agent._running = False - return json.dumps(amount).encode() + return {"type": "tags", "count": 2} - fake_repsocket.recv = recv_once - fake_repsocket.send = AsyncMock() + fake_repsocket.recv_json = recv_once + fake_repsocket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"]) + agent = RobotGestureAgent( + "robot_gesture", + gesture_tags=["hello", "yes", "no"], + ) agent.repsocket = fake_repsocket agent._running = True await agent._fetch_gestures_loop() - fake_repsocket.send.assert_awaited_once() - - args, kwargs = fake_repsocket.send.call_args - response = json.loads(args[0]) - assert "tags" in response - assert len(response["tags"]) == amount - assert response["tags"] == ["hello", "yes", "no"] + fake_repsocket.send_json.assert_awaited_once_with({"tags": ["hello", "yes"]}) @pytest.mark.asyncio -async def test_fetch_gestures_loop_with_integer_request(): - """Fetch gestures request with integer amount.""" +async def test_fetch_gestures_basic_new(): + """NEW: fetch basic gestures""" fake_repsocket = AsyncMock() - amount = 2 async def recv_once(): agent._running = False - return json.dumps(amount).encode() + return {"type": "basic"} - fake_repsocket.recv = recv_once - fake_repsocket.send = AsyncMock() + fake_repsocket.recv_json = recv_once + fake_repsocket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) + agent = RobotGestureAgent( + "robot_gesture", + gesture_basic=["wave", "point"], + ) agent.repsocket = fake_repsocket agent._running = True await agent._fetch_gestures_loop() - fake_repsocket.send.assert_awaited_once() - - args, kwargs = fake_repsocket.send.call_args - response = json.loads(args[0]) - assert response["tags"] == ["hello", "yes"] + fake_repsocket.send_json.assert_awaited_once_with({"basic_gestures": ["wave", "point"]}) @pytest.mark.asyncio -async def test_fetch_gestures_loop_with_invalid_json(): - """Invalid JSON request returns all tags.""" +async def test_fetch_gestures_unknown_type(): fake_repsocket = AsyncMock() async def recv_once(): agent._running = False - return b"not_json" + return {"type": "unknown"} - fake_repsocket.recv = recv_once - fake_repsocket.send = AsyncMock() + fake_repsocket.recv_json = recv_once + fake_repsocket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) + agent = RobotGestureAgent("robot_gesture") agent.repsocket = fake_repsocket agent._running = True await agent._fetch_gestures_loop() - fake_repsocket.send.assert_awaited_once() - - args, kwargs = fake_repsocket.send.call_args - response = json.loads(args[0]) - assert response["tags"] == ["hello", "yes", "no"] + fake_repsocket.send_json.assert_awaited_once_with({}) @pytest.mark.asyncio -async def test_fetch_gestures_loop_with_non_integer_json(): - """Non-integer JSON request returns all tags.""" +async def test_fetch_gestures_exception_logged(): fake_repsocket = AsyncMock() async def recv_once(): agent._running = False - return json.dumps({"not": "an_integer"}).encode() + raise Exception("boom") - fake_repsocket.recv = recv_once - fake_repsocket.send = AsyncMock() + fake_repsocket.recv_json = recv_once + fake_repsocket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) + agent = RobotGestureAgent("robot_gesture") agent.repsocket = fake_repsocket + agent.logger = MagicMock() agent._running = True await agent._fetch_gestures_loop() - fake_repsocket.send.assert_awaited_once() - - args, kwargs = fake_repsocket.send.call_args - response = json.loads(args[0]) - assert response["tags"] == ["hello", "yes", "no"] - - -def test_gesture_data_attribute(): - """Test that gesture_data returns the expected list.""" - gesture_data = ["hello", "yes", "no", "wave"] - agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data) - - assert agent.gesture_data == gesture_data - assert isinstance(agent.gesture_data, list) - assert len(agent.gesture_data) == 4 - assert "hello" in agent.gesture_data - assert "yes" in agent.gesture_data - assert "no" in agent.gesture_data - assert "invalid_tag_not_in_list" not in agent.gesture_data + agent.logger.exception.assert_called_once() @pytest.mark.asyncio async def test_stop_closes_sockets(): - """Stop method closes all sockets.""" pubsocket = MagicMock() subsocket = MagicMock() repsocket = MagicMock() + agent = RobotGestureAgent("robot_gesture") agent.pubsocket = pubsocket agent.subsocket = subsocket @@ -407,38 +261,3 @@ async def test_stop_closes_sockets(): pubsocket.close.assert_called_once() subsocket.close.assert_called_once() - # Note: repsocket is not closed in stop() method, but you might want to add it - # repsocket.close.assert_called_once() - - -@pytest.mark.asyncio -async def test_initialization_with_custom_gesture_data(): - """Agent can be initialized with custom gesture data.""" - custom_gestures = ["custom1", "custom2", "custom3"] - agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures) - - assert agent.gesture_data == custom_gestures - - -@pytest.mark.asyncio -async def test_fetch_gestures_loop_handles_exception(): - """Exception in fetch gestures loop is caught and logged.""" - fake_repsocket = AsyncMock() - - async def recv_once(): - agent._running = False - raise Exception("Test exception") - - fake_repsocket.recv = recv_once - fake_repsocket.send = AsyncMock() - - agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"]) - agent.repsocket = fake_repsocket - agent.logger = MagicMock() - agent._running = True - - # Should not raise exception - await agent._fetch_gestures_loop() - - # Exception should be logged - agent.logger.exception.assert_called_once() diff --git a/test/unit/agents/communication/test_ri_communication_agent.py b/test/unit/agents/communication/test_ri_communication_agent.py index 018b19d..a3cc991 100644 --- a/test/unit/agents/communication/test_ri_communication_agent.py +++ b/test/unit/agents/communication/test_ri_communication_agent.py @@ -61,15 +61,18 @@ async def test_setup_success_connects_and_starts_robot(zmq_context): fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) MockSpeech.return_value.start.assert_awaited_once() MockGesture.return_value.start.assert_awaited_once() + MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False) MockGesture.assert_called_once_with( ANY, address="tcp://localhost:5556", bind=False, - gesture_data=[], + gesture_tags=[], + gesture_basic=[], + gesture_single=[], ) - agent.add_behavior.assert_called_once() + agent.add_behavior.assert_called_once() assert agent.connected is True