Implement the VAD agent #10
142
src/control_backend/agents/vad_agent.py
Normal file
142
src/control_backend/agents/vad_agent.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
from spade.agent import Agent
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context as zmq_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SocketPoller[T]:
|
||||
def __init__(self, socket: zmq.Socket[T]):
|
||||
self.socket = socket
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
|
||||
async def poll(self, timeout_ms: int) -> T | None:
|
||||
"""
|
||||
Get data from the socket, or None if the timeout is reached.
|
||||
|
||||
:param timeout_ms: The number of milliseconds to wait for the socket.
|
||||
:return: Data from the socket or None.
|
||||
"""
|
||||
socks = dict(self.poller.poll(timeout_ms))
|
||||
if socks.get(self.socket) == zmq.POLLIN:
|
||||
return await self.socket.recv()
|
||||
return None
|
||||
|
||||
|
||||
class VADAgent(Agent):
|
||||
"""
|
||||
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
|
||||
fragments with detected speech to other agents over ZeroMQ.
|
||||
"""
|
||||
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||
jid = settings.agent_settings.vad_agent_name + '@' + settings.agent_settings.host
|
||||
super().__init__(jid, settings.agent_settings.vad_agent_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
self.audio_in_bind = audio_in_bind
|
||||
|
||||
self.audio_in_socket: zmq.Socket | None = None
|
||||
self.audio_out_socket: zmq.Socket | None = None
|
||||
|
||||
class Stream(CyclicBehaviour):
|
||||
def __init__(self, audio_in_socket: zmq.Socket, audio_out_socket: zmq.Socket):
|
||||
super().__init__()
|
||||
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
|
||||
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
|
||||
model="silero_vad",
|
||||
force_reload=False)
|
||||
self.audio_out_socket = audio_out_socket
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32) # TODO: Consider using a Tensor
|
||||
self.i_since_data = 0 # Used to avoid logging every cycle if audio input stops
|
||||
self.i_since_speech = 0 # Used to allow small pauses in speech
|
||||
|
||||
async def run(self) -> None:
|
||||
timeout_ms = 100
|
||||
data = await self.audio_in_poller.poll(timeout_ms)
|
||||
if data is None:
|
||||
if self.i_since_data % 10 == 0:
|
||||
logger.debug("Failed to receive audio from socket for %d ms.",
|
||||
timeout_ms*self.i_since_data)
|
||||
self.i_since_data += 1
|
||||
return
|
||||
self.i_since_data = 0
|
||||
|
||||
# copy otherwise Torch will be sad that it's immutable
|
||||
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
||||
prob = self.model(torch.from_numpy(chunk), 16000).item()
|
||||
|
||||
if prob > 0.5:
|
||||
if self.i_since_speech > 3: logger.debug("Speech started.")
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
self.i_since_speech = 0
|
||||
return
|
||||
self.i_since_speech += 1
|
||||
|
||||
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
|
||||
if self.i_since_speech <= 3:
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
return
|
||||
|
||||
# Speech probably ended. Make sure we have a usable amount of data.
|
||||
if len(self.audio_buffer) >= 3*len(chunk):
|
||||
logger.debug("Speech ended.")
|
||||
self.audio_out_socket.send(self.audio_buffer)
|
||||
|
||||
# At this point, we know that the speech has ended.
|
||||
# Prepend the last chunk that had no speech, for a more fluent boundary
|
||||
self.audio_buffer = chunk
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Stop listening to audio, stop publishing audio, close sockets.
|
||||
"""
|
||||
self.audio_in_socket.close()
|
||||
self.audio_in_socket = None
|
||||
self.audio_out_socket.close()
|
||||
self.audio_out_socket = None
|
||||
return await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
self.audio_in_socket = zmq_context.socket(zmq.SUB)
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
if self.audio_in_bind:
|
||||
self.audio_in_socket.bind(self.audio_in_address)
|
||||
else:
|
||||
self.audio_in_socket.connect(self.audio_in_address)
|
||||
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
|
||||
|
||||
def _connect_audio_out_socket(self) -> int | None:
|
||||
"""Returns the port bound, or None if binding failed."""
|
||||
try:
|
||||
self.audio_out_socket = zmq_context.socket(zmq.PUB)
|
||||
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
|
||||
except zmq.ZMQBindError:
|
||||
logger.error("Failed to bind an audio output socket after 100 tries.")
|
||||
self.audio_out_socket = None
|
||||
return None
|
||||
|
||||
async def setup(self):
|
||||
logger.info("Setting up %s", self.jid)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
|
||||
audio_out_port = self._connect_audio_out_socket()
|
||||
if audio_out_port is None:
|
||||
await self.stop()
|
||||
return
|
||||
|
||||
stream = self.Stream(self.audio_in_socket, self.audio_out_socket)
|
||||
self.add_behaviour(stream)
|
||||
|
||||
# ... start agents dependent on the output audio fragments here
|
||||
|
||||
logger.info("Finished setting up %s", self.jid)
|
||||
@@ -2,14 +2,20 @@ from re import L
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class ZMQSettings(BaseModel):
|
||||
internal_comm_address: str = "tcp://localhost:5560"
|
||||
|
||||
audio_fragment_port: int = 5561
|
||||
audio_fragment_address: str = f"tcp://localhost:{audio_fragment_port}"
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
bdi_core_agent_name: str = "bdi_core"
|
||||
belief_collector_agent_name: str = "belief_collector"
|
||||
test_agent_name: str = "test_agent"
|
||||
vad_agent_name: str = "vad_agent"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_title: str = "PepperPlus"
|
||||
@@ -22,4 +28,5 @@ class Settings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -13,6 +13,7 @@ import zmq
|
||||
|
||||
# Internal imports
|
||||
from control_backend.agents.bdi.bdi_core import BDICoreAgent
|
||||
from control_backend.agents.vad_agent import VADAgent
|
||||
from control_backend.api.v1.router import api_router
|
||||
from control_backend.core.config import AgentSettings, settings
|
||||
from control_backend.core.zmq_context import context
|
||||
@@ -35,6 +36,9 @@ async def lifespan(app: FastAPI):
|
||||
bdi_core = BDICoreAgent(settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name, "src/control_backend/agents/bdi/rules.asl")
|
||||
await bdi_core.start()
|
||||
|
||||
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
|
||||
await _temp_vad_agent.start()
|
||||
|
||||
yield
|
||||
|
||||
logger.info("%s shutting down.", app.title)
|
||||
|
||||
Reference in New Issue
Block a user