From bec3e57658a975e0ef14f4b95323628933526311 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 29 Oct 2025 12:49:24 +0100 Subject: [PATCH] feat: limit transcription output length based on input Using heuristics. Also adds documentation and initial unit tests. ref: N25B-209 --- .../agents/transcription/__init__.py | 1 + .../agents/transcription/speech_recognizer.py | 73 +++++++++++++++---- .../transcription/transcription_agent.py | 16 +++- .../transcription/test_speech_recognizer.py | 36 +++++++++ test/unit/conftest.py | 15 ++++ 5 files changed, 124 insertions(+), 17 deletions(-) create mode 100644 test/unit/agents/transcription/test_speech_recognizer.py diff --git a/src/control_backend/agents/transcription/__init__.py b/src/control_backend/agents/transcription/__init__.py index 3e87e70..fd3c8c5 100644 --- a/src/control_backend/agents/transcription/__init__.py +++ b/src/control_backend/agents/transcription/__init__.py @@ -1 +1,2 @@ +from .speech_recognizer import SpeechRecognizer as SpeechRecognizer from .transcription_agent import TranscriptionAgent as TranscriptionAgent diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index 58523a4..cf48fa7 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -12,14 +12,54 @@ import whisper class SpeechRecognizer(abc.ABC): + def __init__(self, limit_output_length=True): + """ + :param limit_output_length: When `True`, the length of the generated speech will be limited + by the length of the input audio and some heuristics. + """ + self.limit_output_length = limit_output_length + @abc.abstractmethod def load_model(self): ... @abc.abstractmethod - def recognize_speech(self, audio: np.ndarray) -> str: ... + def recognize_speech(self, audio: np.ndarray) -> str: + """ + Recognize speech from the given audio sample. + + :param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the + range [-1.0, 1.0]. + :return: Recognized speech. + """ + + @staticmethod + def _estimate_max_tokens(audio: np.ndarray) -> int: + """ + Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking + rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens. + + :param audio: The audio sample (16 kHz) to use for length estimation. + :return: The estimated length of the transcribed audio in tokens. + """ + length_seconds = len(audio) / 16_000 + length_minutes = length_seconds / 60 + word_count = length_minutes * 300 + token_count = word_count / 3 * 4 + return int(token_count) + + def _get_decode_options(self, audio: np.ndarray) -> dict: + """ + :param audio: The audio sample (16 kHz) to use to determine options like max decode length. + :return: A dict that can be used to construct `whisper.DecodingOptions`. + """ + options = {} + if self.limit_output_length: + options["sample_len"] = self._estimate_max_tokens(audio) + return options @staticmethod def best_type(): + """Get the best type of SpeechRecognizer based on system capabilities.""" if torch.mps.is_available(): print("Choosing MLX Whisper model.") return MLXWhisperSpeechRecognizer() @@ -29,34 +69,37 @@ class SpeechRecognizer(abc.ABC): class MLXWhisperSpeechRecognizer(SpeechRecognizer): - def __init__(self): - super().__init__() - self.model = None + def __init__(self, limit_output_length=True): + super().__init__(limit_output_length) + self.was_loaded = False self.model_name = "mlx-community/whisper-small.en-mlx" def load_model(self): - if self.model is not None: - return - ModelHolder.get_model( - self.model_name, mx.float16 - ) # Should store it in memory for later usage + if self.was_loaded: return + # There appears to be no dedicated mechanism to preload a model, but this `get_model` does + # store it in memory for later usage + ModelHolder.get_model(self.model_name, mx.float16) + self.was_loaded = True def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"] + return mlx_whisper.transcribe(audio, + path_or_hf_repo=self.model_name, + decode_options=self._get_decode_options(audio))["text"] class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): - def __init__(self): - super().__init__() + def __init__(self, limit_output_length=True): + super().__init__(limit_output_length) self.model = None def load_model(self): - if self.model is not None: - return + if self.model is not None: return device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model = whisper.load_model("small.en", device=device) def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return self.model.transcribe(audio)["text"] + return whisper.transcribe(self.model, + audio, + decode_options=self._get_decode_options(audio))["text"] diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index b572f5e..dd18639 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -35,18 +35,29 @@ class TranscriptionAgent(Agent): self.speech_recognizer = SpeechRecognizer.best_type() self._concurrency = asyncio.Semaphore(3) + def warmup(self): + """Load the transcription model into memory to speed up the first transcription.""" + self.speech_recognizer.load_model() + async def _transcribe(self, audio: np.ndarray) -> str: async with self._concurrency: return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio) + async def _share_transcription(self, transcription: str): + """Share a transcription to the other agents that depend on it.""" + receiver_jids = [] # Set message receivers here + + for receiver_jid in receiver_jids: + message = Message(to=receiver_jid, body=transcription) + await self.send(message) + async def run(self) -> None: audio = await self.audio_in_socket.recv() audio = np.frombuffer(audio, dtype=np.float32) speech = await self._transcribe(audio) logger.info("Transcribed speech: %s", speech) - message = Message(body=speech) - await self.send(message) + await self._share_transcription(speech) async def stop(self): self.audio_in_socket.close() @@ -64,6 +75,7 @@ class TranscriptionAgent(Agent): self._connect_audio_in_socket() transcribing = self.Transcribing(self.audio_in_socket) + transcribing.warmup() self.add_behaviour(transcribing) logger.info("Finished setting up %s", self.jid) diff --git a/test/unit/agents/transcription/test_speech_recognizer.py b/test/unit/agents/transcription/test_speech_recognizer.py new file mode 100644 index 0000000..6e7cde0 --- /dev/null +++ b/test/unit/agents/transcription/test_speech_recognizer.py @@ -0,0 +1,36 @@ +import numpy as np + +from control_backend.agents.transcription import SpeechRecognizer +from control_backend.agents.transcription.speech_recognizer import OpenAIWhisperSpeechRecognizer + + +def test_estimate_max_tokens(): + """Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens.""" + audio = np.empty(shape=(60*16_000), dtype=np.float32) + + actual = SpeechRecognizer._estimate_max_tokens(audio) + + assert actual == 400 + assert isinstance(actual, int) + + +def test_get_decode_options(): + """Check whether the right decode options are given under different scenarios.""" + audio = np.empty(shape=(60*16_000), dtype=np.float32) + + # With the defaults, it should limit output length based on input size + recognizer = OpenAIWhisperSpeechRecognizer() + options = recognizer._get_decode_options(audio) + + assert "sample_len" in options + assert isinstance(options["sample_len"], int) + + # When explicitly enabled, it should limit output length based on input size + recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=True) + options = recognizer._get_decode_options(audio) + + assert "sample_len" in options + assert isinstance(options["sample_len"], int) + + # When disabled, it should not limit output length based on input size + assert "sample_rate" not in options diff --git a/test/unit/conftest.py b/test/unit/conftest.py index 76ef272..ecf00c1 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -11,6 +11,7 @@ def pytest_configure(config): mock_spade = MagicMock() mock_spade.agent = MagicMock() mock_spade.behaviour = MagicMock() + mock_spade.message = MagicMock() mock_spade_bdi = MagicMock() mock_spade_bdi.bdi = MagicMock() @@ -21,6 +22,7 @@ def pytest_configure(config): sys.modules["spade"] = mock_spade sys.modules["spade.agent"] = mock_spade.agent sys.modules["spade.behaviour"] = mock_spade.behaviour + sys.modules["spade.message"] = mock_spade.message sys.modules["spade_bdi"] = mock_spade_bdi sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi @@ -43,3 +45,16 @@ def pytest_configure(config): sys.modules["torch"] = mock_torch sys.modules["zmq"] = mock_zmq sys.modules["zmq.asyncio"] = mock_zmq.asyncio + + # --- Mock whisper --- + mock_whisper = MagicMock() + mock_mlx = MagicMock() + mock_mlx.core = MagicMock() + mock_mlx_whisper = MagicMock() + mock_mlx_whisper.transcribe = MagicMock() + + sys.modules["whisper"] = mock_whisper + sys.modules["mlx"] = mock_mlx + sys.modules["mlx.core"] = mock_mlx + sys.modules["mlx_whisper"] = mock_mlx_whisper + sys.modules["mlx_whisper.transcribe"] = mock_mlx_whisper.transcribe