diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh index eacf2a8..497a32f 100755 --- a/.githooks/check-commit-msg.sh +++ b/.githooks/check-commit-msg.sh @@ -30,7 +30,7 @@ HEADER=$(head -n 1 "$COMMIT_MSG_FILE") # Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) # Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." -MERGE_PATTERN="^Merge (branch|pull request|tag) .*" +MERGE_PATTERN="^Merge (remote-tracking )?(branch|pull request|tag) .*" if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" exit 0 diff --git a/src/control_backend/agents/actuation/robot_speech_agent.py b/src/control_backend/agents/actuation/robot_speech_agent.py index 65ac7dc..15fa07f 100644 --- a/src/control_backend/agents/actuation/robot_speech_agent.py +++ b/src/control_backend/agents/actuation/robot_speech_agent.py @@ -60,7 +60,7 @@ class RobotSpeechAgent(BaseAgent): self.subsocket.connect(settings.zmq_settings.internal_sub_address) self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") - await self.add_behavior(self._zmq_command_loop()) + self.add_behavior(self._zmq_command_loop()) self.logger.info("Finished setting up %s", self.name) diff --git a/src/control_backend/agents/bdi/bdi_core_agent/bdi_core_agent.py b/src/control_backend/agents/bdi/bdi_core_agent/bdi_core_agent.py index 124f537..f056e09 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent/bdi_core_agent.py +++ b/src/control_backend/agents/bdi/bdi_core_agent/bdi_core_agent.py @@ -65,7 +65,7 @@ class BDICoreAgent(BaseAgent): await self._load_asl() # Start the BDI cycle loop - await self.add_behavior(self._bdi_loop()) + self.add_behavior(self._bdi_loop()) self._wake_bdi_loop.set() self.logger.debug("Setup complete.") diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index f910ff1..f8e826c 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -73,7 +73,7 @@ class BDIProgramManager(BaseAgent): try: program = Program.model_validate_json(body) except ValidationError as e: - self.logger.error("Received an invalid program.", exc_info=e) + self.logger.exception("Received an invalid program.") continue await self._send_to_bdi(program) @@ -91,4 +91,4 @@ class BDIProgramManager(BaseAgent): self.sub_socket.connect(settings.zmq_settings.internal_sub_address) self.sub_socket.subscribe("program") - await self.add_behavior(self._receive_programs()) + self.add_behavior(self._receive_programs()) diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index 401084a..a57400e 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -60,7 +60,7 @@ class RICommunicationAgent(BaseAgent): if await self._negotiate_connection(): self.connected = True - await self.add_behavior(self._listen_loop()) + self.add_behavior(self._listen_loop()) else: self.logger.warning("Failed to negotiate connection during setup.") diff --git a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py index 7e58bb3..765d7ac 100644 --- a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py +++ b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py @@ -59,7 +59,7 @@ class TranscriptionAgent(BaseAgent): self.speech_recognizer.load_model() # Warmup # Start background loop - await self.add_behavior(self._transcribing_loop()) + self.add_behavior(self._transcribing_loop()) self.logger.info("Finished setting up %s", self.name) diff --git a/src/control_backend/agents/perception/vad_agent.py b/src/control_backend/agents/perception/vad_agent.py index fb9a197..48ac741 100644 --- a/src/control_backend/agents/perception/vad_agent.py +++ b/src/control_backend/agents/perception/vad_agent.py @@ -120,7 +120,7 @@ class VADAgent(BaseAgent): # Warmup/reset await self.reset_stream() - await self.add_behavior(self._streaming_loop()) + self.add_behavior(self._streaming_loop()) # Start agents dependent on the output audio fragments here transcriber = TranscriptionAgent(audio_out_address) diff --git a/src/control_backend/core/agent_system.py b/src/control_backend/core/agent_system.py index 3d82182..9d7a47f 100644 --- a/src/control_backend/core/agent_system.py +++ b/src/control_backend/core/agent_system.py @@ -1,6 +1,7 @@ import asyncio import logging from abc import ABC, abstractmethod +from asyncio import Task from collections.abc import Coroutine import zmq @@ -102,8 +103,8 @@ class BaseAgent(ABC): await self.setup() # Start processing inbox and ZMQ messages - await self.add_behavior(self._process_inbox()) - await self.add_behavior(self._receive_internal_zmq_loop()) + self.add_behavior(self._process_inbox()) + self.add_behavior(self._receive_internal_zmq_loop()) async def stop(self): """ @@ -182,7 +183,7 @@ class BaseAgent(ABC): """ raise NotImplementedError - async def add_behavior(self, coro: Coroutine): + def add_behavior(self, coro: Coroutine) -> Task: """ Add a background behavior (task) to the agent. @@ -194,3 +195,4 @@ class BaseAgent(ABC): task = asyncio.create_task(coro) self._tasks.add(task) task.add_done_callback(self._tasks.discard) + return task diff --git a/test/unit/agents/actuation/test_robot_speech_agent.py b/test/unit/agents/actuation/test_robot_speech_agent.py index 1ec2c6f..15324f6 100644 --- a/test/unit/agents/actuation/test_robot_speech_agent.py +++ b/test/unit/agents/actuation/test_robot_speech_agent.py @@ -25,24 +25,14 @@ async def test_setup_bind(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - # Swallow background task coroutines to avoid un-awaited warnings - class Swallow: - def __init__(self): - self.calls = 0 - - async def __call__(self, coro): - self.calls += 1 - coro.close() - - swallow = Swallow() - agent.add_behavior = swallow + agent.add_behavior = MagicMock() await agent.setup() fake_socket.bind.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") - assert swallow.calls == 1 + agent.add_behavior.assert_called_once() @pytest.mark.asyncio @@ -53,22 +43,13 @@ async def test_setup_connect(zmq_context, mocker): settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings.zmq_settings.internal_sub_address = "tcp://internal:1234" - class Swallow: - def __init__(self): - self.calls = 0 - - async def __call__(self, coro): - self.calls += 1 - coro.close() - - swallow = Swallow() - agent.add_behavior = swallow + agent.add_behavior = MagicMock() await agent.setup() fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://internal:1234") - assert swallow.calls == 1 + agent.add_behavior.assert_called_once() @pytest.mark.asyncio diff --git a/test/unit/agents/communication/test_ri_communication_agent.py b/test/unit/agents/communication/test_ri_communication_agent.py index 20b9379..747c4d2 100644 --- a/test/unit/agents/communication/test_ri_communication_agent.py +++ b/test/unit/agents/communication/test_ri_communication_agent.py @@ -46,16 +46,7 @@ async def test_setup_success_connects_and_starts_robot(zmq_context): robot_instance.start = AsyncMock() agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) - class Swallow: - def __init__(self): - self.calls = 0 - - async def __call__(self, coro): - self.calls += 1 - coro.close() - - swallow = Swallow() - agent.add_behavior = swallow + agent.add_behavior = MagicMock() await agent.setup() @@ -63,7 +54,7 @@ async def test_setup_success_connects_and_starts_robot(zmq_context): fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) robot_instance.start.assert_awaited_once() MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True) - assert swallow.calls == 1 + agent.add_behavior.assert_called_once() assert agent.connected is True @@ -76,23 +67,14 @@ async def test_setup_binds_when_requested(zmq_context): agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True) - class Swallow: - def __init__(self): - self.calls = 0 - - async def __call__(self, coro): - self.calls += 1 - coro.close() - - swallow = Swallow() - agent.add_behavior = swallow + agent.add_behavior = MagicMock() with patch(speech_agent_path(), autospec=True) as MockRobot: MockRobot.return_value.start = AsyncMock() await agent.setup() fake_socket.bind.assert_any_call("tcp://localhost:5555") - assert swallow.calls == 1 + agent.add_behavior.assert_called_once() @pytest.mark.asyncio diff --git a/test/unit/core/test_agent_system.py b/test/unit/core/test_agent_system.py index 5e954c8..f78b230 100644 --- a/test/unit/core/test_agent_system.py +++ b/test/unit/core/test_agent_system.py @@ -33,14 +33,16 @@ async def test_agent_lifecycle(): # Test background task async def dummy_task(): - await asyncio.sleep(0.01) + pass - await agent.add_behavior(dummy_task()) - assert len(agent._tasks) > 0 + task = agent.add_behavior(dummy_task()) + assert task in agent._tasks + + await task # Wait for task to finish - await asyncio.sleep(0.02) - assert len(agent._tasks) == 2 # message handling tasks are running + assert task not in agent._tasks + assert len(agent._tasks) == 2 # message handling tasks are still running await agent.stop() assert agent._running is False