refactor: remove SPADE dependencies
Did not look at tests yet, this is a very non-final commit. ref: N25B-300
This commit is contained in:
@@ -3,10 +3,9 @@ import asyncio
|
||||
import numpy as np
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
from spade.message import Message
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
|
||||
from .speech_recognizer import SpeechRecognizer
|
||||
@@ -19,53 +18,31 @@ class TranscriptionAgent(BaseAgent):
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str):
|
||||
jid = settings.agent_settings.transcription_name + "@" + settings.agent_settings.host
|
||||
super().__init__(jid, settings.agent_settings.transcription_name)
|
||||
super().__init__(settings.agent_settings.transcription_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
self.audio_in_socket: azmq.Socket | None = None
|
||||
self.speech_recognizer = None
|
||||
self._concurrency = None
|
||||
|
||||
class TranscribingBehaviour(CyclicBehaviour):
|
||||
def __init__(self, audio_in_socket: azmq.Socket):
|
||||
super().__init__()
|
||||
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
|
||||
self.audio_in_socket = audio_in_socket
|
||||
self.speech_recognizer = SpeechRecognizer.best_type()
|
||||
self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
|
||||
async def setup(self):
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
def warmup(self):
|
||||
"""Load the transcription model into memory to speed up the first transcription."""
|
||||
self.speech_recognizer.load_model()
|
||||
self._connect_audio_in_socket()
|
||||
|
||||
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||
async with self._concurrency:
|
||||
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||
# Initialize recognizer and semaphore
|
||||
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
|
||||
self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
|
||||
self.speech_recognizer = SpeechRecognizer.best_type()
|
||||
self.speech_recognizer.load_model() # Warmup
|
||||
|
||||
async def _share_transcription(self, transcription: str):
|
||||
"""Share a transcription to the other agents that depend on it."""
|
||||
receiver_jids = [
|
||||
settings.agent_settings.text_belief_extractor_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host,
|
||||
] # Set message receivers here
|
||||
# Start background loop
|
||||
await self.add_background_task(self._transcribing_loop())
|
||||
|
||||
for receiver_jid in receiver_jids:
|
||||
message = Message(to=receiver_jid, body=transcription)
|
||||
await self.send(message)
|
||||
|
||||
async def run(self) -> None:
|
||||
audio = await self.audio_in_socket.recv()
|
||||
audio = np.frombuffer(audio, dtype=np.float32)
|
||||
speech = await self._transcribe(audio)
|
||||
if not speech:
|
||||
self.agent.logger.info("Nothing transcribed.")
|
||||
return
|
||||
|
||||
self.agent.logger.info("Transcribed speech: %s", speech)
|
||||
|
||||
await self._share_transcription(speech)
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def stop(self):
|
||||
assert self.audio_in_socket is not None
|
||||
self.audio_in_socket.close()
|
||||
self.audio_in_socket = None
|
||||
return await super().stop()
|
||||
@@ -75,13 +52,37 @@ class TranscriptionAgent(BaseAgent):
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.audio_in_socket.connect(self.audio_in_address)
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Setting up %s", self.jid)
|
||||
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||
assert self._concurrency is not None and self.speech_recognizer is not None
|
||||
async with self._concurrency:
|
||||
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
async def _share_transcription(self, transcription: str):
|
||||
"""Share a transcription to the other agents that depend on it."""
|
||||
receiver_names = [
|
||||
settings.agent_settings.text_belief_extractor_name,
|
||||
]
|
||||
|
||||
transcribing = self.TranscribingBehaviour(self.audio_in_socket)
|
||||
transcribing.warmup()
|
||||
self.add_behaviour(transcribing)
|
||||
for receiver_name in receiver_names:
|
||||
message = InternalMessage(
|
||||
to=receiver_name,
|
||||
sender=self.name,
|
||||
body=transcription,
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
self.logger.info("Finished setting up %s", self.jid)
|
||||
async def _transcribing_loop(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
assert self.audio_in_socket is not None
|
||||
audio_data = await self.audio_in_socket.recv()
|
||||
audio = np.frombuffer(audio_data, dtype=np.float32)
|
||||
speech = await self._transcribe(audio)
|
||||
if not speech:
|
||||
self.logger.info("Nothing transcribed.")
|
||||
continue
|
||||
|
||||
self.logger.info("Transcribed speech: %s", speech)
|
||||
await self._share_transcription(speech)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in transcription loop: {e}")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
@@ -26,7 +27,7 @@ class SocketPoller[T]:
|
||||
:param timeout_ms: A timeout in milliseconds to wait for data.
|
||||
"""
|
||||
self.socket = socket
|
||||
self.poller = zmq.Poller()
|
||||
self.poller = azmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
self.timeout_ms = timeout_ms
|
||||
|
||||
@@ -38,81 +39,12 @@ class SocketPoller[T]:
|
||||
:return: Data from the socket or None.
|
||||
"""
|
||||
timeout_ms = timeout_ms or self.timeout_ms
|
||||
socks = dict(self.poller.poll(timeout_ms))
|
||||
socks = dict(await self.poller.poll(timeout_ms))
|
||||
if socks.get(self.socket) == zmq.POLLIN:
|
||||
return await self.socket.recv()
|
||||
return None
|
||||
|
||||
|
||||
class StreamingBehaviour(CyclicBehaviour):
|
||||
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
|
||||
super().__init__()
|
||||
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
|
||||
self.model, _ = torch.hub.load(
|
||||
repo_or_dir=settings.vad_settings.repo_or_dir,
|
||||
model=settings.vad_settings.model_name,
|
||||
force_reload=False,
|
||||
)
|
||||
self.audio_out_socket = audio_out_socket
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = (
|
||||
settings.behaviour_settings.vad_initial_since_speech
|
||||
) # Used to allow small pauses in speech
|
||||
self._ready = False
|
||||
|
||||
async def reset(self):
|
||||
"""Clears the ZeroMQ queue and tells this behavior to start."""
|
||||
discarded = 0
|
||||
# Poll for the shortest amount of time possible to clear the queue
|
||||
while await self.audio_in_poller.poll(1) is not None:
|
||||
discarded += 1
|
||||
self.agent.logger.info(f"Discarded {discarded} audio packets before starting.")
|
||||
self._ready = True
|
||||
|
||||
async def run(self) -> None:
|
||||
if not self._ready:
|
||||
return
|
||||
|
||||
data = await self.audio_in_poller.poll()
|
||||
if data is None:
|
||||
if len(self.audio_buffer) > 0:
|
||||
self.agent.logger.debug(
|
||||
"No audio data received. Discarding buffer until new data arrives."
|
||||
)
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||
return
|
||||
|
||||
# copy otherwise Torch will be sad that it's immutable
|
||||
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
||||
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
|
||||
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
|
||||
prob_threshold = settings.behaviour_settings.vad_prob_threshold
|
||||
|
||||
if prob > prob_threshold:
|
||||
if self.i_since_speech > non_speech_patience:
|
||||
self.agent.logger.debug("Speech started.")
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
self.i_since_speech = 0
|
||||
return
|
||||
self.i_since_speech += 1
|
||||
|
||||
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
|
||||
if self.i_since_speech <= non_speech_patience:
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
return
|
||||
|
||||
# Speech probably ended. Make sure we have a usable amount of data.
|
||||
if len(self.audio_buffer) >= 3 * len(chunk):
|
||||
self.agent.logger.debug("Speech ended.")
|
||||
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
|
||||
|
||||
# At this point, we know that the speech has ended.
|
||||
# Prepend the last chunk that had no speech, for a more fluent boundary
|
||||
self.audio_buffer = chunk
|
||||
|
||||
|
||||
class VADAgent(BaseAgent):
|
||||
"""
|
||||
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
|
||||
@@ -120,16 +52,54 @@ class VADAgent(BaseAgent):
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||
jid = settings.agent_settings.vad_name + "@" + settings.agent_settings.host
|
||||
super().__init__(jid, settings.agent_settings.vad_name)
|
||||
super().__init__(settings.agent_settings.vad_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
self.audio_in_bind = audio_in_bind
|
||||
|
||||
self.audio_in_socket: azmq.Socket | None = None
|
||||
self.audio_out_socket: azmq.Socket | None = None
|
||||
self.audio_in_poller: SocketPoller | None = None
|
||||
|
||||
self.streaming_behaviour: StreamingBehaviour | None = None
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||
self._ready = False
|
||||
self.model = None
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Setting up %s", self.jid)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
|
||||
audio_out_port = self._connect_audio_out_socket()
|
||||
if audio_out_port is None:
|
||||
self.logger.error("Could not bind output socket, stopping.")
|
||||
await self.stop()
|
||||
return
|
||||
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||
|
||||
# Initialize VAD model
|
||||
try:
|
||||
self.model, _ = torch.hub.load(
|
||||
repo_or_dir=settings.vad_settings.repo_or_dir,
|
||||
model=settings.vad_settings.model_name,
|
||||
force_reload=False,
|
||||
)
|
||||
except Exception:
|
||||
self.logger.exception("Failed to load VAD model.")
|
||||
await self.stop()
|
||||
return
|
||||
|
||||
# Warmup/reset
|
||||
await self.reset_stream()
|
||||
|
||||
await self.add_background_task(self._streaming_loop())
|
||||
|
||||
# Start agents dependent on the output audio fragments here
|
||||
transcriber = TranscriptionAgent(audio_out_address)
|
||||
await transcriber.start()
|
||||
|
||||
self.logger.info("Finished setting up %s", self.jid)
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
@@ -141,7 +111,7 @@ class VADAgent(BaseAgent):
|
||||
if self.audio_out_socket is not None:
|
||||
self.audio_out_socket.close()
|
||||
self.audio_out_socket = None
|
||||
return await super().stop()
|
||||
await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
@@ -156,28 +126,67 @@ class VADAgent(BaseAgent):
|
||||
"""Returns the port bound, or None if binding failed."""
|
||||
try:
|
||||
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
|
||||
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
|
||||
return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100)
|
||||
except zmq.ZMQBindError:
|
||||
self.logger.error("Failed to bind an audio output socket after 100 tries.")
|
||||
self.audio_out_socket = None
|
||||
return None
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Setting up %s", self.jid)
|
||||
async def reset_stream(self):
|
||||
"""
|
||||
Clears the ZeroMQ queue and sets ready state.
|
||||
"""
|
||||
discarded = 0
|
||||
assert self.audio_in_poller is not None
|
||||
while await self.audio_in_poller.poll(1) is not None:
|
||||
discarded += 1
|
||||
self.logger.info(f"Discarded {discarded} audio packets before starting.")
|
||||
self._ready = True
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
async def _streaming_loop(self):
|
||||
while self._running:
|
||||
if not self._ready:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
audio_out_port = self._connect_audio_out_socket()
|
||||
if audio_out_port is None:
|
||||
await self.stop()
|
||||
return
|
||||
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||
assert self.audio_in_poller is not None
|
||||
data = await self.audio_in_poller.poll()
|
||||
if data is None:
|
||||
if len(self.audio_buffer) > 0:
|
||||
self.logger.debug(
|
||||
"No audio data received. Discarding buffer until new data arrives."
|
||||
)
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||
continue
|
||||
|
||||
self.streaming_behaviour = StreamingBehaviour(self.audio_in_socket, self.audio_out_socket)
|
||||
self.add_behaviour(self.streaming_behaviour)
|
||||
# copy otherwise Torch will be sad that it's immutable
|
||||
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
||||
assert self.model is not None
|
||||
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
|
||||
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
|
||||
prob_threshold = settings.behaviour_settings.vad_prob_threshold
|
||||
|
||||
# Start agents dependent on the output audio fragments here
|
||||
transcriber = TranscriptionAgent(audio_out_address)
|
||||
await transcriber.start()
|
||||
if prob > prob_threshold:
|
||||
if self.i_since_speech > non_speech_patience:
|
||||
self.logger.debug("Speech started.")
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
self.i_since_speech = 0
|
||||
continue
|
||||
|
||||
self.logger.info("Finished setting up %s", self.jid)
|
||||
self.i_since_speech += 1
|
||||
|
||||
# prob < threshold, so speech maybe ended. Wait a bit more before to be more certain
|
||||
if self.i_since_speech <= non_speech_patience:
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
continue
|
||||
|
||||
# Speech probably ended. Make sure we have a usable amount of data.
|
||||
if len(self.audio_buffer) >= 3 * len(chunk):
|
||||
self.logger.debug("Speech ended.")
|
||||
assert self.audio_out_socket is not None
|
||||
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
|
||||
|
||||
# At this point, we know that the speech has ended.
|
||||
# Prepend the last chunk that had no speech, for a more fluent boundary
|
||||
self.audio_buffer = chunk
|
||||
|
||||
Reference in New Issue
Block a user