fix: fixed new tests and merged dev into branch
ref: N25B-256
This commit is contained in:
77
.githooks/check-branch-name.sh
Executable file
77
.githooks/check-branch-name.sh
Executable file
@@ -0,0 +1,77 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# This script checks if the current branch name follows the specified format.
|
||||||
|
# It's designed to be used as a 'pre-commit' git hook.
|
||||||
|
|
||||||
|
# Format: <type>/<short-description>
|
||||||
|
# Example: feat/add-user-login
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
# An array of allowed commit types
|
||||||
|
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
|
||||||
|
# An array of branches to ignore
|
||||||
|
IGNORED_BRANCHES=(main dev demo)
|
||||||
|
|
||||||
|
# --- Colors for Output ---
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# --- Helper Functions ---
|
||||||
|
error_exit() {
|
||||||
|
echo -e "${RED}ERROR: $1${NC}" >&2
|
||||||
|
echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Main Logic ---
|
||||||
|
|
||||||
|
# 1. Get the current branch name
|
||||||
|
BRANCH_NAME=$(git symbolic-ref --short HEAD)
|
||||||
|
|
||||||
|
# 2. Check if the current branch is in the ignored list
|
||||||
|
for ignored_branch in "${IGNORED_BRANCHES[@]}"; do
|
||||||
|
if [ "$BRANCH_NAME" == "$ignored_branch" ]; then
|
||||||
|
echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# 3. Validate the overall structure: <type>/<description>
|
||||||
|
if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then
|
||||||
|
error_exit "Branch name must be in the format: <type>/<short-description>\nExample: feat/add-user-login"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. Extract the type and description
|
||||||
|
TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1)
|
||||||
|
DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-)
|
||||||
|
|
||||||
|
# 5. Validate the <type>
|
||||||
|
type_valid=false
|
||||||
|
for allowed_type in "${ALLOWED_TYPES[@]}"; do
|
||||||
|
if [ "$TYPE" == "$allowed_type" ]; then
|
||||||
|
type_valid=true
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ "$type_valid" == false ]; then
|
||||||
|
error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 6. Validate the <short-description>
|
||||||
|
# Regex breakdown:
|
||||||
|
# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word).
|
||||||
|
# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times.
|
||||||
|
# $ - End of the string.
|
||||||
|
# This entire pattern enforces 1 to 6 words total, separated by dashes.
|
||||||
|
DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$"
|
||||||
|
|
||||||
|
if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then
|
||||||
|
error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If all checks pass, exit successfully
|
||||||
|
echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}"
|
||||||
|
exit 0
|
||||||
135
.githooks/check-commit-msg.sh
Executable file
135
.githooks/check-commit-msg.sh
Executable file
@@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# This script checks if a commit message follows the specified format.
|
||||||
|
# It's designed to be used as a 'commit-msg' git hook.
|
||||||
|
|
||||||
|
# Format:
|
||||||
|
# <type>: <short description>
|
||||||
|
#
|
||||||
|
# [optional]<body>
|
||||||
|
#
|
||||||
|
# [ref/close]: <issue identifier>
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
# An array of allowed commit types
|
||||||
|
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
|
||||||
|
|
||||||
|
# --- Colors for Output ---
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# The first argument to the hook is the path to the file containing the commit message
|
||||||
|
COMMIT_MSG_FILE=$1
|
||||||
|
|
||||||
|
# --- Automated Commit Detection ---
|
||||||
|
|
||||||
|
# Read the first line (header) for initial checks
|
||||||
|
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||||
|
|
||||||
|
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
|
||||||
|
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
|
||||||
|
MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
|
||||||
|
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
|
||||||
|
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for Revert commits
|
||||||
|
# Example: "Revert "feat: add new feature""
|
||||||
|
REVERT_PATTERN="^Revert \".*\""
|
||||||
|
if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then
|
||||||
|
echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for Cherry-pick commits (this pattern appears at the end of the message)
|
||||||
|
# Example: "(cherry picked from commit deadbeef...)"
|
||||||
|
# We use grep -q to search the whole file quietly.
|
||||||
|
CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)"
|
||||||
|
if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then
|
||||||
|
echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check for Squash
|
||||||
|
# Example: "Squash commits ..."
|
||||||
|
SQUASH_PATTERN="^Squash .+"
|
||||||
|
if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then
|
||||||
|
echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# --- Validation Functions ---
|
||||||
|
|
||||||
|
# Function to print an error message and exit
|
||||||
|
# Usage: error_exit "Your error message here"
|
||||||
|
error_exit() {
|
||||||
|
# >&2 redirects echo to stderr
|
||||||
|
echo -e "${RED}ERROR: $1${NC}" >&2
|
||||||
|
echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Main Logic ---
|
||||||
|
|
||||||
|
# 1. Read the header (first line) of the commit message
|
||||||
|
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||||
|
|
||||||
|
# 2. Validate the header format: <type>: <description>
|
||||||
|
# Regex breakdown:
|
||||||
|
# ^(type1|type2|...) - Starts with one of the allowed types
|
||||||
|
# : - Followed by a literal colon
|
||||||
|
# \s - Followed by a single space
|
||||||
|
# .+ - Followed by one or more characters for the description
|
||||||
|
# $ - End of the line
|
||||||
|
TYPES_REGEX=$(
|
||||||
|
IFS="|"
|
||||||
|
echo "${ALLOWED_TYPES[*]}"
|
||||||
|
)
|
||||||
|
HEADER_REGEX="^($TYPES_REGEX): .+$"
|
||||||
|
|
||||||
|
if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then
|
||||||
|
error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Only validate footer if commit type is not chore
|
||||||
|
TYPE=$(echo "$HEADER" | cut -d':' -f1)
|
||||||
|
if [ "$TYPE" != "chore" ]; then
|
||||||
|
# 3. Validate the footer (last line) of the commit message
|
||||||
|
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE")
|
||||||
|
|
||||||
|
# Regex breakdown:
|
||||||
|
# ^(ref|close) - Starts with 'ref' or 'close'
|
||||||
|
# : - Followed by a literal colon
|
||||||
|
# \s - Followed by a single space
|
||||||
|
# N25B- - Followed by the literal string 'N25B-'
|
||||||
|
# [0-9]+ - Followed by one or more digits
|
||||||
|
# $ - End of the line
|
||||||
|
FOOTER_REGEX="^(ref|close): N25B-[0-9]+$"
|
||||||
|
|
||||||
|
if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then
|
||||||
|
error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. If the message has more than 2 lines, validate the separator
|
||||||
|
# A blank line must exist between the header and the body.
|
||||||
|
LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace
|
||||||
|
|
||||||
|
# We only care if there is a body. Header + Footer = 2 lines.
|
||||||
|
# Header + Blank Line + Body... + Footer > 2 lines.
|
||||||
|
if [ "$LINE_COUNT" -gt 2 ]; then
|
||||||
|
# Get the second line
|
||||||
|
SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE")
|
||||||
|
|
||||||
|
# Check if the second line is NOT empty. If it's not, it's an error.
|
||||||
|
if [ -n "$SECOND_LINE" ]; then
|
||||||
|
error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If all checks pass, exit with success
|
||||||
|
echo -e "${GREEN}Commit message is valid.${NC}"
|
||||||
|
exit 0
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
|
|
||||||
commit_msg_file=$1
|
|
||||||
commit_msg=$(cat "$commit_msg_file")
|
|
||||||
|
|
||||||
if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then
|
|
||||||
if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "❌ Commit message invalid! Must start with <type>: <description>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
|
|
||||||
# Get current branch
|
|
||||||
branch=$(git rev-parse --abbrev-ref HEAD)
|
|
||||||
|
|
||||||
if echo "$branch" | grep -Eq "(dev|main)"; then
|
|
||||||
echo 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# allowed pattern <type/>
|
|
||||||
if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo "❌ Invalid branch name: $branch"
|
|
||||||
echo "Branch must be named <type>/<description-of-branch> (must have one to six words separated by a dash)"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
|
|
||||||
echo "#<type>: <description>
|
|
||||||
|
|
||||||
#[optional body]
|
|
||||||
|
|
||||||
#[optional footer(s)]
|
|
||||||
|
|
||||||
#[ref/close]: <issue identifier>" > $1
|
|
||||||
@@ -1,10 +1,24 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
# Ruff version.
|
# Ruff version.
|
||||||
rev: v0.14.2
|
rev: v0.14.2
|
||||||
hooks:
|
hooks:
|
||||||
# Run the linter.
|
# Run the linter.
|
||||||
- id: ruff-check
|
- id: ruff-check
|
||||||
args: [ --fix ]
|
# Run the formatter.
|
||||||
# Run the formatter.
|
- id: ruff-format
|
||||||
- id: ruff-format
|
# Configure local hooks
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: check-commit-msg
|
||||||
|
name: Check commit message format
|
||||||
|
entry: .githooks/check-commit-msg.sh
|
||||||
|
language: script
|
||||||
|
stages: [commit-msg]
|
||||||
|
- id: check-branch-name
|
||||||
|
name: Check branch name format
|
||||||
|
entry: .githooks/check-branch-name.sh
|
||||||
|
language: script
|
||||||
|
stages: [pre-commit]
|
||||||
|
always_run: true
|
||||||
|
pass_filenames: false
|
||||||
|
|||||||
22
README.md
22
README.md
@@ -47,21 +47,19 @@ Or for integration tests:
|
|||||||
uv run --group integration-test pytest test/integration
|
uv run --group integration-test pytest test/integration
|
||||||
```
|
```
|
||||||
|
|
||||||
## GitHooks
|
## Git Hooks
|
||||||
|
|
||||||
To activate automatic commits/branch name checks run:
|
To activate automatic linting, formatting, branch name checks and commit message checks, run:
|
||||||
|
|
||||||
```shell
|
```bash
|
||||||
git config --local core.hooksPath .githooks
|
uv run pre-commit install
|
||||||
|
uv run pre-commit install --hook-type commit-msg
|
||||||
```
|
```
|
||||||
|
|
||||||
If your commit fails its either:
|
You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running:
|
||||||
branch name != <type>/description-of-branch ,
|
|
||||||
commit name != <type>: description of the commit.
|
|
||||||
<ref>: N25B-Num's
|
|
||||||
|
|
||||||
To add automatic linting and formatting, run:
|
```bash
|
||||||
|
git config --local --unset core.hooksPath
|
||||||
|
```
|
||||||
|
|
||||||
```shell
|
Then run the pre-commit install commands again.
|
||||||
uv run pre-commit install
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from spade.behaviour import CyclicBehaviour
|
from spade.behaviour import CyclicBehaviour
|
||||||
|
from spade.message import Message
|
||||||
|
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
|
|
||||||
class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
||||||
@@ -10,7 +12,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
|||||||
Adds behavior to receive responses from the LLM Agent.
|
Adds behavior to receive responses from the LLM Agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger = logging.getLogger("BDI/LLM Reciever")
|
logger = logging.getLogger("BDI/LLM Receiver")
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
msg = await self.receive(timeout=2)
|
msg = await self.receive(timeout=2)
|
||||||
@@ -22,7 +24,20 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
|||||||
case settings.agent_settings.llm_agent_name:
|
case settings.agent_settings.llm_agent_name:
|
||||||
content = msg.body
|
content = msg.body
|
||||||
self.logger.info("Received LLM response: %s", content)
|
self.logger.info("Received LLM response: %s", content)
|
||||||
# Here the BDI can pass the message back as a response
|
|
||||||
|
speech_command = SpeechCommand(data=content)
|
||||||
|
|
||||||
|
message = Message(
|
||||||
|
to=settings.agent_settings.ri_command_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ settings.agent_settings.host,
|
||||||
|
sender=self.agent.jid,
|
||||||
|
body=speech_command.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug("Sending message: %s", message)
|
||||||
|
|
||||||
|
await self.send(message)
|
||||||
case _:
|
case _:
|
||||||
self.logger.debug("Not from the llm, discarding message")
|
self.logger.debug("Not from the llm, discarding message")
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -13,23 +13,23 @@ class BeliefFromText(CyclicBehaviour):
|
|||||||
|
|
||||||
# TODO: LLM prompt nog hardcoded
|
# TODO: LLM prompt nog hardcoded
|
||||||
llm_instruction_prompt = """
|
llm_instruction_prompt = """
|
||||||
You are an information extraction assistent for a BDI agent.
|
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
||||||
Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
|
from a user's text to bind a list of ungrounded beliefs. Rules:
|
||||||
You will receive a JSON object with "beliefs"
|
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
|
||||||
(a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript).
|
and "text" (user's transcript).
|
||||||
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
|
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
|
||||||
A single piece of text might contain multiple instances that match a belief.
|
A single piece of text might contain multiple instances that match a belief.
|
||||||
Respond ONLY with a single JSON object.
|
Respond ONLY with a single JSON object.
|
||||||
The JSON object's keys should be the belief functors (e.g., "weather").
|
The JSON object's keys should be the belief functors (e.g., "weather").
|
||||||
The value for each key must be a list of lists.
|
The value for each key must be a list of lists.
|
||||||
Each inner list must contain the extracted arguments
|
Each inner list must contain the extracted arguments (as strings) for one instance \
|
||||||
(as strings) for one instance of that belief.
|
of that belief.
|
||||||
CRITICAL: If no information in the text matches a belief,
|
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
|
||||||
DO NOT include that key in your response.
|
in your response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# on_start agent receives message containing the beliefs to look out
|
# on_start agent receives message containing the beliefs to look out for and
|
||||||
# for and sets up the LLM with instruction prompt
|
# sets up the LLM with instruction prompt
|
||||||
# async def on_start(self):
|
# async def on_start(self):
|
||||||
# msg = await self.receive(timeout=0.1)
|
# msg = await self.receive(timeout=0.1)
|
||||||
# self.beliefs = dict uit message
|
# self.beliefs = dict uit message
|
||||||
|
|||||||
@@ -70,8 +70,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
|||||||
Expected payload:
|
Expected payload:
|
||||||
{
|
{
|
||||||
"type": "belief_extraction_text",
|
"type": "belief_extraction_text",
|
||||||
"beliefs": {"user_said": ["hello"","Can you help me?",
|
"beliefs": {"user_said": ["Can you help me?"]}
|
||||||
"stop talking to me","No","Pepper do a dance"]}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ LLM Agent module for routing text queries from the BDI Core Agent to a local LLM
|
|||||||
service and returning its responses back to the BDI Core Agent.
|
service and returning its responses back to the BDI Core Agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
import re
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from spade.agent import Agent
|
from spade.agent import Agent
|
||||||
@@ -54,11 +56,16 @@ class LLMAgent(Agent):
|
|||||||
|
|
||||||
async def _process_bdi_message(self, message: Message):
|
async def _process_bdi_message(self, message: Message):
|
||||||
"""
|
"""
|
||||||
Forwards user text to the LLM and replies with the generated text.
|
Forwards user text from the BDI to the LLM and replies with the generated text in chunks
|
||||||
|
separated by punctuation.
|
||||||
"""
|
"""
|
||||||
user_text = message.body
|
user_text = message.body
|
||||||
llm_response = await self._query_llm(user_text)
|
# Consume the streaming generator and send a reply for every chunk
|
||||||
await self._reply(llm_response)
|
async for chunk in self._query_llm(user_text):
|
||||||
|
await self._reply(chunk)
|
||||||
|
self.agent.logger.debug(
|
||||||
|
"Finished processing BDI message. Response sent in chunks to BDI Core Agent."
|
||||||
|
)
|
||||||
|
|
||||||
async def _reply(self, msg: str):
|
async def _reply(self, msg: str):
|
||||||
"""
|
"""
|
||||||
@@ -69,48 +76,89 @@ class LLMAgent(Agent):
|
|||||||
body=msg,
|
body=msg,
|
||||||
)
|
)
|
||||||
await self.send(reply)
|
await self.send(reply)
|
||||||
self.agent.logger.info("Reply sent to BDI Core Agent")
|
|
||||||
|
|
||||||
async def _query_llm(self, prompt: str) -> str:
|
async def _query_llm(self, prompt: str) -> AsyncGenerator[str]:
|
||||||
"""
|
"""
|
||||||
Sends a chat completion request to the local LLM service.
|
Sends a chat completion request to the local LLM service and streams the response by
|
||||||
|
yielding fragments separated by punctuation like.
|
||||||
|
|
||||||
:param prompt: Input text prompt to pass to the LLM.
|
:param prompt: Input text prompt to pass to the LLM.
|
||||||
:return: LLM-generated content or fallback message.
|
:yield: Fragments of the LLM-generated content.
|
||||||
"""
|
"""
|
||||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
instructions = LLMInstructions(
|
||||||
# Example dynamic content for future (optional)
|
"- Be friendly and respectful.\n"
|
||||||
|
"- Make the conversation feel natural and engaging.\n"
|
||||||
|
"- Speak like a pirate.\n"
|
||||||
|
"- When the user asks what you can do, tell them.",
|
||||||
|
"- Try to learn the user's name during conversation.\n"
|
||||||
|
"- Suggest playing a game of asking yes or no questions where you think of a word "
|
||||||
|
"and the user must guess it.",
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "developer",
|
||||||
|
"content": instructions.build_developer_instruction(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
instructions = LLMInstructions()
|
try:
|
||||||
developer_instruction = instructions.build_developer_instruction()
|
current_chunk = ""
|
||||||
|
async for token in self._stream_query_llm(messages):
|
||||||
|
current_chunk += token
|
||||||
|
|
||||||
response = await client.post(
|
# Stream the message in chunks separated by punctuation.
|
||||||
|
# We include the delimiter in the emitted chunk for natural flow.
|
||||||
|
pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL)
|
||||||
|
for m in pattern.finditer(current_chunk):
|
||||||
|
chunk = m.group(0)
|
||||||
|
if chunk:
|
||||||
|
yield current_chunk
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
# Yield any remaining tail
|
||||||
|
if current_chunk:
|
||||||
|
yield current_chunk
|
||||||
|
except httpx.HTTPError as err:
|
||||||
|
self.agent.logger.error("HTTP error.", exc_info=err)
|
||||||
|
yield "LLM service unavailable."
|
||||||
|
except Exception as err:
|
||||||
|
self.agent.logger.error("Unexpected error.", exc_info=err)
|
||||||
|
yield "Error processing the request."
|
||||||
|
|
||||||
|
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
|
||||||
|
"""Raises httpx.HTTPError when the API gives an error."""
|
||||||
|
async with httpx.AsyncClient(timeout=None) as client:
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
settings.llm_settings.local_llm_url,
|
settings.llm_settings.local_llm_url,
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
json={
|
json={
|
||||||
"model": settings.llm_settings.local_llm_model,
|
"model": settings.llm_settings.local_llm_model,
|
||||||
"messages": [
|
"messages": messages,
|
||||||
{"role": "developer", "content": developer_instruction},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
|
"stream": True,
|
||||||
},
|
},
|
||||||
)
|
) as response:
|
||||||
|
|
||||||
try:
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data: dict[str, Any] = response.json()
|
|
||||||
return (
|
async for line in response.aiter_lines():
|
||||||
data.get("choices", [{}])[0]
|
if not line or not line.startswith("data: "):
|
||||||
.get("message", {})
|
continue
|
||||||
.get("content", "No response")
|
|
||||||
)
|
data = line[len("data: ") :]
|
||||||
except httpx.HTTPError as err:
|
if data.strip() == "[DONE]":
|
||||||
self.agent.logger.error("HTTP error: %s", err)
|
break
|
||||||
return "LLM service unavailable."
|
|
||||||
except Exception as err:
|
try:
|
||||||
self.agent.logger.error("Unexpected error: %s", err)
|
event = json.loads(data)
|
||||||
return "Error processing the request."
|
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
self.agent.logger.error("Failed to parse LLM response: %s", data)
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -28,7 +28,9 @@ class LLMInstructions:
|
|||||||
"""
|
"""
|
||||||
sections = [
|
sections = [
|
||||||
"You are a Pepper robot engaging in natural human conversation.",
|
"You are a Pepper robot engaging in natural human conversation.",
|
||||||
"Keep responses between 1–5 sentences, unless instructed otherwise.\n",
|
"Keep responses between 1–3 sentences, unless told otherwise.\n",
|
||||||
|
"You're given goals to reach. Reach them in order, but make the conversation feel "
|
||||||
|
"natural. Some turns you should not try to achieve your goals.\n",
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.norms:
|
if self.norms:
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ class BeliefTextAgent(Agent):
|
|||||||
class SendOnceBehaviourBlfText(OneShotBehaviour):
|
class SendOnceBehaviourBlfText(OneShotBehaviour):
|
||||||
async def run(self):
|
async def run(self):
|
||||||
to_jid = (
|
to_jid = (
|
||||||
f"{settings.agent_settings.belief_collector_agent_name}"
|
settings.agent_settings.belief_collector_agent_name
|
||||||
f"@{settings.agent_settings.host}"
|
+ "@"
|
||||||
|
+ settings.agent_settings.host
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send multiple beliefs in one JSON payload
|
# Send multiple beliefs in one JSON payload
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import spade.agent
|
||||||
import zmq
|
import zmq
|
||||||
from spade.agent import Agent
|
from spade.agent import Agent
|
||||||
from spade.behaviour import CyclicBehaviour
|
from spade.behaviour import CyclicBehaviour
|
||||||
@@ -32,6 +33,8 @@ class RICommandAgent(Agent):
|
|||||||
self.bind = bind
|
self.bind = bind
|
||||||
|
|
||||||
class SendCommandsBehaviour(CyclicBehaviour):
|
class SendCommandsBehaviour(CyclicBehaviour):
|
||||||
|
"""Behaviour for sending commands received from the UI."""
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
Run the command publishing loop indefinetely.
|
Run the command publishing loop indefinetely.
|
||||||
@@ -50,6 +53,18 @@ class RICommandAgent(Agent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error processing message: %s", e)
|
logger.error("Error processing message: %s", e)
|
||||||
|
|
||||||
|
class SendPythonCommandsBehaviour(CyclicBehaviour):
|
||||||
|
"""Behaviour for sending commands received from other Python agents."""
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
message: spade.agent.Message = await self.receive(timeout=0.1)
|
||||||
|
if message and message.to == self.agent.jid:
|
||||||
|
try:
|
||||||
|
speech_command = SpeechCommand.model_validate_json(message.body)
|
||||||
|
await self.agent.pubsocket.send_json(speech_command.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error processing message: %s", e)
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""
|
"""
|
||||||
Setup the command agent
|
Setup the command agent
|
||||||
@@ -73,5 +88,6 @@ class RICommandAgent(Agent):
|
|||||||
# Add behaviour to our agent
|
# Add behaviour to our agent
|
||||||
commands_behaviour = self.SendCommandsBehaviour()
|
commands_behaviour = self.SendCommandsBehaviour()
|
||||||
self.add_behaviour(commands_behaviour)
|
self.add_behaviour(commands_behaviour)
|
||||||
|
self.add_behaviour(self.SendPythonCommandsBehaviour())
|
||||||
|
|
||||||
logger.info("Finished setting up %s", self.jid)
|
logger.info("Finished setting up %s", self.jid)
|
||||||
|
|||||||
@@ -63,7 +63,25 @@ class RICommunicationAgent(Agent):
|
|||||||
# We didnt get a reply :(
|
# We didnt get a reply :(
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
||||||
self.kill()
|
|
||||||
|
# Tell UI we're disconnected.
|
||||||
|
topic = b"ping"
|
||||||
|
data = json.dumps(False).encode()
|
||||||
|
if self.agent.pub_socket is None:
|
||||||
|
logger.error("communication agent pub socket not correctly initialized.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self.agent.pub_socket.send_multipart([topic, data]), 5
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"Initial connection ping for router timed"
|
||||||
|
" out in ri_communication_agent."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to reboot.
|
||||||
|
self.agent.setup()
|
||||||
|
|
||||||
logger.debug('Received message "%s"', message)
|
logger.debug('Received message "%s"', message)
|
||||||
if "endpoint" not in message:
|
if "endpoint" not in message:
|
||||||
|
|||||||
@@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC):
|
|||||||
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||||
"""
|
"""
|
||||||
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
||||||
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens.
|
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
|
||||||
|
|
||||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||||
:return: The estimated length of the transcribed audio in tokens.
|
:return: The estimated length of the transcribed audio in tokens.
|
||||||
"""
|
"""
|
||||||
length_seconds = len(audio) / 16_000
|
length_seconds = len(audio) / 16_000
|
||||||
length_minutes = length_seconds / 60
|
length_minutes = length_seconds / 60
|
||||||
word_count = length_minutes * 300
|
word_count = length_minutes * 450
|
||||||
token_count = word_count / 3 * 4
|
token_count = word_count / 3 * 4
|
||||||
return int(token_count)
|
return int(token_count) + 10
|
||||||
|
|
||||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -85,9 +85,10 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return mlx_whisper.transcribe(
|
return mlx_whisper.transcribe(
|
||||||
audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio)
|
audio,
|
||||||
)["text"]
|
path_or_hf_repo=self.model_name,
|
||||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
|
**self._get_decode_options(audio),
|
||||||
|
)["text"].strip()
|
||||||
|
|
||||||
|
|
||||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||||
@@ -103,6 +104,4 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
|
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return whisper.transcribe(
|
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"]
|
||||||
self.model, audio, decode_options=self._get_decode_options(audio)
|
|
||||||
)["text"]
|
|
||||||
|
|||||||
@@ -58,6 +58,10 @@ class TranscriptionAgent(Agent):
|
|||||||
audio = await self.audio_in_socket.recv()
|
audio = await self.audio_in_socket.recv()
|
||||||
audio = np.frombuffer(audio, dtype=np.float32)
|
audio = np.frombuffer(audio, dtype=np.float32)
|
||||||
speech = await self._transcribe(audio)
|
speech = await self._transcribe(audio)
|
||||||
|
if not speech:
|
||||||
|
logger.info("Nothing transcribed.")
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Transcribed speech: %s", speech)
|
logger.info("Transcribed speech: %s", speech)
|
||||||
|
|
||||||
await self._share_transcription(speech)
|
await self._share_transcription(speech)
|
||||||
|
|||||||
@@ -54,8 +54,20 @@ class Streaming(CyclicBehaviour):
|
|||||||
|
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
self.i_since_speech = 100 # Used to allow small pauses in speech
|
self.i_since_speech = 100 # Used to allow small pauses in speech
|
||||||
|
self._ready = False
|
||||||
|
|
||||||
|
async def reset(self):
|
||||||
|
"""Clears the ZeroMQ queue and tells this behavior to start."""
|
||||||
|
discarded = 0
|
||||||
|
while await self.audio_in_poller.poll(1) is not None:
|
||||||
|
discarded += 1
|
||||||
|
logging.info(f"Discarded {discarded} audio packets before starting.")
|
||||||
|
self._ready = True
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
|
if not self._ready:
|
||||||
|
return
|
||||||
|
|
||||||
data = await self.audio_in_poller.poll()
|
data = await self.audio_in_poller.poll()
|
||||||
if data is None:
|
if data is None:
|
||||||
if len(self.audio_buffer) > 0:
|
if len(self.audio_buffer) > 0:
|
||||||
@@ -107,6 +119,8 @@ class VADAgent(Agent):
|
|||||||
self.audio_in_socket: azmq.Socket | None = None
|
self.audio_in_socket: azmq.Socket | None = None
|
||||||
self.audio_out_socket: azmq.Socket | None = None
|
self.audio_out_socket: azmq.Socket | None = None
|
||||||
|
|
||||||
|
self.streaming_behaviour: Streaming | None = None
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""
|
"""
|
||||||
Stop listening to audio, stop publishing audio, close sockets.
|
Stop listening to audio, stop publishing audio, close sockets.
|
||||||
@@ -149,8 +163,8 @@ class VADAgent(Agent):
|
|||||||
return
|
return
|
||||||
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||||
|
|
||||||
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
|
self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket)
|
||||||
self.add_behaviour(streaming)
|
self.add_behaviour(self.streaming_behaviour)
|
||||||
|
|
||||||
# Start agents dependent on the output audio fragments here
|
# Start agents dependent on the output audio fragments here
|
||||||
transcriber = TranscriptionAgent(audio_out_address)
|
transcriber = TranscriptionAgent(audio_out_address)
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ async def receive_command(command: SpeechCommand, request: Request):
|
|||||||
topic = b"command"
|
topic = b"command"
|
||||||
|
|
||||||
# TODO: Check with Kasper
|
# TODO: Check with Kasper
|
||||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
pub_socket: Socket = request.app.state.endpoints_pub_socket
|
||||||
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||||
|
|
||||||
return {"status": "Command received"}
|
return {"status": "Command received"}
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ from control_backend.agents.bdi.bdi_core import BDICoreAgent
|
|||||||
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
|
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
|
||||||
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
|
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
|
||||||
from control_backend.agents.llm.llm import LLMAgent
|
from control_backend.agents.llm.llm import LLMAgent
|
||||||
|
|
||||||
# Internal imports
|
|
||||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||||
from control_backend.agents.vad_agent import VADAgent
|
from control_backend.agents.vad_agent import VADAgent
|
||||||
from control_backend.api.v1.router import api_router
|
from control_backend.api.v1.router import api_router
|
||||||
@@ -99,6 +97,8 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
|
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
|
||||||
await _temp_vad_agent.start()
|
await _temp_vad_agent.start()
|
||||||
|
logger.info("VAD agent started, now making ready...")
|
||||||
|
await _temp_vad_agent.streaming_behaviour.reset()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
@@ -7,25 +7,21 @@ import zmq
|
|||||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.fixture
|
||||||
async def test_setup_bind(monkeypatch):
|
def zmq_context(mocker):
|
||||||
"""Test setup with bind=True"""
|
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
|
||||||
fake_socket = MagicMock()
|
mock_context.return_value = MagicMock()
|
||||||
fake_context = MagicMock()
|
return mock_context
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
|
|
||||||
# Patch Context.instance() to return fake_context
|
|
||||||
monkeypatch.setattr(
|
@pytest.mark.asyncio
|
||||||
"control_backend.agents.ri_command_agent.Context",
|
async def test_setup_bind(zmq_context, mocker):
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
"""Test setup with bind=True"""
|
||||||
)
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
|
|
||||||
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
|
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
|
||||||
|
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
|
||||||
monkeypatch.setattr(
|
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||||
"control_backend.agents.ri_command_agent.settings",
|
|
||||||
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
|
|
||||||
)
|
|
||||||
|
|
||||||
await agent.setup()
|
await agent.setup()
|
||||||
|
|
||||||
@@ -36,23 +32,13 @@ async def test_setup_bind(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_connect(monkeypatch):
|
async def test_setup_connect(zmq_context, mocker):
|
||||||
"""Test setup with bind=False"""
|
"""Test setup with bind=False"""
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
|
|
||||||
# Patch Context.instance() to return fake_context
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_command_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
|
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
|
||||||
monkeypatch.setattr(
|
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
|
||||||
"control_backend.agents.ri_command_agent.settings",
|
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||||
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
|
|
||||||
)
|
|
||||||
|
|
||||||
await agent.setup()
|
await agent.setup()
|
||||||
|
|
||||||
|
|||||||
@@ -84,25 +84,24 @@ def fake_json_invalid_id_negototiate():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def zmq_context(mocker):
|
||||||
|
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
|
||||||
|
mock_context.return_value = MagicMock()
|
||||||
|
return mock_context
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
|
async def test_setup_creates_socket_and_negotiate_1(zmq_context):
|
||||||
"""
|
"""
|
||||||
Test the setup of the communication agent
|
Test the setup of the communication agent
|
||||||
"""
|
"""
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.recv_json = fake_json_correct_negototiate_1()
|
fake_socket.recv_json = fake_json_correct_negototiate_1()
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
# Mock context.socket to return our fake socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
@@ -135,24 +134,16 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
|
async def test_setup_creates_socket_and_negotiate_2(zmq_context):
|
||||||
"""
|
"""
|
||||||
Test the setup of the communication agent
|
Test the setup of the communication agent
|
||||||
"""
|
"""
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.recv_json = fake_json_correct_negototiate_2()
|
fake_socket.recv_json = fake_json_correct_negototiate_2()
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
# Mock context.socket to return our fake socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
@@ -185,24 +176,16 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
|
async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
|
||||||
"""
|
"""
|
||||||
Test the functionality of setup with incorrect negotiation message
|
Test the functionality of setup with incorrect negotiation message
|
||||||
"""
|
"""
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.recv_json = fake_json_wrong_negototiate_1()
|
fake_socket.recv_json = fake_json_wrong_negototiate_1()
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
# Mock context.socket to return our fake socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
|
|
||||||
# We are sending wrong negotiation info to the communication agent,
|
# We are sending wrong negotiation info to the communication agent,
|
||||||
@@ -235,24 +218,16 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
|
async def test_setup_creates_socket_and_negotiate_4(zmq_context):
|
||||||
"""
|
"""
|
||||||
Test the setup of the communication agent with different bind value
|
Test the setup of the communication agent with different bind value
|
||||||
"""
|
"""
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.recv_json = fake_json_correct_negototiate_3()
|
fake_socket.recv_json = fake_json_correct_negototiate_3()
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
# Mock context.socket to return our fake socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
@@ -284,24 +259,16 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
|
async def test_setup_creates_socket_and_negotiate_5(zmq_context):
|
||||||
"""
|
"""
|
||||||
Test the setup of the communication agent
|
Test the setup of the communication agent
|
||||||
"""
|
"""
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.recv_json = fake_json_correct_negototiate_4()
|
fake_socket.recv_json = fake_json_correct_negototiate_4()
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
# Mock context.socket to return our fake socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
@@ -333,24 +300,16 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
|
async def test_setup_creates_socket_and_negotiate_6(zmq_context):
|
||||||
"""
|
"""
|
||||||
Test the setup of the communication agent
|
Test the setup of the communication agent
|
||||||
"""
|
"""
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.recv_json = fake_json_correct_negototiate_5()
|
fake_socket.recv_json = fake_json_correct_negototiate_5()
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
# Mock context.socket to return our fake socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
@@ -382,28 +341,20 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
|
async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
|
||||||
"""
|
"""
|
||||||
Test the functionality of setup with incorrect id
|
Test the functionality of setup with incorrect id
|
||||||
"""
|
"""
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.recv_json = fake_json_invalid_id_negototiate()
|
fake_socket.recv_json = fake_json_invalid_id_negototiate()
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
# Mock context.socket to return our fake socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock RICommandAgent agent startup
|
# Mock RICommandAgent agent startup
|
||||||
|
|
||||||
# We are sending wrong negotiation info to the communication agent,
|
# We are sending wrong negotiation info to the communication agent,
|
||||||
# so we should retry and expect a etter response, within a limited time.
|
# so we should retry and expect a better response, within a limited time.
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
) as MockCommandAgent:
|
) as MockCommandAgent:
|
||||||
@@ -430,24 +381,16 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
|
async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
|
||||||
"""
|
"""
|
||||||
Test the functionality of setup with incorrect negotiation message
|
Test the functionality of setup with incorrect negotiation message
|
||||||
"""
|
"""
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
# Mock context.socket to return our fake socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
) as MockCommandAgent:
|
) as MockCommandAgent:
|
||||||
@@ -534,8 +477,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_listen_behaviour_timeout(caplog):
|
async def test_listen_behaviour_timeout(zmq_context, caplog):
|
||||||
fake_socket = AsyncMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
# recv_json will never resolve, simulate timeout
|
# recv_json will never resolve, simulate timeout
|
||||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||||
@@ -585,20 +528,13 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_unexpected_exception(monkeypatch, caplog):
|
async def test_setup_unexpected_exception(zmq_context, caplog):
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
# Simulate unexpected exception during recv_json()
|
# Simulate unexpected exception during recv_json()
|
||||||
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
|
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = RICommunicationAgent(
|
agent = RICommunicationAgent(
|
||||||
"test@server",
|
"test@server",
|
||||||
"password",
|
"password",
|
||||||
@@ -614,9 +550,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup_unpacking_exception(monkeypatch, caplog):
|
async def test_setup_unpacking_exception(zmq_context, caplog):
|
||||||
# --- Arrange ---
|
# --- Arrange ---
|
||||||
fake_socket = MagicMock()
|
fake_socket = zmq_context.return_value.socket.return_value
|
||||||
fake_socket.send_json = AsyncMock()
|
fake_socket.send_json = AsyncMock()
|
||||||
fake_socket.send_multipart = AsyncMock()
|
fake_socket.send_multipart = AsyncMock()
|
||||||
|
|
||||||
@@ -627,14 +563,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
|
|||||||
} # missing 'port' and 'bind'
|
} # missing 'port' and 'bind'
|
||||||
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
|
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
|
||||||
|
|
||||||
# Patch context.socket
|
|
||||||
fake_context = MagicMock()
|
|
||||||
fake_context.socket.return_value = fake_socket
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"control_backend.agents.ri_communication_agent.Context",
|
|
||||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Patch RICommandAgent so it won't actually start
|
# Patch RICommandAgent so it won't actually start
|
||||||
with patch(
|
with patch(
|
||||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def zmq_context(mocker):
|
def zmq_context(mocker):
|
||||||
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
|
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
|
||||||
|
mock_context.return_value = MagicMock()
|
||||||
|
return mock_context
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
|
|||||||
|
|
||||||
assert vad_agent.audio_in_socket is not None
|
assert vad_agent.audio_in_socket is not None
|
||||||
|
|
||||||
zmq_context.socket.assert_called_once_with(zmq.SUB)
|
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
|
||||||
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
|
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
|
||||||
|
zmq.SUBSCRIBE,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
if do_bind:
|
if do_bind:
|
||||||
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
||||||
else:
|
else:
|
||||||
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
|
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
|
||||||
|
"tcp://localhost:12345"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_out_socket_creation(zmq_context):
|
def test_out_socket_creation(zmq_context):
|
||||||
@@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context):
|
|||||||
|
|
||||||
assert vad_agent.audio_out_socket is not None
|
assert vad_agent.audio_out_socket is not None
|
||||||
|
|
||||||
zmq_context.socket.assert_called_once_with(zmq.PUB)
|
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
|
||||||
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
|
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context):
|
|||||||
Test setup failure when the audio output socket cannot be created.
|
Test setup failure when the audio output socket cannot be created.
|
||||||
"""
|
"""
|
||||||
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
||||||
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
|
||||||
|
zmq.ZMQBindError
|
||||||
|
)
|
||||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
|
||||||
await vad_agent.setup()
|
await vad_agent.setup()
|
||||||
@@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent):
|
|||||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||||
"""
|
"""
|
||||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000)
|
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
||||||
|
1000,
|
||||||
|
10000,
|
||||||
|
)
|
||||||
|
|
||||||
await vad_agent.setup()
|
await vad_agent.setup()
|
||||||
await vad_agent.stop()
|
await vad_agent.stop()
|
||||||
|
|
||||||
assert zmq_context.socket.return_value.close.call_count == 2
|
assert zmq_context.return_value.socket.return_value.close.call_count == 2
|
||||||
assert vad_agent.audio_in_socket is None
|
assert vad_agent.audio_in_socket is None
|
||||||
assert vad_agent.audio_out_socket is None
|
assert vad_agent.audio_out_socket is None
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ async def test_real_audio(mocker):
|
|||||||
audio_out_socket = AsyncMock()
|
audio_out_socket = AsyncMock()
|
||||||
|
|
||||||
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
|
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
|
||||||
|
vad_streamer._ready = True
|
||||||
for _ in audio_chunks:
|
for _ in audio_chunks:
|
||||||
await vad_streamer.run()
|
await vad_streamer.run()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@@ -16,7 +16,6 @@ def app():
|
|||||||
"""
|
"""
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(robot.router)
|
app.include_router(robot.router)
|
||||||
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
@@ -26,32 +25,30 @@ def client(app):
|
|||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
def test_receive_command_endpoint(client, app):
|
def test_receive_command_success(client):
|
||||||
"""
|
"""
|
||||||
Test that a POST to /command sends the right multipart message
|
Test for successful reception of a command. Ensures the status code is 202 and the response body
|
||||||
and returns a 202 with the expected JSON body.
|
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
|
||||||
|
expected data.
|
||||||
"""
|
"""
|
||||||
mock_socket = app.state.internal_comm_socket
|
# Arrange
|
||||||
|
mock_pub_socket = AsyncMock()
|
||||||
|
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||||
|
|
||||||
# Prepare test payload that matches SpeechCommand
|
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
|
||||||
payload = {"endpoint": "actuate/speech", "data": "yooo"}
|
speech_command = SpeechCommand(**command_data)
|
||||||
|
|
||||||
# Send POST request
|
# Act
|
||||||
response = client.post("/command", json=payload)
|
response = client.post("/command", json=command_data)
|
||||||
|
|
||||||
# Check response
|
# Assert
|
||||||
assert response.status_code == 202
|
assert response.status_code == 202
|
||||||
assert response.json() == {"status": "Command received"}
|
assert response.json() == {"status": "Command received"}
|
||||||
|
|
||||||
# Verify that the socket was called with the correct data
|
# Verify that the ZMQ socket was used correctly
|
||||||
assert mock_socket.send_multipart.called, "Socket should be used to send data"
|
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
||||||
|
[b"command", speech_command.model_dump_json().encode()]
|
||||||
args, kwargs = mock_socket.send_multipart.call_args
|
)
|
||||||
sent_data = args[0]
|
|
||||||
|
|
||||||
assert sent_data[0] == b"command"
|
|
||||||
# Check JSON encoding roughly matches
|
|
||||||
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
|
|
||||||
|
|
||||||
|
|
||||||
def test_receive_command_invalid_payload(client):
|
def test_receive_command_invalid_payload(client):
|
||||||
|
|||||||
@@ -16,12 +16,11 @@ def test_valid_speech_command_1():
|
|||||||
command = valid_command_1()
|
command = valid_command_1()
|
||||||
RIMessage.model_validate(command)
|
RIMessage.model_validate(command)
|
||||||
SpeechCommand.model_validate(command)
|
SpeechCommand.model_validate(command)
|
||||||
assert True
|
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_speech_command_1():
|
def test_invalid_speech_command_1():
|
||||||
command = invalid_command_1()
|
command = invalid_command_1()
|
||||||
RIMessage.model_validate(command)
|
RIMessage.model_validate(command)
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
SpeechCommand.model_validate(command)
|
SpeechCommand.model_validate(command)
|
||||||
assert True
|
|
||||||
|
|||||||
@@ -182,8 +182,6 @@ async def test_belief_text_values_not_lists(continuous_collector, mocker):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
|
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
|
||||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
|
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
|
||||||
# Your code calls self.send(..); patch it
|
|
||||||
# (or switch implementation to self.agent.send and patch that)
|
|
||||||
continuous_collector.send = AsyncMock()
|
continuous_collector.send = AsyncMock()
|
||||||
logger_mock = mocker.patch(
|
logger_mock = mocker.patch(
|
||||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ def streaming(audio_in_socket, audio_out_socket):
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.hub.load.return_value = (..., ...) # Mock
|
torch.hub.load.return_value = (..., ...) # Mock
|
||||||
return Streaming(audio_in_socket, audio_out_socket)
|
streaming = Streaming(audio_in_socket, audio_out_socket)
|
||||||
|
streaming._ready = True
|
||||||
|
return streaming
|
||||||
|
|
||||||
|
|
||||||
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||||
|
|||||||
@@ -5,12 +5,13 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper
|
|||||||
|
|
||||||
|
|
||||||
def test_estimate_max_tokens():
|
def test_estimate_max_tokens():
|
||||||
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
|
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
|
||||||
|
expecting 610 tokens."""
|
||||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||||
|
|
||||||
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
||||||
|
|
||||||
assert actual == 400
|
assert actual == 610
|
||||||
assert isinstance(actual, int)
|
assert isinstance(actual, int)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user