Merge remote-tracking branch 'origin/dev' into refactor/config-file
# Conflicts: # src/control_backend/agents/ri_communication_agent.py # src/control_backend/core/config.py # src/control_backend/main.py
This commit is contained in:
@@ -0,0 +1,111 @@
|
||||
import abc
|
||||
import sys
|
||||
|
||||
if sys.platform == "darwin":
|
||||
import mlx.core as mx
|
||||
import mlx_whisper
|
||||
from mlx_whisper.transcribe import ModelHolder
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import whisper
|
||||
|
||||
from control_backend.core.config import settings
|
||||
|
||||
|
||||
class SpeechRecognizer(abc.ABC):
|
||||
def __init__(self, limit_output_length=True):
|
||||
"""
|
||||
:param limit_output_length: When `True`, the length of the generated speech will be limited
|
||||
by the length of the input audio and some heuristics.
|
||||
"""
|
||||
self.limit_output_length = limit_output_length
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_model(self): ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
"""
|
||||
Recognize speech from the given audio sample.
|
||||
|
||||
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
|
||||
range [-1.0, 1.0].
|
||||
:return: Recognized speech.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||
"""
|
||||
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
||||
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||
:return: The estimated length of the transcribed audio in tokens.
|
||||
"""
|
||||
length_seconds = len(audio) / settings.vad_settings.sample_rate_hz
|
||||
length_minutes = length_seconds / 60
|
||||
word_count = length_minutes * settings.behaviour_settings.transcription_words_per_minute
|
||||
token_count = word_count / settings.behaviour_settings.transcription_words_per_token
|
||||
return int(token_count) + settings.behaviour_settings.transcription_token_buffer
|
||||
|
||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||
"""
|
||||
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
|
||||
:return: A dict that can be used to construct `whisper.DecodingOptions`.
|
||||
"""
|
||||
options = {}
|
||||
if self.limit_output_length:
|
||||
options["sample_len"] = self._estimate_max_tokens(audio)
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def best_type():
|
||||
"""Get the best type of SpeechRecognizer based on system capabilities."""
|
||||
if torch.mps.is_available():
|
||||
print("Choosing MLX Whisper model.")
|
||||
return MLXWhisperSpeechRecognizer()
|
||||
else:
|
||||
print("Choosing reference Whisper model.")
|
||||
return OpenAIWhisperSpeechRecognizer()
|
||||
|
||||
|
||||
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
def __init__(self, limit_output_length=True):
|
||||
super().__init__(limit_output_length)
|
||||
self.was_loaded = False
|
||||
self.model_name = settings.speech_model_settings.mlx_model_name
|
||||
|
||||
def load_model(self):
|
||||
if self.was_loaded:
|
||||
return
|
||||
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
||||
# store it in memory for later usage
|
||||
ModelHolder.get_model(self.model_name, mx.float16)
|
||||
self.was_loaded = True
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return mlx_whisper.transcribe(
|
||||
audio,
|
||||
path_or_hf_repo=self.model_name,
|
||||
**self._get_decode_options(audio),
|
||||
)["text"].strip()
|
||||
|
||||
|
||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
def __init__(self, limit_output_length=True):
|
||||
super().__init__(limit_output_length)
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.model = whisper.load_model(
|
||||
settings.speech_model_settings.openai_model_name, device=device
|
||||
)
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"]
|
||||
@@ -0,0 +1,87 @@
|
||||
import asyncio
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
from spade.message import Message
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
|
||||
from .speech_recognizer import SpeechRecognizer
|
||||
|
||||
|
||||
class TranscriptionAgent(BaseAgent):
|
||||
"""
|
||||
An agent which listens to audio fragments with voice, transcribes them, and sends the
|
||||
transcription to other agents.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str):
|
||||
jid = settings.agent_settings.transcription_name + "@" + settings.agent_settings.host
|
||||
super().__init__(jid, settings.agent_settings.transcription_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
self.audio_in_socket: azmq.Socket | None = None
|
||||
|
||||
class TranscribingBehaviour(CyclicBehaviour):
|
||||
def __init__(self, audio_in_socket: azmq.Socket):
|
||||
super().__init__()
|
||||
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
|
||||
self.audio_in_socket = audio_in_socket
|
||||
self.speech_recognizer = SpeechRecognizer.best_type()
|
||||
self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
||||
def warmup(self):
|
||||
"""Load the transcription model into memory to speed up the first transcription."""
|
||||
self.speech_recognizer.load_model()
|
||||
|
||||
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||
async with self._concurrency:
|
||||
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||
|
||||
async def _share_transcription(self, transcription: str):
|
||||
"""Share a transcription to the other agents that depend on it."""
|
||||
receiver_jids = [
|
||||
settings.agent_settings.text_belief_extractor_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host,
|
||||
] # Set message receivers here
|
||||
|
||||
for receiver_jid in receiver_jids:
|
||||
message = Message(to=receiver_jid, body=transcription)
|
||||
await self.send(message)
|
||||
|
||||
async def run(self) -> None:
|
||||
audio = await self.audio_in_socket.recv()
|
||||
audio = np.frombuffer(audio, dtype=np.float32)
|
||||
speech = await self._transcribe(audio)
|
||||
if not speech:
|
||||
self.agent.logger.info("Nothing transcribed.")
|
||||
return
|
||||
|
||||
self.agent.logger.info("Transcribed speech: %s", speech)
|
||||
|
||||
await self._share_transcription(speech)
|
||||
|
||||
async def stop(self):
|
||||
self.audio_in_socket.close()
|
||||
self.audio_in_socket = None
|
||||
return await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.audio_in_socket.connect(self.audio_in_address)
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Setting up %s", self.jid)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
|
||||
transcribing = self.TranscribingBehaviour(self.audio_in_socket)
|
||||
transcribing.warmup()
|
||||
self.add_behaviour(transcribing)
|
||||
|
||||
self.logger.info("Finished setting up %s", self.jid)
|
||||
Reference in New Issue
Block a user