fix: fixed new tests and merged dev into branch
ref: N25B-256
This commit is contained in:
77
.githooks/check-branch-name.sh
Executable file
77
.githooks/check-branch-name.sh
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script checks if the current branch name follows the specified format.
|
||||
# It's designed to be used as a 'pre-commit' git hook.
|
||||
|
||||
# Format: <type>/<short-description>
|
||||
# Example: feat/add-user-login
|
||||
|
||||
# --- Configuration ---
|
||||
# An array of allowed commit types
|
||||
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
|
||||
# An array of branches to ignore
|
||||
IGNORED_BRANCHES=(main dev demo)
|
||||
|
||||
# --- Colors for Output ---
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# --- Helper Functions ---
|
||||
error_exit() {
|
||||
echo -e "${RED}ERROR: $1${NC}" >&2
|
||||
echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# --- Main Logic ---
|
||||
|
||||
# 1. Get the current branch name
|
||||
BRANCH_NAME=$(git symbolic-ref --short HEAD)
|
||||
|
||||
# 2. Check if the current branch is in the ignored list
|
||||
for ignored_branch in "${IGNORED_BRANCHES[@]}"; do
|
||||
if [ "$BRANCH_NAME" == "$ignored_branch" ]; then
|
||||
echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}"
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
|
||||
# 3. Validate the overall structure: <type>/<description>
|
||||
if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then
|
||||
error_exit "Branch name must be in the format: <type>/<short-description>\nExample: feat/add-user-login"
|
||||
fi
|
||||
|
||||
# 4. Extract the type and description
|
||||
TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1)
|
||||
DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-)
|
||||
|
||||
# 5. Validate the <type>
|
||||
type_valid=false
|
||||
for allowed_type in "${ALLOWED_TYPES[@]}"; do
|
||||
if [ "$TYPE" == "$allowed_type" ]; then
|
||||
type_valid=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$type_valid" == false ]; then
|
||||
error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}"
|
||||
fi
|
||||
|
||||
# 6. Validate the <short-description>
|
||||
# Regex breakdown:
|
||||
# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word).
|
||||
# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times.
|
||||
# $ - End of the string.
|
||||
# This entire pattern enforces 1 to 6 words total, separated by dashes.
|
||||
DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$"
|
||||
|
||||
if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then
|
||||
error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature"
|
||||
fi
|
||||
|
||||
# If all checks pass, exit successfully
|
||||
echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}"
|
||||
exit 0
|
||||
135
.githooks/check-commit-msg.sh
Executable file
135
.githooks/check-commit-msg.sh
Executable file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script checks if a commit message follows the specified format.
|
||||
# It's designed to be used as a 'commit-msg' git hook.
|
||||
|
||||
# Format:
|
||||
# <type>: <short description>
|
||||
#
|
||||
# [optional]<body>
|
||||
#
|
||||
# [ref/close]: <issue identifier>
|
||||
|
||||
# --- Configuration ---
|
||||
# An array of allowed commit types
|
||||
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
|
||||
|
||||
# --- Colors for Output ---
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# The first argument to the hook is the path to the file containing the commit message
|
||||
COMMIT_MSG_FILE=$1
|
||||
|
||||
# --- Automated Commit Detection ---
|
||||
|
||||
# Read the first line (header) for initial checks
|
||||
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||
|
||||
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
|
||||
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
|
||||
MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
|
||||
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
|
||||
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check for Revert commits
|
||||
# Example: "Revert "feat: add new feature""
|
||||
REVERT_PATTERN="^Revert \".*\""
|
||||
if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then
|
||||
echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check for Cherry-pick commits (this pattern appears at the end of the message)
|
||||
# Example: "(cherry picked from commit deadbeef...)"
|
||||
# We use grep -q to search the whole file quietly.
|
||||
CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)"
|
||||
if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then
|
||||
echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check for Squash
|
||||
# Example: "Squash commits ..."
|
||||
SQUASH_PATTERN="^Squash .+"
|
||||
if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then
|
||||
echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# --- Validation Functions ---
|
||||
|
||||
# Function to print an error message and exit
|
||||
# Usage: error_exit "Your error message here"
|
||||
error_exit() {
|
||||
# >&2 redirects echo to stderr
|
||||
echo -e "${RED}ERROR: $1${NC}" >&2
|
||||
echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# --- Main Logic ---
|
||||
|
||||
# 1. Read the header (first line) of the commit message
|
||||
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||
|
||||
# 2. Validate the header format: <type>: <description>
|
||||
# Regex breakdown:
|
||||
# ^(type1|type2|...) - Starts with one of the allowed types
|
||||
# : - Followed by a literal colon
|
||||
# \s - Followed by a single space
|
||||
# .+ - Followed by one or more characters for the description
|
||||
# $ - End of the line
|
||||
TYPES_REGEX=$(
|
||||
IFS="|"
|
||||
echo "${ALLOWED_TYPES[*]}"
|
||||
)
|
||||
HEADER_REGEX="^($TYPES_REGEX): .+$"
|
||||
|
||||
if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then
|
||||
error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature"
|
||||
fi
|
||||
|
||||
# Only validate footer if commit type is not chore
|
||||
TYPE=$(echo "$HEADER" | cut -d':' -f1)
|
||||
if [ "$TYPE" != "chore" ]; then
|
||||
# 3. Validate the footer (last line) of the commit message
|
||||
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE")
|
||||
|
||||
# Regex breakdown:
|
||||
# ^(ref|close) - Starts with 'ref' or 'close'
|
||||
# : - Followed by a literal colon
|
||||
# \s - Followed by a single space
|
||||
# N25B- - Followed by the literal string 'N25B-'
|
||||
# [0-9]+ - Followed by one or more digits
|
||||
# $ - End of the line
|
||||
FOOTER_REGEX="^(ref|close): N25B-[0-9]+$"
|
||||
|
||||
if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then
|
||||
error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123"
|
||||
fi
|
||||
fi
|
||||
|
||||
# 4. If the message has more than 2 lines, validate the separator
|
||||
# A blank line must exist between the header and the body.
|
||||
LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace
|
||||
|
||||
# We only care if there is a body. Header + Footer = 2 lines.
|
||||
# Header + Blank Line + Body... + Footer > 2 lines.
|
||||
if [ "$LINE_COUNT" -gt 2 ]; then
|
||||
# Get the second line
|
||||
SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE")
|
||||
|
||||
# Check if the second line is NOT empty. If it's not, it's an error.
|
||||
if [ -n "$SECOND_LINE" ]; then
|
||||
error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present."
|
||||
fi
|
||||
fi
|
||||
|
||||
# If all checks pass, exit with success
|
||||
echo -e "${GREEN}Commit message is valid.${NC}"
|
||||
exit 0
|
||||
@@ -1,16 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
commit_msg_file=$1
|
||||
commit_msg=$(cat "$commit_msg_file")
|
||||
|
||||
if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then
|
||||
if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then
|
||||
exit 0
|
||||
else
|
||||
echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "❌ Commit message invalid! Must start with <type>: <description>"
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Get current branch
|
||||
branch=$(git rev-parse --abbrev-ref HEAD)
|
||||
|
||||
if echo "$branch" | grep -Eq "(dev|main)"; then
|
||||
echo 0
|
||||
fi
|
||||
|
||||
# allowed pattern <type/>
|
||||
if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then
|
||||
exit 0
|
||||
else
|
||||
echo "❌ Invalid branch name: $branch"
|
||||
echo "Branch must be named <type>/<description-of-branch> (must have one to six words separated by a dash)"
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,9 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
echo "#<type>: <description>
|
||||
|
||||
#[optional body]
|
||||
|
||||
#[optional footer(s)]
|
||||
|
||||
#[ref/close]: <issue identifier>" > $1
|
||||
@@ -1,10 +1,24 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.14.2
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
args: [ --fix ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.14.2
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
# Configure local hooks
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-commit-msg
|
||||
name: Check commit message format
|
||||
entry: .githooks/check-commit-msg.sh
|
||||
language: script
|
||||
stages: [commit-msg]
|
||||
- id: check-branch-name
|
||||
name: Check branch name format
|
||||
entry: .githooks/check-branch-name.sh
|
||||
language: script
|
||||
stages: [pre-commit]
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
22
README.md
22
README.md
@@ -47,21 +47,19 @@ Or for integration tests:
|
||||
uv run --group integration-test pytest test/integration
|
||||
```
|
||||
|
||||
## GitHooks
|
||||
## Git Hooks
|
||||
|
||||
To activate automatic commits/branch name checks run:
|
||||
To activate automatic linting, formatting, branch name checks and commit message checks, run:
|
||||
|
||||
```shell
|
||||
git config --local core.hooksPath .githooks
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
uv run pre-commit install --hook-type commit-msg
|
||||
```
|
||||
|
||||
If your commit fails its either:
|
||||
branch name != <type>/description-of-branch ,
|
||||
commit name != <type>: description of the commit.
|
||||
<ref>: N25B-Num's
|
||||
You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running:
|
||||
|
||||
To add automatic linting and formatting, run:
|
||||
```bash
|
||||
git config --local --unset core.hooksPath
|
||||
```
|
||||
|
||||
```shell
|
||||
uv run pre-commit install
|
||||
```
|
||||
Then run the pre-commit install commands again.
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
from spade.message import Message
|
||||
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
|
||||
class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
||||
@@ -10,7 +12,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
||||
Adds behavior to receive responses from the LLM Agent.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger("BDI/LLM Reciever")
|
||||
logger = logging.getLogger("BDI/LLM Receiver")
|
||||
|
||||
async def run(self):
|
||||
msg = await self.receive(timeout=2)
|
||||
@@ -22,7 +24,20 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
||||
case settings.agent_settings.llm_agent_name:
|
||||
content = msg.body
|
||||
self.logger.info("Received LLM response: %s", content)
|
||||
# Here the BDI can pass the message back as a response
|
||||
|
||||
speech_command = SpeechCommand(data=content)
|
||||
|
||||
message = Message(
|
||||
to=settings.agent_settings.ri_command_agent_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host,
|
||||
sender=self.agent.jid,
|
||||
body=speech_command.model_dump_json(),
|
||||
)
|
||||
|
||||
self.logger.debug("Sending message: %s", message)
|
||||
|
||||
await self.send(message)
|
||||
case _:
|
||||
self.logger.debug("Not from the llm, discarding message")
|
||||
pass
|
||||
|
||||
@@ -13,23 +13,23 @@ class BeliefFromText(CyclicBehaviour):
|
||||
|
||||
# TODO: LLM prompt nog hardcoded
|
||||
llm_instruction_prompt = """
|
||||
You are an information extraction assistent for a BDI agent.
|
||||
Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
|
||||
You will receive a JSON object with "beliefs"
|
||||
(a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript).
|
||||
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
||||
from a user's text to bind a list of ungrounded beliefs. Rules:
|
||||
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
|
||||
and "text" (user's transcript).
|
||||
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
|
||||
A single piece of text might contain multiple instances that match a belief.
|
||||
Respond ONLY with a single JSON object.
|
||||
The JSON object's keys should be the belief functors (e.g., "weather").
|
||||
The value for each key must be a list of lists.
|
||||
Each inner list must contain the extracted arguments
|
||||
(as strings) for one instance of that belief.
|
||||
CRITICAL: If no information in the text matches a belief,
|
||||
DO NOT include that key in your response.
|
||||
Each inner list must contain the extracted arguments (as strings) for one instance \
|
||||
of that belief.
|
||||
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
|
||||
in your response.
|
||||
"""
|
||||
|
||||
# on_start agent receives message containing the beliefs to look out
|
||||
# for and sets up the LLM with instruction prompt
|
||||
# on_start agent receives message containing the beliefs to look out for and
|
||||
# sets up the LLM with instruction prompt
|
||||
# async def on_start(self):
|
||||
# msg = await self.receive(timeout=0.1)
|
||||
# self.beliefs = dict uit message
|
||||
|
||||
@@ -70,8 +70,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
Expected payload:
|
||||
{
|
||||
"type": "belief_extraction_text",
|
||||
"beliefs": {"user_said": ["hello"","Can you help me?",
|
||||
"stop talking to me","No","Pepper do a dance"]}
|
||||
"beliefs": {"user_said": ["Can you help me?"]}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ LLM Agent module for routing text queries from the BDI Core Agent to a local LLM
|
||||
service and returning its responses back to the BDI Core Agent.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from spade.agent import Agent
|
||||
@@ -54,11 +56,16 @@ class LLMAgent(Agent):
|
||||
|
||||
async def _process_bdi_message(self, message: Message):
|
||||
"""
|
||||
Forwards user text to the LLM and replies with the generated text.
|
||||
Forwards user text from the BDI to the LLM and replies with the generated text in chunks
|
||||
separated by punctuation.
|
||||
"""
|
||||
user_text = message.body
|
||||
llm_response = await self._query_llm(user_text)
|
||||
await self._reply(llm_response)
|
||||
# Consume the streaming generator and send a reply for every chunk
|
||||
async for chunk in self._query_llm(user_text):
|
||||
await self._reply(chunk)
|
||||
self.agent.logger.debug(
|
||||
"Finished processing BDI message. Response sent in chunks to BDI Core Agent."
|
||||
)
|
||||
|
||||
async def _reply(self, msg: str):
|
||||
"""
|
||||
@@ -69,48 +76,89 @@ class LLMAgent(Agent):
|
||||
body=msg,
|
||||
)
|
||||
await self.send(reply)
|
||||
self.agent.logger.info("Reply sent to BDI Core Agent")
|
||||
|
||||
async def _query_llm(self, prompt: str) -> str:
|
||||
async def _query_llm(self, prompt: str) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Sends a chat completion request to the local LLM service.
|
||||
Sends a chat completion request to the local LLM service and streams the response by
|
||||
yielding fragments separated by punctuation like.
|
||||
|
||||
:param prompt: Input text prompt to pass to the LLM.
|
||||
:return: LLM-generated content or fallback message.
|
||||
:yield: Fragments of the LLM-generated content.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
# Example dynamic content for future (optional)
|
||||
instructions = LLMInstructions(
|
||||
"- Be friendly and respectful.\n"
|
||||
"- Make the conversation feel natural and engaging.\n"
|
||||
"- Speak like a pirate.\n"
|
||||
"- When the user asks what you can do, tell them.",
|
||||
"- Try to learn the user's name during conversation.\n"
|
||||
"- Suggest playing a game of asking yes or no questions where you think of a word "
|
||||
"and the user must guess it.",
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": instructions.build_developer_instruction(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
]
|
||||
|
||||
instructions = LLMInstructions()
|
||||
developer_instruction = instructions.build_developer_instruction()
|
||||
try:
|
||||
current_chunk = ""
|
||||
async for token in self._stream_query_llm(messages):
|
||||
current_chunk += token
|
||||
|
||||
response = await client.post(
|
||||
# Stream the message in chunks separated by punctuation.
|
||||
# We include the delimiter in the emitted chunk for natural flow.
|
||||
pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL)
|
||||
for m in pattern.finditer(current_chunk):
|
||||
chunk = m.group(0)
|
||||
if chunk:
|
||||
yield current_chunk
|
||||
current_chunk = ""
|
||||
|
||||
# Yield any remaining tail
|
||||
if current_chunk:
|
||||
yield current_chunk
|
||||
except httpx.HTTPError as err:
|
||||
self.agent.logger.error("HTTP error.", exc_info=err)
|
||||
yield "LLM service unavailable."
|
||||
except Exception as err:
|
||||
self.agent.logger.error("Unexpected error.", exc_info=err)
|
||||
yield "Error processing the request."
|
||||
|
||||
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
|
||||
"""Raises httpx.HTTPError when the API gives an error."""
|
||||
async with httpx.AsyncClient(timeout=None) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
settings.llm_settings.local_llm_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={
|
||||
"model": settings.llm_settings.local_llm_model,
|
||||
"messages": [
|
||||
{"role": "developer", "content": developer_instruction},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"messages": messages,
|
||||
"temperature": 0.3,
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data: dict[str, Any] = response.json()
|
||||
return (
|
||||
data.get("choices", [{}])[0]
|
||||
.get("message", {})
|
||||
.get("content", "No response")
|
||||
)
|
||||
except httpx.HTTPError as err:
|
||||
self.agent.logger.error("HTTP error: %s", err)
|
||||
return "LLM service unavailable."
|
||||
except Exception as err:
|
||||
self.agent.logger.error("Unexpected error: %s", err)
|
||||
return "Error processing the request."
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data = line[len("data: ") :]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
event = json.loads(data)
|
||||
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
|
||||
if delta:
|
||||
yield delta
|
||||
except json.JSONDecodeError:
|
||||
self.agent.logger.error("Failed to parse LLM response: %s", data)
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
|
||||
@@ -28,7 +28,9 @@ class LLMInstructions:
|
||||
"""
|
||||
sections = [
|
||||
"You are a Pepper robot engaging in natural human conversation.",
|
||||
"Keep responses between 1–5 sentences, unless instructed otherwise.\n",
|
||||
"Keep responses between 1–3 sentences, unless told otherwise.\n",
|
||||
"You're given goals to reach. Reach them in order, but make the conversation feel "
|
||||
"natural. Some turns you should not try to achieve your goals.\n",
|
||||
]
|
||||
|
||||
if self.norms:
|
||||
|
||||
@@ -11,8 +11,9 @@ class BeliefTextAgent(Agent):
|
||||
class SendOnceBehaviourBlfText(OneShotBehaviour):
|
||||
async def run(self):
|
||||
to_jid = (
|
||||
f"{settings.agent_settings.belief_collector_agent_name}"
|
||||
f"@{settings.agent_settings.host}"
|
||||
settings.agent_settings.belief_collector_agent_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host
|
||||
)
|
||||
|
||||
# Send multiple beliefs in one JSON payload
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import spade.agent
|
||||
import zmq
|
||||
from spade.agent import Agent
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
@@ -32,6 +33,8 @@ class RICommandAgent(Agent):
|
||||
self.bind = bind
|
||||
|
||||
class SendCommandsBehaviour(CyclicBehaviour):
|
||||
"""Behaviour for sending commands received from the UI."""
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Run the command publishing loop indefinetely.
|
||||
@@ -50,6 +53,18 @@ class RICommandAgent(Agent):
|
||||
except Exception as e:
|
||||
logger.error("Error processing message: %s", e)
|
||||
|
||||
class SendPythonCommandsBehaviour(CyclicBehaviour):
|
||||
"""Behaviour for sending commands received from other Python agents."""
|
||||
|
||||
async def run(self):
|
||||
message: spade.agent.Message = await self.receive(timeout=0.1)
|
||||
if message and message.to == self.agent.jid:
|
||||
try:
|
||||
speech_command = SpeechCommand.model_validate_json(message.body)
|
||||
await self.agent.pubsocket.send_json(speech_command.model_dump())
|
||||
except Exception as e:
|
||||
logger.error("Error processing message: %s", e)
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Setup the command agent
|
||||
@@ -73,5 +88,6 @@ class RICommandAgent(Agent):
|
||||
# Add behaviour to our agent
|
||||
commands_behaviour = self.SendCommandsBehaviour()
|
||||
self.add_behaviour(commands_behaviour)
|
||||
self.add_behaviour(self.SendPythonCommandsBehaviour())
|
||||
|
||||
logger.info("Finished setting up %s", self.jid)
|
||||
|
||||
@@ -63,7 +63,25 @@ class RICommunicationAgent(Agent):
|
||||
# We didnt get a reply :(
|
||||
except TimeoutError:
|
||||
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
||||
self.kill()
|
||||
|
||||
# Tell UI we're disconnected.
|
||||
topic = b"ping"
|
||||
data = json.dumps(False).encode()
|
||||
if self.agent.pub_socket is None:
|
||||
logger.error("communication agent pub socket not correctly initialized.")
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.agent.pub_socket.send_multipart([topic, data]), 5
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.error(
|
||||
"Initial connection ping for router timed"
|
||||
" out in ri_communication_agent."
|
||||
)
|
||||
|
||||
# Try to reboot.
|
||||
self.agent.setup()
|
||||
|
||||
logger.debug('Received message "%s"', message)
|
||||
if "endpoint" not in message:
|
||||
|
||||
@@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC):
|
||||
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||
"""
|
||||
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
|
||||
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens.
|
||||
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||
:return: The estimated length of the transcribed audio in tokens.
|
||||
"""
|
||||
length_seconds = len(audio) / 16_000
|
||||
length_minutes = length_seconds / 60
|
||||
word_count = length_minutes * 300
|
||||
word_count = length_minutes * 450
|
||||
token_count = word_count / 3 * 4
|
||||
return int(token_count)
|
||||
return int(token_count) + 10
|
||||
|
||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||
"""
|
||||
@@ -85,9 +85,10 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return mlx_whisper.transcribe(
|
||||
audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio)
|
||||
)["text"]
|
||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
|
||||
audio,
|
||||
path_or_hf_repo=self.model_name,
|
||||
**self._get_decode_options(audio),
|
||||
)["text"].strip()
|
||||
|
||||
|
||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
@@ -103,6 +104,4 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return whisper.transcribe(
|
||||
self.model, audio, decode_options=self._get_decode_options(audio)
|
||||
)["text"]
|
||||
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"]
|
||||
|
||||
@@ -58,6 +58,10 @@ class TranscriptionAgent(Agent):
|
||||
audio = await self.audio_in_socket.recv()
|
||||
audio = np.frombuffer(audio, dtype=np.float32)
|
||||
speech = await self._transcribe(audio)
|
||||
if not speech:
|
||||
logger.info("Nothing transcribed.")
|
||||
return
|
||||
|
||||
logger.info("Transcribed speech: %s", speech)
|
||||
|
||||
await self._share_transcription(speech)
|
||||
|
||||
@@ -54,8 +54,20 @@ class Streaming(CyclicBehaviour):
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = 100 # Used to allow small pauses in speech
|
||||
self._ready = False
|
||||
|
||||
async def reset(self):
|
||||
"""Clears the ZeroMQ queue and tells this behavior to start."""
|
||||
discarded = 0
|
||||
while await self.audio_in_poller.poll(1) is not None:
|
||||
discarded += 1
|
||||
logging.info(f"Discarded {discarded} audio packets before starting.")
|
||||
self._ready = True
|
||||
|
||||
async def run(self) -> None:
|
||||
if not self._ready:
|
||||
return
|
||||
|
||||
data = await self.audio_in_poller.poll()
|
||||
if data is None:
|
||||
if len(self.audio_buffer) > 0:
|
||||
@@ -107,6 +119,8 @@ class VADAgent(Agent):
|
||||
self.audio_in_socket: azmq.Socket | None = None
|
||||
self.audio_out_socket: azmq.Socket | None = None
|
||||
|
||||
self.streaming_behaviour: Streaming | None = None
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Stop listening to audio, stop publishing audio, close sockets.
|
||||
@@ -149,8 +163,8 @@ class VADAgent(Agent):
|
||||
return
|
||||
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||
|
||||
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
|
||||
self.add_behaviour(streaming)
|
||||
self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket)
|
||||
self.add_behaviour(self.streaming_behaviour)
|
||||
|
||||
# Start agents dependent on the output audio fragments here
|
||||
transcriber = TranscriptionAgent(audio_out_address)
|
||||
|
||||
@@ -22,8 +22,8 @@ async def receive_command(command: SpeechCommand, request: Request):
|
||||
topic = b"command"
|
||||
|
||||
# TODO: Check with Kasper
|
||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
||||
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||
pub_socket: Socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||
|
||||
return {"status": "Command received"}
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ from control_backend.agents.bdi.bdi_core import BDICoreAgent
|
||||
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
|
||||
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
|
||||
from control_backend.agents.llm.llm import LLMAgent
|
||||
|
||||
# Internal imports
|
||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||
from control_backend.agents.vad_agent import VADAgent
|
||||
from control_backend.api.v1.router import api_router
|
||||
@@ -99,6 +97,8 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
|
||||
await _temp_vad_agent.start()
|
||||
logger.info("VAD agent started, now making ready...")
|
||||
await _temp_vad_agent.streaming_behaviour.reset()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -7,25 +7,21 @@ import zmq
|
||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(monkeypatch):
|
||||
"""Test setup with bind=True"""
|
||||
fake_socket = MagicMock()
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
# Patch Context.instance() to return fake_context
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_command_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(zmq_context, mocker):
|
||||
"""Test setup with bind=True"""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
|
||||
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_command_agent.settings",
|
||||
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
|
||||
)
|
||||
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
await agent.setup()
|
||||
|
||||
@@ -36,23 +32,13 @@ async def test_setup_bind(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_connect(monkeypatch):
|
||||
async def test_setup_connect(zmq_context, mocker):
|
||||
"""Test setup with bind=False"""
|
||||
fake_socket = MagicMock()
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
|
||||
# Patch Context.instance() to return fake_context
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_command_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
|
||||
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_command_agent.settings",
|
||||
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
|
||||
)
|
||||
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
await agent.setup()
|
||||
|
||||
|
||||
@@ -84,25 +84,24 @@ def fake_json_invalid_id_negototiate():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
|
||||
async def test_setup_creates_socket_and_negotiate_1(zmq_context):
|
||||
"""
|
||||
Test the setup of the communication agent
|
||||
"""
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_1()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
@@ -135,24 +134,16 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
|
||||
async def test_setup_creates_socket_and_negotiate_2(zmq_context):
|
||||
"""
|
||||
Test the setup of the communication agent
|
||||
"""
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_2()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
@@ -185,24 +176,16 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
|
||||
async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
|
||||
"""
|
||||
Test the functionality of setup with incorrect negotiation message
|
||||
"""
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_wrong_negototiate_1()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
|
||||
# We are sending wrong negotiation info to the communication agent,
|
||||
@@ -235,24 +218,16 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
|
||||
async def test_setup_creates_socket_and_negotiate_4(zmq_context):
|
||||
"""
|
||||
Test the setup of the communication agent with different bind value
|
||||
"""
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_3()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
@@ -284,24 +259,16 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
|
||||
async def test_setup_creates_socket_and_negotiate_5(zmq_context):
|
||||
"""
|
||||
Test the setup of the communication agent
|
||||
"""
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_4()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
@@ -333,24 +300,16 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
|
||||
async def test_setup_creates_socket_and_negotiate_6(zmq_context):
|
||||
"""
|
||||
Test the setup of the communication agent
|
||||
"""
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_5()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
@@ -382,28 +341,20 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
|
||||
async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
|
||||
"""
|
||||
Test the functionality of setup with incorrect id
|
||||
"""
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_invalid_id_negototiate()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
|
||||
# We are sending wrong negotiation info to the communication agent,
|
||||
# so we should retry and expect a etter response, within a limited time.
|
||||
# so we should retry and expect a better response, within a limited time.
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
) as MockCommandAgent:
|
||||
@@ -430,24 +381,16 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
|
||||
async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
|
||||
"""
|
||||
Test the functionality of setup with incorrect negotiation message
|
||||
"""
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
) as MockCommandAgent:
|
||||
@@ -534,8 +477,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_behaviour_timeout(caplog):
|
||||
fake_socket = AsyncMock()
|
||||
async def test_listen_behaviour_timeout(zmq_context, caplog):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
# recv_json will never resolve, simulate timeout
|
||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
@@ -585,20 +528,13 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_unexpected_exception(monkeypatch, caplog):
|
||||
fake_socket = MagicMock()
|
||||
async def test_setup_unexpected_exception(zmq_context, caplog):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
# Simulate unexpected exception during recv_json()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
@@ -614,9 +550,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_unpacking_exception(monkeypatch, caplog):
|
||||
async def test_setup_unpacking_exception(zmq_context, caplog):
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
@@ -627,14 +563,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
|
||||
} # missing 'port' and 'bind'
|
||||
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
|
||||
|
||||
# Patch context.socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Patch RICommandAgent so it won't actually start
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
|
||||
@@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
|
||||
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
|
||||
|
||||
assert vad_agent.audio_in_socket is not None
|
||||
|
||||
zmq_context.socket.assert_called_once_with(zmq.SUB)
|
||||
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
|
||||
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
|
||||
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
|
||||
zmq.SUBSCRIBE,
|
||||
"",
|
||||
)
|
||||
|
||||
if do_bind:
|
||||
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
||||
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
||||
else:
|
||||
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
|
||||
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
|
||||
"tcp://localhost:12345"
|
||||
)
|
||||
|
||||
|
||||
def test_out_socket_creation(zmq_context):
|
||||
@@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context):
|
||||
|
||||
assert vad_agent.audio_out_socket is not None
|
||||
|
||||
zmq_context.socket.assert_called_once_with(zmq.PUB)
|
||||
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
|
||||
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context):
|
||||
Test setup failure when the audio output socket cannot be created.
|
||||
"""
|
||||
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
||||
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
|
||||
zmq.ZMQBindError
|
||||
)
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
await vad_agent.setup()
|
||||
@@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent):
|
||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000)
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
||||
1000,
|
||||
10000,
|
||||
)
|
||||
|
||||
await vad_agent.setup()
|
||||
await vad_agent.stop()
|
||||
|
||||
assert zmq_context.socket.return_value.close.call_count == 2
|
||||
assert zmq_context.return_value.socket.return_value.close.call_count == 2
|
||||
assert vad_agent.audio_in_socket is None
|
||||
assert vad_agent.audio_out_socket is None
|
||||
|
||||
@@ -48,6 +48,7 @@ async def test_real_audio(mocker):
|
||||
audio_out_socket = AsyncMock()
|
||||
|
||||
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
|
||||
vad_streamer._ready = True
|
||||
for _ in audio_chunks:
|
||||
await vad_streamer.run()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
@@ -16,7 +16,6 @@ def app():
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.include_router(robot.router)
|
||||
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
|
||||
return app
|
||||
|
||||
|
||||
@@ -26,32 +25,30 @@ def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_receive_command_endpoint(client, app):
|
||||
def test_receive_command_success(client):
|
||||
"""
|
||||
Test that a POST to /command sends the right multipart message
|
||||
and returns a 202 with the expected JSON body.
|
||||
Test for successful reception of a command. Ensures the status code is 202 and the response body
|
||||
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
|
||||
expected data.
|
||||
"""
|
||||
mock_socket = app.state.internal_comm_socket
|
||||
# Arrange
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
# Prepare test payload that matches SpeechCommand
|
||||
payload = {"endpoint": "actuate/speech", "data": "yooo"}
|
||||
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
|
||||
speech_command = SpeechCommand(**command_data)
|
||||
|
||||
# Send POST request
|
||||
response = client.post("/command", json=payload)
|
||||
# Act
|
||||
response = client.post("/command", json=command_data)
|
||||
|
||||
# Check response
|
||||
# Assert
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Command received"}
|
||||
|
||||
# Verify that the socket was called with the correct data
|
||||
assert mock_socket.send_multipart.called, "Socket should be used to send data"
|
||||
|
||||
args, kwargs = mock_socket.send_multipart.call_args
|
||||
sent_data = args[0]
|
||||
|
||||
assert sent_data[0] == b"command"
|
||||
# Check JSON encoding roughly matches
|
||||
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
|
||||
# Verify that the ZMQ socket was used correctly
|
||||
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
||||
[b"command", speech_command.model_dump_json().encode()]
|
||||
)
|
||||
|
||||
|
||||
def test_receive_command_invalid_payload(client):
|
||||
|
||||
@@ -16,12 +16,11 @@ def test_valid_speech_command_1():
|
||||
command = valid_command_1()
|
||||
RIMessage.model_validate(command)
|
||||
SpeechCommand.model_validate(command)
|
||||
assert True
|
||||
|
||||
|
||||
def test_invalid_speech_command_1():
|
||||
command = invalid_command_1()
|
||||
RIMessage.model_validate(command)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
SpeechCommand.model_validate(command)
|
||||
assert True
|
||||
|
||||
@@ -182,8 +182,6 @@ async def test_belief_text_values_not_lists(continuous_collector, mocker):
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
|
||||
# Your code calls self.send(..); patch it
|
||||
# (or switch implementation to self.agent.send and patch that)
|
||||
continuous_collector.send = AsyncMock()
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
|
||||
@@ -21,7 +21,9 @@ def streaming(audio_in_socket, audio_out_socket):
|
||||
import torch
|
||||
|
||||
torch.hub.load.return_value = (..., ...) # Mock
|
||||
return Streaming(audio_in_socket, audio_out_socket)
|
||||
streaming = Streaming(audio_in_socket, audio_out_socket)
|
||||
streaming._ready = True
|
||||
return streaming
|
||||
|
||||
|
||||
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||
|
||||
@@ -5,12 +5,13 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper
|
||||
|
||||
|
||||
def test_estimate_max_tokens():
|
||||
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
|
||||
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
|
||||
expecting 610 tokens."""
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
||||
|
||||
assert actual == 400
|
||||
assert actual == 610
|
||||
assert isinstance(actual, int)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user