Previously, it was started in main, but it should use values negotiated by the RI communication agent. ref: N25B-356
257 lines
9.6 KiB
Python
257 lines
9.6 KiB
Python
import asyncio
|
|
|
|
import numpy as np
|
|
import torch
|
|
import zmq
|
|
import zmq.asyncio as azmq
|
|
|
|
from control_backend.agents import BaseAgent
|
|
from control_backend.core.config import settings
|
|
|
|
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
|
from .transcription_agent.transcription_agent import TranscriptionAgent
|
|
|
|
|
|
class SocketPoller[T]:
|
|
"""
|
|
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
|
|
multiple usages.
|
|
|
|
:param T: The type of data returned by the socket.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
socket: azmq.Socket,
|
|
timeout_ms: int = settings.behaviour_settings.socket_poller_timeout_ms,
|
|
):
|
|
"""
|
|
:param socket: The socket to poll and get data from.
|
|
:param timeout_ms: A timeout in milliseconds to wait for data.
|
|
"""
|
|
self.socket = socket
|
|
self.poller = azmq.Poller()
|
|
self.poller.register(self.socket, zmq.POLLIN)
|
|
self.timeout_ms = timeout_ms
|
|
|
|
async def poll(self, timeout_ms: int | None = None) -> T | None:
|
|
"""
|
|
Get data from the socket, or None if the timeout is reached.
|
|
|
|
:param timeout_ms: If given, the timeout. Otherwise, ``self.timeout_ms`` is used.
|
|
:return: Data from the socket or None.
|
|
"""
|
|
timeout_ms = timeout_ms or self.timeout_ms
|
|
socks = dict(await self.poller.poll(timeout_ms))
|
|
if socks.get(self.socket) == zmq.POLLIN:
|
|
return await self.socket.recv()
|
|
return None
|
|
|
|
|
|
class VADAgent(BaseAgent):
|
|
"""
|
|
Voice Activity Detection (VAD) Agent.
|
|
|
|
This agent:
|
|
1. Receives an audio stream (via ZMQ).
|
|
2. Processes the audio using the Silero VAD model to detect speech.
|
|
3. Buffers potential speech segments.
|
|
4. Publishes valid speech fragments (containing speech plus small buffer) to a ZMQ PUB socket.
|
|
5. Instantiates and starts agents (like :class:`TranscriptionAgent`) that use this output.
|
|
|
|
:ivar audio_in_address: Address of the input audio stream.
|
|
:ivar audio_in_bind: Whether to bind or connect to the input address.
|
|
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
|
|
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
|
|
"""
|
|
|
|
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
|
"""
|
|
Initialize the VAD Agent.
|
|
|
|
:param audio_in_address: ZMQ address for input audio.
|
|
:param audio_in_bind: True if this agent should bind to the input address, False to connect.
|
|
"""
|
|
super().__init__(settings.agent_settings.vad_name)
|
|
|
|
self.audio_in_address = audio_in_address
|
|
self.audio_in_bind = audio_in_bind
|
|
|
|
self.audio_in_socket: azmq.Socket | None = None
|
|
self.audio_out_socket: azmq.Socket | None = None
|
|
self.audio_in_poller: SocketPoller | None = None
|
|
|
|
self.program_sub_socket: azmq.Socket | None = None
|
|
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
|
self._ready = asyncio.Event()
|
|
self.model = None
|
|
|
|
async def setup(self):
|
|
"""
|
|
Initialize resources.
|
|
|
|
1. Connects audio input socket.
|
|
2. Binds audio output socket (random port).
|
|
3. Connects to program communication socket.
|
|
4. Loads VAD model from Torch Hub.
|
|
5. Starts the streaming loop.
|
|
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
|
|
"""
|
|
self.logger.info("Setting up %s", self.name)
|
|
|
|
self._connect_audio_in_socket()
|
|
|
|
audio_out_port = self._connect_audio_out_socket()
|
|
if audio_out_port is None:
|
|
self.logger.error("Could not bind output socket, stopping.")
|
|
await self.stop()
|
|
return
|
|
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
|
|
|
# Connect to internal communication socket
|
|
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
|
|
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
|
self.program_sub_socket.subscribe(PROGRAM_STATUS)
|
|
|
|
# Initialize VAD model
|
|
try:
|
|
self.model, _ = torch.hub.load(
|
|
repo_or_dir=settings.vad_settings.repo_or_dir,
|
|
model=settings.vad_settings.model_name,
|
|
force_reload=False,
|
|
)
|
|
except Exception:
|
|
self.logger.exception("Failed to load VAD model.")
|
|
await self.stop()
|
|
return
|
|
|
|
self.add_behavior(self._streaming_loop())
|
|
self.add_behavior(self._status_loop())
|
|
|
|
# Start agents dependent on the output audio fragments here
|
|
transcriber = TranscriptionAgent(audio_out_address)
|
|
await transcriber.start()
|
|
|
|
self.logger.info("Finished setting up %s", self.name)
|
|
|
|
async def stop(self):
|
|
"""
|
|
Stop listening to audio, stop publishing audio, close sockets.
|
|
"""
|
|
if self.audio_in_socket is not None:
|
|
self.audio_in_socket.close()
|
|
self.audio_in_socket = None
|
|
if self.audio_out_socket is not None:
|
|
self.audio_out_socket.close()
|
|
self.audio_out_socket = None
|
|
await super().stop()
|
|
|
|
def _connect_audio_in_socket(self):
|
|
"""
|
|
Connects (or binds) the socket for listening to audio from RI.
|
|
:return:
|
|
"""
|
|
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
|
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
|
if self.audio_in_bind:
|
|
self.audio_in_socket.bind(self.audio_in_address)
|
|
else:
|
|
self.audio_in_socket.connect(self.audio_in_address)
|
|
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
|
|
|
|
def _connect_audio_out_socket(self) -> int | None:
|
|
"""
|
|
Returns the port bound, or None if binding failed.
|
|
"""
|
|
try:
|
|
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
|
|
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
|
|
except zmq.ZMQBindError:
|
|
self.logger.error("Failed to bind an audio output socket after 100 tries.")
|
|
self.audio_out_socket = None
|
|
return None
|
|
|
|
async def _reset_stream(self):
|
|
"""
|
|
Clears the ZeroMQ queue and sets ready state.
|
|
"""
|
|
discarded = 0
|
|
assert self.audio_in_poller is not None
|
|
while await self.audio_in_poller.poll(1) is not None:
|
|
discarded += 1
|
|
self.logger.info(f"Discarded {discarded} audio packets before starting.")
|
|
self._ready.set()
|
|
|
|
async def _status_loop(self):
|
|
"""Loop for checking program status. Only start listening if program is RUNNING."""
|
|
while self._running:
|
|
topic, body = await self.program_sub_socket.recv_multipart()
|
|
|
|
if topic != PROGRAM_STATUS:
|
|
continue
|
|
if body != ProgramStatus.RUNNING.value:
|
|
continue
|
|
|
|
# Program is now running, we can start our stream
|
|
await self._reset_stream()
|
|
|
|
# We don't care about further status updates
|
|
self.program_sub_socket.close()
|
|
break
|
|
|
|
async def _streaming_loop(self):
|
|
"""
|
|
Main loop for processing audio stream.
|
|
|
|
1. Polls for new audio chunks.
|
|
2. Passes chunk to VAD model.
|
|
3. Manages `i_since_speech` counter to determine start/end of speech.
|
|
4. Buffers speech + context.
|
|
5. Sends complete speech segment to output socket when silence is detected.
|
|
"""
|
|
await self._ready.wait()
|
|
while self._running:
|
|
assert self.audio_in_poller is not None
|
|
data = await self.audio_in_poller.poll()
|
|
if data is None:
|
|
if len(self.audio_buffer) > 0:
|
|
self.logger.debug(
|
|
"No audio data received. Discarding buffer until new data arrives."
|
|
)
|
|
self.audio_buffer = np.array([], dtype=np.float32)
|
|
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
|
continue
|
|
|
|
# copy otherwise Torch will be sad that it's immutable
|
|
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
|
assert self.model is not None
|
|
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
|
|
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
|
|
prob_threshold = settings.behaviour_settings.vad_prob_threshold
|
|
|
|
if prob > prob_threshold:
|
|
if self.i_since_speech > non_speech_patience:
|
|
self.logger.debug("Speech started.")
|
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
|
self.i_since_speech = 0
|
|
continue
|
|
|
|
self.i_since_speech += 1
|
|
|
|
# prob < threshold, so speech maybe ended. Wait a bit more before to be more certain
|
|
if self.i_since_speech <= non_speech_patience:
|
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
|
continue
|
|
|
|
# Speech probably ended. Make sure we have a usable amount of data.
|
|
if len(self.audio_buffer) >= 3 * len(chunk):
|
|
self.logger.debug("Speech ended.")
|
|
assert self.audio_out_socket is not None
|
|
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
|
|
|
|
# At this point, we know that the speech has ended.
|
|
# Prepend the last chunk that had no speech, for a more fluent boundary
|
|
self.audio_buffer = chunk
|