diff --git a/src/control_backend/agents/actuation/robot_gesture_agent.py b/src/control_backend/agents/actuation/robot_gesture_agent.py index 9f51d21..8447190 100644 --- a/src/control_backend/agents/actuation/robot_gesture_agent.py +++ b/src/control_backend/agents/actuation/robot_gesture_agent.py @@ -86,8 +86,8 @@ class RobotGestureAgent(BaseAgent): :param msg: The internal message containing the command. """ try: - speech_command = GestureCommand.model_validate_json(msg.body) - await self.pubsocket.send_json(speech_command.model_dump()) + gesture_command = GestureCommand.model_validate_json(msg.body) + await self.pubsocket.send_json(gesture_command.model_dump()) except Exception: self.logger.exception("Error processing internal message.") @@ -107,3 +107,129 @@ class RobotGestureAgent(BaseAgent): await self.pubsocket.send_json(message.model_dump()) except Exception: self.logger.exception("Error processing ZMQ message.") + + def availableTags(self): + """ + Returns the available gesture tags. + + :return: List of available gesture tags. + """ + return [ + "above", + "affirmative", + "afford", + "agitated", + "all", + "allright", + "alright", + "any", + "assuage", + "assuage", + "attemper", + "back", + "bashful", + "beg", + "beseech", + "blank", + "body language", + "bored", + "bow", + "but", + "call", + "calm", + "choose", + "choice", + "cloud", + "cogitate", + "cool", + "crazy", + "disappointed", + "down", + "earth", + "empty", + "embarrassed", + "enthusiastic", + "entire", + "estimate", + "except", + "exalted", + "excited", + "explain", + "far", + "field", + "floor", + "forlorn", + "friendly", + "front", + "frustrated", + "gentle", + "gift", + "give", + "ground", + "happy", + "hello", + "her", + "here", + "hey", + "hi", + "him", + "hopeless", + "hysterical", + "I", + "implore", + "indicate", + "joyful", + "me", + "meditate", + "modest", + "negative", + "nervous", + "no", + "not know", + "nothing", + "offer", + "ok", + "once upon a time", + "oppose", + "or", + "pacify", + "pick", + "placate", + "please", + "present", + "proffer", + "quiet", + "reason", + "refute", + "reject", + "rousing", + "sad", + "select", + "shamefaced", + "show", + "show sky", + "sky", + "soothe", + "sun", + "supplicate", + "tablet", + "tall", + "them", + "there", + "think", + "timid", + "top", + "unless", + "up", + "upstairs", + "void", + "warm", + "winner", + "yeah", + "yes", + "yoo-hoo", + "you", + "your", + "zero", + "zestful", + ] diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index 1b72fe7..5b89088 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -193,6 +193,7 @@ class RICommunicationAgent(BaseAgent): gesture_data=gesture_data, ) await robot_speech_agent.start() + await asyncio.sleep(0.1) # Small delay await robot_gesture_agent.start() case _: self.logger.warning("Unhandled negotiation id: %s", id) diff --git a/src/control_backend/schemas/ri_message.py b/src/control_backend/schemas/ri_message.py index fd073a3..3f3abea 100644 --- a/src/control_backend/schemas/ri_message.py +++ b/src/control_backend/schemas/ri_message.py @@ -1,7 +1,7 @@ from enum import Enum -from typing import Any +from typing import Any, Literal -from pydantic import BaseModel +from pydantic import BaseModel, model_validator class RIEndpoint(str, Enum): @@ -48,5 +48,17 @@ class GestureCommand(RIMessage): :ivar data: The id of the gesture to be executed. """ - endpoint: RIEndpoint = RIEndpoint(RIEndpoint.GESTURE_TAG) or RIEndpoint(RIEndpoint.GESTURE_TAG) + endpoint: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] - We validate this stricter rule ourselves + RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG + ] data: str + + @model_validator(mode="after") + def check_endpoint(self): + allowed = { + RIEndpoint.GESTURE_SINGLE, + RIEndpoint.GESTURE_TAG, + } + if self.endpoint not in allowed: + raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG") + return self diff --git a/test/unit/agents/communication/test_ri_communication_agent.py b/test/unit/agents/communication/test_ri_communication_agent.py index 747c4d2..54f3c5a 100644 --- a/test/unit/agents/communication/test_ri_communication_agent.py +++ b/test/unit/agents/communication/test_ri_communication_agent.py @@ -10,6 +10,10 @@ def speech_agent_path(): return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent" +def gesture_agent_path(): + return "control_backend.agents.communication.ri_communication_agent.RobotGestureAgent" + + @pytest.fixture def zmq_context(mocker): mock_context = mocker.patch( @@ -22,7 +26,7 @@ def zmq_context(mocker): def negotiation_message( actuation_port: int = 5556, bind_main: bool = False, - bind_actuation: bool = True, + bind_actuation: bool = False, main_port: int = 5555, ): return { @@ -41,9 +45,12 @@ async def test_setup_success_connects_and_starts_robot(zmq_context): fake_socket.recv_json = AsyncMock(return_value=negotiation_message()) fake_socket.send_multipart = AsyncMock() - with patch(speech_agent_path(), autospec=True) as MockRobot: - robot_instance = MockRobot.return_value - robot_instance.start = AsyncMock() + with ( + patch(speech_agent_path(), autospec=True) as MockSpeech, + patch(gesture_agent_path(), autospec=True) as MockGesture, + ): + MockSpeech.return_value.start = AsyncMock() + MockGesture.return_value.start = AsyncMock() agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) agent.add_behavior = MagicMock() @@ -52,9 +59,17 @@ async def test_setup_success_connects_and_starts_robot(zmq_context): fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) - robot_instance.start.assert_awaited_once() - MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True) + MockSpeech.return_value.start.assert_awaited_once() + MockGesture.return_value.start.assert_awaited_once() + MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False) + MockGesture.assert_called_once_with( + ANY, + address="tcp://localhost:5556", + bind=False, + gesture_data=[], + ) agent.add_behavior.assert_called_once() + assert agent.connected is True @@ -69,10 +84,13 @@ async def test_setup_binds_when_requested(zmq_context): agent.add_behavior = MagicMock() - with patch(speech_agent_path(), autospec=True) as MockRobot: - MockRobot.return_value.start = AsyncMock() + with ( + patch(speech_agent_path(), autospec=True) as MockSpeech, + patch(gesture_agent_path(), autospec=True) as MockGesture, + ): + MockSpeech.return_value.start = AsyncMock() + MockGesture.return_value.start = AsyncMock() await agent.setup() - fake_socket.bind.assert_any_call("tcp://localhost:5555") agent.add_behavior.assert_called_once() @@ -88,7 +106,6 @@ async def test_negotiate_invalid_endpoint_retries(zmq_context): agent._req_socket = fake_socket success = await agent._negotiate_connection(max_retries=1) - assert success is False @@ -112,8 +129,12 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context): fake_socket = zmq_context.return_value.socket.return_value agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) agent._req_socket = fake_socket - with patch(speech_agent_path(), autospec=True) as MockRobot: - MockRobot.return_value.start = AsyncMock() + with ( + patch(speech_agent_path(), autospec=True) as MockSpeech, + patch(gesture_agent_path(), autospec=True) as MockGesture, + ): + MockSpeech.return_value.start = AsyncMock() + MockGesture.return_value.start = AsyncMock() await agent._handle_negotiation_response( negotiation_message( main_port=6000, @@ -135,7 +156,6 @@ async def test_handle_disconnection_publishes_and_reconnects(): agent._negotiate_connection = AsyncMock(return_value=True) await agent._handle_disconnection() - pub_socket.send_multipart.assert_awaited() assert agent.connected is True @@ -192,7 +212,7 @@ async def test_setup_warns_on_failed_negotiate(zmq_context, mocker): fake_socket.recv_json = AsyncMock() agent = RICommunicationAgent("ri_comm") - async def swallow(coro): + def swallow(coro): coro.close() agent.add_behavior = swallow diff --git a/test/unit/schemas/test_ri_message.py b/test/unit/schemas/test_ri_message.py index 193f7c3..40601ec 100644 --- a/test/unit/schemas/test_ri_message.py +++ b/test/unit/schemas/test_ri_message.py @@ -21,7 +21,21 @@ def invalid_command_1(): def invalid_command_2(): - return GestureCommand(endpoint=RIEndpoint.PING, data="Hey!") + return RIMessage(endpoint=RIEndpoint.PING, data="Hey!") + + +def invalid_command_3(): + return RIMessage(endpoint=RIEndpoint.GESTURE_SINGLE, data={1, 2, 3}) + + +def invalid_command_4(): + test: RIMessage = GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="asdsad") + + def change_endpoint(msg: RIMessage): + msg.endpoint = RIEndpoint.PING + + change_endpoint(test) + return test def test_valid_speech_command_1(): @@ -56,3 +70,19 @@ def test_invalid_gesture_command_1(): with pytest.raises(ValidationError): GestureCommand.model_validate(command) + + +def test_invalid_gesture_command_2(): + command = invalid_command_3() + RIMessage.model_validate(command) + + with pytest.raises(ValidationError): + GestureCommand.model_validate(command) + + +def test_invalid_gesture_command_3(): + command = invalid_command_4() + RIMessage.model_validate(command) + + with pytest.raises(ValidationError): + GestureCommand.model_validate(command)