From e2a71ad6c220beb56184c837cdf4b58bf8eada86 Mon Sep 17 00:00:00 2001 From: JobvAlewijk Date: Mon, 24 Nov 2025 20:37:59 +0000 Subject: [PATCH] test: added main tests --- test/unit/test_main.py | 222 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 test/unit/test_main.py diff --git a/test/unit/test_main.py b/test/unit/test_main.py new file mode 100644 index 0000000..f323630 --- /dev/null +++ b/test/unit/test_main.py @@ -0,0 +1,222 @@ +import pytest +import threading +import zmq + +import robot_interface.main as main_mod +from robot_interface.state import state + + +class FakeSocket: + """Mock ZMQ socket for testing.""" + def __init__(self, socket_type, messages=None): + self.socket_type = socket_type + self.messages = messages or [] + self.sent = [] + self.closed = False + + def recv_json(self): + if not self.messages: + raise RuntimeError("No more messages") + return self.messages.pop(0) + + def send_json(self, msg): + self.sent.append(msg) + + def getsockopt(self, opt): + if opt == zmq.TYPE: + return self.socket_type + + def close(self): + self.closed = True + + +class FakeReceiver: + """Base class for main/actuation receivers.""" + def __init__(self, socket): + self.socket = socket + self._called = [] + + def handle_message(self, msg): + self._called.append(msg) + return {"endpoint": "pong", "data": "ok"} + + def close(self): + pass + + +class DummySender: + """Mock sender to test start methods.""" + def __init__(self): + self.called = False + + def start_video_rcv(self): + self.called = True + + def start(self): + self.called = True + + def close(self): + pass + + +@pytest.fixture +def fake_sockets(): + """Create default fake main and actuation sockets.""" + main_sock = FakeSocket(zmq.REP) + act_sock = FakeSocket(zmq.SUB) + return main_sock, act_sock + + +@pytest.fixture +def fake_poll(monkeypatch): + """Patch zmq.Poller to simulate a single polling cycle based on socket messages.""" + class FakePoller: + def __init__(self): + self.registered = {} + self.used = False + + def register(self, socket, flags): + self.registered[socket] = flags + + def poll(self, timeout): + # Only return sockets that still have messages + active_socks = { + s: flags + for s, flags + in self.registered.items() + if getattr(s, "messages", []) + } + if active_socks: + return active_socks + # No more messages, exit loop + state.exit_event.set() + return {} + + poller_instance = FakePoller() + monkeypatch.setattr(main_mod.zmq, "Poller", lambda: poller_instance) + return poller_instance + + +@pytest.fixture +def patched_main_components(monkeypatch, fake_sockets, fake_poll): + """ + Fixture to patch main receivers and senders with fakes. + Returns the fake instances for inspection in tests. + """ + main_sock, act_sock = fake_sockets + fake_main = FakeReceiver(main_sock) + fake_act = FakeReceiver(act_sock) + video_sender = DummySender() + audio_sender = DummySender() + + monkeypatch.setattr(main_mod, "MainReceiver", lambda ctx: fake_main) + monkeypatch.setattr(main_mod, "ActuationReceiver", lambda ctx: fake_act) + monkeypatch.setattr(main_mod, "VideoSender", lambda ctx: video_sender) + monkeypatch.setattr(main_mod, "AudioSender", lambda ctx: audio_sender) + + # Register sockets for the fake poller + fake_poll.registered = {main_sock: zmq.POLLIN, act_sock: zmq.POLLIN} + + return fake_main, fake_act, video_sender, audio_sender + + +def test_main_loop_rep_response(patched_main_components): + """REP socket returns proper response and handlers are called.""" + state.initialize() + fake_main, fake_act, video_sender, audio_sender = patched_main_components + + fake_main.socket.messages = [{"endpoint": "ping", "data": "x"}] + fake_act.socket.messages = [{"endpoint": "actuate/speech", "data": "hello"}] + + main_mod.main_loop(object()) + + assert fake_main.socket.sent == [{"endpoint": "pong", "data": "ok"}] + assert fake_main._called + assert fake_act._called + assert video_sender.called + assert audio_sender.called + state.deinitialize() + + +@pytest.mark.parametrize( + "messages", + [ + [{"no_endpoint": True}], # Invalid dict + [["not", "a", "dict"]] # Non-dict message + ] +) +def test_main_loop_invalid_or_non_dict_message(patched_main_components, messages): + """Invalid or non-dict messages are ignored.""" + state.initialize() + fake_main, _, _, _ = patched_main_components + + fake_main.socket.messages = messages + main_mod.main_loop(object()) + assert fake_main.socket.sent == [] + state.deinitialize() + + +def test_main_loop_handler_returns_none(patched_main_components, monkeypatch): + """Handler returning None still triggers send_json(None).""" + state.initialize() + fake_main, _, _, _ = patched_main_components + + class NoneHandler(FakeReceiver): + def handle_message(self, msg): + self._called.append(msg) + return None + + monkeypatch.setattr(main_mod, "MainReceiver", lambda ctx: NoneHandler(fake_main.socket)) + fake_main.socket.messages = [{"endpoint": "some", "data": None}] + + main_mod.main_loop(object()) + assert fake_main.socket.sent == [None] + state.deinitialize() + + +def test_main_loop_overtime_callback(patched_main_components, monkeypatch): + """TimeBlock callback is triggered if handler takes too long.""" + state.initialize() + fake_main, _, _, _ = patched_main_components + fake_main.socket.messages = [{"endpoint": "ping", "data": "x"}] + + class FakeTimeBlock: + def __init__(self, callback, limit_ms): + self.callback = callback + def __enter__(self): + return self + def __exit__(self, *a): + self.callback(999.0) + + monkeypatch.setattr(main_mod, "TimeBlock", FakeTimeBlock) + main_mod.main_loop(object()) + assert fake_main.socket.sent == [{"endpoint": "pong", "data": "ok"}] + state.deinitialize() + + +def test_main_keyboard_interrupt(monkeypatch): + """main() handles KeyboardInterrupt and cleans up.""" + called = {"deinitialized": False, "term_called": False} + + class FakeContext: + def term(self): called["term_called"] = True + + monkeypatch.setattr(main_mod.zmq, "Context", lambda: FakeContext()) + + def raise_keyboard_interrupt(*_): + raise KeyboardInterrupt() + monkeypatch.setattr(main_mod, "main_loop", raise_keyboard_interrupt) + + def fake_initialize(): + state.is_initialized = True + state.exit_event = threading.Event() + def fake_deinitialize(): + called["deinitialized"] = True + state.is_initialized = False + + monkeypatch.setattr(main_mod.state, "initialize", fake_initialize) + monkeypatch.setattr(main_mod.state, "deinitialize", fake_deinitialize) + + main_mod.main() + assert called["term_called"] is True + assert called["deinitialized"] is True