Merge remote-tracking branch 'origin/dev' into demo

This commit is contained in:
Twirre Meulenbelt
2025-11-05 12:38:08 +01:00
15 changed files with 187 additions and 163 deletions

View File

@@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# This script checks if the current branch name follows the specified format. # This script checks if the current branch name follows the specified format.
# It's designed to be used as a 'pre-commit' git hook. # It's designed to be used as a 'pre-commit' git hook.
@@ -10,7 +10,7 @@
# An array of allowed commit types # An array of allowed commit types
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
# An array of branches to ignore # An array of branches to ignore
IGNORED_BRANCHES=(main dev) IGNORED_BRANCHES=(main dev demo)
# --- Colors for Output --- # --- Colors for Output ---
RED='\033[0;31m' RED='\033[0;31m'

View File

@@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# This script checks if a commit message follows the specified format. # This script checks if a commit message follows the specified format.
# It's designed to be used as a 'commit-msg' git hook. # It's designed to be used as a 'commit-msg' git hook.
@@ -23,6 +23,44 @@ NC='\033[0m' # No Color
# The first argument to the hook is the path to the file containing the commit message # The first argument to the hook is the path to the file containing the commit message
COMMIT_MSG_FILE=$1 COMMIT_MSG_FILE=$1
# --- Automated Commit Detection ---
# Read the first line (header) for initial checks
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Revert commits
# Example: "Revert "feat: add new feature""
REVERT_PATTERN="^Revert \".*\""
if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then
echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Cherry-pick commits (this pattern appears at the end of the message)
# Example: "(cherry picked from commit deadbeef...)"
# We use grep -q to search the whole file quietly.
CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)"
if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then
echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Squash
# Example: "Squash commits ..."
SQUASH_PATTERN="^Squash .+"
if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then
echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# --- Validation Functions --- # --- Validation Functions ---
# Function to print an error message and exit # Function to print an error message and exit
@@ -56,20 +94,24 @@ if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then
error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature" error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature"
fi fi
# 3. Validate the footer (last line) of the commit message # Only validate footer if commit type is not chore
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE") TYPE=$(echo "$HEADER" | cut -d':' -f1)
if [ "$TYPE" != "chore" ]; then
# 3. Validate the footer (last line) of the commit message
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE")
# Regex breakdown: # Regex breakdown:
# ^(ref|close) - Starts with 'ref' or 'close' # ^(ref|close) - Starts with 'ref' or 'close'
# : - Followed by a literal colon # : - Followed by a literal colon
# \s - Followed by a single space # \s - Followed by a single space
# N25B- - Followed by the literal string 'N25B-' # N25B- - Followed by the literal string 'N25B-'
# [0-9]+ - Followed by one or more digits # [0-9]+ - Followed by one or more digits
# $ - End of the line # $ - End of the line
FOOTER_REGEX="^(ref|close): N25B-[0-9]+$" FOOTER_REGEX="^(ref|close): N25B-[0-9]+$"
if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then
error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123" error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123"
fi
fi fi
# 4. If the message has more than 2 lines, validate the separator # 4. If the message has more than 2 lines, validate the separator

View File

@@ -5,9 +5,9 @@ import spade.agent
import zmq import zmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -71,6 +71,8 @@ class RICommandAgent(Agent):
""" """
logger.info("Setting up %s", self.jid) logger.info("Setting up %s", self.jid)
context = Context.instance()
# To the robot # To the robot
self.pubsocket = context.socket(zmq.PUB) self.pubsocket = context.socket(zmq.PUB)
if self.bind: if self.bind:
@@ -80,7 +82,7 @@ class RICommandAgent(Agent):
# Receive internal topics regarding commands # Receive internal topics regarding commands
self.subsocket = context.socket(zmq.SUB) self.subsocket = context.socket(zmq.SUB)
self.subsocket.connect(settings.zmq_settings.internal_comm_address) self.subsocket.connect(settings.zmq_settings.internal_sub_address)
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent # Add behaviour to our agent

View File

@@ -4,10 +4,10 @@ import logging
import zmq import zmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -74,7 +74,7 @@ class RICommunicationAgent(Agent):
# Let's try a certain amount of times before failing connection # Let's try a certain amount of times before failing connection
while retries < max_retries: while retries < max_retries:
# Bind request socket # Bind request socket
self.req_socket = context.socket(zmq.REQ) self.req_socket = Context.instance().socket(zmq.REQ)
if self._bind: if self._bind:
self.req_socket.bind(self._address) self.req_socket.bind(self._address)
else: else:

View File

@@ -10,7 +10,6 @@ from spade.message import Message
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -73,7 +72,7 @@ class TranscriptionAgent(Agent):
return await super().stop() return await super().stop()
def _connect_audio_in_socket(self): def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB) self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.audio_in_socket.connect(self.audio_in_address) self.audio_in_socket.connect(self.audio_in_address)

View File

@@ -9,7 +9,6 @@ from spade.behaviour import CyclicBehaviour
from control_backend.agents.transcription import TranscriptionAgent from control_backend.agents.transcription import TranscriptionAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -66,7 +65,8 @@ class Streaming(CyclicBehaviour):
self._ready = True self._ready = True
async def run(self) -> None: async def run(self) -> None:
if not self._ready: return if not self._ready:
return
data = await self.audio_in_poller.poll() data = await self.audio_in_poller.poll()
if data is None: if data is None:
@@ -134,7 +134,7 @@ class VADAgent(Agent):
return await super().stop() return await super().stop()
def _connect_audio_in_socket(self): def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB) self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind: if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address) self.audio_in_socket.bind(self.audio_in_address)
@@ -145,7 +145,7 @@ class VADAgent(Agent):
def _connect_audio_out_socket(self) -> int | None: def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed.""" """Returns the port bound, or None if binding failed."""
try: try:
self.audio_out_socket = zmq_context.socket(zmq.PUB) self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError: except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.") logger.error("Failed to bind an audio output socket after 100 tries.")

View File

@@ -1,7 +1,6 @@
import logging import logging
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from zmq import Socket
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
@@ -15,7 +14,7 @@ async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data. # Validate and retrieve data.
SpeechCommand.model_validate(command) SpeechCommand.model_validate(command)
topic = b"command" topic = b"command"
pub_socket: Socket = request.app.state.internal_comm_socket pub_socket = request.app.state.endpoints_pub_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()]) await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"} return {"status": "Command received"}

View File

@@ -1,7 +1,6 @@
import logging import logging
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from zmq import Socket
from control_backend.schemas.message import Message from control_backend.schemas.message import Message
@@ -17,8 +16,7 @@ async def receive_message(message: Message, request: Request):
topic = b"message" topic = b"message"
body = message.model_dump_json().encode("utf-8") body = message.model_dump_json().encode("utf-8")
pub_socket: Socket = request.app.state.internal_comm_socket pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
pub_socket.send_multipart([topic, body])
return {"status": "Message received"} return {"status": "Message received"}

View File

@@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel): class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560" internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561"
class AgentSettings(BaseModel): class AgentSettings(BaseModel):

View File

@@ -1,3 +0,0 @@
from zmq.asyncio import Context
context = Context()

View File

@@ -3,37 +3,59 @@
# External imports # External imports
import contextlib import contextlib
import logging import logging
import threading
import zmq import zmq
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context
from control_backend.agents.bdi.bdi_core import BDICoreAgent from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent from control_backend.agents.llm.llm import LLMAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.vad_agent import VADAgent from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
def setup_sockets():
context = Context.instance()
internal_pub_socket = context.socket(zmq.XPUB)
internal_pub_socket.bind(settings.zmq_settings.internal_sub_address)
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
internal_sub_socket = context.socket(zmq.XSUB)
internal_sub_socket.bind(settings.zmq_settings.internal_pub_address)
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
try:
zmq.proxy(internal_sub_socket, internal_pub_socket)
except zmq.ZMQError:
logger.warning("Error while handling PUB/SUB proxy. Closing sockets.")
finally:
internal_pub_socket.close()
internal_sub_socket.close()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("%s starting up.", app.title) logger.info("%s starting up.", app.title)
# Initiate sockets # Initiate sockets
internal_comm_socket = context.socket(zmq.PUB) proxy_thread = threading.Thread(target=setup_sockets)
internal_comm_address = settings.zmq_settings.internal_comm_address proxy_thread.daemon = True
internal_comm_socket.bind(internal_comm_address) proxy_thread.start()
app.state.internal_comm_socket = internal_comm_socket
logger.info("Internal publishing socket bound to %s", internal_comm_socket) context = Context.instance()
endpoints_pub_socket = context.socket(zmq.PUB)
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
app.state.endpoints_pub_socket = endpoints_pub_socket
# Initiate agents # Initiate agents
ri_communication_agent = RICommunicationAgent( ri_communication_agent = RICommunicationAgent(

View File

@@ -7,19 +7,21 @@ import zmq
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.ri_command_agent import RICommandAgent
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_bind(monkeypatch): async def test_setup_bind(zmq_context, mocker):
"""Test setup with bind=True""" """Test setup with bind=True"""
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
monkeypatch.setattr( settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
"control_backend.agents.ri_command_agent.settings", settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup() await agent.setup()
@@ -34,18 +36,13 @@ async def test_setup_bind(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_connect(monkeypatch): async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False""" """Test setup with bind=False"""
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
monkeypatch.setattr( settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
"control_backend.agents.ri_command_agent.settings", settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup() await agent.setup()

View File

@@ -84,21 +84,23 @@ def fake_json_invalid_id_negototiate():
) )
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(monkeypatch): async def test_setup_creates_socket_and_negotiate_1(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1() fake_socket.recv_json = fake_json_correct_negototiate_1()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -128,20 +130,15 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(monkeypatch): async def test_setup_creates_socket_and_negotiate_2(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2() fake_socket.recv_json = fake_json_correct_negototiate_2()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -171,20 +168,15 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect negotiation message Test the functionality of setup with incorrect negotiation message
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1() fake_socket.recv_json = fake_json_wrong_negototiate_1()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
@@ -215,20 +207,15 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(monkeypatch): async def test_setup_creates_socket_and_negotiate_4(zmq_context):
""" """
Test the setup of the communication agent with different bind value Test the setup of the communication agent with different bind value
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3() fake_socket.recv_json = fake_json_correct_negototiate_3()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -258,20 +245,15 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(monkeypatch): async def test_setup_creates_socket_and_negotiate_5(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4() fake_socket.recv_json = fake_json_correct_negototiate_4()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -301,20 +283,15 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(monkeypatch): async def test_setup_creates_socket_and_negotiate_6(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5() fake_socket.recv_json = fake_json_correct_negototiate_5()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -344,20 +321,15 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect id Test the functionality of setup with incorrect id
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate() fake_socket.recv_json = fake_json_invalid_id_negototiate()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
@@ -385,20 +357,15 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect negotiation message Test the functionality of setup with incorrect negotiation message
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent: ) as MockCommandAgent:
@@ -480,8 +447,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_timeout(caplog): async def test_listen_behaviour_timeout(zmq_context, caplog):
fake_socket = AsyncMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout # recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
@@ -529,16 +496,12 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unexpected_exception(monkeypatch, caplog): async def test_setup_unexpected_exception(zmq_context, caplog):
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json() # Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server", "password", address="tcp://localhost:5555", bind=False
) )
@@ -551,9 +514,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unpacking_exception(monkeypatch, caplog): async def test_setup_unpacking_exception(zmq_context, caplog):
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# Make recv_json return malformed negotiation data to trigger unpacking exception # Make recv_json return malformed negotiation data to trigger unpacking exception
@@ -563,11 +526,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
} # missing 'port' and 'bind' } # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data) fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch context.socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Patch RICommandAgent so it won't actually start # Patch RICommandAgent so it won't actually start
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True

View File

@@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
return mocker.patch("control_backend.agents.vad_agent.zmq_context") mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.fixture @pytest.fixture
@@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
assert vad_agent.audio_in_socket is not None assert vad_agent.audio_in_socket is not None
zmq_context.socket.assert_called_once_with(zmq.SUB) zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
zmq.SUBSCRIBE,
"",
)
if do_bind: if do_bind:
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else: else:
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
"tcp://localhost:12345"
)
def test_out_socket_creation(zmq_context): def test_out_socket_creation(zmq_context):
@@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context):
assert vad_agent.audio_out_socket is not None assert vad_agent.audio_out_socket is not None
zmq_context.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.socket.return_value.bind_to_random_port.assert_called_once() zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context):
Test setup failure when the audio output socket cannot be created. Test setup failure when the audio output socket cannot be created.
""" """
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError
)
vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup() await vad_agent.setup()
@@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000,
10000,
)
await vad_agent.setup() await vad_agent.setup()
await vad_agent.stop() await vad_agent.stop()
assert zmq_context.socket.return_value.close.call_count == 2 assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None assert vad_agent.audio_out_socket is None

View File

@@ -1,4 +1,4 @@
from unittest.mock import MagicMock from unittest.mock import AsyncMock
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
@@ -16,7 +16,6 @@ def app():
""" """
app = FastAPI() app = FastAPI()
app.include_router(command.router) app.include_router(command.router)
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
return app return app
@@ -26,32 +25,30 @@ def client(app):
return TestClient(app) return TestClient(app)
def test_receive_command_endpoint(client, app): def test_receive_command_success(client):
""" """
Test that a POST to /command sends the right multipart message Test for successful reception of a command. Ensures the status code is 202 and the response body
and returns a 202 with the expected JSON body. is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
""" """
mock_socket = app.state.internal_comm_socket # Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Prepare test payload that matches SpeechCommand command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
payload = {"endpoint": "actuate/speech", "data": "yooo"} speech_command = SpeechCommand(**command_data)
# Send POST request # Act
response = client.post("/command", json=payload) response = client.post("/command", json=command_data)
# Check response # Assert
assert response.status_code == 202 assert response.status_code == 202
assert response.json() == {"status": "Command received"} assert response.json() == {"status": "Command received"}
# Verify that the socket was called with the correct data # Verify that the ZMQ socket was used correctly
assert mock_socket.send_multipart.called, "Socket should be used to send data" mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
args, kwargs = mock_socket.send_multipart.call_args )
sent_data = args[0]
assert sent_data[0] == b"command"
# Check JSON encoding roughly matches
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
def test_receive_command_invalid_payload(client): def test_receive_command_invalid_payload(client):