Merge branch 'dev' into refactor/logging
This commit is contained in:
77
.githooks/check-branch-name.sh
Executable file
77
.githooks/check-branch-name.sh
Executable file
@@ -0,0 +1,77 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# This script checks if the current branch name follows the specified format.
|
||||||
|
# It's designed to be used as a 'pre-commit' git hook.
|
||||||
|
|
||||||
|
# Format: <type>/<short-description>
|
||||||
|
# Example: feat/add-user-login
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
# An array of allowed commit types
|
||||||
|
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
|
||||||
|
# An array of branches to ignore
|
||||||
|
IGNORED_BRANCHES=(main dev)
|
||||||
|
|
||||||
|
# --- Colors for Output ---
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# --- Helper Functions ---
|
||||||
|
error_exit() {
|
||||||
|
echo -e "${RED}ERROR: $1${NC}" >&2
|
||||||
|
echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Main Logic ---
|
||||||
|
|
||||||
|
# 1. Get the current branch name
|
||||||
|
BRANCH_NAME=$(git symbolic-ref --short HEAD)
|
||||||
|
|
||||||
|
# 2. Check if the current branch is in the ignored list
|
||||||
|
for ignored_branch in "${IGNORED_BRANCHES[@]}"; do
|
||||||
|
if [ "$BRANCH_NAME" == "$ignored_branch" ]; then
|
||||||
|
echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# 3. Validate the overall structure: <type>/<description>
|
||||||
|
if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then
|
||||||
|
error_exit "Branch name must be in the format: <type>/<short-description>\nExample: feat/add-user-login"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. Extract the type and description
|
||||||
|
TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1)
|
||||||
|
DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-)
|
||||||
|
|
||||||
|
# 5. Validate the <type>
|
||||||
|
type_valid=false
|
||||||
|
for allowed_type in "${ALLOWED_TYPES[@]}"; do
|
||||||
|
if [ "$TYPE" == "$allowed_type" ]; then
|
||||||
|
type_valid=true
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ "$type_valid" == false ]; then
|
||||||
|
error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 6. Validate the <short-description>
|
||||||
|
# Regex breakdown:
|
||||||
|
# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word).
|
||||||
|
# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times.
|
||||||
|
# $ - End of the string.
|
||||||
|
# This entire pattern enforces 1 to 6 words total, separated by dashes.
|
||||||
|
DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$"
|
||||||
|
|
||||||
|
if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then
|
||||||
|
error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If all checks pass, exit successfully
|
||||||
|
echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}"
|
||||||
|
exit 0
|
||||||
135
.githooks/check-commit-msg.sh
Executable file
135
.githooks/check-commit-msg.sh
Executable file
@@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# This script checks if a commit message follows the specified format.
|
||||||
|
# It's designed to be used as a 'commit-msg' git hook.
|
||||||
|
|
||||||
|
# Format:
|
||||||
|
# <type>: <short description>
|
||||||
|
#
|
||||||
|
# [optional]<body>
|
||||||
|
#
|
||||||
|
# [ref/close]: <issue identifier>
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
# An array of allowed commit types
|
||||||
|
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
|
||||||
|
|
||||||
|
# --- Colors for Output ---
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# The first argument to the hook is the path to the file containing the commit message
|
||||||
|
COMMIT_MSG_FILE=$1
|
||||||
|
|
||||||
|
# --- Automated Commit Detection ---
|
||||||
|
|
||||||
|
# Read the first line (header) for initial checks
|
||||||
|
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||||
|
|
||||||
|
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
|
||||||
|
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
|
||||||
|
MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
|
||||||
|
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
|
||||||
|
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for Revert commits
|
||||||
|
# Example: "Revert "feat: add new feature""
|
||||||
|
REVERT_PATTERN="^Revert \".*\""
|
||||||
|
if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then
|
||||||
|
echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for Cherry-pick commits (this pattern appears at the end of the message)
|
||||||
|
# Example: "(cherry picked from commit deadbeef...)"
|
||||||
|
# We use grep -q to search the whole file quietly.
|
||||||
|
CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)"
|
||||||
|
if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then
|
||||||
|
echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for Squash
|
||||||
|
# Example: "Squash commits ..."
|
||||||
|
SQUASH_PATTERN="^Squash .+"
|
||||||
|
if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then
|
||||||
|
echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# --- Validation Functions ---
|
||||||
|
|
||||||
|
# Function to print an error message and exit
|
||||||
|
# Usage: error_exit "Your error message here"
|
||||||
|
error_exit() {
|
||||||
|
# >&2 redirects echo to stderr
|
||||||
|
echo -e "${RED}ERROR: $1${NC}" >&2
|
||||||
|
echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Main Logic ---
|
||||||
|
|
||||||
|
# 1. Read the header (first line) of the commit message
|
||||||
|
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||||
|
|
||||||
|
# 2. Validate the header format: <type>: <description>
|
||||||
|
# Regex breakdown:
|
||||||
|
# ^(type1|type2|...) - Starts with one of the allowed types
|
||||||
|
# : - Followed by a literal colon
|
||||||
|
# \s - Followed by a single space
|
||||||
|
# .+ - Followed by one or more characters for the description
|
||||||
|
# $ - End of the line
|
||||||
|
TYPES_REGEX=$(
|
||||||
|
IFS="|"
|
||||||
|
echo "${ALLOWED_TYPES[*]}"
|
||||||
|
)
|
||||||
|
HEADER_REGEX="^($TYPES_REGEX): .+$"
|
||||||
|
|
||||||
|
if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then
|
||||||
|
error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Only validate footer if commit type is not chore
|
||||||
|
TYPE=$(echo "$HEADER" | cut -d':' -f1)
|
||||||
|
if [ "$TYPE" != "chore" ]; then
|
||||||
|
# 3. Validate the footer (last line) of the commit message
|
||||||
|
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE")
|
||||||
|
|
||||||
|
# Regex breakdown:
|
||||||
|
# ^(ref|close) - Starts with 'ref' or 'close'
|
||||||
|
# : - Followed by a literal colon
|
||||||
|
# \s - Followed by a single space
|
||||||
|
# N25B- - Followed by the literal string 'N25B-'
|
||||||
|
# [0-9]+ - Followed by one or more digits
|
||||||
|
# $ - End of the line
|
||||||
|
FOOTER_REGEX="^(ref|close): N25B-[0-9]+$"
|
||||||
|
|
||||||
|
if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then
|
||||||
|
error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. If the message has more than 2 lines, validate the separator
|
||||||
|
# A blank line must exist between the header and the body.
|
||||||
|
LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace
|
||||||
|
|
||||||
|
# We only care if there is a body. Header + Footer = 2 lines.
|
||||||
|
# Header + Blank Line + Body... + Footer > 2 lines.
|
||||||
|
if [ "$LINE_COUNT" -gt 2 ]; then
|
||||||
|
# Get the second line
|
||||||
|
SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE")
|
||||||
|
|
||||||
|
# Check if the second line is NOT empty. If it's not, it's an error.
|
||||||
|
if [ -n "$SECOND_LINE" ]; then
|
||||||
|
error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If all checks pass, exit with success
|
||||||
|
echo -e "${GREEN}Commit message is valid.${NC}"
|
||||||
|
exit 0
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
|
|
||||||
commit_msg_file=$1
|
|
||||||
commit_msg=$(cat "$commit_msg_file")
|
|
||||||
|
|
||||||
if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then
|
|
||||||
if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "❌ Commit message invalid! Must start with <type>: <description>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
|
|
||||||
# Get current branch
|
|
||||||
branch=$(git rev-parse --abbrev-ref HEAD)
|
|
||||||
|
|
||||||
if echo "$branch" | grep -Eq "(dev|main)"; then
|
|
||||||
echo 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# allowed pattern <type/>
|
|
||||||
if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo "❌ Invalid branch name: $branch"
|
|
||||||
echo "Branch must be named <type>/<description-of-branch> (must have one to six words separated by a dash)"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
|
|
||||||
echo "#<type>: <description>
|
|
||||||
|
|
||||||
#[optional body]
|
|
||||||
|
|
||||||
#[optional footer(s)]
|
|
||||||
|
|
||||||
#[ref/close]: <issue identifier>" > $1
|
|
||||||
@@ -1,10 +1,24 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
# Ruff version.
|
# Ruff version.
|
||||||
rev: v0.14.2
|
rev: v0.14.2
|
||||||
hooks:
|
hooks:
|
||||||
# Run the linter.
|
# Run the linter.
|
||||||
- id: ruff-check
|
- id: ruff-check
|
||||||
args: [ --fix ]
|
# Run the formatter.
|
||||||
# Run the formatter.
|
- id: ruff-format
|
||||||
- id: ruff-format
|
# Configure local hooks
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: check-commit-msg
|
||||||
|
name: Check commit message format
|
||||||
|
entry: .githooks/check-commit-msg.sh
|
||||||
|
language: script
|
||||||
|
stages: [commit-msg]
|
||||||
|
- id: check-branch-name
|
||||||
|
name: Check branch name format
|
||||||
|
entry: .githooks/check-branch-name.sh
|
||||||
|
language: script
|
||||||
|
stages: [pre-commit]
|
||||||
|
always_run: true
|
||||||
|
pass_filenames: false
|
||||||
|
|||||||
22
README.md
22
README.md
@@ -47,21 +47,19 @@ Or for integration tests:
|
|||||||
uv run --group integration-test pytest test/integration
|
uv run --group integration-test pytest test/integration
|
||||||
```
|
```
|
||||||
|
|
||||||
## GitHooks
|
## Git Hooks
|
||||||
|
|
||||||
To activate automatic commits/branch name checks run:
|
To activate automatic linting, formatting, branch name checks and commit message checks, run:
|
||||||
|
|
||||||
```shell
|
```bash
|
||||||
git config --local core.hooksPath .githooks
|
uv run pre-commit install
|
||||||
|
uv run pre-commit install --hook-type commit-msg
|
||||||
```
|
```
|
||||||
|
|
||||||
If your commit fails its either:
|
You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running:
|
||||||
branch name != <type>/description-of-branch ,
|
|
||||||
commit name != <type>: description of the commit.
|
|
||||||
<ref>: N25B-Num's
|
|
||||||
|
|
||||||
To add automatic linting and formatting, run:
|
```bash
|
||||||
|
git config --local --unset core.hooksPath
|
||||||
|
```
|
||||||
|
|
||||||
```shell
|
Then run the pre-commit install commands again.
|
||||||
uv run pre-commit install
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -9,18 +9,23 @@ from control_backend.core.config import settings
|
|||||||
class BeliefFromText(CyclicBehaviour):
|
class BeliefFromText(CyclicBehaviour):
|
||||||
# TODO: LLM prompt nog hardcoded
|
# TODO: LLM prompt nog hardcoded
|
||||||
llm_instruction_prompt = """
|
llm_instruction_prompt = """
|
||||||
You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
|
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
||||||
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript).
|
from a user's text to bind a list of ungrounded beliefs. Rules:
|
||||||
|
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
|
||||||
|
and "text" (user's transcript).
|
||||||
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
|
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
|
||||||
A single piece of text might contain multiple instances that match a belief.
|
A single piece of text might contain multiple instances that match a belief.
|
||||||
Respond ONLY with a single JSON object.
|
Respond ONLY with a single JSON object.
|
||||||
The JSON object's keys should be the belief functors (e.g., "weather").
|
The JSON object's keys should be the belief functors (e.g., "weather").
|
||||||
The value for each key must be a list of lists.
|
The value for each key must be a list of lists.
|
||||||
Each inner list must contain the extracted arguments (as strings) for one instance of that belief.
|
Each inner list must contain the extracted arguments (as strings) for one instance \
|
||||||
CRITICAL: If no information in the text matches a belief, DO NOT include that key in your response.
|
of that belief.
|
||||||
|
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
|
||||||
|
in your response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt
|
# on_start agent receives message containing the beliefs to look out for and
|
||||||
|
# sets up the LLM with instruction prompt
|
||||||
# async def on_start(self):
|
# async def on_start(self):
|
||||||
# msg = await self.receive(timeout=0.1)
|
# msg = await self.receive(timeout=0.1)
|
||||||
# self.beliefs = dict uit message
|
# self.beliefs = dict uit message
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from spade.behaviour import CyclicBehaviour
|
|||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ContinuousBeliefCollector(CyclicBehaviour):
|
class ContinuousBeliefCollector(CyclicBehaviour):
|
||||||
"""
|
"""
|
||||||
Continuously collects beliefs/emotions from extractor agents:
|
Continuously collects beliefs/emotions from extractor agents:
|
||||||
@@ -23,9 +24,12 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
|||||||
# Parse JSON payload
|
# Parse JSON payload
|
||||||
try:
|
try:
|
||||||
payload = json.loads(msg.body)
|
payload = json.loads(msg.body)
|
||||||
except JSONDecodeError as e:
|
except Exception as e:
|
||||||
self.agent.logger.warning(
|
logger.warning(
|
||||||
"Failed to parse JSON from %s. Body=%r Error=%s", sender_node, msg.body, e
|
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
|
||||||
|
sender_node,
|
||||||
|
msg.body,
|
||||||
|
e,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -51,7 +55,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
|||||||
Expected payload:
|
Expected payload:
|
||||||
{
|
{
|
||||||
"type": "belief_extraction_text",
|
"type": "belief_extraction_text",
|
||||||
"beliefs": {"user_said": ["hello"","Can you help me?","stop talking to me","No","Pepper do a dance"]}
|
"beliefs": {"user_said": ["Can you help me?"]}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,11 @@ from control_backend.core.config import settings
|
|||||||
class BeliefTextAgent(Agent):
|
class BeliefTextAgent(Agent):
|
||||||
class SendOnceBehaviourBlfText(OneShotBehaviour):
|
class SendOnceBehaviourBlfText(OneShotBehaviour):
|
||||||
async def run(self):
|
async def run(self):
|
||||||
to_jid = f"{settings.agent_settings.belief_collector_agent_name}@{settings.agent_settings.host}"
|
to_jid = (
|
||||||
|
settings.agent_settings.belief_collector_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ settings.agent_settings.host
|
||||||
|
)
|
||||||
|
|
||||||
# Send multiple beliefs in one JSON payload
|
# Send multiple beliefs in one JSON payload
|
||||||
payload = {
|
payload = {
|
||||||
|
|||||||
@@ -75,7 +75,8 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
self.model_name = "mlx-community/whisper-small.en-mlx"
|
self.model_name = "mlx-community/whisper-small.en-mlx"
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.was_loaded: return
|
if self.was_loaded:
|
||||||
|
return
|
||||||
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
||||||
# store it in memory for later usage
|
# store it in memory for later usage
|
||||||
ModelHolder.get_model(self.model_name, mx.float16)
|
ModelHolder.get_model(self.model_name, mx.float16)
|
||||||
@@ -83,9 +84,9 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
|
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return mlx_whisper.transcribe(audio,
|
return mlx_whisper.transcribe(
|
||||||
path_or_hf_repo=self.model_name,
|
audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio)
|
||||||
decode_options=self._get_decode_options(audio))["text"]
|
)["text"]
|
||||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
|
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
|
||||||
|
|
||||||
|
|
||||||
@@ -95,12 +96,13 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.model is not None: return
|
if self.model is not None:
|
||||||
|
return
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
self.model = whisper.load_model("small.en", device=device)
|
self.model = whisper.load_model("small.en", device=device)
|
||||||
|
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return whisper.transcribe(self.model,
|
return whisper.transcribe(
|
||||||
audio,
|
self.model, audio, decode_options=self._get_decode_options(audio)
|
||||||
decode_options=self._get_decode_options(audio))["text"]
|
)["text"]
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from fastapi import APIRouter, Request
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
from zmq import Socket
|
from zmq import Socket
|
||||||
|
|
||||||
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -17,6 +17,5 @@ async def receive_command(command: SpeechCommand, request: Request):
|
|||||||
topic = b"command"
|
topic = b"command"
|
||||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
pub_socket: Socket = request.app.state.internal_comm_socket
|
||||||
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||||
|
|
||||||
|
|
||||||
return {"status": "Command received"}
|
return {"status": "Command received"}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
|
|
||||||
from control_backend.api.v1.endpoints import message, sse, command
|
from control_backend.api.v1.endpoints import command, message, sse
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class LLMSettings(BaseModel):
|
|||||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||||
local_llm_model: str = "openai/gpt-oss-20b"
|
local_llm_model: str = "openai/gpt-oss-20b"
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
app_title: str = "PepperPlus"
|
app_title: str = "PepperPlus"
|
||||||
|
|
||||||
@@ -37,4 +38,5 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env")
|
model_config = SettingsConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class RIEndpoint(str, Enum):
|
class RIEndpoint(str, Enum):
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import asyncio
|
|
||||||
import zmq
|
|
||||||
import json
|
import json
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import zmq
|
||||||
|
|
||||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||||
from control_backend.schemas.ri_message import SpeechCommand
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch, ANY
|
|
||||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||||
|
|
||||||
|
|
||||||
@@ -185,8 +187,8 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
|
|||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
|
|
||||||
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
|
# We are sending wrong negotiation info to the communication agent,
|
||||||
# better response, within a limited time.
|
# so we should retry and expect a better response, within a limited time.
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
) as MockCommandAgent:
|
) as MockCommandAgent:
|
||||||
@@ -358,8 +360,8 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
|
|||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
|
|
||||||
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
|
# We are sending wrong negotiation info to the communication agent,
|
||||||
# better response, within a limited time.
|
# so we should retry and expect a better response, within a limited time.
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
) as MockCommandAgent:
|
) as MockCommandAgent:
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
from control_backend.api.v1.endpoints import command
|
from control_backend.api.v1.endpoints import command
|
||||||
from control_backend.schemas.ri_message import SpeechCommand
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
|
||||||
|
|
||||||
|
|
||||||
def valid_command_1():
|
def valid_command_1():
|
||||||
return SpeechCommand(data="Hallo?")
|
return SpeechCommand(data="Hallo?")
|
||||||
@@ -13,24 +14,13 @@ def invalid_command_1():
|
|||||||
|
|
||||||
def test_valid_speech_command_1():
|
def test_valid_speech_command_1():
|
||||||
command = valid_command_1()
|
command = valid_command_1()
|
||||||
try:
|
RIMessage.model_validate(command)
|
||||||
RIMessage.model_validate(command)
|
SpeechCommand.model_validate(command)
|
||||||
SpeechCommand.model_validate(command)
|
|
||||||
assert True
|
|
||||||
except ValidationError:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_speech_command_1():
|
def test_invalid_speech_command_1():
|
||||||
command = invalid_command_1()
|
command = invalid_command_1()
|
||||||
passed_ri_message_validation = False
|
RIMessage.model_validate(command)
|
||||||
try:
|
|
||||||
# Should succeed, still.
|
|
||||||
RIMessage.model_validate(command)
|
|
||||||
passed_ri_message_validation = True
|
|
||||||
|
|
||||||
# Should fail.
|
with pytest.raises(ValidationError):
|
||||||
SpeechCommand.model_validate(command)
|
SpeechCommand.model_validate(command)
|
||||||
assert False
|
|
||||||
except ValidationError:
|
|
||||||
assert passed_ri_message_validation
|
|
||||||
|
|||||||
@@ -203,6 +203,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
|
|||||||
assert "Set belief is_hot with arguments ['kitchen']" in caplog.text
|
assert "Set belief is_hot with arguments ['kitchen']" in caplog.text
|
||||||
assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text
|
assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
# def test_responded_unset(belief_setter, mock_agent):
|
# def test_responded_unset(belief_setter, mock_agent):
|
||||||
# # Arrange
|
# # Arrange
|
||||||
# new_beliefs = {"user_said": ["message"]}
|
# new_beliefs = {"user_said": ["message"]}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
from unittest.mock import MagicMock, AsyncMock, call
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from control_backend.agents.belief_collector.behaviours.continuous_collect import ContinuousBeliefCollector
|
from control_backend.agents.belief_collector.behaviours.continuous_collect import (
|
||||||
|
ContinuousBeliefCollector,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent(mocker):
|
def mock_agent(mocker):
|
||||||
@@ -13,18 +15,20 @@ def mock_agent(mocker):
|
|||||||
agent.jid = "belief_collector_agent@test"
|
agent.jid = "belief_collector_agent@test"
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def continuous_collector(mock_agent, mocker):
|
def continuous_collector(mock_agent, mocker):
|
||||||
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
|
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
|
||||||
# Patch asyncio.sleep to prevent tests from actually waiting
|
# Patch asyncio.sleep to prevent tests from actually waiting
|
||||||
mocker.patch("asyncio.sleep", return_value=None)
|
mocker.patch("asyncio.sleep", return_value=None)
|
||||||
|
|
||||||
collector = ContinuousBeliefCollector()
|
collector = ContinuousBeliefCollector()
|
||||||
collector.agent = mock_agent
|
collector.agent = mock_agent
|
||||||
# Mock the receive method, we will control its return value in each test
|
# Mock the receive method, we will control its return value in each test
|
||||||
collector.receive = AsyncMock()
|
collector.receive = AsyncMock()
|
||||||
return collector
|
return collector
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_no_message_received(continuous_collector, mocker):
|
async def test_run_no_message_received(continuous_collector, mocker):
|
||||||
"""
|
"""
|
||||||
@@ -40,6 +44,7 @@ async def test_run_no_message_received(continuous_collector, mocker):
|
|||||||
# Assert
|
# Assert
|
||||||
continuous_collector._process_message.assert_not_called()
|
continuous_collector._process_message.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_message_received(continuous_collector, mocker):
|
async def test_run_message_received(continuous_collector, mocker):
|
||||||
"""
|
"""
|
||||||
@@ -55,7 +60,8 @@ async def test_run_message_received(continuous_collector, mocker):
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
|
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_invalid(continuous_collector, mocker):
|
async def test_process_message_invalid(continuous_collector, mocker):
|
||||||
"""
|
"""
|
||||||
@@ -66,15 +72,18 @@ async def test_process_message_invalid(continuous_collector, mocker):
|
|||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.body = invalid_json
|
msg.body = invalid_json
|
||||||
msg.sender = "belief_text_agent_mock@test"
|
msg.sender = "belief_text_agent_mock@test"
|
||||||
|
|
||||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
logger_mock = mocker.patch(
|
||||||
|
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||||
|
)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await continuous_collector._process_message(msg)
|
await continuous_collector._process_message(msg)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
logger_mock.warning.assert_called_once()
|
logger_mock.warning.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_get_sender_from_message(continuous_collector):
|
def test_get_sender_from_message(continuous_collector):
|
||||||
"""
|
"""
|
||||||
Test that _sender_node correctly extracts the sender node from the message JID.
|
Test that _sender_node correctly extracts the sender node from the message JID.
|
||||||
@@ -89,6 +98,7 @@ def test_get_sender_from_message(continuous_collector):
|
|||||||
# Assert
|
# Assert
|
||||||
assert sender_node == "agent_node"
|
assert sender_node == "agent_node"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
|
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
@@ -98,6 +108,7 @@ async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker
|
|||||||
await continuous_collector._process_message(msg)
|
await continuous_collector._process_message(msg)
|
||||||
spy.assert_awaited_once()
|
spy.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
|
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
@@ -107,6 +118,7 @@ async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mock
|
|||||||
await continuous_collector._process_message(msg)
|
await continuous_collector._process_message(msg)
|
||||||
spy.assert_awaited_once()
|
spy.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
|
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
@@ -116,50 +128,64 @@ async def test_routes_to_handle_emo_text(continuous_collector, mocker):
|
|||||||
await continuous_collector._process_message(msg)
|
await continuous_collector._process_message(msg)
|
||||||
spy.assert_awaited_once()
|
spy.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unrecognized_message_logs_info(continuous_collector, mocker):
|
async def test_unrecognized_message_logs_info(continuous_collector, mocker):
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.body = json.dumps({"type": "something_else"})
|
msg.body = json.dumps({"type": "something_else"})
|
||||||
msg.sender = "x@test"
|
msg.sender = "x@test"
|
||||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
logger_mock = mocker.patch(
|
||||||
|
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||||
|
)
|
||||||
await continuous_collector._process_message(msg)
|
await continuous_collector._process_message(msg)
|
||||||
logger_mock.info.assert_any_call(
|
logger_mock.info.assert_any_call(
|
||||||
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", "x", "something_else"
|
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
|
||||||
|
"x",
|
||||||
|
"something_else",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_belief_text_no_beliefs(continuous_collector, mocker):
|
async def test_belief_text_no_beliefs(continuous_collector, mocker):
|
||||||
msg_payload = {"type": "belief_extraction_text"} # no 'beliefs'
|
msg_payload = {"type": "belief_extraction_text"} # no 'beliefs'
|
||||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
logger_mock = mocker.patch(
|
||||||
|
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||||
|
)
|
||||||
await continuous_collector._handle_belief_text(msg_payload, "origin_node")
|
await continuous_collector._handle_belief_text(msg_payload, "origin_node")
|
||||||
logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.")
|
logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_belief_text_beliefs_not_dict(continuous_collector, mocker):
|
async def test_belief_text_beliefs_not_dict(continuous_collector, mocker):
|
||||||
payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]}
|
payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]}
|
||||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
logger_mock = mocker.patch(
|
||||||
|
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||||
|
)
|
||||||
await continuous_collector._handle_belief_text(payload, "origin")
|
await continuous_collector._handle_belief_text(payload, "origin")
|
||||||
logger_mock.warning.assert_any_call("BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"])
|
logger_mock.warning.assert_any_call(
|
||||||
|
"BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_belief_text_values_not_lists(continuous_collector, mocker):
|
async def test_belief_text_values_not_lists(continuous_collector, mocker):
|
||||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}}
|
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}}
|
||||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
logger_mock = mocker.patch(
|
||||||
|
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||||
|
)
|
||||||
await continuous_collector._handle_belief_text(payload, "origin")
|
await continuous_collector._handle_belief_text(payload, "origin")
|
||||||
logger_mock.warning.assert_any_call(
|
logger_mock.warning.assert_any_call(
|
||||||
"BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"}
|
"BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
|
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
|
||||||
payload = {
|
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
|
||||||
"type": "belief_extraction_text",
|
|
||||||
"beliefs": {"user_said": ["hello test", "No"]}
|
|
||||||
}
|
|
||||||
# Your code calls self.send(..); patch it (or switch implementation to self.agent.send and patch that)
|
|
||||||
continuous_collector.send = AsyncMock()
|
continuous_collector.send = AsyncMock()
|
||||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
logger_mock = mocker.patch(
|
||||||
|
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||||
|
)
|
||||||
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
|
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
|
||||||
|
|
||||||
logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1)
|
logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1)
|
||||||
@@ -169,12 +195,14 @@ async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector,
|
|||||||
# make sure we attempted a send
|
# make sure we attempted a send
|
||||||
continuous_collector.send.assert_awaited_once()
|
continuous_collector.send.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_beliefs_noop_on_empty(continuous_collector):
|
async def test_send_beliefs_noop_on_empty(continuous_collector):
|
||||||
continuous_collector.send = AsyncMock()
|
continuous_collector.send = AsyncMock()
|
||||||
await continuous_collector._send_beliefs_to_bdi([], origin="o")
|
await continuous_collector._send_beliefs_to_bdi([], origin="o")
|
||||||
continuous_collector.send.assert_not_awaited()
|
continuous_collector.send.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.asyncio
|
# @pytest.mark.asyncio
|
||||||
# async def test_send_beliefs_sends_json_packet(continuous_collector):
|
# async def test_send_beliefs_sends_json_packet(continuous_collector):
|
||||||
# # Patch .send and capture the message body
|
# # Patch .send and capture the message body
|
||||||
@@ -191,19 +219,22 @@ async def test_send_beliefs_noop_on_empty(continuous_collector):
|
|||||||
# assert "belief_packet" in json.loads(sent["body"])["type"]
|
# assert "belief_packet" in json.loads(sent["body"])["type"]
|
||||||
# assert json.loads(sent["body"])["beliefs"] == beliefs
|
# assert json.loads(sent["body"])["beliefs"] == beliefs
|
||||||
|
|
||||||
|
|
||||||
def test_sender_node_no_sender_returns_literal(continuous_collector):
|
def test_sender_node_no_sender_returns_literal(continuous_collector):
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.sender = None
|
msg.sender = None
|
||||||
assert continuous_collector._sender_node(msg) == "no_sender"
|
assert continuous_collector._sender_node(msg) == "no_sender"
|
||||||
|
|
||||||
|
|
||||||
def test_sender_node_without_at(continuous_collector):
|
def test_sender_node_without_at(continuous_collector):
|
||||||
msg = MagicMock()
|
msg = MagicMock()
|
||||||
msg.sender = "localpartonly"
|
msg.sender = "localpartonly"
|
||||||
assert continuous_collector._sender_node(msg) == "localpartonly"
|
assert continuous_collector._sender_node(msg) == "localpartonly"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
|
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
|
||||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
|
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
|
||||||
continuous_collector.send = AsyncMock()
|
continuous_collector.send = AsyncMock()
|
||||||
await continuous_collector._handle_belief_text(payload, "origin")
|
await continuous_collector._handle_belief_text(payload, "origin")
|
||||||
continuous_collector.send.assert_awaited_once()
|
continuous_collector.send.assert_awaited_once()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper
|
|||||||
|
|
||||||
def test_estimate_max_tokens():
|
def test_estimate_max_tokens():
|
||||||
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
|
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
|
||||||
audio = np.empty(shape=(60*16_000), dtype=np.float32)
|
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||||
|
|
||||||
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
||||||
|
|
||||||
@@ -16,7 +16,7 @@ def test_estimate_max_tokens():
|
|||||||
|
|
||||||
def test_get_decode_options():
|
def test_get_decode_options():
|
||||||
"""Check whether the right decode options are given under different scenarios."""
|
"""Check whether the right decode options are given under different scenarios."""
|
||||||
audio = np.empty(shape=(60*16_000), dtype=np.float32)
|
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||||
|
|
||||||
# With the defaults, it should limit output length based on input size
|
# With the defaults, it should limit output length based on input size
|
||||||
recognizer = OpenAIWhisperSpeechRecognizer()
|
recognizer = OpenAIWhisperSpeechRecognizer()
|
||||||
|
|||||||
Reference in New Issue
Block a user