Create transcriber agent #15
@@ -1 +1,2 @@
|
|||||||
|
from .speech_recognizer import SpeechRecognizer as SpeechRecognizer
|
||||||
from .transcription_agent import TranscriptionAgent as TranscriptionAgent
|
from .transcription_agent import TranscriptionAgent as TranscriptionAgent
|
||||||
|
|||||||
@@ -12,14 +12,54 @@ import whisper
|
|||||||
|
|
||||||
|
|
||||||
class SpeechRecognizer(abc.ABC):
|
class SpeechRecognizer(abc.ABC):
|
||||||
|
def __init__(self, limit_output_length=True):
|
||||||
|
"""
|
||||||
|
:param limit_output_length: When `True`, the length of the generated speech will be limited
|
||||||
|
by the length of the input audio and some heuristics.
|
||||||
|
"""
|
||||||
|
self.limit_output_length = limit_output_length
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load_model(self): ...
|
def load_model(self): ...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str: ...
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
|
"""
|
||||||
|
Recognize speech from the given audio sample.
|
||||||
|
|
||||||
|
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
|
||||||
|
range [-1.0, 1.0].
|
||||||
|
:return: Recognized speech.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||||
|
"""
|
||||||
|
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
||||||
|
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens.
|
||||||
|
|
||||||
|
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||||
|
:return: The estimated length of the transcribed audio in tokens.
|
||||||
|
"""
|
||||||
|
length_seconds = len(audio) / 16_000
|
||||||
|
length_minutes = length_seconds / 60
|
||||||
|
word_count = length_minutes * 300
|
||||||
|
token_count = word_count / 3 * 4
|
||||||
|
return int(token_count)
|
||||||
|
|
||||||
|
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||||
|
"""
|
||||||
|
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
|
||||||
|
:return: A dict that can be used to construct `whisper.DecodingOptions`.
|
||||||
|
"""
|
||||||
|
options = {}
|
||||||
|
if self.limit_output_length:
|
||||||
|
options["sample_len"] = self._estimate_max_tokens(audio)
|
||||||
|
return options
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def best_type():
|
def best_type():
|
||||||
|
"""Get the best type of SpeechRecognizer based on system capabilities."""
|
||||||
if torch.mps.is_available():
|
if torch.mps.is_available():
|
||||||
print("Choosing MLX Whisper model.")
|
print("Choosing MLX Whisper model.")
|
||||||
return MLXWhisperSpeechRecognizer()
|
return MLXWhisperSpeechRecognizer()
|
||||||
@@ -29,34 +69,37 @@ class SpeechRecognizer(abc.ABC):
|
|||||||
|
|
||||||
|
|
||||||
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||||
def __init__(self):
|
def __init__(self, limit_output_length=True):
|
||||||
super().__init__()
|
super().__init__(limit_output_length)
|
||||||
self.model = None
|
self.was_loaded = False
|
||||||
self.model_name = "mlx-community/whisper-small.en-mlx"
|
self.model_name = "mlx-community/whisper-small.en-mlx"
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.model is not None:
|
if self.was_loaded: return
|
||||||
return
|
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
||||||
ModelHolder.get_model(
|
# store it in memory for later usage
|
||||||
self.model_name, mx.float16
|
ModelHolder.get_model(self.model_name, mx.float16)
|
||||||
) # Should store it in memory for later usage
|
self.was_loaded = True
|
||||||
|
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"]
|
return mlx_whisper.transcribe(audio,
|
||||||
|
path_or_hf_repo=self.model_name,
|
||||||
|
decode_options=self._get_decode_options(audio))["text"]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||||
def __init__(self):
|
def __init__(self, limit_output_length=True):
|
||||||
super().__init__()
|
super().__init__(limit_output_length)
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.model is not None:
|
if self.model is not None: return
|
||||||
return
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
self.model = whisper.load_model("small.en", device=device)
|
self.model = whisper.load_model("small.en", device=device)
|
||||||
|
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return self.model.transcribe(audio)["text"]
|
return whisper.transcribe(self.model,
|
||||||
|
audio,
|
||||||
|
decode_options=self._get_decode_options(audio))["text"]
|
||||||
|
|||||||
@@ -35,18 +35,29 @@ class TranscriptionAgent(Agent):
|
|||||||
self.speech_recognizer = SpeechRecognizer.best_type()
|
self.speech_recognizer = SpeechRecognizer.best_type()
|
||||||
self._concurrency = asyncio.Semaphore(3)
|
self._concurrency = asyncio.Semaphore(3)
|
||||||
|
|
||||||
|
def warmup(self):
|
||||||
|
"""Load the transcription model into memory to speed up the first transcription."""
|
||||||
|
self.speech_recognizer.load_model()
|
||||||
|
|
||||||
async def _transcribe(self, audio: np.ndarray) -> str:
|
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||||
async with self._concurrency:
|
async with self._concurrency:
|
||||||
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||||
|
|
||||||
|
async def _share_transcription(self, transcription: str):
|
||||||
|
"""Share a transcription to the other agents that depend on it."""
|
||||||
|
receiver_jids = [] # Set message receivers here
|
||||||
|
|
||||||
|
for receiver_jid in receiver_jids:
|
||||||
|
message = Message(to=receiver_jid, body=transcription)
|
||||||
|
await self.send(message)
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
audio = await self.audio_in_socket.recv()
|
audio = await self.audio_in_socket.recv()
|
||||||
audio = np.frombuffer(audio, dtype=np.float32)
|
audio = np.frombuffer(audio, dtype=np.float32)
|
||||||
speech = await self._transcribe(audio)
|
speech = await self._transcribe(audio)
|
||||||
logger.info("Transcribed speech: %s", speech)
|
logger.info("Transcribed speech: %s", speech)
|
||||||
|
|
||||||
message = Message(body=speech)
|
await self._share_transcription(speech)
|
||||||
await self.send(message)
|
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.audio_in_socket.close()
|
self.audio_in_socket.close()
|
||||||
@@ -64,6 +75,7 @@ class TranscriptionAgent(Agent):
|
|||||||
self._connect_audio_in_socket()
|
self._connect_audio_in_socket()
|
||||||
|
|
||||||
transcribing = self.Transcribing(self.audio_in_socket)
|
transcribing = self.Transcribing(self.audio_in_socket)
|
||||||
|
transcribing.warmup()
|
||||||
self.add_behaviour(transcribing)
|
self.add_behaviour(transcribing)
|
||||||
|
|
||||||
logger.info("Finished setting up %s", self.jid)
|
logger.info("Finished setting up %s", self.jid)
|
||||||
|
|||||||
36
test/unit/agents/transcription/test_speech_recognizer.py
Normal file
36
test/unit/agents/transcription/test_speech_recognizer.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from control_backend.agents.transcription import SpeechRecognizer
|
||||||
|
from control_backend.agents.transcription.speech_recognizer import OpenAIWhisperSpeechRecognizer
|
||||||
|
|
||||||
|
|
||||||
|
def test_estimate_max_tokens():
|
||||||
|
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
|
||||||
|
audio = np.empty(shape=(60*16_000), dtype=np.float32)
|
||||||
|
|
||||||
|
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
||||||
|
|
||||||
|
assert actual == 400
|
||||||
|
assert isinstance(actual, int)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_decode_options():
|
||||||
|
"""Check whether the right decode options are given under different scenarios."""
|
||||||
|
audio = np.empty(shape=(60*16_000), dtype=np.float32)
|
||||||
|
|
||||||
|
# With the defaults, it should limit output length based on input size
|
||||||
|
recognizer = OpenAIWhisperSpeechRecognizer()
|
||||||
|
options = recognizer._get_decode_options(audio)
|
||||||
|
|
||||||
|
assert "sample_len" in options
|
||||||
|
assert isinstance(options["sample_len"], int)
|
||||||
|
|
||||||
|
# When explicitly enabled, it should limit output length based on input size
|
||||||
|
recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=True)
|
||||||
|
options = recognizer._get_decode_options(audio)
|
||||||
|
|
||||||
|
assert "sample_len" in options
|
||||||
|
assert isinstance(options["sample_len"], int)
|
||||||
|
|
||||||
|
# When disabled, it should not limit output length based on input size
|
||||||
|
assert "sample_rate" not in options
|
||||||
@@ -11,6 +11,7 @@ def pytest_configure(config):
|
|||||||
mock_spade = MagicMock()
|
mock_spade = MagicMock()
|
||||||
mock_spade.agent = MagicMock()
|
mock_spade.agent = MagicMock()
|
||||||
mock_spade.behaviour = MagicMock()
|
mock_spade.behaviour = MagicMock()
|
||||||
|
mock_spade.message = MagicMock()
|
||||||
mock_spade_bdi = MagicMock()
|
mock_spade_bdi = MagicMock()
|
||||||
mock_spade_bdi.bdi = MagicMock()
|
mock_spade_bdi.bdi = MagicMock()
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ def pytest_configure(config):
|
|||||||
sys.modules["spade"] = mock_spade
|
sys.modules["spade"] = mock_spade
|
||||||
sys.modules["spade.agent"] = mock_spade.agent
|
sys.modules["spade.agent"] = mock_spade.agent
|
||||||
sys.modules["spade.behaviour"] = mock_spade.behaviour
|
sys.modules["spade.behaviour"] = mock_spade.behaviour
|
||||||
|
sys.modules["spade.message"] = mock_spade.message
|
||||||
sys.modules["spade_bdi"] = mock_spade_bdi
|
sys.modules["spade_bdi"] = mock_spade_bdi
|
||||||
sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi
|
sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi
|
||||||
|
|
||||||
@@ -43,3 +45,16 @@ def pytest_configure(config):
|
|||||||
sys.modules["torch"] = mock_torch
|
sys.modules["torch"] = mock_torch
|
||||||
sys.modules["zmq"] = mock_zmq
|
sys.modules["zmq"] = mock_zmq
|
||||||
sys.modules["zmq.asyncio"] = mock_zmq.asyncio
|
sys.modules["zmq.asyncio"] = mock_zmq.asyncio
|
||||||
|
|
||||||
|
# --- Mock whisper ---
|
||||||
|
mock_whisper = MagicMock()
|
||||||
|
mock_mlx = MagicMock()
|
||||||
|
mock_mlx.core = MagicMock()
|
||||||
|
mock_mlx_whisper = MagicMock()
|
||||||
|
mock_mlx_whisper.transcribe = MagicMock()
|
||||||
|
|
||||||
|
sys.modules["whisper"] = mock_whisper
|
||||||
|
sys.modules["mlx"] = mock_mlx
|
||||||
|
sys.modules["mlx.core"] = mock_mlx
|
||||||
|
sys.modules["mlx_whisper"] = mock_mlx_whisper
|
||||||
|
sys.modules["mlx_whisper.transcribe"] = mock_mlx_whisper.transcribe
|
||||||
|
|||||||
Reference in New Issue
Block a user