feat: add tests and better model validation for gesture commands
ref: N25B-334
This commit is contained in:
@@ -86,8 +86,8 @@ class RobotGestureAgent(BaseAgent):
|
|||||||
:param msg: The internal message containing the command.
|
:param msg: The internal message containing the command.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
speech_command = GestureCommand.model_validate_json(msg.body)
|
gesture_command = GestureCommand.model_validate_json(msg.body)
|
||||||
await self.pubsocket.send_json(speech_command.model_dump())
|
await self.pubsocket.send_json(gesture_command.model_dump())
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception("Error processing internal message.")
|
self.logger.exception("Error processing internal message.")
|
||||||
|
|
||||||
@@ -107,3 +107,129 @@ class RobotGestureAgent(BaseAgent):
|
|||||||
await self.pubsocket.send_json(message.model_dump())
|
await self.pubsocket.send_json(message.model_dump())
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception("Error processing ZMQ message.")
|
self.logger.exception("Error processing ZMQ message.")
|
||||||
|
|
||||||
|
def availableTags(self):
|
||||||
|
"""
|
||||||
|
Returns the available gesture tags.
|
||||||
|
|
||||||
|
:return: List of available gesture tags.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"above",
|
||||||
|
"affirmative",
|
||||||
|
"afford",
|
||||||
|
"agitated",
|
||||||
|
"all",
|
||||||
|
"allright",
|
||||||
|
"alright",
|
||||||
|
"any",
|
||||||
|
"assuage",
|
||||||
|
"assuage",
|
||||||
|
"attemper",
|
||||||
|
"back",
|
||||||
|
"bashful",
|
||||||
|
"beg",
|
||||||
|
"beseech",
|
||||||
|
"blank",
|
||||||
|
"body language",
|
||||||
|
"bored",
|
||||||
|
"bow",
|
||||||
|
"but",
|
||||||
|
"call",
|
||||||
|
"calm",
|
||||||
|
"choose",
|
||||||
|
"choice",
|
||||||
|
"cloud",
|
||||||
|
"cogitate",
|
||||||
|
"cool",
|
||||||
|
"crazy",
|
||||||
|
"disappointed",
|
||||||
|
"down",
|
||||||
|
"earth",
|
||||||
|
"empty",
|
||||||
|
"embarrassed",
|
||||||
|
"enthusiastic",
|
||||||
|
"entire",
|
||||||
|
"estimate",
|
||||||
|
"except",
|
||||||
|
"exalted",
|
||||||
|
"excited",
|
||||||
|
"explain",
|
||||||
|
"far",
|
||||||
|
"field",
|
||||||
|
"floor",
|
||||||
|
"forlorn",
|
||||||
|
"friendly",
|
||||||
|
"front",
|
||||||
|
"frustrated",
|
||||||
|
"gentle",
|
||||||
|
"gift",
|
||||||
|
"give",
|
||||||
|
"ground",
|
||||||
|
"happy",
|
||||||
|
"hello",
|
||||||
|
"her",
|
||||||
|
"here",
|
||||||
|
"hey",
|
||||||
|
"hi",
|
||||||
|
"him",
|
||||||
|
"hopeless",
|
||||||
|
"hysterical",
|
||||||
|
"I",
|
||||||
|
"implore",
|
||||||
|
"indicate",
|
||||||
|
"joyful",
|
||||||
|
"me",
|
||||||
|
"meditate",
|
||||||
|
"modest",
|
||||||
|
"negative",
|
||||||
|
"nervous",
|
||||||
|
"no",
|
||||||
|
"not know",
|
||||||
|
"nothing",
|
||||||
|
"offer",
|
||||||
|
"ok",
|
||||||
|
"once upon a time",
|
||||||
|
"oppose",
|
||||||
|
"or",
|
||||||
|
"pacify",
|
||||||
|
"pick",
|
||||||
|
"placate",
|
||||||
|
"please",
|
||||||
|
"present",
|
||||||
|
"proffer",
|
||||||
|
"quiet",
|
||||||
|
"reason",
|
||||||
|
"refute",
|
||||||
|
"reject",
|
||||||
|
"rousing",
|
||||||
|
"sad",
|
||||||
|
"select",
|
||||||
|
"shamefaced",
|
||||||
|
"show",
|
||||||
|
"show sky",
|
||||||
|
"sky",
|
||||||
|
"soothe",
|
||||||
|
"sun",
|
||||||
|
"supplicate",
|
||||||
|
"tablet",
|
||||||
|
"tall",
|
||||||
|
"them",
|
||||||
|
"there",
|
||||||
|
"think",
|
||||||
|
"timid",
|
||||||
|
"top",
|
||||||
|
"unless",
|
||||||
|
"up",
|
||||||
|
"upstairs",
|
||||||
|
"void",
|
||||||
|
"warm",
|
||||||
|
"winner",
|
||||||
|
"yeah",
|
||||||
|
"yes",
|
||||||
|
"yoo-hoo",
|
||||||
|
"you",
|
||||||
|
"your",
|
||||||
|
"zero",
|
||||||
|
"zestful",
|
||||||
|
]
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
gesture_data=gesture_data,
|
gesture_data=gesture_data,
|
||||||
)
|
)
|
||||||
await robot_speech_agent.start()
|
await robot_speech_agent.start()
|
||||||
|
await asyncio.sleep(0.1) # Small delay
|
||||||
await robot_gesture_agent.start()
|
await robot_gesture_agent.start()
|
||||||
case _:
|
case _:
|
||||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
|
||||||
class RIEndpoint(str, Enum):
|
class RIEndpoint(str, Enum):
|
||||||
@@ -48,5 +48,17 @@ class GestureCommand(RIMessage):
|
|||||||
:ivar data: The id of the gesture to be executed.
|
:ivar data: The id of the gesture to be executed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.GESTURE_TAG) or RIEndpoint(RIEndpoint.GESTURE_TAG)
|
endpoint: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] - We validate this stricter rule ourselves
|
||||||
|
RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG
|
||||||
|
]
|
||||||
data: str
|
data: str
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_endpoint(self):
|
||||||
|
allowed = {
|
||||||
|
RIEndpoint.GESTURE_SINGLE,
|
||||||
|
RIEndpoint.GESTURE_TAG,
|
||||||
|
}
|
||||||
|
if self.endpoint not in allowed:
|
||||||
|
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
|
||||||
|
return self
|
||||||
|
|||||||
@@ -10,6 +10,10 @@ def speech_agent_path():
|
|||||||
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
|
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
|
||||||
|
|
||||||
|
|
||||||
|
def gesture_agent_path():
|
||||||
|
return "control_backend.agents.communication.ri_communication_agent.RobotGestureAgent"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def zmq_context(mocker):
|
def zmq_context(mocker):
|
||||||
mock_context = mocker.patch(
|
mock_context = mocker.patch(
|
||||||
@@ -22,7 +26,7 @@ def zmq_context(mocker):
|
|||||||
def negotiation_message(
|
def negotiation_message(
|
||||||
actuation_port: int = 5556,
|
actuation_port: int = 5556,
|
||||||
bind_main: bool = False,
|
bind_main: bool = False,
|
||||||
bind_actuation: bool = True,
|
bind_actuation: bool = False,
|
||||||
main_port: int = 5555,
|
main_port: int = 5555,
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
@@ -41,9 +45,12 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
|
|||||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
with (
|
||||||
robot_instance = MockRobot.return_value
|
patch(speech_agent_path(), autospec=True) as MockSpeech,
|
||||||
robot_instance.start = AsyncMock()
|
patch(gesture_agent_path(), autospec=True) as MockGesture,
|
||||||
|
):
|
||||||
|
MockSpeech.return_value.start = AsyncMock()
|
||||||
|
MockGesture.return_value.start = AsyncMock()
|
||||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||||
|
|
||||||
agent.add_behavior = MagicMock()
|
agent.add_behavior = MagicMock()
|
||||||
@@ -52,9 +59,17 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
|
|||||||
|
|
||||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||||
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
|
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
|
||||||
robot_instance.start.assert_awaited_once()
|
MockSpeech.return_value.start.assert_awaited_once()
|
||||||
MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True)
|
MockGesture.return_value.start.assert_awaited_once()
|
||||||
|
MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False)
|
||||||
|
MockGesture.assert_called_once_with(
|
||||||
|
ANY,
|
||||||
|
address="tcp://localhost:5556",
|
||||||
|
bind=False,
|
||||||
|
gesture_data=[],
|
||||||
|
)
|
||||||
agent.add_behavior.assert_called_once()
|
agent.add_behavior.assert_called_once()
|
||||||
|
|
||||||
assert agent.connected is True
|
assert agent.connected is True
|
||||||
|
|
||||||
|
|
||||||
@@ -69,10 +84,13 @@ async def test_setup_binds_when_requested(zmq_context):
|
|||||||
|
|
||||||
agent.add_behavior = MagicMock()
|
agent.add_behavior = MagicMock()
|
||||||
|
|
||||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
with (
|
||||||
MockRobot.return_value.start = AsyncMock()
|
patch(speech_agent_path(), autospec=True) as MockSpeech,
|
||||||
|
patch(gesture_agent_path(), autospec=True) as MockGesture,
|
||||||
|
):
|
||||||
|
MockSpeech.return_value.start = AsyncMock()
|
||||||
|
MockGesture.return_value.start = AsyncMock()
|
||||||
await agent.setup()
|
await agent.setup()
|
||||||
|
|
||||||
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
||||||
agent.add_behavior.assert_called_once()
|
agent.add_behavior.assert_called_once()
|
||||||
|
|
||||||
@@ -88,7 +106,6 @@ async def test_negotiate_invalid_endpoint_retries(zmq_context):
|
|||||||
agent._req_socket = fake_socket
|
agent._req_socket = fake_socket
|
||||||
|
|
||||||
success = await agent._negotiate_connection(max_retries=1)
|
success = await agent._negotiate_connection(max_retries=1)
|
||||||
|
|
||||||
assert success is False
|
assert success is False
|
||||||
|
|
||||||
|
|
||||||
@@ -112,8 +129,12 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context):
|
|||||||
fake_socket = zmq_context.return_value.socket.return_value
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||||
agent._req_socket = fake_socket
|
agent._req_socket = fake_socket
|
||||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
with (
|
||||||
MockRobot.return_value.start = AsyncMock()
|
patch(speech_agent_path(), autospec=True) as MockSpeech,
|
||||||
|
patch(gesture_agent_path(), autospec=True) as MockGesture,
|
||||||
|
):
|
||||||
|
MockSpeech.return_value.start = AsyncMock()
|
||||||
|
MockGesture.return_value.start = AsyncMock()
|
||||||
await agent._handle_negotiation_response(
|
await agent._handle_negotiation_response(
|
||||||
negotiation_message(
|
negotiation_message(
|
||||||
main_port=6000,
|
main_port=6000,
|
||||||
@@ -135,7 +156,6 @@ async def test_handle_disconnection_publishes_and_reconnects():
|
|||||||
agent._negotiate_connection = AsyncMock(return_value=True)
|
agent._negotiate_connection = AsyncMock(return_value=True)
|
||||||
|
|
||||||
await agent._handle_disconnection()
|
await agent._handle_disconnection()
|
||||||
|
|
||||||
pub_socket.send_multipart.assert_awaited()
|
pub_socket.send_multipart.assert_awaited()
|
||||||
assert agent.connected is True
|
assert agent.connected is True
|
||||||
|
|
||||||
@@ -192,7 +212,7 @@ async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
|
|||||||
fake_socket.recv_json = AsyncMock()
|
fake_socket.recv_json = AsyncMock()
|
||||||
agent = RICommunicationAgent("ri_comm")
|
agent = RICommunicationAgent("ri_comm")
|
||||||
|
|
||||||
async def swallow(coro):
|
def swallow(coro):
|
||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
agent.add_behavior = swallow
|
agent.add_behavior = swallow
|
||||||
|
|||||||
@@ -21,7 +21,21 @@ def invalid_command_1():
|
|||||||
|
|
||||||
|
|
||||||
def invalid_command_2():
|
def invalid_command_2():
|
||||||
return GestureCommand(endpoint=RIEndpoint.PING, data="Hey!")
|
return RIMessage(endpoint=RIEndpoint.PING, data="Hey!")
|
||||||
|
|
||||||
|
|
||||||
|
def invalid_command_3():
|
||||||
|
return RIMessage(endpoint=RIEndpoint.GESTURE_SINGLE, data={1, 2, 3})
|
||||||
|
|
||||||
|
|
||||||
|
def invalid_command_4():
|
||||||
|
test: RIMessage = GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="asdsad")
|
||||||
|
|
||||||
|
def change_endpoint(msg: RIMessage):
|
||||||
|
msg.endpoint = RIEndpoint.PING
|
||||||
|
|
||||||
|
change_endpoint(test)
|
||||||
|
return test
|
||||||
|
|
||||||
|
|
||||||
def test_valid_speech_command_1():
|
def test_valid_speech_command_1():
|
||||||
@@ -56,3 +70,19 @@ def test_invalid_gesture_command_1():
|
|||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
GestureCommand.model_validate(command)
|
GestureCommand.model_validate(command)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_gesture_command_2():
|
||||||
|
command = invalid_command_3()
|
||||||
|
RIMessage.model_validate(command)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
GestureCommand.model_validate(command)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_gesture_command_3():
|
||||||
|
command = invalid_command_4()
|
||||||
|
RIMessage.model_validate(command)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
GestureCommand.model_validate(command)
|
||||||
|
|||||||
Reference in New Issue
Block a user