Merge remote-tracking branch 'origin/dev' into demo

This commit is contained in:
Twirre Meulenbelt
2025-11-05 12:38:08 +01:00
15 changed files with 187 additions and 163 deletions

View File

@@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent
@pytest.fixture
def zmq_context(mocker):
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.fixture
@@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
assert vad_agent.audio_in_socket is not None
zmq_context.socket.assert_called_once_with(zmq.SUB)
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
zmq.SUBSCRIBE,
"",
)
if do_bind:
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else:
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
"tcp://localhost:12345"
)
def test_out_socket_creation(zmq_context):
@@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context):
assert vad_agent.audio_out_socket is not None
zmq_context.socket.assert_called_once_with(zmq.PUB)
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio
@@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context):
Test setup failure when the audio output socket cannot be created.
"""
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError
)
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
@@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000)
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000,
10000,
)
await vad_agent.setup()
await vad_agent.stop()
assert zmq_context.socket.return_value.close.call_count == 2
assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None