From 86938f79c0a3f3e7b776f277c8e10a25a8106133 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:42:25 +0100 Subject: [PATCH 01/18] feat: end to end connected for demo Includes the Transcription agent. Involved updating the RI agent to receive messages from other agents, sending speech commands to the RI agent, and some performance optimizations. ref: N25B-216 --- .../behaviours/receive_llm_resp_behaviour.py | 17 ++- src/control_backend/agents/llm/llm.py | 123 ++++++++++++------ .../agents/llm/llm_instructions.py | 4 +- .../agents/ri_command_agent.py | 15 +++ .../agents/transcription/speech_recognizer.py | 3 - src/control_backend/agents/vad_agent.py | 17 ++- src/control_backend/main.py | 2 + 7 files changed, 132 insertions(+), 49 deletions(-) diff --git a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py index 747ab4c..33525f0 100644 --- a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py +++ b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py @@ -1,15 +1,18 @@ import logging from spade.behaviour import CyclicBehaviour +from spade.message import Message from control_backend.core.config import settings +from control_backend.schemas.ri_message import SpeechCommand class ReceiveLLMResponseBehaviour(CyclicBehaviour): """ Adds behavior to receive responses from the LLM Agent. """ - logger = logging.getLogger("BDI/LLM Reciever") + logger = logging.getLogger("BDI/LLM Receiver") + async def run(self): msg = await self.receive(timeout=2) if not msg: @@ -20,7 +23,17 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour): case settings.agent_settings.llm_agent_name: content = msg.body self.logger.info("Received LLM response: %s", content) - #Here the BDI can pass the message back as a response + + speech_command = SpeechCommand(data=content) + + message = Message(to=settings.agent_settings.ri_command_agent_name + + '@' + settings.agent_settings.host, + sender=self.agent.jid, + body=speech_command.model_dump_json()) + + self.logger.debug("Sending message: %s", message) + + await self.send(message) case _: self.logger.debug("Not from the llm, discarding message") pass \ No newline at end of file diff --git a/src/control_backend/agents/llm/llm.py b/src/control_backend/agents/llm/llm.py index 0f78095..96658f6 100644 --- a/src/control_backend/agents/llm/llm.py +++ b/src/control_backend/agents/llm/llm.py @@ -2,9 +2,10 @@ LLM Agent module for routing text queries from the BDI Core Agent to a local LLM service and returning its responses back to the BDI Core Agent. """ - +import json import logging -from typing import Any +import re +from typing import AsyncGenerator import httpx from spade.agent import Agent @@ -54,11 +55,15 @@ class LLMAgent(Agent): async def _process_bdi_message(self, message: Message): """ - Forwards user text to the LLM and replies with the generated text. + Forwards user text from the BDI to the LLM and replies with the generated text in chunks + separated by punctuation. """ user_text = message.body - llm_response = await self._query_llm(user_text) - await self._reply(llm_response) + # Consume the streaming generator and send a reply for every chunk + async for chunk in self._query_llm(user_text): + await self._reply(chunk) + self.agent.logger.debug("Finished processing BDI message. " + "Response sent in chunks to BDI Core Agent.") async def _reply(self, msg: str): """ @@ -69,52 +74,88 @@ class LLMAgent(Agent): body=msg ) await self.send(reply) - self.agent.logger.info("Reply sent to BDI Core Agent") - async def _query_llm(self, prompt: str) -> str: + async def _query_llm(self, prompt: str) -> AsyncGenerator[str]: """ - Sends a chat completion request to the local LLM service. + Sends a chat completion request to the local LLM service and streams the response by + yielding fragments separated by punctuation like. :param prompt: Input text prompt to pass to the LLM. - :return: LLM-generated content or fallback message. + :yield: Fragments of the LLM-generated content. """ - async with httpx.AsyncClient(timeout=120.0) as client: - # Example dynamic content for future (optional) - - instructions = LLMInstructions() - developer_instruction = instructions.build_developer_instruction() - - response = await client.post( + instructions = LLMInstructions( + "- Be friendly and respectful.\n" + "- Make the conversation feel natural and engaging.\n" + "- Speak like a pirate.\n" + "- When the user asks what you can do, tell them.", + "- Try to learn the user's name during conversation.\n" + "- Suggest playing a game of asking yes or no questions where you think of a word " + "and the user must guess it.", + ) + messages = [ + { + "role": "developer", + "content": instructions.build_developer_instruction(), + }, + { + "role": "user", + "content": prompt, + } + ] + + try: + current_chunk = "" + async for token in self._stream_query_llm(messages): + current_chunk += token + + # Stream the message in chunks separated by punctuation. + # We include the delimiter in the emitted chunk for natural flow. + pattern = re.compile( + r".*?(?:,|;|:|—|–|-|\.{3}|…|\.|\?|!|\(|\)|\[|\]|/)\s*", + re.DOTALL + ) + for m in pattern.finditer(current_chunk): + chunk = m.group(0) + if chunk: + yield current_chunk + current_chunk = "" + + # Yield any remaining tail + if current_chunk: yield current_chunk + except httpx.HTTPError as err: + self.agent.logger.error("HTTP error.", exc_info=err) + yield "LLM service unavailable." + except Exception as err: + self.agent.logger.error("Unexpected error.", exc_info=err) + yield "Error processing the request." + + async def _stream_query_llm(self, messages) -> AsyncGenerator[str]: + """Raises httpx.HTTPError when the API gives an error.""" + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream( + "POST", settings.llm_settings.local_llm_url, - headers={"Content-Type": "application/json"}, json={ "model": settings.llm_settings.local_llm_model, - "messages": [ - { - "role": "developer", - "content": developer_instruction - }, - { - "role": "user", - "content": prompt - } - ], - "temperature": 0.3 + "messages": messages, + "temperature": 0.3, + "stream": True, }, - ) - - try: + ) as response: response.raise_for_status() - data: dict[str, Any] = response.json() - return data.get("choices", [{}])[0].get( - "message", {} - ).get("content", "No response") - except httpx.HTTPError as err: - self.agent.logger.error("HTTP error: %s", err) - return "LLM service unavailable." - except Exception as err: - self.agent.logger.error("Unexpected error: %s", err) - return "Error processing the request." + + async for line in response.aiter_lines(): + if not line or not line.startswith("data: "): continue + + data = line[len("data: "):] + if data.strip() == "[DONE]": break + + try: + event = json.loads(data) + delta = event.get("choices", [{}])[0].get("delta", {}).get("content") + if delta: yield delta + except json.JSONDecodeError: + self.agent.logger.error("Failed to parse LLM response: %s", data) async def setup(self): """ diff --git a/src/control_backend/agents/llm/llm_instructions.py b/src/control_backend/agents/llm/llm_instructions.py index 9636d88..e3aed7e 100644 --- a/src/control_backend/agents/llm/llm_instructions.py +++ b/src/control_backend/agents/llm/llm_instructions.py @@ -28,7 +28,9 @@ class LLMInstructions: """ sections = [ "You are a Pepper robot engaging in natural human conversation.", - "Keep responses between 1–5 sentences, unless instructed otherwise.\n", + "Keep responses between 1–3 sentences, unless told otherwise.\n", + "You're given goals to reach. Reach them in order, but make the conversation feel " + "natural. Some turns you should not try to achieve your goals.\n" ] if self.norms: diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 01fc824..f8234ce 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -1,5 +1,7 @@ import json import logging + +import spade.agent from spade.agent import Agent from spade.behaviour import CyclicBehaviour import zmq @@ -31,6 +33,7 @@ class RICommandAgent(Agent): self.bind = bind class SendCommandsBehaviour(CyclicBehaviour): + """Behaviour for sending commands received from the UI.""" async def run(self): """ Run the command publishing loop indefinetely. @@ -49,6 +52,17 @@ class RICommandAgent(Agent): except Exception as e: logger.error("Error processing message: %s", e) + class SendPythonCommandsBehaviour(CyclicBehaviour): + """Behaviour for sending commands received from other Python agents.""" + async def run(self): + message: spade.agent.Message = await self.receive(timeout=0.1) + if message and message.to == self.agent.jid: + try: + speech_command = SpeechCommand.model_validate_json(message.body) + await self.agent.pubsocket.send_json(speech_command.model_dump()) + except Exception as e: + logger.error("Error processing message: %s", e) + async def setup(self): """ Setup the command agent @@ -70,5 +84,6 @@ class RICommandAgent(Agent): # Add behaviour to our agent commands_behaviour = self.SendCommandsBehaviour() self.add_behaviour(commands_behaviour) + self.add_behaviour(self.SendPythonCommandsBehaviour()) logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index f316cda..83a5fd3 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -83,9 +83,6 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return mlx_whisper.transcribe(audio, - path_or_hf_repo=self.model_name, - decode_options=self._get_decode_options(audio))["text"] return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip() diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index a228135..5b7f598 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -55,8 +55,19 @@ class Streaming(CyclicBehaviour): self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = 100 # Used to allow small pauses in speech + self._ready = False + + async def reset(self): + """Clears the ZeroMQ queue and tells this behavior to start.""" + discarded = 0 + while await self.audio_in_poller.poll(1) is not None: + discarded += 1 + logging.info(f"Discarded {discarded} audio packets before starting.") + self._ready = True async def run(self) -> None: + if not self._ready: return + data = await self.audio_in_poller.poll() if data is None: if len(self.audio_buffer) > 0: @@ -108,6 +119,8 @@ class VADAgent(Agent): self.audio_in_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None + self.streaming_behaviour: Streaming | None = None + async def stop(self): """ Stop listening to audio, stop publishing audio, close sockets. @@ -150,8 +163,8 @@ class VADAgent(Agent): return audio_out_address = f"tcp://localhost:{audio_out_port}" - streaming = Streaming(self.audio_in_socket, self.audio_out_socket) - self.add_behaviour(streaming) + self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket) + self.add_behaviour(self.streaming_behaviour) # Start agents dependent on the output audio fragments here transcriber = TranscriptionAgent(audio_out_address) diff --git a/src/control_backend/main.py b/src/control_backend/main.py index d3588ea..4684746 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -71,6 +71,8 @@ async def lifespan(app: FastAPI): _temp_vad_agent = VADAgent("tcp://localhost:5558", False) await _temp_vad_agent.start() + logger.info("VAD agent started, now making ready...") + await _temp_vad_agent.streaming_behaviour.reset() yield From 4ffe3b2071410655b59e7ebc117f69d79ccf57a4 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Thu, 30 Oct 2025 16:40:45 +0100 Subject: [PATCH 02/18] fix: make VAD unit tests work after changes Namely, the Streamer has to be marked ready. ref: N25B-216 --- test/integration/agents/vad_agent/test_vad_with_audio.py | 1 + test/unit/agents/test_vad_streaming.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/integration/agents/vad_agent/test_vad_with_audio.py b/test/integration/agents/vad_agent/test_vad_with_audio.py index 7d10aa3..fd7d4d7 100644 --- a/test/integration/agents/vad_agent/test_vad_with_audio.py +++ b/test/integration/agents/vad_agent/test_vad_with_audio.py @@ -48,6 +48,7 @@ async def test_real_audio(mocker): audio_out_socket = AsyncMock() vad_streamer = Streaming(audio_in_socket, audio_out_socket) + vad_streamer._ready = True for _ in audio_chunks: await vad_streamer.run() diff --git a/test/unit/agents/test_vad_streaming.py b/test/unit/agents/test_vad_streaming.py index 9b38cd0..ab2da0d 100644 --- a/test/unit/agents/test_vad_streaming.py +++ b/test/unit/agents/test_vad_streaming.py @@ -21,7 +21,9 @@ def streaming(audio_in_socket, audio_out_socket): import torch torch.hub.load.return_value = (..., ...) # Mock - return Streaming(audio_in_socket, audio_out_socket) + streaming = Streaming(audio_in_socket, audio_out_socket) + streaming._ready = True + return streaming async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): From e5782b421f9f3be32074cb9ca09528cc8dfa49f8 Mon Sep 17 00:00:00 2001 From: Kasper Date: Sun, 2 Nov 2025 18:45:57 +0100 Subject: [PATCH 03/18] build: fix pre-commit Moved the git hooks into shell scripts and placed them into the pre-commit configuration. Also extended robustness of the hooks. ref: N25B-241 --- .githooks/check-branch-name.sh | 77 ++++++++++++++++++++++++++++ .githooks/check-commit-msg.sh | 93 ++++++++++++++++++++++++++++++++++ .githooks/commit-msg | 16 ------ .githooks/pre-commit | 17 ------- .githooks/prepare-commit-msg | 9 ---- .pre-commit-config.yaml | 32 ++++++++---- 6 files changed, 193 insertions(+), 51 deletions(-) create mode 100755 .githooks/check-branch-name.sh create mode 100755 .githooks/check-commit-msg.sh delete mode 100644 .githooks/commit-msg delete mode 100644 .githooks/pre-commit delete mode 100644 .githooks/prepare-commit-msg diff --git a/.githooks/check-branch-name.sh b/.githooks/check-branch-name.sh new file mode 100755 index 0000000..752e199 --- /dev/null +++ b/.githooks/check-branch-name.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +# This script checks if the current branch name follows the specified format. +# It's designed to be used as a 'pre-commit' git hook. + +# Format: / +# Example: feat/add-user-login + +# --- Configuration --- +# An array of allowed commit types +ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) +# An array of branches to ignore +IGNORED_BRANCHES=(main dev) + +# --- Colors for Output --- +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# --- Helper Functions --- +error_exit() { + echo -e "${RED}ERROR: $1${NC}" >&2 + echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2 + exit 1 +} + +# --- Main Logic --- + +# 1. Get the current branch name +BRANCH_NAME=$(git symbolic-ref --short HEAD) + +# 2. Check if the current branch is in the ignored list +for ignored_branch in "${IGNORED_BRANCHES[@]}"; do + if [ "$BRANCH_NAME" == "$ignored_branch" ]; then + echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}" + exit 0 + fi +done + +# 3. Validate the overall structure: / +if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then + error_exit "Branch name must be in the format: /\nExample: feat/add-user-login" +fi + +# 4. Extract the type and description +TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1) +DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-) + +# 5. Validate the +type_valid=false +for allowed_type in "${ALLOWED_TYPES[@]}"; do + if [ "$TYPE" == "$allowed_type" ]; then + type_valid=true + break + fi +done + +if [ "$type_valid" == false ]; then + error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}" +fi + +# 6. Validate the +# Regex breakdown: +# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word). +# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times. +# $ - End of the string. +# This entire pattern enforces 1 to 6 words total, separated by dashes. +DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$" + +if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then + error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature" +fi + +# If all checks pass, exit successfully +echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}" +exit 0 diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh new file mode 100755 index 0000000..82bd441 --- /dev/null +++ b/.githooks/check-commit-msg.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +# This script checks if a commit message follows the specified format. +# It's designed to be used as a 'commit-msg' git hook. + +# Format: +# : +# +# [optional] +# +# [ref/close]: + +# --- Configuration --- +# An array of allowed commit types +ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) + +# --- Colors for Output --- +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# The first argument to the hook is the path to the file containing the commit message +COMMIT_MSG_FILE=$1 + +# --- Validation Functions --- + +# Function to print an error message and exit +# Usage: error_exit "Your error message here" +error_exit() { + # >&2 redirects echo to stderr + echo -e "${RED}ERROR: $1${NC}" >&2 + echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2 + exit 1 +} + +# --- Main Logic --- + +# 1. Read the header (first line) of the commit message +HEADER=$(head -n 1 "$COMMIT_MSG_FILE") + +# 2. Validate the header format: : +# Regex breakdown: +# ^(type1|type2|...) - Starts with one of the allowed types +# : - Followed by a literal colon +# \s - Followed by a single space +# .+ - Followed by one or more characters for the description +# $ - End of the line +TYPES_REGEX=$( + IFS="|" + echo "${ALLOWED_TYPES[*]}" +) +HEADER_REGEX="^($TYPES_REGEX): .+$" + +if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then + error_exit "Invalid header format.\n\nHeader must be in the format: : \nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature" +fi + +# 3. Validate the footer (last line) of the commit message +FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE") + +# Regex breakdown: +# ^(ref|close) - Starts with 'ref' or 'close' +# : - Followed by a literal colon +# \s - Followed by a single space +# N25B- - Followed by the literal string 'N25B-' +# [0-9]+ - Followed by one or more digits +# $ - End of the line +FOOTER_REGEX="^(ref|close): N25B-[0-9]+$" + +if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then + error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: \nExample: ref: N25B-123" +fi + +# 4. If the message has more than 2 lines, validate the separator +# A blank line must exist between the header and the body. +LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace + +# We only care if there is a body. Header + Footer = 2 lines. +# Header + Blank Line + Body... + Footer > 2 lines. +if [ "$LINE_COUNT" -gt 2 ]; then + # Get the second line + SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE") + + # Check if the second line is NOT empty. If it's not, it's an error. + if [ -n "$SECOND_LINE" ]; then + error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present." + fi +fi + +# If all checks pass, exit with success +echo -e "${GREEN}Commit message is valid.${NC}" +exit 0 diff --git a/.githooks/commit-msg b/.githooks/commit-msg deleted file mode 100644 index 41992ad..0000000 --- a/.githooks/commit-msg +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/sh - -commit_msg_file=$1 -commit_msg=$(cat "$commit_msg_file") - -if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then - if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then - exit 0 - else - echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000" - exit 1 - fi -else - echo "❌ Commit message invalid! Must start with : " - exit 1 -fi \ No newline at end of file diff --git a/.githooks/pre-commit b/.githooks/pre-commit deleted file mode 100644 index 7e94937..0000000 --- a/.githooks/pre-commit +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/sh - -# Get current branch -branch=$(git rev-parse --abbrev-ref HEAD) - -if echo "$branch" | grep -Eq "(dev|main)"; then - echo 0 -fi - -# allowed pattern -if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then - exit 0 -else - echo "❌ Invalid branch name: $branch" - echo "Branch must be named / (must have one to six words separated by a dash)" - exit 1 -fi \ No newline at end of file diff --git a/.githooks/prepare-commit-msg b/.githooks/prepare-commit-msg deleted file mode 100644 index 5b706c1..0000000 --- a/.githooks/prepare-commit-msg +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/sh - -echo "#: - -#[optional body] - -#[optional footer(s)] - -#[ref/close]: " > $1 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6ed188..41710dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,24 @@ repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.14.2 - hooks: - # Run the linter. - - id: ruff-check - args: [ --fix ] - # Run the formatter. - - id: ruff-format \ No newline at end of file + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.14.2 + hooks: + # Run the linter. + - id: ruff-check + # Run the formatter. + - id: ruff-format + # Configure local hooks + - repo: local + hooks: + - id: check-commit-msg + name: Check commit message format + entry: .githooks/check-commit-msg.sh + language: script + stages: [commit-msg] + - id: check-branch-name + name: Check branch name format + entry: .githooks/check-branch-name.sh + language: script + stages: [pre-commit] + always_run: true + pass_filenames: false From 48c97464175116c1f236e00cfaaee2e80dec1d29 Mon Sep 17 00:00:00 2001 From: Kasper Date: Sun, 2 Nov 2025 19:45:01 +0100 Subject: [PATCH 04/18] style: apply ruff check and format Made sure all ruff checks pass and formatted all files. ref: N25B-224 --- src/control_backend/agents/bdi/bdi_core.py | 6 +- .../agents/bdi/behaviours/belief_setter.py | 6 +- .../behaviours/receive_llm_resp_behaviour.py | 8 +- .../bdi/behaviours/text_belief_extractor.py | 40 +++++----- .../agents/bdi/text_extractor.py | 2 +- .../behaviours/continuous_collect.py | 38 +++++----- .../belief_collector/belief_collector.py | 4 +- src/control_backend/agents/llm/llm.py | 28 +++---- .../agents/mock_agents/belief_text_mock.py | 19 ++++- .../agents/ri_command_agent.py | 3 +- .../agents/ri_communication_agent.py | 11 ++- .../agents/transcription/speech_recognizer.py | 18 +++-- .../transcription/transcription_agent.py | 3 +- .../api/v1/endpoints/command.py | 5 +- src/control_backend/api/v1/router.py | 2 +- src/control_backend/core/config.py | 2 + src/control_backend/main.py | 20 ++--- src/control_backend/schemas/ri_message.py | 4 +- .../agents/test_ri_commands_agent.py | 8 +- .../agents/test_ri_communication_agent.py | 12 +-- .../api/endpoints/test_command_endpoint.py | 3 +- test/integration/schemas/test_ri_message.py | 22 ++---- .../bdi/behaviours/test_belief_setter.py | 1 + .../behaviours/test_continuous_collect.py | 73 +++++++++++++------ .../transcription/test_speech_recognizer.py | 4 +- 25 files changed, 199 insertions(+), 143 deletions(-) diff --git a/src/control_backend/agents/bdi/bdi_core.py b/src/control_backend/agents/bdi/bdi_core.py index 06c7b01..6e5cdc0 100644 --- a/src/control_backend/agents/bdi/bdi_core.py +++ b/src/control_backend/agents/bdi/bdi_core.py @@ -58,11 +58,11 @@ class BDICoreAgent(BDIAgent): class SendBehaviour(OneShotBehaviour): async def run(self) -> None: msg = Message( - to= settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host, - body= text + to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host, + body=text, ) await self.send(msg) self.agent.logger.info("Message sent to LLM: %s", text) - self.add_behaviour(SendBehaviour()) \ No newline at end of file + self.add_behaviour(SendBehaviour()) diff --git a/src/control_backend/agents/bdi/behaviours/belief_setter.py b/src/control_backend/agents/bdi/behaviours/belief_setter.py index 961288d..2f64036 100644 --- a/src/control_backend/agents/bdi/behaviours/belief_setter.py +++ b/src/control_backend/agents/bdi/behaviours/belief_setter.py @@ -3,7 +3,7 @@ import logging from spade.agent import Message from spade.behaviour import CyclicBehaviour -from spade_bdi.bdi import BDIAgent, BeliefNotInitiated +from spade_bdi.bdi import BDIAgent from control_backend.core.config import settings @@ -23,7 +23,6 @@ class BeliefSetterBehaviour(CyclicBehaviour): self.logger.info(f"Received message {msg.body}") self._process_message(msg) - def _process_message(self, message: Message): sender = message.sender.node # removes host from jid and converts to str self.logger.debug("Sender: %s", sender) @@ -61,6 +60,7 @@ class BeliefSetterBehaviour(CyclicBehaviour): self.agent.bdi.set_belief(belief, *arguments) # Special case: if there's a new user message, flag that we haven't responded yet - if belief == "user_said": self.agent.bdi.set_belief("new_message") + if belief == "user_said": + self.agent.bdi.set_belief("new_message") self.logger.info("Set belief %s with arguments %s", belief, arguments) diff --git a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py index 747ab4c..dc6e862 100644 --- a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py +++ b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py @@ -9,18 +9,20 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour): """ Adds behavior to receive responses from the LLM Agent. """ + logger = logging.getLogger("BDI/LLM Reciever") + async def run(self): msg = await self.receive(timeout=2) if not msg: return - sender = msg.sender.node + sender = msg.sender.node match sender: case settings.agent_settings.llm_agent_name: content = msg.body self.logger.info("Received LLM response: %s", content) - #Here the BDI can pass the message back as a response + # Here the BDI can pass the message back as a response case _: self.logger.debug("Not from the llm, discarding message") - pass \ No newline at end of file + pass diff --git a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py index c75e66c..ed06463 100644 --- a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py +++ b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py @@ -13,28 +13,30 @@ class BeliefFromText(CyclicBehaviour): # TODO: LLM prompt nog hardcoded llm_instruction_prompt = """ - You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules: - You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript). + You are an information extraction assistent for a BDI agent. Your task is to extract values \ + from a user's text to bind a list of ungrounded beliefs. Rules: + You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \ + and "text" (user's transcript). Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs. A single piece of text might contain multiple instances that match a belief. Respond ONLY with a single JSON object. The JSON object's keys should be the belief functors (e.g., "weather"). The value for each key must be a list of lists. - Each inner list must contain the extracted arguments (as strings) for one instance of that belief. - CRITICAL: If no information in the text matches a belief, DO NOT include that key in your response. + Each inner list must contain the extracted arguments (as strings) for one instance \ + of that belief. + CRITICAL: If no information in the text matches a belief, DO NOT include that key \ + in your response. """ - # on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt - #async def on_start(self): + # on_start agent receives message containing the beliefs to look out for and + # sets up the LLM with instruction prompt + # async def on_start(self): # msg = await self.receive(timeout=0.1) # self.beliefs = dict uit message # send instruction prompt to LLM beliefs: dict[str, list[str]] - beliefs = { - "mood": ["X"], - "car": ["Y"] - } + beliefs = {"mood": ["X"], "car": ["Y"]} async def run(self): msg = await self.receive(timeout=0.1) @@ -58,8 +60,8 @@ class BeliefFromText(CyclicBehaviour): prompt = text_prompt + beliefs_prompt self.logger.info(prompt) - #prompt_msg = Message(to="LLMAgent@whatever") - #response = self.send(prompt_msg) + # prompt_msg = Message(to="LLMAgent@whatever") + # response = self.send(prompt_msg) # Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]] response = '{"mood": [["happy"]]}' @@ -67,8 +69,9 @@ class BeliefFromText(CyclicBehaviour): try: json.loads(response) belief_message = Message( - to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, - body=response) + to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, + body=response, + ) belief_message.thread = "beliefs" await self.send(belief_message) @@ -85,9 +88,12 @@ class BeliefFromText(CyclicBehaviour): """ belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"} payload = json.dumps(belief) - belief_msg = Message(to=settings.agent_settings.belief_collector_agent_name - + '@' + settings.agent_settings.host, - body=payload) + belief_msg = Message( + to=settings.agent_settings.belief_collector_agent_name + + "@" + + settings.agent_settings.host, + body=payload, + ) belief_msg.thread = "beliefs" await self.send(belief_msg) diff --git a/src/control_backend/agents/bdi/text_extractor.py b/src/control_backend/agents/bdi/text_extractor.py index 596a3fe..ff9ad58 100644 --- a/src/control_backend/agents/bdi/text_extractor.py +++ b/src/control_backend/agents/bdi/text_extractor.py @@ -6,4 +6,4 @@ from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFr class TBeliefExtractor(Agent): async def setup(self): self.b = BeliefFromText() - self.add_behaviour(self.b) \ No newline at end of file + self.add_behaviour(self.b) diff --git a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py index 5dcf59d..eb3ee5d 100644 --- a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py +++ b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py @@ -1,11 +1,14 @@ import json import logging -from spade.behaviour import CyclicBehaviour + from spade.agent import Message +from spade.behaviour import CyclicBehaviour + from control_backend.core.config import settings logger = logging.getLogger(__name__) + class ContinuousBeliefCollector(CyclicBehaviour): """ Continuously collects beliefs/emotions from extractor agents: @@ -17,7 +20,6 @@ class ContinuousBeliefCollector(CyclicBehaviour): if msg: await self._process_message(msg) - async def _process_message(self, msg: Message): sender_node = self._sender_node(msg) @@ -27,7 +29,9 @@ class ContinuousBeliefCollector(CyclicBehaviour): except Exception as e: logger.warning( "BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s", - sender_node, msg.body, e + sender_node, + msg.body, + e, ) return @@ -35,16 +39,21 @@ class ContinuousBeliefCollector(CyclicBehaviour): # Prefer explicit 'type' field if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock": - logger.info("BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node) + logger.info( + "BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node + ) await self._handle_belief_text(payload, sender_node) - #This is not implemented yet, but we keep the structure for future use - elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock": - logger.info("BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node) + # This is not implemented yet, but we keep the structure for future use + elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock": + logger.info( + "BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node + ) await self._handle_emo_text(payload, sender_node) else: logger.info( "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", - sender_node, msg_type + sender_node, + msg_type, ) @staticmethod @@ -56,13 +65,12 @@ class ContinuousBeliefCollector(CyclicBehaviour): s = str(msg.sender) if msg.sender is not None else "no_sender" return s.split("@", 1)[0] if "@" in s else s - async def _handle_belief_text(self, payload: dict, origin: str): """ Expected payload: { "type": "belief_extraction_text", - "beliefs": {"user_said": ["hello"","Can you help me?","stop talking to me","No","Pepper do a dance"]} + "beliefs": {"user_said": ["Can you help me?"]} } @@ -72,11 +80,11 @@ class ContinuousBeliefCollector(CyclicBehaviour): if not beliefs: logger.info("BeliefCollector: no beliefs to process.") return - + if not isinstance(beliefs, dict): logger.warning("BeliefCollector: 'beliefs' is not a dict: %r", beliefs) return - + if not all(isinstance(v, list) for v in beliefs.values()): logger.warning("BeliefCollector: 'beliefs' values are not all lists: %r", beliefs) return @@ -84,17 +92,14 @@ class ContinuousBeliefCollector(CyclicBehaviour): logger.info("BeliefCollector: forwarding %d beliefs.", len(beliefs)) for belief_name, belief_list in beliefs.items(): for belief in belief_list: - logger.info(" - %s %s", belief_name,str(belief)) + logger.info(" - %s %s", belief_name, str(belief)) await self._send_beliefs_to_bdi(beliefs, origin=origin) - - async def _handle_emo_text(self, payload: dict, origin: str): """TODO: implement (after we have emotional recogntion)""" pass - async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None): """ Sends a unified belief packet to the BDI agent. @@ -107,6 +112,5 @@ class ContinuousBeliefCollector(CyclicBehaviour): msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs") msg.body = json.dumps(beliefs) - await self.send(msg) logger.info("BeliefCollector: sent %d belief(s) to BDI at %s", len(beliefs), to_jid) diff --git a/src/control_backend/agents/belief_collector/belief_collector.py b/src/control_backend/agents/belief_collector/belief_collector.py index dbb6095..8558242 100644 --- a/src/control_backend/agents/belief_collector/belief_collector.py +++ b/src/control_backend/agents/belief_collector/belief_collector.py @@ -1,13 +1,15 @@ import logging + from spade.agent import Agent from .behaviours.continuous_collect import ContinuousBeliefCollector logger = logging.getLogger(__name__) + class BeliefCollectorAgent(Agent): async def setup(self): logger.info("BeliefCollectorAgent starting (%s)", self.jid) # Attach the continuous collector behaviour (listens and forwards to BDI) self.add_behaviour(ContinuousBeliefCollector()) - logger.info("BeliefCollectorAgent ready.") \ No newline at end of file + logger.info("BeliefCollectorAgent ready.") diff --git a/src/control_backend/agents/llm/llm.py b/src/control_backend/agents/llm/llm.py index 0f78095..c3c17ab 100644 --- a/src/control_backend/agents/llm/llm.py +++ b/src/control_backend/agents/llm/llm.py @@ -65,8 +65,8 @@ class LLMAgent(Agent): Sends a response message back to the BDI Core Agent. """ reply = Message( - to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, - body=msg + to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, + body=msg, ) await self.send(reply) self.agent.logger.info("Reply sent to BDI Core Agent") @@ -80,35 +80,31 @@ class LLMAgent(Agent): """ async with httpx.AsyncClient(timeout=120.0) as client: # Example dynamic content for future (optional) - + instructions = LLMInstructions() developer_instruction = instructions.build_developer_instruction() - + response = await client.post( settings.llm_settings.local_llm_url, headers={"Content-Type": "application/json"}, json={ "model": settings.llm_settings.local_llm_model, "messages": [ - { - "role": "developer", - "content": developer_instruction - }, - { - "role": "user", - "content": prompt - } + {"role": "developer", "content": developer_instruction}, + {"role": "user", "content": prompt}, ], - "temperature": 0.3 + "temperature": 0.3, }, ) try: response.raise_for_status() data: dict[str, Any] = response.json() - return data.get("choices", [{}])[0].get( - "message", {} - ).get("content", "No response") + return ( + data.get("choices", [{}])[0] + .get("message", {}) + .get("content", "No response") + ) except httpx.HTTPError as err: self.agent.logger.error("HTTP error: %s", err) return "LLM service unavailable." diff --git a/src/control_backend/agents/mock_agents/belief_text_mock.py b/src/control_backend/agents/mock_agents/belief_text_mock.py index 607c2f5..27c5e49 100644 --- a/src/control_backend/agents/mock_agents/belief_text_mock.py +++ b/src/control_backend/agents/mock_agents/belief_text_mock.py @@ -1,18 +1,33 @@ import json + from spade.agent import Agent from spade.behaviour import OneShotBehaviour from spade.message import Message + from control_backend.core.config import settings + class BeliefTextAgent(Agent): class SendOnceBehaviourBlfText(OneShotBehaviour): async def run(self): - to_jid = f"{settings.agent_settings.belief_collector_agent_name}@{settings.agent_settings.host}" + to_jid = ( + settings.agent_settings.belief_collector_agent_name + + "@" + + settings.agent_settings.host + ) # Send multiple beliefs in one JSON payload payload = { "type": "belief_extraction_text", - "beliefs": {"user_said": ["hello test","Can you help me?","stop talking to me","No","Pepper do a dance"]} + "beliefs": { + "user_said": [ + "hello test", + "Can you help me?", + "stop talking to me", + "No", + "Pepper do a dance", + ] + }, } msg = Message(to=to_jid) diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 01fc824..51b8064 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -1,8 +1,9 @@ import json import logging + +import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour -import zmq from control_backend.core.config import settings from control_backend.core.zmq_context import context diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 504c707..8d56b09 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -1,14 +1,13 @@ import asyncio -import json import logging + +import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour -import zmq +from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.core.config import settings from control_backend.core.zmq_context import context -from control_backend.schemas.message import Message -from control_backend.agents.ri_command_agent import RICommandAgent logger = logging.getLogger(__name__) @@ -47,7 +46,7 @@ class RICommunicationAgent(Agent): message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0) # We didnt get a reply :( - except asyncio.TimeoutError as e: + except TimeoutError: logger.info("No ping retrieved in 3 seconds, killing myself.") self.kill() @@ -88,7 +87,7 @@ class RICommunicationAgent(Agent): try: received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) - except asyncio.TimeoutError: + except TimeoutError: logger.warning( "No connection established in 20 seconds (attempt %d/%d)", retries + 1, diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index f316cda..19d82ff 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -75,7 +75,8 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): self.model_name = "mlx-community/whisper-small.en-mlx" def load_model(self): - if self.was_loaded: return + if self.was_loaded: + return # There appears to be no dedicated mechanism to preload a model, but this `get_model` does # store it in memory for later usage ModelHolder.get_model(self.model_name, mx.float16) @@ -83,9 +84,9 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return mlx_whisper.transcribe(audio, - path_or_hf_repo=self.model_name, - decode_options=self._get_decode_options(audio))["text"] + return mlx_whisper.transcribe( + audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio) + )["text"] return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip() @@ -95,12 +96,13 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): self.model = None def load_model(self): - if self.model is not None: return + if self.model is not None: + return device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model = whisper.load_model("small.en", device=device) def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return whisper.transcribe(self.model, - audio, - decode_options=self._get_decode_options(audio))["text"] + return whisper.transcribe( + self.model, audio, decode_options=self._get_decode_options(audio) + )["text"] diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index a2c8e2b..2d936c4 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -47,7 +47,8 @@ class TranscriptionAgent(Agent): """Share a transcription to the other agents that depend on it.""" receiver_jids = [ settings.agent_settings.text_belief_extractor_agent_name - + '@' + settings.agent_settings.host, + + "@" + + settings.agent_settings.host, ] # Set message receivers here for receiver_jid in receiver_jids: diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py index badaf90..e19290f 100644 --- a/src/control_backend/api/v1/endpoints/command.py +++ b/src/control_backend/api/v1/endpoints/command.py @@ -1,9 +1,9 @@ -from fastapi import APIRouter, Request import logging +from fastapi import APIRouter, Request from zmq import Socket -from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint +from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) @@ -17,6 +17,5 @@ async def receive_command(command: SpeechCommand, request: Request): topic = b"command" pub_socket: Socket = request.app.state.internal_comm_socket pub_socket.send_multipart([topic, command.model_dump_json().encode()]) - return {"status": "Command received"} diff --git a/src/control_backend/api/v1/router.py b/src/control_backend/api/v1/router.py index dc7aea9..a23b3b3 100644 --- a/src/control_backend/api/v1/router.py +++ b/src/control_backend/api/v1/router.py @@ -1,6 +1,6 @@ from fastapi.routing import APIRouter -from control_backend.api.v1.endpoints import message, sse, command +from control_backend.api.v1.endpoints import command, message, sse api_router = APIRouter() diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 5e4b764..2fd16b8 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -24,6 +24,7 @@ class LLMSettings(BaseModel): local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "openai/gpt-oss-20b" + class Settings(BaseSettings): app_title: str = "PepperPlus" @@ -37,4 +38,5 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_file=".env") + settings = Settings() diff --git a/src/control_backend/main.py b/src/control_backend/main.py index d3588ea..138957c 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -8,13 +8,14 @@ import zmq from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -# Internal imports -from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.bdi.bdi_core import BDICoreAgent -from control_backend.agents.vad_agent import VADAgent -from control_backend.agents.llm.llm import LLMAgent from control_backend.agents.bdi.text_extractor import TBeliefExtractor from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent +from control_backend.agents.llm.llm import LLMAgent + +# Internal imports +from control_backend.agents.ri_communication_agent import RICommunicationAgent +from control_backend.agents.vad_agent import VADAgent from control_backend.api.v1.router import api_router from control_backend.core.config import settings from control_backend.core.zmq_context import context @@ -34,7 +35,6 @@ async def lifespan(app: FastAPI): app.state.internal_comm_socket = internal_comm_socket logger.info("Internal publishing socket bound to %s", internal_comm_socket) - # Initiate agents ri_communication_agent = RICommunicationAgent( settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host, @@ -45,26 +45,28 @@ async def lifespan(app: FastAPI): await ri_communication_agent.start() llm_agent = LLMAgent( - settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host, + settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.llm_agent_name, ) await llm_agent.start() bdi_core = BDICoreAgent( - settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, + settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name, "src/control_backend/agents/bdi/rules.asl", ) await bdi_core.start() belief_collector = BeliefCollectorAgent( - settings.agent_settings.belief_collector_agent_name + '@' + settings.agent_settings.host, + settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.belief_collector_agent_name, ) await belief_collector.start() text_belief_extractor = TBeliefExtractor( - settings.agent_settings.text_belief_extractor_agent_name + '@' + settings.agent_settings.host, + settings.agent_settings.text_belief_extractor_agent_name + + "@" + + settings.agent_settings.host, settings.agent_settings.text_belief_extractor_agent_name, ) await text_belief_extractor.start() diff --git a/src/control_backend/schemas/ri_message.py b/src/control_backend/schemas/ri_message.py index 97b7930..488b823 100644 --- a/src/control_backend/schemas/ri_message.py +++ b/src/control_backend/schemas/ri_message.py @@ -1,7 +1,7 @@ from enum import Enum -from typing import Any, Literal +from typing import Any -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel class RIEndpoint(str, Enum): diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index 219d682..4249401 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -1,10 +1,10 @@ -import asyncio -import zmq import json -import pytest from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import zmq + from control_backend.agents.ri_command_agent import RICommandAgent -from control_backend.schemas.ri_message import SpeechCommand @pytest.mark.asyncio diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index 3e4a056..fd555e1 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -1,6 +1,8 @@ import asyncio +from unittest.mock import ANY, AsyncMock, MagicMock, patch + import pytest -from unittest.mock import AsyncMock, MagicMock, patch, ANY + from control_backend.agents.ri_communication_agent import RICommunicationAgent @@ -185,8 +187,8 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): # Mock RICommandAgent agent startup - # We are sending wrong negotiation info to the communication agent, so we should retry and expect a - # better response, within a limited time. + # We are sending wrong negotiation info to the communication agent, + # so we should retry and expect a better response, within a limited time. with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: @@ -358,8 +360,8 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): # Mock RICommandAgent agent startup - # We are sending wrong negotiation info to the communication agent, so we should retry and expect a - # better response, within a limited time. + # We are sending wrong negotiation info to the communication agent, + # so we should retry and expect a better response, within a limited time. with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index 07bd866..04890c1 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -1,7 +1,8 @@ +from unittest.mock import MagicMock + import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from unittest.mock import MagicMock from control_backend.api.v1.endpoints import command from control_backend.schemas.ri_message import SpeechCommand diff --git a/test/integration/schemas/test_ri_message.py b/test/integration/schemas/test_ri_message.py index aef9ae6..5078f9a 100644 --- a/test/integration/schemas/test_ri_message.py +++ b/test/integration/schemas/test_ri_message.py @@ -1,7 +1,8 @@ import pytest -from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand from pydantic import ValidationError +from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand + def valid_command_1(): return SpeechCommand(data="Hallo?") @@ -13,24 +14,13 @@ def invalid_command_1(): def test_valid_speech_command_1(): command = valid_command_1() - try: - RIMessage.model_validate(command) - SpeechCommand.model_validate(command) - assert True - except ValidationError: - assert False + RIMessage.model_validate(command) + SpeechCommand.model_validate(command) def test_invalid_speech_command_1(): command = invalid_command_1() - passed_ri_message_validation = False - try: - # Should succeed, still. - RIMessage.model_validate(command) - passed_ri_message_validation = True + RIMessage.model_validate(command) - # Should fail. + with pytest.raises(ValidationError): SpeechCommand.model_validate(command) - assert False - except ValidationError: - assert passed_ri_message_validation diff --git a/test/unit/agents/bdi/behaviours/test_belief_setter.py b/test/unit/agents/bdi/behaviours/test_belief_setter.py index 788e95a..c7bb0e9 100644 --- a/test/unit/agents/bdi/behaviours/test_belief_setter.py +++ b/test/unit/agents/bdi/behaviours/test_belief_setter.py @@ -203,6 +203,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog): assert "Set belief is_hot with arguments ['kitchen']" in caplog.text assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text + # def test_responded_unset(belief_setter, mock_agent): # # Arrange # new_beliefs = {"user_said": ["message"]} diff --git a/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py b/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py index 622aefd..e842f5c 100644 --- a/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py +++ b/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py @@ -1,10 +1,12 @@ import json -import logging -from unittest.mock import MagicMock, AsyncMock, call +from unittest.mock import AsyncMock, MagicMock import pytest -from control_backend.agents.belief_collector.behaviours.continuous_collect import ContinuousBeliefCollector +from control_backend.agents.belief_collector.behaviours.continuous_collect import ( + ContinuousBeliefCollector, +) + @pytest.fixture def mock_agent(mocker): @@ -13,18 +15,20 @@ def mock_agent(mocker): agent.jid = "belief_collector_agent@test" return agent + @pytest.fixture def continuous_collector(mock_agent, mocker): """Fixture to create an instance of ContinuousBeliefCollector with a mocked agent.""" # Patch asyncio.sleep to prevent tests from actually waiting mocker.patch("asyncio.sleep", return_value=None) - + collector = ContinuousBeliefCollector() collector.agent = mock_agent # Mock the receive method, we will control its return value in each test collector.receive = AsyncMock() return collector + @pytest.mark.asyncio async def test_run_no_message_received(continuous_collector, mocker): """ @@ -40,6 +44,7 @@ async def test_run_no_message_received(continuous_collector, mocker): # Assert continuous_collector._process_message.assert_not_called() + @pytest.mark.asyncio async def test_run_message_received(continuous_collector, mocker): """ @@ -55,7 +60,8 @@ async def test_run_message_received(continuous_collector, mocker): # Assert continuous_collector._process_message.assert_awaited_once_with(mock_msg) - + + @pytest.mark.asyncio async def test_process_message_invalid(continuous_collector, mocker): """ @@ -66,15 +72,18 @@ async def test_process_message_invalid(continuous_collector, mocker): msg = MagicMock() msg.body = invalid_json msg.sender = "belief_text_agent_mock@test" - - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") - + + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) + # Act await continuous_collector._process_message(msg) # Assert logger_mock.warning.assert_called_once() + def test_get_sender_from_message(continuous_collector): """ Test that _sender_node correctly extracts the sender node from the message JID. @@ -89,6 +98,7 @@ def test_get_sender_from_message(continuous_collector): # Assert assert sender_node == "agent_node" + @pytest.mark.asyncio async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker): msg = MagicMock() @@ -98,6 +108,7 @@ async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker await continuous_collector._process_message(msg) spy.assert_awaited_once() + @pytest.mark.asyncio async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker): msg = MagicMock() @@ -107,6 +118,7 @@ async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mock await continuous_collector._process_message(msg) spy.assert_awaited_once() + @pytest.mark.asyncio async def test_routes_to_handle_emo_text(continuous_collector, mocker): msg = MagicMock() @@ -116,50 +128,64 @@ async def test_routes_to_handle_emo_text(continuous_collector, mocker): await continuous_collector._process_message(msg) spy.assert_awaited_once() + @pytest.mark.asyncio async def test_unrecognized_message_logs_info(continuous_collector, mocker): msg = MagicMock() msg.body = json.dumps({"type": "something_else"}) msg.sender = "x@test" - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._process_message(msg) logger_mock.info.assert_any_call( - "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", "x", "something_else" + "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", + "x", + "something_else", ) @pytest.mark.asyncio async def test_belief_text_no_beliefs(continuous_collector, mocker): msg_payload = {"type": "belief_extraction_text"} # no 'beliefs' - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._handle_belief_text(msg_payload, "origin_node") logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.") + @pytest.mark.asyncio async def test_belief_text_beliefs_not_dict(continuous_collector, mocker): payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]} - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._handle_belief_text(payload, "origin") - logger_mock.warning.assert_any_call("BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]) + logger_mock.warning.assert_any_call( + "BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"] + ) + @pytest.mark.asyncio async def test_belief_text_values_not_lists(continuous_collector, mocker): payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}} - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._handle_belief_text(payload, "origin") logger_mock.warning.assert_any_call( "BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"} ) + @pytest.mark.asyncio async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker): - payload = { - "type": "belief_extraction_text", - "beliefs": {"user_said": ["hello test", "No"]} - } - # Your code calls self.send(..); patch it (or switch implementation to self.agent.send and patch that) + payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}} continuous_collector.send = AsyncMock() - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock") logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1) @@ -169,12 +195,14 @@ async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, # make sure we attempted a send continuous_collector.send.assert_awaited_once() + @pytest.mark.asyncio async def test_send_beliefs_noop_on_empty(continuous_collector): continuous_collector.send = AsyncMock() await continuous_collector._send_beliefs_to_bdi([], origin="o") continuous_collector.send.assert_not_awaited() + # @pytest.mark.asyncio # async def test_send_beliefs_sends_json_packet(continuous_collector): # # Patch .send and capture the message body @@ -191,19 +219,22 @@ async def test_send_beliefs_noop_on_empty(continuous_collector): # assert "belief_packet" in json.loads(sent["body"])["type"] # assert json.loads(sent["body"])["beliefs"] == beliefs + def test_sender_node_no_sender_returns_literal(continuous_collector): msg = MagicMock() msg.sender = None assert continuous_collector._sender_node(msg) == "no_sender" + def test_sender_node_without_at(continuous_collector): msg = MagicMock() msg.sender = "localpartonly" assert continuous_collector._sender_node(msg) == "localpartonly" + @pytest.mark.asyncio async def test_belief_text_coerces_non_strings(continuous_collector, mocker): payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}} continuous_collector.send = AsyncMock() await continuous_collector._handle_belief_text(payload, "origin") - continuous_collector.send.assert_awaited_once() + continuous_collector.send.assert_awaited_once() diff --git a/test/unit/agents/transcription/test_speech_recognizer.py b/test/unit/agents/transcription/test_speech_recognizer.py index 6e7cde0..88a5ac2 100644 --- a/test/unit/agents/transcription/test_speech_recognizer.py +++ b/test/unit/agents/transcription/test_speech_recognizer.py @@ -6,7 +6,7 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper def test_estimate_max_tokens(): """Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens.""" - audio = np.empty(shape=(60*16_000), dtype=np.float32) + audio = np.empty(shape=(60 * 16_000), dtype=np.float32) actual = SpeechRecognizer._estimate_max_tokens(audio) @@ -16,7 +16,7 @@ def test_estimate_max_tokens(): def test_get_decode_options(): """Check whether the right decode options are given under different scenarios.""" - audio = np.empty(shape=(60*16_000), dtype=np.float32) + audio = np.empty(shape=(60 * 16_000), dtype=np.float32) # With the defaults, it should limit output length based on input size recognizer = OpenAIWhisperSpeechRecognizer() From e5bf6fd1ccb95d8ad97820b2199193c0bec0e344 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Sun, 2 Nov 2025 20:40:32 +0100 Subject: [PATCH 05/18] docs: update README instructions for git hooks Removed old advice from the README to configure git to add pre-commit hooks manually. We now have `pre-commit` for this, and they conflict. Added the command to install commit message hooks. ref: N25B-241 --- README.md | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 45f8f98..a79f138 100644 --- a/README.md +++ b/README.md @@ -49,19 +49,9 @@ uv run --group integration-test pytest test/integration ## GitHooks -To activate automatic commits/branch name checks run: - -```shell -git config --local core.hooksPath .githooks -``` - -If your commit fails its either: -branch name != /description-of-branch , -commit name != : description of the commit. - : N25B-Num's - -To add automatic linting and formatting, run: +To activate automatic linting, formatting, branch name checks and commit message checks, run: ```shell uv run pre-commit install -``` \ No newline at end of file +uv run pre-commit install --hook-type commit-msg +``` From e025b146100987f9c78d37c8e57ea44fc8aeb683 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Sun, 2 Nov 2025 20:54:38 +0100 Subject: [PATCH 06/18] docs: add suggested fix for potential issue ref: N25B-241 --- README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a79f138..d20b36d 100644 --- a/README.md +++ b/README.md @@ -47,11 +47,19 @@ Or for integration tests: uv run --group integration-test pytest test/integration ``` -## GitHooks +## Git Hooks To activate automatic linting, formatting, branch name checks and commit message checks, run: -```shell +```bash uv run pre-commit install uv run pre-commit install --hook-type commit-msg ``` + +You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running: + +```bash +git config --local --unset core.hooksPath +``` + +Then run the pre-commit install commands again. From 020bf55772f8b8a482c62151b39de537d908c841 Mon Sep 17 00:00:00 2001 From: Kasper Date: Sun, 2 Nov 2025 22:02:32 +0100 Subject: [PATCH 07/18] fix: automated commit detection ref: N25B-241 --- .githooks/check-commit-msg.sh | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh index 82bd441..f87749a 100755 --- a/.githooks/check-commit-msg.sh +++ b/.githooks/check-commit-msg.sh @@ -23,6 +23,37 @@ NC='\033[0m' # No Color # The first argument to the hook is the path to the file containing the commit message COMMIT_MSG_FILE=$1 +# --- Automated Commit Detection --- + +# Git directory (.git/) +GIT_DIR=$(git rev-parse --git-dir) +# Check for a merge commit +if [ -f "$GIT_DIR/MERGE_HEAD" ]; then + echo "Hook: Detected a merge commit." + # Ensure the message follows a 'Merge branch...' pattern. + first_line=$(head -n1 "$COMMIT_MSG_FILE") + if [[ ! "$first_line" =~ ^Merge.* ]]; then + echo "Error: Merge commit message should start with 'Merge'." >&2 + exit 1 + fi + exit 0 + +# Check for a squash commit (from git merge --squash) +elif [ -f "$GIT_DIR/SQUASH_MSG" ]; then + echo "Hook: Detected a squash commit. Skipping validation." + exit 0 + +# Check for a revert commit +elif [ -f "$GIT_DIR/REVERT_HEAD" ]; then + echo "Hook: Detected a revert commit. Skipping validation." + exit 0 + +# Check for a cherry-pick commit +elif [ -f "$GIT_DIR/CHERRY_PICK_HEAD" ]; then + echo "Hook: Detected a cherry-pick commit. Skipping validation." + exit 0 +fi + # --- Validation Functions --- # Function to print an error message and exit From 0d5e198cad506d495093785d267599b223269a16 Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Mon, 3 Nov 2025 14:51:18 +0100 Subject: [PATCH 08/18] fix: pattern matching instead of file existence The previous method of detecting automated commits was error-prone, specifically when using VSCode to commit changes. This new method uses a simple Regex pattern match to see if the commit message matches any known auto-generated commits. ref: N25B-241 --- .githooks/check-commit-msg.sh | 43 +++++++++++++++++------------------ 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh index f87749a..6fbc251 100755 --- a/.githooks/check-commit-msg.sh +++ b/.githooks/check-commit-msg.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # This script checks if a commit message follows the specified format. # It's designed to be used as a 'commit-msg' git hook. @@ -25,32 +25,31 @@ COMMIT_MSG_FILE=$1 # --- Automated Commit Detection --- -# Git directory (.git/) -GIT_DIR=$(git rev-parse --git-dir) -# Check for a merge commit -if [ -f "$GIT_DIR/MERGE_HEAD" ]; then - echo "Hook: Detected a merge commit." - # Ensure the message follows a 'Merge branch...' pattern. - first_line=$(head -n1 "$COMMIT_MSG_FILE") - if [[ ! "$first_line" =~ ^Merge.* ]]; then - echo "Error: Merge commit message should start with 'Merge'." >&2 - exit 1 - fi - exit 0 +# Read the first line (header) for initial checks +HEADER=$(head -n 1 "$COMMIT_MSG_FILE") -# Check for a squash commit (from git merge --squash) -elif [ -f "$GIT_DIR/SQUASH_MSG" ]; then - echo "Hook: Detected a squash commit. Skipping validation." +# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) +# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." +MERGE_PATTERN="^Merge (branch|pull request|tag) .*" +if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then + echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" exit 0 +fi -# Check for a revert commit -elif [ -f "$GIT_DIR/REVERT_HEAD" ]; then - echo "Hook: Detected a revert commit. Skipping validation." +# Check for Revert commits +# Example: "Revert "feat: add new feature"" +REVERT_PATTERN="^Revert \".*\"" +if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then + echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}" exit 0 +fi -# Check for a cherry-pick commit -elif [ -f "$GIT_DIR/CHERRY_PICK_HEAD" ]; then - echo "Hook: Detected a cherry-pick commit. Skipping validation." +# Check for Cherry-pick commits (this pattern appears at the end of the message) +# Example: "(cherry picked from commit deadbeef...)" +# We use grep -q to search the whole file quietly. +CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)" +if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then + echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}" exit 0 fi From 3c8cee54eb15a9ddd6a795ab3d0cf7beee078820 Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Mon, 3 Nov 2025 14:51:18 +0100 Subject: [PATCH 09/18] fix: pattern matching instead of file existence The previous method of detecting automated commits was error-prone, specifically when using VSCode to commit changes. This new method uses a simple Regex pattern match to see if the commit message matches any known auto-generated commits. ref: N25B-241 --- .githooks/check-branch-name.sh | 2 +- .githooks/check-commit-msg.sh | 43 +++++++++++++++++----------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/.githooks/check-branch-name.sh b/.githooks/check-branch-name.sh index 752e199..0e71c9b 100755 --- a/.githooks/check-branch-name.sh +++ b/.githooks/check-branch-name.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # This script checks if the current branch name follows the specified format. # It's designed to be used as a 'pre-commit' git hook. diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh index f87749a..6fbc251 100755 --- a/.githooks/check-commit-msg.sh +++ b/.githooks/check-commit-msg.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # This script checks if a commit message follows the specified format. # It's designed to be used as a 'commit-msg' git hook. @@ -25,32 +25,31 @@ COMMIT_MSG_FILE=$1 # --- Automated Commit Detection --- -# Git directory (.git/) -GIT_DIR=$(git rev-parse --git-dir) -# Check for a merge commit -if [ -f "$GIT_DIR/MERGE_HEAD" ]; then - echo "Hook: Detected a merge commit." - # Ensure the message follows a 'Merge branch...' pattern. - first_line=$(head -n1 "$COMMIT_MSG_FILE") - if [[ ! "$first_line" =~ ^Merge.* ]]; then - echo "Error: Merge commit message should start with 'Merge'." >&2 - exit 1 - fi - exit 0 +# Read the first line (header) for initial checks +HEADER=$(head -n 1 "$COMMIT_MSG_FILE") -# Check for a squash commit (from git merge --squash) -elif [ -f "$GIT_DIR/SQUASH_MSG" ]; then - echo "Hook: Detected a squash commit. Skipping validation." +# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) +# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." +MERGE_PATTERN="^Merge (branch|pull request|tag) .*" +if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then + echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" exit 0 +fi -# Check for a revert commit -elif [ -f "$GIT_DIR/REVERT_HEAD" ]; then - echo "Hook: Detected a revert commit. Skipping validation." +# Check for Revert commits +# Example: "Revert "feat: add new feature"" +REVERT_PATTERN="^Revert \".*\"" +if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then + echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}" exit 0 +fi -# Check for a cherry-pick commit -elif [ -f "$GIT_DIR/CHERRY_PICK_HEAD" ]; then - echo "Hook: Detected a cherry-pick commit. Skipping validation." +# Check for Cherry-pick commits (this pattern appears at the end of the message) +# Example: "(cherry picked from commit deadbeef...)" +# We use grep -q to search the whole file quietly. +CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)" +if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then + echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}" exit 0 fi From 360f601d007041324aa36a727da598d52b6ce68a Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Mon, 3 Nov 2025 15:23:11 +0100 Subject: [PATCH 10/18] feat: chore doesn't need ref If we detect a chore commit, we don't check for the correct ref/close footer. ref: N25B-241 --- .githooks/check-commit-msg.sh | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh index 6fbc251..cdf56fb 100755 --- a/.githooks/check-commit-msg.sh +++ b/.githooks/check-commit-msg.sh @@ -86,20 +86,24 @@ if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then error_exit "Invalid header format.\n\nHeader must be in the format: : \nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature" fi -# 3. Validate the footer (last line) of the commit message -FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE") +# Only validate footer if commit type is not chore +TYPE=$(echo "$HEADER" | cut -d':' -f1) +if [ "$TYPE" != "chore" ]; then + # 3. Validate the footer (last line) of the commit message + FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE") -# Regex breakdown: -# ^(ref|close) - Starts with 'ref' or 'close' -# : - Followed by a literal colon -# \s - Followed by a single space -# N25B- - Followed by the literal string 'N25B-' -# [0-9]+ - Followed by one or more digits -# $ - End of the line -FOOTER_REGEX="^(ref|close): N25B-[0-9]+$" + # Regex breakdown: + # ^(ref|close) - Starts with 'ref' or 'close' + # : - Followed by a literal colon + # \s - Followed by a single space + # N25B- - Followed by the literal string 'N25B-' + # [0-9]+ - Followed by one or more digits + # $ - End of the line + FOOTER_REGEX="^(ref|close): N25B-[0-9]+$" -if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then - error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: \nExample: ref: N25B-123" + if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then + error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: \nExample: ref: N25B-123" + fi fi # 4. If the message has more than 2 lines, validate the separator From cb5457b6be1457a7e6fb118d0c7f8164e327cd6b Mon Sep 17 00:00:00 2001 From: Kasper Marinus Date: Mon, 3 Nov 2025 15:29:06 +0100 Subject: [PATCH 11/18] feat: check for squash commits ref: N25B-241 --- .githooks/check-commit-msg.sh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh index cdf56fb..eacf2a8 100755 --- a/.githooks/check-commit-msg.sh +++ b/.githooks/check-commit-msg.sh @@ -53,6 +53,14 @@ if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then exit 0 fi +# Check for Squash +# Example: "Squash commits ..." +SQUASH_PATTERN="^Squash .+" +if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then + echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + # --- Validation Functions --- # Function to print an error message and exit From 5c228df1094454443edd8dd79b192c9b4f89edc6 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:41:11 +0100 Subject: [PATCH 12/18] fix: allow Whisper to generate more tokens based on audio length Before, it sometimes cut off the transcription too early. ref: N25B-209 --- .../agents/transcription/speech_recognizer.py | 17 ++++++++++++----- .../agents/transcription/transcription_agent.py | 4 ++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index 45e42bf..9e61fd7 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC): def _estimate_max_tokens(audio: np.ndarray) -> int: """ Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking - rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens. + rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens. :param audio: The audio sample (16 kHz) to use for length estimation. :return: The estimated length of the transcribed audio in tokens. """ length_seconds = len(audio) / 16_000 length_minutes = length_seconds / 60 - word_count = length_minutes * 300 + word_count = length_minutes * 450 token_count = word_count / 3 * 4 - return int(token_count) + return int(token_count) + 10 def _get_decode_options(self, audio: np.ndarray) -> dict: """ @@ -84,7 +84,12 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"] + return mlx_whisper.transcribe( + audio, + path_or_hf_repo=self.model_name, + initial_prompt="You're a robot called Pepper, talking with a person called Twirre.", + **self._get_decode_options(audio), + )["text"].strip() class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): @@ -101,5 +106,7 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() return whisper.transcribe( - self.model, audio, decode_options=self._get_decode_options(audio) + self.model, + audio, + **self._get_decode_options(audio) )["text"] diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index 2d936c4..196fd28 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -59,6 +59,10 @@ class TranscriptionAgent(Agent): audio = await self.audio_in_socket.recv() audio = np.frombuffer(audio, dtype=np.float32) speech = await self._transcribe(audio) + if not speech: + logger.info("Nothing transcribed.") + return + logger.info("Transcribed speech: %s", speech) await self._share_transcription(speech) From b0085625541f059866cc8a01c2d7d6a31c932229 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:08:07 +0100 Subject: [PATCH 13/18] fix: tests To work with the new zmq instance context. ref: N25B-217 --- .../agents/test_ri_commands_agent.py | 33 +++--- .../agents/test_ri_communication_agent.py | 100 +++++------------- .../agents/vad_agent/test_vad_agent.py | 22 ++-- .../api/endpoints/test_command_endpoint.py | 40 +------ 4 files changed, 61 insertions(+), 134 deletions(-) diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index 219d682..15498e3 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -7,19 +7,21 @@ from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.schemas.ri_message import SpeechCommand +@pytest.fixture +def zmq_context(mocker): + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context + + @pytest.mark.asyncio -async def test_setup_bind(monkeypatch): +async def test_setup_bind(zmq_context, mocker): """Test setup with bind=True""" - fake_socket = MagicMock() - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket - ) + fake_socket = zmq_context.return_value.socket.return_value agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), - ) + settings = mocker.patch("control_backend.agents.ri_command_agent.settings") + settings.zmq_settings.internal_sub_address = "tcp://internal:1234" await agent.setup() @@ -34,18 +36,13 @@ async def test_setup_bind(monkeypatch): @pytest.mark.asyncio -async def test_setup_connect(monkeypatch): +async def test_setup_connect(zmq_context, mocker): """Test setup with bind=False""" - fake_socket = MagicMock() - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket - ) + fake_socket = zmq_context.return_value.socket.return_value agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), - ) + settings = mocker.patch("control_backend.agents.ri_command_agent.settings") + settings.zmq_settings.internal_sub_address = "tcp://internal:1234" await agent.setup() diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index 3e4a056..a641c61 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -82,21 +82,23 @@ def fake_json_invalid_id_negototiate(): ) +@pytest.fixture +def zmq_context(mocker): + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context + + @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_1(monkeypatch): +async def test_setup_creates_socket_and_negotiate_1(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_1() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -126,20 +128,15 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_2(monkeypatch): +async def test_setup_creates_socket_and_negotiate_2(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_2() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -169,20 +166,15 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): """ Test the functionality of setup with incorrect negotiation message """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_wrong_negototiate_1() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup # We are sending wrong negotiation info to the communication agent, so we should retry and expect a @@ -213,20 +205,15 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_4(monkeypatch): +async def test_setup_creates_socket_and_negotiate_4(zmq_context): """ Test the setup of the communication agent with different bind value """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_3() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -256,20 +243,15 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_5(monkeypatch): +async def test_setup_creates_socket_and_negotiate_5(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_4() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -299,20 +281,15 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_6(monkeypatch): +async def test_setup_creates_socket_and_negotiate_6(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_5() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -342,20 +319,15 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): """ Test the functionality of setup with incorrect id """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_invalid_id_negototiate() - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Mock RICommandAgent agent startup # We are sending wrong negotiation info to the communication agent, so we should retry and expect a @@ -383,20 +355,15 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog): """ Test the functionality of setup with incorrect negotiation message """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) - # Mock context.socket to return our fake socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: @@ -478,8 +445,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog): @pytest.mark.asyncio -async def test_listen_behaviour_timeout(caplog): - fake_socket = AsyncMock() +async def test_listen_behaviour_timeout(zmq_context, caplog): + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # recv_json will never resolve, simulate timeout fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) @@ -527,16 +494,12 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): @pytest.mark.asyncio -async def test_setup_unexpected_exception(monkeypatch, caplog): - fake_socket = MagicMock() +async def test_setup_unexpected_exception(zmq_context, caplog): + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # Simulate unexpected exception during recv_json() fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - agent = RICommunicationAgent( "test@server", "password", address="tcp://localhost:5555", bind=False ) @@ -549,9 +512,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_unpacking_exception(monkeypatch, caplog): +async def test_setup_unpacking_exception(zmq_context, caplog): # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # Make recv_json return malformed negotiation data to trigger unpacking exception @@ -561,11 +524,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): } # missing 'port' and 'bind' fake_socket.recv_json = AsyncMock(return_value=malformed_data) - # Patch context.socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket - ) - # Patch RICommandAgent so it won't actually start with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True diff --git a/test/integration/agents/vad_agent/test_vad_agent.py b/test/integration/agents/vad_agent/test_vad_agent.py index 54c9d82..7d8e173 100644 --- a/test/integration/agents/vad_agent/test_vad_agent.py +++ b/test/integration/agents/vad_agent/test_vad_agent.py @@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent @pytest.fixture def zmq_context(mocker): - return mocker.patch("control_backend.agents.vad_agent.zmq_context") + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context @pytest.fixture @@ -54,13 +56,13 @@ def test_in_socket_creation(zmq_context, do_bind: bool): assert vad_agent.audio_in_socket is not None - zmq_context.socket.assert_called_once_with(zmq.SUB) - zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") + zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) + zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") if do_bind: - zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") + zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345") else: - zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") + zmq_context.return_value.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") def test_out_socket_creation(zmq_context): @@ -73,8 +75,8 @@ def test_out_socket_creation(zmq_context): assert vad_agent.audio_out_socket is not None - zmq_context.socket.assert_called_once_with(zmq.PUB) - zmq_context.socket.return_value.bind_to_random_port.assert_called_once() + zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) + zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() @pytest.mark.asyncio @@ -83,7 +85,7 @@ async def test_out_socket_creation_failure(zmq_context): Test setup failure when the audio output socket cannot be created. """ with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: - zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError vad_agent = VADAgent("tcp://localhost:12345", False) await vad_agent.setup() @@ -98,11 +100,11 @@ async def test_stop(zmq_context, transcription_agent): Test that when the VAD agent is stopped, the sockets are closed correctly. """ vad_agent = VADAgent("tcp://localhost:12345", False) - zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) + zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) await vad_agent.setup() await vad_agent.stop() - assert zmq_context.socket.return_value.close.call_count == 2 + assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert vad_agent.audio_in_socket is None assert vad_agent.audio_out_socket is None diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index 7e38924..8ecf816 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -26,8 +26,8 @@ def client(app): @pytest.mark.asyncio -@patch("control_backend.api.endpoints.command.Context.instance") -async def test_receive_command_success(mock_context_instance, async_client): +@patch("control_backend.api.v1.endpoints.command.Context.instance") +async def test_receive_command_success(mock_context_instance, client): """ Test for successful reception of a command. Ensures the status code is 202 and the response body is correct. @@ -35,54 +35,24 @@ async def test_receive_command_success(mock_context_instance, async_client): """ # Arrange mock_pub_socket = AsyncMock() - mock_context_instance.return_value.socket.return_value = mock_pub_socket + client.app.state.endpoints_pub_socket = mock_pub_socket - command_data = {"command": "test_command", "text": "This is a test"} + command_data = {"endpoint": "actuate/speech", "data": "This is a test"} speech_command = SpeechCommand(**command_data) # Act - response = await async_client.post("/command", json=command_data) + response = client.post("/command", json=command_data) # Assert assert response.status_code == 202 assert response.json() == {"status": "Command received"} # Verify that the ZMQ socket was used correctly - mock_context_instance.return_value.socket.assert_called_once_with(1) # zmq.PUB is 1 - mock_pub_socket.connect.assert_called_once() mock_pub_socket.send_multipart.assert_awaited_once_with( [b"command", speech_command.model_dump_json().encode()] ) -def test_receive_command_endpoint(client, app, mocker): - """ - Test that a POST to /command sends the right multipart message - and returns a 202 with the expected JSON body. - """ - mock_socket = mocker.patch.object() - - # Prepare test payload that matches SpeechCommand - payload = {"endpoint": "actuate/speech", "data": "yooo"} - - # Send POST request - response = client.post("/command", json=payload) - - # Check response - assert response.status_code == 202 - assert response.json() == {"status": "Command received"} - - # Verify that the socket was called with the correct data - assert mock_socket.send_multipart.called, "Socket should be used to send data" - - args, kwargs = mock_socket.send_multipart.call_args - sent_data = args[0] - - assert sent_data[0] == b"command" - # Check JSON encoding roughly matches - assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand) - - def test_receive_command_invalid_payload(client): """ Test invalid data handling (schema validation). From 2c867adce2641007b4e200d540dd652ffd79b5f2 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:22:42 +0100 Subject: [PATCH 14/18] fix: go back to the working ri command endpoint test Merged the wrong version because it seemed to solve the same problem. It did not. Now using the one I commited two commits ago. ref: N25B-217 --- .../api/endpoints/test_command_endpoint.py | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index 04890c1..c343f0c 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, patch import pytest from fastapi import FastAPI @@ -16,7 +16,6 @@ def app(): """ app = FastAPI() app.include_router(command.router) - app.state.internal_comm_socket = MagicMock() # mock ZMQ socket return app @@ -26,32 +25,30 @@ def client(app): return TestClient(app) -def test_receive_command_endpoint(client, app): +def test_receive_command_success(client): """ - Test that a POST to /command sends the right multipart message - and returns a 202 with the expected JSON body. + Test for successful reception of a command. + Ensures the status code is 202 and the response body is correct. + It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data. """ - mock_socket = app.state.internal_comm_socket + # Arrange + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket - # Prepare test payload that matches SpeechCommand - payload = {"endpoint": "actuate/speech", "data": "yooo"} + command_data = {"endpoint": "actuate/speech", "data": "This is a test"} + speech_command = SpeechCommand(**command_data) - # Send POST request - response = client.post("/command", json=payload) + # Act + response = client.post("/command", json=command_data) - # Check response + # Assert assert response.status_code == 202 assert response.json() == {"status": "Command received"} - # Verify that the socket was called with the correct data - assert mock_socket.send_multipart.called, "Socket should be used to send data" - - args, kwargs = mock_socket.send_multipart.call_args - sent_data = args[0] - - assert sent_data[0] == b"command" - # Check JSON encoding roughly matches - assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand) + # Verify that the ZMQ socket was used correctly + mock_pub_socket.send_multipart.assert_awaited_once_with( + [b"command", speech_command.model_dump_json().encode()] + ) def test_receive_command_invalid_payload(client): From f854a60e46d2805baea1efaf5821f4ff03504f8e Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:34:30 +0100 Subject: [PATCH 15/18] style: import order and lines too long ref: N25B-217 --- src/control_backend/agents/ri_command_agent.py | 2 +- .../agents/ri_communication_agent.py | 2 +- src/control_backend/main.py | 1 - .../agents/vad_agent/test_vad_agent.py | 18 ++++++++++++++---- .../api/endpoints/test_command_endpoint.py | 8 ++++---- 5 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 09c4299..0dcc981 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -1,9 +1,9 @@ import json import logging +import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour -import zmq from zmq.asyncio import Context from control_backend.core.config import settings diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 20b2a4b..638b967 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -1,9 +1,9 @@ import asyncio import logging +import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour -import zmq from zmq.asyncio import Context from control_backend.agents.ri_command_agent import RICommandAgent diff --git a/src/control_backend/main.py b/src/control_backend/main.py index f1cdfa6..29f1396 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -19,7 +19,6 @@ from control_backend.agents.vad_agent import VADAgent from control_backend.api.v1.router import api_router from control_backend.core.config import settings - logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) diff --git a/test/integration/agents/vad_agent/test_vad_agent.py b/test/integration/agents/vad_agent/test_vad_agent.py index 7d8e173..0e1fae2 100644 --- a/test/integration/agents/vad_agent/test_vad_agent.py +++ b/test/integration/agents/vad_agent/test_vad_agent.py @@ -57,12 +57,17 @@ def test_in_socket_creation(zmq_context, do_bind: bool): assert vad_agent.audio_in_socket is not None zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) - zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") + zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with( + zmq.SUBSCRIBE, + "", + ) if do_bind: zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345") else: - zmq_context.return_value.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") + zmq_context.return_value.socket.return_value.connect.assert_called_once_with( + "tcp://localhost:12345" + ) def test_out_socket_creation(zmq_context): @@ -85,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context): Test setup failure when the audio output socket cannot be created. """ with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: - zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = ( + zmq.ZMQBindError + ) vad_agent = VADAgent("tcp://localhost:12345", False) await vad_agent.setup() @@ -100,7 +107,10 @@ async def test_stop(zmq_context, transcription_agent): Test that when the VAD agent is stopped, the sockets are closed correctly. """ vad_agent = VADAgent("tcp://localhost:12345", False) - zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) + zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( + 1000, + 10000, + ) await vad_agent.setup() await vad_agent.stop() diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index c343f0c..1c9213a 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock import pytest from fastapi import FastAPI @@ -27,9 +27,9 @@ def client(app): def test_receive_command_success(client): """ - Test for successful reception of a command. - Ensures the status code is 202 and the response body is correct. - It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data. + Test for successful reception of a command. Ensures the status code is 202 and the response body + is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the + expected data. """ # Arrange mock_pub_socket = AsyncMock() From 1b58549c2ab3580a35217a55750943dbdd1336c4 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:41:48 +0100 Subject: [PATCH 16/18] test: fix expected test value after changing audio token allowance ref: N25B-209 --- test/unit/agents/transcription/test_speech_recognizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/unit/agents/transcription/test_speech_recognizer.py b/test/unit/agents/transcription/test_speech_recognizer.py index 88a5ac2..ab28dcf 100644 --- a/test/unit/agents/transcription/test_speech_recognizer.py +++ b/test/unit/agents/transcription/test_speech_recognizer.py @@ -5,12 +5,13 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper def test_estimate_max_tokens(): - """Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens.""" + """Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding, + expecting 610 tokens.""" audio = np.empty(shape=(60 * 16_000), dtype=np.float32) actual = SpeechRecognizer._estimate_max_tokens(audio) - assert actual == 400 + assert actual == 610 assert isinstance(actual, int) From 06e9e4fd150311edfea46940f6c8ea8fc64cfa1e Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:15:03 +0100 Subject: [PATCH 17/18] chore: ruff format --- src/control_backend/agents/llm/llm_instructions.py | 2 +- .../agents/transcription/speech_recognizer.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/control_backend/agents/llm/llm_instructions.py b/src/control_backend/agents/llm/llm_instructions.py index e3aed7e..6922fca 100644 --- a/src/control_backend/agents/llm/llm_instructions.py +++ b/src/control_backend/agents/llm/llm_instructions.py @@ -30,7 +30,7 @@ class LLMInstructions: "You are a Pepper robot engaging in natural human conversation.", "Keep responses between 1–3 sentences, unless told otherwise.\n", "You're given goals to reach. Reach them in order, but make the conversation feel " - "natural. Some turns you should not try to achieve your goals.\n" + "natural. Some turns you should not try to achieve your goals.\n", ] if self.norms: diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index 9e61fd7..527d371 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -87,7 +87,6 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): return mlx_whisper.transcribe( audio, path_or_hf_repo=self.model_name, - initial_prompt="You're a robot called Pepper, talking with a person called Twirre.", **self._get_decode_options(audio), )["text"].strip() @@ -105,8 +104,4 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return whisper.transcribe( - self.model, - audio, - **self._get_decode_options(audio) - )["text"] + return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"] From 262376fb58d3e6a867fd40a77e0c7c317dab9157 Mon Sep 17 00:00:00 2001 From: Twirre Meulenbelt <43213592+TwirreM@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:01:01 +0100 Subject: [PATCH 18/18] fix: break LLM response with fewer types of punctuation ref: N25B-207 --- src/control_backend/agents/llm/llm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/control_backend/agents/llm/llm.py b/src/control_backend/agents/llm/llm.py index 0b9d259..4487b23 100644 --- a/src/control_backend/agents/llm/llm.py +++ b/src/control_backend/agents/llm/llm.py @@ -112,9 +112,7 @@ class LLMAgent(Agent): # Stream the message in chunks separated by punctuation. # We include the delimiter in the emitted chunk for natural flow. - pattern = re.compile( - r".*?(?:,|;|:|—|–|-|\.{3}|…|\.|\?|!|\(|\)|\[|\]|/)\s*", re.DOTALL - ) + pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL) for m in pattern.finditer(current_chunk): chunk = m.group(0) if chunk: