feat: limit transcription output length based on input

Using heuristics. Also adds documentation and initial unit tests.

ref: N25B-209
This commit is contained in:
Twirre Meulenbelt
2025-10-29 12:49:24 +01:00
parent 4d6bac7e2b
commit bec3e57658
5 changed files with 124 additions and 17 deletions

View File

@@ -0,0 +1,36 @@
import numpy as np
from control_backend.agents.transcription import SpeechRecognizer
from control_backend.agents.transcription.speech_recognizer import OpenAIWhisperSpeechRecognizer
def test_estimate_max_tokens():
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
audio = np.empty(shape=(60*16_000), dtype=np.float32)
actual = SpeechRecognizer._estimate_max_tokens(audio)
assert actual == 400
assert isinstance(actual, int)
def test_get_decode_options():
"""Check whether the right decode options are given under different scenarios."""
audio = np.empty(shape=(60*16_000), dtype=np.float32)
# With the defaults, it should limit output length based on input size
recognizer = OpenAIWhisperSpeechRecognizer()
options = recognizer._get_decode_options(audio)
assert "sample_len" in options
assert isinstance(options["sample_len"], int)
# When explicitly enabled, it should limit output length based on input size
recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=True)
options = recognizer._get_decode_options(audio)
assert "sample_len" in options
assert isinstance(options["sample_len"], int)
# When disabled, it should not limit output length based on input size
assert "sample_rate" not in options

View File

@@ -11,6 +11,7 @@ def pytest_configure(config):
mock_spade = MagicMock()
mock_spade.agent = MagicMock()
mock_spade.behaviour = MagicMock()
mock_spade.message = MagicMock()
mock_spade_bdi = MagicMock()
mock_spade_bdi.bdi = MagicMock()
@@ -21,6 +22,7 @@ def pytest_configure(config):
sys.modules["spade"] = mock_spade
sys.modules["spade.agent"] = mock_spade.agent
sys.modules["spade.behaviour"] = mock_spade.behaviour
sys.modules["spade.message"] = mock_spade.message
sys.modules["spade_bdi"] = mock_spade_bdi
sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi
@@ -43,3 +45,16 @@ def pytest_configure(config):
sys.modules["torch"] = mock_torch
sys.modules["zmq"] = mock_zmq
sys.modules["zmq.asyncio"] = mock_zmq.asyncio
# --- Mock whisper ---
mock_whisper = MagicMock()
mock_mlx = MagicMock()
mock_mlx.core = MagicMock()
mock_mlx_whisper = MagicMock()
mock_mlx_whisper.transcribe = MagicMock()
sys.modules["whisper"] = mock_whisper
sys.modules["mlx"] = mock_mlx
sys.modules["mlx.core"] = mock_mlx
sys.modules["mlx_whisper"] = mock_mlx_whisper
sys.modules["mlx_whisper.transcribe"] = mock_mlx_whisper.transcribe