refactor: testing
Redid testing structure, added tests and changed some tests. ref: N25B-301
This commit is contained in:
@@ -1,158 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch(
|
||||
"control_backend.agents.actuation.robot_speech_agent.azmq.Context.instance"
|
||||
)
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(zmq_context, mocker):
|
||||
"""Setup binds and subscribes to internal commands."""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=True)
|
||||
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
# Swallow background task coroutines to avoid un-awaited warnings
|
||||
class Swallow:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def __call__(self, coro):
|
||||
self.calls += 1
|
||||
coro.close()
|
||||
|
||||
swallow = Swallow()
|
||||
agent.add_background_task = swallow
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
|
||||
assert swallow.calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_connect(zmq_context, mocker):
|
||||
"""Setup connects when bind=False."""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
|
||||
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
class Swallow:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def __call__(self, coro):
|
||||
self.calls += 1
|
||||
coro.close()
|
||||
|
||||
swallow = Swallow()
|
||||
agent.add_background_task = swallow
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
assert swallow.calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_sends_command():
|
||||
"""Internal message is forwarded to robot pub socket as JSON."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
payload = {"endpoint": "actuate/speech", "data": "hello"}
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_awaited_once_with(payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_valid_payload(zmq_context):
|
||||
"""UI command is read from SUB and published."""
|
||||
command = {"endpoint": "actuate/speech", "data": "hello"}
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
# stop after first iteration
|
||||
agent._running = False
|
||||
return (b"command", json.dumps(command).encode("utf-8"))
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_awaited_once_with(command)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_invalid_json():
|
||||
"""Invalid JSON is ignored without sending."""
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return (b"command", b"{not_json}")
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_invalid_payload():
|
||||
"""Invalid payload is caught and does not send."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_closes_sockets():
|
||||
pubsocket = MagicMock()
|
||||
subsocket = MagicMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.pubsocket = pubsocket
|
||||
agent.subsocket = subsocket
|
||||
|
||||
await agent.stop()
|
||||
|
||||
pubsocket.close.assert_called_once()
|
||||
subsocket.close.assert_called_once()
|
||||
@@ -1,354 +0,0 @@
|
||||
import asyncio
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
|
||||
|
||||
|
||||
def speech_agent_path():
|
||||
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch(
|
||||
"control_backend.agents.communication.ri_communication_agent.Context.instance"
|
||||
)
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
def negotiation_message(
|
||||
actuation_port: int = 5556,
|
||||
bind_main: bool = False,
|
||||
bind_actuation: bool = True,
|
||||
main_port: int = 5555,
|
||||
):
|
||||
return {
|
||||
"endpoint": "negotiate/ports",
|
||||
"data": [
|
||||
{"id": "main", "port": main_port, "bind": bind_main},
|
||||
{"id": "actuation", "port": actuation_port, "bind": bind_actuation},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_success_connects_and_starts_robot(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
||||
robot_instance = MockRobot.return_value
|
||||
robot_instance.start = AsyncMock()
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
|
||||
class Swallow:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def __call__(self, coro):
|
||||
self.calls += 1
|
||||
coro.close()
|
||||
|
||||
swallow = Swallow()
|
||||
agent.add_background_task = swallow
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
|
||||
robot_instance.start.assert_awaited_once()
|
||||
MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True)
|
||||
assert swallow.calls == 1
|
||||
assert agent.connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_binds_when_requested(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message(bind_main=True))
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
|
||||
|
||||
class Swallow:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def __call__(self, coro):
|
||||
self.calls += 1
|
||||
coro.close()
|
||||
|
||||
swallow = Swallow()
|
||||
agent.add_background_task = swallow
|
||||
|
||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
||||
MockRobot.return_value.start = AsyncMock()
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
||||
assert swallow.calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_invalid_endpoint_retries(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
success = await agent._negotiate_connection(max_retries=1)
|
||||
|
||||
assert success is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_timeout(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
success = await agent._negotiate_connection(max_retries=1)
|
||||
|
||||
assert success is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_negotiation_response_updates_req_socket(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
||||
MockRobot.return_value.start = AsyncMock()
|
||||
await agent._handle_negotiation_response(
|
||||
negotiation_message(
|
||||
main_port=6000,
|
||||
actuation_port=6001,
|
||||
bind_main=False,
|
||||
bind_actuation=False,
|
||||
)
|
||||
)
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:6000")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_disconnection_publishes_and_reconnects():
|
||||
pub_socket = AsyncMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent.pub_socket = pub_socket
|
||||
agent.connected = True
|
||||
agent._negotiate_connection = AsyncMock(return_value=True)
|
||||
|
||||
await agent._handle_disconnection()
|
||||
|
||||
pub_socket.send_multipart.assert_awaited()
|
||||
assert agent.connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_handles_non_ping(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"endpoint": "negotiate/ports", "data": {}}
|
||||
|
||||
fake_socket.recv_json = recv_once
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
fake_socket.send_json.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_unexpected_error(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom"))
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
assert await agent._negotiate_connection(max_retries=1) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_handle_response_error(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent._handle_negotiation_response = AsyncMock(side_effect=Exception("bad response"))
|
||||
|
||||
assert await agent._negotiate_connection(max_retries=1) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
|
||||
async def swallow(coro):
|
||||
coro.close()
|
||||
|
||||
agent.add_background_task = swallow
|
||||
agent._negotiate_connection = AsyncMock(return_value=False)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
assert agent.connected is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_negotiation_response_unhandled_id():
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
|
||||
await agent._handle_negotiation_response(
|
||||
{"data": [{"id": "other", "port": 5000, "bind": False}]}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_closes_sockets():
|
||||
req = MagicMock()
|
||||
pub = MagicMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = pub
|
||||
|
||||
await agent.stop()
|
||||
|
||||
req.close.assert_called_once()
|
||||
pub.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_not_connected(monkeypatch):
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._running = True
|
||||
agent.connected = False
|
||||
agent._req_socket = AsyncMock()
|
||||
|
||||
async def fake_sleep(duration):
|
||||
agent._running = False
|
||||
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_send_and_recv_timeout():
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock(side_effect=TimeoutError)
|
||||
req.recv_json = AsyncMock(side_effect=TimeoutError)
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
async def stop_run():
|
||||
agent._running = False
|
||||
|
||||
agent._handle_disconnection = AsyncMock(side_effect=stop_run)
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
agent._handle_disconnection.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_missing_endpoint(monkeypatch):
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"data": {}}
|
||||
|
||||
req.recv_json = recv_once
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_generic_exception():
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock()
|
||||
req.recv_json = AsyncMock(side_effect=ValueError("boom"))
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_disconnection_timeout(monkeypatch):
|
||||
pub = AsyncMock()
|
||||
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent.pub_socket = pub
|
||||
agent._negotiate_connection = AsyncMock(return_value=False)
|
||||
|
||||
await agent._handle_disconnection()
|
||||
|
||||
pub.send_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_ping_sends_internal(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
pub_socket = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent.pub_socket = pub_socket
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"endpoint": "ping", "data": {}}
|
||||
|
||||
fake_socket.recv_json = recv_once
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
pub_socket.send_multipart.assert_awaited()
|
||||
@@ -1,125 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from control_backend.api.v1.endpoints import program
|
||||
from control_backend.schemas.program import Program
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a FastAPI app with the /program route and mock socket."""
|
||||
app = FastAPI()
|
||||
app.include_router(program.router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a TestClient."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def make_valid_program_dict():
|
||||
"""Helper to create a valid Program JSON structure."""
|
||||
return {
|
||||
"phases": [
|
||||
{
|
||||
"id": "phase1",
|
||||
"name": "basephase",
|
||||
"nextPhaseId": "phase2",
|
||||
"phaseData": {
|
||||
"norms": [{"id": "n1", "name": "norm", "value": "be nice"}],
|
||||
"goals": [
|
||||
{"id": "g1", "name": "goal", "description": "test goal", "achieved": False}
|
||||
],
|
||||
"triggers": [
|
||||
{
|
||||
"id": "t1",
|
||||
"label": "trigger",
|
||||
"type": "keyword",
|
||||
"value": ["stop", "exit"],
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_receive_program_success(client):
|
||||
"""Valid Program JSON should be parsed and sent through the socket."""
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
program_dict = make_valid_program_dict()
|
||||
|
||||
response = client.post("/program", json=program_dict)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Program parsed"}
|
||||
|
||||
# Verify socket call
|
||||
mock_pub_socket.send_multipart.assert_awaited_once()
|
||||
args, kwargs = mock_pub_socket.send_multipart.await_args
|
||||
|
||||
assert args[0][0] == b"program"
|
||||
|
||||
sent_bytes = args[0][1]
|
||||
sent_obj = json.loads(sent_bytes.decode())
|
||||
|
||||
expected_obj = Program.model_validate(program_dict).model_dump()
|
||||
assert sent_obj == expected_obj
|
||||
|
||||
|
||||
def test_receive_program_invalid_json(client):
|
||||
"""
|
||||
Invalid JSON (malformed) -> FastAPI never calls endpoint.
|
||||
It returns a 422 Unprocessable Entity.
|
||||
"""
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
# FastAPI only accepts valid JSON bodies, so send raw string
|
||||
response = client.post("/program", content="{invalid json}")
|
||||
|
||||
assert response.status_code == 422
|
||||
mock_pub_socket.send_multipart.assert_not_called()
|
||||
|
||||
|
||||
def test_receive_program_invalid_deep_structure(client):
|
||||
"""
|
||||
Valid JSON but schema invalid -> Pydantic throws validation error -> 422.
|
||||
"""
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
# Missing "value" in norms element
|
||||
bad_program = {
|
||||
"phases": [
|
||||
{
|
||||
"id": "phase1",
|
||||
"name": "deepfail",
|
||||
"nextPhaseId": "phase2",
|
||||
"phaseData": {
|
||||
"norms": [
|
||||
{"id": "n1", "name": "norm"} # INVALID: missing "value"
|
||||
],
|
||||
"goals": [
|
||||
{"id": "g1", "name": "goal", "description": "desc", "achieved": False}
|
||||
],
|
||||
"triggers": [
|
||||
{"id": "t1", "label": "trigger", "type": "keyword", "value": ["start"]}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = client.post("/program", json=bad_program)
|
||||
|
||||
assert response.status_code == 422
|
||||
mock_pub_socket.send_multipart.assert_not_called()
|
||||
@@ -1,156 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from control_backend.api.v1.endpoints import robot
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""
|
||||
Creates a FastAPI test app and attaches the router under test.
|
||||
Also sets up a mock internal_comm_socket.
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.include_router(robot.router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client for the app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_receive_command_success(client):
|
||||
"""
|
||||
Test for successful reception of a command. Ensures the status code is 202 and the response body
|
||||
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
|
||||
expected data.
|
||||
"""
|
||||
# Arrange
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
|
||||
speech_command = SpeechCommand(**command_data)
|
||||
|
||||
# Act
|
||||
response = client.post("/command", json=command_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Command received"}
|
||||
|
||||
# Verify that the ZMQ socket was used correctly
|
||||
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
||||
[b"command", speech_command.model_dump_json().encode()]
|
||||
)
|
||||
|
||||
|
||||
def test_receive_command_invalid_payload(client):
|
||||
"""
|
||||
Test invalid data handling (schema validation).
|
||||
"""
|
||||
# Missing required field(s)
|
||||
bad_payload = {"invalid": "data"}
|
||||
response = client.post("/command", json=bad_payload)
|
||||
assert response.status_code == 422 # validation error
|
||||
|
||||
|
||||
def test_ping_check_returns_none(client):
|
||||
"""Ensure /ping_check returns 200 and None (currently unimplemented)."""
|
||||
response = client.get("/ping_check")
|
||||
assert response.status_code == 200
|
||||
assert response.json() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_yields_ping_event(monkeypatch):
|
||||
"""Test that ping_stream yields a proper SSE message when a ping is received."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"ping", b"true"])
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
event = await anext(generator)
|
||||
event_text = event.decode() if isinstance(event, bytes) else str(event)
|
||||
assert event_text.strip() == "data: true"
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await anext(generator)
|
||||
|
||||
mock_sub_socket.connect.assert_called_once()
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_handles_timeout(monkeypatch):
|
||||
"""Test that ping_stream continues looping on TimeoutError."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart.side_effect = TimeoutError()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(return_value=True)
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await anext(generator)
|
||||
|
||||
mock_sub_socket.connect.assert_called_once()
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_yields_json_values(monkeypatch):
|
||||
"""Ensure ping_stream correctly parses and yields JSON body values."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart = AsyncMock(
|
||||
return_value=[b"ping", json.dumps({"connected": True}).encode()]
|
||||
)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
event = await anext(generator)
|
||||
event_text = event.decode() if isinstance(event, bytes) else str(event)
|
||||
|
||||
assert "connected" in event_text
|
||||
assert "true" in event_text
|
||||
|
||||
mock_sub_socket.connect.assert_called_once()
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
@@ -1,26 +0,0 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
|
||||
|
||||
|
||||
def valid_command_1():
|
||||
return SpeechCommand(data="Hallo?")
|
||||
|
||||
|
||||
def invalid_command_1():
|
||||
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
|
||||
|
||||
|
||||
def test_valid_speech_command_1():
|
||||
command = valid_command_1()
|
||||
RIMessage.model_validate(command)
|
||||
SpeechCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_invalid_speech_command_1():
|
||||
command = invalid_command_1()
|
||||
RIMessage.model_validate(command)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
SpeechCommand.model_validate(command)
|
||||
@@ -1,85 +0,0 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.program import Goal, Norm, Phase, PhaseData, Program, Trigger
|
||||
|
||||
|
||||
def base_norm() -> Norm:
|
||||
return Norm(
|
||||
id="norm1",
|
||||
name="testNorm",
|
||||
value="you should act nice",
|
||||
)
|
||||
|
||||
|
||||
def base_goal() -> Goal:
|
||||
return Goal(
|
||||
id="goal1",
|
||||
name="testGoal",
|
||||
description="you should act nice",
|
||||
achieved=False,
|
||||
)
|
||||
|
||||
|
||||
def base_trigger() -> Trigger:
|
||||
return Trigger(
|
||||
id="trigger1",
|
||||
label="testTrigger",
|
||||
type="keyword",
|
||||
value=["Stop", "Exit"],
|
||||
)
|
||||
|
||||
|
||||
def base_phase_data() -> PhaseData:
|
||||
return PhaseData(
|
||||
norms=[base_norm()],
|
||||
goals=[base_goal()],
|
||||
triggers=[base_trigger()],
|
||||
)
|
||||
|
||||
|
||||
def base_phase() -> Phase:
|
||||
return Phase(
|
||||
id="phase1",
|
||||
name="basephase",
|
||||
nextPhaseId="phase2",
|
||||
phaseData=base_phase_data(),
|
||||
)
|
||||
|
||||
|
||||
def base_program() -> Program:
|
||||
return Program(phases=[base_phase()])
|
||||
|
||||
|
||||
def invalid_program() -> dict:
|
||||
# wrong types inside phases list (not Phase objects)
|
||||
return {
|
||||
"phases": [
|
||||
{"id": "phase1"}, # incomplete
|
||||
{"not_a_phase": True},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_valid_program():
|
||||
program = base_program()
|
||||
validated = Program.model_validate(program)
|
||||
assert isinstance(validated, Program)
|
||||
assert validated.phases[0].phaseData.norms[0].name == "testNorm"
|
||||
|
||||
|
||||
def test_valid_deepprogram():
|
||||
program = base_program()
|
||||
validated = Program.model_validate(program)
|
||||
# validate nested components directly
|
||||
phase = validated.phases[0]
|
||||
assert isinstance(phase.phaseData, PhaseData)
|
||||
assert isinstance(phase.phaseData.goals[0], Goal)
|
||||
assert isinstance(phase.phaseData.triggers[0], Trigger)
|
||||
assert isinstance(phase.phaseData.norms[0], Norm)
|
||||
|
||||
|
||||
def test_invalid_program():
|
||||
bad = invalid_program()
|
||||
with pytest.raises(ValidationError):
|
||||
Program.model_validate(bad)
|
||||
Reference in New Issue
Block a user