test: complete VAD unit and integration tests

Including an integration test with real voice audio.

ref: N25B-213
This commit is contained in:
Twirre Meulenbelt
2025-10-23 21:17:41 +02:00
parent ca5e59d029
commit d47074d091
10 changed files with 312 additions and 22 deletions

View File

@@ -24,10 +24,16 @@ uv run fastapi dev src/control_backend/main.py
```
## Testing
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following:
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:
```bash
uv run --only-group test pytest
uv run --only-group test pytest test/unit
```
Or for integration tests:
```bash
uv run --only-group integration-test pytest test/integration
```
## GitHooks

View File

@@ -20,6 +20,9 @@ dependencies = [
]
[dependency-groups]
integration-test = [
"soundfile>=0.13.1",
]
test = [
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",

View File

@@ -14,18 +14,28 @@ logger = logging.getLogger(__name__)
class SocketPoller[T]:
def __init__(self, socket: azmq.Socket):
"""
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
"""
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
"""
:param socket: The socket to poll and get data from.
:param timeout_ms: A timeout in milliseconds to wait for data.
"""
self.socket = socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms
async def poll(self, timeout_ms: int) -> T | None:
async def poll(self, timeout_ms: int | None = None) -> T | None:
"""
Get data from the socket, or None if the timeout is reached.
:param timeout_ms: The number of milliseconds to wait for the socket.
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
socks = dict(self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv()
@@ -41,17 +51,16 @@ class Streaming(CyclicBehaviour):
force_reload=False)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
self.i_since_speech = 0 # Used to allow small pauses in speech
self.i_since_speech = 100 # Used to allow small pauses in speech
async def run(self) -> None:
timeout_ms = 100
data = await self.audio_in_poller.poll(timeout_ms)
data = await self.audio_in_poller.poll()
if data is None:
if self.i_since_data % 10 == 0:
logger.debug("Failed to receive audio from socket for %d ms.",
timeout_ms*self.i_since_data)
self.audio_in_poller.timeout_ms*(self.i_since_data+1))
self.i_since_data += 1
return
self.i_since_data = 0
@@ -75,7 +84,7 @@ class Streaming(CyclicBehaviour):
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3*len(chunk):
logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer.tobytes())
await self.audio_out_socket.send(self.audio_buffer[:-2*len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
@@ -101,10 +110,12 @@ class VADAgent(Agent):
"""
Stop listening to audio, stop publishing audio, close sockets.
"""
self.audio_in_socket.close()
self.audio_in_socket = None
self.audio_out_socket.close()
self.audio_out_socket = None
if self.audio_in_socket is not None:
self.audio_in_socket.close()
self.audio_in_socket = None
if self.audio_out_socket is not None:
self.audio_out_socket.close()
self.audio_out_socket = None
return await super().stop()
def _connect_audio_in_socket(self):

View File

View File

@@ -0,0 +1,97 @@
from unittest.mock import MagicMock, AsyncMock, patch
import pytest
import zmq
from spade.agent import Agent
from control_backend.agents.vad_agent import VADAgent
@pytest.fixture
def zmq_context(mocker):
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
@pytest.fixture
def streaming(mocker):
return mocker.patch("control_backend.agents.vad_agent.Streaming")
@pytest.mark.asyncio
async def test_normal_setup(streaming):
"""
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio sockets.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent.add_behaviour = MagicMock()
await vad_agent.setup()
streaming.assert_called_once()
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
assert vad_agent.audio_in_socket is not None
assert vad_agent.audio_out_socket is not None
@pytest.mark.parametrize("do_bind", [True, False])
def test_in_socket_creation(zmq_context, do_bind: bool):
"""
Test that the VAD agent creates an audio input socket, differentiating between binding and connecting.
"""
vad_agent = VADAgent(f"tcp://{"*" if do_bind else "localhost"}:12345", do_bind)
vad_agent._connect_audio_in_socket()
assert vad_agent.audio_in_socket is not None
zmq_context.socket.assert_called_once_with(zmq.SUB)
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
if do_bind:
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else:
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
def test_out_socket_creation(zmq_context):
"""
Test that the VAD agent creates an audio output socket correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._connect_audio_out_socket()
assert vad_agent.audio_out_socket is not None
zmq_context.socket.assert_called_once_with(zmq.PUB)
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio
async def test_out_socket_creation_failure(zmq_context):
"""
Test setup failure when the audio output socket cannot be created.
"""
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
assert vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once()
@pytest.mark.asyncio
async def test_stop(zmq_context):
"""
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
await vad_agent.stop()
assert zmq_context.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None

View File

@@ -0,0 +1,57 @@
import os
from unittest.mock import MagicMock, AsyncMock
import pytest
import soundfile as sf
import zmq
from control_backend.agents.vad_agent import Streaming
def get_audio_chunks() -> list[bytes]:
curr_file = os.path.realpath(__file__)
curr_dir = os.path.dirname(curr_file)
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
chunk_size = 512
chunks = []
with sf.SoundFile(file, 'r') as f:
assert f.samplerate == 16000
assert f.channels == 1
assert f.subtype == "FLOAT"
while True:
data = f.read(chunk_size, dtype="float32")
if len(data) != chunk_size:
break
chunks.append(data.tobytes())
return chunks
@pytest.mark.asyncio
async def test_real_audio(mocker):
"""
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
input. Ensure that it outputs some fragments with audio.
"""
audio_chunks = get_audio_chunks()
audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
audio_out_socket = AsyncMock()
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
for _ in audio_chunks:
await vad_streamer.run()
audio_out_socket.send.assert_called()
for args in audio_out_socket.send.call_args_list:
assert isinstance(args[0][0], bytes)
assert len(args[0][0]) >= 512*4*3 # Should be at least 3 chunks of audio

View File

@@ -0,0 +1,46 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.vad_agent import SocketPoller
@pytest.fixture
def socket():
return AsyncMock()
@pytest.mark.asyncio
async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test"
socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused
await poller.poll()
data = await poller.poll()
assert data == socket_data
# Ensure that the poller was reused
mock_poller.assert_called_once_with()
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
assert socket.recv.call_count == 2
@pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket)
data = await poller.poll()
assert data is None
socket.recv.assert_not_called()

View File

@@ -21,11 +21,13 @@ def streaming(audio_in_socket, audio_out_socket):
return Streaming(audio_in_socket, audio_out_socket)
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
# After three chunks of audio with speech probability of 1.0, then four chunks of audio with
# speech probability of 0.0, it should send a message over the audio out socket
probabilities = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
"""
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
:param streaming: The streaming component to be tested.
:param probabilities: A list of probabilities representing the outputs of the VAD model.
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
@@ -38,8 +40,53 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
for _ in probabilities:
await streaming.run()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is voice activity detected between silences.
:return:
"""
speech_chunk_count = 5
probabilities = [0.0]*5 + [1.0]*speech_chunk_count + [0.0]*5
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# each sample has 512 frames of 4 bytes, expecting 5 chunks (3 with speech, 2 as padding)
assert len(data) == 512*4*5
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
assert len(data) == 512*4*(speech_chunk_count+2)
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
"""
speech_chunk_count = 5
probabilities = [0.0]*5 + [1.0]*speech_chunk_count + [0.0] + [1.0]*speech_chunk_count + [0.0]*5
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
assert len(data) == 512*4*(speech_chunk_count*2+1+2)
@pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
assert streaming.i_since_data == 0
await streaming.run()
audio_out_socket.send.assert_not_called()
assert streaming.i_since_data == 1

23
uv.lock generated
View File

@@ -1309,6 +1309,9 @@ dependencies = [
]
[package.dev-dependencies]
integration-test = [
{ name = "soundfile" },
]
test = [
{ name = "pytest" },
{ name = "pytest-asyncio" },
@@ -1333,6 +1336,7 @@ requires-dist = [
]
[package.metadata.requires-dev]
integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }]
test = [
{ name = "pytest", specifier = ">=8.4.2" },
{ name = "pytest-asyncio", specifier = ">=1.2.0" },
@@ -2081,6 +2085,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
]
[[package]]
name = "soundfile"
version = "0.13.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cffi" },
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e1/41/9b873a8c055582859b239be17902a85339bec6a30ad162f98c9b0288a2cc/soundfile-0.13.1.tar.gz", hash = "sha256:b2c68dab1e30297317080a5b43df57e302584c49e2942defdde0acccc53f0e5b", size = 46156, upload-time = "2025-01-25T09:17:04.831Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/64/28/e2a36573ccbcf3d57c00626a21fe51989380636e821b341d36ccca0c1c3a/soundfile-0.13.1-py2.py3-none-any.whl", hash = "sha256:a23c717560da2cf4c7b5ae1142514e0fd82d6bbd9dfc93a50423447142f2c445", size = 25751, upload-time = "2025-01-25T09:16:44.235Z" },
{ url = "https://files.pythonhosted.org/packages/ea/ab/73e97a5b3cc46bba7ff8650a1504348fa1863a6f9d57d7001c6b67c5f20e/soundfile-0.13.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:82dc664d19831933fe59adad199bf3945ad06d84bc111a5b4c0d3089a5b9ec33", size = 1142250, upload-time = "2025-01-25T09:16:47.583Z" },
{ url = "https://files.pythonhosted.org/packages/a0/e5/58fd1a8d7b26fc113af244f966ee3aecf03cb9293cb935daaddc1e455e18/soundfile-0.13.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:743f12c12c4054921e15736c6be09ac26b3b3d603aef6fd69f9dde68748f2593", size = 1101406, upload-time = "2025-01-25T09:16:49.662Z" },
{ url = "https://files.pythonhosted.org/packages/58/ae/c0e4a53d77cf6e9a04179535766b3321b0b9ced5f70522e4caf9329f0046/soundfile-0.13.1-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9c9e855f5a4d06ce4213f31918653ab7de0c5a8d8107cd2427e44b42df547deb", size = 1235729, upload-time = "2025-01-25T09:16:53.018Z" },
{ url = "https://files.pythonhosted.org/packages/57/5e/70bdd9579b35003a489fc850b5047beeda26328053ebadc1fb60f320f7db/soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:03267c4e493315294834a0870f31dbb3b28a95561b80b134f0bd3cf2d5f0e618", size = 1313646, upload-time = "2025-01-25T09:16:54.872Z" },
{ url = "https://files.pythonhosted.org/packages/fe/df/8c11dc4dfceda14e3003bb81a0d0edcaaf0796dd7b4f826ea3e532146bba/soundfile-0.13.1-py2.py3-none-win32.whl", hash = "sha256:c734564fab7c5ddf8e9be5bf70bab68042cd17e9c214c06e365e20d64f9a69d5", size = 899881, upload-time = "2025-01-25T09:16:56.663Z" },
{ url = "https://files.pythonhosted.org/packages/14/e9/6b761de83277f2f02ded7e7ea6f07828ec78e4b229b80e4ca55dd205b9dc/soundfile-0.13.1-py2.py3-none-win_amd64.whl", hash = "sha256:1e70a05a0626524a69e9f0f4dd2ec174b4e9567f4d8b6c11d38b5c289be36ee9", size = 1019162, upload-time = "2025-01-25T09:16:59.573Z" },
]
[[package]]
name = "spade"
version = "4.1.0"