test: complete VAD unit and integration tests
Including an integration test with real voice audio. ref: N25B-213
This commit is contained in:
10
README.md
10
README.md
@@ -24,10 +24,16 @@ uv run fastapi dev src/control_backend/main.py
|
||||
```
|
||||
|
||||
## Testing
|
||||
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following:
|
||||
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:
|
||||
|
||||
```bash
|
||||
uv run --only-group test pytest
|
||||
uv run --only-group test pytest test/unit
|
||||
```
|
||||
|
||||
Or for integration tests:
|
||||
|
||||
```bash
|
||||
uv run --only-group integration-test pytest test/integration
|
||||
```
|
||||
|
||||
## GitHooks
|
||||
|
||||
@@ -20,6 +20,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
integration-test = [
|
||||
"soundfile>=0.13.1",
|
||||
]
|
||||
test = [
|
||||
"pytest>=8.4.2",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
|
||||
@@ -14,18 +14,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SocketPoller[T]:
|
||||
def __init__(self, socket: azmq.Socket):
|
||||
"""
|
||||
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
|
||||
multiple usages.
|
||||
"""
|
||||
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
|
||||
"""
|
||||
:param socket: The socket to poll and get data from.
|
||||
:param timeout_ms: A timeout in milliseconds to wait for data.
|
||||
"""
|
||||
self.socket = socket
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
self.timeout_ms = timeout_ms
|
||||
|
||||
async def poll(self, timeout_ms: int) -> T | None:
|
||||
async def poll(self, timeout_ms: int | None = None) -> T | None:
|
||||
"""
|
||||
Get data from the socket, or None if the timeout is reached.
|
||||
|
||||
:param timeout_ms: The number of milliseconds to wait for the socket.
|
||||
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
|
||||
:return: Data from the socket or None.
|
||||
"""
|
||||
timeout_ms = timeout_ms or self.timeout_ms
|
||||
socks = dict(self.poller.poll(timeout_ms))
|
||||
if socks.get(self.socket) == zmq.POLLIN:
|
||||
return await self.socket.recv()
|
||||
@@ -41,17 +51,16 @@ class Streaming(CyclicBehaviour):
|
||||
force_reload=False)
|
||||
self.audio_out_socket = audio_out_socket
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
|
||||
self.i_since_speech = 0 # Used to allow small pauses in speech
|
||||
self.i_since_speech = 100 # Used to allow small pauses in speech
|
||||
|
||||
async def run(self) -> None:
|
||||
timeout_ms = 100
|
||||
data = await self.audio_in_poller.poll(timeout_ms)
|
||||
data = await self.audio_in_poller.poll()
|
||||
if data is None:
|
||||
if self.i_since_data % 10 == 0:
|
||||
logger.debug("Failed to receive audio from socket for %d ms.",
|
||||
timeout_ms*self.i_since_data)
|
||||
self.audio_in_poller.timeout_ms*(self.i_since_data+1))
|
||||
self.i_since_data += 1
|
||||
return
|
||||
self.i_since_data = 0
|
||||
@@ -75,7 +84,7 @@ class Streaming(CyclicBehaviour):
|
||||
# Speech probably ended. Make sure we have a usable amount of data.
|
||||
if len(self.audio_buffer) >= 3*len(chunk):
|
||||
logger.debug("Speech ended.")
|
||||
await self.audio_out_socket.send(self.audio_buffer.tobytes())
|
||||
await self.audio_out_socket.send(self.audio_buffer[:-2*len(chunk)].tobytes())
|
||||
|
||||
# At this point, we know that the speech has ended.
|
||||
# Prepend the last chunk that had no speech, for a more fluent boundary
|
||||
@@ -101,8 +110,10 @@ class VADAgent(Agent):
|
||||
"""
|
||||
Stop listening to audio, stop publishing audio, close sockets.
|
||||
"""
|
||||
if self.audio_in_socket is not None:
|
||||
self.audio_in_socket.close()
|
||||
self.audio_in_socket = None
|
||||
if self.audio_out_socket is not None:
|
||||
self.audio_out_socket.close()
|
||||
self.audio_out_socket = None
|
||||
return await super().stop()
|
||||
|
||||
Binary file not shown.
97
test/integration/agents/vad_agent/test_vad_agent.py
Normal file
97
test/integration/agents/vad_agent/test_vad_agent.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
from spade.agent import Agent
|
||||
|
||||
from control_backend.agents.vad_agent import VADAgent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streaming(mocker):
|
||||
return mocker.patch("control_backend.agents.vad_agent.Streaming")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_setup(streaming):
|
||||
"""
|
||||
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio sockets.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent.add_behaviour = MagicMock()
|
||||
|
||||
await vad_agent.setup()
|
||||
|
||||
streaming.assert_called_once()
|
||||
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
|
||||
assert vad_agent.audio_in_socket is not None
|
||||
assert vad_agent.audio_out_socket is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_bind", [True, False])
|
||||
def test_in_socket_creation(zmq_context, do_bind: bool):
|
||||
"""
|
||||
Test that the VAD agent creates an audio input socket, differentiating between binding and connecting.
|
||||
"""
|
||||
vad_agent = VADAgent(f"tcp://{"*" if do_bind else "localhost"}:12345", do_bind)
|
||||
|
||||
vad_agent._connect_audio_in_socket()
|
||||
|
||||
assert vad_agent.audio_in_socket is not None
|
||||
|
||||
zmq_context.socket.assert_called_once_with(zmq.SUB)
|
||||
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
|
||||
|
||||
if do_bind:
|
||||
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
||||
else:
|
||||
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
|
||||
|
||||
|
||||
def test_out_socket_creation(zmq_context):
|
||||
"""
|
||||
Test that the VAD agent creates an audio output socket correctly.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
vad_agent._connect_audio_out_socket()
|
||||
|
||||
assert vad_agent.audio_out_socket is not None
|
||||
|
||||
zmq_context.socket.assert_called_once_with(zmq.PUB)
|
||||
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_out_socket_creation_failure(zmq_context):
|
||||
"""
|
||||
Test setup failure when the audio output socket cannot be created.
|
||||
"""
|
||||
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
||||
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
await vad_agent.setup()
|
||||
|
||||
assert vad_agent.audio_out_socket is None
|
||||
mock_super_stop.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop(zmq_context):
|
||||
"""
|
||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
await vad_agent.setup()
|
||||
await vad_agent.stop()
|
||||
|
||||
assert zmq_context.socket.return_value.close.call_count == 2
|
||||
assert vad_agent.audio_in_socket is None
|
||||
assert vad_agent.audio_out_socket is None
|
||||
57
test/integration/agents/vad_agent/test_vad_with_audio.py
Normal file
57
test/integration/agents/vad_agent/test_vad_with_audio.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.vad_agent import Streaming
|
||||
|
||||
|
||||
def get_audio_chunks() -> list[bytes]:
|
||||
curr_file = os.path.realpath(__file__)
|
||||
curr_dir = os.path.dirname(curr_file)
|
||||
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
|
||||
|
||||
chunk_size = 512
|
||||
|
||||
chunks = []
|
||||
|
||||
with sf.SoundFile(file, 'r') as f:
|
||||
assert f.samplerate == 16000
|
||||
assert f.channels == 1
|
||||
assert f.subtype == "FLOAT"
|
||||
|
||||
while True:
|
||||
data = f.read(chunk_size, dtype="float32")
|
||||
if len(data) != chunk_size:
|
||||
break
|
||||
|
||||
chunks.append(data.tobytes())
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_audio(mocker):
|
||||
"""
|
||||
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
|
||||
input. Ensure that it outputs some fragments with audio.
|
||||
"""
|
||||
audio_chunks = get_audio_chunks()
|
||||
audio_in_socket = AsyncMock()
|
||||
audio_in_socket.recv.side_effect = audio_chunks
|
||||
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
|
||||
|
||||
audio_out_socket = AsyncMock()
|
||||
|
||||
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
|
||||
for _ in audio_chunks:
|
||||
await vad_streamer.run()
|
||||
|
||||
audio_out_socket.send.assert_called()
|
||||
for args in audio_out_socket.send.call_args_list:
|
||||
assert isinstance(args[0][0], bytes)
|
||||
assert len(args[0][0]) >= 512*4*3 # Should be at least 3 chunks of audio
|
||||
46
test/unit/agents/test_vad_socket_poller.py
Normal file
46
test/unit/agents/test_vad_socket_poller.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.vad_agent import SocketPoller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_socket_poller_with_data(socket, mocker):
|
||||
socket_data = b"test"
|
||||
socket.recv.return_value = socket_data
|
||||
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
|
||||
|
||||
poller = SocketPoller(socket)
|
||||
# Calling `poll` twice to be able to check that the poller is reused
|
||||
await poller.poll()
|
||||
data = await poller.poll()
|
||||
|
||||
assert data == socket_data
|
||||
|
||||
# Ensure that the poller was reused
|
||||
mock_poller.assert_called_once_with()
|
||||
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
|
||||
|
||||
assert socket.recv.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_socket_poller_no_data(socket, mocker):
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||
mock_poller.return_value.poll.return_value = []
|
||||
|
||||
poller = SocketPoller(socket)
|
||||
data = await poller.poll()
|
||||
|
||||
assert data is None
|
||||
|
||||
socket.recv.assert_not_called()
|
||||
@@ -21,11 +21,13 @@ def streaming(audio_in_socket, audio_out_socket):
|
||||
return Streaming(audio_in_socket, audio_out_socket)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
||||
# After three chunks of audio with speech probability of 1.0, then four chunks of audio with
|
||||
# speech probability of 0.0, it should send a message over the audio out socket
|
||||
probabilities = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||
"""
|
||||
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
||||
|
||||
:param streaming: The streaming component to be tested.
|
||||
:param probabilities: A list of probabilities representing the outputs of the VAD model.
|
||||
"""
|
||||
model_item = MagicMock()
|
||||
model_item.item.side_effect = probabilities
|
||||
streaming.model = MagicMock()
|
||||
@@ -38,8 +40,53 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
|
||||
for _ in probabilities:
|
||||
await streaming.run()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
||||
"""
|
||||
Test a scenario where there is voice activity detected between silences.
|
||||
:return:
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
probabilities = [0.0]*5 + [1.0]*speech_chunk_count + [0.0]*5
|
||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
assert isinstance(data, bytes)
|
||||
# each sample has 512 frames of 4 bytes, expecting 5 chunks (3 with speech, 2 as padding)
|
||||
assert len(data) == 512*4*5
|
||||
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
|
||||
assert len(data) == 512*4*(speech_chunk_count+2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
|
||||
"""
|
||||
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
||||
short pause.
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
probabilities = [0.0]*5 + [1.0]*speech_chunk_count + [0.0] + [1.0]*speech_chunk_count + [0.0]*5
|
||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
assert isinstance(data, bytes)
|
||||
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
|
||||
assert len(data) == 512*4*(speech_chunk_count*2+1+2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
|
||||
"""
|
||||
Test a scenario where there is no data received. This should not cause errors.
|
||||
"""
|
||||
audio_in_poller = AsyncMock()
|
||||
audio_in_poller.poll.return_value = None
|
||||
streaming.audio_in_poller = audio_in_poller
|
||||
|
||||
assert streaming.i_since_data == 0
|
||||
|
||||
await streaming.run()
|
||||
|
||||
audio_out_socket.send.assert_not_called()
|
||||
assert streaming.i_since_data == 1
|
||||
|
||||
23
uv.lock
generated
23
uv.lock
generated
@@ -1309,6 +1309,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
integration-test = [
|
||||
{ name = "soundfile" },
|
||||
]
|
||||
test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
@@ -1333,6 +1336,7 @@ requires-dist = [
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }]
|
||||
test = [
|
||||
{ name = "pytest", specifier = ">=8.4.2" },
|
||||
{ name = "pytest-asyncio", specifier = ">=1.2.0" },
|
||||
@@ -2081,6 +2085,25 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "soundfile"
|
||||
version = "0.13.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cffi" },
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e1/41/9b873a8c055582859b239be17902a85339bec6a30ad162f98c9b0288a2cc/soundfile-0.13.1.tar.gz", hash = "sha256:b2c68dab1e30297317080a5b43df57e302584c49e2942defdde0acccc53f0e5b", size = 46156, upload-time = "2025-01-25T09:17:04.831Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/64/28/e2a36573ccbcf3d57c00626a21fe51989380636e821b341d36ccca0c1c3a/soundfile-0.13.1-py2.py3-none-any.whl", hash = "sha256:a23c717560da2cf4c7b5ae1142514e0fd82d6bbd9dfc93a50423447142f2c445", size = 25751, upload-time = "2025-01-25T09:16:44.235Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/ab/73e97a5b3cc46bba7ff8650a1504348fa1863a6f9d57d7001c6b67c5f20e/soundfile-0.13.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:82dc664d19831933fe59adad199bf3945ad06d84bc111a5b4c0d3089a5b9ec33", size = 1142250, upload-time = "2025-01-25T09:16:47.583Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/e5/58fd1a8d7b26fc113af244f966ee3aecf03cb9293cb935daaddc1e455e18/soundfile-0.13.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:743f12c12c4054921e15736c6be09ac26b3b3d603aef6fd69f9dde68748f2593", size = 1101406, upload-time = "2025-01-25T09:16:49.662Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/ae/c0e4a53d77cf6e9a04179535766b3321b0b9ced5f70522e4caf9329f0046/soundfile-0.13.1-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9c9e855f5a4d06ce4213f31918653ab7de0c5a8d8107cd2427e44b42df547deb", size = 1235729, upload-time = "2025-01-25T09:16:53.018Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/5e/70bdd9579b35003a489fc850b5047beeda26328053ebadc1fb60f320f7db/soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:03267c4e493315294834a0870f31dbb3b28a95561b80b134f0bd3cf2d5f0e618", size = 1313646, upload-time = "2025-01-25T09:16:54.872Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/df/8c11dc4dfceda14e3003bb81a0d0edcaaf0796dd7b4f826ea3e532146bba/soundfile-0.13.1-py2.py3-none-win32.whl", hash = "sha256:c734564fab7c5ddf8e9be5bf70bab68042cd17e9c214c06e365e20d64f9a69d5", size = 899881, upload-time = "2025-01-25T09:16:56.663Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/e9/6b761de83277f2f02ded7e7ea6f07828ec78e4b229b80e4ca55dd205b9dc/soundfile-0.13.1-py2.py3-none-win_amd64.whl", hash = "sha256:1e70a05a0626524a69e9f0f4dd2ec174b4e9567f4d8b6c11d38b5c289be36ee9", size = 1019162, upload-time = "2025-01-25T09:16:59.573Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spade"
|
||||
version = "4.1.0"
|
||||
|
||||
Reference in New Issue
Block a user