Implement the VAD agent #10

Merged
0950726 merged 9 commits from feat/vad-agent into dev 2025-10-29 08:15:25 +00:00
4 changed files with 102 additions and 57 deletions
Showing only changes of commit ca5e59d029 - Show all commits

View File

@@ -3,6 +3,7 @@ import logging
import numpy as np import numpy as np
import torch import torch
import zmq import zmq
import zmq.asyncio as azmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
class SocketPoller[T]: class SocketPoller[T]:
def __init__(self, socket: zmq.Socket[T]): def __init__(self, socket: azmq.Socket):
self.socket = socket self.socket = socket
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN) self.poller.register(self.socket, zmq.POLLIN)
@@ -31,23 +32,8 @@ class SocketPoller[T]:
return None return None
class VADAgent(Agent): class Streaming(CyclicBehaviour):
""" def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + '@' + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: zmq.Socket | None = None
self.audio_out_socket: zmq.Socket | None = None
class Stream(CyclicBehaviour):
def __init__(self, audio_in_socket: zmq.Socket, audio_out_socket: zmq.Socket):
super().__init__() super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket) self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
@@ -89,12 +75,28 @@ class VADAgent(Agent):
# Speech probably ended. Make sure we have a usable amount of data. # Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3*len(chunk): if len(self.audio_buffer) >= 3*len(chunk):
logger.debug("Speech ended.") logger.debug("Speech ended.")
self.audio_out_socket.send(self.audio_buffer) await self.audio_out_socket.send(self.audio_buffer.tobytes())
# At this point, we know that the speech has ended. # At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary # Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk self.audio_buffer = chunk
class VADAgent(Agent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + '@' + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
async def stop(self): async def stop(self):
""" """
Stop listening to audio, stop publishing audio, close sockets. Stop listening to audio, stop publishing audio, close sockets.
@@ -133,9 +135,10 @@ class VADAgent(Agent):
if audio_out_port is None: if audio_out_port is None:
await self.stop() await self.stop()
return return
audio_out_address = f"tcp://localhost:{audio_out_port}"
stream = self.Stream(self.audio_in_socket, self.audio_out_socket) streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(stream) self.add_behaviour(streaming)
# ... start agents dependent on the output audio fragments here # ... start agents dependent on the output audio fragments here

View File

@@ -6,9 +6,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel): class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560" internal_comm_address: str = "tcp://localhost:5560"
audio_fragment_port: int = 5561
audio_fragment_address: str = f"tcp://localhost:{audio_fragment_port}"
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
host: str = "localhost" host: str = "localhost"

View File

@@ -0,0 +1,45 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.vad_agent import Streaming
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket):
return Streaming(audio_in_socket, audio_out_socket)
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
# After three chunks of audio with speech probability of 1.0, then four chunks of audio with
# speech probability of 0.0, it should send a message over the audio out socket
probabilities = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
for _ in probabilities:
await streaming.run()
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# each sample has 512 frames of 4 bytes, expecting 5 chunks (3 with speech, 2 as padding)
assert len(data) == 512*4*5