diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 22ec751..c35bf16 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -2,10 +2,10 @@ import json import zmq from spade.behaviour import CyclicBehaviour +from zmq.asyncio import Context from control_backend.agents import BaseAgent from control_backend.core.config import settings -from control_backend.core.zmq_context import context from control_backend.schemas.ri_message import SpeechCommand @@ -53,6 +53,8 @@ class RICommandAgent(BaseAgent): """ self.logger.info("Setting up %s", self.jid) + context = Context.instance() + # To the robot self.pubsocket = context.socket(zmq.PUB) if self.bind: @@ -62,7 +64,7 @@ class RICommandAgent(BaseAgent): # Receive internal topics regarding commands self.subsocket = context.socket(zmq.SUB) - self.subsocket.connect(settings.zmq_settings.internal_comm_address) + self.subsocket.connect(settings.zmq_settings.internal_sub_address) self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") # Add behaviour to our agent diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 4e7680a..76d6431 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -2,10 +2,10 @@ import asyncio import zmq from spade.behaviour import CyclicBehaviour +from zmq.asyncio import Context from control_backend.agents import BaseAgent from control_backend.core.config import settings -from control_backend.core.zmq_context import context from .ri_command_agent import RICommandAgent @@ -72,7 +72,7 @@ class RICommunicationAgent(BaseAgent): # Let's try a certain amount of times before failing connection while retries < max_retries: # Bind request socket - self.req_socket = context.socket(zmq.REQ) + self.req_socket = Context.instance().socket(zmq.REQ) if self._bind: self.req_socket.bind(self._address) else: diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index 495f623..4ace513 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -8,7 +8,6 @@ from spade.message import Message from control_backend.agents import BaseAgent from control_backend.core.config import settings -from control_backend.core.zmq_context import context as zmq_context from .speech_recognizer import SpeechRecognizer @@ -67,7 +66,7 @@ class TranscriptionAgent(BaseAgent): return await super().stop() def _connect_audio_in_socket(self): - self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") self.audio_in_socket.connect(self.audio_in_address) diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index 099b49a..2b127b8 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -6,7 +6,6 @@ from spade.behaviour import CyclicBehaviour from control_backend.agents import BaseAgent from control_backend.core.config import settings -from control_backend.core.zmq_context import context as zmq_context from .transcription.transcription_agent import TranscriptionAgent @@ -120,7 +119,7 @@ class VADAgent(BaseAgent): return await super().stop() def _connect_audio_in_socket(self): - self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") if self.audio_in_bind: self.audio_in_socket.bind(self.audio_in_address) @@ -131,7 +130,7 @@ class VADAgent(BaseAgent): def _connect_audio_out_socket(self) -> int | None: """Returns the port bound, or None if binding failed.""" try: - self.audio_out_socket = zmq_context.socket(zmq.PUB) + self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) except zmq.ZMQBindError: self.logger.error("Failed to bind an audio output socket after 100 tries.") diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py index e19290f..d7f963b 100644 --- a/src/control_backend/api/v1/endpoints/command.py +++ b/src/control_backend/api/v1/endpoints/command.py @@ -1,7 +1,6 @@ import logging from fastapi import APIRouter, Request -from zmq import Socket from control_backend.schemas.ri_message import SpeechCommand @@ -15,7 +14,7 @@ async def receive_command(command: SpeechCommand, request: Request): # Validate and retrieve data. SpeechCommand.model_validate(command) topic = b"command" - pub_socket: Socket = request.app.state.internal_comm_socket - pub_socket.send_multipart([topic, command.model_dump_json().encode()]) + pub_socket = request.app.state.endpoints_pub_socket + await pub_socket.send_multipart([topic, command.model_dump_json().encode()]) return {"status": "Command received"} diff --git a/src/control_backend/api/v1/endpoints/message.py b/src/control_backend/api/v1/endpoints/message.py index 1053c3c..bd88a0b 100644 --- a/src/control_backend/api/v1/endpoints/message.py +++ b/src/control_backend/api/v1/endpoints/message.py @@ -1,7 +1,6 @@ import logging from fastapi import APIRouter, Request -from zmq import Socket from control_backend.schemas.message import Message @@ -17,8 +16,7 @@ async def receive_message(message: Message, request: Request): topic = b"message" body = message.model_dump_json().encode("utf-8") - pub_socket: Socket = request.app.state.internal_comm_socket - - pub_socket.send_multipart([topic, body]) + pub_socket = request.app.state.endpoints_pub_socket + await pub_socket.send_multipart([topic, body]) return {"status": "Message received"} diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 2fd16b8..8de2403 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class ZMQSettings(BaseModel): - internal_comm_address: str = "tcp://localhost:5560" + internal_pub_address: str = "tcp://localhost:5560" + internal_sub_address: str = "tcp://localhost:5561" class AgentSettings(BaseModel): diff --git a/src/control_backend/core/zmq_context.py b/src/control_backend/core/zmq_context.py deleted file mode 100644 index a74544f..0000000 --- a/src/control_backend/core/zmq_context.py +++ /dev/null @@ -1,3 +0,0 @@ -from zmq.asyncio import Context - -context = Context() diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 2df89e2..1fbf4fa 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -1,9 +1,11 @@ import contextlib import logging +import threading import zmq from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from zmq.asyncio import Context from control_backend.agents import ( BeliefCollectorAgent, @@ -14,12 +16,30 @@ from control_backend.agents import ( from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent from control_backend.api.v1.router import api_router from control_backend.core.config import settings -from control_backend.core.zmq_context import context from control_backend.logging import setup_logging logger = logging.getLogger(__name__) +def setup_sockets(): + context = Context.instance() + + internal_pub_socket = context.socket(zmq.XPUB) + internal_pub_socket.bind(settings.zmq_settings.internal_sub_address) + logger.debug("Internal publishing socket bound to %s", internal_pub_socket) + + internal_sub_socket = context.socket(zmq.XSUB) + internal_sub_socket.bind(settings.zmq_settings.internal_pub_address) + logger.debug("Internal subscribing socket bound to %s", internal_sub_socket) + try: + zmq.proxy(internal_sub_socket, internal_pub_socket) + except zmq.ZMQError: + logger.warning("Error while handling PUB/SUB proxy. Closing sockets.") + finally: + internal_pub_socket.close() + internal_sub_socket.close() + + @contextlib.asynccontextmanager async def lifespan(app: FastAPI): """ @@ -29,18 +49,16 @@ async def lifespan(app: FastAPI): setup_logging() logger.info("%s is starting up.", app.title) - # --- Initialize Sockets --- - logger.info("Initializing ZeroMQ sockets.") - try: - internal_comm_socket = context.socket(zmq.PUB) - internal_comm_address = settings.zmq_settings.internal_comm_address - logger.debug("Binding internal PUB socket to address: %s", internal_comm_address) - internal_comm_socket.bind(internal_comm_address) - app.state.internal_comm_socket = internal_comm_socket - logger.info("Internal communication socket bound successfully.") - except Exception as e: - logger.error("Failed to bind internal communication socket: %s", e, exc_info=True) - raise + # Initiate sockets + proxy_thread = threading.Thread(target=setup_sockets) + proxy_thread.daemon = True + proxy_thread.start() + + context = Context.instance() + + endpoints_pub_socket = context.socket(zmq.PUB) + endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address) + app.state.endpoints_pub_socket = endpoints_pub_socket # --- Initialize Agents --- logger.info("Initializing and starting agents.") diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index 4249401..00edcb1 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -7,19 +7,21 @@ import zmq from control_backend.agents.ri_command_agent import RICommandAgent +@pytest.fixture +def zmq_context(mocker): + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context + + @pytest.mark.asyncio -async def test_setup_bind(monkeypatch): +async def test_setup_bind(zmq_context, mocker): """Test setup with bind=True""" - fake_socket = MagicMock() - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket - ) + fake_socket = zmq_context.return_value.socket.return_value agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), - ) + settings = mocker.patch("control_backend.agents.ri_command_agent.settings") + settings.zmq_settings.internal_sub_address = "tcp://internal:1234" await agent.setup() @@ -34,18 +36,13 @@ async def test_setup_bind(monkeypatch): @pytest.mark.asyncio -async def test_setup_connect(monkeypatch): +async def test_setup_connect(zmq_context, mocker): """Test setup with bind=False""" - fake_socket = MagicMock() - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket - ) + fake_socket = zmq_context.return_value.socket.return_value agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), - ) + settings = mocker.patch("control_backend.agents.ri_command_agent.settings") + settings.zmq_settings.internal_sub_address = "tcp://internal:1234" await agent.setup() diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index fd555e1..443d609 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -84,21 +84,23 @@ def fake_json_invalid_id_negototiate(): ) +@pytest.fixture +def zmq_context(mocker): + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context + + @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_1(monkeypatch): +async def test_setup_creates_socket_and_negotiate_1(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_1() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -128,20 +130,15 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_2(monkeypatch): +async def test_setup_creates_socket_and_negotiate_2(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_2() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -171,20 +168,15 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): """ Test the functionality of setup with incorrect negotiation message """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_wrong_negototiate_1() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup # We are sending wrong negotiation info to the communication agent, @@ -215,20 +207,15 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_4(monkeypatch): +async def test_setup_creates_socket_and_negotiate_4(zmq_context): """ Test the setup of the communication agent with different bind value """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_3() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -258,20 +245,15 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_5(monkeypatch): +async def test_setup_creates_socket_and_negotiate_5(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_4() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -301,20 +283,15 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_6(monkeypatch): +async def test_setup_creates_socket_and_negotiate_6(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_5() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -344,20 +321,15 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): """ Test the functionality of setup with incorrect id """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_invalid_id_negototiate() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup # We are sending wrong negotiation info to the communication agent, @@ -385,20 +357,15 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog): """ Test the functionality of setup with incorrect negotiation message """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: @@ -480,8 +447,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog): @pytest.mark.asyncio -async def test_listen_behaviour_timeout(caplog): - fake_socket = AsyncMock() +async def test_listen_behaviour_timeout(zmq_context, caplog): + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # recv_json will never resolve, simulate timeout fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) @@ -529,16 +496,12 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): @pytest.mark.asyncio -async def test_setup_unexpected_exception(monkeypatch, caplog): - fake_socket = MagicMock() +async def test_setup_unexpected_exception(zmq_context, caplog): + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # Simulate unexpected exception during recv_json() fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - agent = RICommunicationAgent( "test@server", "password", address="tcp://localhost:5555", bind=False ) @@ -551,9 +514,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_unpacking_exception(monkeypatch, caplog): +async def test_setup_unpacking_exception(zmq_context, caplog): # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # Make recv_json return malformed negotiation data to trigger unpacking exception @@ -563,11 +526,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): } # missing 'port' and 'bind' fake_socket.recv_json = AsyncMock(return_value=malformed_data) - # Patch context.socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Patch RICommandAgent so it won't actually start with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True diff --git a/test/integration/agents/vad_agent/test_vad_agent.py b/test/integration/agents/vad_agent/test_vad_agent.py index 54c9d82..0e1fae2 100644 --- a/test/integration/agents/vad_agent/test_vad_agent.py +++ b/test/integration/agents/vad_agent/test_vad_agent.py @@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent @pytest.fixture def zmq_context(mocker): - return mocker.patch("control_backend.agents.vad_agent.zmq_context") + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context @pytest.fixture @@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool): assert vad_agent.audio_in_socket is not None - zmq_context.socket.assert_called_once_with(zmq.SUB) - zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") + zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) + zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with( + zmq.SUBSCRIBE, + "", + ) if do_bind: - zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") + zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345") else: - zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") + zmq_context.return_value.socket.return_value.connect.assert_called_once_with( + "tcp://localhost:12345" + ) def test_out_socket_creation(zmq_context): @@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context): assert vad_agent.audio_out_socket is not None - zmq_context.socket.assert_called_once_with(zmq.PUB) - zmq_context.socket.return_value.bind_to_random_port.assert_called_once() + zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) + zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() @pytest.mark.asyncio @@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context): Test setup failure when the audio output socket cannot be created. """ with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: - zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = ( + zmq.ZMQBindError + ) vad_agent = VADAgent("tcp://localhost:12345", False) await vad_agent.setup() @@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent): Test that when the VAD agent is stopped, the sockets are closed correctly. """ vad_agent = VADAgent("tcp://localhost:12345", False) - zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) + zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( + 1000, + 10000, + ) await vad_agent.setup() await vad_agent.stop() - assert zmq_context.socket.return_value.close.call_count == 2 + assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert vad_agent.audio_in_socket is None assert vad_agent.audio_out_socket is None diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index 04890c1..1c9213a 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock import pytest from fastapi import FastAPI @@ -16,7 +16,6 @@ def app(): """ app = FastAPI() app.include_router(command.router) - app.state.internal_comm_socket = MagicMock() # mock ZMQ socket return app @@ -26,32 +25,30 @@ def client(app): return TestClient(app) -def test_receive_command_endpoint(client, app): +def test_receive_command_success(client): """ - Test that a POST to /command sends the right multipart message - and returns a 202 with the expected JSON body. + Test for successful reception of a command. Ensures the status code is 202 and the response body + is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the + expected data. """ - mock_socket = app.state.internal_comm_socket + # Arrange + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket - # Prepare test payload that matches SpeechCommand - payload = {"endpoint": "actuate/speech", "data": "yooo"} + command_data = {"endpoint": "actuate/speech", "data": "This is a test"} + speech_command = SpeechCommand(**command_data) - # Send POST request - response = client.post("/command", json=payload) + # Act + response = client.post("/command", json=command_data) - # Check response + # Assert assert response.status_code == 202 assert response.json() == {"status": "Command received"} - # Verify that the socket was called with the correct data - assert mock_socket.send_multipart.called, "Socket should be used to send data" - - args, kwargs = mock_socket.send_multipart.call_args - sent_data = args[0] - - assert sent_data[0] == b"command" - # Check JSON encoding roughly matches - assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand) + # Verify that the ZMQ socket was used correctly + mock_pub_socket.send_multipart.assert_awaited_once_with( + [b"command", speech_command.model_dump_json().encode()] + ) def test_receive_command_invalid_payload(client):