From 5c228df1094454443edd8dd79b192c9b4f89edc6 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:41:11 +0100 Subject: [PATCH] fix: allow Whisper to generate more tokens based on audio length Before, it sometimes cut off the transcription too early. ref: N25B-209 --- .../agents/transcription/speech_recognizer.py | 17 ++++++++++++----- .../agents/transcription/transcription_agent.py | 4 ++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index 45e42bf..9e61fd7 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC): def _estimate_max_tokens(audio: np.ndarray) -> int: """ Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking - rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens. + rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens. :param audio: The audio sample (16 kHz) to use for length estimation. :return: The estimated length of the transcribed audio in tokens. """ length_seconds = len(audio) / 16_000 length_minutes = length_seconds / 60 - word_count = length_minutes * 300 + word_count = length_minutes * 450 token_count = word_count / 3 * 4 - return int(token_count) + return int(token_count) + 10 def _get_decode_options(self, audio: np.ndarray) -> dict: """ @@ -84,7 +84,12 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"] + return mlx_whisper.transcribe( + audio, + path_or_hf_repo=self.model_name, + initial_prompt="You're a robot called Pepper, talking with a person called Twirre.", + **self._get_decode_options(audio), + )["text"].strip() class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): @@ -101,5 +106,7 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() return whisper.transcribe( - self.model, audio, decode_options=self._get_decode_options(audio) + self.model, + audio, + **self._get_decode_options(audio) )["text"] diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index 2d936c4..196fd28 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -59,6 +59,10 @@ class TranscriptionAgent(Agent): audio = await self.audio_in_socket.recv() audio = np.frombuffer(audio, dtype=np.float32) speech = await self._transcribe(audio) + if not speech: + logger.info("Nothing transcribed.") + return + logger.info("Transcribed speech: %s", speech) await self._share_transcription(speech)