Merge remote-tracking branch 'origin/dev' into feat/vad-agent
# Conflicts: # pyproject.toml # src/control_backend/main.py # uv.lock
This commit is contained in:
@@ -18,6 +18,7 @@ class SocketPoller[T]:
|
||||
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
|
||||
multiple usages.
|
||||
"""
|
||||
|
||||
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
|
||||
"""
|
||||
:param socket: The socket to poll and get data from.
|
||||
@@ -46,9 +47,9 @@ class Streaming(CyclicBehaviour):
|
||||
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
|
||||
super().__init__()
|
||||
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
|
||||
self.model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad",
|
||||
model="silero_vad",
|
||||
force_reload=False)
|
||||
self.model, _ = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
|
||||
)
|
||||
self.audio_out_socket = audio_out_socket
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
@@ -59,8 +60,10 @@ class Streaming(CyclicBehaviour):
|
||||
data = await self.audio_in_poller.poll()
|
||||
if data is None:
|
||||
if self.i_since_data % 10 == 0:
|
||||
logger.debug("Failed to receive audio from socket for %d ms.",
|
||||
self.audio_in_poller.timeout_ms*(self.i_since_data+1))
|
||||
logger.debug(
|
||||
"Failed to receive audio from socket for %d ms.",
|
||||
self.audio_in_poller.timeout_ms * (self.i_since_data + 1),
|
||||
)
|
||||
self.i_since_data += 1
|
||||
return
|
||||
self.i_since_data = 0
|
||||
@@ -70,7 +73,8 @@ class Streaming(CyclicBehaviour):
|
||||
prob = self.model(torch.from_numpy(chunk), 16000).item()
|
||||
|
||||
if prob > 0.5:
|
||||
if self.i_since_speech > 3: logger.debug("Speech started.")
|
||||
if self.i_since_speech > 3:
|
||||
logger.debug("Speech started.")
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
self.i_since_speech = 0
|
||||
return
|
||||
@@ -82,9 +86,9 @@ class Streaming(CyclicBehaviour):
|
||||
return
|
||||
|
||||
# Speech probably ended. Make sure we have a usable amount of data.
|
||||
if len(self.audio_buffer) >= 3*len(chunk):
|
||||
if len(self.audio_buffer) >= 3 * len(chunk):
|
||||
logger.debug("Speech ended.")
|
||||
await self.audio_out_socket.send(self.audio_buffer[:-2*len(chunk)].tobytes())
|
||||
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
|
||||
|
||||
# At this point, we know that the speech has ended.
|
||||
# Prepend the last chunk that had no speech, for a more fluent boundary
|
||||
@@ -96,8 +100,9 @@ class VADAgent(Agent):
|
||||
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
|
||||
fragments with detected speech to other agents over ZeroMQ.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||
jid = settings.agent_settings.vad_agent_name + '@' + settings.agent_settings.host
|
||||
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host
|
||||
super().__init__(jid, settings.agent_settings.vad_agent_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
@@ -146,7 +151,6 @@ class VADAgent(Agent):
|
||||
if audio_out_port is None:
|
||||
await self.stop()
|
||||
return
|
||||
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||
|
||||
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
|
||||
self.add_behaviour(streaming)
|
||||
|
||||
Reference in New Issue
Block a user