diff --git a/README.md b/README.md index c2a8702..57b052d 100644 --- a/README.md +++ b/README.md @@ -24,10 +24,16 @@ uv run fastapi dev src/control_backend/main.py ``` ## Testing -Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following: +Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests: ```bash -uv run --only-group test pytest +uv run --only-group test pytest test/unit +``` + +Or for integration tests: + +```bash +uv run --only-group integration-test pytest test/integration ``` ## GitHooks diff --git a/pyproject.toml b/pyproject.toml index 6776668..7fadc00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ dependencies = [ ] [dependency-groups] +integration-test = [ + "soundfile>=0.13.1", +] test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index f0325c2..1e08502 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -14,18 +14,28 @@ logger = logging.getLogger(__name__) class SocketPoller[T]: - def __init__(self, socket: azmq.Socket): + """ + Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for + multiple usages. + """ + def __init__(self, socket: azmq.Socket, timeout_ms: int = 100): + """ + :param socket: The socket to poll and get data from. + :param timeout_ms: A timeout in milliseconds to wait for data. + """ self.socket = socket self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN) + self.timeout_ms = timeout_ms - async def poll(self, timeout_ms: int) -> T | None: + async def poll(self, timeout_ms: int | None = None) -> T | None: """ Get data from the socket, or None if the timeout is reached. - :param timeout_ms: The number of milliseconds to wait for the socket. + :param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used. :return: Data from the socket or None. """ + timeout_ms = timeout_ms or self.timeout_ms socks = dict(self.poller.poll(timeout_ms)) if socks.get(self.socket) == zmq.POLLIN: return await self.socket.recv() @@ -41,17 +51,16 @@ class Streaming(CyclicBehaviour): force_reload=False) self.audio_out_socket = audio_out_socket - self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor + self.audio_buffer = np.array([], dtype=np.float32) self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops - self.i_since_speech = 0 # Used to allow small pauses in speech + self.i_since_speech = 100 # Used to allow small pauses in speech async def run(self) -> None: - timeout_ms = 100 - data = await self.audio_in_poller.poll(timeout_ms) + data = await self.audio_in_poller.poll() if data is None: if self.i_since_data % 10 == 0: logger.debug("Failed to receive audio from socket for %d ms.", - timeout_ms*self.i_since_data) + self.audio_in_poller.timeout_ms*(self.i_since_data+1)) self.i_since_data += 1 return self.i_since_data = 0 @@ -75,7 +84,7 @@ class Streaming(CyclicBehaviour): # Speech probably ended. Make sure we have a usable amount of data. if len(self.audio_buffer) >= 3*len(chunk): logger.debug("Speech ended.") - await self.audio_out_socket.send(self.audio_buffer.tobytes()) + await self.audio_out_socket.send(self.audio_buffer[:-2*len(chunk)].tobytes()) # At this point, we know that the speech has ended. # Prepend the last chunk that had no speech, for a more fluent boundary @@ -101,10 +110,12 @@ class VADAgent(Agent): """ Stop listening to audio, stop publishing audio, close sockets. """ - self.audio_in_socket.close() - self.audio_in_socket = None - self.audio_out_socket.close() - self.audio_out_socket = None + if self.audio_in_socket is not None: + self.audio_in_socket.close() + self.audio_in_socket = None + if self.audio_out_socket is not None: + self.audio_out_socket.close() + self.audio_out_socket = None return await super().stop() def _connect_audio_in_socket(self): diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav b/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav new file mode 100644 index 0000000..530bc0a Binary files /dev/null and b/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav differ diff --git a/test/integration/agents/vad_agent/test_vad_agent.py b/test/integration/agents/vad_agent/test_vad_agent.py new file mode 100644 index 0000000..6357918 --- /dev/null +++ b/test/integration/agents/vad_agent/test_vad_agent.py @@ -0,0 +1,97 @@ +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest +import zmq +from spade.agent import Agent + +from control_backend.agents.vad_agent import VADAgent + + +@pytest.fixture +def zmq_context(mocker): + return mocker.patch("control_backend.agents.vad_agent.zmq_context") + + +@pytest.fixture +def streaming(mocker): + return mocker.patch("control_backend.agents.vad_agent.Streaming") + + +@pytest.mark.asyncio +async def test_normal_setup(streaming): + """ + Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio sockets. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + vad_agent.add_behaviour = MagicMock() + + await vad_agent.setup() + + streaming.assert_called_once() + vad_agent.add_behaviour.assert_called_once_with(streaming.return_value) + assert vad_agent.audio_in_socket is not None + assert vad_agent.audio_out_socket is not None + + +@pytest.mark.parametrize("do_bind", [True, False]) +def test_in_socket_creation(zmq_context, do_bind: bool): + """ + Test that the VAD agent creates an audio input socket, differentiating between binding and connecting. + """ + vad_agent = VADAgent(f"tcp://{"*" if do_bind else "localhost"}:12345", do_bind) + + vad_agent._connect_audio_in_socket() + + assert vad_agent.audio_in_socket is not None + + zmq_context.socket.assert_called_once_with(zmq.SUB) + zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") + + if do_bind: + zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") + else: + zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") + + +def test_out_socket_creation(zmq_context): + """ + Test that the VAD agent creates an audio output socket correctly. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + + vad_agent._connect_audio_out_socket() + + assert vad_agent.audio_out_socket is not None + + zmq_context.socket.assert_called_once_with(zmq.PUB) + zmq_context.socket.return_value.bind_to_random_port.assert_called_once() + + +@pytest.mark.asyncio +async def test_out_socket_creation_failure(zmq_context): + """ + Test setup failure when the audio output socket cannot be created. + """ + with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: + zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + vad_agent = VADAgent("tcp://localhost:12345", False) + + await vad_agent.setup() + + assert vad_agent.audio_out_socket is None + mock_super_stop.assert_called_once() + + +@pytest.mark.asyncio +async def test_stop(zmq_context): + """ + Test that when the VAD agent is stopped, the sockets are closed correctly. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + + await vad_agent.setup() + await vad_agent.stop() + + assert zmq_context.socket.return_value.close.call_count == 2 + assert vad_agent.audio_in_socket is None + assert vad_agent.audio_out_socket is None diff --git a/test/integration/agents/vad_agent/test_vad_with_audio.py b/test/integration/agents/vad_agent/test_vad_with_audio.py new file mode 100644 index 0000000..fa0b781 --- /dev/null +++ b/test/integration/agents/vad_agent/test_vad_with_audio.py @@ -0,0 +1,57 @@ +import os +from unittest.mock import MagicMock, AsyncMock + +import pytest +import soundfile as sf +import zmq + +from control_backend.agents.vad_agent import Streaming + + +def get_audio_chunks() -> list[bytes]: + curr_file = os.path.realpath(__file__) + curr_dir = os.path.dirname(curr_file) + file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav" + + chunk_size = 512 + + chunks = [] + + with sf.SoundFile(file, 'r') as f: + assert f.samplerate == 16000 + assert f.channels == 1 + assert f.subtype == "FLOAT" + + while True: + data = f.read(chunk_size, dtype="float32") + if len(data) != chunk_size: + break + + chunks.append(data.tobytes()) + + return chunks + + +@pytest.mark.asyncio +async def test_real_audio(mocker): + """ + Test the VAD agent with only input and output mocked. Using the real model, using real audio as + input. Ensure that it outputs some fragments with audio. + """ + audio_chunks = get_audio_chunks() + audio_in_socket = AsyncMock() + audio_in_socket.recv.side_effect = audio_chunks + + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)] + + audio_out_socket = AsyncMock() + + vad_streamer = Streaming(audio_in_socket, audio_out_socket) + for _ in audio_chunks: + await vad_streamer.run() + + audio_out_socket.send.assert_called() + for args in audio_out_socket.send.call_args_list: + assert isinstance(args[0][0], bytes) + assert len(args[0][0]) >= 512*4*3 # Should be at least 3 chunks of audio diff --git a/test/unit/agents/test_vad_socket_poller.py b/test/unit/agents/test_vad_socket_poller.py new file mode 100644 index 0000000..aaf8d0f --- /dev/null +++ b/test/unit/agents/test_vad_socket_poller.py @@ -0,0 +1,46 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +import zmq + +from control_backend.agents.vad_agent import SocketPoller + + +@pytest.fixture +def socket(): + return AsyncMock() + + +@pytest.mark.asyncio +async def test_socket_poller_with_data(socket, mocker): + socket_data = b"test" + socket.recv.return_value = socket_data + + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)] + + poller = SocketPoller(socket) + # Calling `poll` twice to be able to check that the poller is reused + await poller.poll() + data = await poller.poll() + + assert data == socket_data + + # Ensure that the poller was reused + mock_poller.assert_called_once_with() + mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN) + + assert socket.recv.call_count == 2 + + +@pytest.mark.asyncio +async def test_socket_poller_no_data(socket, mocker): + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [] + + poller = SocketPoller(socket) + data = await poller.poll() + + assert data is None + + socket.recv.assert_not_called() diff --git a/test/unit/agents/test_vad_streaming.py b/test/unit/agents/test_vad_streaming.py index c48626d..17456cc 100644 --- a/test/unit/agents/test_vad_streaming.py +++ b/test/unit/agents/test_vad_streaming.py @@ -21,11 +21,13 @@ def streaming(audio_in_socket, audio_out_socket): return Streaming(audio_in_socket, audio_out_socket) -@pytest.mark.asyncio -async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming): - # After three chunks of audio with speech probability of 1.0, then four chunks of audio with - # speech probability of 0.0, it should send a message over the audio out socket - probabilities = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0] +async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): + """ + Simulates a streaming scenario with given VAD model probabilities for testing purposes. + + :param streaming: The streaming component to be tested. + :param probabilities: A list of probabilities representing the outputs of the VAD model. + """ model_item = MagicMock() model_item.item.side_effect = probabilities streaming.model = MagicMock() @@ -38,8 +40,53 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream for _ in probabilities: await streaming.run() + +@pytest.mark.asyncio +async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is voice activity detected between silences. + :return: + """ + speech_chunk_count = 5 + probabilities = [0.0]*5 + [1.0]*speech_chunk_count + [0.0]*5 + await simulate_streaming_with_probabilities(streaming, probabilities) + audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] assert isinstance(data, bytes) - # each sample has 512 frames of 4 bytes, expecting 5 chunks (3 with speech, 2 as padding) - assert len(data) == 512*4*5 + # each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding) + assert len(data) == 512*4*(speech_chunk_count+2) + + +@pytest.mark.asyncio +async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is a short pause between speech, checking whether it ignores the + short pause. + """ + speech_chunk_count = 5 + probabilities = [0.0]*5 + [1.0]*speech_chunk_count + [0.0] + [1.0]*speech_chunk_count + [0.0]*5 + await simulate_streaming_with_probabilities(streaming, probabilities) + + audio_out_socket.send.assert_called_once() + data = audio_out_socket.send.call_args[0][0] + assert isinstance(data, bytes) + # Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding) + assert len(data) == 512*4*(speech_chunk_count*2+1+2) + + +@pytest.mark.asyncio +async def test_no_data(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is no data received. This should not cause errors. + """ + audio_in_poller = AsyncMock() + audio_in_poller.poll.return_value = None + streaming.audio_in_poller = audio_in_poller + + assert streaming.i_since_data == 0 + + await streaming.run() + + audio_out_socket.send.assert_not_called() + assert streaming.i_since_data == 1 diff --git a/uv.lock b/uv.lock index 050aa28..9e7324c 100644 --- a/uv.lock +++ b/uv.lock @@ -1309,6 +1309,9 @@ dependencies = [ ] [package.dev-dependencies] +integration-test = [ + { name = "soundfile" }, +] test = [ { name = "pytest" }, { name = "pytest-asyncio" }, @@ -1333,6 +1336,7 @@ requires-dist = [ ] [package.metadata.requires-dev] +integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }] test = [ { name = "pytest", specifier = ">=8.4.2" }, { name = "pytest-asyncio", specifier = ">=1.2.0" }, @@ -2081,6 +2085,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "soundfile" +version = "0.13.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e1/41/9b873a8c055582859b239be17902a85339bec6a30ad162f98c9b0288a2cc/soundfile-0.13.1.tar.gz", hash = "sha256:b2c68dab1e30297317080a5b43df57e302584c49e2942defdde0acccc53f0e5b", size = 46156, upload-time = "2025-01-25T09:17:04.831Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/28/e2a36573ccbcf3d57c00626a21fe51989380636e821b341d36ccca0c1c3a/soundfile-0.13.1-py2.py3-none-any.whl", hash = "sha256:a23c717560da2cf4c7b5ae1142514e0fd82d6bbd9dfc93a50423447142f2c445", size = 25751, upload-time = "2025-01-25T09:16:44.235Z" }, + { url = "https://files.pythonhosted.org/packages/ea/ab/73e97a5b3cc46bba7ff8650a1504348fa1863a6f9d57d7001c6b67c5f20e/soundfile-0.13.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:82dc664d19831933fe59adad199bf3945ad06d84bc111a5b4c0d3089a5b9ec33", size = 1142250, upload-time = "2025-01-25T09:16:47.583Z" }, + { url = "https://files.pythonhosted.org/packages/a0/e5/58fd1a8d7b26fc113af244f966ee3aecf03cb9293cb935daaddc1e455e18/soundfile-0.13.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:743f12c12c4054921e15736c6be09ac26b3b3d603aef6fd69f9dde68748f2593", size = 1101406, upload-time = "2025-01-25T09:16:49.662Z" }, + { url = "https://files.pythonhosted.org/packages/58/ae/c0e4a53d77cf6e9a04179535766b3321b0b9ced5f70522e4caf9329f0046/soundfile-0.13.1-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9c9e855f5a4d06ce4213f31918653ab7de0c5a8d8107cd2427e44b42df547deb", size = 1235729, upload-time = "2025-01-25T09:16:53.018Z" }, + { url = "https://files.pythonhosted.org/packages/57/5e/70bdd9579b35003a489fc850b5047beeda26328053ebadc1fb60f320f7db/soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:03267c4e493315294834a0870f31dbb3b28a95561b80b134f0bd3cf2d5f0e618", size = 1313646, upload-time = "2025-01-25T09:16:54.872Z" }, + { url = "https://files.pythonhosted.org/packages/fe/df/8c11dc4dfceda14e3003bb81a0d0edcaaf0796dd7b4f826ea3e532146bba/soundfile-0.13.1-py2.py3-none-win32.whl", hash = "sha256:c734564fab7c5ddf8e9be5bf70bab68042cd17e9c214c06e365e20d64f9a69d5", size = 899881, upload-time = "2025-01-25T09:16:56.663Z" }, + { url = "https://files.pythonhosted.org/packages/14/e9/6b761de83277f2f02ded7e7ea6f07828ec78e4b229b80e4ca55dd205b9dc/soundfile-0.13.1-py2.py3-none-win_amd64.whl", hash = "sha256:1e70a05a0626524a69e9f0f4dd2ec174b4e9567f4d8b6c11d38b5c289be36ee9", size = 1019162, upload-time = "2025-01-25T09:16:59.573Z" }, +] + [[package]] name = "spade" version = "4.1.0"