test: add first unit test for VAD agent
Mocking audio input probabilities, checking whether it publishes audio data on the output socket. ref: N25B-213
This commit is contained in:
@@ -3,6 +3,7 @@ import logging
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
import zmq.asyncio as azmq
|
||||||
from spade.agent import Agent
|
from spade.agent import Agent
|
||||||
from spade.behaviour import CyclicBehaviour
|
from spade.behaviour import CyclicBehaviour
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class SocketPoller[T]:
|
class SocketPoller[T]:
|
||||||
def __init__(self, socket: zmq.Socket[T]):
|
def __init__(self, socket: azmq.Socket):
|
||||||
self.socket = socket
|
self.socket = socket
|
||||||
self.poller = zmq.Poller()
|
self.poller = zmq.Poller()
|
||||||
self.poller.register(self.socket, zmq.POLLIN)
|
self.poller.register(self.socket, zmq.POLLIN)
|
||||||
@@ -31,6 +32,56 @@ class SocketPoller[T]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Streaming(CyclicBehaviour):
|
||||||
|
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
|
||||||
|
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
|
||||||
|
model="silero_vad",
|
||||||
|
force_reload=False)
|
||||||
|
self.audio_out_socket = audio_out_socket
|
||||||
|
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
|
||||||
|
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
|
||||||
|
self.i_since_speech = 0 # Used to allow small pauses in speech
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
timeout_ms = 100
|
||||||
|
data = await self.audio_in_poller.poll(timeout_ms)
|
||||||
|
if data is None:
|
||||||
|
if self.i_since_data % 10 == 0:
|
||||||
|
logger.debug("Failed to receive audio from socket for %d ms.",
|
||||||
|
timeout_ms*self.i_since_data)
|
||||||
|
self.i_since_data += 1
|
||||||
|
return
|
||||||
|
self.i_since_data = 0
|
||||||
|
|
||||||
|
# copy otherwise Torch will be sad that it's immutable
|
||||||
|
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
||||||
|
prob = self.model(torch.from_numpy(chunk), 16000).item()
|
||||||
|
|
||||||
|
if prob > 0.5:
|
||||||
|
if self.i_since_speech > 3: logger.debug("Speech started.")
|
||||||
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||||
|
self.i_since_speech = 0
|
||||||
|
return
|
||||||
|
self.i_since_speech += 1
|
||||||
|
|
||||||
|
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
|
||||||
|
if self.i_since_speech <= 3:
|
||||||
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Speech probably ended. Make sure we have a usable amount of data.
|
||||||
|
if len(self.audio_buffer) >= 3*len(chunk):
|
||||||
|
logger.debug("Speech ended.")
|
||||||
|
await self.audio_out_socket.send(self.audio_buffer.tobytes())
|
||||||
|
|
||||||
|
# At this point, we know that the speech has ended.
|
||||||
|
# Prepend the last chunk that had no speech, for a more fluent boundary
|
||||||
|
self.audio_buffer = chunk
|
||||||
|
|
||||||
|
|
||||||
class VADAgent(Agent):
|
class VADAgent(Agent):
|
||||||
"""
|
"""
|
||||||
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
|
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
|
||||||
@@ -43,57 +94,8 @@ class VADAgent(Agent):
|
|||||||
self.audio_in_address = audio_in_address
|
self.audio_in_address = audio_in_address
|
||||||
self.audio_in_bind = audio_in_bind
|
self.audio_in_bind = audio_in_bind
|
||||||
|
|
||||||
self.audio_in_socket: zmq.Socket | None = None
|
self.audio_in_socket: azmq.Socket | None = None
|
||||||
self.audio_out_socket: zmq.Socket | None = None
|
self.audio_out_socket: azmq.Socket | None = None
|
||||||
|
|
||||||
class Stream(CyclicBehaviour):
|
|
||||||
def __init__(self, audio_in_socket: zmq.Socket, audio_out_socket: zmq.Socket):
|
|
||||||
super().__init__()
|
|
||||||
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
|
|
||||||
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
|
|
||||||
model="silero_vad",
|
|
||||||
force_reload=False)
|
|
||||||
self.audio_out_socket = audio_out_socket
|
|
||||||
|
|
||||||
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
|
|
||||||
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
|
|
||||||
self.i_since_speech = 0 # Used to allow small pauses in speech
|
|
||||||
|
|
||||||
async def run(self) -> None:
|
|
||||||
timeout_ms = 100
|
|
||||||
data = await self.audio_in_poller.poll(timeout_ms)
|
|
||||||
if data is None:
|
|
||||||
if self.i_since_data % 10 == 0:
|
|
||||||
logger.debug("Failed to receive audio from socket for %d ms.",
|
|
||||||
timeout_ms*self.i_since_data)
|
|
||||||
self.i_since_data += 1
|
|
||||||
return
|
|
||||||
self.i_since_data = 0
|
|
||||||
|
|
||||||
# copy otherwise Torch will be sad that it's immutable
|
|
||||||
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
|
||||||
prob = self.model(torch.from_numpy(chunk), 16000).item()
|
|
||||||
|
|
||||||
if prob > 0.5:
|
|
||||||
if self.i_since_speech > 3: logger.debug("Speech started.")
|
|
||||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
|
||||||
self.i_since_speech = 0
|
|
||||||
return
|
|
||||||
self.i_since_speech += 1
|
|
||||||
|
|
||||||
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
|
|
||||||
if self.i_since_speech <= 3:
|
|
||||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Speech probably ended. Make sure we have a usable amount of data.
|
|
||||||
if len(self.audio_buffer) >= 3*len(chunk):
|
|
||||||
logger.debug("Speech ended.")
|
|
||||||
self.audio_out_socket.send(self.audio_buffer)
|
|
||||||
|
|
||||||
# At this point, we know that the speech has ended.
|
|
||||||
# Prepend the last chunk that had no speech, for a more fluent boundary
|
|
||||||
self.audio_buffer = chunk
|
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""
|
"""
|
||||||
@@ -133,9 +135,10 @@ class VADAgent(Agent):
|
|||||||
if audio_out_port is None:
|
if audio_out_port is None:
|
||||||
await self.stop()
|
await self.stop()
|
||||||
return
|
return
|
||||||
|
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||||
|
|
||||||
stream = self.Stream(self.audio_in_socket, self.audio_out_socket)
|
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
|
||||||
self.add_behaviour(stream)
|
self.add_behaviour(streaming)
|
||||||
|
|
||||||
# ... start agents dependent on the output audio fragments here
|
# ... start agents dependent on the output audio fragments here
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
class ZMQSettings(BaseModel):
|
class ZMQSettings(BaseModel):
|
||||||
internal_comm_address: str = "tcp://localhost:5560"
|
internal_comm_address: str = "tcp://localhost:5560"
|
||||||
|
|
||||||
audio_fragment_port: int = 5561
|
|
||||||
audio_fragment_address: str = f"tcp://localhost:{audio_fragment_port}"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentSettings(BaseModel):
|
class AgentSettings(BaseModel):
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
|
|||||||
45
test/unit/agents/test_vad_streaming.py
Normal file
45
test/unit/agents/test_vad_streaming.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from control_backend.agents.vad_agent import Streaming
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_in_socket():
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def audio_out_socket():
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def streaming(audio_in_socket, audio_out_socket):
|
||||||
|
return Streaming(audio_in_socket, audio_out_socket)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
||||||
|
# After three chunks of audio with speech probability of 1.0, then four chunks of audio with
|
||||||
|
# speech probability of 0.0, it should send a message over the audio out socket
|
||||||
|
probabilities = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
model_item = MagicMock()
|
||||||
|
model_item.item.side_effect = probabilities
|
||||||
|
streaming.model = MagicMock()
|
||||||
|
streaming.model.return_value = model_item
|
||||||
|
|
||||||
|
audio_in_poller = AsyncMock()
|
||||||
|
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
|
||||||
|
streaming.audio_in_poller = audio_in_poller
|
||||||
|
|
||||||
|
for _ in probabilities:
|
||||||
|
await streaming.run()
|
||||||
|
|
||||||
|
audio_out_socket.send.assert_called_once()
|
||||||
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
|
assert isinstance(data, bytes)
|
||||||
|
# each sample has 512 frames of 4 bytes, expecting 5 chunks (3 with speech, 2 as padding)
|
||||||
|
assert len(data) == 512*4*5
|
||||||
Reference in New Issue
Block a user