fix: allow Whisper to generate more tokens based on audio length
Before, it sometimes cut off the transcription too early. ref: N25B-209
This commit is contained in:
@@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC):
|
||||
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||
"""
|
||||
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
||||
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens.
|
||||
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||
:return: The estimated length of the transcribed audio in tokens.
|
||||
"""
|
||||
length_seconds = len(audio) / 16_000
|
||||
length_minutes = length_seconds / 60
|
||||
word_count = length_minutes * 300
|
||||
word_count = length_minutes * 450
|
||||
token_count = word_count / 3 * 4
|
||||
return int(token_count)
|
||||
return int(token_count) + 10
|
||||
|
||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||
"""
|
||||
@@ -84,7 +84,12 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"]
|
||||
return mlx_whisper.transcribe(
|
||||
audio,
|
||||
path_or_hf_repo=self.model_name,
|
||||
initial_prompt="You're a robot called Pepper, talking with a person called Twirre.",
|
||||
**self._get_decode_options(audio),
|
||||
)["text"].strip()
|
||||
|
||||
|
||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
@@ -101,5 +106,7 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return whisper.transcribe(
|
||||
self.model, audio, decode_options=self._get_decode_options(audio)
|
||||
self.model,
|
||||
audio,
|
||||
**self._get_decode_options(audio)
|
||||
)["text"]
|
||||
|
||||
@@ -59,6 +59,10 @@ class TranscriptionAgent(Agent):
|
||||
audio = await self.audio_in_socket.recv()
|
||||
audio = np.frombuffer(audio, dtype=np.float32)
|
||||
speech = await self._transcribe(audio)
|
||||
if not speech:
|
||||
logger.info("Nothing transcribed.")
|
||||
return
|
||||
|
||||
logger.info("Transcribed speech: %s", speech)
|
||||
|
||||
await self._share_transcription(speech)
|
||||
|
||||
Reference in New Issue
Block a user