Implement the VAD agent #10

Merged
0950726 merged 9 commits from feat/vad-agent into dev 2025-10-29 08:15:25 +00:00
4 changed files with 102 additions and 57 deletions
Showing only changes of commit ca5e59d029 - Show all commits

View File

@@ -3,6 +3,7 @@ import logging
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
class SocketPoller[T]:
def __init__(self, socket: zmq.Socket[T]):
def __init__(self, socket: azmq.Socket):
self.socket = socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
@@ -31,23 +32,8 @@ class SocketPoller[T]:
return None
class VADAgent(Agent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + '@' + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: zmq.Socket | None = None
self.audio_out_socket: zmq.Socket | None = None
class Stream(CyclicBehaviour):
def __init__(self, audio_in_socket: zmq.Socket, audio_out_socket: zmq.Socket):
class Streaming(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
@@ -89,12 +75,28 @@ class VADAgent(Agent):
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3*len(chunk):
logger.debug("Speech ended.")
self.audio_out_socket.send(self.audio_buffer)
await self.audio_out_socket.send(self.audio_buffer.tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(Agent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + '@' + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
async def stop(self):
"""
Stop listening to audio, stop publishing audio, close sockets.
@@ -133,9 +135,10 @@ class VADAgent(Agent):
if audio_out_port is None:
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
stream = self.Stream(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(stream)
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming)
# ... start agents dependent on the output audio fragments here

View File

@@ -6,9 +6,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class ZMQSettings(BaseModel):
internal_comm_address: str = "tcp://localhost:5560"
audio_fragment_port: int = 5561
audio_fragment_address: str = f"tcp://localhost:{audio_fragment_port}"
class AgentSettings(BaseModel):
host: str = "localhost"

View File

@@ -0,0 +1,45 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.vad_agent import Streaming
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket):
return Streaming(audio_in_socket, audio_out_socket)
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
# After three chunks of audio with speech probability of 1.0, then four chunks of audio with
# speech probability of 0.0, it should send a message over the audio out socket
probabilities = [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
for _ in probabilities:
await streaming.run()
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# each sample has 512 frames of 4 bytes, expecting 5 chunks (3 with speech, 2 as padding)
assert len(data) == 512*4*5