diff --git a/README.md b/README.md index f51a14e..45f8f98 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,16 @@ uv run fastapi dev src/control_backend/main.py ``` ## Testing -Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following: +Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests: ```bash -uv run --only-group test pytest +uv run --only-group test pytest test/unit +``` + +Or for integration tests: + +```bash +uv run --group integration-test pytest test/integration ``` ## GitHooks diff --git a/pyproject.toml b/pyproject.toml index 8299d0f..ee3ca08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.13" dependencies = [ "fastapi[all]>=0.115.6", "mlx-whisper>=0.4.3 ; sys_platform == 'darwin'", + "numpy>=2.3.3", "openai-whisper>=20250625", "pyaudio>=0.2.14", "pydantic>=2.12.0", @@ -33,6 +34,7 @@ integration-test = [ "soundfile>=0.13.1", ] test = [ + "numpy>=2.3.3", "pytest>=8.4.2", "pytest-asyncio>=1.2.0", "pytest-cov>=7.0.0", diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py new file mode 100644 index 0000000..7b87fbb --- /dev/null +++ b/src/control_backend/agents/vad_agent.py @@ -0,0 +1,156 @@ +import logging + +import numpy as np +import torch +import zmq +import zmq.asyncio as azmq +from spade.agent import Agent +from spade.behaviour import CyclicBehaviour + +from control_backend.core.config import settings +from control_backend.core.zmq_context import context as zmq_context + +logger = logging.getLogger(__name__) + + +class SocketPoller[T]: + """ + Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for + multiple usages. + """ + + def __init__(self, socket: azmq.Socket, timeout_ms: int = 100): + """ + :param socket: The socket to poll and get data from. + :param timeout_ms: A timeout in milliseconds to wait for data. + """ + self.socket = socket + self.poller = zmq.Poller() + self.poller.register(self.socket, zmq.POLLIN) + self.timeout_ms = timeout_ms + + async def poll(self, timeout_ms: int | None = None) -> T | None: + """ + Get data from the socket, or None if the timeout is reached. + + :param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used. + :return: Data from the socket or None. + """ + timeout_ms = timeout_ms or self.timeout_ms + socks = dict(self.poller.poll(timeout_ms)) + if socks.get(self.socket) == zmq.POLLIN: + return await self.socket.recv() + return None + + +class Streaming(CyclicBehaviour): + def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket): + super().__init__() + self.audio_in_poller = SocketPoller[bytes](audio_in_socket) + self.model, _ = torch.hub.load( + repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False + ) + self.audio_out_socket = audio_out_socket + + self.audio_buffer = np.array([], dtype=np.float32) + self.i_since_speech = 100 # Used to allow small pauses in speech + + async def run(self) -> None: + data = await self.audio_in_poller.poll() + if data is None: + if len(self.audio_buffer) > 0: + logger.debug("No audio data received. Discarding buffer until new data arrives.") + self.audio_buffer = np.array([], dtype=np.float32) + self.i_since_speech = 100 + return + + # copy otherwise Torch will be sad that it's immutable + chunk = np.frombuffer(data, dtype=np.float32).copy() + prob = self.model(torch.from_numpy(chunk), 16000).item() + + if prob > 0.5: + if self.i_since_speech > 3: + logger.debug("Speech started.") + self.audio_buffer = np.append(self.audio_buffer, chunk) + self.i_since_speech = 0 + return + self.i_since_speech += 1 + + # prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain + if self.i_since_speech <= 3: + self.audio_buffer = np.append(self.audio_buffer, chunk) + return + + # Speech probably ended. Make sure we have a usable amount of data. + if len(self.audio_buffer) >= 3 * len(chunk): + logger.debug("Speech ended.") + await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes()) + + # At this point, we know that the speech has ended. + # Prepend the last chunk that had no speech, for a more fluent boundary + self.audio_buffer = chunk + + +class VADAgent(Agent): + """ + An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends + fragments with detected speech to other agents over ZeroMQ. + """ + + def __init__(self, audio_in_address: str, audio_in_bind: bool): + jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host + super().__init__(jid, settings.agent_settings.vad_agent_name) + + self.audio_in_address = audio_in_address + self.audio_in_bind = audio_in_bind + + self.audio_in_socket: azmq.Socket | None = None + self.audio_out_socket: azmq.Socket | None = None + + async def stop(self): + """ + Stop listening to audio, stop publishing audio, close sockets. + """ + if self.audio_in_socket is not None: + self.audio_in_socket.close() + self.audio_in_socket = None + if self.audio_out_socket is not None: + self.audio_out_socket.close() + self.audio_out_socket = None + return await super().stop() + + def _connect_audio_in_socket(self): + self.audio_in_socket = zmq_context.socket(zmq.SUB) + self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") + if self.audio_in_bind: + self.audio_in_socket.bind(self.audio_in_address) + else: + self.audio_in_socket.connect(self.audio_in_address) + self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket) + + def _connect_audio_out_socket(self) -> int | None: + """Returns the port bound, or None if binding failed.""" + try: + self.audio_out_socket = zmq_context.socket(zmq.PUB) + return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) + except zmq.ZMQBindError: + logger.error("Failed to bind an audio output socket after 100 tries.") + self.audio_out_socket = None + return None + + async def setup(self): + logger.info("Setting up %s", self.jid) + + self._connect_audio_in_socket() + + audio_out_port = self._connect_audio_out_socket() + if audio_out_port is None: + await self.stop() + return + + streaming = Streaming(self.audio_in_socket, self.audio_out_socket) + self.add_behaviour(streaming) + + # ... start agents dependent on the output audio fragments here + + logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index ab17f74..34032ba 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -10,6 +10,7 @@ class AgentSettings(BaseModel): host: str = "xmpp.twirre.dev" bdi_core_agent_name: str = "bdi_core" belief_collector_agent_name: str = "belief_collector" + vad_agent_name: str = "vad_agent" llm_agent_name: str = "llm_agent" test_agent_name: str = "test_agent" diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 6b280f3..de357d8 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -11,6 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware # Internal imports from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.bdi.bdi_core import BDICoreAgent +from control_backend.agents.vad_agent import VADAgent from control_backend.agents.llm.llm import LLMAgent from control_backend.api.v1.router import api_router from control_backend.core.config import settings @@ -48,7 +49,10 @@ async def lifespan(app: FastAPI): bdi_core = BDICoreAgent(settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, "secret, ask twirre", "src/control_backend/agents/bdi/rules.asl") await bdi_core.start() - + + _temp_vad_agent = VADAgent("tcp://localhost:5558", False) + await _temp_vad_agent.start() + yield logger.info("%s shutting down.", app.title) diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav b/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav new file mode 100644 index 0000000..530bc0a Binary files /dev/null and b/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav differ diff --git a/test/integration/agents/vad_agent/test_vad_agent.py b/test/integration/agents/vad_agent/test_vad_agent.py new file mode 100644 index 0000000..293912e --- /dev/null +++ b/test/integration/agents/vad_agent/test_vad_agent.py @@ -0,0 +1,99 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import zmq +from spade.agent import Agent + +from control_backend.agents.vad_agent import VADAgent + + +@pytest.fixture +def zmq_context(mocker): + return mocker.patch("control_backend.agents.vad_agent.zmq_context") + + +@pytest.fixture +def streaming(mocker): + return mocker.patch("control_backend.agents.vad_agent.Streaming") + + +@pytest.mark.asyncio +async def test_normal_setup(streaming): + """ + Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio + sockets. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + vad_agent.add_behaviour = MagicMock() + + await vad_agent.setup() + + streaming.assert_called_once() + vad_agent.add_behaviour.assert_called_once_with(streaming.return_value) + assert vad_agent.audio_in_socket is not None + assert vad_agent.audio_out_socket is not None + + +@pytest.mark.parametrize("do_bind", [True, False]) +def test_in_socket_creation(zmq_context, do_bind: bool): + """ + Test that the VAD agent creates an audio input socket, differentiating between binding and + connecting. + """ + vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind) + + vad_agent._connect_audio_in_socket() + + assert vad_agent.audio_in_socket is not None + + zmq_context.socket.assert_called_once_with(zmq.SUB) + zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") + + if do_bind: + zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") + else: + zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") + + +def test_out_socket_creation(zmq_context): + """ + Test that the VAD agent creates an audio output socket correctly. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + + vad_agent._connect_audio_out_socket() + + assert vad_agent.audio_out_socket is not None + + zmq_context.socket.assert_called_once_with(zmq.PUB) + zmq_context.socket.return_value.bind_to_random_port.assert_called_once() + + +@pytest.mark.asyncio +async def test_out_socket_creation_failure(zmq_context): + """ + Test setup failure when the audio output socket cannot be created. + """ + with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: + zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + vad_agent = VADAgent("tcp://localhost:12345", False) + + await vad_agent.setup() + + assert vad_agent.audio_out_socket is None + mock_super_stop.assert_called_once() + + +@pytest.mark.asyncio +async def test_stop(zmq_context): + """ + Test that when the VAD agent is stopped, the sockets are closed correctly. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + + await vad_agent.setup() + await vad_agent.stop() + + assert zmq_context.socket.return_value.close.call_count == 2 + assert vad_agent.audio_in_socket is None + assert vad_agent.audio_out_socket is None diff --git a/test/integration/agents/vad_agent/test_vad_with_audio.py b/test/integration/agents/vad_agent/test_vad_with_audio.py new file mode 100644 index 0000000..7d10aa3 --- /dev/null +++ b/test/integration/agents/vad_agent/test_vad_with_audio.py @@ -0,0 +1,57 @@ +import os +from unittest.mock import AsyncMock, MagicMock + +import pytest +import soundfile as sf +import zmq + +from control_backend.agents.vad_agent import Streaming + + +def get_audio_chunks() -> list[bytes]: + curr_file = os.path.realpath(__file__) + curr_dir = os.path.dirname(curr_file) + file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav" + + chunk_size = 512 + + chunks = [] + + with sf.SoundFile(file, "r") as f: + assert f.samplerate == 16000 + assert f.channels == 1 + assert f.subtype == "FLOAT" + + while True: + data = f.read(chunk_size, dtype="float32") + if len(data) != chunk_size: + break + + chunks.append(data.tobytes()) + + return chunks + + +@pytest.mark.asyncio +async def test_real_audio(mocker): + """ + Test the VAD agent with only input and output mocked. Using the real model, using real audio as + input. Ensure that it outputs some fragments with audio. + """ + audio_chunks = get_audio_chunks() + audio_in_socket = AsyncMock() + audio_in_socket.recv.side_effect = audio_chunks + + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)] + + audio_out_socket = AsyncMock() + + vad_streamer = Streaming(audio_in_socket, audio_out_socket) + for _ in audio_chunks: + await vad_streamer.run() + + audio_out_socket.send.assert_called() + for args in audio_out_socket.send.call_args_list: + assert isinstance(args[0][0], bytes) + assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio diff --git a/test/unit/agents/test_vad_socket_poller.py b/test/unit/agents/test_vad_socket_poller.py new file mode 100644 index 0000000..aaf8d0f --- /dev/null +++ b/test/unit/agents/test_vad_socket_poller.py @@ -0,0 +1,46 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +import zmq + +from control_backend.agents.vad_agent import SocketPoller + + +@pytest.fixture +def socket(): + return AsyncMock() + + +@pytest.mark.asyncio +async def test_socket_poller_with_data(socket, mocker): + socket_data = b"test" + socket.recv.return_value = socket_data + + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)] + + poller = SocketPoller(socket) + # Calling `poll` twice to be able to check that the poller is reused + await poller.poll() + data = await poller.poll() + + assert data == socket_data + + # Ensure that the poller was reused + mock_poller.assert_called_once_with() + mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN) + + assert socket.recv.call_count == 2 + + +@pytest.mark.asyncio +async def test_socket_poller_no_data(socket, mocker): + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [] + + poller = SocketPoller(socket) + data = await poller.poll() + + assert data is None + + socket.recv.assert_not_called() diff --git a/test/unit/agents/test_vad_streaming.py b/test/unit/agents/test_vad_streaming.py new file mode 100644 index 0000000..9b38cd0 --- /dev/null +++ b/test/unit/agents/test_vad_streaming.py @@ -0,0 +1,95 @@ +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest + +from control_backend.agents.vad_agent import Streaming + + +@pytest.fixture +def audio_in_socket(): + return AsyncMock() + + +@pytest.fixture +def audio_out_socket(): + return AsyncMock() + + +@pytest.fixture +def streaming(audio_in_socket, audio_out_socket): + import torch + + torch.hub.load.return_value = (..., ...) # Mock + return Streaming(audio_in_socket, audio_out_socket) + + +async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): + """ + Simulates a streaming scenario with given VAD model probabilities for testing purposes. + + :param streaming: The streaming component to be tested. + :param probabilities: A list of probabilities representing the outputs of the VAD model. + """ + model_item = MagicMock() + model_item.item.side_effect = probabilities + streaming.model = MagicMock() + streaming.model.return_value = model_item + + audio_in_poller = AsyncMock() + audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32) + streaming.audio_in_poller = audio_in_poller + + for _ in probabilities: + await streaming.run() + + +@pytest.mark.asyncio +async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is voice activity detected between silences. + :return: + """ + speech_chunk_count = 5 + probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5 + await simulate_streaming_with_probabilities(streaming, probabilities) + + audio_out_socket.send.assert_called_once() + data = audio_out_socket.send.call_args[0][0] + assert isinstance(data, bytes) + # each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding) + assert len(data) == 512 * 4 * (speech_chunk_count + 2) + + +@pytest.mark.asyncio +async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is a short pause between speech, checking whether it ignores the + short pause. + """ + speech_chunk_count = 5 + probabilities = ( + [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5 + ) + await simulate_streaming_with_probabilities(streaming, probabilities) + + audio_out_socket.send.assert_called_once() + data = audio_out_socket.send.call_args[0][0] + assert isinstance(data, bytes) + # Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding) + assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2) + + +@pytest.mark.asyncio +async def test_no_data(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is no data received. This should not cause errors. + """ + audio_in_poller = AsyncMock() + audio_in_poller.poll.return_value = None + streaming.audio_in_poller = audio_in_poller + + await streaming.run() + + audio_out_socket.send.assert_not_called() + assert len(streaming.audio_buffer) == 0 diff --git a/test/unit/conftest.py b/test/unit/conftest.py index d7c10f2..76ef272 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -33,3 +33,13 @@ def pytest_configure(config): mock_config_module.settings = MagicMock() sys.modules["control_backend.core.config"] = mock_config_module + + # --- Mock torch and zmq for VAD --- + mock_torch = MagicMock() + mock_zmq = MagicMock() + mock_zmq.asyncio = mock_zmq + + # In individual tests, these can be imported and the return values changed + sys.modules["torch"] = mock_torch + sys.modules["zmq"] = mock_zmq + sys.modules["zmq.asyncio"] = mock_zmq.asyncio diff --git a/uv.lock b/uv.lock index 07ec3c1..c2bb61a 100644 --- a/uv.lock +++ b/uv.lock @@ -1332,6 +1332,7 @@ source = { virtual = "." } dependencies = [ { name = "fastapi", extra = ["all"] }, { name = "mlx-whisper", marker = "sys_platform == 'darwin'" }, + { name = "numpy" }, { name = "openai-whisper" }, { name = "pyaudio" }, { name = "pydantic" }, @@ -1358,6 +1359,7 @@ integration-test = [ { name = "soundfile" }, ] test = [ + { name = "numpy" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -1368,6 +1370,7 @@ test = [ requires-dist = [ { name = "fastapi", extras = ["all"], specifier = ">=0.115.6" }, { name = "mlx-whisper", marker = "sys_platform == 'darwin'", specifier = ">=0.4.3" }, + { name = "numpy", specifier = ">=2.3.3" }, { name = "openai-whisper", specifier = ">=20250625" }, { name = "pyaudio", specifier = ">=0.2.14" }, { name = "pydantic", specifier = ">=2.12.0" }, @@ -1392,6 +1395,7 @@ dev = [ ] integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }] test = [ + { name = "numpy", specifier = ">=2.3.3" }, { name = "pytest", specifier = ">=8.4.2" }, { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-cov", specifier = ">=7.0.0" },