Merge remote-tracking branch 'origin/dev' into demo
# Conflicts: # src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py # src/control_backend/agents/llm/llm.py # src/control_backend/agents/ri_command_agent.py # src/control_backend/agents/transcription/speech_recognizer.py
This commit is contained in:
@@ -75,7 +75,8 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
self.model_name = "mlx-community/whisper-small.en-mlx"
|
||||
|
||||
def load_model(self):
|
||||
if self.was_loaded: return
|
||||
if self.was_loaded:
|
||||
return
|
||||
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
||||
# store it in memory for later usage
|
||||
ModelHolder.get_model(self.model_name, mx.float16)
|
||||
@@ -83,7 +84,7 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
|
||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"]
|
||||
|
||||
|
||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
@@ -92,12 +93,13 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
if self.model is not None: return
|
||||
if self.model is not None:
|
||||
return
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.model = whisper.load_model("small.en", device=device)
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return whisper.transcribe(self.model,
|
||||
audio,
|
||||
decode_options=self._get_decode_options(audio))["text"]
|
||||
return whisper.transcribe(
|
||||
self.model, audio, decode_options=self._get_decode_options(audio)
|
||||
)["text"]
|
||||
|
||||
@@ -47,7 +47,8 @@ class TranscriptionAgent(Agent):
|
||||
"""Share a transcription to the other agents that depend on it."""
|
||||
receiver_jids = [
|
||||
settings.agent_settings.text_belief_extractor_agent_name
|
||||
+ '@' + settings.agent_settings.host,
|
||||
+ "@"
|
||||
+ settings.agent_settings.host,
|
||||
] # Set message receivers here
|
||||
|
||||
for receiver_jid in receiver_jids:
|
||||
|
||||
Reference in New Issue
Block a user