feat: implement transcriber agent
Uses speech fragments of the VAD agent, emits transcribed text over SPADE's default communication channel to no recipient for now. ref: N25B-209
This commit is contained in:
1
src/control_backend/agents/transcription/__init__.py
Normal file
1
src/control_backend/agents/transcription/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .transcription_agent import TranscriptionAgent as TranscriptionAgent
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
import abc
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_whisper
|
||||||
|
from mlx_whisper.transcribe import ModelHolder
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import whisper
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechRecognizer(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_model(self): ...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def recognize_speech(self, audio: np.ndarray) -> str: ...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def best_type():
|
||||||
|
if torch.mps.is_available():
|
||||||
|
print("Choosing MLX Whisper model.")
|
||||||
|
return MLXWhisperSpeechRecognizer()
|
||||||
|
else:
|
||||||
|
print("Choosing reference Whisper model.")
|
||||||
|
return OpenAIWhisperSpeechRecognizer()
|
||||||
|
|
||||||
|
|
||||||
|
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.model = None
|
||||||
|
self.model_name = "mlx-community/whisper-small.en-mlx"
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
if self.model is not None:
|
||||||
|
return
|
||||||
|
ModelHolder.get_model(
|
||||||
|
self.model_name, mx.float16
|
||||||
|
) # Should store it in memory for later usage
|
||||||
|
|
||||||
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
|
self.load_model()
|
||||||
|
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
if self.model is not None:
|
||||||
|
return
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.model = whisper.load_model("small.en", device=device)
|
||||||
|
|
||||||
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
|
self.load_model()
|
||||||
|
return self.model.transcribe(audio)["text"]
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio as azmq
|
||||||
|
from spade.agent import Agent
|
||||||
|
from spade.behaviour import CyclicBehaviour
|
||||||
|
from spade.message import Message
|
||||||
|
|
||||||
|
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.core.zmq_context import context as zmq_context
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionAgent(Agent):
|
||||||
|
"""
|
||||||
|
An agent which listens to audio fragments with voice, transcribes them, and sends the
|
||||||
|
transcription to other agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, audio_in_address: str):
|
||||||
|
jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host
|
||||||
|
super().__init__(jid, settings.agent_settings.transcription_agent_name)
|
||||||
|
|
||||||
|
self.audio_in_address = audio_in_address
|
||||||
|
self.audio_in_socket: azmq.Socket | None = None
|
||||||
|
|
||||||
|
class Transcribing(CyclicBehaviour):
|
||||||
|
def __init__(self, audio_in_socket: azmq.Socket):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_in_socket = audio_in_socket
|
||||||
|
self.speech_recognizer = SpeechRecognizer.best_type()
|
||||||
|
self._concurrency = asyncio.Semaphore(3)
|
||||||
|
|
||||||
|
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||||
|
async with self._concurrency:
|
||||||
|
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
audio = await self.audio_in_socket.recv()
|
||||||
|
audio = np.frombuffer(audio, dtype=np.float32)
|
||||||
|
speech = await self._transcribe(audio)
|
||||||
|
logger.info("Transcribed speech: %s", speech)
|
||||||
|
|
||||||
|
message = Message(body=speech)
|
||||||
|
await self.send(message)
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
self.audio_in_socket.close()
|
||||||
|
self.audio_in_socket = None
|
||||||
|
return await super().stop()
|
||||||
|
|
||||||
|
def _connect_audio_in_socket(self):
|
||||||
|
self.audio_in_socket = zmq_context.socket(zmq.SUB)
|
||||||
|
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
self.audio_in_socket.connect(self.audio_in_address)
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
logger.info("Setting up %s", self.jid)
|
||||||
|
|
||||||
|
self._connect_audio_in_socket()
|
||||||
|
|
||||||
|
transcribing = self.Transcribing(self.audio_in_socket)
|
||||||
|
self.add_behaviour(transcribing)
|
||||||
|
|
||||||
|
logger.info("Finished setting up %s", self.jid)
|
||||||
@@ -7,6 +7,7 @@ import zmq.asyncio as azmq
|
|||||||
from spade.agent import Agent
|
from spade.agent import Agent
|
||||||
from spade.behaviour import CyclicBehaviour
|
from spade.behaviour import CyclicBehaviour
|
||||||
|
|
||||||
|
from control_backend.agents.transcription import TranscriptionAgent
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.core.zmq_context import context as zmq_context
|
from control_backend.core.zmq_context import context as zmq_context
|
||||||
|
|
||||||
@@ -145,10 +146,13 @@ class VADAgent(Agent):
|
|||||||
if audio_out_port is None:
|
if audio_out_port is None:
|
||||||
await self.stop()
|
await self.stop()
|
||||||
return
|
return
|
||||||
|
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||||
|
|
||||||
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
|
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
|
||||||
self.add_behaviour(streaming)
|
self.add_behaviour(streaming)
|
||||||
|
|
||||||
# ... start agents dependent on the output audio fragments here
|
# Start agents dependent on the output audio fragments here
|
||||||
|
transcriber = TranscriptionAgent(audio_out_address)
|
||||||
|
await transcriber.start()
|
||||||
|
|
||||||
logger.info("Finished setting up %s", self.jid)
|
logger.info("Finished setting up %s", self.jid)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ class AgentSettings(BaseModel):
|
|||||||
bdi_core_agent_name: str = "bdi_core"
|
bdi_core_agent_name: str = "bdi_core"
|
||||||
belief_collector_agent_name: str = "belief_collector"
|
belief_collector_agent_name: str = "belief_collector"
|
||||||
vad_agent_name: str = "vad_agent"
|
vad_agent_name: str = "vad_agent"
|
||||||
|
transcription_agent_name: str = "transcription_agent"
|
||||||
|
|
||||||
ri_communication_agent_name: str = "ri_communication_agent"
|
ri_communication_agent_name: str = "ri_communication_agent"
|
||||||
ri_command_agent_name: str = "ri_command_agent"
|
ri_command_agent_name: str = "ri_command_agent"
|
||||||
|
|||||||
Reference in New Issue
Block a user