test: make VAD tests work again
ref: N25B-301
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import random
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
from spade.agent import Agent
|
||||
|
||||
from control_backend.agents.perception.vad_agent import VADAgent
|
||||
|
||||
@@ -15,11 +14,6 @@ def zmq_context(mocker):
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streaming(mocker):
|
||||
return mocker.patch("control_backend.agents.perception.vad_agent.StreamingBehaviour")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def per_transcription_agent(mocker):
|
||||
return mocker.patch(
|
||||
@@ -27,21 +21,36 @@ def per_transcription_agent(mocker):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def torch_load(mocker):
|
||||
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
|
||||
model = MagicMock()
|
||||
mock_torch.hub.load.return_value = (model, None)
|
||||
mock_torch.from_numpy.side_effect = lambda arr: arr
|
||||
return mock_torch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_setup(streaming, per_transcription_agent):
|
||||
async def test_normal_setup(per_transcription_agent):
|
||||
"""
|
||||
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
|
||||
sockets, and starts the TranscriptionAgent without loading real models.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent.add_behaviour = MagicMock()
|
||||
per_vad_agent._streaming_loop = AsyncMock()
|
||||
|
||||
async def swallow_background_task(coro):
|
||||
coro.close()
|
||||
|
||||
per_vad_agent.add_background_task = swallow_background_task
|
||||
per_vad_agent.reset_stream = AsyncMock()
|
||||
|
||||
await per_vad_agent.setup()
|
||||
|
||||
streaming.assert_called_once()
|
||||
per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
|
||||
per_transcription_agent.assert_called_once()
|
||||
per_transcription_agent.return_value.start.assert_called_once()
|
||||
per_vad_agent._streaming_loop.assert_called_once()
|
||||
per_vad_agent.reset_stream.assert_called_once()
|
||||
assert per_vad_agent.audio_in_socket is not None
|
||||
assert per_vad_agent.audio_out_socket is not None
|
||||
|
||||
@@ -91,16 +100,22 @@ async def test_out_socket_creation_failure(zmq_context):
|
||||
"""
|
||||
Test setup failure when the audio output socket cannot be created.
|
||||
"""
|
||||
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
|
||||
zmq.ZMQBindError
|
||||
)
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent.stop = AsyncMock()
|
||||
per_vad_agent.reset_stream = AsyncMock()
|
||||
per_vad_agent._streaming_loop = AsyncMock()
|
||||
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
|
||||
|
||||
await per_vad_agent.setup()
|
||||
async def swallow_background_task(coro):
|
||||
coro.close()
|
||||
|
||||
assert per_vad_agent.audio_out_socket is None
|
||||
mock_super_stop.assert_called_once()
|
||||
per_vad_agent.add_background_task = swallow_background_task
|
||||
|
||||
await per_vad_agent.setup()
|
||||
|
||||
assert per_vad_agent.audio_out_socket is None
|
||||
per_vad_agent.stop.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -109,6 +124,13 @@ async def test_stop(zmq_context, per_transcription_agent):
|
||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent.reset_stream = AsyncMock()
|
||||
per_vad_agent._streaming_loop = AsyncMock()
|
||||
|
||||
async def swallow_background_task(coro):
|
||||
coro.close()
|
||||
|
||||
per_vad_agent.add_background_task = swallow_background_task
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
||||
1000,
|
||||
10000,
|
||||
|
||||
Reference in New Issue
Block a user