refactor: remove SPADE dependencies

Did not look at tests yet, this is a very non-final commit.

ref: N25B-300
This commit is contained in:
2025-11-20 14:35:28 +01:00
parent 6025721866
commit bb3f81d2e8
20 changed files with 757 additions and 1683 deletions

View File

@@ -1,8 +1,9 @@
import asyncio
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
@@ -26,7 +27,7 @@ class SocketPoller[T]:
:param timeout_ms: A timeout in milliseconds to wait for data.
"""
self.socket = socket
self.poller = zmq.Poller()
self.poller = azmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms
@@ -38,81 +39,12 @@ class SocketPoller[T]:
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
socks = dict(self.poller.poll(timeout_ms))
socks = dict(await self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv()
return None
class StreamingBehaviour(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(
repo_or_dir=settings.vad_settings.repo_or_dir,
model=settings.vad_settings.model_name,
force_reload=False,
)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = (
settings.behaviour_settings.vad_initial_since_speech
) # Used to allow small pauses in speech
self._ready = False
async def reset(self):
"""Clears the ZeroMQ queue and tells this behavior to start."""
discarded = 0
# Poll for the shortest amount of time possible to clear the queue
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
self.agent.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready = True
async def run(self) -> None:
if not self._ready:
return
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
self.agent.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
return
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
prob_threshold = settings.behaviour_settings.vad_prob_threshold
if prob > prob_threshold:
if self.i_since_speech > non_speech_patience:
self.agent.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= non_speech_patience:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
self.agent.logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(BaseAgent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
@@ -120,16 +52,54 @@ class VADAgent(BaseAgent):
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_name)
super().__init__(settings.agent_settings.vad_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
self.audio_in_poller: SocketPoller | None = None
self.streaming_behaviour: StreamingBehaviour | None = None
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = False
self.model = None
async def setup(self):
self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
self.logger.error("Could not bind output socket, stopping.")
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
# Initialize VAD model
try:
self.model, _ = torch.hub.load(
repo_or_dir=settings.vad_settings.repo_or_dir,
model=settings.vad_settings.model_name,
force_reload=False,
)
except Exception:
self.logger.exception("Failed to load VAD model.")
await self.stop()
return
# Warmup/reset
await self.reset_stream()
await self.add_background_task(self._streaming_loop())
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
self.logger.info("Finished setting up %s", self.jid)
async def stop(self):
"""
@@ -141,7 +111,7 @@ class VADAgent(BaseAgent):
if self.audio_out_socket is not None:
self.audio_out_socket.close()
self.audio_out_socket = None
return await super().stop()
await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
@@ -156,28 +126,67 @@ class VADAgent(BaseAgent):
"""Returns the port bound, or None if binding failed."""
try:
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
except zmq.ZMQBindError:
self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def setup(self):
self.logger.info("Setting up %s", self.jid)
async def reset_stream(self):
"""
Clears the ZeroMQ queue and sets ready state.
"""
discarded = 0
assert self.audio_in_poller is not None
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
self.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready = True
self._connect_audio_in_socket()
async def _streaming_loop(self):
while self._running:
if not self._ready:
await asyncio.sleep(0.1)
continue
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
await self.stop()
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
assert self.audio_in_poller is not None
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
self.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
continue
self.streaming_behaviour = StreamingBehaviour(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(self.streaming_behaviour)
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
assert self.model is not None
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
prob_threshold = settings.behaviour_settings.vad_prob_threshold
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
if prob > prob_threshold:
if self.i_since_speech > non_speech_patience:
self.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
continue
self.logger.info("Finished setting up %s", self.jid)
self.i_since_speech += 1
# prob < threshold, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= non_speech_patience:
self.audio_buffer = np.append(self.audio_buffer, chunk)
continue
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
self.logger.debug("Speech ended.")
assert self.audio_out_socket is not None
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk