diff --git a/src/control_backend/agents/actuation/robot_speech_agent.py b/src/control_backend/agents/actuation/robot_speech_agent.py index d540698..9b1ea61 100644 --- a/src/control_backend/agents/actuation/robot_speech_agent.py +++ b/src/control_backend/agents/actuation/robot_speech_agent.py @@ -20,9 +20,9 @@ class RobotSpeechAgent(BaseAgent): self, jid: str, password: str, - port: int = 5222, + port: int = settings.agent_settings.default_spade_port, verify_security: bool = False, - address="tcp://localhost:0000", + address=settings.zmq_settings.ri_command_address, bind=False, ): super().__init__(jid, password, port, verify_security) diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index 3e52df3..3b414a1 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -21,9 +21,9 @@ class RICommunicationAgent(BaseAgent): self, jid: str, password: str, - port: int = 5222, + port: int = settings.agent_settings.default_spade_port, verify_security: bool = False, - address="tcp://localhost:0000", + address=settings.zmq_settings.ri_command_address, bind=False, ): super().__init__(jid, password, port, verify_security) @@ -40,12 +40,12 @@ class RICommunicationAgent(BaseAgent): assert self.agent is not None if not self.agent.connected: - await asyncio.sleep(1) + await asyncio.sleep(settings.behaviour_settings.sleep_s) return # We need to listen and sent pings. message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}} - seconds_to_wait_total = 1.0 + seconds_to_wait_total = settings.behaviour_settings.sleep_s try: await asyncio.wait_for( self.agent._req_socket.send_json(message), timeout=seconds_to_wait_total / 2 @@ -87,7 +87,7 @@ class RICommunicationAgent(BaseAgent): ) except TimeoutError: self.agent.logger.warning( - "Initial connection ping for router timed out in com_ri_agent." + f"Initial connection ping for router timed out in {self.agent.name}." ) # Try to reboot. @@ -108,7 +108,7 @@ class RICommunicationAgent(BaseAgent): data = json.dumps(True).encode() if self.agent.pub_socket is not None: await self.agent.pub_socket.send_multipart([topic, data]) - await asyncio.sleep(1) + await asyncio.sleep(settings.behaviour_settings.sleep_s) case _: self.agent.logger.debug( "Received message with topic different than ping, while ping expected." @@ -130,9 +130,10 @@ class RICommunicationAgent(BaseAgent): self.pub_socket = Context.instance().socket(zmq.PUB) self.pub_socket.connect(settings.zmq_settings.internal_pub_address) - async def setup(self, max_retries: int = 100): + async def setup(self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries): """ - Try to setup the communication agent, we have 5 retries in case we dont have a response yet. + Try to set up the communication agent, we have `behaviour_settings.comm_setup_max_retries` + retries in case we don't have a response yet. """ self.logger.info("Setting up %s", self.jid) diff --git a/src/control_backend/agents/perception/transcription_agent/speech_recognizer.py b/src/control_backend/agents/perception/transcription_agent/speech_recognizer.py index 527d371..5893be4 100644 --- a/src/control_backend/agents/perception/transcription_agent/speech_recognizer.py +++ b/src/control_backend/agents/perception/transcription_agent/speech_recognizer.py @@ -10,6 +10,8 @@ import numpy as np import torch import whisper +from control_backend.core.config import settings + class SpeechRecognizer(abc.ABC): def __init__(self, limit_output_length=True): @@ -41,11 +43,11 @@ class SpeechRecognizer(abc.ABC): :param audio: The audio sample (16 kHz) to use for length estimation. :return: The estimated length of the transcribed audio in tokens. """ - length_seconds = len(audio) / 16_000 + length_seconds = len(audio) / settings.vad_settings.sample_rate_hz length_minutes = length_seconds / 60 - word_count = length_minutes * 450 - token_count = word_count / 3 * 4 - return int(token_count) + 10 + word_count = length_minutes * settings.behaviour_settings.transcription_words_per_minute + token_count = word_count / settings.behaviour_settings.transcription_words_per_token + return int(token_count) + settings.behaviour_settings.transcription_token_buffer def _get_decode_options(self, audio: np.ndarray) -> dict: """ @@ -72,7 +74,7 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def __init__(self, limit_output_length=True): super().__init__(limit_output_length) self.was_loaded = False - self.model_name = "mlx-community/whisper-small.en-mlx" + self.model_name = settings.speech_model_settings.mlx_model_name def load_model(self): if self.was_loaded: @@ -100,7 +102,9 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): if self.model is not None: return device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.model = whisper.load_model("small.en", device=device) + self.model = whisper.load_model( + settings.speech_model_settings.openai_model_name, device=device + ) def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() diff --git a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py index d6c1207..44c1387 100644 --- a/src/control_backend/agents/perception/transcription_agent/transcription_agent.py +++ b/src/control_backend/agents/perception/transcription_agent/transcription_agent.py @@ -28,9 +28,10 @@ class TranscriptionAgent(BaseAgent): class TranscribingBehaviour(CyclicBehaviour): def __init__(self, audio_in_socket: azmq.Socket): super().__init__() + max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks self.audio_in_socket = audio_in_socket self.speech_recognizer = SpeechRecognizer.best_type() - self._concurrency = asyncio.Semaphore(3) + self._concurrency = asyncio.Semaphore(max_concurrent_tasks) def warmup(self): """Load the transcription model into memory to speed up the first transcription.""" diff --git a/src/control_backend/agents/perception/vad_agent.py b/src/control_backend/agents/perception/vad_agent.py index cab27c2..7c9d513 100644 --- a/src/control_backend/agents/perception/vad_agent.py +++ b/src/control_backend/agents/perception/vad_agent.py @@ -16,7 +16,11 @@ class SocketPoller[T]: multiple usages. """ - def __init__(self, socket: azmq.Socket, timeout_ms: int = 100): + def __init__( + self, + socket: azmq.Socket, + timeout_ms: int = settings.behaviour_settings.socket_poller_timeout_ms, + ): """ :param socket: The socket to poll and get data from. :param timeout_ms: A timeout in milliseconds to wait for data. @@ -45,17 +49,22 @@ class StreamingBehaviour(CyclicBehaviour): super().__init__() self.audio_in_poller = SocketPoller[bytes](audio_in_socket) self.model, _ = torch.hub.load( - repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False + repo_or_dir=settings.vad_settings.repo_or_dir, + model=settings.vad_settings.model_name, + force_reload=False, ) self.audio_out_socket = audio_out_socket self.audio_buffer = np.array([], dtype=np.float32) - self.i_since_speech = 100 # Used to allow small pauses in speech + self.i_since_speech = ( + settings.behaviour_settings.vad_initial_since_speech + ) # Used to allow small pauses in speech self._ready = False async def reset(self): """Clears the ZeroMQ queue and tells this behavior to start.""" discarded = 0 + # Poll for the shortest amount of time possible to clear the queue while await self.audio_in_poller.poll(1) is not None: discarded += 1 self.agent.logger.info(f"Discarded {discarded} audio packets before starting.") @@ -72,15 +81,17 @@ class StreamingBehaviour(CyclicBehaviour): "No audio data received. Discarding buffer until new data arrives." ) self.audio_buffer = np.array([], dtype=np.float32) - self.i_since_speech = 100 + self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech return # copy otherwise Torch will be sad that it's immutable chunk = np.frombuffer(data, dtype=np.float32).copy() - prob = self.model(torch.from_numpy(chunk), 16000).item() + prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item() + non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks + prob_threshold = settings.behaviour_settings.vad_prob_threshold - if prob > 0.5: - if self.i_since_speech > 3: + if prob > prob_threshold: + if self.i_since_speech > non_speech_patience: self.agent.logger.debug("Speech started.") self.audio_buffer = np.append(self.audio_buffer, chunk) self.i_since_speech = 0 @@ -88,7 +99,7 @@ class StreamingBehaviour(CyclicBehaviour): self.i_since_speech += 1 # prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain - if self.i_since_speech <= 3: + if self.i_since_speech <= non_speech_patience: self.audio_buffer = np.append(self.audio_buffer, chunk) return diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index e0f1292..90ab512 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -5,10 +5,16 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class ZMQSettings(BaseModel): internal_pub_address: str = "tcp://localhost:5560" internal_sub_address: str = "tcp://localhost:5561" + ri_command_address: str = "tcp://localhost:0000" + ri_communication_address: str = "tcp://*:5555" + vad_agent_address: str = "tcp://localhost:5558" class AgentSettings(BaseModel): + # connection settings host: str = "localhost" + + # agent names bdi_core_name: str = "bdi_core_agent" bdi_belief_collector_name: str = "belief_collector_agent" text_belief_extractor_name: str = "text_belief_extractor_agent" @@ -16,14 +22,46 @@ class AgentSettings(BaseModel): llm_name: str = "llm_agent" test_name: str = "test_agent" transcription_name: str = "transcription_agent" - ri_communication_name: str = "ri_communication_agent" robot_speech_name: str = "robot_speech_agent" + # default SPADE port + default_spade_port: int = 5222 + + +class BehaviourSettings(BaseModel): + sleep_s: float = 1.0 + comm_setup_max_retries: int = 5 + socket_poller_timeout_ms: int = 100 + + # VAD settings + vad_prob_threshold: float = 0.5 + vad_initial_since_speech: int = 100 + vad_non_speech_patience_chunks: int = 3 + + # transcription behaviour + transcription_max_concurrent_tasks: int = 3 + transcription_words_per_minute: int = 300 + transcription_words_per_token: float = 0.75 # (3 words = 4 tokens) + transcription_token_buffer: int = 10 + class LLMSettings(BaseModel): local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "openai/gpt-oss-20b" + request_timeout_s: int = 120 + + +class VADSettings(BaseModel): + repo_or_dir: str = "snakers4/silero-vad" + model_name: str = "silero_vad" + sample_rate_hz: int = 16000 + + +class SpeechModelSettings(BaseModel): + # model identifiers for speech recognition + mlx_model_name: str = "mlx-community/whisper-small.en-mlx" + openai_model_name: str = "small.en" class Settings(BaseSettings): @@ -35,6 +73,12 @@ class Settings(BaseSettings): agent_settings: AgentSettings = AgentSettings() + behaviour_settings: BehaviourSettings = BehaviourSettings() + + vad_settings: VADSettings = VADSettings() + + speech_model_settings: SpeechModelSettings = SpeechModelSettings() + llm_settings: LLMSettings = LLMSettings() model_config = SettingsConfigDict(env_file=".env") diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 5a38f39..04b34ff 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -83,7 +83,7 @@ async def lifespan(app: FastAPI): "jid": f"{settings.agent_settings.ri_communication_name}" f"@{settings.agent_settings.host}", "password": settings.agent_settings.ri_communication_name, - "address": "tcp://*:5555", + "address": settings.zmq_settings.ri_communication_address, "bind": True, }, ), @@ -124,7 +124,7 @@ async def lifespan(app: FastAPI): ), "VADAgent": ( VADAgent, - {"audio_in_address": "tcp://localhost:5558", "audio_in_bind": False}, + {"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False}, ), } diff --git a/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py b/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py index d0b8df6..47443a9 100644 --- a/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py +++ b/test/unit/agents/perception/transcription_agent/test_speech_recognizer.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from control_backend.agents.perception.transcription_agent.speech_recognizer import ( OpenAIWhisperSpeechRecognizer, @@ -6,6 +7,24 @@ from control_backend.agents.perception.transcription_agent.speech_recognizer imp ) +@pytest.fixture(autouse=True) +def patch_sr_settings(monkeypatch): + # Patch the *module-local* settings that SpeechRecognizer imported + from control_backend.agents.perception.transcription_agent import speech_recognizer as sr + + # Provide real numbers for everything _estimate_max_tokens() reads + monkeypatch.setattr(sr.settings.vad_settings, "sample_rate_hz", 16_000, raising=False) + monkeypatch.setattr( + sr.settings.behaviour_settings, "transcription_words_per_minute", 450, raising=False + ) + monkeypatch.setattr( + sr.settings.behaviour_settings, "transcription_words_per_token", 0.75, raising=False + ) + monkeypatch.setattr( + sr.settings.behaviour_settings, "transcription_token_buffer", 10, raising=False + ) + + def test_estimate_max_tokens(): """Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding, expecting 610 tokens.""" diff --git a/test/unit/agents/perception/vad_agent/test_vad_streaming.py b/test/unit/agents/perception/vad_agent/test_vad_streaming.py index de488ff..13b3f23 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_streaming.py +++ b/test/unit/agents/perception/vad_agent/test_vad_streaming.py @@ -35,6 +35,23 @@ def streaming(audio_in_socket, audio_out_socket, mock_agent): return streaming +@pytest.fixture(autouse=True) +def patch_settings(monkeypatch): + # Patch the settings that vad_agent.run() reads + from control_backend.agents.perception import vad_agent + + monkeypatch.setattr( + vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False + ) + monkeypatch.setattr( + vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False + ) + monkeypatch.setattr( + vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False + ) + monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False) + + async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): """ Simulates a streaming scenario with given VAD model probabilities for testing purposes. @@ -59,7 +76,6 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming): """ Test a scenario where there is voice activity detected between silences. - :return: """ speech_chunk_count = 5 probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5 @@ -68,8 +84,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] assert isinstance(data, bytes) - # each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding) - assert len(data) == 512 * 4 * (speech_chunk_count + 2) + assert len(data) == 512 * 4 * (speech_chunk_count + 1) @pytest.mark.asyncio @@ -87,8 +102,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str audio_out_socket.send.assert_called_once() data = audio_out_socket.send.call_args[0][0] assert isinstance(data, bytes) - # Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding) - assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2) + # Expecting 13 chunks (2*5 with speech, 1 pause between, 1 as padding) + assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 1) @pytest.mark.asyncio