Refactored ZMQ context implementation #16

Merged
k.marinus merged 7 commits from refactor/zmq-internal-socket-behaviour into dev 2025-11-05 11:35:27 +00:00
4 changed files with 61 additions and 134 deletions
Showing only changes of commit b008562554 - Show all commits

View File

@@ -7,19 +7,21 @@ from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_bind(monkeypatch): async def test_setup_bind(zmq_context, mocker):
"""Test setup with bind=True""" """Test setup with bind=True"""
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
monkeypatch.setattr( settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
"control_backend.agents.ri_command_agent.settings", settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup() await agent.setup()
@@ -34,18 +36,13 @@ async def test_setup_bind(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_connect(monkeypatch): async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False""" """Test setup with bind=False"""
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
monkeypatch.setattr( settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
"control_backend.agents.ri_command_agent.settings", settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
)
await agent.setup() await agent.setup()

View File

@@ -82,21 +82,23 @@ def fake_json_invalid_id_negototiate():
) )
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(monkeypatch): async def test_setup_creates_socket_and_negotiate_1(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1() fake_socket.recv_json = fake_json_correct_negototiate_1()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -126,20 +128,15 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(monkeypatch): async def test_setup_creates_socket_and_negotiate_2(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2() fake_socket.recv_json = fake_json_correct_negototiate_2()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -169,20 +166,15 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect negotiation message Test the functionality of setup with incorrect negotiation message
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1() fake_socket.recv_json = fake_json_wrong_negototiate_1()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a # We are sending wrong negotiation info to the communication agent, so we should retry and expect a
@@ -213,20 +205,15 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(monkeypatch): async def test_setup_creates_socket_and_negotiate_4(zmq_context):
""" """
Test the setup of the communication agent with different bind value Test the setup of the communication agent with different bind value
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3() fake_socket.recv_json = fake_json_correct_negototiate_3()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -256,20 +243,15 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(monkeypatch): async def test_setup_creates_socket_and_negotiate_5(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4() fake_socket.recv_json = fake_json_correct_negototiate_4()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -299,20 +281,15 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(monkeypatch): async def test_setup_creates_socket_and_negotiate_6(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5() fake_socket.recv_json = fake_json_correct_negototiate_5()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -342,20 +319,15 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect id Test the functionality of setup with incorrect id
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate() fake_socket.recv_json = fake_json_invalid_id_negototiate()
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a # We are sending wrong negotiation info to the communication agent, so we should retry and expect a
@@ -383,20 +355,15 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect negotiation message Test the functionality of setup with incorrect negotiation message
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
# Mock context.socket to return our fake socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent: ) as MockCommandAgent:
@@ -478,8 +445,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_timeout(caplog): async def test_listen_behaviour_timeout(zmq_context, caplog):
fake_socket = AsyncMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout # recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
@@ -527,16 +494,12 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unexpected_exception(monkeypatch, caplog): async def test_setup_unexpected_exception(zmq_context, caplog):
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json() # Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server", "password", address="tcp://localhost:5555", bind=False
) )
@@ -549,9 +512,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unpacking_exception(monkeypatch, caplog): async def test_setup_unpacking_exception(zmq_context, caplog):
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# Make recv_json return malformed negotiation data to trigger unpacking exception # Make recv_json return malformed negotiation data to trigger unpacking exception
@@ -561,11 +524,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
} # missing 'port' and 'bind' } # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data) fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch context.socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
)
# Patch RICommandAgent so it won't actually start # Patch RICommandAgent so it won't actually start
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True

View File

@@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
return mocker.patch("control_backend.agents.vad_agent.zmq_context") mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.fixture @pytest.fixture
@@ -54,13 +56,13 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
assert vad_agent.audio_in_socket is not None assert vad_agent.audio_in_socket is not None
zmq_context.socket.assert_called_once_with(zmq.SUB) zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
if do_bind: if do_bind:
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else: else:
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") zmq_context.return_value.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
def test_out_socket_creation(zmq_context): def test_out_socket_creation(zmq_context):
@@ -73,8 +75,8 @@ def test_out_socket_creation(zmq_context):
assert vad_agent.audio_out_socket is not None assert vad_agent.audio_out_socket is not None
zmq_context.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.socket.return_value.bind_to_random_port.assert_called_once() zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -83,7 +85,7 @@ async def test_out_socket_creation_failure(zmq_context):
Test setup failure when the audio output socket cannot be created. Test setup failure when the audio output socket cannot be created.
""" """
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup() await vad_agent.setup()
@@ -98,11 +100,11 @@ async def test_stop(zmq_context, transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000)
await vad_agent.setup() await vad_agent.setup()
await vad_agent.stop() await vad_agent.stop()
assert zmq_context.socket.return_value.close.call_count == 2 assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None assert vad_agent.audio_out_socket is None

View File

@@ -26,8 +26,8 @@ def client(app):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("control_backend.api.endpoints.command.Context.instance") @patch("control_backend.api.v1.endpoints.command.Context.instance")
async def test_receive_command_success(mock_context_instance, async_client): async def test_receive_command_success(mock_context_instance, client):
""" """
Test for successful reception of a command. Test for successful reception of a command.
Ensures the status code is 202 and the response body is correct. Ensures the status code is 202 and the response body is correct.
@@ -35,54 +35,24 @@ async def test_receive_command_success(mock_context_instance, async_client):
""" """
# Arrange # Arrange
mock_pub_socket = AsyncMock() mock_pub_socket = AsyncMock()
mock_context_instance.return_value.socket.return_value = mock_pub_socket client.app.state.endpoints_pub_socket = mock_pub_socket
command_data = {"command": "test_command", "text": "This is a test"} command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
speech_command = SpeechCommand(**command_data) speech_command = SpeechCommand(**command_data)
# Act # Act
response = await async_client.post("/command", json=command_data) response = client.post("/command", json=command_data)
# Assert # Assert
assert response.status_code == 202 assert response.status_code == 202
assert response.json() == {"status": "Command received"} assert response.json() == {"status": "Command received"}
# Verify that the ZMQ socket was used correctly # Verify that the ZMQ socket was used correctly
mock_context_instance.return_value.socket.assert_called_once_with(1) # zmq.PUB is 1
mock_pub_socket.connect.assert_called_once()
mock_pub_socket.send_multipart.assert_awaited_once_with( mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()] [b"command", speech_command.model_dump_json().encode()]
) )
def test_receive_command_endpoint(client, app, mocker):
"""
Test that a POST to /command sends the right multipart message
and returns a 202 with the expected JSON body.
"""
mock_socket = mocker.patch.object()
# Prepare test payload that matches SpeechCommand
payload = {"endpoint": "actuate/speech", "data": "yooo"}
# Send POST request
response = client.post("/command", json=payload)
# Check response
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the socket was called with the correct data
assert mock_socket.send_multipart.called, "Socket should be used to send data"
args, kwargs = mock_socket.send_multipart.call_args
sent_data = args[0]
assert sent_data[0] == b"command"
# Check JSON encoding roughly matches
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
def test_receive_command_invalid_payload(client): def test_receive_command_invalid_payload(client):
""" """
Test invalid data handling (schema validation). Test invalid data handling (schema validation).