diff --git a/.githooks/check-branch-name.sh b/.githooks/check-branch-name.sh new file mode 100755 index 0000000..0e71c9b --- /dev/null +++ b/.githooks/check-branch-name.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash + +# This script checks if the current branch name follows the specified format. +# It's designed to be used as a 'pre-commit' git hook. + +# Format: / +# Example: feat/add-user-login + +# --- Configuration --- +# An array of allowed commit types +ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) +# An array of branches to ignore +IGNORED_BRANCHES=(main dev) + +# --- Colors for Output --- +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# --- Helper Functions --- +error_exit() { + echo -e "${RED}ERROR: $1${NC}" >&2 + echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2 + exit 1 +} + +# --- Main Logic --- + +# 1. Get the current branch name +BRANCH_NAME=$(git symbolic-ref --short HEAD) + +# 2. Check if the current branch is in the ignored list +for ignored_branch in "${IGNORED_BRANCHES[@]}"; do + if [ "$BRANCH_NAME" == "$ignored_branch" ]; then + echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}" + exit 0 + fi +done + +# 3. Validate the overall structure: / +if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then + error_exit "Branch name must be in the format: /\nExample: feat/add-user-login" +fi + +# 4. Extract the type and description +TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1) +DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-) + +# 5. Validate the +type_valid=false +for allowed_type in "${ALLOWED_TYPES[@]}"; do + if [ "$TYPE" == "$allowed_type" ]; then + type_valid=true + break + fi +done + +if [ "$type_valid" == false ]; then + error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}" +fi + +# 6. Validate the +# Regex breakdown: +# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word). +# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times. +# $ - End of the string. +# This entire pattern enforces 1 to 6 words total, separated by dashes. +DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$" + +if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then + error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature" +fi + +# If all checks pass, exit successfully +echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}" +exit 0 diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh new file mode 100755 index 0000000..eacf2a8 --- /dev/null +++ b/.githooks/check-commit-msg.sh @@ -0,0 +1,135 @@ +#!/usr/bin/env bash + +# This script checks if a commit message follows the specified format. +# It's designed to be used as a 'commit-msg' git hook. + +# Format: +# : +# +# [optional] +# +# [ref/close]: + +# --- Configuration --- +# An array of allowed commit types +ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) + +# --- Colors for Output --- +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# The first argument to the hook is the path to the file containing the commit message +COMMIT_MSG_FILE=$1 + +# --- Automated Commit Detection --- + +# Read the first line (header) for initial checks +HEADER=$(head -n 1 "$COMMIT_MSG_FILE") + +# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) +# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." +MERGE_PATTERN="^Merge (branch|pull request|tag) .*" +if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then + echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Revert commits +# Example: "Revert "feat: add new feature"" +REVERT_PATTERN="^Revert \".*\"" +if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then + echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Cherry-pick commits (this pattern appears at the end of the message) +# Example: "(cherry picked from commit deadbeef...)" +# We use grep -q to search the whole file quietly. +CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)" +if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then + echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Squash +# Example: "Squash commits ..." +SQUASH_PATTERN="^Squash .+" +if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then + echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# --- Validation Functions --- + +# Function to print an error message and exit +# Usage: error_exit "Your error message here" +error_exit() { + # >&2 redirects echo to stderr + echo -e "${RED}ERROR: $1${NC}" >&2 + echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2 + exit 1 +} + +# --- Main Logic --- + +# 1. Read the header (first line) of the commit message +HEADER=$(head -n 1 "$COMMIT_MSG_FILE") + +# 2. Validate the header format: : +# Regex breakdown: +# ^(type1|type2|...) - Starts with one of the allowed types +# : - Followed by a literal colon +# \s - Followed by a single space +# .+ - Followed by one or more characters for the description +# $ - End of the line +TYPES_REGEX=$( + IFS="|" + echo "${ALLOWED_TYPES[*]}" +) +HEADER_REGEX="^($TYPES_REGEX): .+$" + +if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then + error_exit "Invalid header format.\n\nHeader must be in the format: : \nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature" +fi + +# Only validate footer if commit type is not chore +TYPE=$(echo "$HEADER" | cut -d':' -f1) +if [ "$TYPE" != "chore" ]; then + # 3. Validate the footer (last line) of the commit message + FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE") + + # Regex breakdown: + # ^(ref|close) - Starts with 'ref' or 'close' + # : - Followed by a literal colon + # \s - Followed by a single space + # N25B- - Followed by the literal string 'N25B-' + # [0-9]+ - Followed by one or more digits + # $ - End of the line + FOOTER_REGEX="^(ref|close): N25B-[0-9]+$" + + if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then + error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: \nExample: ref: N25B-123" + fi +fi + +# 4. If the message has more than 2 lines, validate the separator +# A blank line must exist between the header and the body. +LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace + +# We only care if there is a body. Header + Footer = 2 lines. +# Header + Blank Line + Body... + Footer > 2 lines. +if [ "$LINE_COUNT" -gt 2 ]; then + # Get the second line + SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE") + + # Check if the second line is NOT empty. If it's not, it's an error. + if [ -n "$SECOND_LINE" ]; then + error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present." + fi +fi + +# If all checks pass, exit with success +echo -e "${GREEN}Commit message is valid.${NC}" +exit 0 diff --git a/.githooks/commit-msg b/.githooks/commit-msg deleted file mode 100644 index 41992ad..0000000 --- a/.githooks/commit-msg +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/sh - -commit_msg_file=$1 -commit_msg=$(cat "$commit_msg_file") - -if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then - if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then - exit 0 - else - echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000" - exit 1 - fi -else - echo "❌ Commit message invalid! Must start with : " - exit 1 -fi \ No newline at end of file diff --git a/.githooks/pre-commit b/.githooks/pre-commit deleted file mode 100644 index 7e94937..0000000 --- a/.githooks/pre-commit +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/sh - -# Get current branch -branch=$(git rev-parse --abbrev-ref HEAD) - -if echo "$branch" | grep -Eq "(dev|main)"; then - echo 0 -fi - -# allowed pattern -if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then - exit 0 -else - echo "❌ Invalid branch name: $branch" - echo "Branch must be named / (must have one to six words separated by a dash)" - exit 1 -fi \ No newline at end of file diff --git a/.githooks/prepare-commit-msg b/.githooks/prepare-commit-msg deleted file mode 100644 index 5b706c1..0000000 --- a/.githooks/prepare-commit-msg +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/sh - -echo "#: - -#[optional body] - -#[optional footer(s)] - -#[ref/close]: " > $1 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6ed188..41710dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,24 @@ repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.14.2 - hooks: - # Run the linter. - - id: ruff-check - args: [ --fix ] - # Run the formatter. - - id: ruff-format \ No newline at end of file + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.14.2 + hooks: + # Run the linter. + - id: ruff-check + # Run the formatter. + - id: ruff-format + # Configure local hooks + - repo: local + hooks: + - id: check-commit-msg + name: Check commit message format + entry: .githooks/check-commit-msg.sh + language: script + stages: [commit-msg] + - id: check-branch-name + name: Check branch name format + entry: .githooks/check-branch-name.sh + language: script + stages: [pre-commit] + always_run: true + pass_filenames: false diff --git a/README.md b/README.md index 45f8f98..d20b36d 100644 --- a/README.md +++ b/README.md @@ -47,21 +47,19 @@ Or for integration tests: uv run --group integration-test pytest test/integration ``` -## GitHooks +## Git Hooks -To activate automatic commits/branch name checks run: +To activate automatic linting, formatting, branch name checks and commit message checks, run: -```shell -git config --local core.hooksPath .githooks +```bash +uv run pre-commit install +uv run pre-commit install --hook-type commit-msg ``` -If your commit fails its either: -branch name != /description-of-branch , -commit name != : description of the commit. - : N25B-Num's +You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running: -To add automatic linting and formatting, run: +```bash +git config --local --unset core.hooksPath +``` -```shell -uv run pre-commit install -``` \ No newline at end of file +Then run the pre-commit install commands again. diff --git a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py index 549fb0c..bc98bf1 100644 --- a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py +++ b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py @@ -9,18 +9,23 @@ from control_backend.core.config import settings class BeliefFromText(CyclicBehaviour): # TODO: LLM prompt nog hardcoded llm_instruction_prompt = """ - You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules: - You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript). + You are an information extraction assistent for a BDI agent. Your task is to extract values \ + from a user's text to bind a list of ungrounded beliefs. Rules: + You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \ + and "text" (user's transcript). Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs. A single piece of text might contain multiple instances that match a belief. Respond ONLY with a single JSON object. The JSON object's keys should be the belief functors (e.g., "weather"). The value for each key must be a list of lists. - Each inner list must contain the extracted arguments (as strings) for one instance of that belief. - CRITICAL: If no information in the text matches a belief, DO NOT include that key in your response. + Each inner list must contain the extracted arguments (as strings) for one instance \ + of that belief. + CRITICAL: If no information in the text matches a belief, DO NOT include that key \ + in your response. """ - # on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt + # on_start agent receives message containing the beliefs to look out for and + # sets up the LLM with instruction prompt # async def on_start(self): # msg = await self.receive(timeout=0.1) # self.beliefs = dict uit message diff --git a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py index 83e381d..621eb20 100644 --- a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py +++ b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py @@ -7,6 +7,7 @@ from spade.behaviour import CyclicBehaviour from control_backend.core.config import settings + class ContinuousBeliefCollector(CyclicBehaviour): """ Continuously collects beliefs/emotions from extractor agents: @@ -23,9 +24,12 @@ class ContinuousBeliefCollector(CyclicBehaviour): # Parse JSON payload try: payload = json.loads(msg.body) - except JSONDecodeError as e: - self.agent.logger.warning( - "Failed to parse JSON from %s. Body=%r Error=%s", sender_node, msg.body, e + except Exception as e: + logger.warning( + "BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s", + sender_node, + msg.body, + e, ) return @@ -51,7 +55,7 @@ class ContinuousBeliefCollector(CyclicBehaviour): Expected payload: { "type": "belief_extraction_text", - "beliefs": {"user_said": ["hello"","Can you help me?","stop talking to me","No","Pepper do a dance"]} + "beliefs": {"user_said": ["Can you help me?"]} } diff --git a/src/control_backend/agents/mock_agents/belief_text_mock.py b/src/control_backend/agents/mock_agents/belief_text_mock.py index ea896fb..27c5e49 100644 --- a/src/control_backend/agents/mock_agents/belief_text_mock.py +++ b/src/control_backend/agents/mock_agents/belief_text_mock.py @@ -10,7 +10,11 @@ from control_backend.core.config import settings class BeliefTextAgent(Agent): class SendOnceBehaviourBlfText(OneShotBehaviour): async def run(self): - to_jid = f"{settings.agent_settings.belief_collector_agent_name}@{settings.agent_settings.host}" + to_jid = ( + settings.agent_settings.belief_collector_agent_name + + "@" + + settings.agent_settings.host + ) # Send multiple beliefs in one JSON payload payload = { diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index f316cda..19d82ff 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -75,7 +75,8 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): self.model_name = "mlx-community/whisper-small.en-mlx" def load_model(self): - if self.was_loaded: return + if self.was_loaded: + return # There appears to be no dedicated mechanism to preload a model, but this `get_model` does # store it in memory for later usage ModelHolder.get_model(self.model_name, mx.float16) @@ -83,9 +84,9 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return mlx_whisper.transcribe(audio, - path_or_hf_repo=self.model_name, - decode_options=self._get_decode_options(audio))["text"] + return mlx_whisper.transcribe( + audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio) + )["text"] return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip() @@ -95,12 +96,13 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): self.model = None def load_model(self): - if self.model is not None: return + if self.model is not None: + return device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model = whisper.load_model("small.en", device=device) def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return whisper.transcribe(self.model, - audio, - decode_options=self._get_decode_options(audio))["text"] + return whisper.transcribe( + self.model, audio, decode_options=self._get_decode_options(audio) + )["text"] diff --git a/src/control_backend/api/v1/endpoints/command.py b/src/control_backend/api/v1/endpoints/command.py index badaf90..e19290f 100644 --- a/src/control_backend/api/v1/endpoints/command.py +++ b/src/control_backend/api/v1/endpoints/command.py @@ -1,9 +1,9 @@ -from fastapi import APIRouter, Request import logging +from fastapi import APIRouter, Request from zmq import Socket -from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint +from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) @@ -17,6 +17,5 @@ async def receive_command(command: SpeechCommand, request: Request): topic = b"command" pub_socket: Socket = request.app.state.internal_comm_socket pub_socket.send_multipart([topic, command.model_dump_json().encode()]) - return {"status": "Command received"} diff --git a/src/control_backend/api/v1/router.py b/src/control_backend/api/v1/router.py index dc7aea9..a23b3b3 100644 --- a/src/control_backend/api/v1/router.py +++ b/src/control_backend/api/v1/router.py @@ -1,6 +1,6 @@ from fastapi.routing import APIRouter -from control_backend.api.v1.endpoints import message, sse, command +from control_backend.api.v1.endpoints import command, message, sse api_router = APIRouter() diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 5e4b764..2fd16b8 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -24,6 +24,7 @@ class LLMSettings(BaseModel): local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "openai/gpt-oss-20b" + class Settings(BaseSettings): app_title: str = "PepperPlus" @@ -37,4 +38,5 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_file=".env") + settings = Settings() diff --git a/src/control_backend/schemas/ri_message.py b/src/control_backend/schemas/ri_message.py index 97b7930..488b823 100644 --- a/src/control_backend/schemas/ri_message.py +++ b/src/control_backend/schemas/ri_message.py @@ -1,7 +1,7 @@ from enum import Enum -from typing import Any, Literal +from typing import Any -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel class RIEndpoint(str, Enum): diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index 219d682..4249401 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -1,10 +1,10 @@ -import asyncio -import zmq import json -import pytest from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import zmq + from control_backend.agents.ri_command_agent import RICommandAgent -from control_backend.schemas.ri_message import SpeechCommand @pytest.mark.asyncio diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index 3e4a056..fd555e1 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -1,6 +1,8 @@ import asyncio +from unittest.mock import ANY, AsyncMock, MagicMock, patch + import pytest -from unittest.mock import AsyncMock, MagicMock, patch, ANY + from control_backend.agents.ri_communication_agent import RICommunicationAgent @@ -185,8 +187,8 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): # Mock RICommandAgent agent startup - # We are sending wrong negotiation info to the communication agent, so we should retry and expect a - # better response, within a limited time. + # We are sending wrong negotiation info to the communication agent, + # so we should retry and expect a better response, within a limited time. with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: @@ -358,8 +360,8 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): # Mock RICommandAgent agent startup - # We are sending wrong negotiation info to the communication agent, so we should retry and expect a - # better response, within a limited time. + # We are sending wrong negotiation info to the communication agent, + # so we should retry and expect a better response, within a limited time. with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: diff --git a/test/integration/api/endpoints/test_command_endpoint.py b/test/integration/api/endpoints/test_command_endpoint.py index 07bd866..04890c1 100644 --- a/test/integration/api/endpoints/test_command_endpoint.py +++ b/test/integration/api/endpoints/test_command_endpoint.py @@ -1,7 +1,8 @@ +from unittest.mock import MagicMock + import pytest from fastapi import FastAPI from fastapi.testclient import TestClient -from unittest.mock import MagicMock from control_backend.api.v1.endpoints import command from control_backend.schemas.ri_message import SpeechCommand diff --git a/test/integration/schemas/test_ri_message.py b/test/integration/schemas/test_ri_message.py index aef9ae6..5078f9a 100644 --- a/test/integration/schemas/test_ri_message.py +++ b/test/integration/schemas/test_ri_message.py @@ -1,7 +1,8 @@ import pytest -from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand from pydantic import ValidationError +from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand + def valid_command_1(): return SpeechCommand(data="Hallo?") @@ -13,24 +14,13 @@ def invalid_command_1(): def test_valid_speech_command_1(): command = valid_command_1() - try: - RIMessage.model_validate(command) - SpeechCommand.model_validate(command) - assert True - except ValidationError: - assert False + RIMessage.model_validate(command) + SpeechCommand.model_validate(command) def test_invalid_speech_command_1(): command = invalid_command_1() - passed_ri_message_validation = False - try: - # Should succeed, still. - RIMessage.model_validate(command) - passed_ri_message_validation = True + RIMessage.model_validate(command) - # Should fail. + with pytest.raises(ValidationError): SpeechCommand.model_validate(command) - assert False - except ValidationError: - assert passed_ri_message_validation diff --git a/test/unit/agents/bdi/behaviours/test_belief_setter.py b/test/unit/agents/bdi/behaviours/test_belief_setter.py index 788e95a..c7bb0e9 100644 --- a/test/unit/agents/bdi/behaviours/test_belief_setter.py +++ b/test/unit/agents/bdi/behaviours/test_belief_setter.py @@ -203,6 +203,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog): assert "Set belief is_hot with arguments ['kitchen']" in caplog.text assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text + # def test_responded_unset(belief_setter, mock_agent): # # Arrange # new_beliefs = {"user_said": ["message"]} diff --git a/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py b/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py index 622aefd..e842f5c 100644 --- a/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py +++ b/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py @@ -1,10 +1,12 @@ import json -import logging -from unittest.mock import MagicMock, AsyncMock, call +from unittest.mock import AsyncMock, MagicMock import pytest -from control_backend.agents.belief_collector.behaviours.continuous_collect import ContinuousBeliefCollector +from control_backend.agents.belief_collector.behaviours.continuous_collect import ( + ContinuousBeliefCollector, +) + @pytest.fixture def mock_agent(mocker): @@ -13,18 +15,20 @@ def mock_agent(mocker): agent.jid = "belief_collector_agent@test" return agent + @pytest.fixture def continuous_collector(mock_agent, mocker): """Fixture to create an instance of ContinuousBeliefCollector with a mocked agent.""" # Patch asyncio.sleep to prevent tests from actually waiting mocker.patch("asyncio.sleep", return_value=None) - + collector = ContinuousBeliefCollector() collector.agent = mock_agent # Mock the receive method, we will control its return value in each test collector.receive = AsyncMock() return collector + @pytest.mark.asyncio async def test_run_no_message_received(continuous_collector, mocker): """ @@ -40,6 +44,7 @@ async def test_run_no_message_received(continuous_collector, mocker): # Assert continuous_collector._process_message.assert_not_called() + @pytest.mark.asyncio async def test_run_message_received(continuous_collector, mocker): """ @@ -55,7 +60,8 @@ async def test_run_message_received(continuous_collector, mocker): # Assert continuous_collector._process_message.assert_awaited_once_with(mock_msg) - + + @pytest.mark.asyncio async def test_process_message_invalid(continuous_collector, mocker): """ @@ -66,15 +72,18 @@ async def test_process_message_invalid(continuous_collector, mocker): msg = MagicMock() msg.body = invalid_json msg.sender = "belief_text_agent_mock@test" - - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") - + + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) + # Act await continuous_collector._process_message(msg) # Assert logger_mock.warning.assert_called_once() + def test_get_sender_from_message(continuous_collector): """ Test that _sender_node correctly extracts the sender node from the message JID. @@ -89,6 +98,7 @@ def test_get_sender_from_message(continuous_collector): # Assert assert sender_node == "agent_node" + @pytest.mark.asyncio async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker): msg = MagicMock() @@ -98,6 +108,7 @@ async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker await continuous_collector._process_message(msg) spy.assert_awaited_once() + @pytest.mark.asyncio async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker): msg = MagicMock() @@ -107,6 +118,7 @@ async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mock await continuous_collector._process_message(msg) spy.assert_awaited_once() + @pytest.mark.asyncio async def test_routes_to_handle_emo_text(continuous_collector, mocker): msg = MagicMock() @@ -116,50 +128,64 @@ async def test_routes_to_handle_emo_text(continuous_collector, mocker): await continuous_collector._process_message(msg) spy.assert_awaited_once() + @pytest.mark.asyncio async def test_unrecognized_message_logs_info(continuous_collector, mocker): msg = MagicMock() msg.body = json.dumps({"type": "something_else"}) msg.sender = "x@test" - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._process_message(msg) logger_mock.info.assert_any_call( - "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", "x", "something_else" + "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", + "x", + "something_else", ) @pytest.mark.asyncio async def test_belief_text_no_beliefs(continuous_collector, mocker): msg_payload = {"type": "belief_extraction_text"} # no 'beliefs' - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._handle_belief_text(msg_payload, "origin_node") logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.") + @pytest.mark.asyncio async def test_belief_text_beliefs_not_dict(continuous_collector, mocker): payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]} - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._handle_belief_text(payload, "origin") - logger_mock.warning.assert_any_call("BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]) + logger_mock.warning.assert_any_call( + "BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"] + ) + @pytest.mark.asyncio async def test_belief_text_values_not_lists(continuous_collector, mocker): payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}} - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._handle_belief_text(payload, "origin") logger_mock.warning.assert_any_call( "BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"} ) + @pytest.mark.asyncio async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker): - payload = { - "type": "belief_extraction_text", - "beliefs": {"user_said": ["hello test", "No"]} - } - # Your code calls self.send(..); patch it (or switch implementation to self.agent.send and patch that) + payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}} continuous_collector.send = AsyncMock() - logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock") logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1) @@ -169,12 +195,14 @@ async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, # make sure we attempted a send continuous_collector.send.assert_awaited_once() + @pytest.mark.asyncio async def test_send_beliefs_noop_on_empty(continuous_collector): continuous_collector.send = AsyncMock() await continuous_collector._send_beliefs_to_bdi([], origin="o") continuous_collector.send.assert_not_awaited() + # @pytest.mark.asyncio # async def test_send_beliefs_sends_json_packet(continuous_collector): # # Patch .send and capture the message body @@ -191,19 +219,22 @@ async def test_send_beliefs_noop_on_empty(continuous_collector): # assert "belief_packet" in json.loads(sent["body"])["type"] # assert json.loads(sent["body"])["beliefs"] == beliefs + def test_sender_node_no_sender_returns_literal(continuous_collector): msg = MagicMock() msg.sender = None assert continuous_collector._sender_node(msg) == "no_sender" + def test_sender_node_without_at(continuous_collector): msg = MagicMock() msg.sender = "localpartonly" assert continuous_collector._sender_node(msg) == "localpartonly" + @pytest.mark.asyncio async def test_belief_text_coerces_non_strings(continuous_collector, mocker): payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}} continuous_collector.send = AsyncMock() await continuous_collector._handle_belief_text(payload, "origin") - continuous_collector.send.assert_awaited_once() + continuous_collector.send.assert_awaited_once() diff --git a/test/unit/agents/transcription/test_speech_recognizer.py b/test/unit/agents/transcription/test_speech_recognizer.py index 6e7cde0..88a5ac2 100644 --- a/test/unit/agents/transcription/test_speech_recognizer.py +++ b/test/unit/agents/transcription/test_speech_recognizer.py @@ -6,7 +6,7 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper def test_estimate_max_tokens(): """Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens.""" - audio = np.empty(shape=(60*16_000), dtype=np.float32) + audio = np.empty(shape=(60 * 16_000), dtype=np.float32) actual = SpeechRecognizer._estimate_max_tokens(audio) @@ -16,7 +16,7 @@ def test_estimate_max_tokens(): def test_get_decode_options(): """Check whether the right decode options are given under different scenarios.""" - audio = np.empty(shape=(60*16_000), dtype=np.float32) + audio = np.empty(shape=(60 * 16_000), dtype=np.float32) # With the defaults, it should limit output length based on input size recognizer = OpenAIWhisperSpeechRecognizer()