Merge remote-tracking branch 'origin/dev' into feat/norms-and-goals-program

This commit is contained in:
Twirre Meulenbelt
2025-11-25 11:29:27 +01:00
10 changed files with 26 additions and 59 deletions

View File

@@ -30,7 +30,7 @@ HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) # Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." # Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
MERGE_PATTERN="^Merge (branch|pull request|tag) .*" MERGE_PATTERN="^Merge (remote-tracking )?(branch|pull request|tag) .*"
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
exit 0 exit 0

View File

@@ -45,7 +45,7 @@ class RobotSpeechAgent(BaseAgent):
self.subsocket.connect(settings.zmq_settings.internal_sub_address) self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
await self.add_behavior(self._zmq_command_loop()) self.add_behavior(self._zmq_command_loop())
self.logger.info("Finished setting up %s", self.name) self.logger.info("Finished setting up %s", self.name)

View File

@@ -37,7 +37,7 @@ class BDICoreAgent(BaseAgent):
await self._load_asl() await self._load_asl()
# Start the BDI cycle loop # Start the BDI cycle loop
await self.add_behavior(self._bdi_loop()) self.add_behavior(self._bdi_loop())
self._wake_bdi_loop.set() self._wake_bdi_loop.set()
self.logger.debug("Setup complete.") self.logger.debug("Setup complete.")

View File

@@ -37,7 +37,7 @@ class RICommunicationAgent(BaseAgent):
if await self._negotiate_connection(): if await self._negotiate_connection():
self.connected = True self.connected = True
await self.add_behavior(self._listen_loop()) self.add_behavior(self._listen_loop())
else: else:
self.logger.warning("Failed to negotiate connection during setup.") self.logger.warning("Failed to negotiate connection during setup.")

View File

@@ -37,7 +37,7 @@ class TranscriptionAgent(BaseAgent):
self.speech_recognizer.load_model() # Warmup self.speech_recognizer.load_model() # Warmup
# Start background loop # Start background loop
await self.add_behavior(self._transcribing_loop()) self.add_behavior(self._transcribing_loop())
self.logger.info("Finished setting up %s", self.name) self.logger.info("Finished setting up %s", self.name)

View File

@@ -93,7 +93,7 @@ class VADAgent(BaseAgent):
# Warmup/reset # Warmup/reset
await self.reset_stream() await self.reset_stream()
await self.add_behavior(self._streaming_loop()) self.add_behavior(self._streaming_loop())
# Start agents dependent on the output audio fragments here # Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address) transcriber = TranscriptionAgent(audio_out_address)

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import Task
from collections.abc import Coroutine from collections.abc import Coroutine
import zmq import zmq
@@ -75,8 +76,8 @@ class BaseAgent(ABC):
await self.setup() await self.setup()
# Start processing inbox and ZMQ messages # Start processing inbox and ZMQ messages
await self.add_behavior(self._process_inbox()) self.add_behavior(self._process_inbox())
await self.add_behavior(self._receive_internal_zmq_loop()) self.add_behavior(self._receive_internal_zmq_loop())
async def stop(self): async def stop(self):
"""Stops the agent.""" """Stops the agent."""
@@ -128,7 +129,7 @@ class BaseAgent(ABC):
"""Override this to handle incoming messages.""" """Override this to handle incoming messages."""
raise NotImplementedError raise NotImplementedError
async def add_behavior(self, coro: Coroutine): def add_behavior(self, coro: Coroutine) -> Task:
""" """
Helper to add a behavior to the agent. To add asynchronous behavior to an agent, define Helper to add a behavior to the agent. To add asynchronous behavior to an agent, define
an `async` function and add it to the task list by calling :func:`add_behavior` an `async` function and add it to the task list by calling :func:`add_behavior`
@@ -138,3 +139,4 @@ class BaseAgent(ABC):
task = asyncio.create_task(coro) task = asyncio.create_task(coro)
self._tasks.add(task) self._tasks.add(task)
task.add_done_callback(self._tasks.discard) task.add_done_callback(self._tasks.discard)
return task

View File

@@ -25,24 +25,14 @@ async def test_setup_bind(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
# Swallow background task coroutines to avoid un-awaited warnings agent.add_behavior = MagicMock()
class Swallow:
def __init__(self):
self.calls = 0
async def __call__(self, coro):
self.calls += 1
coro.close()
swallow = Swallow()
agent.add_behavior = swallow
await agent.setup() await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555") fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
assert swallow.calls == 1 agent.add_behavior.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -53,22 +43,13 @@ async def test_setup_connect(zmq_context, mocker):
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings") settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
class Swallow: agent.add_behavior = MagicMock()
def __init__(self):
self.calls = 0
async def __call__(self, coro):
self.calls += 1
coro.close()
swallow = Swallow()
agent.add_behavior = swallow
await agent.setup() await agent.setup()
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.connect.assert_any_call("tcp://internal:1234")
assert swallow.calls == 1 agent.add_behavior.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -46,16 +46,7 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
robot_instance.start = AsyncMock() robot_instance.start = AsyncMock()
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False) agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
class Swallow: agent.add_behavior = MagicMock()
def __init__(self):
self.calls = 0
async def __call__(self, coro):
self.calls += 1
coro.close()
swallow = Swallow()
agent.add_behavior = swallow
await agent.setup() await agent.setup()
@@ -63,7 +54,7 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}}) fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
robot_instance.start.assert_awaited_once() robot_instance.start.assert_awaited_once()
MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True) MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True)
assert swallow.calls == 1 agent.add_behavior.assert_called_once()
assert agent.connected is True assert agent.connected is True
@@ -76,23 +67,14 @@ async def test_setup_binds_when_requested(zmq_context):
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True) agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
class Swallow: agent.add_behavior = MagicMock()
def __init__(self):
self.calls = 0
async def __call__(self, coro):
self.calls += 1
coro.close()
swallow = Swallow()
agent.add_behavior = swallow
with patch(speech_agent_path(), autospec=True) as MockRobot: with patch(speech_agent_path(), autospec=True) as MockRobot:
MockRobot.return_value.start = AsyncMock() MockRobot.return_value.start = AsyncMock()
await agent.setup() await agent.setup()
fake_socket.bind.assert_any_call("tcp://localhost:5555") fake_socket.bind.assert_any_call("tcp://localhost:5555")
assert swallow.calls == 1 agent.add_behavior.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -33,14 +33,16 @@ async def test_agent_lifecycle():
# Test background task # Test background task
async def dummy_task(): async def dummy_task():
await asyncio.sleep(0.01) pass
await agent.add_behavior(dummy_task()) task = agent.add_behavior(dummy_task())
assert len(agent._tasks) > 0 assert task in agent._tasks
await task
# Wait for task to finish # Wait for task to finish
await asyncio.sleep(0.02) assert task not in agent._tasks
assert len(agent._tasks) == 2 # message handling tasks are running assert len(agent._tasks) == 2 # message handling tasks are still running
await agent.stop() await agent.stop()
assert agent._running is False assert agent._running is False