feat: implement transcriber agent
Uses speech fragments of the VAD agent, emits transcribed text over SPADE's default communication channel to no recipient for now. ref: N25B-209
This commit is contained in:
@@ -0,0 +1,62 @@
|
||||
import abc
|
||||
import sys
|
||||
|
||||
if sys.platform == "darwin":
|
||||
import mlx.core as mx
|
||||
import mlx_whisper
|
||||
from mlx_whisper.transcribe import ModelHolder
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import whisper
|
||||
|
||||
|
||||
class SpeechRecognizer(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def load_model(self): ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def recognize_speech(self, audio: np.ndarray) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def best_type():
|
||||
if torch.mps.is_available():
|
||||
print("Choosing MLX Whisper model.")
|
||||
return MLXWhisperSpeechRecognizer()
|
||||
else:
|
||||
print("Choosing reference Whisper model.")
|
||||
return OpenAIWhisperSpeechRecognizer()
|
||||
|
||||
|
||||
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = None
|
||||
self.model_name = "mlx-community/whisper-small.en-mlx"
|
||||
|
||||
def load_model(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
ModelHolder.get_model(
|
||||
self.model_name, mx.float16
|
||||
) # Should store it in memory for later usage
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"]
|
||||
|
||||
|
||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.model = whisper.load_model("small.en", device=device)
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return self.model.transcribe(audio)["text"]
|
||||
Reference in New Issue
Block a user