fix: allow Whisper to generate more tokens based on audio length

Before, it sometimes cut off the transcription too early.

ref: N25B-209
This commit is contained in:
Twirre Meulenbelt
2025-11-05 10:41:11 +01:00
parent 8e4d8f9d1e
commit 5c228df109
2 changed files with 16 additions and 5 deletions

View File

@@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC):
def _estimate_max_tokens(audio: np.ndarray) -> int:
"""
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens.
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
:param audio: The audio sample (16 kHz) to use for length estimation.
:return: The estimated length of the transcribed audio in tokens.
"""
length_seconds = len(audio) / 16_000
length_minutes = length_seconds / 60
word_count = length_minutes * 300
word_count = length_minutes * 450
token_count = word_count / 3 * 4
return int(token_count)
return int(token_count) + 10
def _get_decode_options(self, audio: np.ndarray) -> dict:
"""
@@ -84,7 +84,12 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"]
return mlx_whisper.transcribe(
audio,
path_or_hf_repo=self.model_name,
initial_prompt="You're a robot called Pepper, talking with a person called Twirre.",
**self._get_decode_options(audio),
)["text"].strip()
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
@@ -101,5 +106,7 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return whisper.transcribe(
self.model, audio, decode_options=self._get_decode_options(audio)
self.model,
audio,
**self._get_decode_options(audio)
)["text"]

View File

@@ -59,6 +59,10 @@ class TranscriptionAgent(Agent):
audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
logger.info("Nothing transcribed.")
return
logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)