feat: limit transcription output length based on input
Using heuristics. Also adds documentation and initial unit tests. ref: N25B-209
This commit is contained in:
@@ -35,18 +35,29 @@ class TranscriptionAgent(Agent):
|
||||
self.speech_recognizer = SpeechRecognizer.best_type()
|
||||
self._concurrency = asyncio.Semaphore(3)
|
||||
|
||||
def warmup(self):
|
||||
"""Load the transcription model into memory to speed up the first transcription."""
|
||||
self.speech_recognizer.load_model()
|
||||
|
||||
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||
async with self._concurrency:
|
||||
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||
|
||||
async def _share_transcription(self, transcription: str):
|
||||
"""Share a transcription to the other agents that depend on it."""
|
||||
receiver_jids = [] # Set message receivers here
|
||||
|
||||
for receiver_jid in receiver_jids:
|
||||
message = Message(to=receiver_jid, body=transcription)
|
||||
await self.send(message)
|
||||
|
||||
async def run(self) -> None:
|
||||
audio = await self.audio_in_socket.recv()
|
||||
audio = np.frombuffer(audio, dtype=np.float32)
|
||||
speech = await self._transcribe(audio)
|
||||
logger.info("Transcribed speech: %s", speech)
|
||||
|
||||
message = Message(body=speech)
|
||||
await self.send(message)
|
||||
await self._share_transcription(speech)
|
||||
|
||||
async def stop(self):
|
||||
self.audio_in_socket.close()
|
||||
@@ -64,6 +75,7 @@ class TranscriptionAgent(Agent):
|
||||
self._connect_audio_in_socket()
|
||||
|
||||
transcribing = self.Transcribing(self.audio_in_socket)
|
||||
transcribing.warmup()
|
||||
self.add_behaviour(transcribing)
|
||||
|
||||
logger.info("Finished setting up %s", self.jid)
|
||||
|
||||
Reference in New Issue
Block a user