feat: LLM agent #12
10
README.md
10
README.md
@@ -35,10 +35,16 @@ uv run fastapi dev src/control_backend/main.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following:
|
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run --only-group test pytest
|
uv run --only-group test pytest test/unit
|
||||||
|
```
|
||||||
|
|
||||||
|
Or for integration tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run --group integration-test pytest test/integration
|
||||||
```
|
```
|
||||||
|
|
||||||
## GitHooks
|
## GitHooks
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ requires-python = ">=3.13"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi[all]>=0.115.6",
|
"fastapi[all]>=0.115.6",
|
||||||
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
||||||
|
"numpy>=2.3.3",
|
||||||
"openai-whisper>=20250625",
|
"openai-whisper>=20250625",
|
||||||
"pyaudio>=0.2.14",
|
"pyaudio>=0.2.14",
|
||||||
"pydantic>=2.12.0",
|
"pydantic>=2.12.0",
|
||||||
@@ -33,6 +34,7 @@ integration-test = [
|
|||||||
"soundfile>=0.13.1",
|
"soundfile>=0.13.1",
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
|
"numpy>=2.3.3",
|
||||||
"pytest>=8.4.2",
|
"pytest>=8.4.2",
|
||||||
"pytest-asyncio>=1.2.0",
|
"pytest-asyncio>=1.2.0",
|
||||||
"pytest-cov>=7.0.0",
|
"pytest-cov>=7.0.0",
|
||||||
|
|||||||
156
src/control_backend/agents/vad_agent.py
Normal file
156
src/control_backend/agents/vad_agent.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio as azmq
|
||||||
|
from spade.agent import Agent
|
||||||
|
from spade.behaviour import CyclicBehaviour
|
||||||
|
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.core.zmq_context import context as zmq_context
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SocketPoller[T]:
|
||||||
|
"""
|
||||||
|
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
|
||||||
|
multiple usages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
|
||||||
|
"""
|
||||||
|
:param socket: The socket to poll and get data from.
|
||||||
|
:param timeout_ms: A timeout in milliseconds to wait for data.
|
||||||
|
"""
|
||||||
|
self.socket = socket
|
||||||
|
self.poller = zmq.Poller()
|
||||||
|
self.poller.register(self.socket, zmq.POLLIN)
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
|
||||||
|
async def poll(self, timeout_ms: int | None = None) -> T | None:
|
||||||
|
"""
|
||||||
|
Get data from the socket, or None if the timeout is reached.
|
||||||
|
|
||||||
|
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
|
||||||
|
:return: Data from the socket or None.
|
||||||
|
"""
|
||||||
|
timeout_ms = timeout_ms or self.timeout_ms
|
||||||
|
socks = dict(self.poller.poll(timeout_ms))
|
||||||
|
if socks.get(self.socket) == zmq.POLLIN:
|
||||||
|
return await self.socket.recv()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Streaming(CyclicBehaviour):
|
||||||
|
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
|
||||||
|
self.model, _ = torch.hub.load(
|
||||||
|
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
|
||||||
|
)
|
||||||
|
self.audio_out_socket = audio_out_socket
|
||||||
|
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
self.i_since_speech = 100 # Used to allow small pauses in speech
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
data = await self.audio_in_poller.poll()
|
||||||
|
if data is None:
|
||||||
|
if len(self.audio_buffer) > 0:
|
||||||
|
logger.debug("No audio data received. Discarding buffer until new data arrives.")
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
self.i_since_speech = 100
|
||||||
|
return
|
||||||
|
|
||||||
|
# copy otherwise Torch will be sad that it's immutable
|
||||||
|
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
||||||
|
prob = self.model(torch.from_numpy(chunk), 16000).item()
|
||||||
|
|
||||||
|
if prob > 0.5:
|
||||||
|
if self.i_since_speech > 3:
|
||||||
|
logger.debug("Speech started.")
|
||||||
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||||
|
self.i_since_speech = 0
|
||||||
|
return
|
||||||
|
self.i_since_speech += 1
|
||||||
|
|
||||||
|
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
|
||||||
|
if self.i_since_speech <= 3:
|
||||||
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Speech probably ended. Make sure we have a usable amount of data.
|
||||||
|
if len(self.audio_buffer) >= 3 * len(chunk):
|
||||||
|
logger.debug("Speech ended.")
|
||||||
|
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
|
||||||
|
|
||||||
|
# At this point, we know that the speech has ended.
|
||||||
|
# Prepend the last chunk that had no speech, for a more fluent boundary
|
||||||
|
self.audio_buffer = chunk
|
||||||
|
|
||||||
|
|
||||||
|
class VADAgent(Agent):
|
||||||
|
"""
|
||||||
|
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
|
||||||
|
fragments with detected speech to other agents over ZeroMQ.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||||
|
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host
|
||||||
|
super().__init__(jid, settings.agent_settings.vad_agent_name)
|
||||||
|
|
||||||
|
self.audio_in_address = audio_in_address
|
||||||
|
self.audio_in_bind = audio_in_bind
|
||||||
|
|
||||||
|
self.audio_in_socket: azmq.Socket | None = None
|
||||||
|
self.audio_out_socket: azmq.Socket | None = None
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""
|
||||||
|
Stop listening to audio, stop publishing audio, close sockets.
|
||||||
|
"""
|
||||||
|
if self.audio_in_socket is not None:
|
||||||
|
self.audio_in_socket.close()
|
||||||
|
self.audio_in_socket = None
|
||||||
|
if self.audio_out_socket is not None:
|
||||||
|
self.audio_out_socket.close()
|
||||||
|
self.audio_out_socket = None
|
||||||
|
return await super().stop()
|
||||||
|
|
||||||
|
def _connect_audio_in_socket(self):
|
||||||
|
self.audio_in_socket = zmq_context.socket(zmq.SUB)
|
||||||
|
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
if self.audio_in_bind:
|
||||||
|
self.audio_in_socket.bind(self.audio_in_address)
|
||||||
|
else:
|
||||||
|
self.audio_in_socket.connect(self.audio_in_address)
|
||||||
|
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
|
||||||
|
|
||||||
|
def _connect_audio_out_socket(self) -> int | None:
|
||||||
|
"""Returns the port bound, or None if binding failed."""
|
||||||
|
try:
|
||||||
|
self.audio_out_socket = zmq_context.socket(zmq.PUB)
|
||||||
|
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
|
||||||
|
except zmq.ZMQBindError:
|
||||||
|
logger.error("Failed to bind an audio output socket after 100 tries.")
|
||||||
|
self.audio_out_socket = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
logger.info("Setting up %s", self.jid)
|
||||||
|
|
||||||
|
self._connect_audio_in_socket()
|
||||||
|
|
||||||
|
audio_out_port = self._connect_audio_out_socket()
|
||||||
|
if audio_out_port is None:
|
||||||
|
await self.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
|
||||||
|
self.add_behaviour(streaming)
|
||||||
|
|
||||||
|
# ... start agents dependent on the output audio fragments here
|
||||||
|
|
||||||
|
logger.info("Finished setting up %s", self.jid)
|
||||||
@@ -10,6 +10,7 @@ class AgentSettings(BaseModel):
|
|||||||
host: str = "xmpp.twirre.dev"
|
host: str = "xmpp.twirre.dev"
|
||||||
bdi_core_agent_name: str = "bdi_core"
|
bdi_core_agent_name: str = "bdi_core"
|
||||||
belief_collector_agent_name: str = "belief_collector"
|
belief_collector_agent_name: str = "belief_collector"
|
||||||
|
vad_agent_name: str = "vad_agent"
|
||||||
llm_agent_name: str = "llm_agent"
|
llm_agent_name: str = "llm_agent"
|
||||||
test_agent_name: str = "test_agent"
|
test_agent_name: str = "test_agent"
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
# Internal imports
|
# Internal imports
|
||||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||||
from control_backend.agents.bdi.bdi_core import BDICoreAgent
|
from control_backend.agents.bdi.bdi_core import BDICoreAgent
|
||||||
|
from control_backend.agents.vad_agent import VADAgent
|
||||||
from control_backend.agents.llm.llm import LLMAgent
|
from control_backend.agents.llm.llm import LLMAgent
|
||||||
from control_backend.api.v1.router import api_router
|
from control_backend.api.v1.router import api_router
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
@@ -49,6 +50,9 @@ async def lifespan(app: FastAPI):
|
|||||||
"secret, ask twirre", "src/control_backend/agents/bdi/rules.asl")
|
"secret, ask twirre", "src/control_backend/agents/bdi/rules.asl")
|
||||||
await bdi_core.start()
|
await bdi_core.start()
|
||||||
|
|
||||||
|
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
|
||||||
|
await _temp_vad_agent.start()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
logger.info("%s shutting down.", app.title)
|
logger.info("%s shutting down.", app.title)
|
||||||
|
|||||||
Binary file not shown.
99
test/integration/agents/vad_agent/test_vad_agent.py
Normal file
99
test/integration/agents/vad_agent/test_vad_agent.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import zmq
|
||||||
|
from spade.agent import Agent
|
||||||
|
|
||||||
|
from control_backend.agents.vad_agent import VADAgent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def zmq_context(mocker):
|
||||||
|
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def streaming(mocker):
|
||||||
|
return mocker.patch("control_backend.agents.vad_agent.Streaming")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_normal_setup(streaming):
|
||||||
|
"""
|
||||||
|
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
|
||||||
|
sockets.
|
||||||
|
"""
|
||||||
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
vad_agent.add_behaviour = MagicMock()
|
||||||
|
|
||||||
|
await vad_agent.setup()
|
||||||
|
|
||||||
|
streaming.assert_called_once()
|
||||||
|
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
|
||||||
|
assert vad_agent.audio_in_socket is not None
|
||||||
|
assert vad_agent.audio_out_socket is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("do_bind", [True, False])
|
||||||
|
def test_in_socket_creation(zmq_context, do_bind: bool):
|
||||||
|
"""
|
||||||
|
Test that the VAD agent creates an audio input socket, differentiating between binding and
|
||||||
|
connecting.
|
||||||
|
"""
|
||||||
|
vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
|
||||||
|
|
||||||
|
vad_agent._connect_audio_in_socket()
|
||||||
|
|
||||||
|
assert vad_agent.audio_in_socket is not None
|
||||||
|
|
||||||
|
zmq_context.socket.assert_called_once_with(zmq.SUB)
|
||||||
|
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
|
||||||
|
|
||||||
|
if do_bind:
|
||||||
|
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
||||||
|
else:
|
||||||
|
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
|
||||||
|
|
||||||
|
|
||||||
|
def test_out_socket_creation(zmq_context):
|
||||||
|
"""
|
||||||
|
Test that the VAD agent creates an audio output socket correctly.
|
||||||
|
"""
|
||||||
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
|
||||||
|
vad_agent._connect_audio_out_socket()
|
||||||
|
|
||||||
|
assert vad_agent.audio_out_socket is not None
|
||||||
|
|
||||||
|
zmq_context.socket.assert_called_once_with(zmq.PUB)
|
||||||
|
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_out_socket_creation_failure(zmq_context):
|
||||||
|
"""
|
||||||
|
Test setup failure when the audio output socket cannot be created.
|
||||||
|
"""
|
||||||
|
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
||||||
|
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||||
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
|
||||||
|
await vad_agent.setup()
|
||||||
|
|
||||||
|
assert vad_agent.audio_out_socket is None
|
||||||
|
mock_super_stop.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop(zmq_context):
|
||||||
|
"""
|
||||||
|
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||||
|
"""
|
||||||
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
|
||||||
|
await vad_agent.setup()
|
||||||
|
await vad_agent.stop()
|
||||||
|
|
||||||
|
assert zmq_context.socket.return_value.close.call_count == 2
|
||||||
|
assert vad_agent.audio_in_socket is None
|
||||||
|
assert vad_agent.audio_out_socket is None
|
||||||
57
test/integration/agents/vad_agent/test_vad_with_audio.py
Normal file
57
test/integration/agents/vad_agent/test_vad_with_audio.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import soundfile as sf
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from control_backend.agents.vad_agent import Streaming
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_chunks() -> list[bytes]:
|
||||||
|
curr_file = os.path.realpath(__file__)
|
||||||
|
curr_dir = os.path.dirname(curr_file)
|
||||||
|
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
|
||||||
|
|
||||||
|
chunk_size = 512
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
with sf.SoundFile(file, "r") as f:
|
||||||
|
assert f.samplerate == 16000
|
||||||
|
assert f.channels == 1
|
||||||
|
assert f.subtype == "FLOAT"
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = f.read(chunk_size, dtype="float32")
|
||||||
|
if len(data) != chunk_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
chunks.append(data.tobytes())
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_audio(mocker):
|
||||||
|
"""
|
||||||
|
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
|
||||||
|
input. Ensure that it outputs some fragments with audio.
|
||||||
|
"""
|
||||||
|
audio_chunks = get_audio_chunks()
|
||||||
|
audio_in_socket = AsyncMock()
|
||||||
|
audio_in_socket.recv.side_effect = audio_chunks
|
||||||
|
|
||||||
|
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||||
|
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
|
||||||
|
|
||||||
|
audio_out_socket = AsyncMock()
|
||||||
|
|
||||||
|
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
|
||||||
|
for _ in audio_chunks:
|
||||||
|
await vad_streamer.run()
|
||||||
|
|
||||||
|
audio_out_socket.send.assert_called()
|
||||||
|
for args in audio_out_socket.send.call_args_list:
|
||||||
|
assert isinstance(args[0][0], bytes)
|
||||||
|
assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio
|
||||||
46
test/unit/agents/test_vad_socket_poller.py
Normal file
46
test/unit/agents/test_vad_socket_poller.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from control_backend.agents.vad_agent import SocketPoller
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def socket():
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_socket_poller_with_data(socket, mocker):
|
||||||
|
socket_data = b"test"
|
||||||
|
socket.recv.return_value = socket_data
|
||||||
|
|
||||||
|
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||||
|
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
|
||||||
|
|
||||||
|
poller = SocketPoller(socket)
|
||||||
|
# Calling `poll` twice to be able to check that the poller is reused
|
||||||
|
await poller.poll()
|
||||||
|
data = await poller.poll()
|
||||||
|
|
||||||
|
assert data == socket_data
|
||||||
|
|
||||||
|
# Ensure that the poller was reused
|
||||||
|
mock_poller.assert_called_once_with()
|
||||||
|
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
|
||||||
|
|
||||||
|
assert socket.recv.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_socket_poller_no_data(socket, mocker):
|
||||||
|
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||||
|
mock_poller.return_value.poll.return_value = []
|
||||||
|
|
||||||
|
poller = SocketPoller(socket)
|
||||||
|
data = await poller.poll()
|
||||||
|
|
||||||
|
assert data is None
|
||||||
|
|
||||||
|
socket.recv.assert_not_called()
|
||||||
95
test/unit/agents/test_vad_streaming.py
Normal file
95
test/unit/agents/test_vad_streaming.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from control_backend.agents.vad_agent import Streaming
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_in_socket():
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_out_socket():
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def streaming(audio_in_socket, audio_out_socket):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.hub.load.return_value = (..., ...) # Mock
|
||||||
|
return Streaming(audio_in_socket, audio_out_socket)
|
||||||
|
|
||||||
|
|
||||||
|
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||||
|
"""
|
||||||
|
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
||||||
|
|
||||||
|
:param streaming: The streaming component to be tested.
|
||||||
|
:param probabilities: A list of probabilities representing the outputs of the VAD model.
|
||||||
|
"""
|
||||||
|
model_item = MagicMock()
|
||||||
|
model_item.item.side_effect = probabilities
|
||||||
|
streaming.model = MagicMock()
|
||||||
|
streaming.model.return_value = model_item
|
||||||
|
|
||||||
|
audio_in_poller = AsyncMock()
|
||||||
|
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
|
||||||
|
streaming.audio_in_poller = audio_in_poller
|
||||||
|
|
||||||
|
for _ in probabilities:
|
||||||
|
await streaming.run()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
||||||
|
"""
|
||||||
|
Test a scenario where there is voice activity detected between silences.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
speech_chunk_count = 5
|
||||||
|
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||||
|
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||||
|
|
||||||
|
audio_out_socket.send.assert_called_once()
|
||||||
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
|
assert isinstance(data, bytes)
|
||||||
|
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
|
||||||
|
assert len(data) == 512 * 4 * (speech_chunk_count + 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
|
||||||
|
"""
|
||||||
|
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
||||||
|
short pause.
|
||||||
|
"""
|
||||||
|
speech_chunk_count = 5
|
||||||
|
probabilities = (
|
||||||
|
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
||||||
|
)
|
||||||
|
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||||
|
|
||||||
|
audio_out_socket.send.assert_called_once()
|
||||||
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
|
assert isinstance(data, bytes)
|
||||||
|
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
|
||||||
|
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
|
||||||
|
"""
|
||||||
|
Test a scenario where there is no data received. This should not cause errors.
|
||||||
|
"""
|
||||||
|
audio_in_poller = AsyncMock()
|
||||||
|
audio_in_poller.poll.return_value = None
|
||||||
|
streaming.audio_in_poller = audio_in_poller
|
||||||
|
|
||||||
|
await streaming.run()
|
||||||
|
|
||||||
|
audio_out_socket.send.assert_not_called()
|
||||||
|
assert len(streaming.audio_buffer) == 0
|
||||||
@@ -33,3 +33,13 @@ def pytest_configure(config):
|
|||||||
mock_config_module.settings = MagicMock()
|
mock_config_module.settings = MagicMock()
|
||||||
|
|
||||||
sys.modules["control_backend.core.config"] = mock_config_module
|
sys.modules["control_backend.core.config"] = mock_config_module
|
||||||
|
|
||||||
|
# --- Mock torch and zmq for VAD ---
|
||||||
|
mock_torch = MagicMock()
|
||||||
|
mock_zmq = MagicMock()
|
||||||
|
mock_zmq.asyncio = mock_zmq
|
||||||
|
|
||||||
|
# In individual tests, these can be imported and the return values changed
|
||||||
|
sys.modules["torch"] = mock_torch
|
||||||
|
sys.modules["zmq"] = mock_zmq
|
||||||
|
sys.modules["zmq.asyncio"] = mock_zmq.asyncio
|
||||||
|
|||||||
4
uv.lock
generated
4
uv.lock
generated
@@ -1332,6 +1332,7 @@ source = { virtual = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "fastapi", extra = ["all"] },
|
{ name = "fastapi", extra = ["all"] },
|
||||||
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'" },
|
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'" },
|
||||||
|
{ name = "numpy" },
|
||||||
{ name = "openai-whisper" },
|
{ name = "openai-whisper" },
|
||||||
{ name = "pyaudio" },
|
{ name = "pyaudio" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
@@ -1358,6 +1359,7 @@ integration-test = [
|
|||||||
{ name = "soundfile" },
|
{ name = "soundfile" },
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
|
{ name = "numpy" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
@@ -1368,6 +1370,7 @@ test = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "fastapi", extras = ["all"], specifier = ">=0.115.6" },
|
{ name = "fastapi", extras = ["all"], specifier = ">=0.115.6" },
|
||||||
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'", specifier = ">=0.4.3" },
|
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'", specifier = ">=0.4.3" },
|
||||||
|
{ name = "numpy", specifier = ">=2.3.3" },
|
||||||
{ name = "openai-whisper", specifier = ">=20250625" },
|
{ name = "openai-whisper", specifier = ">=20250625" },
|
||||||
{ name = "pyaudio", specifier = ">=0.2.14" },
|
{ name = "pyaudio", specifier = ">=0.2.14" },
|
||||||
{ name = "pydantic", specifier = ">=2.12.0" },
|
{ name = "pydantic", specifier = ">=2.12.0" },
|
||||||
@@ -1392,6 +1395,7 @@ dev = [
|
|||||||
]
|
]
|
||||||
integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }]
|
integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }]
|
||||||
test = [
|
test = [
|
||||||
|
{ name = "numpy", specifier = ">=2.3.3" },
|
||||||
{ name = "pytest", specifier = ">=8.4.2" },
|
{ name = "pytest", specifier = ">=8.4.2" },
|
||||||
{ name = "pytest-asyncio", specifier = ">=1.2.0" },
|
{ name = "pytest-asyncio", specifier = ">=1.2.0" },
|
||||||
{ name = "pytest-cov", specifier = ">=7.0.0" },
|
{ name = "pytest-cov", specifier = ">=7.0.0" },
|
||||||
|
|||||||
Reference in New Issue
Block a user