Merge branch 'refactor/config-file' into 'dev'
refactor: remove constants and put in config file See merge request ics/sp/2025/n25b/pepperplus-cb!24
This commit was merged in pull request #24.
This commit is contained in:
@@ -20,9 +20,9 @@ class RobotSpeechAgent(BaseAgent):
|
|||||||
self,
|
self,
|
||||||
jid: str,
|
jid: str,
|
||||||
password: str,
|
password: str,
|
||||||
port: int = 5222,
|
port: int = settings.agent_settings.default_spade_port,
|
||||||
verify_security: bool = False,
|
verify_security: bool = False,
|
||||||
address="tcp://localhost:0000",
|
address=settings.zmq_settings.ri_command_address,
|
||||||
bind=False,
|
bind=False,
|
||||||
):
|
):
|
||||||
super().__init__(jid, password, port, verify_security)
|
super().__init__(jid, password, port, verify_security)
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
self,
|
self,
|
||||||
jid: str,
|
jid: str,
|
||||||
password: str,
|
password: str,
|
||||||
port: int = 5222,
|
port: int = settings.agent_settings.default_spade_port,
|
||||||
verify_security: bool = False,
|
verify_security: bool = False,
|
||||||
address="tcp://localhost:0000",
|
address=settings.zmq_settings.ri_command_address,
|
||||||
bind=False,
|
bind=False,
|
||||||
):
|
):
|
||||||
super().__init__(jid, password, port, verify_security)
|
super().__init__(jid, password, port, verify_security)
|
||||||
@@ -40,12 +40,12 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
assert self.agent is not None
|
assert self.agent is not None
|
||||||
|
|
||||||
if not self.agent.connected:
|
if not self.agent.connected:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(settings.behaviour_settings.sleep_s)
|
||||||
return
|
return
|
||||||
|
|
||||||
# We need to listen and sent pings.
|
# We need to listen and sent pings.
|
||||||
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
|
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
|
||||||
seconds_to_wait_total = 1.0
|
seconds_to_wait_total = settings.behaviour_settings.sleep_s
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
self.agent._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
|
self.agent._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
|
||||||
@@ -87,7 +87,7 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
self.agent.logger.warning(
|
self.agent.logger.warning(
|
||||||
"Initial connection ping for router timed out in com_ri_agent."
|
f"Initial connection ping for router timed out in {self.agent.name}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to reboot.
|
# Try to reboot.
|
||||||
@@ -108,7 +108,7 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
data = json.dumps(True).encode()
|
data = json.dumps(True).encode()
|
||||||
if self.agent.pub_socket is not None:
|
if self.agent.pub_socket is not None:
|
||||||
await self.agent.pub_socket.send_multipart([topic, data])
|
await self.agent.pub_socket.send_multipart([topic, data])
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(settings.behaviour_settings.sleep_s)
|
||||||
case _:
|
case _:
|
||||||
self.agent.logger.debug(
|
self.agent.logger.debug(
|
||||||
"Received message with topic different than ping, while ping expected."
|
"Received message with topic different than ping, while ping expected."
|
||||||
@@ -130,9 +130,10 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
self.pub_socket = Context.instance().socket(zmq.PUB)
|
self.pub_socket = Context.instance().socket(zmq.PUB)
|
||||||
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||||
|
|
||||||
async def setup(self, max_retries: int = 100):
|
async def setup(self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries):
|
||||||
"""
|
"""
|
||||||
Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
|
Try to set up the communication agent, we have `behaviour_settings.comm_setup_max_retries`
|
||||||
|
retries in case we don't have a response yet.
|
||||||
"""
|
"""
|
||||||
self.logger.info("Setting up %s", self.jid)
|
self.logger.info("Setting up %s", self.jid)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
class SpeechRecognizer(abc.ABC):
|
class SpeechRecognizer(abc.ABC):
|
||||||
def __init__(self, limit_output_length=True):
|
def __init__(self, limit_output_length=True):
|
||||||
@@ -41,11 +43,11 @@ class SpeechRecognizer(abc.ABC):
|
|||||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||||
:return: The estimated length of the transcribed audio in tokens.
|
:return: The estimated length of the transcribed audio in tokens.
|
||||||
"""
|
"""
|
||||||
length_seconds = len(audio) / 16_000
|
length_seconds = len(audio) / settings.vad_settings.sample_rate_hz
|
||||||
length_minutes = length_seconds / 60
|
length_minutes = length_seconds / 60
|
||||||
word_count = length_minutes * 450
|
word_count = length_minutes * settings.behaviour_settings.transcription_words_per_minute
|
||||||
token_count = word_count / 3 * 4
|
token_count = word_count / settings.behaviour_settings.transcription_words_per_token
|
||||||
return int(token_count) + 10
|
return int(token_count) + settings.behaviour_settings.transcription_token_buffer
|
||||||
|
|
||||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -72,7 +74,7 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
def __init__(self, limit_output_length=True):
|
def __init__(self, limit_output_length=True):
|
||||||
super().__init__(limit_output_length)
|
super().__init__(limit_output_length)
|
||||||
self.was_loaded = False
|
self.was_loaded = False
|
||||||
self.model_name = "mlx-community/whisper-small.en-mlx"
|
self.model_name = settings.speech_model_settings.mlx_model_name
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.was_loaded:
|
if self.was_loaded:
|
||||||
@@ -100,7 +102,9 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
return
|
return
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
self.model = whisper.load_model("small.en", device=device)
|
self.model = whisper.load_model(
|
||||||
|
settings.speech_model_settings.openai_model_name, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|||||||
@@ -28,9 +28,10 @@ class TranscriptionAgent(BaseAgent):
|
|||||||
class TranscribingBehaviour(CyclicBehaviour):
|
class TranscribingBehaviour(CyclicBehaviour):
|
||||||
def __init__(self, audio_in_socket: azmq.Socket):
|
def __init__(self, audio_in_socket: azmq.Socket):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
|
||||||
self.audio_in_socket = audio_in_socket
|
self.audio_in_socket = audio_in_socket
|
||||||
self.speech_recognizer = SpeechRecognizer.best_type()
|
self.speech_recognizer = SpeechRecognizer.best_type()
|
||||||
self._concurrency = asyncio.Semaphore(3)
|
self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
|
||||||
|
|
||||||
def warmup(self):
|
def warmup(self):
|
||||||
"""Load the transcription model into memory to speed up the first transcription."""
|
"""Load the transcription model into memory to speed up the first transcription."""
|
||||||
|
|||||||
@@ -16,7 +16,11 @@ class SocketPoller[T]:
|
|||||||
multiple usages.
|
multiple usages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
|
def __init__(
|
||||||
|
self,
|
||||||
|
socket: azmq.Socket,
|
||||||
|
timeout_ms: int = settings.behaviour_settings.socket_poller_timeout_ms,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
:param socket: The socket to poll and get data from.
|
:param socket: The socket to poll and get data from.
|
||||||
:param timeout_ms: A timeout in milliseconds to wait for data.
|
:param timeout_ms: A timeout in milliseconds to wait for data.
|
||||||
@@ -45,17 +49,22 @@ class StreamingBehaviour(CyclicBehaviour):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
|
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
|
||||||
self.model, _ = torch.hub.load(
|
self.model, _ = torch.hub.load(
|
||||||
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
|
repo_or_dir=settings.vad_settings.repo_or_dir,
|
||||||
|
model=settings.vad_settings.model_name,
|
||||||
|
force_reload=False,
|
||||||
)
|
)
|
||||||
self.audio_out_socket = audio_out_socket
|
self.audio_out_socket = audio_out_socket
|
||||||
|
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
self.i_since_speech = 100 # Used to allow small pauses in speech
|
self.i_since_speech = (
|
||||||
|
settings.behaviour_settings.vad_initial_since_speech
|
||||||
|
) # Used to allow small pauses in speech
|
||||||
self._ready = False
|
self._ready = False
|
||||||
|
|
||||||
async def reset(self):
|
async def reset(self):
|
||||||
"""Clears the ZeroMQ queue and tells this behavior to start."""
|
"""Clears the ZeroMQ queue and tells this behavior to start."""
|
||||||
discarded = 0
|
discarded = 0
|
||||||
|
# Poll for the shortest amount of time possible to clear the queue
|
||||||
while await self.audio_in_poller.poll(1) is not None:
|
while await self.audio_in_poller.poll(1) is not None:
|
||||||
discarded += 1
|
discarded += 1
|
||||||
self.agent.logger.info(f"Discarded {discarded} audio packets before starting.")
|
self.agent.logger.info(f"Discarded {discarded} audio packets before starting.")
|
||||||
@@ -72,15 +81,17 @@ class StreamingBehaviour(CyclicBehaviour):
|
|||||||
"No audio data received. Discarding buffer until new data arrives."
|
"No audio data received. Discarding buffer until new data arrives."
|
||||||
)
|
)
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
self.i_since_speech = 100
|
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||||
return
|
return
|
||||||
|
|
||||||
# copy otherwise Torch will be sad that it's immutable
|
# copy otherwise Torch will be sad that it's immutable
|
||||||
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
||||||
prob = self.model(torch.from_numpy(chunk), 16000).item()
|
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
|
||||||
|
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
|
||||||
|
prob_threshold = settings.behaviour_settings.vad_prob_threshold
|
||||||
|
|
||||||
if prob > 0.5:
|
if prob > prob_threshold:
|
||||||
if self.i_since_speech > 3:
|
if self.i_since_speech > non_speech_patience:
|
||||||
self.agent.logger.debug("Speech started.")
|
self.agent.logger.debug("Speech started.")
|
||||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||||
self.i_since_speech = 0
|
self.i_since_speech = 0
|
||||||
@@ -88,7 +99,7 @@ class StreamingBehaviour(CyclicBehaviour):
|
|||||||
self.i_since_speech += 1
|
self.i_since_speech += 1
|
||||||
|
|
||||||
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
|
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
|
||||||
if self.i_since_speech <= 3:
|
if self.i_since_speech <= non_speech_patience:
|
||||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,16 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
class ZMQSettings(BaseModel):
|
class ZMQSettings(BaseModel):
|
||||||
internal_pub_address: str = "tcp://localhost:5560"
|
internal_pub_address: str = "tcp://localhost:5560"
|
||||||
internal_sub_address: str = "tcp://localhost:5561"
|
internal_sub_address: str = "tcp://localhost:5561"
|
||||||
|
ri_command_address: str = "tcp://localhost:0000"
|
||||||
|
ri_communication_address: str = "tcp://*:5555"
|
||||||
|
vad_agent_address: str = "tcp://localhost:5558"
|
||||||
|
|
||||||
|
|
||||||
class AgentSettings(BaseModel):
|
class AgentSettings(BaseModel):
|
||||||
|
# connection settings
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
|
|
||||||
|
# agent names
|
||||||
bdi_core_name: str = "bdi_core_agent"
|
bdi_core_name: str = "bdi_core_agent"
|
||||||
bdi_belief_collector_name: str = "belief_collector_agent"
|
bdi_belief_collector_name: str = "belief_collector_agent"
|
||||||
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
||||||
@@ -16,14 +22,46 @@ class AgentSettings(BaseModel):
|
|||||||
llm_name: str = "llm_agent"
|
llm_name: str = "llm_agent"
|
||||||
test_name: str = "test_agent"
|
test_name: str = "test_agent"
|
||||||
transcription_name: str = "transcription_agent"
|
transcription_name: str = "transcription_agent"
|
||||||
|
|
||||||
ri_communication_name: str = "ri_communication_agent"
|
ri_communication_name: str = "ri_communication_agent"
|
||||||
robot_speech_name: str = "robot_speech_agent"
|
robot_speech_name: str = "robot_speech_agent"
|
||||||
|
|
||||||
|
# default SPADE port
|
||||||
|
default_spade_port: int = 5222
|
||||||
|
|
||||||
|
|
||||||
|
class BehaviourSettings(BaseModel):
|
||||||
|
sleep_s: float = 1.0
|
||||||
|
comm_setup_max_retries: int = 5
|
||||||
|
socket_poller_timeout_ms: int = 100
|
||||||
|
|
||||||
|
# VAD settings
|
||||||
|
vad_prob_threshold: float = 0.5
|
||||||
|
vad_initial_since_speech: int = 100
|
||||||
|
vad_non_speech_patience_chunks: int = 3
|
||||||
|
|
||||||
|
# transcription behaviour
|
||||||
|
transcription_max_concurrent_tasks: int = 3
|
||||||
|
transcription_words_per_minute: int = 300
|
||||||
|
transcription_words_per_token: float = 0.75 # (3 words = 4 tokens)
|
||||||
|
transcription_token_buffer: int = 10
|
||||||
|
|
||||||
|
|
||||||
class LLMSettings(BaseModel):
|
class LLMSettings(BaseModel):
|
||||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||||
local_llm_model: str = "openai/gpt-oss-20b"
|
local_llm_model: str = "openai/gpt-oss-20b"
|
||||||
|
request_timeout_s: int = 120
|
||||||
|
|
||||||
|
|
||||||
|
class VADSettings(BaseModel):
|
||||||
|
repo_or_dir: str = "snakers4/silero-vad"
|
||||||
|
model_name: str = "silero_vad"
|
||||||
|
sample_rate_hz: int = 16000
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechModelSettings(BaseModel):
|
||||||
|
# model identifiers for speech recognition
|
||||||
|
mlx_model_name: str = "mlx-community/whisper-small.en-mlx"
|
||||||
|
openai_model_name: str = "small.en"
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -35,6 +73,12 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
agent_settings: AgentSettings = AgentSettings()
|
agent_settings: AgentSettings = AgentSettings()
|
||||||
|
|
||||||
|
behaviour_settings: BehaviourSettings = BehaviourSettings()
|
||||||
|
|
||||||
|
vad_settings: VADSettings = VADSettings()
|
||||||
|
|
||||||
|
speech_model_settings: SpeechModelSettings = SpeechModelSettings()
|
||||||
|
|
||||||
llm_settings: LLMSettings = LLMSettings()
|
llm_settings: LLMSettings = LLMSettings()
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env")
|
model_config = SettingsConfigDict(env_file=".env")
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ async def lifespan(app: FastAPI):
|
|||||||
"jid": f"{settings.agent_settings.ri_communication_name}"
|
"jid": f"{settings.agent_settings.ri_communication_name}"
|
||||||
f"@{settings.agent_settings.host}",
|
f"@{settings.agent_settings.host}",
|
||||||
"password": settings.agent_settings.ri_communication_name,
|
"password": settings.agent_settings.ri_communication_name,
|
||||||
"address": "tcp://*:5555",
|
"address": settings.zmq_settings.ri_communication_address,
|
||||||
"bind": True,
|
"bind": True,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@@ -124,7 +124,7 @@ async def lifespan(app: FastAPI):
|
|||||||
),
|
),
|
||||||
"VADAgent": (
|
"VADAgent": (
|
||||||
VADAgent,
|
VADAgent,
|
||||||
{"audio_in_address": "tcp://localhost:5558", "audio_in_bind": False},
|
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
|
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
|
||||||
OpenAIWhisperSpeechRecognizer,
|
OpenAIWhisperSpeechRecognizer,
|
||||||
@@ -6,6 +7,24 @@ from control_backend.agents.perception.transcription_agent.speech_recognizer imp
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_sr_settings(monkeypatch):
|
||||||
|
# Patch the *module-local* settings that SpeechRecognizer imported
|
||||||
|
from control_backend.agents.perception.transcription_agent import speech_recognizer as sr
|
||||||
|
|
||||||
|
# Provide real numbers for everything _estimate_max_tokens() reads
|
||||||
|
monkeypatch.setattr(sr.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
sr.settings.behaviour_settings, "transcription_words_per_minute", 450, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
sr.settings.behaviour_settings, "transcription_words_per_token", 0.75, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
sr.settings.behaviour_settings, "transcription_token_buffer", 10, raising=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_estimate_max_tokens():
|
def test_estimate_max_tokens():
|
||||||
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
|
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
|
||||||
expecting 610 tokens."""
|
expecting 610 tokens."""
|
||||||
|
|||||||
@@ -35,6 +35,23 @@ def streaming(audio_in_socket, audio_out_socket, mock_agent):
|
|||||||
return streaming
|
return streaming
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_settings(monkeypatch):
|
||||||
|
# Patch the settings that vad_agent.run() reads
|
||||||
|
from control_backend.agents.perception import vad_agent
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
||||||
|
|
||||||
|
|
||||||
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||||
"""
|
"""
|
||||||
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
||||||
@@ -59,7 +76,6 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f
|
|||||||
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
||||||
"""
|
"""
|
||||||
Test a scenario where there is voice activity detected between silences.
|
Test a scenario where there is voice activity detected between silences.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
speech_chunk_count = 5
|
speech_chunk_count = 5
|
||||||
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||||
@@ -68,8 +84,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
|
|||||||
audio_out_socket.send.assert_called_once()
|
audio_out_socket.send.assert_called_once()
|
||||||
data = audio_out_socket.send.call_args[0][0]
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
assert isinstance(data, bytes)
|
assert isinstance(data, bytes)
|
||||||
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
|
assert len(data) == 512 * 4 * (speech_chunk_count + 1)
|
||||||
assert len(data) == 512 * 4 * (speech_chunk_count + 2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -87,8 +102,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
|
|||||||
audio_out_socket.send.assert_called_once()
|
audio_out_socket.send.assert_called_once()
|
||||||
data = audio_out_socket.send.call_args[0][0]
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
assert isinstance(data, bytes)
|
assert isinstance(data, bytes)
|
||||||
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
|
# Expecting 13 chunks (2*5 with speech, 1 pause between, 1 as padding)
|
||||||
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2)
|
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user