diff --git a/.githooks/check-branch-name.sh b/.githooks/check-branch-name.sh new file mode 100755 index 0000000..6a6669a --- /dev/null +++ b/.githooks/check-branch-name.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash + +# This script checks if the current branch name follows the specified format. +# It's designed to be used as a 'pre-commit' git hook. + +# Format: / +# Example: feat/add-user-login + +# --- Configuration --- +# An array of allowed commit types +ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) +# An array of branches to ignore +IGNORED_BRANCHES=(main dev demo) + +# --- Colors for Output --- +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# --- Helper Functions --- +error_exit() { + echo -e "${RED}ERROR: $1${NC}" >&2 + echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2 + exit 1 +} + +# --- Main Logic --- + +# 1. Get the current branch name +BRANCH_NAME=$(git symbolic-ref --short HEAD) + +# 2. Check if the current branch is in the ignored list +for ignored_branch in "${IGNORED_BRANCHES[@]}"; do + if [ "$BRANCH_NAME" == "$ignored_branch" ]; then + echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}" + exit 0 + fi +done + +# 3. Validate the overall structure: / +if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then + error_exit "Branch name must be in the format: /\nExample: feat/add-user-login" +fi + +# 4. Extract the type and description +TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1) +DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-) + +# 5. Validate the +type_valid=false +for allowed_type in "${ALLOWED_TYPES[@]}"; do + if [ "$TYPE" == "$allowed_type" ]; then + type_valid=true + break + fi +done + +if [ "$type_valid" == false ]; then + error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}" +fi + +# 6. Validate the +# Regex breakdown: +# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word). +# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times. +# $ - End of the string. +# This entire pattern enforces 1 to 6 words total, separated by dashes. +DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$" + +if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then + error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature" +fi + +# If all checks pass, exit successfully +echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}" +exit 0 diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh new file mode 100755 index 0000000..eacf2a8 --- /dev/null +++ b/.githooks/check-commit-msg.sh @@ -0,0 +1,135 @@ +#!/usr/bin/env bash + +# This script checks if a commit message follows the specified format. +# It's designed to be used as a 'commit-msg' git hook. + +# Format: +# : +# +# [optional] +# +# [ref/close]: + +# --- Configuration --- +# An array of allowed commit types +ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) + +# --- Colors for Output --- +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# The first argument to the hook is the path to the file containing the commit message +COMMIT_MSG_FILE=$1 + +# --- Automated Commit Detection --- + +# Read the first line (header) for initial checks +HEADER=$(head -n 1 "$COMMIT_MSG_FILE") + +# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) +# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." +MERGE_PATTERN="^Merge (branch|pull request|tag) .*" +if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then + echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Revert commits +# Example: "Revert "feat: add new feature"" +REVERT_PATTERN="^Revert \".*\"" +if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then + echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Cherry-pick commits (this pattern appears at the end of the message) +# Example: "(cherry picked from commit deadbeef...)" +# We use grep -q to search the whole file quietly. +CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)" +if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then + echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# Check for Squash +# Example: "Squash commits ..." +SQUASH_PATTERN="^Squash .+" +if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then + echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}" + exit 0 +fi + +# --- Validation Functions --- + +# Function to print an error message and exit +# Usage: error_exit "Your error message here" +error_exit() { + # >&2 redirects echo to stderr + echo -e "${RED}ERROR: $1${NC}" >&2 + echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2 + exit 1 +} + +# --- Main Logic --- + +# 1. Read the header (first line) of the commit message +HEADER=$(head -n 1 "$COMMIT_MSG_FILE") + +# 2. Validate the header format: : +# Regex breakdown: +# ^(type1|type2|...) - Starts with one of the allowed types +# : - Followed by a literal colon +# \s - Followed by a single space +# .+ - Followed by one or more characters for the description +# $ - End of the line +TYPES_REGEX=$( + IFS="|" + echo "${ALLOWED_TYPES[*]}" +) +HEADER_REGEX="^($TYPES_REGEX): .+$" + +if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then + error_exit "Invalid header format.\n\nHeader must be in the format: : \nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature" +fi + +# Only validate footer if commit type is not chore +TYPE=$(echo "$HEADER" | cut -d':' -f1) +if [ "$TYPE" != "chore" ]; then + # 3. Validate the footer (last line) of the commit message + FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE") + + # Regex breakdown: + # ^(ref|close) - Starts with 'ref' or 'close' + # : - Followed by a literal colon + # \s - Followed by a single space + # N25B- - Followed by the literal string 'N25B-' + # [0-9]+ - Followed by one or more digits + # $ - End of the line + FOOTER_REGEX="^(ref|close): N25B-[0-9]+$" + + if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then + error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: \nExample: ref: N25B-123" + fi +fi + +# 4. If the message has more than 2 lines, validate the separator +# A blank line must exist between the header and the body. +LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace + +# We only care if there is a body. Header + Footer = 2 lines. +# Header + Blank Line + Body... + Footer > 2 lines. +if [ "$LINE_COUNT" -gt 2 ]; then + # Get the second line + SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE") + + # Check if the second line is NOT empty. If it's not, it's an error. + if [ -n "$SECOND_LINE" ]; then + error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present." + fi +fi + +# If all checks pass, exit with success +echo -e "${GREEN}Commit message is valid.${NC}" +exit 0 diff --git a/.githooks/commit-msg b/.githooks/commit-msg deleted file mode 100644 index 41992ad..0000000 --- a/.githooks/commit-msg +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/sh - -commit_msg_file=$1 -commit_msg=$(cat "$commit_msg_file") - -if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then - if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then - exit 0 - else - echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000" - exit 1 - fi -else - echo "❌ Commit message invalid! Must start with : " - exit 1 -fi \ No newline at end of file diff --git a/.githooks/pre-commit b/.githooks/pre-commit deleted file mode 100644 index 7e94937..0000000 --- a/.githooks/pre-commit +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/sh - -# Get current branch -branch=$(git rev-parse --abbrev-ref HEAD) - -if echo "$branch" | grep -Eq "(dev|main)"; then - echo 0 -fi - -# allowed pattern -if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then - exit 0 -else - echo "❌ Invalid branch name: $branch" - echo "Branch must be named / (must have one to six words separated by a dash)" - exit 1 -fi \ No newline at end of file diff --git a/.githooks/prepare-commit-msg b/.githooks/prepare-commit-msg deleted file mode 100644 index 5b706c1..0000000 --- a/.githooks/prepare-commit-msg +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/sh - -echo "#: - -#[optional body] - -#[optional footer(s)] - -#[ref/close]: " > $1 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6ed188..41710dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,24 @@ repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.14.2 - hooks: - # Run the linter. - - id: ruff-check - args: [ --fix ] - # Run the formatter. - - id: ruff-format \ No newline at end of file + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.14.2 + hooks: + # Run the linter. + - id: ruff-check + # Run the formatter. + - id: ruff-format + # Configure local hooks + - repo: local + hooks: + - id: check-commit-msg + name: Check commit message format + entry: .githooks/check-commit-msg.sh + language: script + stages: [commit-msg] + - id: check-branch-name + name: Check branch name format + entry: .githooks/check-branch-name.sh + language: script + stages: [pre-commit] + always_run: true + pass_filenames: false diff --git a/README.md b/README.md index 45f8f98..d20b36d 100644 --- a/README.md +++ b/README.md @@ -47,21 +47,19 @@ Or for integration tests: uv run --group integration-test pytest test/integration ``` -## GitHooks +## Git Hooks -To activate automatic commits/branch name checks run: +To activate automatic linting, formatting, branch name checks and commit message checks, run: -```shell -git config --local core.hooksPath .githooks +```bash +uv run pre-commit install +uv run pre-commit install --hook-type commit-msg ``` -If your commit fails its either: -branch name != /description-of-branch , -commit name != : description of the commit. - : N25B-Num's +You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running: -To add automatic linting and formatting, run: +```bash +git config --local --unset core.hooksPath +``` -```shell -uv run pre-commit install -``` \ No newline at end of file +Then run the pre-commit install commands again. diff --git a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py index dc6e862..71e69c6 100644 --- a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py +++ b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py @@ -1,8 +1,10 @@ import logging from spade.behaviour import CyclicBehaviour +from spade.message import Message from control_backend.core.config import settings +from control_backend.schemas.ri_message import SpeechCommand class ReceiveLLMResponseBehaviour(CyclicBehaviour): @@ -10,7 +12,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour): Adds behavior to receive responses from the LLM Agent. """ - logger = logging.getLogger("BDI/LLM Reciever") + logger = logging.getLogger("BDI/LLM Receiver") async def run(self): msg = await self.receive(timeout=2) @@ -22,7 +24,20 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour): case settings.agent_settings.llm_agent_name: content = msg.body self.logger.info("Received LLM response: %s", content) - # Here the BDI can pass the message back as a response + + speech_command = SpeechCommand(data=content) + + message = Message( + to=settings.agent_settings.ri_command_agent_name + + "@" + + settings.agent_settings.host, + sender=self.agent.jid, + body=speech_command.model_dump_json(), + ) + + self.logger.debug("Sending message: %s", message) + + await self.send(message) case _: self.logger.debug("Not from the llm, discarding message") pass diff --git a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py index 913bc44..ed06463 100644 --- a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py +++ b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py @@ -13,23 +13,23 @@ class BeliefFromText(CyclicBehaviour): # TODO: LLM prompt nog hardcoded llm_instruction_prompt = """ - You are an information extraction assistent for a BDI agent. - Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules: - You will receive a JSON object with "beliefs" - (a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript). + You are an information extraction assistent for a BDI agent. Your task is to extract values \ + from a user's text to bind a list of ungrounded beliefs. Rules: + You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \ + and "text" (user's transcript). Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs. A single piece of text might contain multiple instances that match a belief. Respond ONLY with a single JSON object. The JSON object's keys should be the belief functors (e.g., "weather"). The value for each key must be a list of lists. - Each inner list must contain the extracted arguments - (as strings) for one instance of that belief. - CRITICAL: If no information in the text matches a belief, - DO NOT include that key in your response. + Each inner list must contain the extracted arguments (as strings) for one instance \ + of that belief. + CRITICAL: If no information in the text matches a belief, DO NOT include that key \ + in your response. """ - # on_start agent receives message containing the beliefs to look out - # for and sets up the LLM with instruction prompt + # on_start agent receives message containing the beliefs to look out for and + # sets up the LLM with instruction prompt # async def on_start(self): # msg = await self.receive(timeout=0.1) # self.beliefs = dict uit message diff --git a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py index ada1c7a..eb3ee5d 100644 --- a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py +++ b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py @@ -70,8 +70,7 @@ class ContinuousBeliefCollector(CyclicBehaviour): Expected payload: { "type": "belief_extraction_text", - "beliefs": {"user_said": ["hello"","Can you help me?", - "stop talking to me","No","Pepper do a dance"]} + "beliefs": {"user_said": ["Can you help me?"]} } diff --git a/src/control_backend/agents/llm/llm.py b/src/control_backend/agents/llm/llm.py index c3c17ab..4487b23 100644 --- a/src/control_backend/agents/llm/llm.py +++ b/src/control_backend/agents/llm/llm.py @@ -3,8 +3,10 @@ LLM Agent module for routing text queries from the BDI Core Agent to a local LLM service and returning its responses back to the BDI Core Agent. """ +import json import logging -from typing import Any +import re +from collections.abc import AsyncGenerator import httpx from spade.agent import Agent @@ -54,11 +56,16 @@ class LLMAgent(Agent): async def _process_bdi_message(self, message: Message): """ - Forwards user text to the LLM and replies with the generated text. + Forwards user text from the BDI to the LLM and replies with the generated text in chunks + separated by punctuation. """ user_text = message.body - llm_response = await self._query_llm(user_text) - await self._reply(llm_response) + # Consume the streaming generator and send a reply for every chunk + async for chunk in self._query_llm(user_text): + await self._reply(chunk) + self.agent.logger.debug( + "Finished processing BDI message. Response sent in chunks to BDI Core Agent." + ) async def _reply(self, msg: str): """ @@ -69,48 +76,89 @@ class LLMAgent(Agent): body=msg, ) await self.send(reply) - self.agent.logger.info("Reply sent to BDI Core Agent") - async def _query_llm(self, prompt: str) -> str: + async def _query_llm(self, prompt: str) -> AsyncGenerator[str]: """ - Sends a chat completion request to the local LLM service. + Sends a chat completion request to the local LLM service and streams the response by + yielding fragments separated by punctuation like. :param prompt: Input text prompt to pass to the LLM. - :return: LLM-generated content or fallback message. + :yield: Fragments of the LLM-generated content. """ - async with httpx.AsyncClient(timeout=120.0) as client: - # Example dynamic content for future (optional) + instructions = LLMInstructions( + "- Be friendly and respectful.\n" + "- Make the conversation feel natural and engaging.\n" + "- Speak like a pirate.\n" + "- When the user asks what you can do, tell them.", + "- Try to learn the user's name during conversation.\n" + "- Suggest playing a game of asking yes or no questions where you think of a word " + "and the user must guess it.", + ) + messages = [ + { + "role": "developer", + "content": instructions.build_developer_instruction(), + }, + { + "role": "user", + "content": prompt, + }, + ] - instructions = LLMInstructions() - developer_instruction = instructions.build_developer_instruction() + try: + current_chunk = "" + async for token in self._stream_query_llm(messages): + current_chunk += token - response = await client.post( + # Stream the message in chunks separated by punctuation. + # We include the delimiter in the emitted chunk for natural flow. + pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL) + for m in pattern.finditer(current_chunk): + chunk = m.group(0) + if chunk: + yield current_chunk + current_chunk = "" + + # Yield any remaining tail + if current_chunk: + yield current_chunk + except httpx.HTTPError as err: + self.agent.logger.error("HTTP error.", exc_info=err) + yield "LLM service unavailable." + except Exception as err: + self.agent.logger.error("Unexpected error.", exc_info=err) + yield "Error processing the request." + + async def _stream_query_llm(self, messages) -> AsyncGenerator[str]: + """Raises httpx.HTTPError when the API gives an error.""" + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream( + "POST", settings.llm_settings.local_llm_url, - headers={"Content-Type": "application/json"}, json={ "model": settings.llm_settings.local_llm_model, - "messages": [ - {"role": "developer", "content": developer_instruction}, - {"role": "user", "content": prompt}, - ], + "messages": messages, "temperature": 0.3, + "stream": True, }, - ) - - try: + ) as response: response.raise_for_status() - data: dict[str, Any] = response.json() - return ( - data.get("choices", [{}])[0] - .get("message", {}) - .get("content", "No response") - ) - except httpx.HTTPError as err: - self.agent.logger.error("HTTP error: %s", err) - return "LLM service unavailable." - except Exception as err: - self.agent.logger.error("Unexpected error: %s", err) - return "Error processing the request." + + async for line in response.aiter_lines(): + if not line or not line.startswith("data: "): + continue + + data = line[len("data: ") :] + if data.strip() == "[DONE]": + break + + try: + event = json.loads(data) + delta = event.get("choices", [{}])[0].get("delta", {}).get("content") + if delta: + yield delta + except json.JSONDecodeError: + self.agent.logger.error("Failed to parse LLM response: %s", data) async def setup(self): """ diff --git a/src/control_backend/agents/llm/llm_instructions.py b/src/control_backend/agents/llm/llm_instructions.py index 9636d88..6922fca 100644 --- a/src/control_backend/agents/llm/llm_instructions.py +++ b/src/control_backend/agents/llm/llm_instructions.py @@ -28,7 +28,9 @@ class LLMInstructions: """ sections = [ "You are a Pepper robot engaging in natural human conversation.", - "Keep responses between 1–5 sentences, unless instructed otherwise.\n", + "Keep responses between 1–3 sentences, unless told otherwise.\n", + "You're given goals to reach. Reach them in order, but make the conversation feel " + "natural. Some turns you should not try to achieve your goals.\n", ] if self.norms: diff --git a/src/control_backend/agents/mock_agents/belief_text_mock.py b/src/control_backend/agents/mock_agents/belief_text_mock.py index 769b263..27c5e49 100644 --- a/src/control_backend/agents/mock_agents/belief_text_mock.py +++ b/src/control_backend/agents/mock_agents/belief_text_mock.py @@ -11,8 +11,9 @@ class BeliefTextAgent(Agent): class SendOnceBehaviourBlfText(OneShotBehaviour): async def run(self): to_jid = ( - f"{settings.agent_settings.belief_collector_agent_name}" - f"@{settings.agent_settings.host}" + settings.agent_settings.belief_collector_agent_name + + "@" + + settings.agent_settings.host ) # Send multiple beliefs in one JSON payload diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 51e148f..98d3ef3 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -1,6 +1,7 @@ import json import logging +import spade.agent import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour @@ -32,6 +33,8 @@ class RICommandAgent(Agent): self.bind = bind class SendCommandsBehaviour(CyclicBehaviour): + """Behaviour for sending commands received from the UI.""" + async def run(self): """ Run the command publishing loop indefinetely. @@ -50,6 +53,18 @@ class RICommandAgent(Agent): except Exception as e: logger.error("Error processing message: %s", e) + class SendPythonCommandsBehaviour(CyclicBehaviour): + """Behaviour for sending commands received from other Python agents.""" + + async def run(self): + message: spade.agent.Message = await self.receive(timeout=0.1) + if message and message.to == self.agent.jid: + try: + speech_command = SpeechCommand.model_validate_json(message.body) + await self.agent.pubsocket.send_json(speech_command.model_dump()) + except Exception as e: + logger.error("Error processing message: %s", e) + async def setup(self): """ Setup the command agent @@ -73,5 +88,6 @@ class RICommandAgent(Agent): # Add behaviour to our agent commands_behaviour = self.SendCommandsBehaviour() self.add_behaviour(commands_behaviour) + self.add_behaviour(self.SendPythonCommandsBehaviour()) logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 9d9170f..d467203 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -63,7 +63,25 @@ class RICommunicationAgent(Agent): # We didnt get a reply :( except TimeoutError: logger.info("No ping retrieved in 3 seconds, killing myself.") - self.kill() + + # Tell UI we're disconnected. + topic = b"ping" + data = json.dumps(False).encode() + if self.agent.pub_socket is None: + logger.error("communication agent pub socket not correctly initialized.") + else: + try: + await asyncio.wait_for( + self.agent.pub_socket.send_multipart([topic, data]), 5 + ) + except TimeoutError: + logger.error( + "Initial connection ping for router timed" + " out in ri_communication_agent." + ) + + # Try to reboot. + self.agent.setup() logger.debug('Received message "%s"', message) if "endpoint" not in message: diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index 19d82ff..527d371 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC): def _estimate_max_tokens(audio: np.ndarray) -> int: """ Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking - rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens. + rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens. :param audio: The audio sample (16 kHz) to use for length estimation. :return: The estimated length of the transcribed audio in tokens. """ length_seconds = len(audio) / 16_000 length_minutes = length_seconds / 60 - word_count = length_minutes * 300 + word_count = length_minutes * 450 token_count = word_count / 3 * 4 - return int(token_count) + return int(token_count) + 10 def _get_decode_options(self, audio: np.ndarray) -> dict: """ @@ -85,9 +85,10 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() return mlx_whisper.transcribe( - audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio) - )["text"] - return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip() + audio, + path_or_hf_repo=self.model_name, + **self._get_decode_options(audio), + )["text"].strip() class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): @@ -103,6 +104,4 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return whisper.transcribe( - self.model, audio, decode_options=self._get_decode_options(audio) - )["text"] + return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"] diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index 530bd68..25103a4 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -58,6 +58,10 @@ class TranscriptionAgent(Agent): audio = await self.audio_in_socket.recv() audio = np.frombuffer(audio, dtype=np.float32) speech = await self._transcribe(audio) + if not speech: + logger.info("Nothing transcribed.") + return + logger.info("Transcribed speech: %s", speech) await self._share_transcription(speech) diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index f16abf4..9cf2adf 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -54,8 +54,20 @@ class Streaming(CyclicBehaviour): self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = 100 # Used to allow small pauses in speech + self._ready = False + + async def reset(self): + """Clears the ZeroMQ queue and tells this behavior to start.""" + discarded = 0 + while await self.audio_in_poller.poll(1) is not None: + discarded += 1 + logging.info(f"Discarded {discarded} audio packets before starting.") + self._ready = True async def run(self) -> None: + if not self._ready: + return + data = await self.audio_in_poller.poll() if data is None: if len(self.audio_buffer) > 0: @@ -107,6 +119,8 @@ class VADAgent(Agent): self.audio_in_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None + self.streaming_behaviour: Streaming | None = None + async def stop(self): """ Stop listening to audio, stop publishing audio, close sockets. @@ -149,8 +163,8 @@ class VADAgent(Agent): return audio_out_address = f"tcp://localhost:{audio_out_port}" - streaming = Streaming(self.audio_in_socket, self.audio_out_socket) - self.add_behaviour(streaming) + self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket) + self.add_behaviour(self.streaming_behaviour) # Start agents dependent on the output audio fragments here transcriber = TranscriptionAgent(audio_out_address) diff --git a/src/control_backend/api/v1/endpoints/robot.py b/src/control_backend/api/v1/endpoints/robot.py index 7b1c2f8..b2ca053 100644 --- a/src/control_backend/api/v1/endpoints/robot.py +++ b/src/control_backend/api/v1/endpoints/robot.py @@ -22,8 +22,8 @@ async def receive_command(command: SpeechCommand, request: Request): topic = b"command" # TODO: Check with Kasper - pub_socket: Socket = request.app.state.internal_comm_socket - pub_socket.send_multipart([topic, command.model_dump_json().encode()]) + pub_socket: Socket = request.app.state.endpoints_pub_socket + await pub_socket.send_multipart([topic, command.model_dump_json().encode()]) return {"status": "Command received"} diff --git a/src/control_backend/main.py b/src/control_backend/main.py index b8e3ef3..5409f75 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -14,8 +14,6 @@ from control_backend.agents.bdi.bdi_core import BDICoreAgent from control_backend.agents.bdi.text_extractor import TBeliefExtractor from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent from control_backend.agents.llm.llm import LLMAgent - -# Internal imports from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.vad_agent import VADAgent from control_backend.api.v1.router import api_router @@ -99,6 +97,8 @@ async def lifespan(app: FastAPI): _temp_vad_agent = VADAgent("tcp://localhost:5558", False) await _temp_vad_agent.start() + logger.info("VAD agent started, now making ready...") + await _temp_vad_agent.streaming_behaviour.reset() yield diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index a4902b5..ea9fca9 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -7,25 +7,21 @@ import zmq from control_backend.agents.ri_command_agent import RICommandAgent -@pytest.mark.asyncio -async def test_setup_bind(monkeypatch): - """Test setup with bind=True""" - fake_socket = MagicMock() - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket +@pytest.fixture +def zmq_context(mocker): + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context - # Patch Context.instance() to return fake_context - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) + +@pytest.mark.asyncio +async def test_setup_bind(zmq_context, mocker): + """Test setup with bind=True""" + fake_socket = zmq_context.return_value.socket.return_value agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) - - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")), - ) + settings = mocker.patch("control_backend.agents.ri_command_agent.settings") + settings.zmq_settings.internal_sub_address = "tcp://internal:1234" await agent.setup() @@ -36,23 +32,13 @@ async def test_setup_bind(monkeypatch): @pytest.mark.asyncio -async def test_setup_connect(monkeypatch): +async def test_setup_connect(zmq_context, mocker): """Test setup with bind=False""" - fake_socket = MagicMock() - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - - # Patch Context.instance() to return fake_context - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) + fake_socket = zmq_context.return_value.socket.return_value agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) - monkeypatch.setattr( - "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")), - ) + settings = mocker.patch("control_backend.agents.ri_command_agent.settings") + settings.zmq_settings.internal_sub_address = "tcp://internal:1234" await agent.setup() diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index 9febf20..33051c8 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -84,25 +84,24 @@ def fake_json_invalid_id_negototiate(): ) +@pytest.fixture +def zmq_context(mocker): + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context + + @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_1(monkeypatch): +async def test_setup_creates_socket_and_negotiate_1(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_1() fake_socket.send_multipart = AsyncMock() - # Mock context.socket to return our fake socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -135,24 +134,16 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_2(monkeypatch): +async def test_setup_creates_socket_and_negotiate_2(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_2() fake_socket.send_multipart = AsyncMock() - # Mock context.socket to return our fake socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -185,24 +176,16 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): """ Test the functionality of setup with incorrect negotiation message """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_wrong_negototiate_1() fake_socket.send_multipart = AsyncMock() - # Mock context.socket to return our fake socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - # Mock RICommandAgent agent startup # We are sending wrong negotiation info to the communication agent, @@ -235,24 +218,16 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_4(monkeypatch): +async def test_setup_creates_socket_and_negotiate_4(zmq_context): """ Test the setup of the communication agent with different bind value """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_3() fake_socket.send_multipart = AsyncMock() - # Mock context.socket to return our fake socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -284,24 +259,16 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_5(monkeypatch): +async def test_setup_creates_socket_and_negotiate_5(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_4() fake_socket.send_multipart = AsyncMock() - # Mock context.socket to return our fake socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -333,24 +300,16 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_6(monkeypatch): +async def test_setup_creates_socket_and_negotiate_6(zmq_context): """ Test the setup of the communication agent """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_5() fake_socket.send_multipart = AsyncMock() - # Mock context.socket to return our fake socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - # Mock RICommandAgent agent startup with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True @@ -382,28 +341,20 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): """ Test the functionality of setup with incorrect id """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_invalid_id_negototiate() fake_socket.send_multipart = AsyncMock() - # Mock context.socket to return our fake socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - # Mock RICommandAgent agent startup # We are sending wrong negotiation info to the communication agent, - # so we should retry and expect a etter response, within a limited time. + # so we should retry and expect a better response, within a limited time. with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: @@ -430,24 +381,16 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): +async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog): """ Test the functionality of setup with incorrect negotiation message """ # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.send_multipart = AsyncMock() - # Mock context.socket to return our fake socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True ) as MockCommandAgent: @@ -534,8 +477,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog): @pytest.mark.asyncio -async def test_listen_behaviour_timeout(caplog): - fake_socket = AsyncMock() +async def test_listen_behaviour_timeout(zmq_context, caplog): + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # recv_json will never resolve, simulate timeout fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) @@ -585,20 +528,13 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): @pytest.mark.asyncio -async def test_setup_unexpected_exception(monkeypatch, caplog): - fake_socket = MagicMock() +async def test_setup_unexpected_exception(zmq_context, caplog): + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() # Simulate unexpected exception during recv_json() fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) fake_socket.send_multipart = AsyncMock() - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - agent = RICommunicationAgent( "test@server", "password", @@ -614,9 +550,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog): @pytest.mark.asyncio -async def test_setup_unpacking_exception(monkeypatch, caplog): +async def test_setup_unpacking_exception(zmq_context, caplog): # --- Arrange --- - fake_socket = MagicMock() + fake_socket = zmq_context.return_value.socket.return_value fake_socket.send_json = AsyncMock() fake_socket.send_multipart = AsyncMock() @@ -627,14 +563,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): } # missing 'port' and 'bind' fake_socket.recv_json = AsyncMock(return_value=malformed_data) - # Patch context.socket - fake_context = MagicMock() - fake_context.socket.return_value = fake_socket - monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.Context", - MagicMock(instance=MagicMock(return_value=fake_context)), - ) - # Patch RICommandAgent so it won't actually start with patch( "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True diff --git a/test/integration/agents/vad_agent/test_vad_agent.py b/test/integration/agents/vad_agent/test_vad_agent.py index 54c9d82..0e1fae2 100644 --- a/test/integration/agents/vad_agent/test_vad_agent.py +++ b/test/integration/agents/vad_agent/test_vad_agent.py @@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent @pytest.fixture def zmq_context(mocker): - return mocker.patch("control_backend.agents.vad_agent.zmq_context") + mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") + mock_context.return_value = MagicMock() + return mock_context @pytest.fixture @@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool): assert vad_agent.audio_in_socket is not None - zmq_context.socket.assert_called_once_with(zmq.SUB) - zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") + zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) + zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with( + zmq.SUBSCRIBE, + "", + ) if do_bind: - zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") + zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345") else: - zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") + zmq_context.return_value.socket.return_value.connect.assert_called_once_with( + "tcp://localhost:12345" + ) def test_out_socket_creation(zmq_context): @@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context): assert vad_agent.audio_out_socket is not None - zmq_context.socket.assert_called_once_with(zmq.PUB) - zmq_context.socket.return_value.bind_to_random_port.assert_called_once() + zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) + zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() @pytest.mark.asyncio @@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context): Test setup failure when the audio output socket cannot be created. """ with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: - zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = ( + zmq.ZMQBindError + ) vad_agent = VADAgent("tcp://localhost:12345", False) await vad_agent.setup() @@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent): Test that when the VAD agent is stopped, the sockets are closed correctly. """ vad_agent = VADAgent("tcp://localhost:12345", False) - zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) + zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( + 1000, + 10000, + ) await vad_agent.setup() await vad_agent.stop() - assert zmq_context.socket.return_value.close.call_count == 2 + assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert vad_agent.audio_in_socket is None assert vad_agent.audio_out_socket is None diff --git a/test/integration/agents/vad_agent/test_vad_with_audio.py b/test/integration/agents/vad_agent/test_vad_with_audio.py index 7d10aa3..fd7d4d7 100644 --- a/test/integration/agents/vad_agent/test_vad_with_audio.py +++ b/test/integration/agents/vad_agent/test_vad_with_audio.py @@ -48,6 +48,7 @@ async def test_real_audio(mocker): audio_out_socket = AsyncMock() vad_streamer = Streaming(audio_in_socket, audio_out_socket) + vad_streamer._ready = True for _ in audio_chunks: await vad_streamer.run() diff --git a/test/integration/api/endpoints/test_robot_endpoint.py b/test/integration/api/endpoints/test_robot_endpoint.py index 3fd175f..3a2df88 100644 --- a/test/integration/api/endpoints/test_robot_endpoint.py +++ b/test/integration/api/endpoints/test_robot_endpoint.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock import pytest from fastapi import FastAPI @@ -16,7 +16,6 @@ def app(): """ app = FastAPI() app.include_router(robot.router) - app.state.internal_comm_socket = MagicMock() # mock ZMQ socket return app @@ -26,32 +25,30 @@ def client(app): return TestClient(app) -def test_receive_command_endpoint(client, app): +def test_receive_command_success(client): """ - Test that a POST to /command sends the right multipart message - and returns a 202 with the expected JSON body. + Test for successful reception of a command. Ensures the status code is 202 and the response body + is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the + expected data. """ - mock_socket = app.state.internal_comm_socket + # Arrange + mock_pub_socket = AsyncMock() + client.app.state.endpoints_pub_socket = mock_pub_socket - # Prepare test payload that matches SpeechCommand - payload = {"endpoint": "actuate/speech", "data": "yooo"} + command_data = {"endpoint": "actuate/speech", "data": "This is a test"} + speech_command = SpeechCommand(**command_data) - # Send POST request - response = client.post("/command", json=payload) + # Act + response = client.post("/command", json=command_data) - # Check response + # Assert assert response.status_code == 202 assert response.json() == {"status": "Command received"} - # Verify that the socket was called with the correct data - assert mock_socket.send_multipart.called, "Socket should be used to send data" - - args, kwargs = mock_socket.send_multipart.call_args - sent_data = args[0] - - assert sent_data[0] == b"command" - # Check JSON encoding roughly matches - assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand) + # Verify that the ZMQ socket was used correctly + mock_pub_socket.send_multipart.assert_awaited_once_with( + [b"command", speech_command.model_dump_json().encode()] + ) def test_receive_command_invalid_payload(client): diff --git a/test/integration/schemas/test_ri_message.py b/test/integration/schemas/test_ri_message.py index 966b582..5078f9a 100644 --- a/test/integration/schemas/test_ri_message.py +++ b/test/integration/schemas/test_ri_message.py @@ -16,12 +16,11 @@ def test_valid_speech_command_1(): command = valid_command_1() RIMessage.model_validate(command) SpeechCommand.model_validate(command) - assert True def test_invalid_speech_command_1(): command = invalid_command_1() RIMessage.model_validate(command) + with pytest.raises(ValidationError): SpeechCommand.model_validate(command) - assert True diff --git a/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py b/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py index 79957f0..e842f5c 100644 --- a/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py +++ b/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py @@ -182,8 +182,6 @@ async def test_belief_text_values_not_lists(continuous_collector, mocker): @pytest.mark.asyncio async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker): payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}} - # Your code calls self.send(..); patch it - # (or switch implementation to self.agent.send and patch that) continuous_collector.send = AsyncMock() logger_mock = mocker.patch( "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" diff --git a/test/unit/agents/test_vad_streaming.py b/test/unit/agents/test_vad_streaming.py index 9b38cd0..ab2da0d 100644 --- a/test/unit/agents/test_vad_streaming.py +++ b/test/unit/agents/test_vad_streaming.py @@ -21,7 +21,9 @@ def streaming(audio_in_socket, audio_out_socket): import torch torch.hub.load.return_value = (..., ...) # Mock - return Streaming(audio_in_socket, audio_out_socket) + streaming = Streaming(audio_in_socket, audio_out_socket) + streaming._ready = True + return streaming async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): diff --git a/test/unit/agents/transcription/test_speech_recognizer.py b/test/unit/agents/transcription/test_speech_recognizer.py index 88a5ac2..ab28dcf 100644 --- a/test/unit/agents/transcription/test_speech_recognizer.py +++ b/test/unit/agents/transcription/test_speech_recognizer.py @@ -5,12 +5,13 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper def test_estimate_max_tokens(): - """Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens.""" + """Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding, + expecting 610 tokens.""" audio = np.empty(shape=(60 * 16_000), dtype=np.float32) actual = SpeechRecognizer._estimate_max_tokens(audio) - assert actual == 400 + assert actual == 610 assert isinstance(actual, int)