214 lines
6.9 KiB
Python
214 lines
6.9 KiB
Python
"""
|
|
This program has been developed by students from the bachelor Computer Science at Utrecht
|
|
University within the Software Project course.
|
|
© Copyright Utrecht University (Department of Information and Computing Sciences)
|
|
"""
|
|
|
|
import random
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
import zmq
|
|
|
|
from control_backend.agents.perception.vad_agent import VADAgent
|
|
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
|
|
|
|
|
@pytest.fixture
|
|
def zmq_context(mocker):
|
|
mock_context = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Context.instance")
|
|
mock_context.return_value = MagicMock()
|
|
return mock_context
|
|
|
|
|
|
@pytest.fixture
|
|
def per_transcription_agent(mocker):
|
|
return mocker.patch(
|
|
"control_backend.agents.perception.vad_agent.TranscriptionAgent", autospec=True
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def torch_load(mocker):
|
|
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
|
|
model = MagicMock()
|
|
mock_torch.hub.load.return_value = (model, None)
|
|
mock_torch.from_numpy.side_effect = lambda arr: arr
|
|
return mock_torch
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_normal_setup(per_transcription_agent):
|
|
"""
|
|
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
|
|
sockets, and starts the TranscriptionAgent without loading real models.
|
|
"""
|
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
|
per_vad_agent._streaming_loop = AsyncMock()
|
|
|
|
def swallow_background_task(coro):
|
|
coro.close()
|
|
|
|
per_vad_agent.add_behavior = swallow_background_task
|
|
|
|
await per_vad_agent.setup()
|
|
|
|
per_transcription_agent.assert_called_once()
|
|
per_transcription_agent.return_value.start.assert_called_once()
|
|
per_vad_agent._streaming_loop.assert_called_once()
|
|
assert per_vad_agent.audio_in_socket is not None
|
|
assert per_vad_agent.audio_out_socket is not None
|
|
|
|
|
|
@pytest.mark.parametrize("do_bind", [True, False])
|
|
def test_in_socket_creation(zmq_context, do_bind: bool):
|
|
"""
|
|
Test that the VAD agent creates an audio input socket, differentiating between binding and
|
|
connecting.
|
|
"""
|
|
per_vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
|
|
|
|
per_vad_agent._connect_audio_in_socket()
|
|
|
|
assert per_vad_agent.audio_in_socket is not None
|
|
|
|
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
|
|
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
|
|
zmq.SUBSCRIBE,
|
|
"",
|
|
)
|
|
|
|
if do_bind:
|
|
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
|
else:
|
|
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
|
|
"tcp://localhost:12345"
|
|
)
|
|
|
|
|
|
def test_out_socket_creation(zmq_context):
|
|
"""
|
|
Test that the VAD agent creates an audio output socket correctly.
|
|
"""
|
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
|
|
|
per_vad_agent._connect_audio_out_socket()
|
|
|
|
assert per_vad_agent.audio_out_socket is not None
|
|
|
|
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
|
|
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("inproc://vad_stream")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_out_socket_creation_failure(zmq_context):
|
|
"""
|
|
Test setup failure when the audio output socket cannot be created.
|
|
"""
|
|
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
|
per_vad_agent.stop = AsyncMock()
|
|
per_vad_agent._reset_stream = AsyncMock()
|
|
per_vad_agent._streaming_loop = AsyncMock()
|
|
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
|
|
|
|
def swallow_background_task(coro):
|
|
coro.close()
|
|
|
|
per_vad_agent.add_behavior = swallow_background_task
|
|
|
|
await per_vad_agent.setup()
|
|
|
|
assert per_vad_agent.audio_out_socket is None
|
|
per_vad_agent.stop.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop(zmq_context, per_transcription_agent):
|
|
"""
|
|
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
|
"""
|
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
|
per_vad_agent._reset_stream = AsyncMock()
|
|
per_vad_agent._streaming_loop = AsyncMock()
|
|
|
|
def swallow_background_task(coro):
|
|
coro.close()
|
|
|
|
per_vad_agent.add_behavior = swallow_background_task
|
|
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
|
1000,
|
|
10000,
|
|
)
|
|
|
|
await per_vad_agent.setup()
|
|
await per_vad_agent.stop()
|
|
|
|
assert zmq_context.return_value.socket.return_value.close.call_count == 2
|
|
assert per_vad_agent.audio_in_socket is None
|
|
assert per_vad_agent.audio_out_socket is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_application_startup_complete(zmq_context):
|
|
"""Check that it resets the stream when the program finishes startup."""
|
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
|
vad_agent._running = True
|
|
vad_agent._reset_stream = AsyncMock()
|
|
vad_agent.program_sub_socket = AsyncMock()
|
|
vad_agent.program_sub_socket.close = MagicMock()
|
|
vad_agent.program_sub_socket.recv_multipart.side_effect = [
|
|
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
|
|
]
|
|
|
|
await vad_agent._status_loop()
|
|
|
|
vad_agent._reset_stream.assert_called_once()
|
|
vad_agent.program_sub_socket.close.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_application_other_status(zmq_context):
|
|
"""
|
|
Check that it does nothing when the internal communication message is a status update, but not
|
|
running.
|
|
"""
|
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
|
vad_agent._running = True
|
|
vad_agent._reset_stream = AsyncMock()
|
|
vad_agent.program_sub_socket = AsyncMock()
|
|
|
|
vad_agent.program_sub_socket.recv_multipart.side_effect = [
|
|
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
|
|
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
|
|
]
|
|
try:
|
|
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
|
|
await vad_agent._status_loop()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
vad_agent._reset_stream.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_application_message_other(zmq_context):
|
|
"""
|
|
Check that it does nothing when there's an internal communication message that is not a status
|
|
update.
|
|
"""
|
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
|
vad_agent._running = True
|
|
vad_agent._reset_stream = AsyncMock()
|
|
vad_agent.program_sub_socket = AsyncMock()
|
|
|
|
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
|
|
|
|
try:
|
|
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
|
|
await vad_agent._status_loop()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
vad_agent._reset_stream.assert_not_called()
|