From 9e926178da32c423fd12b2118698ec52c7714305 Mon Sep 17 00:00:00 2001 From: Pim Hutting Date: Wed, 5 Nov 2025 13:43:57 +0100 Subject: [PATCH] refactor: remove constants and put in config file removed all constants from all files and put them in src/control_backend/core/config.py also removed some old mock agents that we don't use anymore ref: N25B-236 --- et --hard f8dee6d | 41 ++++++++++++++++ .../agents/bdi/behaviours/belief_setter.py | 3 +- .../behaviours/receive_llm_resp_behaviour.py | 3 +- .../bdi/behaviours/text_belief_extractor.py | 3 +- .../behaviours/continuous_collect.py | 3 +- src/control_backend/agents/llm/llm.py | 6 ++- .../agents/mock_agents/__init__.py | 0 .../agents/mock_agents/belief_text_mock.py | 44 ----------------- .../agents/ri_command_agent.py | 4 +- .../agents/ri_communication_agent.py | 8 ++-- .../agents/transcription/speech_recognizer.py | 14 ++++-- .../transcription/transcription_agent.py | 3 +- src/control_backend/agents/vad_agent.py | 26 ++++++---- src/control_backend/core/config.py | 47 ++++++++++++++++++- src/control_backend/main.py | 4 +- 15 files changed, 136 insertions(+), 73 deletions(-) create mode 100644 et --hard f8dee6d delete mode 100644 src/control_backend/agents/mock_agents/__init__.py delete mode 100644 src/control_backend/agents/mock_agents/belief_text_mock.py diff --git a/et --hard f8dee6d b/et --hard f8dee6d new file mode 100644 index 0000000..663bfc7 --- /dev/null +++ b/et --hard f8dee6d @@ -0,0 +1,41 @@ +bcbfc26 (HEAD -> feat/belief-collector, origin/feat/add-end-of-utterance-detection) HEAD@{0}: reset: moving to ORIG_HEAD +e48096f HEAD@{1}: checkout: moving from feat/add-end-of-utterance-detection to feat/belief-collector +ab94c2e (feat/add-end-of-utterance-detection) HEAD@{2}: commit (merge): Merge remote-tracking branch 'origin/dev' into feat/add-end-of-utterance-detection +bcbfc26 (HEAD -> feat/belief-collector, origin/feat/add-end-of-utterance-detection) HEAD@{3}: checkout: moving from feat/belief-collector to feat/add-end-of-utterance-detection +e48096f HEAD@{4}: checkout: moving from feat/add-end-of-utterance-detection to feat/belief-collector +bcbfc26 (HEAD -> feat/belief-collector, origin/feat/add-end-of-utterance-detection) HEAD@{5}: checkout: moving from feat/belief-collector to feat/add-end-of-utterance-detection +e48096f HEAD@{6}: reset: moving to HEAD +e48096f HEAD@{7}: commit (merge): Merge remote-tracking branch 'origin/dev' into feat/belief-collector +f8dee6d (origin/feat/belief-collector) HEAD@{8}: commit: test: added tests +2efce93 HEAD@{9}: checkout: moving from dev to feat/belief-collector +e36f5fc (origin/dev, dev) HEAD@{10}: pull: Fast-forward +9b36982 HEAD@{11}: checkout: moving from feat/belief-collector to dev +2efce93 HEAD@{12}: checkout: moving from feat/vad-agent to feat/belief-collector +f73f510 (origin/feat/vad-agent, feat/vad-agent) HEAD@{13}: checkout: moving from feat/vad-agent to feat/vad-agent +f73f510 (origin/feat/vad-agent, feat/vad-agent) HEAD@{14}: pull: Fast-forward +fd1face HEAD@{15}: checkout: moving from feat/belief-collector to feat/vad-agent +2efce93 HEAD@{16}: reset: moving to HEAD +2efce93 HEAD@{17}: commit: fix: made beliefs a dict of lists +1f34b14 HEAD@{18}: commit: Feat: Implement belief collector +9b36982 HEAD@{19}: checkout: moving from style/fix-style to feat/belief-collector +65cfdda (origin/style/fix-style, style/fix-style) HEAD@{20}: checkout: moving from feat/belief-collector to style/fix-style +9b36982 HEAD@{21}: reset: moving to HEAD +9b36982 HEAD@{22}: checkout: moving from dev to feat/belief-collector +9b36982 HEAD@{23}: checkout: moving from feat/belief-collector to dev +9b36982 HEAD@{24}: reset: moving to HEAD +9b36982 HEAD@{25}: checkout: moving from feat/belief-from-text to feat/belief-collector +bece44b (feat/belief-from-text) HEAD@{26}: checkout: moving from feat/belief-collector to feat/belief-from-text +9b36982 HEAD@{27}: reset: moving to HEAD +9b36982 HEAD@{28}: checkout: moving from dev to feat/belief-collector +9b36982 HEAD@{29}: pull: Fast-forward +71ddb50 HEAD@{30}: checkout: moving from feat/add-end-of-utterance-detection to dev +bcbfc26 (HEAD -> feat/belief-collector, origin/feat/add-end-of-utterance-detection) HEAD@{31}: commit: feat: prototype end-of-utterance scorer over text input +379e04a (origin/feat/add-speech-recognition) HEAD@{32}: checkout: moving from feat/add-end-of-utterance-detection to feat/add-end-of-utterance-detection +379e04a (origin/feat/add-speech-recognition) HEAD@{33}: rebase (abort): updating HEAD +71ddb50 HEAD@{34}: rebase (start): checkout dev +379e04a (origin/feat/add-speech-recognition) HEAD@{35}: checkout: moving from dev to feat/add-end-of-utterance-detection +71ddb50 HEAD@{36}: checkout: moving from feat/add-end-of-utterance-detection to dev +379e04a (origin/feat/add-speech-recognition) HEAD@{37}: checkout: moving from feat/add-end-of-utterance-detection to feat/add-end-of-utterance-detection +379e04a (origin/feat/add-speech-recognition) HEAD@{38}: checkout: moving from feat/add-end-of-utterance-detection to feat/add-end-of-utterance-detection +379e04a (origin/feat/add-speech-recognition) HEAD@{39}: checkout: moving from main to feat/add-end-of-utterance-detection +54b22d8 (origin/main, origin/HEAD, main) HEAD@{40}: clone: from git.science.uu.nl:ics/sp/2025/n25b/pepperplus-cb.git diff --git a/src/control_backend/agents/bdi/behaviours/belief_setter.py b/src/control_backend/agents/bdi/behaviours/belief_setter.py index 2f64036..69950b6 100644 --- a/src/control_backend/agents/bdi/behaviours/belief_setter.py +++ b/src/control_backend/agents/bdi/behaviours/belief_setter.py @@ -18,7 +18,8 @@ class BeliefSetterBehaviour(CyclicBehaviour): logger = logging.getLogger("BDI/Belief Setter") async def run(self): - msg = await self.receive(timeout=0.1) + t = settings.behaviour_settings.default_rcv_timeout + msg = await self.receive(timeout=t) if msg: self.logger.info(f"Received message {msg.body}") self._process_message(msg) diff --git a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py index dc6e862..0d4788e 100644 --- a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py +++ b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py @@ -13,7 +13,8 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour): logger = logging.getLogger("BDI/LLM Reciever") async def run(self): - msg = await self.receive(timeout=2) + t = settings.llm_settings.llm_response_rcv_timeout + msg = await self.receive(timeout=t) if not msg: return diff --git a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py index ed06463..9f10f1c 100644 --- a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py +++ b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py @@ -39,7 +39,8 @@ class BeliefFromText(CyclicBehaviour): beliefs = {"mood": ["X"], "car": ["Y"]} async def run(self): - msg = await self.receive(timeout=0.1) + t = settings.behaviour_settings.default_rcv_timeout + msg = await self.receive(timeout=t) if msg: sender = msg.sender.node match sender: diff --git a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py index eb3ee5d..fb0a5af 100644 --- a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py +++ b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py @@ -16,7 +16,8 @@ class ContinuousBeliefCollector(CyclicBehaviour): """ async def run(self): - msg = await self.receive(timeout=0.1) # Wait for 0.1s + t = settings.behaviour_settings.default_rcv_timeout + msg = await self.receive(timeout=t) if msg: await self._process_message(msg) diff --git a/src/control_backend/agents/llm/llm.py b/src/control_backend/agents/llm/llm.py index c3c17ab..6944180 100644 --- a/src/control_backend/agents/llm/llm.py +++ b/src/control_backend/agents/llm/llm.py @@ -35,7 +35,8 @@ class LLMAgent(Agent): Receives SPADE messages and processes only those originating from the configured BDI agent. """ - msg = await self.receive(timeout=1) + t = settings.behaviour_settings.llm_response_rcv_timeout + msg = await self.receive(timeout=t) if not msg: return @@ -78,7 +79,8 @@ class LLMAgent(Agent): :param prompt: Input text prompt to pass to the LLM. :return: LLM-generated content or fallback message. """ - async with httpx.AsyncClient(timeout=120.0) as client: + t = settings.llm_settings.request_timeout_s + async with httpx.AsyncClient(timeout=t) as client: # Example dynamic content for future (optional) instructions = LLMInstructions() diff --git a/src/control_backend/agents/mock_agents/__init__.py b/src/control_backend/agents/mock_agents/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/control_backend/agents/mock_agents/belief_text_mock.py b/src/control_backend/agents/mock_agents/belief_text_mock.py deleted file mode 100644 index 27c5e49..0000000 --- a/src/control_backend/agents/mock_agents/belief_text_mock.py +++ /dev/null @@ -1,44 +0,0 @@ -import json - -from spade.agent import Agent -from spade.behaviour import OneShotBehaviour -from spade.message import Message - -from control_backend.core.config import settings - - -class BeliefTextAgent(Agent): - class SendOnceBehaviourBlfText(OneShotBehaviour): - async def run(self): - to_jid = ( - settings.agent_settings.belief_collector_agent_name - + "@" - + settings.agent_settings.host - ) - - # Send multiple beliefs in one JSON payload - payload = { - "type": "belief_extraction_text", - "beliefs": { - "user_said": [ - "hello test", - "Can you help me?", - "stop talking to me", - "No", - "Pepper do a dance", - ] - }, - } - - msg = Message(to=to_jid) - msg.body = json.dumps(payload) - await self.send(msg) - print(f"Beliefs sent to {to_jid}!") - - self.exit_code = "Job Finished!" - await self.agent.stop() - - async def setup(self): - print("BeliefTextAgent started") - self.b = self.SendOnceBehaviourBlfText() - self.add_behaviour(self.b) diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 51b8064..fc238f5 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -22,9 +22,9 @@ class RICommandAgent(Agent): self, jid: str, password: str, - port: int = 5222, + port: int = settings.agent_settings.default_spade_port, verify_security: bool = False, - address="tcp://localhost:0000", + address=settings.zmq_settings.ri_command_address, bind=False, ): super().__init__(jid, password, port, verify_security) diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 8d56b09..c2340a6 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -21,9 +21,9 @@ class RICommunicationAgent(Agent): self, jid: str, password: str, - port: int = 5222, + port: int = settings.agent_settings.default_spade_port, verify_security: bool = False, - address="tcp://localhost:0000", + address=settings.zmq_settings.ri_command_address, bind=False, ): super().__init__(jid, password, port, verify_security) @@ -58,13 +58,13 @@ class RICommunicationAgent(Agent): # See what endpoint we received match message["endpoint"]: case "ping": - await asyncio.sleep(1) + await asyncio.sleep(settings.agent_settings.behaviour_settings.ping_sleep_s) case _: logger.info( "Received message with topic different than ping, while ping expected." ) - async def setup(self, max_retries: int = 5): + async def setup(self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries): """ Try to setup the communication agent, we have 5 retries in case we dont have a response yet. """ diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index 19d82ff..40d9215 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -10,6 +10,8 @@ import numpy as np import torch import whisper +from control_backend.core.config import settings + class SpeechRecognizer(abc.ABC): def __init__(self, limit_output_length=True): @@ -41,10 +43,10 @@ class SpeechRecognizer(abc.ABC): :param audio: The audio sample (16 kHz) to use for length estimation. :return: The estimated length of the transcribed audio in tokens. """ - length_seconds = len(audio) / 16_000 + length_seconds = len(audio) / settings.vad_settings.sample_rate_hz length_minutes = length_seconds / 60 - word_count = length_minutes * 300 - token_count = word_count / 3 * 4 + word_count = length_minutes * settings.behaviour_settings.transcription_words_per_minute + token_count = word_count / settings.behaviour_settings.transcription_words_per_token return int(token_count) def _get_decode_options(self, audio: np.ndarray) -> dict: @@ -72,7 +74,7 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def __init__(self, limit_output_length=True): super().__init__(limit_output_length) self.was_loaded = False - self.model_name = "mlx-community/whisper-small.en-mlx" + self.model_name = settings.speech_model_settings.mlx_model_name def load_model(self): if self.was_loaded: @@ -99,7 +101,9 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): if self.model is not None: return device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.model = whisper.load_model("small.en", device=device) + self.model = whisper.load_model( + settings.speech_model_settings.openai_model_name, device=device + ) def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index 2d936c4..52c0056 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -31,9 +31,10 @@ class TranscriptionAgent(Agent): class Transcribing(CyclicBehaviour): def __init__(self, audio_in_socket: azmq.Socket): super().__init__() + max_concurrent_tasks = settings.transcription_settings.max_concurrent_transcriptions self.audio_in_socket = audio_in_socket self.speech_recognizer = SpeechRecognizer.best_type() - self._concurrency = asyncio.Semaphore(3) + self._concurrency = asyncio.Semaphore(max_concurrent_tasks) def warmup(self): """Load the transcription model into memory to speed up the first transcription.""" diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index a228135..42c26ef 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -20,7 +20,11 @@ class SocketPoller[T]: multiple usages. """ - def __init__(self, socket: azmq.Socket, timeout_ms: int = 100): + def __init__( + self, + socket: azmq.Socket, + timeout_ms: int = settings.behaviour_settings.socket_poller_timeout_ms, + ): """ :param socket: The socket to poll and get data from. :param timeout_ms: A timeout in milliseconds to wait for data. @@ -49,12 +53,16 @@ class Streaming(CyclicBehaviour): super().__init__() self.audio_in_poller = SocketPoller[bytes](audio_in_socket) self.model, _ = torch.hub.load( - repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False + repo_or_dir=settings.vad_settings.repo_or_dir, + model=settings.vad_settings.model_name, + force_reload=False, ) self.audio_out_socket = audio_out_socket self.audio_buffer = np.array([], dtype=np.float32) - self.i_since_speech = 100 # Used to allow small pauses in speech + self.i_since_speech = ( + settings.behaviour_settings.vad_initial_since_speech + ) # Used to allow small pauses in speech async def run(self) -> None: data = await self.audio_in_poller.poll() @@ -62,15 +70,17 @@ class Streaming(CyclicBehaviour): if len(self.audio_buffer) > 0: logger.debug("No audio data received. Discarding buffer until new data arrives.") self.audio_buffer = np.array([], dtype=np.float32) - self.i_since_speech = 100 + self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech return # copy otherwise Torch will be sad that it's immutable chunk = np.frombuffer(data, dtype=np.float32).copy() - prob = self.model(torch.from_numpy(chunk), 16000).item() + prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item() + non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks + prob_threshold = settings.behaviour_settings.vad_prob_threshold - if prob > 0.5: - if self.i_since_speech > 3: + if prob > prob_threshold: + if self.i_since_speech > non_speech_patience: logger.debug("Speech started.") self.audio_buffer = np.append(self.audio_buffer, chunk) self.i_since_speech = 0 @@ -78,7 +88,7 @@ class Streaming(CyclicBehaviour): self.i_since_speech += 1 # prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain - if self.i_since_speech <= 3: + if self.i_since_speech <= non_speech_patience: self.audio_buffer = np.append(self.audio_buffer, chunk) return diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 2fd16b8..826d972 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -4,10 +4,16 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class ZMQSettings(BaseModel): internal_comm_address: str = "tcp://localhost:5560" + ri_command_address: str = "tcp://localhost:0000" + ri_communication_address: str = "tcp://*:5555" + vad_agent_address: str = "tcp://localhost:5558" class AgentSettings(BaseModel): + # connection settings host: str = "localhost" + + # agent names bdi_core_agent_name: str = "bdi_core" belief_collector_agent_name: str = "belief_collector" text_belief_extractor_agent_name: str = "text_belief_extractor" @@ -15,14 +21,47 @@ class AgentSettings(BaseModel): llm_agent_name: str = "llm_agent" test_agent_name: str = "test_agent" transcription_agent_name: str = "transcription_agent" - ri_communication_agent_name: str = "ri_communication_agent" ri_command_agent_name: str = "ri_command_agent" + # default SPADE port + default_spade_port: int = 5222 + + +class BehaviourSettings(BaseModel): + default_rcv_timeout: float = 0.1 + llm_response_rcv_timeout: float = 1.0 + ping_sleep_s: float = 1.0 + comm_setup_max_retries: int = 5 + socket_poller_timeout_ms: int = 100 + + # VAD settings + vad_prob_threshold: float = 0.5 + vad_initial_since_speech: int = 100 + vad_non_speech_patience_chunks: int = 3 + + # transcription behaviour + transcription_max_concurrent_tasks: int = 3 + transcription_words_per_minute: int = 300 + transcription_words_per_token: float = 0.75 # (3 words = 4 tokens) + class LLMSettings(BaseModel): local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "openai/gpt-oss-20b" + request_timeout_s: int = 120 + + +class VADSettings(BaseModel): + repo_or_dir: str = "snakers4/silero-vad" + model_name: str = "silero_vad" + sample_rate_hz: int = 16000 + + +class SpeechModelSettings(BaseModel): + # model identifiers for speech recognition + mlx_model_name: str = "mlx-community/whisper-small.en-mlx" + openai_model_name: str = "small.en" class Settings(BaseSettings): @@ -34,6 +73,12 @@ class Settings(BaseSettings): agent_settings: AgentSettings = AgentSettings() + behaviour_settings: BehaviourSettings = BehaviourSettings() + + vad_settings: VADSettings = VADSettings() + + speech_model_settings: SpeechModelSettings = SpeechModelSettings() + llm_settings: LLMSettings = LLMSettings() model_config = SettingsConfigDict(env_file=".env") diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 138957c..a2cc7f6 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -39,7 +39,7 @@ async def lifespan(app: FastAPI): ri_communication_agent = RICommunicationAgent( settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.ri_communication_agent_name, - address="tcp://*:5555", + address=settings.zmq_settings.ri_communication_address, bind=True, ) await ri_communication_agent.start() @@ -71,7 +71,7 @@ async def lifespan(app: FastAPI): ) await text_belief_extractor.start() - _temp_vad_agent = VADAgent("tcp://localhost:5558", False) + _temp_vad_agent = VADAgent(settings.zmq_settings.vad_agent_address, False) await _temp_vad_agent.start() yield