feat: add tests and better model validation for gesture commands

ref: N25B-334
This commit is contained in:
Björn Otgaar
2025-12-04 15:13:27 +01:00
parent b93c39420e
commit fe4a060188
5 changed files with 209 additions and 20 deletions

View File

@@ -10,6 +10,10 @@ def speech_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
def gesture_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotGestureAgent"
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch(
@@ -22,7 +26,7 @@ def zmq_context(mocker):
def negotiation_message(
actuation_port: int = 5556,
bind_main: bool = False,
bind_actuation: bool = True,
bind_actuation: bool = False,
main_port: int = 5555,
):
return {
@@ -41,9 +45,12 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
fake_socket.send_multipart = AsyncMock()
with patch(speech_agent_path(), autospec=True) as MockRobot:
robot_instance = MockRobot.return_value
robot_instance.start = AsyncMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent.add_behavior = MagicMock()
@@ -52,9 +59,17 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
robot_instance.start.assert_awaited_once()
MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True)
MockSpeech.return_value.start.assert_awaited_once()
MockGesture.return_value.start.assert_awaited_once()
MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False)
MockGesture.assert_called_once_with(
ANY,
address="tcp://localhost:5556",
bind=False,
gesture_data=[],
)
agent.add_behavior.assert_called_once()
assert agent.connected is True
@@ -69,10 +84,13 @@ async def test_setup_binds_when_requested(zmq_context):
agent.add_behavior = MagicMock()
with patch(speech_agent_path(), autospec=True) as MockRobot:
MockRobot.return_value.start = AsyncMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555")
agent.add_behavior.assert_called_once()
@@ -88,7 +106,6 @@ async def test_negotiate_invalid_endpoint_retries(zmq_context):
agent._req_socket = fake_socket
success = await agent._negotiate_connection(max_retries=1)
assert success is False
@@ -112,8 +129,12 @@ async def test_handle_negotiation_response_updates_req_socket(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
agent._req_socket = fake_socket
with patch(speech_agent_path(), autospec=True) as MockRobot:
MockRobot.return_value.start = AsyncMock()
with (
patch(speech_agent_path(), autospec=True) as MockSpeech,
patch(gesture_agent_path(), autospec=True) as MockGesture,
):
MockSpeech.return_value.start = AsyncMock()
MockGesture.return_value.start = AsyncMock()
await agent._handle_negotiation_response(
negotiation_message(
main_port=6000,
@@ -135,7 +156,6 @@ async def test_handle_disconnection_publishes_and_reconnects():
agent._negotiate_connection = AsyncMock(return_value=True)
await agent._handle_disconnection()
pub_socket.send_multipart.assert_awaited()
assert agent.connected is True
@@ -192,7 +212,7 @@ async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
fake_socket.recv_json = AsyncMock()
agent = RICommunicationAgent("ri_comm")
async def swallow(coro):
def swallow(coro):
coro.close()
agent.add_behavior = swallow

View File

@@ -21,7 +21,21 @@ def invalid_command_1():
def invalid_command_2():
return GestureCommand(endpoint=RIEndpoint.PING, data="Hey!")
return RIMessage(endpoint=RIEndpoint.PING, data="Hey!")
def invalid_command_3():
return RIMessage(endpoint=RIEndpoint.GESTURE_SINGLE, data={1, 2, 3})
def invalid_command_4():
test: RIMessage = GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="asdsad")
def change_endpoint(msg: RIMessage):
msg.endpoint = RIEndpoint.PING
change_endpoint(test)
return test
def test_valid_speech_command_1():
@@ -56,3 +70,19 @@ def test_invalid_gesture_command_1():
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)
def test_invalid_gesture_command_2():
command = invalid_command_3()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)
def test_invalid_gesture_command_3():
command = invalid_command_4()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
GestureCommand.model_validate(command)