diff --git a/.githooks/check-branch-name.sh b/.githooks/check-branch-name.sh index 752e199..6a6669a 100755 --- a/.githooks/check-branch-name.sh +++ b/.githooks/check-branch-name.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # This script checks if the current branch name follows the specified format. # It's designed to be used as a 'pre-commit' git hook. @@ -10,7 +10,7 @@ # An array of allowed commit types ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) # An array of branches to ignore -IGNORED_BRANCHES=(main dev) +IGNORED_BRANCHES=(main dev demo) # --- Colors for Output --- RED='\033[0;31m' diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh index 82bd441..eacf2a8 100755 --- a/.githooks/check-commit-msg.sh +++ b/.githooks/check-commit-msg.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # This script checks if a commit message follows the specified format. # It's designed to be used as a 'commit-msg' git hook. @@ -23,6 +23,44 @@ NC='\033[0m' # No Color # The first argument to the hook is the path to the file containing the commit message COMMIT_MSG_FILE=$1 +# --- Automated Commit Detection --- + +# Read the first line (header) for initial checks +HEADER=$(head -n 1 "$COMMIT_MSG_FILE") + +# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) +# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." +MERGE_PATTERN="^Merge (branch|pull request|tag) .*" +if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then + echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Revert commits +# Example: "Revert "feat: add new feature"" +REVERT_PATTERN="^Revert \".*\"" +if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then + echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Cherry-pick commits (this pattern appears at the end of the message) +# Example: "(cherry picked from commit deadbeef...)" +# We use grep -q to search the whole file quietly. +CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)" +if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then + echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Squash +# Example: "Squash commits ..." +SQUASH_PATTERN="^Squash .+" +if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then + echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + # --- Validation Functions --- # Function to print an error message and exit @@ -56,20 +94,24 @@ if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then error_exit "Invalid header format.\n\nHeader must be in the format: : \nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature" fi -# 3. Validate the footer (last line) of the commit message -FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE") +# Only validate footer if commit type is not chore +TYPE=$(echo "$HEADER" | cut -d':' -f1) +if [ "$TYPE" != "chore" ]; then + # 3. Validate the footer (last line) of the commit message + FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE") -# Regex breakdown: -# ^(ref|close) - Starts with 'ref' or 'close' -# : - Followed by a literal colon -# \s - Followed by a single space -# N25B- - Followed by the literal string 'N25B-' -# [0-9]+ - Followed by one or more digits -# $ - End of the line -FOOTER_REGEX="^(ref|close): N25B-[0-9]+$" + # Regex breakdown: + # ^(ref|close) - Starts with 'ref' or 'close' + # : - Followed by a literal colon + # \s - Followed by a single space + # N25B- - Followed by the literal string 'N25B-' + # [0-9]+ - Followed by one or more digits + # $ - End of the line + FOOTER_REGEX="^(ref|close): N25B-[0-9]+$" -if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then - error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: \nExample: ref: N25B-123" + if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then + error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: \nExample: ref: N25B-123" + fi fi # 4. If the message has more than 2 lines, validate the separator diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 9c5d2bc..dac41f3 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -5,9 +5,9 @@ import spade.agent import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour +from zmq.asyncio import Context from control_backend.core.config import settings -from control_backend.core.zmq_context import context from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) @@ -71,6 +71,8 @@ class RICommandAgent(Agent): """ logger.info("Setting up %s", self.jid) + context = Context.instance() + # To the robot self.pubsocket = context.socket(zmq.PUB) if self.bind: @@ -80,7 +82,7 @@ class RICommandAgent(Agent): # Receive internal topics regarding commands self.subsocket = context.socket(zmq.SUB) - self.subsocket.connect(settings.zmq_settings.internal_comm_address) + self.subsocket.connect(settings.zmq_settings.internal_sub_address) self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") # Add behaviour to our agent diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 8d56b09..638b967 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -4,10 +4,10 @@ import logging import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour +from zmq.asyncio import Context from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.core.config import settings -from control_backend.core.zmq_context import context logger = logging.getLogger(__name__) @@ -74,7 +74,7 @@ class RICommunicationAgent(Agent): # Let's try a certain amount of times before failing connection while retries < max_retries: # Bind request socket - self.req_socket = context.socket(zmq.REQ) + self.req_socket = Context.instance().socket(zmq.REQ) if self._bind: self.req_socket.bind(self._address) else: diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index 196fd28..25103a4 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -10,7 +10,6 @@ from spade.message import Message from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer from control_backend.core.config import settings -from control_backend.core.zmq_context import context as zmq_context logger = logging.getLogger(__name__) @@ -73,7 +72,7 @@ class TranscriptionAgent(Agent): return await super().stop() def _connect_audio_in_socket(self): - self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.connect(self.audio_in_address) diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index 5b7f598..9cf2adf 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -9,7 +9,6 @@ from spade.behaviour import CyclicBehaviour from control_backend.agents.transcription import TranscriptionAgent from control_backend.core.config import settings -from control_backend.core.zmq_context import context as zmq_context logger = logging.getLogger(__name__) @@ -66,7 +65,8 @@ class Streaming(CyclicBehaviour): self._ready = True async def run(self) -> None: - if not self._ready: return + if not self._ready: + return data = await self.audio_in_poller.poll() if data is None: @@ -134,7 +134,7 @@ class VADAgent(Agent): return await super().stop() def _connect_audio_in_socket(self): - self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") if self.audio_in_bind: self.audio_in_socket.bind(self.audio_in_address) @@ -145,7 +145,7 @@ class VADAgent(Agent): def _connect_audio_out_socket(self) -> int | None: """Returns the port bound, or None if binding failed.""" try: - self.audio_out_socket = zmq_context.socket(zmq.PUB) + self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) except zmq.ZMQBindError: logger.error("Failed to bind an audio output socket after 100 tries.") diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py index e19290f..d7f963b 100644 --- a/src/control_backend/api/v1/endpoints/command.py +++ b/src/control_backend/api/v1/endpoints/command.py @@ -1,7 +1,6 @@ import logging from fastapi import APIRouter, Request -from zmq import Socket from control_backend.schemas.ri_message import SpeechCommand @@ -15,7 +14,7 @@ async def receive_command(command: SpeechCommand, request: Request): # Validate and retrieve data. SpeechCommand.model_validate(command) topic = b"command" - pub_socket: Socket = request.app.state.internal_comm_socket - pub_socket.send_multipart([topic, command.model_dump_json().encode()]) + pub_socket = request.app.state.endpoints_pub_socket + await pub_socket.send_multipart([topic, command.model_dump_json().encode()]) return {"status": "Command received"} diff --git a/src/control_backend/api/v1/endpoints/message.py b/src/control_backend/api/v1/endpoints/message.py index 1053c3c..bd88a0b 100644 --- a/src/control_backend/api/v1/endpoints/message.py +++ b/src/control_backend/api/v1/endpoints/message.py @@ -1,7 +1,6 @@ import logging from fastapi import APIRouter, Request -from zmq import Socket from control_backend.schemas.message import Message @@ -17,8 +16,7 @@ async def receive_message(message: Message, request: Request): topic = b"message" body = message.model_dump_json().encode("utf-8") - pub_socket: Socket = request.app.state.internal_comm_socket - - pub_socket.send_multipart([topic, body]) + pub_socket = request.app.state.endpoints_pub_socket + await pub_socket.send_multipart([topic, body]) return {"status": "Message received"} diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 2fd16b8..8de2403 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class ZMQSettings(BaseModel): - internal_comm_address: str = "tcp://localhost:5560" + internal_pub_address: str = "tcp://localhost:5560" + internal_sub_address: str = "tcp://localhost:5561" class AgentSettings(BaseModel): diff --git a/src/control_backend/core/zmq_context.py b/src/control_backend/core/zmq_context.py deleted file mode 100644 index a74544f..0000000 --- a/src/control_backend/core/zmq_context.py +++ /dev/null @@ -1,3 +0,0 @@ -from zmq.asyncio import Context - -context = Context() diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 4fddc1e..043eefd 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -3,37 +3,59 @@ # External imports import contextlib import logging +import threading import zmq from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from zmq.asyncio import Context from control_backend.agents.bdi.bdi_core import BDICoreAgent from control_backend.agents.bdi.text_extractor import TBeliefExtractor from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent from control_backend.agents.llm.llm import LLMAgent - -# Internal imports from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.vad_agent import VADAgent from control_backend.api.v1.router import api_router from control_backend.core.config import settings -from control_backend.core.zmq_context import context logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) +def setup_sockets(): + context = Context.instance() + + internal_pub_socket = context.socket(zmq.XPUB) + internal_pub_socket.bind(settings.zmq_settings.internal_sub_address) + logger.debug("Internal publishing socket bound to %s", internal_pub_socket) + + internal_sub_socket = context.socket(zmq.XSUB) + internal_sub_socket.bind(settings.zmq_settings.internal_pub_address) + logger.debug("Internal subscribing socket bound to %s", internal_sub_socket) + try: + zmq.proxy(internal_sub_socket, internal_pub_socket) + except zmq.ZMQError: + logger.warning("Error while handling PUB/SUB proxy. Closing sockets.") + finally: + internal_pub_socket.close() + internal_sub_socket.close() + + @contextlib.asynccontextmanager async def lifespan(app: FastAPI): logger.info("%s starting up.", app.title) # Initiate sockets - internal_comm_socket = context.socket(zmq.PUB) - internal_comm_address = settings.zmq_settings.internal_comm_address - internal_comm_socket.bind(internal_comm_address) - app.state.internal_comm_socket = internal_comm_socket - logger.info("Internal publishing socket bound to %s", internal_comm_socket) + proxy_thread = threading.Thread(target=setup_sockets) + proxy_thread.daemon = True + proxy_thread.start() + + context = Context.instance() + + endpoints_pub_socket = context.socket(zmq.PUB) + endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address) + app.state.endpoints_pub_socket = endpoints_pub_socket # Initiate agents ri_communication_agent = RICommunicationAgent( diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index 4249401..00edcb1 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -7,19 +7,21 @@ import zmq from control_backend.agents.ri_command_agent import RICommandAgent +@pytest.fixture +def zmq_context(mocker): + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context + + @pytest.mark.asyncio -async def test_setup_bind(monkeypatch): +async def test_setup_bind(zmq_context, mocker): """Test setup with bind=True""" - fake_socket = MagicMock() - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket - ) + fake_socket = zmq_context.return_value.socket.return_value agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), - ) + settings = mocker.patch("control_backend.agents.ri_command_agent.settings") + settings.zmq_settings.internal_sub_address = "tcp://internal:1234" await agent.setup() @@ -34,18 +36,13 @@ async def test_setup_bind(monkeypatch): @pytest.mark.asyncio -async def test_setup_connect(monkeypatch): +async def test_setup_connect(zmq_context, mocker): """Test setup with bind=False""" - fake_socket = MagicMock() - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket - ) + fake_socket = zmq_context.return_value.socket.return_value agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), - ) + settings = mocker.patch("control_backend.agents.ri_command_agent.settings") + settings.zmq_settings.internal_sub_address = "tcp://internal:1234" await agent.setup() diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index fd555e1..443d609 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -84,21 +84,23 @@ def fake_json_invalid_id_negototiate(): ) +@pytest.fixture +def zmq_context(mocker): + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context + + @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_1(monkeypatch): +async def test_setup_creates_socket_and_negotiate_1(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_1() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -128,20 +130,15 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_2(monkeypatch): +async def test_setup_creates_socket_and_negotiate_2(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_2() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -171,20 +168,15 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): """ Test the functionality of setup with incorrect negotiation message """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_wrong_negototiate_1() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup # We are sending wrong negotiation info to the communication agent, @@ -215,20 +207,15 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_4(monkeypatch): +async def test_setup_creates_socket_and_negotiate_4(zmq_context): """ Test the setup of the communication agent with different bind value """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_3() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -258,20 +245,15 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_5(monkeypatch): +async def test_setup_creates_socket_and_negotiate_5(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_4() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -301,20 +283,15 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_6(monkeypatch): +async def test_setup_creates_socket_and_negotiate_6(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_5() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -344,20 +321,15 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): """ Test the functionality of setup with incorrect id """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_invalid_id_negototiate() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup # We are sending wrong negotiation info to the communication agent, @@ -385,20 +357,15 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog): """ Test the functionality of setup with incorrect negotiation message """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: @@ -480,8 +447,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog): @pytest.mark.asyncio -async def test_listen_behaviour_timeout(caplog): - fake_socket = AsyncMock() +async def test_listen_behaviour_timeout(zmq_context, caplog): + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # recv_json will never resolve, simulate timeout fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) @@ -529,16 +496,12 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): @pytest.mark.asyncio -async def test_setup_unexpected_exception(monkeypatch, caplog): - fake_socket = MagicMock() +async def test_setup_unexpected_exception(zmq_context, caplog): + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # Simulate unexpected exception during recv_json() fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - agent = RICommunicationAgent( "test@server", "password", address="tcp://localhost:5555", bind=False ) @@ -551,9 +514,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_unpacking_exception(monkeypatch, caplog): +async def test_setup_unpacking_exception(zmq_context, caplog): # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # Make recv_json return malformed negotiation data to trigger unpacking exception @@ -563,11 +526,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): } # missing 'port' and 'bind' fake_socket.recv_json = AsyncMock(return_value=malformed_data) - # Patch context.socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Patch RICommandAgent so it won't actually start with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True diff --git a/test/integration/agents/vad_agent/test_vad_agent.py b/test/integration/agents/vad_agent/test_vad_agent.py index 54c9d82..0e1fae2 100644 --- a/test/integration/agents/vad_agent/test_vad_agent.py +++ b/test/integration/agents/vad_agent/test_vad_agent.py @@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent @pytest.fixture def zmq_context(mocker): - return mocker.patch("control_backend.agents.vad_agent.zmq_context") + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context @pytest.fixture @@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool): assert vad_agent.audio_in_socket is not None - zmq_context.socket.assert_called_once_with(zmq.SUB) - zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") + zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) + zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with( + zmq.SUBSCRIBE, + "", + ) if do_bind: - zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") + zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345") else: - zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") + zmq_context.return_value.socket.return_value.connect.assert_called_once_with( + "tcp://localhost:12345" + ) def test_out_socket_creation(zmq_context): @@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context): assert vad_agent.audio_out_socket is not None - zmq_context.socket.assert_called_once_with(zmq.PUB) - zmq_context.socket.return_value.bind_to_random_port.assert_called_once() + zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) + zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() @pytest.mark.asyncio @@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context): Test setup failure when the audio output socket cannot be created. """ with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: - zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = ( + zmq.ZMQBindError + ) vad_agent = VADAgent("tcp://localhost:12345", False) await vad_agent.setup() @@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent): Test that when the VAD agent is stopped, the sockets are closed correctly. """ vad_agent = VADAgent("tcp://localhost:12345", False) - zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) + zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( + 1000, + 10000, + ) await vad_agent.setup() await vad_agent.stop() - assert zmq_context.socket.return_value.close.call_count == 2 + assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert vad_agent.audio_in_socket is None assert vad_agent.audio_out_socket is None diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index 04890c1..1c9213a 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock import pytest from fastapi import FastAPI @@ -16,7 +16,6 @@ def app(): """ app = FastAPI() app.include_router(command.router) - app.state.internal_comm_socket = MagicMock() # mock ZMQ socket return app @@ -26,32 +25,30 @@ def client(app): return TestClient(app) -def test_receive_command_endpoint(client, app): +def test_receive_command_success(client): """ - Test that a POST to /command sends the right multipart message - and returns a 202 with the expected JSON body. + Test for successful reception of a command. Ensures the status code is 202 and the response body + is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the + expected data. """ - mock_socket = app.state.internal_comm_socket + # Arrange + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket - # Prepare test payload that matches SpeechCommand - payload = {"endpoint": "actuate/speech", "data": "yooo"} + command_data = {"endpoint": "actuate/speech", "data": "This is a test"} + speech_command = SpeechCommand(**command_data) - # Send POST request - response = client.post("/command", json=payload) + # Act + response = client.post("/command", json=command_data) - # Check response + # Assert assert response.status_code == 202 assert response.json() == {"status": "Command received"} - # Verify that the socket was called with the correct data - assert mock_socket.send_multipart.called, "Socket should be used to send data" - - args, kwargs = mock_socket.send_multipart.call_args - sent_data = args[0] - - assert sent_data[0] == b"command" - # Check JSON encoding roughly matches - assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand) + # Verify that the ZMQ socket was used correctly + mock_pub_socket.send_multipart.assert_awaited_once_with( + [b"command", speech_command.model_dump_json().encode()] + ) def test_receive_command_invalid_payload(client):