fix: allow Whisper to generate more tokens based on audio length
Before, it sometimes cut off the transcription too early. ref: N25B-209
This commit is contained in:
@@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC):
|
|||||||
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||||
"""
|
"""
|
||||||
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
||||||
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens.
|
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
|
||||||
|
|
||||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||||
:return: The estimated length of the transcribed audio in tokens.
|
:return: The estimated length of the transcribed audio in tokens.
|
||||||
"""
|
"""
|
||||||
length_seconds = len(audio) / 16_000
|
length_seconds = len(audio) / 16_000
|
||||||
length_minutes = length_seconds / 60
|
length_minutes = length_seconds / 60
|
||||||
word_count = length_minutes * 300
|
word_count = length_minutes * 450
|
||||||
token_count = word_count / 3 * 4
|
token_count = word_count / 3 * 4
|
||||||
return int(token_count)
|
return int(token_count) + 10
|
||||||
|
|
||||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -84,7 +84,12 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
|
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"]
|
return mlx_whisper.transcribe(
|
||||||
|
audio,
|
||||||
|
path_or_hf_repo=self.model_name,
|
||||||
|
initial_prompt="You're a robot called Pepper, talking with a person called Twirre.",
|
||||||
|
**self._get_decode_options(audio),
|
||||||
|
)["text"].strip()
|
||||||
|
|
||||||
|
|
||||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||||
@@ -101,5 +106,7 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return whisper.transcribe(
|
return whisper.transcribe(
|
||||||
self.model, audio, decode_options=self._get_decode_options(audio)
|
self.model,
|
||||||
|
audio,
|
||||||
|
**self._get_decode_options(audio)
|
||||||
)["text"]
|
)["text"]
|
||||||
|
|||||||
@@ -59,6 +59,10 @@ class TranscriptionAgent(Agent):
|
|||||||
audio = await self.audio_in_socket.recv()
|
audio = await self.audio_in_socket.recv()
|
||||||
audio = np.frombuffer(audio, dtype=np.float32)
|
audio = np.frombuffer(audio, dtype=np.float32)
|
||||||
speech = await self._transcribe(audio)
|
speech = await self._transcribe(audio)
|
||||||
|
if not speech:
|
||||||
|
logger.info("Nothing transcribed.")
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Transcribed speech: %s", speech)
|
logger.info("Transcribed speech: %s", speech)
|
||||||
|
|
||||||
await self._share_transcription(speech)
|
await self._share_transcription(speech)
|
||||||
|
|||||||
Reference in New Issue
Block a user