Merge branch 'dev' into 'main'
merge dev into main See merge request ics/sp/2025/n25b/pepperplus-cb!49
This commit was merged in pull request #49.
This commit is contained in:
20
.env.example
Normal file
20
.env.example
Normal file
@@ -0,0 +1,20 @@
|
||||
# Example .env file. To use, make a copy, call it ".env" (i.e. removing the ".example" suffix), then you edit values.
|
||||
|
||||
# The hostname of the Robot Interface. Change if the Control Backend and Robot Interface are running on different computers.
|
||||
RI_HOST="localhost"
|
||||
|
||||
# URL for the local LLM API. Must be an API that implements the OpenAI Chat Completions API, but most do.
|
||||
LLM_SETTINGS__LOCAL_LLM_URL="http://localhost:1234/v1/chat/completions"
|
||||
|
||||
# Name of the local LLM model to use.
|
||||
LLM_SETTINGS__LOCAL_LLM_MODEL="gpt-oss"
|
||||
|
||||
# Number of non-speech chunks to wait before speech ended. A chunk is approximately 31 ms. Increasing this number allows longer pauses in speech, but also increases response time.
|
||||
BEHAVIOUR_SETTINGS__VAD_NON_SPEECH_PATIENCE_CHUNKS=15
|
||||
|
||||
# Timeout in milliseconds for socket polling. Increase this number if network latency/jitter is high, often the case when using Wi-Fi. Perhaps 500 ms. A symptom of this issue is transcriptions getting cut off.
|
||||
BEHAVIOUR_SETTINGS__SOCKET_POLLER_TIMEOUT_MS=100
|
||||
|
||||
|
||||
|
||||
# For an exhaustive list of options, see the control_backend.core.config module in the docs.
|
||||
77
.githooks/check-branch-name.sh
Executable file
77
.githooks/check-branch-name.sh
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script checks if the current branch name follows the specified format.
|
||||
# It's designed to be used as a 'pre-commit' git hook.
|
||||
|
||||
# Format: <type>/<short-description>
|
||||
# Example: feat/add-user-login
|
||||
|
||||
# --- Configuration ---
|
||||
# An array of allowed commit types
|
||||
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
|
||||
# An array of branches to ignore
|
||||
IGNORED_BRANCHES=(main dev demo)
|
||||
|
||||
# --- Colors for Output ---
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# --- Helper Functions ---
|
||||
error_exit() {
|
||||
echo -e "${RED}ERROR: $1${NC}" >&2
|
||||
echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# --- Main Logic ---
|
||||
|
||||
# 1. Get the current branch name
|
||||
BRANCH_NAME=$(git symbolic-ref --short HEAD)
|
||||
|
||||
# 2. Check if the current branch is in the ignored list
|
||||
for ignored_branch in "${IGNORED_BRANCHES[@]}"; do
|
||||
if [ "$BRANCH_NAME" == "$ignored_branch" ]; then
|
||||
echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}"
|
||||
exit 0
|
||||
fi
|
||||
done
|
||||
|
||||
# 3. Validate the overall structure: <type>/<description>
|
||||
if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then
|
||||
error_exit "Branch name must be in the format: <type>/<short-description>\nExample: feat/add-user-login"
|
||||
fi
|
||||
|
||||
# 4. Extract the type and description
|
||||
TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1)
|
||||
DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-)
|
||||
|
||||
# 5. Validate the <type>
|
||||
type_valid=false
|
||||
for allowed_type in "${ALLOWED_TYPES[@]}"; do
|
||||
if [ "$TYPE" == "$allowed_type" ]; then
|
||||
type_valid=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$type_valid" == false ]; then
|
||||
error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}"
|
||||
fi
|
||||
|
||||
# 6. Validate the <short-description>
|
||||
# Regex breakdown:
|
||||
# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word).
|
||||
# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times.
|
||||
# $ - End of the string.
|
||||
# This entire pattern enforces 1 to 6 words total, separated by dashes.
|
||||
DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$"
|
||||
|
||||
if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then
|
||||
error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature"
|
||||
fi
|
||||
|
||||
# If all checks pass, exit successfully
|
||||
echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}"
|
||||
exit 0
|
||||
135
.githooks/check-commit-msg.sh
Executable file
135
.githooks/check-commit-msg.sh
Executable file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This script checks if a commit message follows the specified format.
|
||||
# It's designed to be used as a 'commit-msg' git hook.
|
||||
|
||||
# Format:
|
||||
# <type>: <short description>
|
||||
#
|
||||
# [optional]<body>
|
||||
#
|
||||
# [ref/close]: <issue identifier>
|
||||
|
||||
# --- Configuration ---
|
||||
# An array of allowed commit types
|
||||
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
|
||||
|
||||
# --- Colors for Output ---
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# The first argument to the hook is the path to the file containing the commit message
|
||||
COMMIT_MSG_FILE=$1
|
||||
|
||||
# --- Automated Commit Detection ---
|
||||
|
||||
# Read the first line (header) for initial checks
|
||||
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||
|
||||
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
|
||||
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
|
||||
MERGE_PATTERN="^Merge (remote-tracking )?(branch|pull request|tag) .*"
|
||||
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
|
||||
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check for Revert commits
|
||||
# Example: "Revert "feat: add new feature""
|
||||
REVERT_PATTERN="^Revert \".*\""
|
||||
if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then
|
||||
echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check for Cherry-pick commits (this pattern appears at the end of the message)
|
||||
# Example: "(cherry picked from commit deadbeef...)"
|
||||
# We use grep -q to search the whole file quietly.
|
||||
CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)"
|
||||
if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then
|
||||
echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check for Squash
|
||||
# Example: "Squash commits ..."
|
||||
SQUASH_PATTERN="^Squash .+"
|
||||
if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then
|
||||
echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# --- Validation Functions ---
|
||||
|
||||
# Function to print an error message and exit
|
||||
# Usage: error_exit "Your error message here"
|
||||
error_exit() {
|
||||
# >&2 redirects echo to stderr
|
||||
echo -e "${RED}ERROR: $1${NC}" >&2
|
||||
echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# --- Main Logic ---
|
||||
|
||||
# 1. Read the header (first line) of the commit message
|
||||
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||
|
||||
# 2. Validate the header format: <type>: <description>
|
||||
# Regex breakdown:
|
||||
# ^(type1|type2|...) - Starts with one of the allowed types
|
||||
# : - Followed by a literal colon
|
||||
# \s - Followed by a single space
|
||||
# .+ - Followed by one or more characters for the description
|
||||
# $ - End of the line
|
||||
TYPES_REGEX=$(
|
||||
IFS="|"
|
||||
echo "${ALLOWED_TYPES[*]}"
|
||||
)
|
||||
HEADER_REGEX="^($TYPES_REGEX): .+$"
|
||||
|
||||
if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then
|
||||
error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature"
|
||||
fi
|
||||
|
||||
# Only validate footer if commit type is not chore
|
||||
TYPE=$(echo "$HEADER" | cut -d':' -f1)
|
||||
if [ "$TYPE" != "chore" ]; then
|
||||
# 3. Validate the footer (last line) of the commit message
|
||||
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE")
|
||||
|
||||
# Regex breakdown:
|
||||
# ^(ref|close) - Starts with 'ref' or 'close'
|
||||
# : - Followed by a literal colon
|
||||
# \s - Followed by a single space
|
||||
# N25B- - Followed by the literal string 'N25B-'
|
||||
# [0-9]+ - Followed by one or more digits
|
||||
# $ - End of the line
|
||||
FOOTER_REGEX="^(ref|close): N25B-[0-9]+$"
|
||||
|
||||
if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then
|
||||
error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123"
|
||||
fi
|
||||
fi
|
||||
|
||||
# 4. If the message has more than 2 lines, validate the separator
|
||||
# A blank line must exist between the header and the body.
|
||||
LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace
|
||||
|
||||
# We only care if there is a body. Header + Footer = 2 lines.
|
||||
# Header + Blank Line + Body... + Footer > 2 lines.
|
||||
if [ "$LINE_COUNT" -gt 2 ]; then
|
||||
# Get the second line
|
||||
SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE")
|
||||
|
||||
# Check if the second line is NOT empty. If it's not, it's an error.
|
||||
if [ -n "$SECOND_LINE" ]; then
|
||||
error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present."
|
||||
fi
|
||||
fi
|
||||
|
||||
# If all checks pass, exit with success
|
||||
echo -e "${GREEN}Commit message is valid.${NC}"
|
||||
exit 0
|
||||
279
.gitignore
vendored
Normal file
279
.gitignore
vendored
Normal file
@@ -0,0 +1,279 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
# Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
# poetry.lock
|
||||
# poetry.toml
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||
# pdm.lock
|
||||
# pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# pixi
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||
# pixi.lock
|
||||
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||
.pixi
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# Redis
|
||||
*.rdb
|
||||
*.aof
|
||||
*.pid
|
||||
|
||||
# RabbitMQ
|
||||
mnesia/
|
||||
rabbitmq/
|
||||
rabbitmq-data/
|
||||
|
||||
# ActiveMQ
|
||||
activemq-data/
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.envrc
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# Abstra
|
||||
# Abstra is an AI-powered process automation framework.
|
||||
# Ignore directories containing user credentials, local state, and settings.
|
||||
# Learn more at https://abstra.io/docs
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
.vscode/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Marimo
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
|
||||
# Streamlit
|
||||
.streamlit/secrets.toml
|
||||
|
||||
# MacOS
|
||||
.DS_Store
|
||||
|
||||
# Docs
|
||||
docs/*
|
||||
!docs/conf.py
|
||||
|
||||
# Generated files
|
||||
*.asl
|
||||
experiment-*.log
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
25
.gitlab-ci.yml
Normal file
25
.gitlab-ci.yml
Normal file
@@ -0,0 +1,25 @@
|
||||
# ---------- GLOBAL SETUP ---------- #
|
||||
workflow:
|
||||
rules:
|
||||
- if: '$CI_PIPELINE_SOURCE == "merge_request_event"'
|
||||
|
||||
stages:
|
||||
- install
|
||||
- lint
|
||||
- test
|
||||
|
||||
variables:
|
||||
UV_VERSION: "0.9.4"
|
||||
PYTHON_VERSION: "3.13"
|
||||
BASE_LAYER: trixie-slim
|
||||
|
||||
default:
|
||||
image: ghcr.io/astral-sh/uv:$UV_VERSION-python$PYTHON_VERSION-$BASE_LAYER
|
||||
|
||||
# ---------- TESTING ---------- #
|
||||
test:
|
||||
stage: test
|
||||
tags:
|
||||
- test
|
||||
script:
|
||||
- uv run --only-group test pytest test
|
||||
65
.logging_config.yaml
Normal file
65
.logging_config.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
version: 1
|
||||
|
||||
custom_levels:
|
||||
OBSERVATION: 24
|
||||
ACTION: 25
|
||||
CHAT: 26
|
||||
LLM: 9
|
||||
|
||||
formatters:
|
||||
# Console output
|
||||
colored:
|
||||
class: colorlog.ColoredFormatter
|
||||
format: "{log_color}{asctime}.{msecs:03.0f} | {levelname:11} | {name:70} | {message}"
|
||||
style: "{"
|
||||
datefmt: "%H:%M:%S"
|
||||
|
||||
# User-facing UI (structured JSON)
|
||||
json:
|
||||
class: pythonjsonlogger.jsonlogger.JsonFormatter
|
||||
format: "{name} {levelname} {levelno} {message} {created} {relativeCreated}"
|
||||
style: "{"
|
||||
|
||||
# Experiment stream for console and file output, with optional `role` field
|
||||
experiment:
|
||||
class: control_backend.logging.OptionalFieldFormatter
|
||||
format: "%(asctime)s %(levelname)s %(role?)s %(message)s"
|
||||
defaults:
|
||||
role: "-"
|
||||
|
||||
filters:
|
||||
# Filter out any log records that have the extra field "partial" set to True, indicating that they
|
||||
# will be replaced later.
|
||||
partial:
|
||||
(): control_backend.logging.PartialFilter
|
||||
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
level: DEBUG
|
||||
formatter: colored
|
||||
filters: [partial]
|
||||
stream: ext://sys.stdout
|
||||
ui:
|
||||
class: zmq.log.handlers.PUBHandler
|
||||
level: LLM
|
||||
formatter: json
|
||||
file:
|
||||
class: control_backend.logging.DatedFileHandler
|
||||
formatter: experiment
|
||||
filters: [partial]
|
||||
# Directory must match config.logging_settings.experiment_log_directory
|
||||
file_prefix: experiment_logs/experiment
|
||||
|
||||
# Level for external libraries
|
||||
root:
|
||||
level: WARN
|
||||
handlers: [console]
|
||||
|
||||
loggers:
|
||||
control_backend:
|
||||
level: LLM
|
||||
handlers: [ui]
|
||||
experiment: # This name must match config.logging_settings.experiment_logger_name
|
||||
level: DEBUG
|
||||
handlers: [ui, file]
|
||||
24
.pre-commit-config.yaml
Normal file
24
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.14.2
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
# Configure local hooks
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-commit-msg
|
||||
name: Check commit message format
|
||||
entry: .githooks/check-commit-msg.sh
|
||||
language: script
|
||||
stages: [commit-msg]
|
||||
- id: check-branch-name
|
||||
name: Check branch name format
|
||||
entry: .githooks/check-branch-name.sh
|
||||
language: script
|
||||
stages: [pre-commit]
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.13
|
||||
136
README.md
136
README.md
@@ -1,93 +1,101 @@
|
||||
# PepperPlus-CB
|
||||
|
||||
|
||||
|
||||
## Getting started
|
||||
|
||||
To make it easy for you to get started with GitLab, here's a list of recommended next steps.
|
||||
|
||||
Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
|
||||
|
||||
## Add your files
|
||||
|
||||
- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
|
||||
- [ ] [Add files using the command line](https://docs.gitlab.com/topics/git/add_files/#add-files-to-a-git-repository) or push an existing Git repository with the following command:
|
||||
## Development environment
|
||||
We begin by installing UV (very nice utility for managing packages and Python version):
|
||||
|
||||
```bash
|
||||
# On MacOS and Linux.
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
cd existing_repo
|
||||
git remote add origin https://git.science.uu.nl/ics/sp/2025/n25b/pepperplus-cb.git
|
||||
git branch -M main
|
||||
git push -uf origin main
|
||||
```bash
|
||||
# On Windows.
|
||||
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
|
||||
```
|
||||
|
||||
## Integrate with your tools
|
||||
Using UV, installing the packages and virtual environment is as simple as typing the following (inside the root directory of this repository):
|
||||
|
||||
- [ ] [Set up project integrations](https://git.science.uu.nl/ics/sp/2025/n25b/pepperplus-cb/-/settings/integrations)
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Collaborate with your team
|
||||
## Local LLM
|
||||
|
||||
- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
|
||||
- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
|
||||
- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
|
||||
- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
|
||||
- [ ] [Set auto-merge](https://docs.gitlab.com/user/project/merge_requests/auto_merge/)
|
||||
To run a LLM locally download https://lmstudio.ai
|
||||
When installing select developer mode, download a model (it will already suggest one) and run it (see developer window, status: running)
|
||||
|
||||
## Test and Deploy
|
||||
copy the url at the top right and replace local_llm_url with it + v1/chat/completions.
|
||||
This + part might differ based on what model you choose.
|
||||
|
||||
Use the built-in continuous integration in GitLab.
|
||||
copy the model name in the module loaded and replace local_llm_modelL. In settings.
|
||||
|
||||
- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/)
|
||||
- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
|
||||
- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
|
||||
- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
|
||||
- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
|
||||
|
||||
***
|
||||
|
||||
# Editing this README
|
||||
## Running
|
||||
To run the project (development server), execute the following command (while inside the root repository):
|
||||
|
||||
When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template.
|
||||
```bash
|
||||
uv run fastapi dev src/control_backend/main.py
|
||||
```
|
||||
|
||||
## Suggestions for a good README
|
||||
### Environment Variables
|
||||
|
||||
Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
|
||||
You can use environment variables to change settings. Make a copy of the [`.env.example`](.env.example) file, name it `.env` and put it in the root directory. The file itself describes how to do the configuration.
|
||||
|
||||
## Name
|
||||
Choose a self-explaining name for your project.
|
||||
For an exhaustive list of environment options, see the `control_backend.core.config` module in the docs.
|
||||
|
||||
## Description
|
||||
Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
|
||||
|
||||
## Badges
|
||||
On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
|
||||
|
||||
## Visuals
|
||||
Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
|
||||
## Testing
|
||||
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:
|
||||
|
||||
## Installation
|
||||
Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
|
||||
```bash
|
||||
uv run --only-group test pytest test/unit
|
||||
```
|
||||
|
||||
## Usage
|
||||
Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
|
||||
Or for integration tests:
|
||||
|
||||
## Support
|
||||
Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
|
||||
```bash
|
||||
uv run --group integration-test pytest test/integration
|
||||
```
|
||||
|
||||
## Roadmap
|
||||
If you have ideas for releases in the future, it is a good idea to list them in the README.
|
||||
## Git Hooks
|
||||
|
||||
## Contributing
|
||||
State if you are open to contributions and what your requirements are for accepting them.
|
||||
To activate automatic linting, formatting, branch name checks and commit message checks, run:
|
||||
|
||||
For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
uv run pre-commit install --hook-type commit-msg
|
||||
```
|
||||
|
||||
You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
|
||||
You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running:
|
||||
|
||||
## Authors and acknowledgment
|
||||
Show your appreciation to those who have contributed to the project.
|
||||
```bash
|
||||
git config --local --unset core.hooksPath
|
||||
```
|
||||
|
||||
## License
|
||||
For open source projects, say how it is licensed.
|
||||
Then run the pre-commit install commands again.
|
||||
|
||||
## Project status
|
||||
If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
|
||||
## Documentation
|
||||
Generate documentation web pages using:
|
||||
|
||||
### Linux & macOS
|
||||
```bash
|
||||
PYTHONPATH=src sphinx-apidoc -F -o docs src/control_backend
|
||||
```
|
||||
|
||||
### Windows
|
||||
```bash
|
||||
$env:PYTHONPATH="src"; sphinx-apidoc -F -o docs src/control_backend
|
||||
```
|
||||
|
||||
Optionally, in the `conf.py` file in the `docs` folder, change preferences.
|
||||
|
||||
In the `docs` folder:
|
||||
|
||||
### Linux & macOS
|
||||
```bash
|
||||
make html
|
||||
```
|
||||
|
||||
### Windows
|
||||
```bash
|
||||
.\make.bat html
|
||||
```
|
||||
40
docs/conf.py
Normal file
40
docs/conf.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../src"))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = "control_backend"
|
||||
copyright = "2025, Author"
|
||||
author = "Author"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.todo",
|
||||
]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
language = "en"
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
html_static_path = ["_static"]
|
||||
|
||||
# -- Options for todo extension ----------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/extensions/todo.html#configuration
|
||||
|
||||
todo_include_todos = True
|
||||
76
pyproject.toml
Normal file
76
pyproject.toml
Normal file
@@ -0,0 +1,76 @@
|
||||
[project]
|
||||
name = "pepperplus-cb"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"agentspeak>=0.2.2",
|
||||
"colorlog>=6.10.1",
|
||||
"fastapi[all]>=0.115.6",
|
||||
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
||||
"numpy>=2.3.3",
|
||||
"openai-whisper>=20250625",
|
||||
"pyaudio>=0.2.14",
|
||||
"pydantic>=2.12.0",
|
||||
"pydantic-settings>=2.11.0",
|
||||
"python-json-logger>=4.0.0",
|
||||
"python-slugify>=8.0.4",
|
||||
"pyyaml>=6.0.3",
|
||||
"pyzmq>=27.1.0",
|
||||
"silero-vad>=6.0.0",
|
||||
"sphinx>=7.3.7",
|
||||
"sphinx-rtd-theme>=3.0.2",
|
||||
"torch>=2.8.0",
|
||||
"uvicorn>=0.37.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pre-commit>=4.3.0",
|
||||
"pytest>=8.4.2",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"pytest-mock>=3.15.1",
|
||||
"soundfile>=0.13.1",
|
||||
"ruff>=0.14.2",
|
||||
"ruff-format>=0.3.0",
|
||||
]
|
||||
test = [
|
||||
"agentspeak>=0.2.2",
|
||||
"fastapi>=0.115.6",
|
||||
"httpx>=0.28.1",
|
||||
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
||||
"openai-whisper>=20250625",
|
||||
"pydantic>=2.12.0",
|
||||
"pydantic-settings>=2.11.0",
|
||||
"pytest>=8.4.2",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
"pytest-mock>=3.15.1",
|
||||
"python-slugify>=8.0.4",
|
||||
"pyyaml>=6.0.3",
|
||||
"pyzmq>=27.1.0",
|
||||
"soundfile>=0.13.1",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort (import sorting)
|
||||
"UP", # pyupgrade (modernize code)
|
||||
"B", # flake8-bugbear (common bugs)
|
||||
"C4", # flake8-comprehensions (unnecessary comprehensions)
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"E226", # spaces around operators
|
||||
"E701", # multiple statements on a single line
|
||||
]
|
||||
0
src/control_backend/__init__.py
Normal file
0
src/control_backend/__init__.py
Normal file
5
src/control_backend/agents/__init__.py
Normal file
5
src/control_backend/agents/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
This package contains all agent implementations for the PepperPlus Control Backend.
|
||||
"""
|
||||
|
||||
from .base import BaseAgent as BaseAgent
|
||||
6
src/control_backend/agents/actuation/__init__.py
Normal file
6
src/control_backend/agents/actuation/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Agents responsible for controlling the robot's physical actions, such as speech and gestures.
|
||||
"""
|
||||
|
||||
from .robot_gesture_agent import RobotGestureAgent as RobotGestureAgent
|
||||
from .robot_speech_agent import RobotSpeechAgent as RobotSpeechAgent
|
||||
177
src/control_backend/agents/actuation/robot_gesture_agent.py
Normal file
177
src/control_backend/agents/actuation/robot_gesture_agent.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
|
||||
|
||||
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||
|
||||
|
||||
class RobotGestureAgent(BaseAgent):
|
||||
"""
|
||||
This agent acts as a bridge between the control backend and the Robot Interface (RI).
|
||||
It receives gesture commands from other agents or from the UI,
|
||||
and forwards them to the robot via a ZMQ PUB socket.
|
||||
|
||||
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
|
||||
:ivar pubsocket: ZMQ PUB socket for sending commands to the Robot Interface.
|
||||
:ivar address: Address to bind/connect the PUB socket.
|
||||
:ivar bind: Whether to bind or connect the PUB socket.
|
||||
:ivar gesture_data: A list of strings for available gestures
|
||||
"""
|
||||
|
||||
subsocket: azmq.Socket
|
||||
repsocket: azmq.Socket
|
||||
pubsocket: azmq.Socket
|
||||
address = ""
|
||||
bind = False
|
||||
gesture_data = []
|
||||
single_gesture_data = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
address: str,
|
||||
bind=False,
|
||||
gesture_data=None,
|
||||
single_gesture_data=None,
|
||||
):
|
||||
self.gesture_data = gesture_data or []
|
||||
self.single_gesture_data = single_gesture_data or []
|
||||
super().__init__(name)
|
||||
self.address = address
|
||||
self.bind = bind
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent.
|
||||
|
||||
1. Sets up the PUB socket to talk to the robot.
|
||||
2. Sets up the SUB socket to listen for "command" topics (from UI/External).
|
||||
3. Starts the loop for handling ZMQ commands.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
context = azmq.Context.instance()
|
||||
|
||||
# To the robot
|
||||
self.pubsocket = context.socket(zmq.PUB)
|
||||
if self.bind:
|
||||
self.pubsocket.bind(self.address)
|
||||
else:
|
||||
self.pubsocket.connect(self.address)
|
||||
|
||||
# Receive internal topics regarding commands
|
||||
self.subsocket = context.socket(zmq.SUB)
|
||||
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures")
|
||||
|
||||
# REP socket for replying to gesture requests
|
||||
self.repsocket = context.socket(zmq.REP)
|
||||
self.repsocket.bind(settings.zmq_settings.internal_gesture_rep_adress)
|
||||
|
||||
self.add_behavior(self._zmq_command_loop())
|
||||
self.add_behavior(self._fetch_gestures_loop())
|
||||
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def stop(self):
|
||||
if self.subsocket:
|
||||
self.subsocket.close()
|
||||
if self.pubsocket:
|
||||
self.pubsocket.close()
|
||||
if self.repsocket:
|
||||
self.repsocket.close()
|
||||
await super().stop()
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle commands received from other internal Python agents.
|
||||
|
||||
Validates the message as a :class:`GestureCommand` and forwards it to the robot.
|
||||
|
||||
:param msg: The internal message containing the command.
|
||||
"""
|
||||
try:
|
||||
gesture_command = GestureCommand.model_validate_json(msg.body)
|
||||
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
|
||||
if gesture_command.data not in self.gesture_data:
|
||||
self.logger.warning(
|
||||
"Received gesture tag '%s' which is not in available tags. Early returning",
|
||||
gesture_command.data,
|
||||
)
|
||||
return
|
||||
elif gesture_command.endpoint == RIEndpoint.GESTURE_SINGLE:
|
||||
if gesture_command.data not in self.single_gesture_data:
|
||||
self.logger.warning(
|
||||
"Received gesture '%s' which is not in available gestures. Early returning",
|
||||
gesture_command.data,
|
||||
)
|
||||
return
|
||||
experiment_logger.action("Gesture: %s", gesture_command.data)
|
||||
await self.pubsocket.send_json(gesture_command.model_dump())
|
||||
except Exception:
|
||||
self.logger.exception("Error processing internal message.")
|
||||
|
||||
async def _zmq_command_loop(self):
|
||||
"""
|
||||
Loop to handle commands received via ZMQ (e.g., from the UI).
|
||||
|
||||
Listens on the 'command' topic, validates the JSON and forwards it to the robot.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
topic, body = await self.subsocket.recv_multipart()
|
||||
|
||||
# Don't process send_gestures here
|
||||
if topic != b"command":
|
||||
continue
|
||||
|
||||
body = json.loads(body)
|
||||
gesture_command = GestureCommand.model_validate(body)
|
||||
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
|
||||
if gesture_command.data not in self.gesture_data:
|
||||
self.logger.warning(
|
||||
"Received gesture tag '%s' which is not in available tags.\
|
||||
Early returning",
|
||||
gesture_command.data,
|
||||
)
|
||||
continue
|
||||
await self.pubsocket.send_json(gesture_command.model_dump())
|
||||
except Exception:
|
||||
self.logger.exception("Error processing ZMQ message.")
|
||||
|
||||
async def _fetch_gestures_loop(self):
|
||||
"""
|
||||
Loop to handle fetching gestures received via ZMQ (e.g., from the UI).
|
||||
|
||||
Listens on the 'send_gestures' topic, and returns a list on the get_gestures topic.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
# Get a request
|
||||
body = await self.repsocket.recv()
|
||||
|
||||
# Figure out amount, if specified
|
||||
try:
|
||||
body = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
body = None
|
||||
|
||||
amount = None
|
||||
if isinstance(body, int):
|
||||
amount = body
|
||||
|
||||
# Fetch tags from gesture data and respond
|
||||
tags = self.gesture_data[:amount] if amount else self.gesture_data
|
||||
response = json.dumps({"tags": tags}).encode()
|
||||
await self.repsocket.send(response)
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Error fetching gesture tags.")
|
||||
103
src/control_backend/agents/actuation/robot_speech_agent.py
Normal file
103
src/control_backend/agents/actuation/robot_speech_agent.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import json
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
|
||||
class RobotSpeechAgent(BaseAgent):
|
||||
"""
|
||||
This agent acts as a bridge between the control backend and the Robot Interface (RI).
|
||||
It receives speech commands from other agents or from the UI,
|
||||
and forwards them to the robot via a ZMQ PUB socket.
|
||||
|
||||
:ivar subsocket: ZMQ SUB socket for receiving external commands (e.g., from UI).
|
||||
:ivar pubsocket: ZMQ PUB socket for sending commands to the Robot Interface.
|
||||
:ivar address: Address to bind/connect the PUB socket.
|
||||
:ivar bind: Whether to bind or connect the PUB socket.
|
||||
"""
|
||||
|
||||
subsocket: azmq.Socket
|
||||
pubsocket: azmq.Socket
|
||||
address = ""
|
||||
bind = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
address: str,
|
||||
bind=False,
|
||||
):
|
||||
super().__init__(name)
|
||||
self.address = address
|
||||
self.bind = bind
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent.
|
||||
|
||||
1. Sets up the PUB socket to talk to the robot.
|
||||
2. Sets up the SUB socket to listen for "command" topics (from UI/External).
|
||||
3. Starts the loop for handling ZMQ commands.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
context = azmq.Context.instance()
|
||||
|
||||
# To the robot
|
||||
self.pubsocket = context.socket(zmq.PUB)
|
||||
if self.bind: # TODO: Should this ever be the case?
|
||||
self.pubsocket.bind(self.address)
|
||||
else:
|
||||
self.pubsocket.connect(self.address)
|
||||
|
||||
# Receive internal topics regarding commands
|
||||
self.subsocket = context.socket(zmq.SUB)
|
||||
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
||||
|
||||
self.add_behavior(self._zmq_command_loop())
|
||||
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def stop(self):
|
||||
if self.subsocket:
|
||||
self.subsocket.close()
|
||||
if self.pubsocket:
|
||||
self.pubsocket.close()
|
||||
await super().stop()
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle commands received from other internal Python agents.
|
||||
|
||||
Validates the message as a :class:`SpeechCommand` and forwards it to the robot.
|
||||
|
||||
:param msg: The internal message containing the command.
|
||||
"""
|
||||
try:
|
||||
speech_command = SpeechCommand.model_validate_json(msg.body)
|
||||
await self.pubsocket.send_json(speech_command.model_dump())
|
||||
except Exception:
|
||||
self.logger.exception("Error processing internal message.")
|
||||
|
||||
async def _zmq_command_loop(self):
|
||||
"""
|
||||
Loop to handle commands received via ZMQ (e.g., from the UI).
|
||||
|
||||
Listens on the 'command' topic, validates the JSON, and forwards it to the robot.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
_, body = await self.subsocket.recv_multipart()
|
||||
|
||||
body = json.loads(body)
|
||||
message = SpeechCommand.model_validate(body)
|
||||
|
||||
await self.pubsocket.send_json(message.model_dump())
|
||||
except Exception:
|
||||
self.logger.exception("Error processing ZMQ message.")
|
||||
27
src/control_backend/agents/base.py
Normal file
27
src/control_backend/agents/base.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from abc import ABC
|
||||
|
||||
from control_backend.core.agent_system import BaseAgent as CoreBaseAgent
|
||||
|
||||
|
||||
class BaseAgent(CoreBaseAgent, ABC):
|
||||
"""
|
||||
The primary base class for all implementation agents.
|
||||
|
||||
Inherits from :class:`control_backend.core.agent_system.BaseAgent`.
|
||||
This class ensures that every agent instance is automatically equipped with a
|
||||
properly configured ``logger``.
|
||||
|
||||
:ivar logger: A logger instance named after the agent's package and class.
|
||||
"""
|
||||
|
||||
logger: logging.Logger
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""
|
||||
Whenever a subclass is initiated, give it the correct logger.
|
||||
:param kwargs: Keyword arguments for the subclass.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
cls.logger = logging.getLogger(__package__).getChild(cls.__name__)
|
||||
10
src/control_backend/agents/bdi/__init__.py
Normal file
10
src/control_backend/agents/bdi/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Agents and utilities for the BDI (Belief-Desire-Intention) reasoning system,
|
||||
implementing AgentSpeak(L) logic.
|
||||
"""
|
||||
|
||||
from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent
|
||||
|
||||
from .text_belief_extractor_agent import (
|
||||
TextBeliefExtractorAgent as TextBeliefExtractorAgent,
|
||||
)
|
||||
570
src/control_backend/agents/bdi/agentspeak_ast.py
Normal file
570
src/control_backend/agents/bdi/agentspeak_ast.py
Normal file
@@ -0,0 +1,570 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AstNode(ABC):
|
||||
"""
|
||||
Abstract base class for all elements of an AgentSpeak program.
|
||||
|
||||
This class serves as the foundation for all AgentSpeak abstract syntax tree (AST) nodes.
|
||||
It defines the core interface that all AST nodes must implement to generate AgentSpeak code.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Generates the AgentSpeak code string.
|
||||
|
||||
This method converts the AST node into its corresponding
|
||||
AgentSpeak source code representation.
|
||||
|
||||
:return: The AgentSpeak code string representation of this node.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Returns the string representation of this AST node.
|
||||
|
||||
This method provides a convenient way to get the AgentSpeak code representation
|
||||
by delegating to the _to_agentspeak method.
|
||||
|
||||
:return: The AgentSpeak code string representation of this node.
|
||||
"""
|
||||
return self._to_agentspeak()
|
||||
|
||||
|
||||
class AstExpression(AstNode, ABC):
|
||||
"""
|
||||
Intermediate class for anything that can be used in a logical expression.
|
||||
|
||||
This class extends AstNode to provide common functionality for all expressions
|
||||
that can be used in logical operations within AgentSpeak programs.
|
||||
"""
|
||||
|
||||
def __and__(self, other: ExprCoalescible) -> AstBinaryOp:
|
||||
"""
|
||||
Creates a logical AND operation between this expression and another.
|
||||
|
||||
This method allows for operator overloading of the & operator to create
|
||||
binary logical operations in a more intuitive syntax.
|
||||
|
||||
:param other: The right-hand side expression to combine with this one.
|
||||
:return: A new AstBinaryOp representing the logical AND operation.
|
||||
"""
|
||||
return AstBinaryOp(self, BinaryOperatorType.AND, _coalesce_expr(other))
|
||||
|
||||
def __or__(self, other: ExprCoalescible) -> AstBinaryOp:
|
||||
"""
|
||||
Creates a logical OR operation between this expression and another.
|
||||
|
||||
This method allows for operator overloading of the | operator to create
|
||||
binary logical operations in a more intuitive syntax.
|
||||
|
||||
:param other: The right-hand side expression to combine with this one.
|
||||
:return: A new AstBinaryOp representing the logical OR operation.
|
||||
"""
|
||||
return AstBinaryOp(self, BinaryOperatorType.OR, _coalesce_expr(other))
|
||||
|
||||
def __invert__(self) -> AstLogicalExpression:
|
||||
"""
|
||||
Creates a logical negation of this expression.
|
||||
|
||||
This method allows for operator overloading of the ~ operator to create
|
||||
negated expressions. If the expression is already a logical expression,
|
||||
it toggles the negation flag. Otherwise, it wraps the expression in a
|
||||
new AstLogicalExpression with negation set to True.
|
||||
|
||||
:return: An AstLogicalExpression representing the negated form of this expression.
|
||||
"""
|
||||
if isinstance(self, AstLogicalExpression):
|
||||
self.negated = not self.negated
|
||||
return self
|
||||
return AstLogicalExpression(self, negated=True)
|
||||
|
||||
|
||||
type ExprCoalescible = AstExpression | str | int | float
|
||||
|
||||
|
||||
def _coalesce_expr(value: ExprCoalescible) -> AstExpression:
|
||||
if isinstance(value, AstExpression):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return AstString(value)
|
||||
if isinstance(value, (int, float)):
|
||||
return AstNumber(value)
|
||||
raise TypeError(f"Cannot coalesce type {type(value)} into an AstTerm.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstTerm(AstExpression, ABC):
|
||||
"""
|
||||
Base class for terms appearing inside literals.
|
||||
"""
|
||||
|
||||
def __ge__(self, other: ExprCoalescible) -> AstBinaryOp:
|
||||
return AstBinaryOp(self, BinaryOperatorType.GREATER_EQUALS, _coalesce_expr(other))
|
||||
|
||||
def __gt__(self, other: ExprCoalescible) -> AstBinaryOp:
|
||||
return AstBinaryOp(self, BinaryOperatorType.GREATER_THAN, _coalesce_expr(other))
|
||||
|
||||
def __le__(self, other: ExprCoalescible) -> AstBinaryOp:
|
||||
return AstBinaryOp(self, BinaryOperatorType.LESS_EQUALS, _coalesce_expr(other))
|
||||
|
||||
def __lt__(self, other: ExprCoalescible) -> AstBinaryOp:
|
||||
return AstBinaryOp(self, BinaryOperatorType.LESS_THAN, _coalesce_expr(other))
|
||||
|
||||
def __eq__(self, other: ExprCoalescible) -> AstBinaryOp:
|
||||
return AstBinaryOp(self, BinaryOperatorType.EQUALS, _coalesce_expr(other))
|
||||
|
||||
def __ne__(self, other: ExprCoalescible) -> AstBinaryOp:
|
||||
return AstBinaryOp(self, BinaryOperatorType.NOT_EQUALS, _coalesce_expr(other))
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class AstAtom(AstTerm):
|
||||
"""
|
||||
Represents a grounded atom in AgentSpeak (e.g., lowercase constants).
|
||||
|
||||
Atoms are the simplest form of terms in AgentSpeak, representing concrete,
|
||||
unchanging values. They are typically used as constants in logical expressions.
|
||||
|
||||
:ivar value: The string value of this atom, which will be converted to lowercase
|
||||
in the AgentSpeak representation.
|
||||
"""
|
||||
|
||||
value: str
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this atom to its AgentSpeak string representation.
|
||||
|
||||
Atoms are represented in lowercase in AgentSpeak to distinguish them
|
||||
from variables (which are capitalized).
|
||||
|
||||
:return: The lowercase string representation of this atom.
|
||||
"""
|
||||
return self.value.lower()
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class AstVar(AstTerm):
|
||||
"""
|
||||
Represents an ungrounded variable in AgentSpeak (e.g., capitalized names).
|
||||
|
||||
Variables in AgentSpeak are placeholders that can be bound to specific values
|
||||
during execution. They are distinguished from atoms by their capitalization.
|
||||
|
||||
:ivar name: The name of this variable, which will be capitalized in the
|
||||
AgentSpeak representation.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this variable to its AgentSpeak string representation.
|
||||
|
||||
Variables are represented with capitalized names in AgentSpeak to distinguish
|
||||
them from atoms (which are lowercase).
|
||||
|
||||
:return: The capitalized string representation of this variable.
|
||||
"""
|
||||
return self.name.capitalize()
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class AstNumber(AstTerm):
|
||||
"""
|
||||
Represents a numeric constant in AgentSpeak.
|
||||
|
||||
Numeric constants can be either integers or floating-point numbers and are
|
||||
used in logical expressions and comparisons.
|
||||
|
||||
:ivar value: The numeric value of this constant (can be int or float).
|
||||
"""
|
||||
|
||||
value: int | float
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this numeric constant to its AgentSpeak string representation.
|
||||
|
||||
:return: The string representation of the numeric value.
|
||||
"""
|
||||
return str(self.value)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class AstString(AstTerm):
|
||||
"""
|
||||
Represents a string literal in AgentSpeak.
|
||||
|
||||
String literals are used to represent textual data and are enclosed in
|
||||
double quotes in the AgentSpeak representation.
|
||||
|
||||
:ivar value: The string content of this literal.
|
||||
"""
|
||||
|
||||
value: str
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this string literal to its AgentSpeak string representation.
|
||||
|
||||
String literals are enclosed in double quotes in AgentSpeak.
|
||||
|
||||
:return: The string literal enclosed in double quotes.
|
||||
"""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class AstLiteral(AstTerm):
|
||||
"""
|
||||
Represents a literal (functor and terms) in AgentSpeak.
|
||||
|
||||
Literals are the fundamental building blocks of AgentSpeak programs, consisting
|
||||
of a functor (predicate name) and an optional list of terms (arguments).
|
||||
|
||||
:ivar functor: The name of the predicate or function.
|
||||
:ivar terms: A list of terms (arguments) for this literal. Defaults to an empty list.
|
||||
"""
|
||||
|
||||
functor: str
|
||||
terms: list[AstTerm] = field(default_factory=list)
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this literal to its AgentSpeak string representation.
|
||||
|
||||
If the literal has no terms, it returns just the functor name.
|
||||
Otherwise, it returns the functor followed by the terms in parentheses.
|
||||
|
||||
:return: The AgentSpeak string representation of this literal.
|
||||
"""
|
||||
if not self.terms:
|
||||
return self.functor
|
||||
args = ", ".join(map(str, self.terms))
|
||||
return f"{self.functor}({args})"
|
||||
|
||||
|
||||
class BinaryOperatorType(StrEnum):
|
||||
"""
|
||||
Enumeration of binary operator types used in AgentSpeak expressions.
|
||||
|
||||
These operators are used to create binary operations between expressions,
|
||||
including logical operations (AND, OR) and comparison operations.
|
||||
"""
|
||||
|
||||
AND = "&"
|
||||
OR = "|"
|
||||
GREATER_THAN = ">"
|
||||
LESS_THAN = "<"
|
||||
EQUALS = "=="
|
||||
NOT_EQUALS = "\\=="
|
||||
GREATER_EQUALS = ">="
|
||||
LESS_EQUALS = "<="
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstBinaryOp(AstExpression):
|
||||
"""
|
||||
Represents a binary logical or relational operation in AgentSpeak.
|
||||
|
||||
Binary operations combine two expressions using a logical or comparison operator.
|
||||
They are used to create complex logical conditions in AgentSpeak programs.
|
||||
|
||||
:ivar left: The left-hand side expression of the operation.
|
||||
:ivar operator: The binary operator type (AND, OR, comparison operators, etc.).
|
||||
:ivar right: The right-hand side expression of the operation.
|
||||
"""
|
||||
|
||||
left: AstExpression
|
||||
operator: BinaryOperatorType
|
||||
right: AstExpression
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to ensure proper expression types.
|
||||
|
||||
This method wraps the left and right expressions in AstLogicalExpression
|
||||
instances if they aren't already, ensuring consistent handling throughout
|
||||
the AST.
|
||||
"""
|
||||
self.left = _as_logical(self.left)
|
||||
self.right = _as_logical(self.right)
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this binary operation to its AgentSpeak string representation.
|
||||
|
||||
The method handles proper parenthesization of sub-expressions to maintain
|
||||
correct operator precedence and readability.
|
||||
|
||||
:return: The AgentSpeak string representation of this binary operation.
|
||||
"""
|
||||
l_str = str(self.left)
|
||||
r_str = str(self.right)
|
||||
|
||||
assert isinstance(self.left, AstLogicalExpression)
|
||||
assert isinstance(self.right, AstLogicalExpression)
|
||||
|
||||
if isinstance(self.left.expression, AstBinaryOp) or self.left.negated:
|
||||
l_str = f"({l_str})"
|
||||
if isinstance(self.right.expression, AstBinaryOp) or self.right.negated:
|
||||
r_str = f"({r_str})"
|
||||
|
||||
return f"{l_str} {self.operator.value} {r_str}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstLogicalExpression(AstExpression):
|
||||
"""
|
||||
Represents a logical expression, potentially negated, in AgentSpeak.
|
||||
|
||||
Logical expressions can be either positive or negated and form the basis
|
||||
of conditions and beliefs in AgentSpeak programs.
|
||||
|
||||
:ivar expression: The underlying expression being evaluated.
|
||||
:ivar negated: Boolean flag indicating whether this expression is negated.
|
||||
"""
|
||||
|
||||
expression: AstExpression
|
||||
negated: bool = False
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this logical expression to its AgentSpeak string representation.
|
||||
|
||||
If the expression is negated, it prepends 'not ' to the expression string.
|
||||
For complex expressions (binary operations), it adds parentheses when negated
|
||||
to maintain correct logical interpretation.
|
||||
|
||||
:return: The AgentSpeak string representation of this logical expression.
|
||||
"""
|
||||
expr_str = str(self.expression)
|
||||
if isinstance(self.expression, AstBinaryOp) and self.negated:
|
||||
expr_str = f"({expr_str})"
|
||||
return f"{'not ' if self.negated else ''}{expr_str}"
|
||||
|
||||
|
||||
def _as_logical(expr: AstExpression) -> AstLogicalExpression:
|
||||
"""
|
||||
Converts an expression to a logical expression if it isn't already.
|
||||
|
||||
This helper function ensures that expressions are properly wrapped in
|
||||
AstLogicalExpression instances, which is necessary for consistent handling
|
||||
of logical operations in the AST.
|
||||
|
||||
:param expr: The expression to convert.
|
||||
:return: The expression wrapped in an AstLogicalExpression if it wasn't already.
|
||||
"""
|
||||
if isinstance(expr, AstLogicalExpression):
|
||||
return expr
|
||||
return AstLogicalExpression(expr)
|
||||
|
||||
|
||||
class StatementType(StrEnum):
|
||||
"""
|
||||
Enumeration of statement types that can appear in AgentSpeak plans.
|
||||
|
||||
These statement types define the different kinds of actions and operations
|
||||
that can be performed within the body of an AgentSpeak plan.
|
||||
"""
|
||||
|
||||
EMPTY = ""
|
||||
"""Empty statement (no operation, used when evaluating a plan to true)."""
|
||||
|
||||
DO_ACTION = "."
|
||||
"""Execute an action defined in Python."""
|
||||
|
||||
ACHIEVE_GOAL = "!"
|
||||
"""Achieve a goal (add a goal to be accomplished)."""
|
||||
|
||||
TEST_GOAL = "?"
|
||||
"""Test a goal (check if a goal can be achieved)."""
|
||||
|
||||
ADD_BELIEF = "+"
|
||||
"""Add a belief to the belief base."""
|
||||
|
||||
REMOVE_BELIEF = "-"
|
||||
"""Remove a belief from the belief base."""
|
||||
|
||||
REPLACE_BELIEF = "-+"
|
||||
"""Replace a belief in the belief base."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstStatement(AstNode):
|
||||
"""
|
||||
A statement that can appear inside a plan.
|
||||
|
||||
Statements are the executable units within AgentSpeak plans. They consist
|
||||
of a statement type (defining the operation) and an expression (defining
|
||||
what to operate on).
|
||||
|
||||
:ivar type: The type of statement (action, goal, belief operation, etc.).
|
||||
:ivar expression: The expression that this statement operates on.
|
||||
"""
|
||||
|
||||
type: StatementType
|
||||
expression: AstExpression
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this statement to its AgentSpeak string representation.
|
||||
|
||||
The representation consists of the statement type prefix followed by
|
||||
the expression.
|
||||
|
||||
:return: The AgentSpeak string representation of this statement.
|
||||
"""
|
||||
return f"{self.type.value}{self.expression}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstRule(AstNode):
|
||||
"""
|
||||
Represents an inference rule in AgentSpeak. If there is no condition, it always holds.
|
||||
|
||||
Rules define logical implications in AgentSpeak programs. They consist of a
|
||||
result (conclusion) and an optional condition (premise). When the condition
|
||||
holds, the result is inferred to be true.
|
||||
|
||||
:ivar result: The conclusion or result of this rule.
|
||||
:ivar condition: The premise or condition for this rule (optional).
|
||||
"""
|
||||
|
||||
result: AstExpression
|
||||
condition: AstExpression | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to ensure proper expression types.
|
||||
|
||||
If a condition is provided, this method wraps it in an AstLogicalExpression
|
||||
to ensure consistent handling throughout the AST.
|
||||
"""
|
||||
if self.condition is not None:
|
||||
self.condition = _as_logical(self.condition)
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this rule to its AgentSpeak string representation.
|
||||
|
||||
If no condition is specified, the rule is represented as a simple fact.
|
||||
If a condition is specified, it's represented as an implication (result :- condition).
|
||||
|
||||
:return: The AgentSpeak string representation of this rule.
|
||||
"""
|
||||
if not self.condition:
|
||||
return f"{self.result}."
|
||||
return f"{self.result} :- {self.condition}."
|
||||
|
||||
|
||||
class TriggerType(StrEnum):
|
||||
"""
|
||||
Enumeration of trigger types for AgentSpeak plans.
|
||||
|
||||
Trigger types define what kind of events can activate an AgentSpeak plan.
|
||||
Currently, the system supports triggers for added beliefs and added goals.
|
||||
"""
|
||||
|
||||
ADDED_BELIEF = "+"
|
||||
"""Trigger when a belief is added to the belief base."""
|
||||
|
||||
# REMOVED_BELIEF = "-" # TODO
|
||||
# MODIFIED_BELIEF = "^" # TODO
|
||||
|
||||
ADDED_GOAL = "+!"
|
||||
"""Trigger when a goal is added to be achieved."""
|
||||
|
||||
# REMOVED_GOAL = "-!" # TODO
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstPlan(AstNode):
|
||||
"""
|
||||
Represents a plan in AgentSpeak, consisting of a trigger, context, and body.
|
||||
|
||||
Plans define the reactive behavior of agents in AgentSpeak. They specify what
|
||||
actions to take when certain conditions are met (trigger and context).
|
||||
|
||||
:ivar type: The type of trigger that activates this plan.
|
||||
:ivar trigger_literal: The specific event or condition that triggers this plan.
|
||||
:ivar context: A list of conditions that must hold for this plan to be applicable.
|
||||
:ivar body: A list of statements to execute when this plan is triggered.
|
||||
"""
|
||||
|
||||
type: TriggerType
|
||||
trigger_literal: AstExpression
|
||||
context: list[AstExpression]
|
||||
body: list[AstStatement]
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this plan to its AgentSpeak string representation.
|
||||
|
||||
The representation follows the standard AgentSpeak plan format:
|
||||
trigger_type + trigger_literal
|
||||
: context_conditions
|
||||
<- body_statements.
|
||||
|
||||
:return: The AgentSpeak string representation of this plan.
|
||||
"""
|
||||
assert isinstance(self.trigger_literal, AstLiteral)
|
||||
|
||||
indent = " " * 6
|
||||
colon = " : "
|
||||
arrow = " <- "
|
||||
|
||||
lines = []
|
||||
|
||||
lines.append(f"{self.type.value}{self.trigger_literal}")
|
||||
|
||||
if self.context:
|
||||
lines.append(colon + f" &\n{indent}".join(str(c) for c in self.context))
|
||||
|
||||
if self.body:
|
||||
lines.append(arrow + f";\n{indent}".join(str(s) for s in self.body) + ".")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstProgram(AstNode):
|
||||
"""
|
||||
Represents a full AgentSpeak program, consisting of rules and plans.
|
||||
|
||||
This is the root node of the AgentSpeak AST, containing all the rules
|
||||
and plans that define the agent's behavior.
|
||||
|
||||
:ivar rules: A list of inference rules in this program.
|
||||
:ivar plans: A list of reactive plans in this program.
|
||||
"""
|
||||
|
||||
rules: list[AstRule] = field(default_factory=list)
|
||||
plans: list[AstPlan] = field(default_factory=list)
|
||||
|
||||
def _to_agentspeak(self) -> str:
|
||||
"""
|
||||
Converts this program to its AgentSpeak string representation.
|
||||
|
||||
The representation consists of all rules followed by all plans,
|
||||
separated by blank lines for readability.
|
||||
|
||||
:return: The complete AgentSpeak source code for this program.
|
||||
"""
|
||||
lines = []
|
||||
lines.extend(map(str, self.rules))
|
||||
|
||||
lines.extend(["", ""])
|
||||
lines.extend(map(str, self.plans))
|
||||
|
||||
return "\n".join(lines)
|
||||
881
src/control_backend/agents/bdi/agentspeak_generator.py
Normal file
881
src/control_backend/agents/bdi/agentspeak_generator.py
Normal file
@@ -0,0 +1,881 @@
|
||||
from functools import singledispatchmethod
|
||||
|
||||
from slugify import slugify
|
||||
|
||||
from control_backend.agents.bdi.agentspeak_ast import (
|
||||
AstAtom,
|
||||
AstBinaryOp,
|
||||
AstExpression,
|
||||
AstLiteral,
|
||||
AstNumber,
|
||||
AstPlan,
|
||||
AstProgram,
|
||||
AstRule,
|
||||
AstStatement,
|
||||
AstString,
|
||||
AstVar,
|
||||
BinaryOperatorType,
|
||||
StatementType,
|
||||
TriggerType,
|
||||
)
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.program import (
|
||||
BaseGoal,
|
||||
BasicNorm,
|
||||
ConditionalNorm,
|
||||
GestureAction,
|
||||
Goal,
|
||||
InferredBelief,
|
||||
KeywordBelief,
|
||||
LLMAction,
|
||||
LogicalOperator,
|
||||
Norm,
|
||||
Phase,
|
||||
PlanElement,
|
||||
Program,
|
||||
ProgramElement,
|
||||
SemanticBelief,
|
||||
SpeechAction,
|
||||
Trigger,
|
||||
)
|
||||
|
||||
|
||||
class AgentSpeakGenerator:
|
||||
"""
|
||||
Generator class that translates a high-level :class:`~control_backend.schemas.program.Program`
|
||||
into AgentSpeak(L) source code.
|
||||
|
||||
It handles the conversion of phases, norms, goals, and triggers into AgentSpeak rules and plans,
|
||||
ensuring the robot follows the defined behavioral logic.
|
||||
|
||||
The generator follows a systematic approach:
|
||||
1. Sets up initial phase and cycle notification rules
|
||||
2. Adds keyword inference capabilities for natural language processing
|
||||
3. Creates default plans for common operations
|
||||
4. Processes each phase with its norms, goals, and triggers
|
||||
5. Adds fallback plans for robust execution
|
||||
|
||||
:ivar _asp: The internal AgentSpeak program representation being built.
|
||||
"""
|
||||
|
||||
_asp: AstProgram
|
||||
|
||||
def generate(self, program: Program) -> str:
|
||||
"""
|
||||
Translates a Program object into an AgentSpeak source string.
|
||||
|
||||
This is the main entry point for the code generation process. It initializes
|
||||
the AgentSpeak program structure and orchestrates the conversion of all
|
||||
program elements into their AgentSpeak representations.
|
||||
|
||||
:param program: The behavior program to translate.
|
||||
:return: The generated AgentSpeak code as a string.
|
||||
"""
|
||||
self._asp = AstProgram()
|
||||
|
||||
if program.phases:
|
||||
self._asp.rules.append(AstRule(self._astify(program.phases[0])))
|
||||
else:
|
||||
self._asp.rules.append(AstRule(AstLiteral("phase", [AstString("end")])))
|
||||
|
||||
self._asp.rules.append(AstRule(AstLiteral("!notify_cycle")))
|
||||
|
||||
self._add_keyword_inference()
|
||||
self._add_default_plans()
|
||||
|
||||
self._process_phases(program.phases)
|
||||
|
||||
self._add_fallbacks()
|
||||
|
||||
return str(self._asp)
|
||||
|
||||
def _add_keyword_inference(self) -> None:
|
||||
"""
|
||||
Adds inference rules for keyword detection in user messages.
|
||||
|
||||
This method creates rules that allow the system to detect when specific
|
||||
keywords are mentioned in user messages. It uses string operations to
|
||||
check if a keyword is a substring of the user's message.
|
||||
|
||||
The generated rule has the form:
|
||||
keyword_said(Keyword) :- user_said(Message) & .substring(Keyword, Message, Pos) & Pos >= 0
|
||||
|
||||
This enables the system to trigger behaviors based on keyword detection.
|
||||
"""
|
||||
keyword = AstVar("Keyword")
|
||||
message = AstVar("Message")
|
||||
position = AstVar("Pos")
|
||||
|
||||
self._asp.rules.append(
|
||||
AstRule(
|
||||
AstLiteral("keyword_said", [keyword]),
|
||||
AstLiteral("user_said", [message])
|
||||
& AstLiteral(".substring", [keyword, message, position])
|
||||
& (position >= 0),
|
||||
)
|
||||
)
|
||||
|
||||
def _add_default_plans(self):
|
||||
"""
|
||||
Adds default plans for common operations.
|
||||
|
||||
This method sets up the standard plans that handle fundamental operations
|
||||
like replying with goals, simple speech actions, general replies, and
|
||||
cycle notifications. These plans provide the basic infrastructure for
|
||||
the agent's reactive behavior.
|
||||
"""
|
||||
self._add_reply_with_goal_plan()
|
||||
self._add_say_plan()
|
||||
self._add_reply_plan()
|
||||
self._add_notify_cycle_plan()
|
||||
|
||||
def _add_reply_with_goal_plan(self):
|
||||
"""
|
||||
Adds a plan for replying with a specific conversational goal.
|
||||
|
||||
This plan handles the case where the agent needs to respond to user input
|
||||
while pursuing a specific conversational goal. It:
|
||||
1. Marks that the agent has responded this turn
|
||||
2. Gathers all active norms
|
||||
3. Generates a reply that considers both the user message and the goal
|
||||
|
||||
Trigger: +!reply_with_goal(Goal)
|
||||
Context: user_said(Message)
|
||||
"""
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("reply_with_goal", [AstVar("Goal")]),
|
||||
[AstLiteral("user_said", [AstVar("Message")])],
|
||||
[
|
||||
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral(
|
||||
"findall",
|
||||
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
|
||||
),
|
||||
),
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral(
|
||||
"reply_with_goal", [AstVar("Message"), AstVar("Norms"), AstVar("Goal")]
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
def _add_say_plan(self):
|
||||
"""
|
||||
Adds a plan for simple speech actions.
|
||||
|
||||
This plan handles direct speech actions where the agent needs to say
|
||||
a specific text. It:
|
||||
1. Marks that the agent has responded this turn
|
||||
2. Executes the speech action
|
||||
|
||||
Trigger: +!say(Text)
|
||||
Context: None (can be executed anytime)
|
||||
"""
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("say", [AstVar("Text")]),
|
||||
[],
|
||||
[
|
||||
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
|
||||
AstStatement(StatementType.DO_ACTION, AstLiteral("say", [AstVar("Text")])),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
def _add_reply_plan(self):
|
||||
"""
|
||||
Adds a plan for general reply actions.
|
||||
|
||||
This plan handles general reply actions where the agent needs to respond
|
||||
to user input without a specific conversational goal. It:
|
||||
1. Marks that the agent has responded this turn
|
||||
2. Gathers all active norms
|
||||
3. Generates a reply based on the user message and norms
|
||||
|
||||
Trigger: +!reply
|
||||
Context: user_said(Message)
|
||||
"""
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("reply"),
|
||||
[AstLiteral("user_said", [AstVar("Message")])],
|
||||
[
|
||||
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral(
|
||||
"findall",
|
||||
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
|
||||
),
|
||||
),
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral("reply", [AstVar("Message"), AstVar("Norms")]),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
def _add_notify_cycle_plan(self):
|
||||
"""
|
||||
Adds a plan for cycle notification.
|
||||
|
||||
This plan handles the periodic notification cycle that allows the system
|
||||
to monitor and report on the current state. It:
|
||||
1. Gathers all active norms
|
||||
2. Notifies the system about the current norms
|
||||
3. Waits briefly to allow processing
|
||||
4. Recursively triggers the next cycle
|
||||
|
||||
Trigger: +!notify_cycle
|
||||
Context: None (can be executed anytime)
|
||||
"""
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("notify_cycle"),
|
||||
[],
|
||||
[
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral(
|
||||
"findall",
|
||||
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
|
||||
),
|
||||
),
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION, AstLiteral("notify_norms", [AstVar("Norms")])
|
||||
),
|
||||
AstStatement(StatementType.DO_ACTION, AstLiteral("wait", [AstNumber(100)])),
|
||||
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("notify_cycle")),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
def _process_phases(self, phases: list[Phase]) -> None:
|
||||
"""
|
||||
Processes all phases in the program and their transitions.
|
||||
|
||||
This method iterates through each phase and:
|
||||
1. Processes the current phase (norms, goals, triggers)
|
||||
2. Sets up transitions between phases
|
||||
3. Adds special handling for the end phase
|
||||
|
||||
:param phases: The list of phases to process.
|
||||
"""
|
||||
for curr_phase, next_phase in zip([None] + phases, phases + [None], strict=True):
|
||||
if curr_phase:
|
||||
self._process_phase(curr_phase)
|
||||
self._add_phase_transition(curr_phase, next_phase)
|
||||
|
||||
# End phase behavior
|
||||
# When deleting this, the entire `reply` plan and action can be deleted
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
type=TriggerType.ADDED_BELIEF,
|
||||
trigger_literal=AstLiteral("user_said", [AstVar("Message")]),
|
||||
context=[AstLiteral("phase", [AstString("end")])],
|
||||
body=[
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")])
|
||||
),
|
||||
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply")),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
def _process_phase(self, phase: Phase) -> None:
|
||||
"""
|
||||
Processes a single phase, including its norms, goals, and triggers.
|
||||
|
||||
This method handles the complete processing of a phase by:
|
||||
1. Processing all norms in the phase
|
||||
2. Setting up the default execution loop for the phase
|
||||
3. Processing all goals in sequence
|
||||
4. Processing all triggers for reactive behavior
|
||||
|
||||
:param phase: The phase to process.
|
||||
"""
|
||||
for norm in phase.norms:
|
||||
self._process_norm(norm, phase)
|
||||
|
||||
self._add_default_loop(phase)
|
||||
|
||||
previous_goal = None
|
||||
for goal in phase.goals:
|
||||
self._process_goal(goal, phase, previous_goal, main_goal=True)
|
||||
previous_goal = goal
|
||||
|
||||
for trigger in phase.triggers:
|
||||
self._process_trigger(trigger, phase)
|
||||
|
||||
def _add_phase_transition(self, from_phase: Phase | None, to_phase: Phase | None) -> None:
|
||||
"""
|
||||
Adds plans for transitioning between phases.
|
||||
|
||||
This method creates two plans for each phase transition:
|
||||
1. A check plan that verifies if transition conditions are met
|
||||
2. A force plan that actually performs the transition (can be forced externally)
|
||||
|
||||
The transition involves:
|
||||
- Notifying the system about the phase change
|
||||
- Removing the current phase belief
|
||||
- Adding the next phase belief
|
||||
|
||||
:param from_phase: The phase being transitioned from (or None for initial setup).
|
||||
:param to_phase: The phase being transitioned to (or None for end phase).
|
||||
"""
|
||||
if from_phase is None:
|
||||
return
|
||||
from_phase_ast = self._astify(from_phase)
|
||||
to_phase_ast = (
|
||||
self._astify(to_phase) if to_phase else AstLiteral("phase", [AstString("end")])
|
||||
)
|
||||
|
||||
check_context = [from_phase_ast]
|
||||
if from_phase:
|
||||
for goal in from_phase.goals:
|
||||
check_context.append(self._astify(goal, achieved=True))
|
||||
|
||||
force_context = [from_phase_ast]
|
||||
|
||||
body = [
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral(
|
||||
"notify_transition_phase",
|
||||
[
|
||||
AstString(str(from_phase.id)),
|
||||
AstString(str(to_phase.id) if to_phase else "end"),
|
||||
],
|
||||
),
|
||||
),
|
||||
AstStatement(StatementType.REMOVE_BELIEF, from_phase_ast),
|
||||
AstStatement(StatementType.ADD_BELIEF, to_phase_ast),
|
||||
]
|
||||
|
||||
# Check
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("transition_phase"),
|
||||
check_context,
|
||||
[
|
||||
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("force_transition_phase")),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# Force
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL, AstLiteral("force_transition_phase"), force_context, body
|
||||
)
|
||||
)
|
||||
|
||||
def _process_norm(self, norm: Norm, phase: Phase) -> None:
|
||||
"""
|
||||
Processes a norm and adds it as an inference rule.
|
||||
|
||||
This method converts norms into AgentSpeak rules that define when
|
||||
the norm should be active. It handles both basic norms (always active
|
||||
in their phase) and conditional norms (active only when their condition
|
||||
is met).
|
||||
|
||||
:param norm: The norm to process.
|
||||
:param phase: The phase this norm belongs to.
|
||||
"""
|
||||
rule: AstRule | None = None
|
||||
|
||||
match norm:
|
||||
case ConditionalNorm(condition=cond):
|
||||
rule = AstRule(
|
||||
self._astify(norm),
|
||||
self._astify(phase) & self._astify(cond)
|
||||
| AstAtom(f"force_{self.slugify(norm)}"),
|
||||
)
|
||||
case BasicNorm():
|
||||
rule = AstRule(self._astify(norm), self._astify(phase))
|
||||
|
||||
if not rule:
|
||||
return
|
||||
|
||||
self._asp.rules.append(rule)
|
||||
|
||||
def _add_default_loop(self, phase: Phase) -> None:
|
||||
"""
|
||||
Adds the default execution loop for a phase.
|
||||
|
||||
This method creates the main reactive loop that runs when the agent
|
||||
receives user input during a phase. The loop:
|
||||
1. Notifies the system about the user input
|
||||
2. Resets the response tracking
|
||||
3. Executes all phase goals
|
||||
4. Attempts phase transition
|
||||
|
||||
:param phase: The phase to create the loop for.
|
||||
"""
|
||||
actions = []
|
||||
|
||||
actions.append(
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")])
|
||||
)
|
||||
)
|
||||
actions.append(AstStatement(StatementType.REMOVE_BELIEF, AstLiteral("responded_this_turn")))
|
||||
|
||||
for goal in phase.goals:
|
||||
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, self._astify(goal)))
|
||||
|
||||
actions.append(AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("transition_phase")))
|
||||
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_BELIEF,
|
||||
AstLiteral("user_said", [AstVar("Message")]),
|
||||
[self._astify(phase)],
|
||||
actions,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_goal(
|
||||
self,
|
||||
goal: Goal,
|
||||
phase: Phase,
|
||||
previous_goal: Goal | None = None,
|
||||
continues_response: bool = False,
|
||||
main_goal: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Processes a goal and creates plans for achieving it.
|
||||
|
||||
This method creates two plans for each goal:
|
||||
1. A main plan that executes the goal's steps when conditions are met
|
||||
2. A fallback plan that provides a default empty implementation (prevents crashes)
|
||||
|
||||
The method also recursively processes any subgoals contained within
|
||||
the goal's plan.
|
||||
|
||||
:param goal: The goal to process.
|
||||
:param phase: The phase this goal belongs to.
|
||||
:param previous_goal: The previous goal in sequence (for dependency tracking).
|
||||
:param continues_response: Whether this goal continues an existing response.
|
||||
:param main_goal: Whether this is a main goal (for UI notification purposes).
|
||||
"""
|
||||
context: list[AstExpression] = [self._astify(phase)]
|
||||
context.append(~self._astify(goal, achieved=True))
|
||||
if previous_goal and previous_goal.can_fail:
|
||||
context.append(self._astify(previous_goal, achieved=True))
|
||||
if not continues_response:
|
||||
context.append(~AstLiteral("responded_this_turn"))
|
||||
|
||||
body = []
|
||||
if main_goal: # UI only needs to know about the main goals
|
||||
body.append(
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral("notify_goal_start", [AstString(self.slugify(goal))]),
|
||||
)
|
||||
)
|
||||
|
||||
subgoals = []
|
||||
for step in goal.plan.steps:
|
||||
body.append(self._step_to_statement(step))
|
||||
if isinstance(step, Goal):
|
||||
subgoals.append(step)
|
||||
|
||||
if not goal.can_fail and not continues_response:
|
||||
body.append(AstStatement(StatementType.ADD_BELIEF, self._astify(goal, achieved=True)))
|
||||
|
||||
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(goal), context, body))
|
||||
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
self._astify(goal),
|
||||
context=[],
|
||||
body=[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
|
||||
)
|
||||
)
|
||||
|
||||
prev_goal = None
|
||||
for subgoal in subgoals:
|
||||
self._process_goal(subgoal, phase, prev_goal)
|
||||
prev_goal = subgoal
|
||||
|
||||
def _step_to_statement(self, step: PlanElement) -> AstStatement:
|
||||
"""
|
||||
Converts a plan step to an AgentSpeak statement.
|
||||
|
||||
This method transforms different types of plan elements into their
|
||||
corresponding AgentSpeak statements. Goals and speech-related actions
|
||||
become achieve-goal statements, while gesture actions become do-action
|
||||
statements.
|
||||
|
||||
:param step: The plan element to convert.
|
||||
:return: The corresponding AgentSpeak statement.
|
||||
"""
|
||||
match step:
|
||||
# Note that SpeechAction gets included in the ACHIEVE_GOAL, since it's a goal internally
|
||||
case Goal() | SpeechAction() | LLMAction() as a:
|
||||
return AstStatement(StatementType.ACHIEVE_GOAL, self._astify(a))
|
||||
case GestureAction() as a:
|
||||
return AstStatement(StatementType.DO_ACTION, self._astify(a))
|
||||
|
||||
def _process_trigger(self, trigger: Trigger, phase: Phase) -> None:
|
||||
"""
|
||||
Processes a trigger and creates plans for its execution.
|
||||
|
||||
This method creates plans that execute when trigger conditions are met.
|
||||
It handles both automatic triggering (when conditions are detected) and
|
||||
manual forcing (from UI). The trigger execution includes:
|
||||
1. Notifying the system about trigger start
|
||||
2. Executing all trigger steps
|
||||
3. Waiting briefly for UI display
|
||||
4. Notifying the system about trigger end
|
||||
|
||||
:param trigger: The trigger to process.
|
||||
:param phase: The phase this trigger belongs to.
|
||||
"""
|
||||
body = []
|
||||
subgoals = []
|
||||
|
||||
body.append(
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral("notify_trigger_start", [AstString(self.slugify(trigger))]),
|
||||
)
|
||||
)
|
||||
for step in trigger.plan.steps:
|
||||
body.append(self._step_to_statement(step))
|
||||
if isinstance(step, Goal):
|
||||
step.can_fail = False # triggers are continuous sequence
|
||||
subgoals.append(step)
|
||||
|
||||
# Arbitrary wait for UI to display nicely
|
||||
body.append(
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral("wait", [AstNumber(settings.behaviour_settings.trigger_time_to_wait)]),
|
||||
)
|
||||
)
|
||||
|
||||
body.append(
|
||||
AstStatement(
|
||||
StatementType.DO_ACTION,
|
||||
AstLiteral("notify_trigger_end", [AstString(self.slugify(trigger))]),
|
||||
)
|
||||
)
|
||||
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("check_triggers"),
|
||||
[self._astify(phase), self._astify(trigger.condition)],
|
||||
body,
|
||||
)
|
||||
)
|
||||
|
||||
# Force trigger (from UI)
|
||||
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(trigger), [], body))
|
||||
|
||||
for subgoal in subgoals:
|
||||
self._process_goal(subgoal, phase, continues_response=True)
|
||||
|
||||
def _add_fallbacks(self):
|
||||
"""
|
||||
Adds fallback plans for robust execution, preventing crashes.
|
||||
|
||||
This method creates fallback plans that provide default empty implementations
|
||||
for key goals. These fallbacks ensure that the system can continue execution
|
||||
even when no specific plans are applicable, preventing crashes.
|
||||
|
||||
The fallbacks are created for:
|
||||
- check_triggers: When no triggers are applicable
|
||||
- transition_phase: When phase transition conditions aren't met
|
||||
- force_transition_phase: When forced transitions aren't possible
|
||||
"""
|
||||
# Trigger fallback
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("check_triggers"),
|
||||
[],
|
||||
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
|
||||
)
|
||||
)
|
||||
|
||||
# Phase transition fallback
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("transition_phase"),
|
||||
[],
|
||||
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
|
||||
)
|
||||
)
|
||||
|
||||
# Force phase transition fallback
|
||||
self._asp.plans.append(
|
||||
AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("force_transition_phase"),
|
||||
[],
|
||||
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
|
||||
)
|
||||
)
|
||||
|
||||
@singledispatchmethod
|
||||
def _astify(self, element: ProgramElement) -> AstExpression:
|
||||
"""
|
||||
Converts program elements to AgentSpeak expressions (base method).
|
||||
|
||||
This is the base method for the singledispatch mechanism that handles
|
||||
conversion of different program element types to their AgentSpeak
|
||||
representations. Specific implementations are provided for each
|
||||
element type through registered methods.
|
||||
|
||||
:param element: The program element to convert.
|
||||
:return: The corresponding AgentSpeak expression.
|
||||
:raises NotImplementedError: If no specific implementation exists for the element type.
|
||||
"""
|
||||
raise NotImplementedError(f"Cannot convert element {element} to an AgentSpeak expression.")
|
||||
|
||||
@_astify.register
|
||||
def _(self, kwb: KeywordBelief) -> AstExpression:
|
||||
"""
|
||||
Converts a KeywordBelief to an AgentSpeak expression.
|
||||
|
||||
Keyword beliefs are converted to keyword_said literals that check
|
||||
if the keyword was mentioned in user input.
|
||||
|
||||
:param kwb: The KeywordBelief to convert.
|
||||
:return: An AstLiteral representing the keyword detection.
|
||||
"""
|
||||
return AstLiteral("keyword_said", [AstString(kwb.keyword)])
|
||||
|
||||
@_astify.register
|
||||
def _(self, sb: SemanticBelief) -> AstExpression:
|
||||
"""
|
||||
Converts a SemanticBelief to an AgentSpeak expression.
|
||||
|
||||
Semantic beliefs are converted to literals using their slugified names,
|
||||
which are used for LLM-based belief evaluation.
|
||||
|
||||
:param sb: The SemanticBelief to convert.
|
||||
:return: An AstLiteral representing the semantic belief.
|
||||
"""
|
||||
return AstLiteral(self.slugify(sb))
|
||||
|
||||
@_astify.register
|
||||
def _(self, ib: InferredBelief) -> AstExpression:
|
||||
"""
|
||||
Converts an InferredBelief to an AgentSpeak expression.
|
||||
|
||||
Inferred beliefs are converted to binary operations that combine
|
||||
their left and right operands using the appropriate logical operator.
|
||||
|
||||
:param ib: The InferredBelief to convert.
|
||||
:return: An AstBinaryOp representing the logical combination.
|
||||
"""
|
||||
return AstBinaryOp(
|
||||
self._astify(ib.left),
|
||||
BinaryOperatorType.AND if ib.operator == LogicalOperator.AND else BinaryOperatorType.OR,
|
||||
self._astify(ib.right),
|
||||
)
|
||||
|
||||
@_astify.register
|
||||
def _(self, norm: Norm) -> AstExpression:
|
||||
"""
|
||||
Converts a Norm to an AgentSpeak expression.
|
||||
|
||||
Norms are converted to literals with either 'norm' or 'critical_norm'
|
||||
functors depending on their critical flag, with the norm text as an argument.
|
||||
|
||||
Note that currently, critical norms are not yet functionally supported. They are possible
|
||||
to astify for future use.
|
||||
|
||||
:param norm: The Norm to convert.
|
||||
:return: An AstLiteral representing the norm.
|
||||
"""
|
||||
functor = "critical_norm" if norm.critical else "norm"
|
||||
return AstLiteral(functor, [AstString(norm.norm)])
|
||||
|
||||
@_astify.register
|
||||
def _(self, phase: Phase) -> AstExpression:
|
||||
"""
|
||||
Converts a Phase to an AgentSpeak expression.
|
||||
|
||||
Phases are converted to phase literals with their unique identifier
|
||||
as an argument, which is used for phase tracking and transitions.
|
||||
|
||||
:param phase: The Phase to convert.
|
||||
:return: An AstLiteral representing the phase.
|
||||
"""
|
||||
return AstLiteral("phase", [AstString(str(phase.id))])
|
||||
|
||||
@_astify.register
|
||||
def _(self, goal: Goal, achieved: bool = False) -> AstExpression:
|
||||
"""
|
||||
Converts a Goal to an AgentSpeak expression.
|
||||
|
||||
Goals are converted to literals using their slugified names. If the
|
||||
achieved parameter is True, the literal is prefixed with 'achieved_'.
|
||||
|
||||
:param goal: The Goal to convert.
|
||||
:param achieved: Whether to represent this as an achieved goal.
|
||||
:return: An AstLiteral representing the goal.
|
||||
"""
|
||||
return AstLiteral(f"{'achieved_' if achieved else ''}{self._slugify_str(goal.name)}")
|
||||
|
||||
@_astify.register
|
||||
def _(self, trigger: Trigger) -> AstExpression:
|
||||
"""
|
||||
Converts a Trigger to an AgentSpeak expression.
|
||||
|
||||
Triggers are converted to literals using their slugified names,
|
||||
which are used to identify and execute trigger plans.
|
||||
|
||||
:param trigger: The Trigger to convert.
|
||||
:return: An AstLiteral representing the trigger.
|
||||
"""
|
||||
return AstLiteral(self.slugify(trigger))
|
||||
|
||||
@_astify.register
|
||||
def _(self, sa: SpeechAction) -> AstExpression:
|
||||
"""
|
||||
Converts a SpeechAction to an AgentSpeak expression.
|
||||
|
||||
Speech actions are converted to say literals with the text content
|
||||
as an argument, which are used for direct speech output.
|
||||
|
||||
:param sa: The SpeechAction to convert.
|
||||
:return: An AstLiteral representing the speech action.
|
||||
"""
|
||||
return AstLiteral("say", [AstString(sa.text)])
|
||||
|
||||
@_astify.register
|
||||
def _(self, ga: GestureAction) -> AstExpression:
|
||||
"""
|
||||
Converts a GestureAction to an AgentSpeak expression.
|
||||
|
||||
Gesture actions are converted to gesture literals with the gesture
|
||||
type and name as arguments, which are used for physical robot gestures.
|
||||
|
||||
:param ga: The GestureAction to convert.
|
||||
:return: An AstLiteral representing the gesture action.
|
||||
"""
|
||||
gesture = ga.gesture
|
||||
return AstLiteral("gesture", [AstString(gesture.type), AstString(gesture.name)])
|
||||
|
||||
@_astify.register
|
||||
def _(self, la: LLMAction) -> AstExpression:
|
||||
"""
|
||||
Converts an LLMAction to an AgentSpeak expression.
|
||||
|
||||
LLM actions are converted to reply_with_goal literals with the
|
||||
conversational goal as an argument, which are used for LLM-generated
|
||||
responses guided by specific goals.
|
||||
|
||||
:param la: The LLMAction to convert.
|
||||
:return: An AstLiteral representing the LLM action.
|
||||
"""
|
||||
return AstLiteral("reply_with_goal", [AstString(la.goal)])
|
||||
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
def slugify(element: ProgramElement) -> str:
|
||||
"""
|
||||
Converts program elements to slugs (base method).
|
||||
|
||||
This is the base method for the singledispatch mechanism that handles
|
||||
conversion of different program element types to their slug representations.
|
||||
Specific implementations are provided for each element type through
|
||||
registered methods.
|
||||
|
||||
Slugs are used outside of AgentSpeak, mostly for identifying what to send to the AgentSpeak
|
||||
program as beliefs.
|
||||
|
||||
:param element: The program element to convert to a slug.
|
||||
:return: The slug string representation.
|
||||
:raises NotImplementedError: If no specific implementation exists for the element type.
|
||||
"""
|
||||
raise NotImplementedError(f"Cannot convert element {element} to a slug.")
|
||||
|
||||
@slugify.register
|
||||
@staticmethod
|
||||
def _(n: Norm) -> str:
|
||||
"""
|
||||
Converts a Norm to a slug.
|
||||
|
||||
Norms are converted to slugs with the 'norm_' prefix followed by
|
||||
the slugified norm text.
|
||||
|
||||
:param n: The Norm to convert.
|
||||
:return: The slug string representation.
|
||||
"""
|
||||
return f"norm_{AgentSpeakGenerator._slugify_str(n.norm)}"
|
||||
|
||||
@slugify.register
|
||||
@staticmethod
|
||||
def _(sb: SemanticBelief) -> str:
|
||||
"""
|
||||
Converts a SemanticBelief to a slug.
|
||||
|
||||
Semantic beliefs are converted to slugs with the 'semantic_' prefix
|
||||
followed by the slugified belief name.
|
||||
|
||||
:param sb: The SemanticBelief to convert.
|
||||
:return: The slug string representation.
|
||||
"""
|
||||
return f"semantic_{AgentSpeakGenerator._slugify_str(sb.name)}"
|
||||
|
||||
@slugify.register
|
||||
@staticmethod
|
||||
def _(g: BaseGoal) -> str:
|
||||
"""
|
||||
Converts a BaseGoal to a slug.
|
||||
|
||||
Goals are converted to slugs using their slugified names directly.
|
||||
|
||||
:param g: The BaseGoal to convert.
|
||||
:return: The slug string representation.
|
||||
"""
|
||||
return AgentSpeakGenerator._slugify_str(g.name)
|
||||
|
||||
@slugify.register
|
||||
@staticmethod
|
||||
def _(t: Trigger) -> str:
|
||||
"""
|
||||
Converts a Trigger to a slug.
|
||||
|
||||
Triggers are converted to slugs with the 'trigger_' prefix followed by
|
||||
the slugified trigger name.
|
||||
|
||||
:param t: The Trigger to convert.
|
||||
:return: The slug string representation.
|
||||
"""
|
||||
return f"trigger_{AgentSpeakGenerator._slugify_str(t.name)}"
|
||||
|
||||
@staticmethod
|
||||
def _slugify_str(text: str) -> str:
|
||||
"""
|
||||
Converts a text string to a slug.
|
||||
|
||||
This helper method converts arbitrary text to a URL-friendly slug format
|
||||
by converting to lowercase, removing special characters, and replacing
|
||||
spaces with underscores. It also removes common stopwords.
|
||||
|
||||
:param text: The text string to convert.
|
||||
:return: The slugified string.
|
||||
"""
|
||||
return slugify(text, separator="_", stopwords=["a", "an", "the", "we", "you", "I"])
|
||||
546
src/control_backend/agents/bdi/bdi_core_agent.py
Normal file
546
src/control_backend/agents/bdi/bdi_core_agent.py
Normal file
@@ -0,0 +1,546 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
|
||||
import agentspeak
|
||||
import agentspeak.runtime
|
||||
import agentspeak.stdlib
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.agents.base import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_message import BeliefMessage
|
||||
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
|
||||
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, SpeechCommand
|
||||
|
||||
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
|
||||
|
||||
|
||||
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||
|
||||
|
||||
class BDICoreAgent(BaseAgent):
|
||||
"""
|
||||
BDI Core Agent.
|
||||
|
||||
This is the central reasoning agent of the system, powered by the **AgentSpeak(L)** language.
|
||||
It maintains a belief base (representing the state of the world) and a set of plans (rules).
|
||||
|
||||
It runs an internal BDI (Belief-Desire-Intention) cycle using the ``agentspeak`` library.
|
||||
When beliefs change (e.g., via :meth:`_apply_beliefs`), the agent evaluates its plans to
|
||||
determine the best course of action.
|
||||
|
||||
**Custom Actions:**
|
||||
It defines custom actions (like ``.reply``) that allow the AgentSpeak code to interact with
|
||||
external Python agents (e.g., querying the LLM).
|
||||
|
||||
:ivar bdi_agent: The internal AgentSpeak agent instance.
|
||||
:ivar asl_file: Path to the AgentSpeak source file (.asl).
|
||||
:ivar env: The AgentSpeak environment.
|
||||
:ivar actions: A registry of custom actions available to the AgentSpeak code.
|
||||
:ivar _wake_bdi_loop: Event used to wake up the reasoning loop when new beliefs arrive.
|
||||
"""
|
||||
|
||||
bdi_agent: agentspeak.runtime.Agent
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.env = agentspeak.runtime.Environment()
|
||||
# Deep copy because we don't actually want to modify the standard actions globally
|
||||
self.actions = copy.deepcopy(agentspeak.stdlib.actions)
|
||||
self._wake_bdi_loop = asyncio.Event()
|
||||
self._bdi_loop_task = None
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""
|
||||
Initialize the BDI agent.
|
||||
|
||||
1. Registers custom actions (like ``.reply``).
|
||||
2. Loads the .asl source file.
|
||||
3. Starts the reasoning loop (:meth:`_bdi_loop`) in the background.
|
||||
"""
|
||||
self.logger.debug("Setup started.")
|
||||
|
||||
self._add_custom_actions()
|
||||
|
||||
await self._load_asl()
|
||||
|
||||
# Start the BDI cycle loop
|
||||
self._bdi_loop_task = self.add_behavior(self._bdi_loop())
|
||||
self._wake_bdi_loop.set()
|
||||
self.logger.debug("Setup complete.")
|
||||
|
||||
async def _load_asl(self, file_name: str | None = None) -> None:
|
||||
"""
|
||||
Load and parse the AgentSpeak source file.
|
||||
"""
|
||||
file_name = file_name or "src/control_backend/agents/bdi/default_behavior.asl"
|
||||
|
||||
try:
|
||||
with open(file_name) as source:
|
||||
self.bdi_agent = self.env.build_agent(source, self.actions)
|
||||
self.logger.info(f"Loaded new ASL from {file_name}.")
|
||||
except FileNotFoundError:
|
||||
self.logger.warning(f"Could not find the specified ASL file at {file_name}.")
|
||||
self.bdi_agent = agentspeak.runtime.Agent(self.env, self.name)
|
||||
|
||||
async def _bdi_loop(self):
|
||||
"""
|
||||
The main BDI reasoning loop.
|
||||
|
||||
It waits for the ``_wake_bdi_loop`` event (set when beliefs change or actions complete).
|
||||
When awake, it steps through the AgentSpeak interpreter. It also handles sleeping if
|
||||
the agent has deferred intentions (deadlines).
|
||||
"""
|
||||
while self._running:
|
||||
await (
|
||||
self._wake_bdi_loop.wait()
|
||||
) # gets set whenever there's an update to the belief base
|
||||
|
||||
# Agent knows when it's expected to have to do its next thing
|
||||
maybe_more_work = True
|
||||
while maybe_more_work:
|
||||
maybe_more_work = False
|
||||
if self.bdi_agent.step():
|
||||
maybe_more_work = True
|
||||
|
||||
if not maybe_more_work:
|
||||
deadline = self.bdi_agent.shortest_deadline()
|
||||
if deadline:
|
||||
await asyncio.sleep(deadline - time.time())
|
||||
maybe_more_work = True
|
||||
else:
|
||||
self._wake_bdi_loop.clear()
|
||||
self.logger.debug("No more deadlines. Halting BDI loop.")
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming messages.
|
||||
|
||||
- **Beliefs**: Updates the internal belief base.
|
||||
- **Program**: Updates the internal agentspeak file to match the current program.
|
||||
- **LLM Responses**: Forwards the generated text to the Robot Speech Agent (actuation).
|
||||
|
||||
:param msg: The received internal message.
|
||||
"""
|
||||
self.logger.debug("Processing message from %s.", msg.sender)
|
||||
|
||||
if msg.thread == "beliefs":
|
||||
try:
|
||||
belief_changes = BeliefMessage.model_validate_json(msg.body)
|
||||
self._apply_belief_changes(belief_changes)
|
||||
except ValidationError:
|
||||
self.logger.exception("Error processing belief.")
|
||||
return
|
||||
|
||||
# New agentspeak file
|
||||
if msg.thread == "new_program":
|
||||
if self._bdi_loop_task:
|
||||
self._bdi_loop_task.cancel()
|
||||
await self._load_asl(msg.body)
|
||||
self.add_behavior(self._bdi_loop())
|
||||
|
||||
# The message was not a belief, handle special cases based on sender
|
||||
match msg.sender:
|
||||
case settings.agent_settings.llm_name:
|
||||
content = msg.body
|
||||
self.logger.info("Received LLM response: %s", content)
|
||||
|
||||
# Forward to Robot Speech Agent
|
||||
cmd = SpeechCommand(data=content)
|
||||
out_msg = InternalMessage(
|
||||
to=settings.agent_settings.robot_speech_name,
|
||||
sender=self.name,
|
||||
body=cmd.model_dump_json(),
|
||||
)
|
||||
await self.send(out_msg)
|
||||
case settings.agent_settings.user_interrupt_name:
|
||||
self.logger.debug("Received user interruption: %s", msg)
|
||||
|
||||
match msg.thread:
|
||||
case "force_phase_transition":
|
||||
self._set_goal("transition_phase")
|
||||
case "force_trigger":
|
||||
self._force_trigger(msg.body)
|
||||
case "force_norm":
|
||||
self._force_norm(msg.body)
|
||||
case "force_next_phase":
|
||||
self._force_next_phase()
|
||||
case _:
|
||||
self.logger.warning("Received unknown user interruption: %s", msg)
|
||||
|
||||
def _apply_belief_changes(self, belief_changes: BeliefMessage):
|
||||
"""
|
||||
Update the belief base with a list of new beliefs.
|
||||
|
||||
For beliefs in ``belief_changes.replace``, it removes all existing beliefs with that name
|
||||
before adding one new one.
|
||||
|
||||
:param belief_changes: The changes in beliefs to apply.
|
||||
"""
|
||||
if not belief_changes.create and not belief_changes.replace and not belief_changes.delete:
|
||||
return
|
||||
|
||||
for belief in belief_changes.create:
|
||||
self._add_belief(belief.name, belief.arguments)
|
||||
|
||||
for belief in belief_changes.replace:
|
||||
self._remove_all_with_name(belief.name)
|
||||
self._add_belief(belief.name, belief.arguments)
|
||||
|
||||
for belief in belief_changes.delete:
|
||||
self._remove_belief(belief.name, belief.arguments)
|
||||
|
||||
def _add_belief(self, name: str, args: list[str] = None):
|
||||
"""
|
||||
Add a single belief to the BDI agent.
|
||||
|
||||
:param name: The functor/name of the belief (e.g., "user_said").
|
||||
:param args: Arguments for the belief.
|
||||
"""
|
||||
# new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple
|
||||
args = args or []
|
||||
if args:
|
||||
merged_args = DELIMITER.join(arg for arg in args)
|
||||
new_args = (agentspeak.Literal(merged_args),)
|
||||
term = agentspeak.Literal(name, new_args)
|
||||
else:
|
||||
term = agentspeak.Literal(name)
|
||||
|
||||
if name != "user_said":
|
||||
experiment_logger.observation(f"Formed new belief: {name}{f'={args}' if args else ''}")
|
||||
|
||||
self.bdi_agent.call(
|
||||
agentspeak.Trigger.addition,
|
||||
agentspeak.GoalType.belief,
|
||||
term,
|
||||
agentspeak.runtime.Intention(),
|
||||
)
|
||||
|
||||
# Check for transitions
|
||||
self.bdi_agent.call(
|
||||
agentspeak.Trigger.addition,
|
||||
agentspeak.GoalType.achievement,
|
||||
agentspeak.Literal("transition_phase"),
|
||||
agentspeak.runtime.Intention(),
|
||||
)
|
||||
|
||||
# Check triggers
|
||||
self.bdi_agent.call(
|
||||
agentspeak.Trigger.addition,
|
||||
agentspeak.GoalType.achievement,
|
||||
agentspeak.Literal("check_triggers"),
|
||||
agentspeak.runtime.Intention(),
|
||||
)
|
||||
|
||||
self._wake_bdi_loop.set()
|
||||
|
||||
self.logger.debug(f"Added belief {self.format_belief_string(name, args)}")
|
||||
|
||||
def _remove_belief(self, name: str, args: Iterable[str] | None):
|
||||
"""
|
||||
Removes a specific belief (with arguments), if it exists.
|
||||
"""
|
||||
if args is None:
|
||||
term = agentspeak.Literal(name)
|
||||
else:
|
||||
new_args = (agentspeak.Literal(arg) for arg in args)
|
||||
term = agentspeak.Literal(name, new_args)
|
||||
|
||||
if name != "user_said":
|
||||
experiment_logger.observation(f"Removed belief: {name}{f'={args}' if args else ''}")
|
||||
|
||||
result = self.bdi_agent.call(
|
||||
agentspeak.Trigger.removal,
|
||||
agentspeak.GoalType.belief,
|
||||
term,
|
||||
agentspeak.runtime.Intention(),
|
||||
)
|
||||
|
||||
if result:
|
||||
self.logger.debug(f"Removed belief {self.format_belief_string(name, args)}")
|
||||
self._wake_bdi_loop.set()
|
||||
else:
|
||||
self.logger.debug("Failed to remove belief (it was not in the belief base).")
|
||||
|
||||
def _remove_all_with_name(self, name: str):
|
||||
"""
|
||||
Removes all beliefs that match the given `name`.
|
||||
"""
|
||||
relevant_groups = []
|
||||
for key in self.bdi_agent.beliefs:
|
||||
if key[0] == name:
|
||||
relevant_groups.append(key)
|
||||
|
||||
removed_count = 0
|
||||
for group in relevant_groups:
|
||||
beliefs_to_remove = list(self.bdi_agent.beliefs[group])
|
||||
for belief in beliefs_to_remove:
|
||||
self.bdi_agent.call(
|
||||
agentspeak.Trigger.removal,
|
||||
agentspeak.GoalType.belief,
|
||||
belief,
|
||||
agentspeak.runtime.Intention(),
|
||||
)
|
||||
removed_count += 1
|
||||
|
||||
self._wake_bdi_loop.set()
|
||||
|
||||
self.logger.debug(f"Removed {removed_count} beliefs.")
|
||||
|
||||
def _set_goal(self, name: str, args: Iterable[str] | None = None):
|
||||
args = args or []
|
||||
|
||||
if args:
|
||||
merged_args = DELIMITER.join(arg for arg in args)
|
||||
new_args = (agentspeak.Literal(merged_args),)
|
||||
term = agentspeak.Literal(name, new_args)
|
||||
else:
|
||||
term = agentspeak.Literal(name)
|
||||
|
||||
self.bdi_agent.call(
|
||||
agentspeak.Trigger.addition,
|
||||
agentspeak.GoalType.achievement,
|
||||
term,
|
||||
agentspeak.runtime.Intention(),
|
||||
)
|
||||
|
||||
self._wake_bdi_loop.set()
|
||||
|
||||
self.logger.debug(f"Set goal !{self.format_belief_string(name, args)}.")
|
||||
|
||||
def _force_trigger(self, name: str):
|
||||
self._set_goal(name)
|
||||
|
||||
self.logger.info("Manually forced trigger %s.", name)
|
||||
|
||||
# TODO: make this compatible for critical norms
|
||||
def _force_norm(self, name: str):
|
||||
self._add_belief(f"force_{name}")
|
||||
|
||||
self.logger.info("Manually forced norm %s.", name)
|
||||
|
||||
def _force_next_phase(self):
|
||||
self._set_goal("force_transition_phase")
|
||||
|
||||
self.logger.info("Manually forced phase transition.")
|
||||
|
||||
def _add_custom_actions(self) -> None:
|
||||
"""
|
||||
Add any custom actions here. Inside `@self.actions.add()`, the first argument is
|
||||
the name of the function in the ASL file, and the second the amount of arguments
|
||||
the function expects (which will be located in `term.args`).
|
||||
"""
|
||||
|
||||
@self.actions.add(".reply", 2)
|
||||
def _reply(agent, term, intention):
|
||||
"""
|
||||
Let the LLM generate a response to a user's utterance with the current norms and goals.
|
||||
"""
|
||||
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
||||
norms = agentspeak.grounded(term.args[1], intention.scope)
|
||||
|
||||
self.add_behavior(self._send_to_llm(str(message_text), str(norms), ""))
|
||||
yield
|
||||
|
||||
@self.actions.add(".reply_with_goal", 3)
|
||||
def _reply_with_goal(agent, term, intention):
|
||||
"""
|
||||
Let the LLM generate a response to a user's utterance with the current norms and a
|
||||
specific goal.
|
||||
"""
|
||||
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
||||
norms = agentspeak.grounded(term.args[1], intention.scope)
|
||||
goal = agentspeak.grounded(term.args[2], intention.scope)
|
||||
self.add_behavior(self._send_to_llm(str(message_text), str(norms), str(goal)))
|
||||
yield
|
||||
|
||||
@self.actions.add(".notify_norms", 1)
|
||||
def _notify_norms(agent, term, intention):
|
||||
norms = agentspeak.grounded(term.args[0], intention.scope)
|
||||
|
||||
norm_update_message = InternalMessage(
|
||||
to=settings.agent_settings.user_interrupt_name,
|
||||
thread="active_norms_update",
|
||||
body=str(norms),
|
||||
)
|
||||
|
||||
self.add_behavior(self.send(norm_update_message, should_log=False))
|
||||
yield
|
||||
|
||||
@self.actions.add(".say", 1)
|
||||
def _say(agent, term, intention):
|
||||
"""
|
||||
Make the robot say the given text instantly.
|
||||
"""
|
||||
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
||||
|
||||
self.logger.debug('"say" action called with text=%s', message_text)
|
||||
|
||||
speech_command = SpeechCommand(data=message_text)
|
||||
speech_message = InternalMessage(
|
||||
to=settings.agent_settings.robot_speech_name,
|
||||
sender=settings.agent_settings.bdi_core_name,
|
||||
body=speech_command.model_dump_json(),
|
||||
)
|
||||
|
||||
self.add_behavior(self.send(speech_message))
|
||||
|
||||
chat_history_message = InternalMessage(
|
||||
to=settings.agent_settings.llm_name,
|
||||
thread="assistant_message",
|
||||
body=str(message_text),
|
||||
)
|
||||
|
||||
experiment_logger.chat(str(message_text), extra={"role": "assistant"})
|
||||
|
||||
self.add_behavior(self.send(chat_history_message))
|
||||
|
||||
yield
|
||||
|
||||
@self.actions.add(".gesture", 2)
|
||||
def _gesture(agent, term, intention):
|
||||
"""
|
||||
Make the robot perform the given gesture instantly.
|
||||
"""
|
||||
gesture_type = agentspeak.grounded(term.args[0], intention.scope)
|
||||
gesture_name = agentspeak.grounded(term.args[1], intention.scope)
|
||||
|
||||
self.logger.debug(
|
||||
'"gesture" action called with type=%s, name=%s',
|
||||
gesture_type,
|
||||
gesture_name,
|
||||
)
|
||||
|
||||
if str(gesture_type) == "single":
|
||||
endpoint = RIEndpoint.GESTURE_SINGLE
|
||||
elif str(gesture_type) == "tag":
|
||||
endpoint = RIEndpoint.GESTURE_TAG
|
||||
else:
|
||||
self.logger.warning("Gesture type %s could not be resolved.", gesture_type)
|
||||
endpoint = RIEndpoint.GESTURE_SINGLE
|
||||
|
||||
gesture_command = GestureCommand(endpoint=endpoint, data=gesture_name)
|
||||
gesture_message = InternalMessage(
|
||||
to=settings.agent_settings.robot_gesture_name,
|
||||
sender=settings.agent_settings.bdi_core_name,
|
||||
body=gesture_command.model_dump_json(),
|
||||
)
|
||||
self.add_behavior(self.send(gesture_message))
|
||||
yield
|
||||
|
||||
@self.actions.add(".notify_user_said", 1)
|
||||
def _notify_user_said(agent, term, intention):
|
||||
user_said = agentspeak.grounded(term.args[0], intention.scope)
|
||||
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.llm_name, thread="user_message", body=str(user_said)
|
||||
)
|
||||
|
||||
self.add_behavior(self.send(msg))
|
||||
|
||||
yield
|
||||
|
||||
@self.actions.add(".notify_trigger_start", 1)
|
||||
def _notify_trigger_start(agent, term, intention):
|
||||
"""
|
||||
Notify the UI about the trigger we just started doing.
|
||||
"""
|
||||
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
|
||||
|
||||
self.logger.debug("Started trigger %s", trigger_name)
|
||||
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.user_interrupt_name,
|
||||
sender=self.name,
|
||||
thread="trigger_start",
|
||||
body=str(trigger_name),
|
||||
)
|
||||
|
||||
# TODO: check with Pim
|
||||
self.add_behavior(self.send(msg))
|
||||
|
||||
yield
|
||||
|
||||
@self.actions.add(".notify_trigger_end", 1)
|
||||
def _notify_trigger_end(agent, term, intention):
|
||||
"""
|
||||
Notify the UI about the trigger we just started doing.
|
||||
"""
|
||||
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
|
||||
|
||||
self.logger.debug("Finished trigger %s", trigger_name)
|
||||
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.user_interrupt_name,
|
||||
sender=self.name,
|
||||
thread="trigger_end",
|
||||
body=str(trigger_name),
|
||||
)
|
||||
|
||||
self.add_behavior(self.send(msg))
|
||||
|
||||
yield
|
||||
|
||||
@self.actions.add(".notify_goal_start", 1)
|
||||
def _notify_goal_start(agent, term, intention):
|
||||
"""
|
||||
Notify the UI about the goal we just started chasing.
|
||||
"""
|
||||
goal_name = agentspeak.grounded(term.args[0], intention.scope)
|
||||
|
||||
self.logger.debug("Started chasing goal %s", goal_name)
|
||||
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.user_interrupt_name,
|
||||
sender=self.name,
|
||||
thread="goal_start",
|
||||
body=str(goal_name),
|
||||
)
|
||||
|
||||
self.add_behavior(self.send(msg))
|
||||
|
||||
yield
|
||||
|
||||
@self.actions.add(".notify_transition_phase", 2)
|
||||
def _notify_transition_phase(agent, term, intention):
|
||||
"""
|
||||
Notify the BDI program manager about a phase transition.
|
||||
"""
|
||||
old = agentspeak.grounded(term.args[0], intention.scope)
|
||||
new = agentspeak.grounded(term.args[1], intention.scope)
|
||||
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.bdi_program_manager_name,
|
||||
thread="transition_phase",
|
||||
body=json.dumps({"old": str(old), "new": str(new)}),
|
||||
)
|
||||
|
||||
self.add_behavior(self.send(msg))
|
||||
|
||||
yield
|
||||
|
||||
async def _send_to_llm(self, text: str, norms: str, goals: str):
|
||||
"""
|
||||
Sends a text query to the LLM agent asynchronously.
|
||||
"""
|
||||
prompt = LLMPromptMessage(text=text, norms=norms.split("\n"), goals=goals.split("\n"))
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.llm_name,
|
||||
sender=self.name,
|
||||
body=prompt.model_dump_json(),
|
||||
thread="prompt_message",
|
||||
)
|
||||
await self.send(msg)
|
||||
self.logger.info("Message sent to LLM agent: %s", text)
|
||||
|
||||
@staticmethod
|
||||
def format_belief_string(name: str, args: Iterable[str] | None = []):
|
||||
"""
|
||||
Given a belief's name and its args, return a string of the form "name(*args)"
|
||||
"""
|
||||
return f"{name}{'(' if args else ''}{','.join(args or [])}{')' if args else ''}"
|
||||
354
src/control_backend/agents/bdi/bdi_program_manager.py
Normal file
354
src/control_backend/agents/bdi/bdi_program_manager.py
Normal file
@@ -0,0 +1,354 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import zmq
|
||||
from pydantic import ValidationError
|
||||
from zmq.asyncio import Context
|
||||
|
||||
import control_backend
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_list import BeliefList, GoalList
|
||||
from control_backend.schemas.internal_message import InternalMessage
|
||||
from control_backend.schemas.program import (
|
||||
Belief,
|
||||
ConditionalNorm,
|
||||
Goal,
|
||||
InferredBelief,
|
||||
Phase,
|
||||
Program,
|
||||
)
|
||||
|
||||
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||
|
||||
|
||||
class BDIProgramManager(BaseAgent):
|
||||
"""
|
||||
BDI Program Manager Agent.
|
||||
|
||||
This agent is responsible for receiving high-level programs (sequences of instructions/goals)
|
||||
from the external HTTP API (via ZMQ), transforming it into an AgentSpeak program, sharing the
|
||||
program and its components to other agents, and keeping agents informed of the current state.
|
||||
|
||||
:ivar sub_socket: The ZMQ SUB socket used to receive program updates.
|
||||
:ivar _program: The current Program.
|
||||
:ivar _phase: The current Phase.
|
||||
:ivar _goal_mapping: A mapping of goal IDs to goals.
|
||||
"""
|
||||
|
||||
_program: Program
|
||||
_phase: Phase | None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.sub_socket = None
|
||||
self._goal_mapping: dict[str, Goal] = {}
|
||||
|
||||
def _initialize_internal_state(self, program: Program):
|
||||
"""
|
||||
Initialize the state of the program manager given a new Program. Reset the tracking of the
|
||||
current phase to the first phase, make a mapping of goal IDs to goals, used during the life
|
||||
of the program.
|
||||
:param program: The new program.
|
||||
"""
|
||||
self._program = program
|
||||
self._phase = program.phases[0] # start in first phase
|
||||
self._goal_mapping = {}
|
||||
for phase in program.phases:
|
||||
for goal in phase.goals:
|
||||
self._populate_goal_mapping_with_goal(goal)
|
||||
|
||||
def _populate_goal_mapping_with_goal(self, goal: Goal):
|
||||
"""
|
||||
Recurse through the given goal and its subgoals and add all goals found to the
|
||||
``self._goal_mapping``.
|
||||
:param goal: The goal to add to the ``self._goal_mapping``, including subgoals.
|
||||
"""
|
||||
self._goal_mapping[str(goal.id)] = goal
|
||||
for step in goal.plan.steps:
|
||||
if isinstance(step, Goal):
|
||||
self._populate_goal_mapping_with_goal(step)
|
||||
|
||||
async def _create_agentspeak_and_send_to_bdi(self, program: Program):
|
||||
"""
|
||||
Convert a received program into an AgentSpeak file and send it to the BDI Core Agent.
|
||||
|
||||
:param program: The program object received from the API.
|
||||
"""
|
||||
asg = AgentSpeakGenerator()
|
||||
|
||||
asl_str = asg.generate(program)
|
||||
|
||||
file_name = settings.behaviour_settings.agentspeak_file
|
||||
|
||||
with open(file_name, "w") as f:
|
||||
f.write(asl_str)
|
||||
|
||||
msg = InternalMessage(
|
||||
sender=self.name,
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
body=file_name,
|
||||
thread="new_program",
|
||||
)
|
||||
|
||||
await self.send(msg)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
match msg.thread:
|
||||
case "transition_phase":
|
||||
phases = json.loads(msg.body)
|
||||
|
||||
await self._transition_phase(phases["old"], phases["new"])
|
||||
case "achieve_goal":
|
||||
goal_id = msg.body
|
||||
await self._send_achieved_goal_to_semantic_belief_extractor(goal_id)
|
||||
|
||||
async def _transition_phase(self, old: str, new: str):
|
||||
"""
|
||||
When receiving a signal from the BDI core that the phase has changed, apply this change to
|
||||
the current state and inform other agents about the change.
|
||||
|
||||
:param old: The ID of the old phase.
|
||||
:param new: The ID of the new phase.
|
||||
"""
|
||||
if self._phase is None:
|
||||
return
|
||||
|
||||
if old != str(self._phase.id):
|
||||
self.logger.warning(
|
||||
f"Phase transition desync detected! ASL requested move from '{old}', "
|
||||
f"but Python is currently in '{self._phase.id}'. Request ignored."
|
||||
)
|
||||
return
|
||||
|
||||
if new == "end":
|
||||
self._phase = None
|
||||
# Notify user interaction agent
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.user_interrupt_name,
|
||||
thread="transition_phase",
|
||||
body="end",
|
||||
)
|
||||
self.logger.info("Transitioned to end phase, notifying UserInterruptAgent.")
|
||||
|
||||
self.add_behavior(self.send(msg))
|
||||
return
|
||||
|
||||
for phase in self._program.phases:
|
||||
if str(phase.id) == new:
|
||||
self._phase = phase
|
||||
|
||||
await self._send_beliefs_to_semantic_belief_extractor()
|
||||
await self._send_goals_to_semantic_belief_extractor()
|
||||
|
||||
# Notify user interaction agent
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.user_interrupt_name,
|
||||
thread="transition_phase",
|
||||
body=str(self._phase.id),
|
||||
)
|
||||
self.logger.info(f"Transitioned to phase {new}, notifying UserInterruptAgent.")
|
||||
|
||||
self.add_behavior(self.send(msg))
|
||||
|
||||
def _extract_current_beliefs(self) -> list[Belief]:
|
||||
"""Extract beliefs from the current phase."""
|
||||
assert self._phase is not None, (
|
||||
"Invalid state, no phase set. Call this method only when "
|
||||
"a program has been received and the end-phase has not "
|
||||
"been reached."
|
||||
)
|
||||
|
||||
beliefs: list[Belief] = []
|
||||
|
||||
for norm in self._phase.norms:
|
||||
if isinstance(norm, ConditionalNorm):
|
||||
beliefs += self._extract_beliefs_from_belief(norm.condition)
|
||||
|
||||
for trigger in self._phase.triggers:
|
||||
beliefs += self._extract_beliefs_from_belief(trigger.condition)
|
||||
|
||||
return beliefs
|
||||
|
||||
@staticmethod
|
||||
def _extract_beliefs_from_belief(belief: Belief) -> list[Belief]:
|
||||
"""Recursively extract beliefs from the given belief."""
|
||||
if isinstance(belief, InferredBelief):
|
||||
return BDIProgramManager._extract_beliefs_from_belief(
|
||||
belief.left
|
||||
) + BDIProgramManager._extract_beliefs_from_belief(belief.right)
|
||||
return [belief]
|
||||
|
||||
async def _send_beliefs_to_semantic_belief_extractor(self):
|
||||
"""Extract beliefs from the program and send them to the Semantic Belief Extractor Agent."""
|
||||
beliefs = BeliefList(beliefs=self._extract_current_beliefs())
|
||||
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=self.name,
|
||||
body=beliefs.model_dump_json(),
|
||||
thread="beliefs",
|
||||
)
|
||||
|
||||
await self.send(message)
|
||||
|
||||
@staticmethod
|
||||
def _extract_goals_from_goal(goal: Goal) -> list[Goal]:
|
||||
"""
|
||||
Extract all goals from a given goal, that is: the goal itself and any subgoals.
|
||||
|
||||
:return: All goals within and including the given goal.
|
||||
"""
|
||||
goals: list[Goal] = [goal]
|
||||
for step in goal.plan.steps:
|
||||
if isinstance(step, Goal):
|
||||
goals.extend(BDIProgramManager._extract_goals_from_goal(step))
|
||||
return goals
|
||||
|
||||
def _extract_current_goals(self) -> list[Goal]:
|
||||
"""
|
||||
Extract all goals from the program, including subgoals.
|
||||
|
||||
:return: A list of Goal objects.
|
||||
"""
|
||||
assert self._phase is not None, (
|
||||
"Invalid state, no phase set. Call this method only when "
|
||||
"a program has been received and the end-phase has not "
|
||||
"been reached."
|
||||
)
|
||||
|
||||
goals: list[Goal] = []
|
||||
|
||||
for goal in self._phase.goals:
|
||||
goals.extend(self._extract_goals_from_goal(goal))
|
||||
|
||||
return goals
|
||||
|
||||
async def _send_goals_to_semantic_belief_extractor(self):
|
||||
"""
|
||||
Extract goals for the current phase and send them to the Semantic Belief Extractor Agent.
|
||||
"""
|
||||
goals = GoalList(goals=self._extract_current_goals())
|
||||
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=self.name,
|
||||
body=goals.model_dump_json(),
|
||||
thread="goals",
|
||||
)
|
||||
|
||||
await self.send(message)
|
||||
|
||||
async def _send_achieved_goal_to_semantic_belief_extractor(self, achieved_goal_id: str):
|
||||
"""
|
||||
Inform the semantic belief extractor when a goal is marked achieved.
|
||||
|
||||
:param achieved_goal_id: The id of the achieved goal.
|
||||
"""
|
||||
goal = self._goal_mapping.get(achieved_goal_id)
|
||||
if goal is None:
|
||||
self.logger.debug(f"Goal with ID {achieved_goal_id} marked achieved but was not found.")
|
||||
return
|
||||
|
||||
goals = self._extract_goals_from_goal(goal)
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
body=GoalList(goals=goals).model_dump_json(),
|
||||
thread="achieved_goals",
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
async def _send_clear_llm_history(self):
|
||||
"""
|
||||
Clear the LLM Agent's conversation history.
|
||||
|
||||
Sends an empty history to the LLM Agent to reset its state.
|
||||
"""
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.llm_name,
|
||||
body="clear_history",
|
||||
)
|
||||
await self.send(message)
|
||||
self.logger.debug("Sent message to LLM agent to clear history.")
|
||||
|
||||
extractor_msg = InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
thread="conversation_history",
|
||||
body="reset",
|
||||
)
|
||||
await self.send(extractor_msg)
|
||||
self.logger.debug("Sent message to extractor agent to clear history.")
|
||||
|
||||
@staticmethod
|
||||
def _rollover_experiment_logs():
|
||||
"""
|
||||
A new experiment program started; make a new experiment log file.
|
||||
"""
|
||||
handlers = logging.getLogger(settings.logging_settings.experiment_logger_name).handlers
|
||||
for handler in handlers:
|
||||
if isinstance(handler, control_backend.logging.DatedFileHandler):
|
||||
experiment_logger.action("Doing rollover...")
|
||||
handler.do_rollover()
|
||||
experiment_logger.debug("Finished rollover.")
|
||||
|
||||
async def _receive_programs(self):
|
||||
"""
|
||||
Continuous loop that receives program updates from the HTTP endpoint.
|
||||
|
||||
It listens to the ``program`` topic on the internal ZMQ SUB socket.
|
||||
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
|
||||
Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`.
|
||||
"""
|
||||
while True:
|
||||
topic, body = await self.sub_socket.recv_multipart()
|
||||
|
||||
try:
|
||||
program = Program.model_validate_json(body)
|
||||
except ValidationError:
|
||||
self.logger.warning("Received an invalid program.")
|
||||
continue
|
||||
|
||||
self._initialize_internal_state(program)
|
||||
await self._send_program_to_user_interrupt(program)
|
||||
await self._send_clear_llm_history()
|
||||
self._rollover_experiment_logs()
|
||||
|
||||
await asyncio.gather(
|
||||
self._create_agentspeak_and_send_to_bdi(program),
|
||||
self._send_beliefs_to_semantic_belief_extractor(),
|
||||
self._send_goals_to_semantic_belief_extractor(),
|
||||
)
|
||||
|
||||
async def _send_program_to_user_interrupt(self, program: Program):
|
||||
"""
|
||||
Send the received program to the User Interrupt Agent.
|
||||
|
||||
:param program: The program object received from the API.
|
||||
"""
|
||||
msg = InternalMessage(
|
||||
sender=self.name,
|
||||
to=settings.agent_settings.user_interrupt_name,
|
||||
body=program.model_dump_json(),
|
||||
thread="new_program",
|
||||
)
|
||||
|
||||
await self.send(msg)
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent.
|
||||
|
||||
Connects the internal ZMQ SUB socket and subscribes to the 'program' topic.
|
||||
Starts the background behavior to receive programs. Initializes a default program.
|
||||
"""
|
||||
await self._create_agentspeak_and_send_to_bdi(Program(phases=[]))
|
||||
|
||||
context = Context.instance()
|
||||
|
||||
self.sub_socket = context.socket(zmq.SUB)
|
||||
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.sub_socket.subscribe("program")
|
||||
|
||||
self.add_behavior(self._receive_programs())
|
||||
34
src/control_backend/agents/bdi/default_behavior.asl
Normal file
34
src/control_backend/agents/bdi/default_behavior.asl
Normal file
@@ -0,0 +1,34 @@
|
||||
phase("end").
|
||||
keyword_said(Keyword) :- (user_said(Message) & .substring(Keyword, Message, Pos)) & (Pos >= 0).
|
||||
|
||||
|
||||
+!reply_with_goal(Goal)
|
||||
: user_said(Message)
|
||||
<- +responded_this_turn;
|
||||
.findall(Norm, norm(Norm), Norms);
|
||||
.reply_with_goal(Message, Norms, Goal).
|
||||
|
||||
+!say(Text)
|
||||
<- +responded_this_turn;
|
||||
.say(Text).
|
||||
|
||||
+!reply
|
||||
: user_said(Message)
|
||||
<- +responded_this_turn;
|
||||
.findall(Norm, norm(Norm), Norms);
|
||||
.reply(Message, Norms).
|
||||
|
||||
+!notify_cycle
|
||||
<- .notify_ui;
|
||||
.wait(1).
|
||||
|
||||
+user_said(Message)
|
||||
: phase("end")
|
||||
<- .notify_user_said(Message);
|
||||
!reply.
|
||||
|
||||
+!check_triggers
|
||||
<- true.
|
||||
|
||||
+!transition_phase
|
||||
<- true.
|
||||
546
src/control_backend/agents/bdi/text_belief_extractor_agent.py
Normal file
546
src/control_backend/agents/bdi/text_belief_extractor_agent.py
Normal file
@@ -0,0 +1,546 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from control_backend.agents.base import BaseAgent
|
||||
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_list import BeliefList, GoalList
|
||||
from control_backend.schemas.belief_message import Belief as InternalBelief
|
||||
from control_backend.schemas.belief_message import BeliefMessage
|
||||
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
|
||||
from control_backend.schemas.program import BaseGoal, SemanticBelief
|
||||
|
||||
type JSONLike = None | bool | int | float | str | list["JSONLike"] | dict[str, "JSONLike"]
|
||||
|
||||
|
||||
class BeliefState(BaseModel):
|
||||
"""
|
||||
Represents the state of inferred semantic beliefs.
|
||||
|
||||
Maintains sets of beliefs that are currently considered true or false.
|
||||
"""
|
||||
|
||||
true: set[InternalBelief] = set()
|
||||
false: set[InternalBelief] = set()
|
||||
|
||||
def difference(self, other: "BeliefState") -> "BeliefState":
|
||||
return BeliefState(
|
||||
true=self.true - other.true,
|
||||
false=self.false - other.false,
|
||||
)
|
||||
|
||||
def union(self, other: "BeliefState") -> "BeliefState":
|
||||
return BeliefState(
|
||||
true=self.true | other.true,
|
||||
false=self.false | other.false,
|
||||
)
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.difference(other)
|
||||
|
||||
def __or__(self, other):
|
||||
return self.union(other)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.true) or bool(self.false)
|
||||
|
||||
|
||||
class TextBeliefExtractorAgent(BaseAgent):
|
||||
"""
|
||||
Text Belief Extractor Agent.
|
||||
|
||||
This agent is responsible for processing raw text (e.g., from speech transcription) and
|
||||
extracting semantic beliefs from it.
|
||||
|
||||
It uses the available beliefs received from the program manager to try to extract beliefs from a
|
||||
user's message, sends and updated beliefs to the BDI core, and forms a ``user_said`` belief from
|
||||
the message itself.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self._llm = self.LLM(self, settings.llm_settings.n_parallel)
|
||||
self.belief_inferrer = SemanticBeliefInferrer(self._llm)
|
||||
self.goal_inferrer = GoalAchievementInferrer(self._llm)
|
||||
self._current_beliefs = BeliefState()
|
||||
self._current_goal_completions: dict[str, bool] = {}
|
||||
self._force_completed_goals: set[BaseGoal] = set()
|
||||
self.conversation = ChatHistory(messages=[])
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent and its resources.
|
||||
"""
|
||||
self.logger.info("Setting up %s.", self.name)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming messages. Expect messages from the Transcriber agent, LLM agent, and the
|
||||
Program manager agent.
|
||||
|
||||
:param msg: The received message.
|
||||
"""
|
||||
sender = msg.sender
|
||||
|
||||
match sender:
|
||||
case settings.agent_settings.transcription_name:
|
||||
self.logger.debug("Received text from transcriber: %s", msg.body)
|
||||
self._apply_conversation_message(ChatMessage(role="user", content=msg.body))
|
||||
await self._user_said(msg.body)
|
||||
await self._infer_new_beliefs()
|
||||
await self._infer_goal_completions()
|
||||
case settings.agent_settings.llm_name:
|
||||
self.logger.debug("Received text from LLM: %s", msg.body)
|
||||
self._apply_conversation_message(ChatMessage(role="assistant", content=msg.body))
|
||||
case settings.agent_settings.bdi_program_manager_name:
|
||||
await self._handle_program_manager_message(msg)
|
||||
case _:
|
||||
self.logger.info("Discarding message from %s", sender)
|
||||
return
|
||||
|
||||
def _apply_conversation_message(self, message: ChatMessage):
|
||||
"""
|
||||
Save the chat message to our conversation history, taking into account the conversation
|
||||
length limit.
|
||||
|
||||
:param message: The chat message to add to the conversation history.
|
||||
"""
|
||||
length_limit = settings.behaviour_settings.conversation_history_length_limit
|
||||
self.conversation.messages = (self.conversation.messages + [message])[-length_limit:]
|
||||
|
||||
async def _handle_program_manager_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle a message from the program manager: extract available beliefs and goals from it.
|
||||
|
||||
:param msg: The received message from the program manager.
|
||||
"""
|
||||
match msg.thread:
|
||||
case "beliefs":
|
||||
self._handle_beliefs_message(msg)
|
||||
await self._infer_new_beliefs()
|
||||
case "goals":
|
||||
self._handle_goals_message(msg)
|
||||
await self._infer_goal_completions()
|
||||
case "achieved_goals":
|
||||
self._handle_goal_achieved_message(msg)
|
||||
case "conversation_history":
|
||||
if msg.body == "reset":
|
||||
self._reset_phase()
|
||||
case _:
|
||||
self.logger.warning("Received unexpected message from %s", msg.sender)
|
||||
|
||||
def _reset_phase(self):
|
||||
"""
|
||||
Delete all state about the current phase, such as what beliefs exist and which ones are
|
||||
true.
|
||||
"""
|
||||
self.conversation = ChatHistory(messages=[])
|
||||
self.belief_inferrer.available_beliefs.clear()
|
||||
self._current_beliefs = BeliefState()
|
||||
self.goal_inferrer.goals.clear()
|
||||
self._current_goal_completions = {}
|
||||
|
||||
def _handle_beliefs_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle the message from the Program Manager agent containing the beliefs that exist for this
|
||||
phase.
|
||||
:param msg: A list of beliefs.
|
||||
"""
|
||||
try:
|
||||
belief_list = BeliefList.model_validate_json(msg.body)
|
||||
except ValidationError:
|
||||
self.logger.warning(
|
||||
"Received message from program manager but it is not a valid list of beliefs."
|
||||
)
|
||||
return
|
||||
|
||||
available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
|
||||
self.belief_inferrer.available_beliefs = available_beliefs
|
||||
self.logger.debug(
|
||||
"Received %d semantic beliefs from the program manager: %s",
|
||||
len(available_beliefs),
|
||||
", ".join(b.name for b in available_beliefs),
|
||||
)
|
||||
|
||||
def _handle_goals_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle the message from the Program Manager agent containing the goals that exist for this
|
||||
phase.
|
||||
:param msg: A list of goals.
|
||||
"""
|
||||
try:
|
||||
goals_list = GoalList.model_validate_json(msg.body)
|
||||
except ValidationError:
|
||||
self.logger.warning(
|
||||
"Received message from program manager but it is not a valid list of goals."
|
||||
)
|
||||
return
|
||||
|
||||
# Use only goals that can fail, as the others are always assumed to be completed
|
||||
available_goals = {g for g in goals_list.goals if g.can_fail}
|
||||
available_goals -= self._force_completed_goals
|
||||
self.goal_inferrer.goals = available_goals
|
||||
self.logger.debug(
|
||||
"Received %d failable goals from the program manager: %s",
|
||||
len(available_goals),
|
||||
", ".join(g.name for g in available_goals),
|
||||
)
|
||||
|
||||
def _handle_goal_achieved_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle message that gets sent when goals are marked achieved from a user interrupt. This
|
||||
goal should then not be changed by this agent anymore.
|
||||
:param msg: List of goals that are marked achieved.
|
||||
"""
|
||||
# NOTE: When goals can be marked unachieved, remember to re-add them to the goal_inferrer
|
||||
try:
|
||||
goals_list = GoalList.model_validate_json(msg.body)
|
||||
except ValidationError:
|
||||
self.logger.warning(
|
||||
"Received goal achieved message from the program manager, "
|
||||
"but it is not a valid list of goals."
|
||||
)
|
||||
return
|
||||
|
||||
for goal in goals_list.goals:
|
||||
self._force_completed_goals.add(goal)
|
||||
self._current_goal_completions[f"achieved_{AgentSpeakGenerator.slugify(goal)}"] = True
|
||||
|
||||
self.goal_inferrer.goals -= self._force_completed_goals
|
||||
|
||||
async def _user_said(self, text: str):
|
||||
"""
|
||||
Create a belief for the user's full speech.
|
||||
|
||||
:param text: User's transcribed text.
|
||||
"""
|
||||
belief_msg = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
sender=self.name,
|
||||
body=BeliefMessage(
|
||||
replace=[InternalBelief(name="user_said", arguments=[text])],
|
||||
).model_dump_json(),
|
||||
thread="beliefs",
|
||||
)
|
||||
await self.send(belief_msg)
|
||||
|
||||
async def _infer_new_beliefs(self):
|
||||
"""
|
||||
Determine which beliefs hold and do not hold for the current conversation state. When
|
||||
beliefs change, a message is sent to the BDI core.
|
||||
"""
|
||||
conversation_beliefs = await self.belief_inferrer.infer_from_conversation(self.conversation)
|
||||
|
||||
new_beliefs = conversation_beliefs - self._current_beliefs
|
||||
if not new_beliefs:
|
||||
self.logger.debug("No new beliefs detected.")
|
||||
return
|
||||
|
||||
self._current_beliefs |= new_beliefs
|
||||
|
||||
belief_changes = BeliefMessage(
|
||||
create=list(new_beliefs.true),
|
||||
delete=list(new_beliefs.false),
|
||||
)
|
||||
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
sender=self.name,
|
||||
body=belief_changes.model_dump_json(),
|
||||
thread="beliefs",
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
async def _infer_goal_completions(self):
|
||||
"""
|
||||
Determine which goals have been achieved given the current conversation state. When
|
||||
a goal's achieved state changes, a message is sent to the BDI core.
|
||||
"""
|
||||
goal_completions = await self.goal_inferrer.infer_from_conversation(self.conversation)
|
||||
|
||||
new_achieved = [
|
||||
InternalBelief(name=goal, arguments=None)
|
||||
for goal, achieved in goal_completions.items()
|
||||
if achieved and self._current_goal_completions.get(goal) != achieved
|
||||
]
|
||||
new_not_achieved = [
|
||||
InternalBelief(name=goal, arguments=None)
|
||||
for goal, achieved in goal_completions.items()
|
||||
if not achieved and self._current_goal_completions.get(goal) != achieved
|
||||
]
|
||||
for goal, achieved in goal_completions.items():
|
||||
self._current_goal_completions[goal] = achieved
|
||||
|
||||
if not new_achieved and not new_not_achieved:
|
||||
self.logger.debug("No goal achievement changes detected.")
|
||||
return
|
||||
|
||||
belief_changes = BeliefMessage(
|
||||
create=new_achieved,
|
||||
delete=new_not_achieved,
|
||||
)
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
sender=self.name,
|
||||
body=belief_changes.model_dump_json(),
|
||||
thread="beliefs",
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
class LLM:
|
||||
"""
|
||||
Class that handles sending structured generation requests to an LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: "TextBeliefExtractorAgent", n_parallel: int):
|
||||
self._agent = agent
|
||||
self._semaphore = asyncio.Semaphore(n_parallel)
|
||||
|
||||
async def query(self, prompt: str, schema: dict, tries: int = 3) -> JSONLike | None:
|
||||
"""
|
||||
Query the LLM with the given prompt and schema, return an instance of a dict conforming
|
||||
to this schema. Try ``tries`` times, or return None.
|
||||
|
||||
:param prompt: Prompt to be queried.
|
||||
:param schema: Schema to be queried.
|
||||
:param tries: Number of times to try to query the LLM.
|
||||
:return: An instance of a dict conforming to this schema, or None if failed.
|
||||
"""
|
||||
try_count = 0
|
||||
while try_count < tries:
|
||||
try_count += 1
|
||||
|
||||
try:
|
||||
return await self._query_llm(prompt, schema)
|
||||
except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e:
|
||||
if try_count < tries:
|
||||
continue
|
||||
self._agent.logger.exception(
|
||||
"Failed to get LLM response after %d tries.",
|
||||
try_count,
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _query_llm(self, prompt: str, schema: dict) -> JSONLike:
|
||||
"""
|
||||
Query an LLM with the given prompt and schema, return an instance of a dict conforming
|
||||
to that schema.
|
||||
|
||||
:param prompt: The prompt to be queried.
|
||||
:param schema: Schema to use during response.
|
||||
:return: A dict conforming to this schema.
|
||||
:raises httpx.HTTPStatusError: If the LLM server responded with an error.
|
||||
:raises json.JSONDecodeError: If the LLM response was not valid JSON. May happen if the
|
||||
response was cut off early due to length limitations.
|
||||
:raises KeyError: If the LLM server responded with no error, but the response was
|
||||
invalid.
|
||||
"""
|
||||
async with self._semaphore:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.llm_settings.local_llm_url,
|
||||
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
|
||||
if settings.llm_settings.api_key
|
||||
else {},
|
||||
json={
|
||||
"model": settings.llm_settings.local_llm_model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "Beliefs",
|
||||
"strict": True,
|
||||
"schema": schema,
|
||||
},
|
||||
},
|
||||
"reasoning_effort": "low",
|
||||
"temperature": settings.llm_settings.code_temperature,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
json_message = response_json["choices"][0]["message"]["content"]
|
||||
return json.loads(json_message)
|
||||
|
||||
|
||||
class SemanticBeliefInferrer:
|
||||
"""
|
||||
Infers semantic beliefs from conversation history using an LLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: "TextBeliefExtractorAgent.LLM",
|
||||
available_beliefs: list[SemanticBelief] | None = None,
|
||||
):
|
||||
self._llm = llm
|
||||
self.available_beliefs: list[SemanticBelief] = available_beliefs or []
|
||||
|
||||
async def infer_from_conversation(self, conversation: ChatHistory) -> BeliefState:
|
||||
"""
|
||||
Process conversation history to extract beliefs, semantically. The result is an object that
|
||||
describes all beliefs that hold or don't hold based on the full conversation.
|
||||
|
||||
:param conversation: The conversation history to be processed.
|
||||
:return: An object that describes beliefs.
|
||||
"""
|
||||
# Return instantly if there are no beliefs to infer
|
||||
if not self.available_beliefs:
|
||||
return BeliefState()
|
||||
|
||||
n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)))
|
||||
all_beliefs: list[dict[str, bool | None] | None] = await asyncio.gather(
|
||||
*[
|
||||
self._infer_beliefs(conversation, beliefs)
|
||||
for beliefs in self._split_into_chunks(self.available_beliefs, n_parallel)
|
||||
]
|
||||
)
|
||||
new_beliefs = BeliefState()
|
||||
# Collect beliefs from all parallel calls
|
||||
for beliefs in all_beliefs:
|
||||
if beliefs is None:
|
||||
continue
|
||||
# For each, convert them to InternalBeliefs
|
||||
for belief_name, belief_holds in beliefs.items():
|
||||
# Skip beliefs that were marked not possible to determine
|
||||
if belief_holds is None:
|
||||
continue
|
||||
belief = InternalBelief(name=belief_name, arguments=None)
|
||||
if belief_holds:
|
||||
new_beliefs.true.add(belief)
|
||||
else:
|
||||
new_beliefs.false.add(belief)
|
||||
return new_beliefs
|
||||
|
||||
@staticmethod
|
||||
def _split_into_chunks[T](items: list[T], n: int) -> list[list[T]]:
|
||||
"""
|
||||
Split a list into ``n`` chunks, making each chunk approximately ``len(items) / n`` long.
|
||||
|
||||
:param items: The list of items to split.
|
||||
:param n: The number of desired chunks.
|
||||
:return: A list of chunks each approximately ``len(items) / n`` long.
|
||||
"""
|
||||
k, m = divmod(len(items), n)
|
||||
return [items[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)]
|
||||
|
||||
async def _infer_beliefs(
|
||||
self,
|
||||
conversation: ChatHistory,
|
||||
beliefs: list[SemanticBelief],
|
||||
) -> dict[str, bool | None] | None:
|
||||
"""
|
||||
Infer given beliefs based on the given conversation.
|
||||
:param conversation: The conversation to infer beliefs from.
|
||||
:param beliefs: The beliefs to infer.
|
||||
:return: A dict containing belief names and a boolean whether they hold, or None if the
|
||||
belief cannot be inferred based on the given conversation.
|
||||
"""
|
||||
example = {
|
||||
"example_belief": True,
|
||||
}
|
||||
|
||||
prompt = f"""{self._format_conversation(conversation)}
|
||||
|
||||
Given the above conversation, what beliefs can be inferred?
|
||||
If there is no relevant information about a belief belief, give null.
|
||||
In case messages conflict, prefer using the most recent messages for inference.
|
||||
|
||||
Choose from the following list of beliefs, formatted as `- <belief_name>: <description>`:
|
||||
{self._format_beliefs(beliefs)}
|
||||
|
||||
Respond with a JSON similar to the following, but with the property names as given above:
|
||||
{json.dumps(example, indent=2)}
|
||||
"""
|
||||
|
||||
schema = self._create_beliefs_schema(beliefs)
|
||||
|
||||
return await self._llm.query(prompt, schema)
|
||||
|
||||
@staticmethod
|
||||
def _create_belief_schema(belief: SemanticBelief) -> tuple[str, dict]:
|
||||
return AgentSpeakGenerator.slugify(belief), {
|
||||
"type": ["boolean", "null"],
|
||||
"description": belief.description,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _create_beliefs_schema(beliefs: list[SemanticBelief]) -> dict:
|
||||
belief_schemas = [
|
||||
SemanticBeliefInferrer._create_belief_schema(belief) for belief in beliefs
|
||||
]
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": dict(belief_schemas),
|
||||
"required": [name for name, _ in belief_schemas],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_message(message: ChatMessage):
|
||||
return f"{message.role.upper()}:\n{message.content}"
|
||||
|
||||
@staticmethod
|
||||
def _format_conversation(conversation: ChatHistory):
|
||||
return "\n\n".join(
|
||||
[SemanticBeliefInferrer._format_message(message) for message in conversation.messages]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_beliefs(beliefs: list[SemanticBelief]):
|
||||
return "\n".join(
|
||||
[f"- {AgentSpeakGenerator.slugify(belief)}: {belief.description}" for belief in beliefs]
|
||||
)
|
||||
|
||||
|
||||
class GoalAchievementInferrer(SemanticBeliefInferrer):
|
||||
"""
|
||||
Infers whether specific conversational goals have been achieved using an LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: TextBeliefExtractorAgent.LLM):
|
||||
super().__init__(llm)
|
||||
self.goals: set[BaseGoal] = set()
|
||||
|
||||
async def infer_from_conversation(self, conversation: ChatHistory) -> dict[str, bool]:
|
||||
"""
|
||||
Determine which goals have been achieved based on the given conversation.
|
||||
|
||||
:param conversation: The conversation to infer goal completion from.
|
||||
:return: A mapping of goals and a boolean whether they have been achieved.
|
||||
"""
|
||||
if not self.goals:
|
||||
return {}
|
||||
|
||||
goals_achieved = await asyncio.gather(
|
||||
*[self._infer_goal(conversation, g) for g in self.goals]
|
||||
)
|
||||
return {
|
||||
f"achieved_{AgentSpeakGenerator.slugify(goal)}": achieved
|
||||
for goal, achieved in zip(self.goals, goals_achieved, strict=True)
|
||||
}
|
||||
|
||||
async def _infer_goal(self, conversation: ChatHistory, goal: BaseGoal) -> bool:
|
||||
prompt = f"""{self._format_conversation(conversation)}
|
||||
|
||||
Given the above conversation, what has the following goal been achieved?
|
||||
|
||||
The name of the goal: {goal.name}
|
||||
Description of the goal: {goal.description}
|
||||
|
||||
Answer with literally only `true` or `false` (without backticks)."""
|
||||
|
||||
schema = {
|
||||
"type": "boolean",
|
||||
}
|
||||
|
||||
return await self._llm.query(prompt, schema)
|
||||
5
src/control_backend/agents/communication/__init__.py
Normal file
5
src/control_backend/agents/communication/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Agents responsible for external communication and service discovery.
|
||||
"""
|
||||
|
||||
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent
|
||||
@@ -0,0 +1,330 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
from pydantic import ValidationError
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.internal_message import InternalMessage
|
||||
from control_backend.schemas.ri_message import PauseCommand
|
||||
|
||||
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||
from ..perception import VADAgent
|
||||
|
||||
|
||||
class RICommunicationAgent(BaseAgent):
|
||||
"""
|
||||
Robot Interface (RI) Communication Agent.
|
||||
|
||||
This agent manages the high-level connection negotiation and health checking (heartbeat)
|
||||
between the Control Backend and the Robot Interface (or UI).
|
||||
|
||||
It acts as a service discovery mechanism:
|
||||
1. It initiates a handshake (negotiation) to discover where other services (like the robot
|
||||
command listener) are listening.
|
||||
2. It spawns specific agents
|
||||
(like :class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`)
|
||||
once the connection details are established.
|
||||
3. It maintains a "ping" loop to ensure the connection remains active.
|
||||
|
||||
:ivar _address: The ZMQ address to attempt the initial connection negotiation.
|
||||
:ivar _bind: Whether to bind or connect the negotiation socket.
|
||||
:ivar _req_socket: ZMQ REQ socket for negotiation and pings.
|
||||
:ivar pub_socket: ZMQ PUB socket for internal notifications (e.g., ping status).
|
||||
:ivar connected: Boolean flag indicating active connection status.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
address=settings.zmq_settings.ri_communication_address,
|
||||
bind=False,
|
||||
):
|
||||
super().__init__(name)
|
||||
self._address = address
|
||||
self._bind = bind
|
||||
self._req_socket: azmq.Socket | None = None
|
||||
self.pub_socket: azmq.Socket | None = None
|
||||
self.connected = False
|
||||
self.gesture_agent: RobotGestureAgent | None = None
|
||||
self.speech_agent: RobotSpeechAgent | None = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent and attempt connection.
|
||||
|
||||
Tries to negotiate connection up to ``behaviour_settings.comm_setup_max_retries`` times.
|
||||
If successful, starts the :meth:`_listen_loop`.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
# Bind request socket
|
||||
await self._setup_sockets()
|
||||
|
||||
if await self._negotiate_connection():
|
||||
self.connected = True
|
||||
self.add_behavior(self._listen_loop())
|
||||
else:
|
||||
self.logger.warning("Failed to negotiate connection during setup.")
|
||||
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def _setup_sockets(self, force=False):
|
||||
"""
|
||||
Initialize ZMQ sockets (REQ for negotiation, PUB for internal updates).
|
||||
"""
|
||||
# Bind request socket
|
||||
if self._req_socket is None or force:
|
||||
self._req_socket = Context.instance().socket(zmq.REQ)
|
||||
if self._bind:
|
||||
self._req_socket.bind(self._address)
|
||||
else:
|
||||
self._req_socket.connect(self._address)
|
||||
|
||||
if self.pub_socket is None or force:
|
||||
self.pub_socket = Context.instance().socket(zmq.PUB)
|
||||
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||
|
||||
async def _negotiate_connection(
|
||||
self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries
|
||||
):
|
||||
"""
|
||||
Perform the handshake protocol with the Robot Interface.
|
||||
|
||||
Sends a ``negotiate/ports`` request and expects a configuration response containing
|
||||
port assignments for various services (e.g., actuation).
|
||||
|
||||
:param max_retries: Number of attempts before giving up.
|
||||
:return: True if negotiation succeeded, False otherwise.
|
||||
"""
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
if self._req_socket is None:
|
||||
retries += 1
|
||||
continue
|
||||
|
||||
# Send our message and receive one back
|
||||
message = {"endpoint": "negotiate/ports", "data": {}}
|
||||
await self._req_socket.send_json(message)
|
||||
|
||||
retry_frequency = 1.0
|
||||
try:
|
||||
received_message = await asyncio.wait_for(
|
||||
self._req_socket.recv_json(), timeout=retry_frequency
|
||||
)
|
||||
except TimeoutError:
|
||||
self.logger.warning(
|
||||
"No connection established in %d seconds (attempt %d/%d)",
|
||||
retries * retry_frequency,
|
||||
retries + 1,
|
||||
max_retries,
|
||||
)
|
||||
retries += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.warning("Unexpected error during negotiation: %s", e)
|
||||
retries += 1
|
||||
continue
|
||||
|
||||
# Validate endpoint
|
||||
endpoint = received_message.get("endpoint")
|
||||
if endpoint != "negotiate/ports":
|
||||
self.logger.warning(
|
||||
"Invalid endpoint '%s' received (attempt %d/%d)",
|
||||
endpoint,
|
||||
retries + 1,
|
||||
max_retries,
|
||||
)
|
||||
retries += 1
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# At this point, we have a valid response
|
||||
try:
|
||||
self.logger.debug("Negotiation successful.")
|
||||
await self._handle_negotiation_response(received_message)
|
||||
# Let UI know that we're connected
|
||||
topic = b"ping"
|
||||
data = json.dumps(True).encode()
|
||||
if self.pub_socket:
|
||||
await self.pub_socket.send_multipart([topic, data])
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.warning("Error unpacking negotiation data: %s", e)
|
||||
retries += 1
|
||||
await asyncio.sleep(settings.behaviour_settings.sleep_s)
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
async def _handle_negotiation_response(self, received_message):
|
||||
"""
|
||||
Parse the negotiation response and initialize services.
|
||||
|
||||
Based on the response, it might re-connect the main socket or spawn new agents
|
||||
(e.g., for robot actuation).
|
||||
"""
|
||||
for port_data in received_message["data"]:
|
||||
id = port_data["id"]
|
||||
port = port_data["port"]
|
||||
bind = port_data["bind"]
|
||||
|
||||
if not bind:
|
||||
addr = f"tcp://{settings.ri_host}:{port}"
|
||||
else:
|
||||
addr = f"tcp://*:{port}"
|
||||
|
||||
match id:
|
||||
case "main":
|
||||
if addr != self._address:
|
||||
assert self._req_socket is not None
|
||||
if not bind:
|
||||
self._req_socket.connect(addr)
|
||||
else:
|
||||
self._req_socket.bind(addr)
|
||||
case "actuation":
|
||||
gesture_data = port_data.get("gestures", [])
|
||||
single_gesture_data = port_data.get("single_gestures", [])
|
||||
robot_speech_agent = RobotSpeechAgent(
|
||||
settings.agent_settings.robot_speech_name,
|
||||
address=addr,
|
||||
bind=bind,
|
||||
)
|
||||
self.speech_agent = robot_speech_agent
|
||||
robot_gesture_agent = RobotGestureAgent(
|
||||
settings.agent_settings.robot_gesture_name,
|
||||
address=addr,
|
||||
bind=bind,
|
||||
gesture_data=gesture_data,
|
||||
single_gesture_data=single_gesture_data,
|
||||
)
|
||||
self.gesture_agent = robot_gesture_agent
|
||||
await robot_speech_agent.start()
|
||||
await asyncio.sleep(0.1) # Small delay
|
||||
await robot_gesture_agent.start()
|
||||
case "audio":
|
||||
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
||||
await vad_agent.start()
|
||||
case _:
|
||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Closes all sockets.
|
||||
:return:
|
||||
"""
|
||||
if self._req_socket:
|
||||
self._req_socket.close()
|
||||
if self.pub_socket:
|
||||
self.pub_socket.close()
|
||||
await super().stop()
|
||||
|
||||
async def _listen_loop(self):
|
||||
"""
|
||||
Maintain the connection via a heartbeat (ping) loop.
|
||||
|
||||
Sends a ``ping`` request periodically and waits for a reply.
|
||||
If pings fail repeatedly, it triggers a disconnection handler to restart negotiation.
|
||||
"""
|
||||
while self._running:
|
||||
if not self.connected:
|
||||
await asyncio.sleep(settings.behaviour_settings.sleep_s)
|
||||
self.logger.debug("Not connected, skipping ping loop iteration.")
|
||||
continue
|
||||
|
||||
# We need to listen and send pings.
|
||||
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
|
||||
seconds_to_wait_total = settings.behaviour_settings.sleep_s
|
||||
try:
|
||||
assert self._req_socket is not None
|
||||
await asyncio.wait_for(
|
||||
self._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
|
||||
)
|
||||
except TimeoutError:
|
||||
self.logger.debug(
|
||||
"Waited too long to send message - "
|
||||
"we probably dont have any receivers... but let's check!"
|
||||
)
|
||||
|
||||
# Wait up to {seconds_to_wait_total/2} seconds for a reply
|
||||
try:
|
||||
assert self._req_socket is not None
|
||||
message = await asyncio.wait_for(
|
||||
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
|
||||
)
|
||||
|
||||
if "endpoint" in message and message["endpoint"] != "ping":
|
||||
self.logger.debug(f'Received message "{message}" from RI.')
|
||||
if "endpoint" not in message:
|
||||
self.logger.warning("No received endpoint in message, expected ping endpoint.")
|
||||
continue
|
||||
|
||||
# See what endpoint we received
|
||||
match message["endpoint"]:
|
||||
case "ping":
|
||||
topic = b"ping"
|
||||
data = json.dumps(True).encode()
|
||||
if self.pub_socket is not None:
|
||||
await self.pub_socket.send_multipart([topic, data])
|
||||
await asyncio.sleep(settings.behaviour_settings.sleep_s)
|
||||
case _:
|
||||
self.logger.debug(
|
||||
"Received message with topic different than ping, while ping expected."
|
||||
)
|
||||
# We didnt get a reply
|
||||
except TimeoutError:
|
||||
self.logger.info(
|
||||
f"No ping retrieved in {seconds_to_wait_total} seconds, "
|
||||
"sending UI disconnection event and attempting to restart."
|
||||
)
|
||||
await self._handle_disconnection()
|
||||
continue
|
||||
except Exception:
|
||||
self.logger.error("Error while waiting for ping message.", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _handle_disconnection(self):
|
||||
"""
|
||||
Handle connection loss.
|
||||
|
||||
Notifies the UI of disconnection (via internal PUB) and attempts to restart negotiation.
|
||||
"""
|
||||
self.connected = False
|
||||
|
||||
# Tell UI we're disconnected.
|
||||
topic = b"ping"
|
||||
data = json.dumps(False).encode()
|
||||
self.logger.debug("1")
|
||||
if self.pub_socket:
|
||||
try:
|
||||
self.logger.debug("2")
|
||||
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
|
||||
except TimeoutError:
|
||||
self.logger.debug("3")
|
||||
self.logger.warning("Connection ping for router timed out.")
|
||||
|
||||
# Try to reboot/renegotiate
|
||||
if self.gesture_agent is not None:
|
||||
await self.gesture_agent.stop()
|
||||
|
||||
if self.speech_agent is not None:
|
||||
await self.speech_agent.stop()
|
||||
|
||||
if self.pub_socket is not None:
|
||||
self.pub_socket.close()
|
||||
|
||||
self.logger.debug("Restarting communication negotiation.")
|
||||
if await self._negotiate_connection(max_retries=2):
|
||||
self.connected = True
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
try:
|
||||
pause_command = PauseCommand.model_validate_json(msg.body)
|
||||
await self._req_socket.send_json(pause_command.model_dump())
|
||||
self.logger.debug(await self._req_socket.recv_json())
|
||||
except ValidationError:
|
||||
self.logger.warning("Incorrect message format for PauseCommand.")
|
||||
5
src/control_backend/agents/llm/__init__.py
Normal file
5
src/control_backend/agents/llm/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Agents that interface with Large Language Models for natural language processing and generation.
|
||||
"""
|
||||
|
||||
from .llm_agent import LLMAgent as LLMAgent
|
||||
252
src/control_backend/agents/llm/llm_agent.py
Normal file
252
src/control_backend/agents/llm/llm_agent.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
|
||||
from ...schemas.llm_prompt_message import LLMPromptMessage
|
||||
from .llm_instructions import LLMInstructions
|
||||
|
||||
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||
|
||||
|
||||
class LLMAgent(BaseAgent):
|
||||
"""
|
||||
LLM Agent.
|
||||
|
||||
This agent is responsible for processing user text input and querying a locally
|
||||
hosted LLM for text generation. It acts as the conversational brain of the system.
|
||||
|
||||
It receives :class:`~control_backend.schemas.llm_prompt_message.LLMPromptMessage`
|
||||
payloads from the BDI Core Agent, constructs a conversation history, queries the
|
||||
LLM via HTTP, and streams the response back to the BDI agent in natural chunks
|
||||
(e.g., sentence by sentence).
|
||||
|
||||
:ivar history: A list of dictionaries representing the conversation history (Role/Content).
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.history = []
|
||||
self._querying = False
|
||||
self._interrupted = False
|
||||
self._interrupted_message = ""
|
||||
self._go_ahead = asyncio.Event()
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Setting up %s.", self.name)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming messages.
|
||||
|
||||
Expects messages from :attr:`settings.agent_settings.bdi_core_name` containing
|
||||
an :class:`LLMPromptMessage` in the body.
|
||||
|
||||
:param msg: The received internal message.
|
||||
"""
|
||||
if msg.sender == settings.agent_settings.bdi_core_name:
|
||||
match msg.thread:
|
||||
case "prompt_message":
|
||||
try:
|
||||
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
||||
self.add_behavior(self._process_bdi_message(prompt_message)) # no block
|
||||
except ValidationError:
|
||||
self.logger.debug("Prompt message from BDI core is invalid.")
|
||||
case "assistant_message":
|
||||
self._apply_conversation_message({"role": "assistant", "content": msg.body})
|
||||
case "user_message":
|
||||
self._apply_conversation_message({"role": "user", "content": msg.body})
|
||||
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
|
||||
if msg.body == "clear_history":
|
||||
self.logger.debug("Clearing conversation history.")
|
||||
self.history.clear()
|
||||
else:
|
||||
self.logger.debug("Message ignored.")
|
||||
|
||||
async def _process_bdi_message(self, message: LLMPromptMessage):
|
||||
"""
|
||||
Orchestrate the LLM query and response streaming.
|
||||
|
||||
Iterates over the chunks yielded by :meth:`_query_llm` and forwards them
|
||||
individually to the BDI agent via :meth:`_send_reply`.
|
||||
|
||||
:param message: The parsed prompt message containing text, norms, and goals.
|
||||
"""
|
||||
if self._querying:
|
||||
self.logger.debug("Received another BDI prompt while processing previous message.")
|
||||
self._interrupted = True # interrupt the previous processing
|
||||
await self._go_ahead.wait() # wait until we get the go-ahead
|
||||
|
||||
message.text = f"{self._interrupted_message} {message.text}"
|
||||
|
||||
self._go_ahead.clear()
|
||||
self._querying = True
|
||||
full_message = ""
|
||||
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||||
if self._interrupted:
|
||||
self._interrupted_message = message.text
|
||||
self.logger.debug("Interrupted processing of previous message.")
|
||||
break
|
||||
await self._send_reply(chunk)
|
||||
full_message += chunk
|
||||
else:
|
||||
self._querying = False
|
||||
|
||||
self._apply_conversation_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_message,
|
||||
}
|
||||
)
|
||||
self.logger.debug(
|
||||
"Finished processing BDI message. Response sent in chunks to BDI core."
|
||||
)
|
||||
await self._send_full_reply(full_message)
|
||||
|
||||
self._go_ahead.set()
|
||||
self._interrupted = False
|
||||
|
||||
def _apply_conversation_message(self, message: dict[str, str]):
|
||||
if len(self.history) > 0 and message["role"] == self.history[-1]["role"]:
|
||||
self.history[-1]["content"] += " " + message["content"]
|
||||
return
|
||||
self.history.append(message)
|
||||
|
||||
async def _send_reply(self, msg: str):
|
||||
"""
|
||||
Sends a response message (chunk) back to the BDI Core Agent.
|
||||
|
||||
:param msg: The text content of the chunk.
|
||||
"""
|
||||
reply = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
sender=self.name,
|
||||
body=msg,
|
||||
)
|
||||
await self.send(reply)
|
||||
|
||||
async def _send_full_reply(self, msg: str):
|
||||
"""
|
||||
Sends a response message (full) to agents that need it.
|
||||
|
||||
:param msg: The text content of the message.
|
||||
"""
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=self.name,
|
||||
body=msg,
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
async def _query_llm(
|
||||
self, prompt: str, norms: list[str], goals: list[str]
|
||||
) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Send a chat completion request to the local LLM service and stream the response.
|
||||
|
||||
It constructs the full prompt using
|
||||
:class:`~control_backend.agents.llm.llm_instructions.LLMInstructions`.
|
||||
It streams the response from the LLM and buffers tokens until a natural break (punctuation)
|
||||
is reached, then yields the chunk. This ensures that the robot speaks in complete phrases
|
||||
rather than individual tokens.
|
||||
|
||||
:param prompt: Input text prompt to pass to the LLM.
|
||||
:param norms: Norms the LLM should hold itself to.
|
||||
:param goals: Goals the LLM should achieve.
|
||||
:yield: Fragments of the LLM-generated content (e.g., sentences/phrases).
|
||||
"""
|
||||
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
|
||||
messages = [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": instructions.build_developer_instruction(),
|
||||
},
|
||||
*self.history,
|
||||
]
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
full_message = ""
|
||||
current_chunk = ""
|
||||
async for token in self._stream_query_llm(messages):
|
||||
full_message += token
|
||||
current_chunk += token
|
||||
|
||||
experiment_logger.chat(
|
||||
full_message,
|
||||
extra={"role": "assistant", "reference": message_id, "partial": True},
|
||||
)
|
||||
|
||||
# Stream the message in chunks separated by punctuation.
|
||||
# We include the delimiter in the emitted chunk for natural flow.
|
||||
pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL)
|
||||
for m in pattern.finditer(current_chunk):
|
||||
chunk = m.group(0)
|
||||
if chunk:
|
||||
yield current_chunk
|
||||
current_chunk = ""
|
||||
|
||||
# Yield any remaining tail
|
||||
if current_chunk:
|
||||
yield current_chunk
|
||||
|
||||
experiment_logger.chat(
|
||||
full_message,
|
||||
extra={"role": "assistant", "reference": message_id, "partial": False},
|
||||
)
|
||||
except httpx.HTTPError as err:
|
||||
self.logger.error("HTTP error.", exc_info=err)
|
||||
yield "LLM service unavailable."
|
||||
except Exception as err:
|
||||
self.logger.error("Unexpected error.", exc_info=err)
|
||||
yield "Error processing the request."
|
||||
|
||||
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Perform the raw HTTP streaming request to the LLM API.
|
||||
|
||||
:param messages: The list of message dictionaries (role/content).
|
||||
:yield: Raw text tokens (deltas) from the SSE stream.
|
||||
:raises httpx.HTTPError: If the API returns a non-200 status.
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
settings.llm_settings.local_llm_url,
|
||||
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
|
||||
if settings.llm_settings.api_key
|
||||
else {},
|
||||
json={
|
||||
"model": settings.llm_settings.local_llm_model,
|
||||
"messages": messages,
|
||||
"temperature": settings.llm_settings.chat_temperature,
|
||||
"stream": True,
|
||||
},
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data = line[len("data: ") :]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
event = json.loads(data)
|
||||
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
|
||||
if delta:
|
||||
yield delta
|
||||
except json.JSONDecodeError:
|
||||
self.logger.error("Failed to parse LLM response: %s", data)
|
||||
63
src/control_backend/agents/llm/llm_instructions.py
Normal file
63
src/control_backend/agents/llm/llm_instructions.py
Normal file
@@ -0,0 +1,63 @@
|
||||
class LLMInstructions:
|
||||
"""
|
||||
Helper class to construct the system instructions for the LLM.
|
||||
|
||||
It combines the base persona (Pepper robot) with dynamic norms and goals
|
||||
provided by the BDI system.
|
||||
|
||||
If no norms/goals are given it assumes empty lists.
|
||||
|
||||
:ivar norms: A list of behavioral norms.
|
||||
:ivar goals: A list of specific conversational goals.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def default_norms() -> list[str]:
|
||||
return [
|
||||
"Be friendly and respectful.",
|
||||
"Make the conversation feel natural and engaging.",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def default_goals() -> list[str]:
|
||||
return [
|
||||
"Try to learn the user's name during conversation.",
|
||||
]
|
||||
|
||||
def __init__(self, norms: list[str] | None = None, goals: list[str] | None = None):
|
||||
self.norms = norms or self.default_norms()
|
||||
self.goals = goals or self.default_goals()
|
||||
|
||||
def build_developer_instruction(self) -> str:
|
||||
"""
|
||||
Builds the final system prompt string.
|
||||
|
||||
The prompt includes:
|
||||
1. Persona definition.
|
||||
2. Constraint on response length.
|
||||
3. Instructions on how to handle goals (reach them in order, but prioritize natural flow).
|
||||
4. The specific list of norms.
|
||||
5. The specific list of goals.
|
||||
|
||||
:return: The formatted system prompt string.
|
||||
"""
|
||||
sections = [
|
||||
"You are a Pepper robot engaging in natural human conversation.",
|
||||
"Keep responses between 1–3 sentences, unless told otherwise.\n",
|
||||
"You're given goals to reach. Reach them in order, but make the conversation feel "
|
||||
"natural. Some turns you should not try to achieve your goals.\n",
|
||||
]
|
||||
|
||||
if self.norms:
|
||||
sections.append("Norms to follow:")
|
||||
for norm in self.norms:
|
||||
sections.append("- " + norm)
|
||||
sections.append("")
|
||||
|
||||
if self.goals:
|
||||
sections.append("Goals to reach:")
|
||||
for goal in self.goals:
|
||||
sections.append("- " + goal)
|
||||
sections.append("")
|
||||
|
||||
return "\n".join(sections).strip()
|
||||
9
src/control_backend/agents/perception/__init__.py
Normal file
9
src/control_backend/agents/perception/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Agents responsible for processing sensory input, such as audio transcription and voice activity
|
||||
detection.
|
||||
"""
|
||||
|
||||
from .transcription_agent.transcription_agent import (
|
||||
TranscriptionAgent as TranscriptionAgent,
|
||||
)
|
||||
from .vad_agent import VADAgent as VADAgent
|
||||
@@ -0,0 +1,150 @@
|
||||
import abc
|
||||
import sys
|
||||
|
||||
if sys.platform == "darwin":
|
||||
import mlx.core as mx
|
||||
import mlx_whisper
|
||||
from mlx_whisper.transcribe import ModelHolder
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import whisper
|
||||
|
||||
from control_backend.core.config import settings
|
||||
|
||||
|
||||
class SpeechRecognizer(abc.ABC):
|
||||
"""
|
||||
Abstract base class for speech recognition backends.
|
||||
|
||||
Provides a common interface for loading models and transcribing audio,
|
||||
as well as heuristics for estimating token counts to optimize decoding.
|
||||
|
||||
:ivar limit_output_length: If True, limits the generated text length based on audio duration.
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
"""
|
||||
:param limit_output_length: When ``True``, the length of the generated speech will be
|
||||
limited by the length of the input audio and some heuristics.
|
||||
"""
|
||||
self.limit_output_length = limit_output_length
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_model(self):
|
||||
"""
|
||||
Load the speech recognition model into memory.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
"""
|
||||
Recognize speech from the given audio sample.
|
||||
|
||||
:param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the
|
||||
range [-1.0, 1.0].
|
||||
:return: The recognized speech text.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _estimate_max_tokens(audio: np.ndarray) -> int:
|
||||
"""
|
||||
Estimate the maximum length of a given audio sample in tokens.
|
||||
|
||||
Assumes a maximum speaking rate of 450 words per minute (3x average), and assumes that
|
||||
3 words is approx. 4 tokens.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use for length estimation.
|
||||
:return: The estimated length of the transcribed audio in tokens.
|
||||
"""
|
||||
length_seconds = len(audio) / settings.vad_settings.sample_rate_hz
|
||||
length_minutes = length_seconds / 60
|
||||
word_count = length_minutes * settings.behaviour_settings.transcription_words_per_minute
|
||||
token_count = word_count / settings.behaviour_settings.transcription_words_per_token
|
||||
return int(token_count) + settings.behaviour_settings.transcription_token_buffer
|
||||
|
||||
def _get_decode_options(self, audio: np.ndarray) -> dict:
|
||||
"""
|
||||
Construct decoding options for the Whisper model.
|
||||
|
||||
:param audio: The audio sample (16 kHz) to use to determine options like max decode length.
|
||||
:return: A dict that can be used to construct ``whisper.DecodingOptions`` (or equivalent).
|
||||
"""
|
||||
options = {}
|
||||
if self.limit_output_length:
|
||||
options["sample_len"] = self._estimate_max_tokens(audio)
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def best_type():
|
||||
"""
|
||||
Factory method to get the best available `SpeechRecognizer`.
|
||||
|
||||
:return: An instance of :class:`MLXWhisperSpeechRecognizer` if on macOS with Apple Silicon,
|
||||
otherwise :class:`OpenAIWhisperSpeechRecognizer`.
|
||||
"""
|
||||
if torch.mps.is_available():
|
||||
print("Choosing MLX Whisper model.")
|
||||
return MLXWhisperSpeechRecognizer()
|
||||
else:
|
||||
print("Choosing reference Whisper model.")
|
||||
return OpenAIWhisperSpeechRecognizer()
|
||||
|
||||
|
||||
class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
"""
|
||||
Speech recognizer using the MLX framework (optimized for Apple Silicon).
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
super().__init__(limit_output_length)
|
||||
self.was_loaded = False
|
||||
self.model_name = settings.speech_model_settings.mlx_model_name
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
Ensures the model is downloaded and cached. MLX loads dynamically, so this
|
||||
pre-fetches the model.
|
||||
"""
|
||||
if self.was_loaded:
|
||||
return
|
||||
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
||||
# store it in memory for later usage
|
||||
ModelHolder.get_model(self.model_name, mx.float16)
|
||||
self.was_loaded = True
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return mlx_whisper.transcribe(
|
||||
audio,
|
||||
path_or_hf_repo=self.model_name,
|
||||
**self._get_decode_options(audio),
|
||||
)["text"].strip()
|
||||
|
||||
|
||||
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
"""
|
||||
Speech recognizer using the standard OpenAI Whisper library (PyTorch).
|
||||
"""
|
||||
|
||||
def __init__(self, limit_output_length=True):
|
||||
super().__init__(limit_output_length)
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
Loads the OpenAI Whisper model onto the available device (CUDA or CPU).
|
||||
"""
|
||||
if self.model is not None:
|
||||
return
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.model = whisper.load_model(
|
||||
settings.speech_model_settings.openai_model_name, device=device
|
||||
)
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))[
|
||||
"text"
|
||||
].strip()
|
||||
@@ -0,0 +1,148 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
|
||||
from .speech_recognizer import SpeechRecognizer
|
||||
|
||||
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||
|
||||
|
||||
class TranscriptionAgent(BaseAgent):
|
||||
"""
|
||||
Transcription Agent.
|
||||
|
||||
This agent listens to audio fragments (containing speech) on a ZMQ SUB socket,
|
||||
transcribes them using the configured :class:`SpeechRecognizer`, and sends the
|
||||
resulting text to other agents (e.g., the Text Belief Extractor).
|
||||
|
||||
It uses an internal semaphore to limit the number of concurrent transcription tasks.
|
||||
|
||||
:ivar audio_in_address: The ZMQ address to receive audio from (usually from VAD Agent).
|
||||
:ivar audio_in_socket: The ZMQ SUB socket instance.
|
||||
:ivar speech_recognizer: The speech recognition engine instance.
|
||||
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
|
||||
:ivar _current_speech_reference: The reference of the current user utterance, for synchronising
|
||||
experiment logs.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str):
|
||||
"""
|
||||
Initialize the Transcription Agent.
|
||||
|
||||
:param audio_in_address: The ZMQ address of the audio source (e.g., VAD output).
|
||||
"""
|
||||
super().__init__(settings.agent_settings.transcription_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
self.audio_in_socket: azmq.Socket | None = None
|
||||
self.speech_recognizer = None
|
||||
self._concurrency = None
|
||||
self._current_speech_reference: str | None = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent resources.
|
||||
|
||||
1. Connects to the audio input ZMQ socket.
|
||||
2. Initializes the :class:`SpeechRecognizer` (choosing the best available backend).
|
||||
3. Starts the background transcription loop.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
|
||||
# Initialize recognizer and semaphore
|
||||
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
|
||||
self._concurrency = asyncio.Semaphore(max_concurrent_tasks)
|
||||
self.speech_recognizer = SpeechRecognizer.best_type()
|
||||
self.speech_recognizer.load_model() # Warmup
|
||||
|
||||
# Start background loop
|
||||
self.add_behavior(self._transcribing_loop())
|
||||
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
if msg.thread == "voice_activity":
|
||||
self._current_speech_reference = msg.body
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Stop the agent and close sockets.
|
||||
"""
|
||||
assert self.audio_in_socket is not None
|
||||
self.audio_in_socket.close()
|
||||
self.audio_in_socket = None
|
||||
return await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
"""
|
||||
Connects the ZMQ SUB socket for receiving audio data.
|
||||
"""
|
||||
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.audio_in_socket.connect(self.audio_in_address)
|
||||
|
||||
async def _transcribe(self, audio: np.ndarray) -> str:
|
||||
"""
|
||||
Run the speech recognition on the audio data.
|
||||
|
||||
This runs in a separate thread (via `asyncio.to_thread`) to avoid blocking the event loop,
|
||||
constrained by the concurrency semaphore.
|
||||
|
||||
:param audio: The audio data as a numpy array.
|
||||
:return: The transcribed text string.
|
||||
"""
|
||||
assert self._concurrency is not None and self.speech_recognizer is not None
|
||||
async with self._concurrency:
|
||||
return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio)
|
||||
|
||||
async def _share_transcription(self, transcription: str):
|
||||
"""
|
||||
Share a transcription to the other agents that depend on it, and to experiment logs.
|
||||
|
||||
Currently sends to:
|
||||
- :attr:`settings.agent_settings.text_belief_extractor_name`
|
||||
- The UI via the experiment logger
|
||||
|
||||
:param transcription: The transcribed text.
|
||||
"""
|
||||
experiment_logger.chat(
|
||||
transcription,
|
||||
extra={"role": "user", "reference": self._current_speech_reference, "partial": False},
|
||||
)
|
||||
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=self.name,
|
||||
body=transcription,
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
async def _transcribing_loop(self) -> None:
|
||||
"""
|
||||
The main loop for receiving audio and triggering transcription.
|
||||
|
||||
Receives audio chunks from ZMQ, decodes them to float32, and calls :meth:`_transcribe`.
|
||||
If speech is found, it calls :meth:`_share_transcription`.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
assert self.audio_in_socket is not None
|
||||
audio_data = await self.audio_in_socket.recv()
|
||||
audio = np.frombuffer(audio_data, dtype=np.float32)
|
||||
speech = await self._transcribe(audio)
|
||||
if not speech:
|
||||
self.logger.debug("Nothing transcribed.")
|
||||
continue
|
||||
|
||||
await self._share_transcription(speech)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in transcription loop: {e}")
|
||||
315
src/control_backend/agents/perception/vad_agent.py
Normal file
315
src/control_backend/agents/perception/vad_agent.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.internal_message import InternalMessage
|
||||
|
||||
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||
from .transcription_agent.transcription_agent import TranscriptionAgent
|
||||
|
||||
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||
|
||||
|
||||
class SocketPoller[T]:
|
||||
"""
|
||||
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
|
||||
multiple usages.
|
||||
|
||||
:param T: The type of data returned by the socket.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket: azmq.Socket,
|
||||
timeout_ms: int = settings.behaviour_settings.socket_poller_timeout_ms,
|
||||
):
|
||||
"""
|
||||
:param socket: The socket to poll and get data from.
|
||||
:param timeout_ms: A timeout in milliseconds to wait for data.
|
||||
"""
|
||||
self.socket = socket
|
||||
self.poller = azmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
self.timeout_ms = timeout_ms
|
||||
|
||||
async def poll(self, timeout_ms: int | None = None) -> T | None:
|
||||
"""
|
||||
Get data from the socket, or None if the timeout is reached.
|
||||
|
||||
:param timeout_ms: If given, the timeout. Otherwise, ``self.timeout_ms`` is used.
|
||||
:return: Data from the socket or None.
|
||||
"""
|
||||
timeout_ms = timeout_ms or self.timeout_ms
|
||||
socks = dict(await self.poller.poll(timeout_ms))
|
||||
if socks.get(self.socket) == zmq.POLLIN:
|
||||
return await self.socket.recv()
|
||||
return None
|
||||
|
||||
|
||||
class VADAgent(BaseAgent):
|
||||
"""
|
||||
Voice Activity Detection (VAD) Agent.
|
||||
|
||||
This agent:
|
||||
1. Receives an audio stream (via ZMQ).
|
||||
2. Processes the audio using the Silero VAD model to detect speech.
|
||||
3. Buffers potential speech segments.
|
||||
4. Publishes valid speech fragments (containing speech plus small buffer) to a ZMQ PUB socket.
|
||||
5. Instantiates and starts agents (like :class:`TranscriptionAgent`) that use this output.
|
||||
|
||||
:ivar audio_in_address: Address of the input audio stream.
|
||||
:ivar audio_in_bind: Whether to bind or connect to the input address.
|
||||
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
|
||||
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||
"""
|
||||
Initialize the VAD Agent.
|
||||
|
||||
:param audio_in_address: ZMQ address for input audio.
|
||||
:param audio_in_bind: True if this agent should bind to the input address, False to connect.
|
||||
"""
|
||||
super().__init__(settings.agent_settings.vad_name)
|
||||
|
||||
self.audio_in_address = audio_in_address
|
||||
self.audio_in_bind = audio_in_bind
|
||||
|
||||
self.audio_in_socket: azmq.Socket | None = None
|
||||
self.audio_out_socket: azmq.Socket | None = None
|
||||
self.audio_in_poller: SocketPoller | None = None
|
||||
|
||||
self.program_sub_socket: azmq.Socket | None = None
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||
self._ready = asyncio.Event()
|
||||
|
||||
# Pause control
|
||||
self._reset_needed = False
|
||||
self._paused = asyncio.Event()
|
||||
self._paused.set() # Not paused at start
|
||||
|
||||
self.model = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize resources.
|
||||
|
||||
1. Connects audio input socket.
|
||||
2. Binds audio output socket (random port).
|
||||
3. Connects to program communication socket.
|
||||
4. Loads VAD model from Torch Hub.
|
||||
5. Starts the streaming loop.
|
||||
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
self._connect_audio_in_socket()
|
||||
|
||||
audio_out_address = self._connect_audio_out_socket()
|
||||
if audio_out_address is None:
|
||||
self.logger.error("Could not bind output socket, stopping.")
|
||||
await self.stop()
|
||||
return
|
||||
|
||||
# Connect to internal communication socket
|
||||
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.program_sub_socket.subscribe(PROGRAM_STATUS)
|
||||
|
||||
# Initialize VAD model
|
||||
try:
|
||||
self.model, _ = torch.hub.load(
|
||||
repo_or_dir=settings.vad_settings.repo_or_dir,
|
||||
model=settings.vad_settings.model_name,
|
||||
force_reload=False,
|
||||
)
|
||||
except Exception:
|
||||
self.logger.exception("Failed to load VAD model.")
|
||||
await self.stop()
|
||||
return
|
||||
|
||||
self.add_behavior(self._streaming_loop())
|
||||
self.add_behavior(self._status_loop())
|
||||
|
||||
# Start agents dependent on the output audio fragments here
|
||||
transcriber = TranscriptionAgent(audio_out_address)
|
||||
await transcriber.start()
|
||||
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Stop listening to audio, stop publishing audio, close sockets.
|
||||
"""
|
||||
if self.audio_in_socket is not None:
|
||||
self.audio_in_socket.close()
|
||||
self.audio_in_socket = None
|
||||
if self.audio_out_socket is not None:
|
||||
self.audio_out_socket.close()
|
||||
self.audio_out_socket = None
|
||||
await super().stop()
|
||||
|
||||
def _connect_audio_in_socket(self):
|
||||
"""
|
||||
Connects (or binds) the socket for listening to audio from RI.
|
||||
:return:
|
||||
"""
|
||||
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
if self.audio_in_bind:
|
||||
self.audio_in_socket.bind(self.audio_in_address)
|
||||
else:
|
||||
self.audio_in_socket.connect(self.audio_in_address)
|
||||
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
|
||||
|
||||
def _connect_audio_out_socket(self) -> str | None:
|
||||
"""
|
||||
Returns the address that was bound to, or None if binding failed.
|
||||
"""
|
||||
try:
|
||||
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
|
||||
self.audio_out_socket.bind(settings.zmq_settings.vad_pub_address)
|
||||
return settings.zmq_settings.vad_pub_address
|
||||
except zmq.ZMQBindError:
|
||||
self.logger.error("Failed to bind an audio output socket after 100 tries.")
|
||||
self.audio_out_socket = None
|
||||
return None
|
||||
|
||||
async def _reset_stream(self):
|
||||
"""
|
||||
Clears the ZeroMQ queue and sets ready state.
|
||||
"""
|
||||
discarded = 0
|
||||
assert self.audio_in_poller is not None
|
||||
while await self.audio_in_poller.poll(1) is not None:
|
||||
discarded += 1
|
||||
self.logger.info(f"Discarded {discarded} audio packets before starting.")
|
||||
self._ready.set()
|
||||
|
||||
async def _status_loop(self):
|
||||
"""Loop for checking program status. Only start listening if program is RUNNING."""
|
||||
while self._running:
|
||||
topic, body = await self.program_sub_socket.recv_multipart()
|
||||
|
||||
if topic != PROGRAM_STATUS:
|
||||
continue
|
||||
if body != ProgramStatus.RUNNING.value:
|
||||
continue
|
||||
|
||||
# Program is now running, we can start our stream
|
||||
await self._reset_stream()
|
||||
|
||||
# We don't care about further status updates
|
||||
self.program_sub_socket.close()
|
||||
break
|
||||
|
||||
async def _streaming_loop(self):
|
||||
"""
|
||||
Main loop for processing audio stream.
|
||||
|
||||
1. Polls for new audio chunks.
|
||||
2. Passes chunk to VAD model.
|
||||
3. Manages `i_since_speech` counter to determine start/end of speech.
|
||||
4. Buffers speech + context.
|
||||
5. Sends complete speech segment to output socket when silence is detected.
|
||||
"""
|
||||
await self._ready.wait()
|
||||
while self._running:
|
||||
await self._paused.wait()
|
||||
|
||||
# After being unpaused, reset stream and buffers
|
||||
if self._reset_needed:
|
||||
self.logger.debug("Resuming: resetting stream and buffers.")
|
||||
await self._reset_stream()
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||
self._reset_needed = False
|
||||
|
||||
assert self.audio_in_poller is not None
|
||||
data = await self.audio_in_poller.poll()
|
||||
if data is None:
|
||||
if len(self.audio_buffer) > 0:
|
||||
self.logger.debug(
|
||||
"No audio data received. Discarding buffer until new data arrives."
|
||||
)
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||
continue
|
||||
|
||||
# copy otherwise Torch will be sad that it's immutable
|
||||
chunk = np.frombuffer(data, dtype=np.float32).copy()
|
||||
assert self.model is not None
|
||||
prob = self.model(torch.from_numpy(chunk), settings.vad_settings.sample_rate_hz).item()
|
||||
non_speech_patience = settings.behaviour_settings.vad_non_speech_patience_chunks
|
||||
begin_silence_length = settings.behaviour_settings.vad_begin_silence_chunks
|
||||
prob_threshold = settings.behaviour_settings.vad_prob_threshold
|
||||
|
||||
if prob > prob_threshold:
|
||||
if self.i_since_speech > non_speech_patience + begin_silence_length:
|
||||
self.logger.debug("Speech started.")
|
||||
reference = str(uuid.uuid4())
|
||||
experiment_logger.chat(
|
||||
"...",
|
||||
extra={"role": "user", "reference": reference, "partial": True},
|
||||
)
|
||||
await self.send(
|
||||
InternalMessage(
|
||||
to=settings.agent_settings.transcription_name,
|
||||
body=reference,
|
||||
thread="voice_activity",
|
||||
)
|
||||
)
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
self.i_since_speech = 0
|
||||
continue
|
||||
|
||||
self.i_since_speech += 1
|
||||
|
||||
# prob < threshold, so speech maybe ended. Wait a bit more before to be more certain
|
||||
if self.i_since_speech <= non_speech_patience:
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
continue
|
||||
|
||||
# Speech probably ended. Make sure we have a usable amount of data.
|
||||
if len(self.audio_buffer) > begin_silence_length * len(chunk):
|
||||
self.logger.debug("Speech ended.")
|
||||
assert self.audio_out_socket is not None
|
||||
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
|
||||
|
||||
# At this point, we know that there is no speech.
|
||||
# Prepend the last few chunks that had no speech, for a more fluent boundary.
|
||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||
self.audio_buffer = self.audio_buffer[-begin_silence_length * len(chunk) :]
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming messages.
|
||||
|
||||
Expects messages to pause or resume the VAD processing from User Interrupt Agent.
|
||||
|
||||
:param msg: The received internal message.
|
||||
"""
|
||||
sender = msg.sender
|
||||
|
||||
if sender == settings.agent_settings.user_interrupt_name:
|
||||
if msg.body == "PAUSE":
|
||||
self.logger.info("Pausing VAD processing.")
|
||||
self._paused.clear()
|
||||
# If the robot needs to pick up speaking where it left off, do not set _reset_needed
|
||||
self._reset_needed = True
|
||||
elif msg.body == "RESUME":
|
||||
self.logger.info("Resuming VAD processing.")
|
||||
self._paused.set()
|
||||
else:
|
||||
self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}")
|
||||
else:
|
||||
self.logger.debug(f"Ignoring message from unknown sender: {sender}")
|
||||
@@ -0,0 +1,425 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import zmq
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||
from control_backend.schemas.program import ConditionalNorm, Goal, Program
|
||||
from control_backend.schemas.ri_message import (
|
||||
GestureCommand,
|
||||
PauseCommand,
|
||||
RIEndpoint,
|
||||
SpeechCommand,
|
||||
)
|
||||
|
||||
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||
|
||||
|
||||
class UserInterruptAgent(BaseAgent):
|
||||
"""
|
||||
User Interrupt Agent.
|
||||
|
||||
This agent receives button_pressed events from the external HTTP API
|
||||
(via ZMQ) and uses the associated context to trigger one of the following actions:
|
||||
|
||||
- Send a prioritized message to the `RobotSpeechAgent`
|
||||
- Send a prioritized gesture to the `RobotGestureAgent`
|
||||
- Send a belief override to the `BDI Core` in order to activate a
|
||||
trigger/conditional norm or complete a goal.
|
||||
|
||||
Prioritized actions clear the current RI queue before inserting the new item,
|
||||
ensuring they are executed immediately after Pepper's current action has been fulfilled.
|
||||
|
||||
:ivar sub_socket: The ZMQ SUB socket used to receive user interrupts.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.sub_socket = None
|
||||
self.pub_socket = None
|
||||
self._trigger_map = {}
|
||||
self._trigger_reverse_map = {}
|
||||
|
||||
self._goal_map = {} # id -> sluggified goal
|
||||
self._goal_reverse_map = {} # sluggified goal -> id
|
||||
|
||||
self._cond_norm_map = {} # id -> sluggified cond norm
|
||||
self._cond_norm_reverse_map = {} # sluggified cond norm -> id
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent by setting up ZMQ sockets for receiving button events and
|
||||
publishing updates.
|
||||
"""
|
||||
context = Context.instance()
|
||||
|
||||
self.sub_socket = context.socket(zmq.SUB)
|
||||
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.sub_socket.subscribe("button_pressed")
|
||||
|
||||
self.pub_socket = context.socket(zmq.PUB)
|
||||
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||
|
||||
self.add_behavior(self._receive_button_event())
|
||||
|
||||
async def _receive_button_event(self):
|
||||
"""
|
||||
Main loop to receive and process button press events from the UI.
|
||||
|
||||
Handles different event types:
|
||||
- `speech`: Triggers immediate robot speech.
|
||||
- `gesture`: Triggers an immediate robot gesture.
|
||||
- `override`: Forces a belief, trigger, or goal completion in the BDI core.
|
||||
- `override_unachieve`: Removes a belief from the BDI core.
|
||||
- `pause`: Toggles the system's pause state.
|
||||
- `next_phase` / `reset_phase`: Controls experiment flow.
|
||||
"""
|
||||
while True:
|
||||
topic, body = await self.sub_socket.recv_multipart()
|
||||
|
||||
try:
|
||||
event_data = json.loads(body)
|
||||
event_type = event_data.get("type") # e.g., "speech", "gesture"
|
||||
event_context = event_data.get("context") # e.g., "Hello, I am Pepper!"
|
||||
except json.JSONDecodeError:
|
||||
self.logger.error("Received invalid JSON payload on topic %s", topic)
|
||||
continue
|
||||
|
||||
self.logger.debug("Received event type %s", event_type)
|
||||
|
||||
match event_type:
|
||||
case "speech":
|
||||
await self._send_to_speech_agent(event_context)
|
||||
self.logger.info(
|
||||
"Forwarded button press (speech) with context '%s' to RobotSpeechAgent.",
|
||||
event_context,
|
||||
)
|
||||
case "gesture":
|
||||
await self._send_to_gesture_agent(event_context)
|
||||
self.logger.info(
|
||||
"Forwarded button press (gesture) with context '%s' to RobotGestureAgent.",
|
||||
event_context,
|
||||
)
|
||||
case "override":
|
||||
ui_id = str(event_context)
|
||||
if asl_trigger := self._trigger_map.get(ui_id):
|
||||
await self._send_to_bdi("force_trigger", asl_trigger)
|
||||
self.logger.info(
|
||||
"Forwarded button press (override) with context '%s' to BDI Core.",
|
||||
event_context,
|
||||
)
|
||||
elif asl_cond_norm := self._cond_norm_map.get(ui_id):
|
||||
await self._send_to_bdi_belief(asl_cond_norm, "cond_norm")
|
||||
self.logger.info(
|
||||
"Forwarded button press (override) with context '%s' to BDI Core.",
|
||||
event_context,
|
||||
)
|
||||
elif asl_goal := self._goal_map.get(ui_id):
|
||||
await self._send_to_bdi_belief(asl_goal, "goal")
|
||||
self.logger.info(
|
||||
"Forwarded button press (override) with context '%s' to BDI Core.",
|
||||
event_context,
|
||||
)
|
||||
# Send achieve_goal to program manager to update semantic belief extractor
|
||||
goal_achieve_msg = InternalMessage(
|
||||
to=settings.agent_settings.bdi_program_manager_name,
|
||||
thread="achieve_goal",
|
||||
body=ui_id,
|
||||
)
|
||||
|
||||
await self.send(goal_achieve_msg)
|
||||
else:
|
||||
self.logger.warning("Could not determine which element to override.")
|
||||
case "override_unachieve":
|
||||
ui_id = str(event_context)
|
||||
if asl_cond_norm := self._cond_norm_map.get(ui_id):
|
||||
await self._send_to_bdi_belief(asl_cond_norm, "cond_norm", True)
|
||||
self.logger.info(
|
||||
"Forwarded button press (override_unachieve)"
|
||||
"with context '%s' to BDI Core.",
|
||||
event_context,
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Could not determine which conditional norm to unachieve."
|
||||
)
|
||||
|
||||
case "pause":
|
||||
self.logger.debug(
|
||||
"Received pause/resume button press with context '%s'.", event_context
|
||||
)
|
||||
await self._send_pause_command(event_context)
|
||||
if event_context:
|
||||
self.logger.info("Sent pause command.")
|
||||
else:
|
||||
self.logger.info("Sent resume command.")
|
||||
|
||||
case "next_phase" | "reset_phase":
|
||||
await self._send_experiment_control_to_bdi_core(event_type)
|
||||
case _:
|
||||
self.logger.warning(
|
||||
"Received button press with unknown type '%s' (context: '%s').",
|
||||
event_type,
|
||||
event_context,
|
||||
)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handles internal messages from other agents, such as program updates or trigger
|
||||
notifications.
|
||||
|
||||
:param msg: The incoming :class:`~control_backend.core.agent_system.InternalMessage`.
|
||||
"""
|
||||
match msg.thread:
|
||||
case "new_program":
|
||||
self._create_mapping(msg.body)
|
||||
case "trigger_start":
|
||||
# msg.body is the sluggified trigger
|
||||
asl_slug = msg.body
|
||||
ui_id = self._trigger_reverse_map.get(asl_slug)
|
||||
|
||||
if ui_id:
|
||||
payload = {"type": "trigger_update", "id": ui_id, "achieved": True}
|
||||
await self._send_experiment_update(payload)
|
||||
self.logger.info(f"UI Update: Trigger {asl_slug} started (ID: {ui_id})")
|
||||
case "trigger_end":
|
||||
asl_slug = msg.body
|
||||
ui_id = self._trigger_reverse_map.get(asl_slug)
|
||||
if ui_id:
|
||||
payload = {"type": "trigger_update", "id": ui_id, "achieved": False}
|
||||
await self._send_experiment_update(payload)
|
||||
self.logger.info(f"UI Update: Trigger {asl_slug} ended (ID: {ui_id})")
|
||||
case "transition_phase":
|
||||
new_phase_id = msg.body
|
||||
self.logger.info(f"Phase transition detected: {new_phase_id}")
|
||||
|
||||
payload = {"type": "phase_update", "id": new_phase_id}
|
||||
|
||||
await self._send_experiment_update(payload)
|
||||
case "goal_start":
|
||||
goal_name = msg.body
|
||||
ui_id = self._goal_reverse_map.get(goal_name)
|
||||
if ui_id:
|
||||
payload = {"type": "goal_update", "id": ui_id, "active": True}
|
||||
await self._send_experiment_update(payload)
|
||||
self.logger.info(f"UI Update: Goal {goal_name} started (ID: {ui_id})")
|
||||
case "active_norms_update":
|
||||
active_norms_asl = [
|
||||
s.strip("() '\",") for s in msg.body.split(",") if s.strip("() '\",")
|
||||
]
|
||||
await self._broadcast_cond_norms(active_norms_asl)
|
||||
case _:
|
||||
self.logger.debug(f"Received internal message on unhandled thread: {msg.thread}")
|
||||
|
||||
async def _broadcast_cond_norms(self, active_slugs: list[str]):
|
||||
"""
|
||||
Broadcasts the current activation state of all conditional norms to the UI.
|
||||
|
||||
:param active_slugs: A list of sluggified norm names currently active in the BDI core.
|
||||
"""
|
||||
updates = []
|
||||
for asl_slug, ui_id in self._cond_norm_reverse_map.items():
|
||||
is_active = asl_slug in active_slugs
|
||||
updates.append({"id": ui_id, "active": is_active})
|
||||
|
||||
payload = {"type": "cond_norms_state_update", "norms": updates}
|
||||
|
||||
if self.pub_socket:
|
||||
topic = b"status"
|
||||
body = json.dumps(payload).encode("utf-8")
|
||||
await self.pub_socket.send_multipart([topic, body])
|
||||
# self.logger.info(f"UI Update: Active norms {updates}")
|
||||
|
||||
def _create_mapping(self, program_json: str):
|
||||
"""
|
||||
Creates a bidirectional mapping between UI identifiers and AgentSpeak slugs.
|
||||
|
||||
:param program_json: The JSON representation of the behavioral program.
|
||||
"""
|
||||
try:
|
||||
program = Program.model_validate_json(program_json)
|
||||
self._trigger_map = {}
|
||||
self._trigger_reverse_map = {}
|
||||
self._goal_map = {}
|
||||
self._cond_norm_map = {}
|
||||
self._cond_norm_reverse_map = {}
|
||||
|
||||
def _register_goal(goal: Goal):
|
||||
"""Recursively register goals and their subgoals."""
|
||||
slug = AgentSpeakGenerator.slugify(goal)
|
||||
self._goal_map[str(goal.id)] = slug
|
||||
self._goal_reverse_map[slug] = str(goal.id)
|
||||
|
||||
for step in goal.plan.steps:
|
||||
if isinstance(step, Goal):
|
||||
_register_goal(step)
|
||||
|
||||
for phase in program.phases:
|
||||
for trigger in phase.triggers:
|
||||
slug = AgentSpeakGenerator.slugify(trigger)
|
||||
self._trigger_map[str(trigger.id)] = slug
|
||||
self._trigger_reverse_map[slug] = str(trigger.id)
|
||||
|
||||
for goal in phase.goals:
|
||||
_register_goal(goal)
|
||||
|
||||
for goal, id in self._goal_reverse_map.items():
|
||||
self.logger.debug(f"Goal mapping: UI ID {goal} -> {id}")
|
||||
|
||||
for norm in phase.norms:
|
||||
if isinstance(norm, ConditionalNorm):
|
||||
asl_slug = AgentSpeakGenerator.slugify(norm)
|
||||
|
||||
norm_id = str(norm.id)
|
||||
|
||||
self._cond_norm_map[norm_id] = asl_slug
|
||||
self._cond_norm_reverse_map[norm.norm] = norm_id
|
||||
self.logger.debug("Added conditional norm %s", asl_slug)
|
||||
|
||||
self.logger.info(
|
||||
f"Mapped {len(self._trigger_map)} triggers and {len(self._goal_map)} goals "
|
||||
f"and {len(self._cond_norm_map)} conditional norms for UserInterruptAgent."
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Mapping failed: {e}")
|
||||
|
||||
async def _send_experiment_update(self, data, should_log: bool = True):
|
||||
"""
|
||||
Publishes an experiment state update to the internal ZMQ bus for the UI.
|
||||
|
||||
:param data: The update payload.
|
||||
:param should_log: Whether to log the update.
|
||||
"""
|
||||
if self.pub_socket:
|
||||
topic = b"experiment"
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
await self.pub_socket.send_multipart([topic, body])
|
||||
if should_log:
|
||||
self.logger.debug(f"Sent experiment update: {data}")
|
||||
|
||||
async def _send_to_speech_agent(self, text_to_say: str):
|
||||
"""
|
||||
method to send prioritized speech command to RobotSpeechAgent.
|
||||
|
||||
:param text_to_say: The string that the robot has to say.
|
||||
"""
|
||||
experiment_logger.chat(text_to_say, extra={"role": "assistant"})
|
||||
cmd = SpeechCommand(data=text_to_say, is_priority=True)
|
||||
out_msg = InternalMessage(
|
||||
to=settings.agent_settings.robot_speech_name,
|
||||
sender=self.name,
|
||||
body=cmd.model_dump_json(),
|
||||
)
|
||||
await self.send(out_msg)
|
||||
|
||||
async def _send_to_gesture_agent(self, single_gesture_name: str):
|
||||
"""
|
||||
method to send prioritized gesture command to RobotGestureAgent.
|
||||
|
||||
:param single_gesture_name: The gesture tag that the robot has to perform.
|
||||
"""
|
||||
# the endpoint is set to always be GESTURE_SINGLE for user interrupts
|
||||
cmd = GestureCommand(
|
||||
endpoint=RIEndpoint.GESTURE_SINGLE, data=single_gesture_name, is_priority=True
|
||||
)
|
||||
out_msg = InternalMessage(
|
||||
to=settings.agent_settings.robot_gesture_name,
|
||||
sender=self.name,
|
||||
body=cmd.model_dump_json(),
|
||||
)
|
||||
await self.send(out_msg)
|
||||
|
||||
async def _send_to_bdi(self, thread: str, body: str):
|
||||
"""Send slug of trigger to BDI"""
|
||||
msg = InternalMessage(to=settings.agent_settings.bdi_core_name, thread=thread, body=body)
|
||||
await self.send(msg)
|
||||
self.logger.info(f"Directly forced {thread} in BDI: {body}")
|
||||
|
||||
async def _send_to_bdi_belief(self, asl: str, asl_type: str, unachieve: bool = False):
|
||||
"""Send belief to BDI Core"""
|
||||
if asl_type == "goal":
|
||||
belief_name = f"achieved_{asl}"
|
||||
elif asl_type == "cond_norm":
|
||||
belief_name = f"force_{asl}"
|
||||
else:
|
||||
self.logger.warning("Tried to send belief with unknown type")
|
||||
return
|
||||
belief = Belief(name=belief_name, arguments=None)
|
||||
self.logger.debug(f"Sending belief to BDI Core: {belief_name}")
|
||||
# Conditional norms are unachieved by removing the belief
|
||||
belief_message = (
|
||||
BeliefMessage(delete=[belief]) if unachieve else BeliefMessage(create=[belief])
|
||||
)
|
||||
msg = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
thread="beliefs",
|
||||
body=belief_message.model_dump_json(),
|
||||
)
|
||||
await self.send(msg)
|
||||
|
||||
async def _send_experiment_control_to_bdi_core(self, type):
|
||||
"""
|
||||
method to send experiment control buttons to bdi core.
|
||||
|
||||
:param type: the type of control button we should send to the bdi core.
|
||||
"""
|
||||
# Switch which thread we should send to bdi core
|
||||
thread = ""
|
||||
match type:
|
||||
case "next_phase":
|
||||
thread = "force_next_phase"
|
||||
case "reset_phase":
|
||||
thread = "reset_current_phase"
|
||||
case "reset_experiment":
|
||||
thread = "reset_experiment"
|
||||
case _:
|
||||
self.logger.warning(
|
||||
"Received unknown experiment control type '%s' to send to BDI Core.",
|
||||
type,
|
||||
)
|
||||
|
||||
out_msg = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
sender=self.name,
|
||||
thread=thread,
|
||||
body="",
|
||||
)
|
||||
self.logger.debug("Sending experiment control '%s' to BDI Core.", thread)
|
||||
await self.send(out_msg)
|
||||
|
||||
async def _send_pause_command(self, pause):
|
||||
"""
|
||||
Send a pause command to the Robot Interface via the RI Communication Agent.
|
||||
Send a pause command to the other internal agents; for now just VAD agent.
|
||||
"""
|
||||
cmd = PauseCommand(data=pause)
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.ri_communication_name,
|
||||
sender=self.name,
|
||||
body=cmd.model_dump_json(),
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
if pause == "true":
|
||||
# Send pause to VAD agent
|
||||
vad_message = InternalMessage(
|
||||
to=settings.agent_settings.vad_name,
|
||||
sender=self.name,
|
||||
body="PAUSE",
|
||||
)
|
||||
await self.send(vad_message)
|
||||
self.logger.info("Sent pause command to VAD Agent and RI Communication Agent.")
|
||||
else:
|
||||
# Send resume to VAD agent
|
||||
vad_message = InternalMessage(
|
||||
to=settings.agent_settings.vad_name,
|
||||
sender=self.name,
|
||||
body="RESUME",
|
||||
)
|
||||
await self.send(vad_message)
|
||||
self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.")
|
||||
0
src/control_backend/api/__init__.py
Normal file
0
src/control_backend/api/__init__.py
Normal file
0
src/control_backend/api/v1/__init__.py
Normal file
0
src/control_backend/api/v1/__init__.py
Normal file
0
src/control_backend/api/v1/endpoints/__init__.py
Normal file
0
src/control_backend/api/v1/endpoints/__init__.py
Normal file
67
src/control_backend/api/v1/endpoints/logs.py
Normal file
67
src/control_backend/api/v1/endpoints/logs.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import zmq
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# DO NOT LOG INSIDE THIS FUNCTION
|
||||
@router.get("/logs/stream")
|
||||
async def log_stream():
|
||||
"""
|
||||
Server-Sent Events (SSE) endpoint for real-time log streaming.
|
||||
|
||||
Subscribes to the internal ZMQ logging topic and forwards log records to the client.
|
||||
Allows the frontend to display live logs from the backend.
|
||||
|
||||
:return: A StreamingResponse yielding SSE data.
|
||||
"""
|
||||
context = Context.instance()
|
||||
socket = context.socket(zmq.SUB)
|
||||
|
||||
for level in logging.getLevelNamesMapping():
|
||||
socket.subscribe(topic=level)
|
||||
|
||||
socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
|
||||
async def gen():
|
||||
while True:
|
||||
_, message = await socket.recv_multipart()
|
||||
message = message.decode().strip()
|
||||
yield f"data: {message}\n\n"
|
||||
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
|
||||
|
||||
LOGGING_DIR = Path(settings.logging_settings.experiment_log_directory).resolve()
|
||||
|
||||
|
||||
@router.get("/logs/files")
|
||||
@router.get("/api/logs/files")
|
||||
async def log_directory():
|
||||
"""
|
||||
Get a list of all log files stored in the experiment log file directory.
|
||||
"""
|
||||
return [f.name for f in LOGGING_DIR.glob("*.log")]
|
||||
|
||||
|
||||
@router.get("/logs/files/{filename}")
|
||||
@router.get("/api/logs/files/{filename}")
|
||||
async def log_file(filename: str):
|
||||
# Prevent path-traversal
|
||||
file_path = (LOGGING_DIR / filename).resolve() # This .resolve() is important
|
||||
if not file_path.is_relative_to(LOGGING_DIR):
|
||||
raise HTTPException(status_code=400, detail="Invalid filename.")
|
||||
|
||||
if not file_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="File not found.")
|
||||
|
||||
return FileResponse(file_path, filename=file_path.name)
|
||||
30
src/control_backend/api/v1/endpoints/message.py
Normal file
30
src/control_backend/api/v1/endpoints/message.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from control_backend.schemas.message import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/message", status_code=202)
|
||||
async def receive_message(message: Message, request: Request):
|
||||
"""
|
||||
Generic endpoint to receive text messages.
|
||||
|
||||
Publishes the message to the internal 'message' topic via ZMQ.
|
||||
|
||||
:param message: The message payload.
|
||||
:param request: The FastAPI request object (used to access app state).
|
||||
"""
|
||||
logger.info("Received message: %s", message.message)
|
||||
|
||||
topic = b"message"
|
||||
body = message.model_dump_json().encode("utf-8")
|
||||
|
||||
pub_socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, body])
|
||||
|
||||
return {"status": "Message received"}
|
||||
31
src/control_backend/api/v1/endpoints/program.py
Normal file
31
src/control_backend/api/v1/endpoints/program.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from control_backend.schemas.program import Program
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/program", status_code=202)
|
||||
async def receive_message(program: Program, request: Request):
|
||||
"""
|
||||
Endpoint to upload a new Behavior Program.
|
||||
|
||||
Validates the program structure (phases, norms, goals) and publishes it to the internal
|
||||
'program' topic. The :class:`~control_backend.agents.bdi.bdi_program_manager.BDIProgramManager`
|
||||
will pick this up and update the BDI agent.
|
||||
|
||||
:param program: The parsed Program object.
|
||||
:param request: The FastAPI request object.
|
||||
"""
|
||||
logger.debug("Received raw program: %s", program)
|
||||
|
||||
# send away
|
||||
topic = b"program"
|
||||
body = program.model_dump_json().encode()
|
||||
pub_socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, body])
|
||||
|
||||
return {"status": "Program parsed"}
|
||||
143
src/control_backend/api/v1/endpoints/robot.py
Normal file
143
src/control_backend/api/v1/endpoints/robot.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import zmq.asyncio
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from zmq.asyncio import Context, Socket
|
||||
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.ri_message import GestureCommand, SpeechCommand
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/command/speech", status_code=202)
|
||||
async def receive_command_speech(command: SpeechCommand, request: Request):
|
||||
"""
|
||||
Send a direct speech command to the robot.
|
||||
|
||||
Publishes the command to the internal 'command' topic. The
|
||||
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`
|
||||
will forward this to the robot.
|
||||
|
||||
:param command: The speech command payload.
|
||||
:param request: The FastAPI request object.
|
||||
"""
|
||||
topic = b"command"
|
||||
|
||||
pub_socket: Socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||
|
||||
return {"status": "Speech command received"}
|
||||
|
||||
|
||||
@router.post("/command/gesture", status_code=202)
|
||||
async def receive_command_gesture(command: GestureCommand, request: Request):
|
||||
"""
|
||||
Send a direct gesture command to the robot.
|
||||
|
||||
Publishes the command to the internal 'command' topic. The
|
||||
:class:`~control_backend.agents.actuation.robot_speech_agent.RobotGestureAgent`
|
||||
will forward this to the robot.
|
||||
|
||||
:param command: The speech command payload.
|
||||
:param request: The FastAPI request object.
|
||||
"""
|
||||
topic = b"command"
|
||||
|
||||
pub_socket: Socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||
|
||||
return {"status": "Gesture command received"}
|
||||
|
||||
|
||||
@router.get("/ping_check")
|
||||
async def ping(request: Request):
|
||||
"""
|
||||
Simple HTTP ping endpoint to check if the backend is reachable.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/commands/gesture/tags")
|
||||
async def get_available_gesture_tags(request: Request, count=0):
|
||||
"""
|
||||
Endpoint to retrieve the available gesture tags for the robot.
|
||||
|
||||
:param request: The FastAPI request object.
|
||||
:return: A list of available gesture tags.
|
||||
"""
|
||||
req_socket = Context.instance().socket(zmq.REQ)
|
||||
req_socket.connect(settings.zmq_settings.internal_gesture_rep_adress)
|
||||
|
||||
# Check to see if we've got any count given in the query parameter
|
||||
amount = count or None
|
||||
timeout = 5 # seconds
|
||||
|
||||
await req_socket.send(f"{amount}".encode() if amount else b"None")
|
||||
try:
|
||||
body = await asyncio.wait_for(req_socket.recv(), timeout=timeout)
|
||||
except TimeoutError:
|
||||
body = '{"tags": []}'
|
||||
logger.debug("Got timeout error fetching gestures.")
|
||||
|
||||
# Handle empty response and JSON decode errors
|
||||
available_tags = []
|
||||
if body:
|
||||
try:
|
||||
available_tags = json.loads(body).get("tags", [])
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}")
|
||||
# Return empty list on JSON error
|
||||
available_tags = []
|
||||
|
||||
return {"available_gesture_tags": available_tags}
|
||||
|
||||
|
||||
@router.get("/ping_stream")
|
||||
async def ping_stream(request: Request):
|
||||
"""
|
||||
SSE endpoint for monitoring the Robot Interface connection status.
|
||||
|
||||
Subscribes to the internal 'ping' topic (published by the RI Communication Agent)
|
||||
and yields status updates to the client.
|
||||
|
||||
:return: A StreamingResponse of connection status events.
|
||||
"""
|
||||
|
||||
async def event_stream():
|
||||
# Set up internal socket to receive ping updates
|
||||
|
||||
sub_socket = Context.instance().socket(zmq.SUB)
|
||||
sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
sub_socket.setsockopt(zmq.SUBSCRIBE, b"ping")
|
||||
connected = False
|
||||
|
||||
ping_frequency = 2
|
||||
|
||||
# Even though its most likely the updates should alternate
|
||||
# (So, True - False - True - False for connectivity),
|
||||
# let's still check.
|
||||
while True:
|
||||
try:
|
||||
topic, body = await asyncio.wait_for(
|
||||
sub_socket.recv_multipart(), timeout=ping_frequency
|
||||
)
|
||||
connected = json.loads(body)
|
||||
except TimeoutError:
|
||||
logger.debug("got timeout error in ping loop in ping router")
|
||||
connected = False
|
||||
|
||||
# Stop if client disconnected
|
||||
if await request.is_disconnected():
|
||||
logger.info("Client disconnected from SSE")
|
||||
break
|
||||
|
||||
connectedJson = json.dumps(connected)
|
||||
yield (f"data: {connectedJson}\n\n")
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
94
src/control_backend/api/v1/endpoints/user_interact.py
Normal file
94
src/control_backend/api/v1/endpoints/user_interact.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.events import ButtonPressedEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/button_pressed", status_code=202)
|
||||
async def receive_button_event(event: ButtonPressedEvent, request: Request):
|
||||
"""
|
||||
Endpoint to handle external button press events.
|
||||
|
||||
Validates the event payload and publishes it to the internal 'button_pressed' topic.
|
||||
Subscribers (in this case user_interrupt_agent) will pick this up to trigger
|
||||
specific behaviors or state changes.
|
||||
|
||||
:param event: The parsed ButtonPressedEvent object.
|
||||
:param request: The FastAPI request object.
|
||||
"""
|
||||
logger.debug("Received button event: %s | %s", event.type, event.context)
|
||||
|
||||
topic = b"button_pressed"
|
||||
body = event.model_dump_json().encode()
|
||||
|
||||
pub_socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, body])
|
||||
|
||||
return {"status": "Event received"}
|
||||
|
||||
|
||||
@router.get("/experiment_stream")
|
||||
async def experiment_stream(request: Request):
|
||||
# Use the asyncio-compatible context
|
||||
context = Context.instance()
|
||||
socket = context.socket(zmq.SUB)
|
||||
|
||||
# Connect and subscribe
|
||||
socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
socket.subscribe(b"experiment")
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
while True:
|
||||
# Check if client closed the tab
|
||||
if await request.is_disconnected():
|
||||
logger.error("Client disconnected from experiment stream.")
|
||||
break
|
||||
|
||||
try:
|
||||
parts = await asyncio.wait_for(socket.recv_multipart(), timeout=10.0)
|
||||
_, message = parts
|
||||
yield f"data: {message.decode().strip()}\n\n"
|
||||
except TimeoutError:
|
||||
continue
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/status_stream")
|
||||
async def status_stream(request: Request):
|
||||
context = Context.instance()
|
||||
socket = context.socket(zmq.SUB)
|
||||
socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
|
||||
socket.subscribe(b"status")
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
# Shorter timeout since this is frequent
|
||||
parts = await asyncio.wait_for(socket.recv_multipart(), timeout=0.5)
|
||||
_, message = parts
|
||||
yield f"data: {message.decode().strip()}\n\n"
|
||||
except TimeoutError:
|
||||
yield ": ping\n\n" # Keep the connection alive
|
||||
continue
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
15
src/control_backend/api/v1/router.py
Normal file
15
src/control_backend/api/v1/router.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from control_backend.api.v1.endpoints import logs, message, program, robot, user_interact
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(message.router, tags=["Messages"])
|
||||
|
||||
api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"])
|
||||
|
||||
api_router.include_router(logs.router, tags=["Logs"])
|
||||
|
||||
api_router.include_router(program.router, tags=["Program"])
|
||||
|
||||
api_router.include_router(user_interact.router, tags=["Button Pressed Events"])
|
||||
0
src/control_backend/core/__init__.py
Normal file
0
src/control_backend/core/__init__.py
Normal file
231
src/control_backend/core/agent_system.py
Normal file
231
src/control_backend/core/agent_system.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Task
|
||||
from collections.abc import Coroutine
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.internal_message import InternalMessage
|
||||
|
||||
# Central directory to resolve agent names to instances
|
||||
_agent_directory: dict[str, "BaseAgent"] = {}
|
||||
|
||||
|
||||
class AgentDirectory:
|
||||
"""
|
||||
Helper class to keep track of which agents are registered.
|
||||
Used for handling message routing.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def register(name: str, agent: "BaseAgent"):
|
||||
"""
|
||||
Registers an agent instance with a unique name.
|
||||
|
||||
:param name: The name of the agent.
|
||||
:param agent: The :class:`BaseAgent` instance.
|
||||
"""
|
||||
_agent_directory[name] = agent
|
||||
|
||||
@staticmethod
|
||||
def get(name: str) -> "BaseAgent | None":
|
||||
"""
|
||||
Retrieves a registered agent instance by name.
|
||||
|
||||
:param name: The name of the agent to retrieve.
|
||||
:return: The :class:`BaseAgent` instance, or None if not found.
|
||||
"""
|
||||
return _agent_directory.get(name)
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""
|
||||
Abstract base class for all agents in the system.
|
||||
|
||||
This class provides the foundational infrastructure for agent lifecycle management, messaging
|
||||
(both intra-process and inter-process via ZMQ), and asynchronous behavior execution.
|
||||
|
||||
.. warning::
|
||||
Do not inherit from this class directly for creating new agents. Instead, inherit from
|
||||
:class:`control_backend.agents.base.BaseAgent`, which ensures proper logger configuration.
|
||||
|
||||
:ivar name: The unique name of the agent.
|
||||
:ivar inbox: The queue for receiving internal messages.
|
||||
:ivar _tasks: A set of currently running asynchronous tasks/behaviors.
|
||||
:ivar _running: A boolean flag indicating if the agent is currently running.
|
||||
:ivar logger: The logger instance for the agent.
|
||||
"""
|
||||
|
||||
logger: logging.Logger
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
Initialize the BaseAgent.
|
||||
|
||||
:param name: The unique identifier for this agent.
|
||||
"""
|
||||
self.name = name
|
||||
self.inbox: asyncio.Queue[InternalMessage] = asyncio.Queue()
|
||||
self._tasks: set[asyncio.Task] = set()
|
||||
self._running = False
|
||||
|
||||
self._internal_pub_socket: None | azmq.Socket = None
|
||||
self._internal_sub_socket: None | azmq.Socket = None
|
||||
|
||||
# Register immediately
|
||||
AgentDirectory.register(name, self)
|
||||
|
||||
@abstractmethod
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize agent-specific resources.
|
||||
|
||||
This method must be overridden by subclasses. It is called after the agent has started
|
||||
and the ZMQ sockets have been initialized. Use this method to:
|
||||
|
||||
* Initialize connections (databases, APIs, etc.)
|
||||
* Add initial behaviors using :meth:`add_behavior`
|
||||
"""
|
||||
pass
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
Start the agent and its internal loops.
|
||||
|
||||
This method:
|
||||
1. Sets the running state to True.
|
||||
2. Initializes ZeroMQ PUB/SUB sockets for inter-process communication.
|
||||
3. Calls the user-defined :meth:`setup` method.
|
||||
4. Starts the inbox processing loop and the ZMQ receiver loop.
|
||||
"""
|
||||
self.logger.info(f"Starting agent {self.name}")
|
||||
self._running = True
|
||||
|
||||
context = azmq.Context.instance()
|
||||
|
||||
# Setup the internal publishing socket
|
||||
self._internal_pub_socket = context.socket(zmq.PUB)
|
||||
self._internal_pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||
|
||||
# Setup the internal receiving socket
|
||||
self._internal_sub_socket = context.socket(zmq.SUB)
|
||||
self._internal_sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self._internal_sub_socket.subscribe(f"internal/{self.name}")
|
||||
|
||||
await self.setup()
|
||||
|
||||
# Start processing inbox and ZMQ messages
|
||||
self.add_behavior(self._process_inbox())
|
||||
self.add_behavior(self._receive_internal_zmq_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Stop the agent.
|
||||
|
||||
Sets the running state to False and cancels all running background tasks.
|
||||
"""
|
||||
self._running = False
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
self.logger.info(f"Agent {self.name} stopped")
|
||||
|
||||
async def send(self, message: InternalMessage, should_log: bool = True):
|
||||
"""
|
||||
Send a message to another agent.
|
||||
|
||||
This method intelligently routes the message:
|
||||
|
||||
* If the target agent is in the same process (found in :class:`AgentDirectory`),
|
||||
the message is put directly into its inbox.
|
||||
* If the target agent is not found locally, the message is serialized and sent
|
||||
via ZeroMQ to the internal publication address.
|
||||
|
||||
:param message: The message to send.
|
||||
"""
|
||||
message.sender = self.name
|
||||
to = message.to
|
||||
receivers = [to] if isinstance(to, str) else to
|
||||
|
||||
for receiver in receivers:
|
||||
target = AgentDirectory.get(receiver)
|
||||
|
||||
if target:
|
||||
await target.inbox.put(message)
|
||||
if should_log:
|
||||
self.logger.debug(
|
||||
f"Sent message {message.body} to {message.to} via regular inbox."
|
||||
)
|
||||
else:
|
||||
# Apparently target agent is on a different process, send via ZMQ
|
||||
topic = f"internal/{receiver}".encode()
|
||||
body = message.model_dump_json().encode()
|
||||
await self._internal_pub_socket.send_multipart([topic, body])
|
||||
if should_log:
|
||||
self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.")
|
||||
|
||||
async def _process_inbox(self):
|
||||
"""
|
||||
Internal loop that processes messages from the inbox.
|
||||
|
||||
Reads messages from ``self.inbox`` and passes them to :meth:`handle_message`.
|
||||
"""
|
||||
while self._running:
|
||||
msg = await self.inbox.get()
|
||||
await self.handle_message(msg)
|
||||
|
||||
async def _receive_internal_zmq_loop(self):
|
||||
"""
|
||||
Internal loop that listens for ZMQ messages.
|
||||
|
||||
Subscribes to ``internal/<agent_name>`` topics. When a message is received,
|
||||
it is deserialized into an :class:`InternalMessage` and put into the local inbox.
|
||||
This bridges the gap between inter-process ZMQ communication and the intra-process inbox.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
_, body = await self._internal_sub_socket.recv_multipart()
|
||||
|
||||
msg = InternalMessage.model_validate_json(body)
|
||||
|
||||
await self.inbox.put(msg)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
self.logger.exception("Could not process ZMQ message.")
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle an incoming message.
|
||||
|
||||
This method must be overridden by subclasses to define how the agent reacts to messages.
|
||||
|
||||
:param msg: The received message.
|
||||
:raises NotImplementedError: If not overridden by the subclass.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_behavior(self, coro: Coroutine) -> Task:
|
||||
"""
|
||||
Add a background behavior (task) to the agent.
|
||||
|
||||
This is the preferred way to run continuous loops or long-running tasks within an agent.
|
||||
The task is tracked and will be automatically cancelled when :meth:`stop` is called.
|
||||
|
||||
:param coro: The coroutine to execute as a task.
|
||||
"""
|
||||
|
||||
async def try_coro(coro_: Coroutine):
|
||||
try:
|
||||
await coro_
|
||||
except asyncio.CancelledError:
|
||||
self.logger.debug("A behavior was canceled successfully: %s", coro_)
|
||||
except Exception:
|
||||
self.logger.warning("An exception occurred in a behavior.", exc_info=True)
|
||||
|
||||
task = asyncio.create_task(try_coro(coro))
|
||||
self._tasks.add(task)
|
||||
task.add_done_callback(self._tasks.discard)
|
||||
return task
|
||||
215
src/control_backend/core/config.py
Normal file
215
src/control_backend/core/config.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
An exhaustive overview of configurable options. All of these can be set using environment variables
|
||||
by nesting with double underscores (__). Start from the ``Settings`` class.
|
||||
|
||||
For example, ``settings.ri_host`` becomes ``RI_HOST``, and
|
||||
``settings.zmq_settings.ri_communication_address`` becomes
|
||||
``ZMQ_SETTINGS__RI_COMMUNICATION_ADDRESS``.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class ZMQSettings(BaseModel):
|
||||
"""
|
||||
Configuration for ZeroMQ (ZMQ) addresses used for inter-process communication.
|
||||
|
||||
:ivar internal_pub_address: Address for the internal PUB socket.
|
||||
:ivar internal_sub_address: Address for the internal SUB socket.
|
||||
:ivar ri_communication_address: Address for the endpoint that the Robot Interface connects to.
|
||||
:ivar vad_pub_address: Address that the VAD agent binds to and publishes audio segments to.
|
||||
"""
|
||||
|
||||
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
||||
|
||||
internal_pub_address: str = "tcp://localhost:5560"
|
||||
internal_sub_address: str = "tcp://localhost:5561"
|
||||
ri_communication_address: str = "tcp://*:5555"
|
||||
internal_gesture_rep_adress: str = "tcp://localhost:7788"
|
||||
vad_pub_address: str = "inproc://vad_stream"
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
"""
|
||||
Names of the various agents in the system. These names are used for routing messages.
|
||||
|
||||
:ivar bdi_core_name: Name of the BDI Core Agent.
|
||||
:ivar bdi_program_manager_name: Name of the BDI Program Manager Agent.
|
||||
:ivar text_belief_extractor_name: Name of the Text Belief Extractor Agent.
|
||||
:ivar vad_name: Name of the Voice Activity Detection (VAD) Agent.
|
||||
:ivar llm_name: Name of the Large Language Model (LLM) Agent.
|
||||
:ivar test_name: Name of the Test Agent.
|
||||
:ivar transcription_name: Name of the Transcription Agent.
|
||||
:ivar ri_communication_name: Name of the RI Communication Agent.
|
||||
:ivar robot_speech_name: Name of the Robot Speech Agent.
|
||||
"""
|
||||
|
||||
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
||||
|
||||
# agent names
|
||||
bdi_core_name: str = "bdi_core_agent"
|
||||
bdi_program_manager_name: str = "bdi_program_manager_agent"
|
||||
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
||||
vad_name: str = "vad_agent"
|
||||
llm_name: str = "llm_agent"
|
||||
test_name: str = "test_agent"
|
||||
transcription_name: str = "transcription_agent"
|
||||
ri_communication_name: str = "ri_communication_agent"
|
||||
robot_speech_name: str = "robot_speech_agent"
|
||||
robot_gesture_name: str = "robot_gesture_agent"
|
||||
user_interrupt_name: str = "user_interrupt_agent"
|
||||
|
||||
|
||||
class BehaviourSettings(BaseModel):
|
||||
"""
|
||||
Configuration for agent behaviors and parameters.
|
||||
|
||||
:ivar sleep_s: Default sleep time in seconds for loops.
|
||||
:ivar comm_setup_max_retries: Maximum number of retries for setting up communication.
|
||||
:ivar socket_poller_timeout_ms: Timeout in milliseconds for socket polling.
|
||||
:ivar vad_prob_threshold: Probability threshold for Voice Activity Detection.
|
||||
:ivar vad_initial_since_speech: Initial value for 'since speech' counter in VAD.
|
||||
:ivar vad_non_speech_patience_chunks: Number of non-speech chunks to wait before speech ended.
|
||||
:ivar vad_begin_silence_chunks: The number of chunks of silence to prepend to speech chunks.
|
||||
:ivar transcription_max_concurrent_tasks: Maximum number of concurrent transcription tasks.
|
||||
:ivar transcription_words_per_minute: Estimated words per minute for transcription timing.
|
||||
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
|
||||
:ivar transcription_token_buffer: Buffer for transcription tokens.
|
||||
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
|
||||
:ivar trigger_time_to_wait: Amount of milliseconds to wait before informing the UI about trigger
|
||||
completion.
|
||||
"""
|
||||
|
||||
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
||||
|
||||
sleep_s: float = 1.0
|
||||
comm_setup_max_retries: int = 5
|
||||
socket_poller_timeout_ms: int = 100
|
||||
|
||||
# VAD settings
|
||||
vad_prob_threshold: float = 0.5
|
||||
vad_initial_since_speech: int = 100
|
||||
vad_non_speech_patience_chunks: int = 15
|
||||
vad_begin_silence_chunks: int = 6
|
||||
|
||||
# transcription behaviour
|
||||
transcription_max_concurrent_tasks: int = 3
|
||||
transcription_words_per_minute: int = 300
|
||||
transcription_words_per_token: float = 0.75 # (3 words = 4 tokens)
|
||||
transcription_token_buffer: int = 10
|
||||
|
||||
# Text belief extractor settings
|
||||
conversation_history_length_limit: int = 10
|
||||
|
||||
# AgentSpeak related settings
|
||||
trigger_time_to_wait: int = 2000
|
||||
agentspeak_file: str = "src/control_backend/agents/bdi/agentspeak.asl"
|
||||
|
||||
|
||||
class LLMSettings(BaseModel):
|
||||
"""
|
||||
Configuration for the Large Language Model (LLM).
|
||||
|
||||
:ivar local_llm_url: URL for the local LLM API.
|
||||
:ivar local_llm_model: Name of the local LLM model to use.
|
||||
:ivar chat_temperature: The temperature to use while generating chat responses.
|
||||
:ivar code_temperature: The temperature to use while generating code-like responses like during
|
||||
belief inference.
|
||||
:ivar n_parallel: The number of parallel calls allowed to be made to the LLM.
|
||||
"""
|
||||
|
||||
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
||||
|
||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||
local_llm_model: str = "gpt-oss"
|
||||
api_key: str = ""
|
||||
chat_temperature: float = 1.0
|
||||
code_temperature: float = 0.3
|
||||
n_parallel: int = 4
|
||||
|
||||
|
||||
class VADSettings(BaseModel):
|
||||
"""
|
||||
Configuration for Voice Activity Detection (VAD) model.
|
||||
|
||||
:ivar repo_or_dir: Repository or directory for the VAD model.
|
||||
:ivar model_name: Name of the VAD model.
|
||||
:ivar sample_rate_hz: Sample rate in Hz for the VAD model.
|
||||
"""
|
||||
|
||||
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
||||
|
||||
repo_or_dir: str = "snakers4/silero-vad"
|
||||
model_name: str = "silero_vad"
|
||||
sample_rate_hz: int = 16000
|
||||
|
||||
|
||||
class SpeechModelSettings(BaseModel):
|
||||
"""
|
||||
Configuration for speech recognition models.
|
||||
|
||||
:ivar mlx_model_name: Model name for MLX-based speech recognition.
|
||||
:ivar openai_model_name: Model name for OpenAI-based speech recognition.
|
||||
"""
|
||||
|
||||
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
||||
|
||||
# model identifiers for speech recognition
|
||||
mlx_model_name: str = "mlx-community/whisper-small.en-mlx"
|
||||
openai_model_name: str = "small.en"
|
||||
|
||||
|
||||
class LoggingSettings(BaseModel):
|
||||
"""
|
||||
Configuration for logging.
|
||||
|
||||
:ivar logging_config_file: Path to the logging configuration file.
|
||||
:ivar experiment_log_directory: Location of the experiment logs. Must match the logging config.
|
||||
:ivar experiment_logger_name: Name of the experiment logger. Must match the logging config.
|
||||
"""
|
||||
|
||||
logging_config_file: str = ".logging_config.yaml"
|
||||
experiment_log_directory: str = "experiment_logs"
|
||||
experiment_logger_name: str = "experiment"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
Global application settings.
|
||||
|
||||
:ivar app_title: Title of the application.
|
||||
:ivar ui_url: URL of the frontend UI.
|
||||
:ivar ri_host: The hostname of the Robot Interface.
|
||||
:ivar zmq_settings: ZMQ configuration.
|
||||
:ivar agent_settings: Agent name configuration.
|
||||
:ivar behaviour_settings: Behavior configuration.
|
||||
:ivar vad_settings: VAD model configuration.
|
||||
:ivar speech_model_settings: Speech model configuration.
|
||||
:ivar llm_settings: LLM configuration.
|
||||
"""
|
||||
|
||||
app_title: str = "PepperPlus"
|
||||
|
||||
ui_url: str = "http://localhost:5173"
|
||||
|
||||
ri_host: str = "localhost"
|
||||
|
||||
logging_settings: LoggingSettings = LoggingSettings()
|
||||
|
||||
zmq_settings: ZMQSettings = ZMQSettings()
|
||||
|
||||
agent_settings: AgentSettings = AgentSettings()
|
||||
|
||||
behaviour_settings: BehaviourSettings = BehaviourSettings()
|
||||
|
||||
vad_settings: VADSettings = VADSettings()
|
||||
|
||||
speech_model_settings: SpeechModelSettings = SpeechModelSettings()
|
||||
|
||||
llm_settings: LLMSettings = LLMSettings()
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_nested_delimiter="__")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
4
src/control_backend/logging/__init__.py
Normal file
4
src/control_backend/logging/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .dated_file_handler import DatedFileHandler as DatedFileHandler
|
||||
from .optional_field_formatter import OptionalFieldFormatter as OptionalFieldFormatter
|
||||
from .partial_filter import PartialFilter as PartialFilter
|
||||
from .setup_logging import setup_logging as setup_logging
|
||||
38
src/control_backend/logging/dated_file_handler.py
Normal file
38
src/control_backend/logging/dated_file_handler.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from datetime import datetime
|
||||
from logging import FileHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class DatedFileHandler(FileHandler):
|
||||
def __init__(self, file_prefix: str, **kwargs):
|
||||
if not file_prefix:
|
||||
raise ValueError("`file_prefix` argument cannot be empty.")
|
||||
self._file_prefix = file_prefix
|
||||
kwargs["filename"] = self._make_filename()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _make_filename(self) -> str:
|
||||
"""
|
||||
Create the filename for the current logfile, using the configured file prefix and the
|
||||
current date and time. If the directory does not exist, it gets created.
|
||||
|
||||
:return: A filepath.
|
||||
"""
|
||||
filepath = Path(f"{self._file_prefix}-{datetime.now():%Y%m%d-%H%M%S}.log")
|
||||
if not filepath.parent.is_dir():
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
return str(filepath)
|
||||
|
||||
def do_rollover(self):
|
||||
"""
|
||||
Close the current logfile and create a new one with the current date and time.
|
||||
"""
|
||||
self.acquire()
|
||||
try:
|
||||
if self.stream:
|
||||
self.stream.close()
|
||||
|
||||
self.baseFilename = self._make_filename()
|
||||
self.stream = self._open()
|
||||
finally:
|
||||
self.release()
|
||||
67
src/control_backend/logging/optional_field_formatter.py
Normal file
67
src/control_backend/logging/optional_field_formatter.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
|
||||
class OptionalFieldFormatter(logging.Formatter):
|
||||
"""
|
||||
Logging formatter that supports optional fields marked by `?`.
|
||||
|
||||
Optional fields are denoted by placing a `?` after the field name inside
|
||||
the parentheses, e.g., `%(role?)s`. If the field is not provided in the
|
||||
log call's `extra` dict, it will use the default value from `defaults`
|
||||
or `None` if no default is specified.
|
||||
|
||||
:param fmt: Format string with optional `%(name?)s` style fields.
|
||||
:type fmt: str or None
|
||||
:param datefmt: Date format string, passed to parent :class:`logging.Formatter`.
|
||||
:type datefmt: str or None
|
||||
:param style: Formatting style, must be '%'. Passed to parent.
|
||||
:type style: str
|
||||
:param defaults: Default values for optional fields when not provided.
|
||||
:type defaults: dict or None
|
||||
|
||||
:example:
|
||||
|
||||
>>> formatter = OptionalFieldFormatter(
|
||||
... fmt="%(asctime)s %(levelname)s %(role?)s %(message)s",
|
||||
... defaults={"role": ""-""}
|
||||
... )
|
||||
>>> handler = logging.StreamHandler()
|
||||
>>> handler.setFormatter(formatter)
|
||||
>>> logger = logging.getLogger(__name__)
|
||||
>>> logger.addHandler(handler)
|
||||
>>>
|
||||
>>> logger.chat("Hello there!", extra={"role": "USER"})
|
||||
2025-01-15 10:30:00 CHAT USER Hello there!
|
||||
>>>
|
||||
>>> logger.info("A logging message")
|
||||
2025-01-15 10:30:01 INFO - A logging message
|
||||
|
||||
.. note::
|
||||
Only `%`-style formatting is supported. The `{` and `$` styles are not
|
||||
compatible with this formatter.
|
||||
|
||||
.. seealso::
|
||||
:class:`logging.Formatter` for base formatter documentation.
|
||||
"""
|
||||
|
||||
# Match %(name?)s or %(name?)d etc.
|
||||
OPTIONAL_PATTERN = re.compile(r"%\((\w+)\?\)([sdifFeEgGxXocrba%])")
|
||||
|
||||
def __init__(self, fmt=None, datefmt=None, style="%", defaults=None):
|
||||
self.defaults = defaults or {}
|
||||
|
||||
self.optional_fields = set(self.OPTIONAL_PATTERN.findall(fmt or ""))
|
||||
|
||||
# Convert %(name?)s to %(name)s for standard formatting
|
||||
normalized_fmt = self.OPTIONAL_PATTERN.sub(r"%(\1)\2", fmt or "")
|
||||
|
||||
super().__init__(normalized_fmt, datefmt, style)
|
||||
|
||||
def format(self, record):
|
||||
for field, _ in self.optional_fields:
|
||||
if not hasattr(record, field):
|
||||
default = self.defaults.get(field, None)
|
||||
setattr(record, field, default)
|
||||
|
||||
return super().format(record)
|
||||
10
src/control_backend/logging/partial_filter.py
Normal file
10
src/control_backend/logging/partial_filter.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import logging
|
||||
|
||||
|
||||
class PartialFilter(logging.Filter):
|
||||
"""
|
||||
Class to filter any log records that have the "partial" attribute set to ``True``.
|
||||
"""
|
||||
|
||||
def filter(self, record):
|
||||
return getattr(record, "partial", False) is not True
|
||||
78
src/control_backend/logging/setup_logging.py
Normal file
78
src/control_backend/logging/setup_logging.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import zmq
|
||||
from zmq.log.handlers import PUBHandler
|
||||
|
||||
from control_backend.core.config import settings
|
||||
|
||||
|
||||
def add_logging_level(level_name: str, level_num: int, method_name: str | None = None) -> None:
|
||||
"""
|
||||
Adds a logging level to the `logging` module and the
|
||||
currently configured logging class.
|
||||
"""
|
||||
if not method_name:
|
||||
method_name = level_name.lower()
|
||||
|
||||
if hasattr(logging, level_name):
|
||||
raise AttributeError(f"{level_name} already defined in logging module")
|
||||
if hasattr(logging, method_name):
|
||||
raise AttributeError(f"{method_name} already defined in logging module")
|
||||
if hasattr(logging.getLoggerClass(), method_name):
|
||||
raise AttributeError(f"{method_name} already defined in logger class")
|
||||
|
||||
def log_for_level(self, message, *args, **kwargs):
|
||||
if self.isEnabledFor(level_num):
|
||||
self._log(level_num, message, args, **kwargs)
|
||||
|
||||
def log_to_root(message, *args, **kwargs):
|
||||
logging.log(level_num, message, *args, **kwargs)
|
||||
|
||||
logging.addLevelName(level_num, level_name)
|
||||
setattr(logging, level_name, level_num)
|
||||
setattr(logging.getLoggerClass(), method_name, log_for_level)
|
||||
setattr(logging, method_name, log_to_root)
|
||||
|
||||
|
||||
def setup_logging(path: str = settings.logging_settings.logging_config_file) -> None:
|
||||
"""
|
||||
Setup logging configuration of the CB. Tries to load the logging configuration from a file,
|
||||
in which we specify custom loggers, formatters, handlers, etc.
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
if os.path.exists(path):
|
||||
with open(path) as f:
|
||||
try:
|
||||
config = yaml.safe_load(f.read())
|
||||
except (AttributeError, yaml.YAMLError) as e:
|
||||
logging.warning(f"Could not load logging configuration: {e}")
|
||||
config = {}
|
||||
|
||||
custom_levels = config.get("custom_levels", {}) or {}
|
||||
for level_name, level_num in custom_levels.items():
|
||||
add_logging_level(level_name, level_num)
|
||||
|
||||
if config.get("handlers") is not None and config.get("handlers").get("ui"):
|
||||
pub_socket = zmq.Context.instance().socket(zmq.PUB)
|
||||
pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||
config["handlers"]["ui"]["interface_or_socket"] = pub_socket
|
||||
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
# Patch ZMQ PUBHandler to know about custom levels
|
||||
if custom_levels:
|
||||
for logger_name in config.get("loggers", {}):
|
||||
logger = logging.getLogger(logger_name)
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, PUBHandler):
|
||||
# Use the INFO formatter as the default template
|
||||
default_fmt = handler.formatters[logging.INFO]
|
||||
for level_num in custom_levels.values():
|
||||
handler.setFormatter(default_fmt, level=level_num)
|
||||
|
||||
else:
|
||||
logging.warning("Logging config file not found. Using default logging configuration.")
|
||||
189
src/control_backend/main.py
Normal file
189
src/control_backend/main.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Control Backend Main Application.
|
||||
|
||||
This module defines the FastAPI application that serves as the entry point for the
|
||||
Control Backend. It manages the lifecycle of the entire system, including:
|
||||
|
||||
1. **Socket Initialization**: Setting up the internal ZeroMQ PUB/SUB proxy for agent communication.
|
||||
2. **Agent Management**: Instantiating and starting all agents.
|
||||
3. **API Routing**: Exposing REST endpoints for external interaction.
|
||||
|
||||
Lifecycle Manager
|
||||
-----------------
|
||||
The :func:`lifespan` context manager handles the startup and shutdown sequences:
|
||||
- **Startup**: Configures logging, starts the ZMQ proxy, connects sockets, and launches agents.
|
||||
- **Shutdown**: Handles graceful cleanup (though currently minimal).
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
|
||||
import zmq
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from zmq.asyncio import Context
|
||||
|
||||
# BDI agents
|
||||
from control_backend.agents.bdi import (
|
||||
BDICoreAgent,
|
||||
TextBeliefExtractorAgent,
|
||||
)
|
||||
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
|
||||
|
||||
# Communication agents
|
||||
from control_backend.agents.communication import RICommunicationAgent
|
||||
|
||||
# Emotional Agents
|
||||
# LLM Agents
|
||||
from control_backend.agents.llm import LLMAgent
|
||||
|
||||
# User Interrupt Agent
|
||||
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
||||
|
||||
# Other backend imports
|
||||
from control_backend.api.v1.router import api_router
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.logging import setup_logging
|
||||
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_sockets():
|
||||
"""
|
||||
Initialize and run the internal ZeroMQ Proxy (XPUB/XSUB).
|
||||
|
||||
This proxy acts as the central message bus, forwarding messages published on the
|
||||
internal PUB address to all subscribers on the internal SUB address.
|
||||
"""
|
||||
context = Context.instance()
|
||||
|
||||
internal_pub_socket = context.socket(zmq.XPUB)
|
||||
internal_pub_socket.bind(settings.zmq_settings.internal_sub_address)
|
||||
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
|
||||
|
||||
internal_sub_socket = context.socket(zmq.XSUB)
|
||||
internal_sub_socket.bind(settings.zmq_settings.internal_pub_address)
|
||||
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
|
||||
try:
|
||||
zmq.proxy(internal_sub_socket, internal_pub_socket)
|
||||
except zmq.ZMQError:
|
||||
logger.warning("Error while handling PUB/SUB proxy. Closing sockets.")
|
||||
finally:
|
||||
internal_pub_socket.close()
|
||||
internal_sub_socket.close()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Application lifespan context manager to handle startup and shutdown events.
|
||||
"""
|
||||
# --- APPLICATION STARTUP ---
|
||||
setup_logging()
|
||||
logger.info("%s is starting up.", app.title)
|
||||
|
||||
# Initiate sockets
|
||||
proxy_thread = threading.Thread(target=setup_sockets)
|
||||
proxy_thread.daemon = True
|
||||
proxy_thread.start()
|
||||
|
||||
context = Context.instance()
|
||||
|
||||
endpoints_pub_socket = context.socket(zmq.PUB)
|
||||
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||
app.state.endpoints_pub_socket = endpoints_pub_socket
|
||||
|
||||
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STARTING.value])
|
||||
|
||||
# --- Initialize Agents ---
|
||||
logger.info("Initializing and starting agents.")
|
||||
|
||||
agents_to_start = {
|
||||
"RICommunicationAgent": (
|
||||
RICommunicationAgent,
|
||||
{
|
||||
"name": settings.agent_settings.ri_communication_name,
|
||||
"address": settings.zmq_settings.ri_communication_address,
|
||||
"bind": True,
|
||||
},
|
||||
),
|
||||
"LLMAgent": (
|
||||
LLMAgent,
|
||||
{
|
||||
"name": settings.agent_settings.llm_name,
|
||||
},
|
||||
),
|
||||
"BDICoreAgent": (
|
||||
BDICoreAgent,
|
||||
{
|
||||
"name": settings.agent_settings.bdi_core_name,
|
||||
},
|
||||
),
|
||||
"TextBeliefExtractorAgent": (
|
||||
TextBeliefExtractorAgent,
|
||||
{
|
||||
"name": settings.agent_settings.text_belief_extractor_name,
|
||||
},
|
||||
),
|
||||
"ProgramManagerAgent": (
|
||||
BDIProgramManager,
|
||||
{
|
||||
"name": settings.agent_settings.bdi_program_manager_name,
|
||||
},
|
||||
),
|
||||
"UserInterruptAgent": (
|
||||
UserInterruptAgent,
|
||||
{
|
||||
"name": settings.agent_settings.user_interrupt_name,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
agents = []
|
||||
|
||||
for name, (agent_class, kwargs) in agents_to_start.items():
|
||||
try:
|
||||
logger.debug("Starting agent: %s", name)
|
||||
agent_instance = agent_class(**kwargs)
|
||||
await agent_instance.start()
|
||||
agents.append(agent_instance)
|
||||
logger.info("Agent '%s' started successfully.", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
|
||||
raise
|
||||
|
||||
logger.info("Application startup complete.")
|
||||
|
||||
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.RUNNING.value])
|
||||
|
||||
yield
|
||||
|
||||
# --- APPLICATION SHUTDOWN ---
|
||||
logger.info("%s is shutting down.", app.title)
|
||||
|
||||
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
|
||||
# Additional shutdown logic goes here
|
||||
for agent in agents:
|
||||
await agent.stop()
|
||||
|
||||
logger.info("Application shutdown complete.")
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
app = FastAPI(title=settings.app_title, lifespan=lifespan)
|
||||
|
||||
# This middleware allows other origins to communicate with us
|
||||
app.add_middleware(
|
||||
CORSMiddleware, # https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS
|
||||
allow_origins=[settings.ui_url], # address of our UI application
|
||||
allow_methods=["*"], # GET, POST, etc.
|
||||
)
|
||||
|
||||
app.include_router(api_router, prefix="") # TODO: make prefix /api/v1
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
0
src/control_backend/schemas/__init__.py
Normal file
0
src/control_backend/schemas/__init__.py
Normal file
25
src/control_backend/schemas/belief_list.py
Normal file
25
src/control_backend/schemas/belief_list.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from control_backend.schemas.program import BaseGoal
|
||||
from control_backend.schemas.program import Belief as ProgramBelief
|
||||
|
||||
|
||||
class BeliefList(BaseModel):
|
||||
"""
|
||||
Represents a list of beliefs, separated from a program. Useful in agents which need to
|
||||
communicate beliefs.
|
||||
|
||||
:ivar: beliefs: The list of beliefs.
|
||||
"""
|
||||
|
||||
beliefs: list[ProgramBelief]
|
||||
|
||||
|
||||
class GoalList(BaseModel):
|
||||
"""
|
||||
Represents a list of goals, used for communicating multiple goals between agents.
|
||||
|
||||
:ivar goals: The list of goals.
|
||||
"""
|
||||
|
||||
goals: list[BaseGoal]
|
||||
35
src/control_backend/schemas/belief_message.py
Normal file
35
src/control_backend/schemas/belief_message.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Belief(BaseModel):
|
||||
"""
|
||||
Represents a single belief in the BDI system.
|
||||
|
||||
:ivar name: The functor or name of the belief (e.g., 'user_said').
|
||||
:ivar arguments: A list of string arguments for the belief, or None if the belief has no
|
||||
arguments.
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: list[str] | None = None
|
||||
|
||||
# To make it hashable
|
||||
model_config = {"frozen": True}
|
||||
|
||||
|
||||
class BeliefMessage(BaseModel):
|
||||
"""
|
||||
A container for communicating beliefs between agents.
|
||||
|
||||
:ivar create: Beliefs to create.
|
||||
:ivar delete: Beliefs to delete.
|
||||
:ivar replace: Beliefs to replace. Deletes all beliefs with the same name, replacing them with
|
||||
one new belief.
|
||||
"""
|
||||
|
||||
create: list[Belief] = []
|
||||
delete: list[Belief] = []
|
||||
replace: list[Belief] = []
|
||||
|
||||
def has_values(self) -> bool:
|
||||
return len(self.create) > 0 or len(self.delete) > 0 or len(self.replace) > 0
|
||||
23
src/control_backend/schemas/chat_history.py
Normal file
23
src/control_backend/schemas/chat_history.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""
|
||||
Represents a single message in a conversation.
|
||||
|
||||
:ivar role: The role of the speaker (e.g., 'user', 'assistant').
|
||||
:ivar content: The text content of the message.
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatHistory(BaseModel):
|
||||
"""
|
||||
Represents a sequence of chat messages, forming a conversation history.
|
||||
|
||||
:ivar messages: An ordered list of :class:`ChatMessage` objects.
|
||||
"""
|
||||
|
||||
messages: list[ChatMessage]
|
||||
14
src/control_backend/schemas/events.py
Normal file
14
src/control_backend/schemas/events.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ButtonPressedEvent(BaseModel):
|
||||
"""
|
||||
Represents a button press event from the UI.
|
||||
|
||||
:ivar type: The type of event (e.g., 'speech', 'gesture', 'override').
|
||||
:ivar context: Additional data associated with the event (e.g., speech text, gesture name,
|
||||
or ID).
|
||||
"""
|
||||
|
||||
type: str
|
||||
context: str
|
||||
19
src/control_backend/schemas/internal_message.py
Normal file
19
src/control_backend/schemas/internal_message.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class InternalMessage(BaseModel):
|
||||
"""
|
||||
Standard message envelope for communication between agents within the Control Backend.
|
||||
|
||||
:ivar to: The name(s) of the destination agent(s).
|
||||
:ivar sender: The name of the sending agent.
|
||||
:ivar body: The string payload (often a JSON-serialized model).
|
||||
:ivar thread: An optional thread identifier/topic to categorize the message (e.g., 'beliefs').
|
||||
"""
|
||||
|
||||
to: str | Iterable[str]
|
||||
sender: str | None = None
|
||||
body: str
|
||||
thread: str | None = None
|
||||
18
src/control_backend/schemas/llm_prompt_message.py
Normal file
18
src/control_backend/schemas/llm_prompt_message.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMPromptMessage(BaseModel):
|
||||
"""
|
||||
Payload sent from the BDI agent to the LLM agent.
|
||||
|
||||
Contains the user's text input along with the dynamic context (norms and goals)
|
||||
that the LLM should use to generate a response.
|
||||
|
||||
:ivar text: The user's input text.
|
||||
:ivar norms: A list of active behavioral norms.
|
||||
:ivar goals: A list of active goals to pursue.
|
||||
"""
|
||||
|
||||
text: str
|
||||
norms: list[str]
|
||||
goals: list[str]
|
||||
9
src/control_backend/schemas/message.py
Normal file
9
src/control_backend/schemas/message.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""
|
||||
A simple generic message wrapper, typically used for simple API responses.
|
||||
"""
|
||||
|
||||
message: str
|
||||
311
src/control_backend/schemas/program.py
Normal file
311
src/control_backend/schemas/program.py
Normal file
@@ -0,0 +1,311 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import UUID4, BaseModel
|
||||
|
||||
|
||||
class ProgramElement(BaseModel):
|
||||
"""
|
||||
Represents a basic element of our behavior program.
|
||||
|
||||
:ivar name: The researcher-assigned name of the element.
|
||||
:ivar id: Unique identifier.
|
||||
"""
|
||||
|
||||
name: str
|
||||
id: UUID4
|
||||
|
||||
# To make program elements hashable
|
||||
model_config = {"frozen": True}
|
||||
|
||||
|
||||
class LogicalOperator(Enum):
|
||||
"""
|
||||
Logical operators for combining beliefs.
|
||||
|
||||
These operators define how beliefs can be combined to form more complex
|
||||
logical conditions. They are used in inferred beliefs to create compound
|
||||
beliefs from simpler ones.
|
||||
|
||||
AND: Both operands must be true for the result to be true.
|
||||
OR: At least one operand must be true for the result to be true.
|
||||
"""
|
||||
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
|
||||
|
||||
type Belief = KeywordBelief | SemanticBelief | InferredBelief
|
||||
type BasicBelief = KeywordBelief | SemanticBelief
|
||||
|
||||
|
||||
class KeywordBelief(ProgramElement):
|
||||
"""
|
||||
Represents a belief that is activated when a specific keyword is detected in the user's speech.
|
||||
|
||||
Keyword beliefs provide a simple but effective way to detect specific topics
|
||||
or intentions in user speech. They are triggered when the exact keyword
|
||||
string appears in the transcribed user input.
|
||||
|
||||
:ivar keyword: The string to look for in the transcription.
|
||||
|
||||
Example:
|
||||
A keyword belief with keyword="robot" would be activated when the user
|
||||
says "I like the robot" or "Tell me about robots".
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
keyword: str
|
||||
|
||||
|
||||
class SemanticBelief(ProgramElement):
|
||||
"""
|
||||
Represents a belief whose truth value is determined by an LLM analyzing the conversation
|
||||
context.
|
||||
|
||||
Semantic beliefs provide more sophisticated belief detection by using
|
||||
an LLM to analyze the conversation context and determine
|
||||
if the belief should be considered true. This allows for more nuanced
|
||||
and context-aware belief evaluation.
|
||||
|
||||
:ivar description: A natural language description of what this belief represents,
|
||||
used as a prompt for the LLM.
|
||||
|
||||
Example:
|
||||
A semantic belief with description="The user is expressing frustration"
|
||||
would be activated when the LLM determines that the user's statements
|
||||
indicate frustration, even if no specific keywords are used.
|
||||
"""
|
||||
|
||||
description: str
|
||||
|
||||
|
||||
class InferredBelief(ProgramElement):
|
||||
"""
|
||||
Represents a belief derived from other beliefs using logical operators.
|
||||
|
||||
Inferred beliefs allow for the creation of complex belief structures by
|
||||
combining simpler beliefs using logical operators. This enables the
|
||||
representation of sophisticated conditions and relationships between
|
||||
different aspects of the conversation or context.
|
||||
|
||||
:ivar operator: The :class:`LogicalOperator` (AND/OR) to apply.
|
||||
:ivar left: The left operand (another belief).
|
||||
:ivar right: The right operand (another belief).
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
operator: LogicalOperator
|
||||
left: Belief
|
||||
right: Belief
|
||||
|
||||
|
||||
class Norm(ProgramElement):
|
||||
"""
|
||||
Base class for behavioral norms that guide the robot's interactions.
|
||||
|
||||
Norms represent guidelines, principles, or rules that should govern the
|
||||
robot's behavior during interactions. They can be either basic (always
|
||||
active in their phase) or conditional (active only when specific beliefs
|
||||
are true).
|
||||
|
||||
:ivar norm: The textual description of the norm.
|
||||
:ivar critical: Whether this norm is considered critical and should be strictly enforced.
|
||||
|
||||
Critical norms are currently not supported yet, but are intended for norms that should
|
||||
ABSOLUTELY NOT be violated, possible cheched by additional validator agents.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
norm: str
|
||||
critical: bool = False
|
||||
|
||||
|
||||
class BasicNorm(Norm):
|
||||
"""
|
||||
A simple behavioral norm that is always considered for activation when its phase is active.
|
||||
|
||||
Basic norms are the most straightforward type of norms. They are active
|
||||
throughout their assigned phase and provide consistent behavioral guidance
|
||||
without any additional conditions.
|
||||
|
||||
These norms are suitable for general principles that should always apply
|
||||
during a particular interaction phase.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ConditionalNorm(Norm):
|
||||
"""
|
||||
A behavioral norm that is only active when a specific condition (belief) is met.
|
||||
|
||||
Conditional norms provide context-sensitive behavioral guidance. They are
|
||||
only active and considered for activation when their associated condition
|
||||
(belief) is true. This allows for more nuanced and adaptive behavior that
|
||||
responds to the specific context of the interaction.
|
||||
|
||||
An important note, is that the current implementation of these norms for keyword-based beliefs
|
||||
is that they only hold for 1 turn, as keyword-based conditions often express temporary
|
||||
conditions.
|
||||
|
||||
:ivar condition: The :class:`Belief` that must hold for this norm to be active.
|
||||
|
||||
Example:
|
||||
A conditional norm with the condition "user is frustrated" might specify
|
||||
that the robot should use more empathetic language and avoid complex topics.
|
||||
"""
|
||||
|
||||
condition: Belief
|
||||
|
||||
|
||||
type PlanElement = Goal | Action
|
||||
|
||||
|
||||
class Plan(ProgramElement):
|
||||
"""
|
||||
Represents a list of steps to execute. Each of these steps can be a goal (with its own plan)
|
||||
or a simple action.
|
||||
|
||||
Plans define sequences of actions and subgoals that the robot should execute
|
||||
to achieve a particular objective. They form the procedural knowledge of
|
||||
the behavior program, specifying what the robot should do in different
|
||||
situations.
|
||||
|
||||
:ivar steps: The actions or subgoals to execute, in order.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
steps: list[PlanElement]
|
||||
|
||||
|
||||
class BaseGoal(ProgramElement):
|
||||
"""
|
||||
Represents an objective to be achieved. This base version does not include a plan to achieve
|
||||
this goal, and is used in semantic belief extraction.
|
||||
|
||||
:ivar description: A description of the goal, used to determine if it has been achieved.
|
||||
:ivar can_fail: Whether we can fail to achieve the goal after executing the plan.
|
||||
|
||||
The can_fail attribute determines whether goal achievement is binary
|
||||
(success/failure) or whether it can be determined through conversation
|
||||
analysis.
|
||||
"""
|
||||
|
||||
description: str = ""
|
||||
can_fail: bool = True
|
||||
|
||||
|
||||
class Goal(BaseGoal):
|
||||
"""
|
||||
Represents an objective to be achieved. To reach the goal, we should execute the corresponding
|
||||
plan. It inherits from the BaseGoal a variable `can_fail`, which, if true, will cause the
|
||||
completion to be determined based on the conversation.
|
||||
|
||||
Goals extend base goals by including a specific plan to achieve the objective.
|
||||
They form the core of the robot's proactive behavior, defining both what
|
||||
should be accomplished and how to accomplish it.
|
||||
|
||||
Instances of this goal are not hashable because a plan is not hashable.
|
||||
|
||||
:ivar plan: The plan to execute.
|
||||
"""
|
||||
|
||||
plan: Plan
|
||||
|
||||
|
||||
type Action = SpeechAction | GestureAction | LLMAction
|
||||
|
||||
|
||||
class SpeechAction(ProgramElement):
|
||||
"""
|
||||
An action where the robot speaks a predefined literal text.
|
||||
|
||||
:ivar text: The text content to be spoken.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
text: str
|
||||
|
||||
|
||||
class Gesture(BaseModel):
|
||||
"""
|
||||
Defines a physical gesture for the robot to perform.
|
||||
|
||||
:ivar type: Whether to use a specific "single" gesture or a random one from a "tag" category.
|
||||
:ivar name: The identifier for the gesture or tag.
|
||||
|
||||
The type field determines how the gesture is selected:
|
||||
- "single": Use the specific gesture identified by name
|
||||
- "tag": Select a random gesture from the category identified by name
|
||||
"""
|
||||
|
||||
type: Literal["tag", "single"]
|
||||
name: str
|
||||
|
||||
|
||||
class GestureAction(ProgramElement):
|
||||
"""
|
||||
An action where the robot performs a physical gesture.
|
||||
|
||||
:ivar gesture: The :class:`Gesture` definition.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
gesture: Gesture
|
||||
|
||||
|
||||
class LLMAction(ProgramElement):
|
||||
"""
|
||||
An action that triggers an LLM-generated conversational response.
|
||||
|
||||
:ivar goal: A temporary conversational goal to guide the LLM's response generation.
|
||||
|
||||
The goal parameter provides high-level guidance to the LLM about what
|
||||
the response should aim to achieve, while allowing the LLM flexibility
|
||||
in how to express it.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
goal: str
|
||||
|
||||
|
||||
class Trigger(ProgramElement):
|
||||
"""
|
||||
Defines a reactive behavior: when the condition (belief) is met, the plan is executed.
|
||||
|
||||
:ivar condition: The :class:`Belief` that triggers this behavior.
|
||||
:ivar plan: The :class:`Plan` to execute upon activation.
|
||||
"""
|
||||
|
||||
condition: Belief
|
||||
plan: Plan
|
||||
|
||||
|
||||
class Phase(ProgramElement):
|
||||
"""
|
||||
A logical stage in the interaction program, grouping norms, goals, and triggers.
|
||||
|
||||
:ivar norms: List of norms active during this phase.
|
||||
:ivar goals: List of goals the robot pursues in this phase.
|
||||
:ivar triggers: List of reactive behaviors defined for this phase.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
norms: list[BasicNorm | ConditionalNorm]
|
||||
goals: list[Goal]
|
||||
triggers: list[Trigger]
|
||||
|
||||
|
||||
class Program(BaseModel):
|
||||
"""
|
||||
The top-level container for a complete robot behavior definition.
|
||||
|
||||
The Program class represents the complete specification of a robot's
|
||||
behavioral logic. It contains all the phases, norms, goals, triggers,
|
||||
and actions that define how the robot should behave during interactions.
|
||||
|
||||
:ivar phases: An ordered list of :class:`Phase` objects defining the interaction flow.
|
||||
"""
|
||||
|
||||
phases: list[Phase]
|
||||
16
src/control_backend/schemas/program_status.py
Normal file
16
src/control_backend/schemas/program_status.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import Enum
|
||||
|
||||
PROGRAM_STATUS = b"internal/program_status"
|
||||
"""A topic key for the program status."""
|
||||
|
||||
|
||||
class ProgramStatus(Enum):
|
||||
"""
|
||||
Used in internal communication, to tell agents what the status of the program is.
|
||||
|
||||
For example, the VAD agent only starts listening when the program is RUNNING.
|
||||
"""
|
||||
|
||||
STARTING = b"starting"
|
||||
RUNNING = b"running"
|
||||
STOPPING = b"stopping"
|
||||
79
src/control_backend/schemas/ri_message.py
Normal file
79
src/control_backend/schemas/ri_message.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class RIEndpoint(str, Enum):
|
||||
"""
|
||||
Enumeration of valid endpoints for the Robot Interface (RI).
|
||||
"""
|
||||
|
||||
SPEECH = "actuate/speech"
|
||||
GESTURE_SINGLE = "actuate/gesture/single"
|
||||
GESTURE_TAG = "actuate/gesture/tag"
|
||||
PING = "ping"
|
||||
NEGOTIATE_PORTS = "negotiate/ports"
|
||||
PAUSE = ""
|
||||
|
||||
|
||||
class RIMessage(BaseModel):
|
||||
"""
|
||||
Base schema for messages sent to the Robot Interface.
|
||||
|
||||
:ivar endpoint: The target endpoint/action on the RI.
|
||||
:ivar data: The payload associated with the action.
|
||||
"""
|
||||
|
||||
endpoint: RIEndpoint
|
||||
data: Any
|
||||
|
||||
|
||||
class SpeechCommand(RIMessage):
|
||||
"""
|
||||
A specific command to make the robot speak.
|
||||
|
||||
:ivar endpoint: Fixed to ``RIEndpoint.SPEECH``.
|
||||
:ivar data: The text string to be spoken.
|
||||
"""
|
||||
|
||||
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.SPEECH)
|
||||
data: str
|
||||
is_priority: bool = False
|
||||
|
||||
|
||||
class GestureCommand(RIMessage):
|
||||
"""
|
||||
A specific command to make the robot do a gesture.
|
||||
|
||||
:ivar endpoint: Should be ``RIEndpoint.GESTURE_SINGLE`` or ``RIEndpoint.GESTURE_TAG``.
|
||||
:ivar data: The id of the gesture to be executed.
|
||||
"""
|
||||
|
||||
endpoint: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] - We validate this stricter rule ourselves
|
||||
RIEndpoint.GESTURE_SINGLE, RIEndpoint.GESTURE_TAG
|
||||
]
|
||||
data: str
|
||||
is_priority: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_endpoint(self):
|
||||
allowed = {
|
||||
RIEndpoint.GESTURE_SINGLE,
|
||||
RIEndpoint.GESTURE_TAG,
|
||||
}
|
||||
if self.endpoint not in allowed:
|
||||
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
|
||||
return self
|
||||
|
||||
|
||||
class PauseCommand(RIMessage):
|
||||
"""
|
||||
A specific command to pause or unpause the robot's actions.
|
||||
|
||||
:ivar endpoint: Fixed to ``RIEndpoint.PAUSE``.
|
||||
:ivar data: A boolean indicating whether to pause (True) or unpause (False).
|
||||
"""
|
||||
|
||||
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.PAUSE)
|
||||
data: bool
|
||||
Binary file not shown.
207
test/integration/agents/perception/vad_agent/test_vad_agent.py
Normal file
207
test/integration/agents/perception/vad_agent/test_vad_agent.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import random
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.perception.vad_agent import VADAgent
|
||||
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Context.instance")
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def per_transcription_agent(mocker):
|
||||
return mocker.patch(
|
||||
"control_backend.agents.perception.vad_agent.TranscriptionAgent", autospec=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def torch_load(mocker):
|
||||
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
|
||||
model = MagicMock()
|
||||
mock_torch.hub.load.return_value = (model, None)
|
||||
mock_torch.from_numpy.side_effect = lambda arr: arr
|
||||
return mock_torch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_setup(per_transcription_agent):
|
||||
"""
|
||||
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
|
||||
sockets, and starts the TranscriptionAgent without loading real models.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent._streaming_loop = AsyncMock()
|
||||
|
||||
def swallow_background_task(coro):
|
||||
coro.close()
|
||||
|
||||
per_vad_agent.add_behavior = swallow_background_task
|
||||
|
||||
await per_vad_agent.setup()
|
||||
|
||||
per_transcription_agent.assert_called_once()
|
||||
per_transcription_agent.return_value.start.assert_called_once()
|
||||
per_vad_agent._streaming_loop.assert_called_once()
|
||||
assert per_vad_agent.audio_in_socket is not None
|
||||
assert per_vad_agent.audio_out_socket is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_bind", [True, False])
|
||||
def test_in_socket_creation(zmq_context, do_bind: bool):
|
||||
"""
|
||||
Test that the VAD agent creates an audio input socket, differentiating between binding and
|
||||
connecting.
|
||||
"""
|
||||
per_vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
|
||||
|
||||
per_vad_agent._connect_audio_in_socket()
|
||||
|
||||
assert per_vad_agent.audio_in_socket is not None
|
||||
|
||||
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
|
||||
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
|
||||
zmq.SUBSCRIBE,
|
||||
"",
|
||||
)
|
||||
|
||||
if do_bind:
|
||||
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
||||
else:
|
||||
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
|
||||
"tcp://localhost:12345"
|
||||
)
|
||||
|
||||
|
||||
def test_out_socket_creation(zmq_context):
|
||||
"""
|
||||
Test that the VAD agent creates an audio output socket correctly.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
per_vad_agent._connect_audio_out_socket()
|
||||
|
||||
assert per_vad_agent.audio_out_socket is not None
|
||||
|
||||
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
|
||||
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("inproc://vad_stream")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_out_socket_creation_failure(zmq_context):
|
||||
"""
|
||||
Test setup failure when the audio output socket cannot be created.
|
||||
"""
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent.stop = AsyncMock()
|
||||
per_vad_agent._reset_stream = AsyncMock()
|
||||
per_vad_agent._streaming_loop = AsyncMock()
|
||||
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
|
||||
|
||||
def swallow_background_task(coro):
|
||||
coro.close()
|
||||
|
||||
per_vad_agent.add_behavior = swallow_background_task
|
||||
|
||||
await per_vad_agent.setup()
|
||||
|
||||
assert per_vad_agent.audio_out_socket is None
|
||||
per_vad_agent.stop.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop(zmq_context, per_transcription_agent):
|
||||
"""
|
||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent._reset_stream = AsyncMock()
|
||||
per_vad_agent._streaming_loop = AsyncMock()
|
||||
|
||||
def swallow_background_task(coro):
|
||||
coro.close()
|
||||
|
||||
per_vad_agent.add_behavior = swallow_background_task
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
||||
1000,
|
||||
10000,
|
||||
)
|
||||
|
||||
await per_vad_agent.setup()
|
||||
await per_vad_agent.stop()
|
||||
|
||||
assert zmq_context.return_value.socket.return_value.close.call_count == 2
|
||||
assert per_vad_agent.audio_in_socket is None
|
||||
assert per_vad_agent.audio_out_socket is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_application_startup_complete(zmq_context):
|
||||
"""Check that it resets the stream when the program finishes startup."""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent._running = True
|
||||
vad_agent._reset_stream = AsyncMock()
|
||||
vad_agent.program_sub_socket = AsyncMock()
|
||||
vad_agent.program_sub_socket.close = MagicMock()
|
||||
vad_agent.program_sub_socket.recv_multipart.side_effect = [
|
||||
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
|
||||
]
|
||||
|
||||
await vad_agent._status_loop()
|
||||
|
||||
vad_agent._reset_stream.assert_called_once()
|
||||
vad_agent.program_sub_socket.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_application_other_status(zmq_context):
|
||||
"""
|
||||
Check that it does nothing when the internal communication message is a status update, but not
|
||||
running.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent._running = True
|
||||
vad_agent._reset_stream = AsyncMock()
|
||||
vad_agent.program_sub_socket = AsyncMock()
|
||||
|
||||
vad_agent.program_sub_socket.recv_multipart.side_effect = [
|
||||
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
|
||||
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
|
||||
]
|
||||
try:
|
||||
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
|
||||
await vad_agent._status_loop()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
vad_agent._reset_stream.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_application_message_other(zmq_context):
|
||||
"""
|
||||
Check that it does nothing when there's an internal communication message that is not a status
|
||||
update.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent._running = True
|
||||
vad_agent._reset_stream = AsyncMock()
|
||||
vad_agent.program_sub_socket = AsyncMock()
|
||||
|
||||
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
|
||||
|
||||
try:
|
||||
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
|
||||
await vad_agent._status_loop()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
vad_agent._reset_stream.assert_not_called()
|
||||
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.perception.vad_agent import VADAgent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_settings():
|
||||
from control_backend.agents.perception import vad_agent
|
||||
|
||||
vad_agent.settings.behaviour_settings.vad_prob_threshold = 0.5
|
||||
vad_agent.settings.behaviour_settings.vad_non_speech_patience_chunks = 3
|
||||
vad_agent.settings.behaviour_settings.vad_initial_since_speech = 0
|
||||
vad_agent.settings.vad_settings.sample_rate_hz = 16_000
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_torch(mocker):
|
||||
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
|
||||
mock_torch.from_numpy.side_effect = lambda arr: arr
|
||||
return mock_torch
|
||||
|
||||
|
||||
def get_audio_chunks() -> list[bytes]:
|
||||
curr_file = os.path.realpath(__file__)
|
||||
curr_dir = os.path.dirname(curr_file)
|
||||
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
|
||||
|
||||
chunk_size = 512
|
||||
|
||||
chunks = []
|
||||
|
||||
with sf.SoundFile(file, "r") as f:
|
||||
assert f.samplerate == 16000
|
||||
assert f.channels == 1
|
||||
assert f.subtype == "FLOAT"
|
||||
|
||||
while True:
|
||||
data = f.read(chunk_size, dtype="float32")
|
||||
if len(data) != chunk_size:
|
||||
break
|
||||
|
||||
chunks.append(data.tobytes())
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_audio(mocker):
|
||||
"""
|
||||
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
|
||||
input. Ensure that it outputs some fragments with audio.
|
||||
"""
|
||||
audio_chunks = get_audio_chunks()
|
||||
audio_in_socket = AsyncMock()
|
||||
audio_in_socket.recv.side_effect = audio_chunks
|
||||
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
|
||||
mock_poller.return_value.poll = AsyncMock(return_value=[(audio_in_socket, zmq.POLLIN)])
|
||||
|
||||
audio_out_socket = AsyncMock()
|
||||
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
|
||||
# Use a fake model that marks most chunks as speech and ends with a few silences
|
||||
silence_padding = 5
|
||||
probabilities = [1.0] * len(audio_chunks) + [0.0] * silence_padding
|
||||
chunk_bytes = audio_chunks + [b"\x00" * len(audio_chunks[0])] * silence_padding
|
||||
model_item = MagicMock()
|
||||
model_item.item.side_effect = probabilities
|
||||
vad_agent.model = MagicMock(return_value=model_item)
|
||||
|
||||
class DummyPoller:
|
||||
def __init__(self, data, agent):
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
|
||||
async def poll(self, timeout_ms=None):
|
||||
if self.data:
|
||||
return self.data.pop(0)
|
||||
self.agent._running = False
|
||||
return None
|
||||
|
||||
vad_agent.audio_in_poller = DummyPoller(chunk_bytes, vad_agent)
|
||||
vad_agent._ready = AsyncMock()
|
||||
vad_agent._running = True
|
||||
vad_agent.i_since_speech = 0
|
||||
|
||||
await vad_agent._streaming_loop()
|
||||
|
||||
audio_out_socket.send.assert_called()
|
||||
for args in audio_out_socket.send.call_args_list:
|
||||
assert isinstance(args[0][0], bytes)
|
||||
assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio
|
||||
520
test/unit/agents/actuation/test_robot_gesture_agent.py
Normal file
520
test/unit/agents/actuation/test_robot_gesture_agent.py
Normal file
@@ -0,0 +1,520 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.schemas.ri_message import RIEndpoint
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
"""Mock the ZMQ context."""
|
||||
mock_context = mocker.patch(
|
||||
"control_backend.agents.actuation.robot_gesture_agent.azmq.Context.instance"
|
||||
)
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch("control_backend.agents.actuation.robot_gesture_agent.experiment_logger") as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(zmq_context, mocker):
|
||||
"""Setup binds and subscribes to internal commands."""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=True)
|
||||
|
||||
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
# Check PUB socket binding
|
||||
fake_socket.bind.assert_any_call("tcp://localhost:5556")
|
||||
# Check REP socket binding
|
||||
fake_socket.bind.assert_called()
|
||||
|
||||
# Check SUB socket connection and subscriptions
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
|
||||
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"send_gestures")
|
||||
|
||||
# Check behavior was added (twice: once for command loop, once for fetch gestures loop)
|
||||
assert agent.add_behavior.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_connect(zmq_context, mocker):
|
||||
"""Setup connects when bind=False."""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RobotGestureAgent("robot_gesture", address="tcp://localhost:5556", bind=False)
|
||||
|
||||
settings = mocker.patch("control_backend.agents.actuation.robot_gesture_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
# Check PUB socket connection (not binding)
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5556")
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
# Check REP socket binding (always binds)
|
||||
fake_socket.bind.assert_called()
|
||||
|
||||
# Check behavior was added (twice)
|
||||
assert agent.add_behavior.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_sends_valid_gesture_command():
|
||||
"""Internal message with valid gesture tag is forwarded to robot pub socket."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
payload = {
|
||||
"endpoint": RIEndpoint.GESTURE_TAG,
|
||||
"data": "hello", # "hello" is in gesture_data
|
||||
}
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_sends_non_gesture_command():
|
||||
"""Internal message with non-gesture endpoint is not forwarded by this agent."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
payload = {"endpoint": "some_other_endpoint", "data": "invalid_tag_not_in_list"}
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Non-gesture endpoints should not be forwarded by this agent
|
||||
pubsocket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_rejects_invalid_gesture_tag():
|
||||
"""Internal message with invalid gesture tag is not forwarded."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
# Use a tag that's not in gesture_data
|
||||
payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_sends_valid_single_gesture_command():
|
||||
"""Internal message with valid single gesture is forwarded."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
payload = {
|
||||
"endpoint": RIEndpoint.GESTURE_SINGLE,
|
||||
"data": "wave",
|
||||
}
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_rejects_invalid_single_gesture():
|
||||
"""Internal message with invalid single gesture is not forwarded."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
payload = {
|
||||
"endpoint": RIEndpoint.GESTURE_SINGLE,
|
||||
"data": "dance",
|
||||
}
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_valid_single_gesture_payload():
|
||||
"""UI command with valid single gesture is read from SUB and published."""
|
||||
command = {"endpoint": RIEndpoint.GESTURE_SINGLE, "data": "wave"}
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return b"command", json.dumps(command).encode("utf-8")
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", single_gesture_data=["wave", "point"], address="")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_invalid_payload():
|
||||
"""Invalid payload is caught and does not send."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_valid_gesture_payload():
|
||||
"""UI command with valid gesture tag is read from SUB and published."""
|
||||
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"}
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
# stop after first iteration
|
||||
agent._running = False
|
||||
return b"command", json.dumps(command).encode("utf-8")
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_valid_non_gesture_payload():
|
||||
"""UI command with non-gesture endpoint is not forwarded by this agent."""
|
||||
command = {"endpoint": "some_other_endpoint", "data": "anything"}
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return b"command", json.dumps(command).encode("utf-8")
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_invalid_gesture_tag():
|
||||
"""UI command with invalid gesture tag is not forwarded."""
|
||||
command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid_tag_not_in_list"}
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return b"command", json.dumps(command).encode("utf-8")
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_invalid_json():
|
||||
"""Invalid JSON is ignored without sending."""
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return b"command", b"{not_json}"
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_ignores_send_gestures_topic():
|
||||
"""send_gestures topic is ignored in command loop."""
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return b"send_gestures", b"{}"
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_gestures_loop_without_amount():
|
||||
"""Fetch gestures request without amount returns all tags."""
|
||||
fake_repsocket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return b"{}" # Empty JSON request
|
||||
|
||||
fake_repsocket.recv = recv_once
|
||||
fake_repsocket.send = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent(
|
||||
"robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"], address=""
|
||||
)
|
||||
agent.repsocket = fake_repsocket
|
||||
agent._running = True
|
||||
|
||||
await agent._fetch_gestures_loop()
|
||||
|
||||
fake_repsocket.send.assert_awaited_once()
|
||||
|
||||
# Check the response contains all tags
|
||||
args, kwargs = fake_repsocket.send.call_args
|
||||
response = json.loads(args[0])
|
||||
assert "tags" in response
|
||||
assert response["tags"] == ["hello", "yes", "no", "wave", "point"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_gestures_loop_with_amount():
|
||||
"""Fetch gestures request with amount returns limited tags."""
|
||||
fake_repsocket = AsyncMock()
|
||||
amount = 3
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return json.dumps(amount).encode()
|
||||
|
||||
fake_repsocket.recv = recv_once
|
||||
fake_repsocket.send = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent(
|
||||
"robot_gesture", gesture_data=["hello", "yes", "no", "wave", "point"], address=""
|
||||
)
|
||||
agent.repsocket = fake_repsocket
|
||||
agent._running = True
|
||||
|
||||
await agent._fetch_gestures_loop()
|
||||
|
||||
fake_repsocket.send.assert_awaited_once()
|
||||
|
||||
args, kwargs = fake_repsocket.send.call_args
|
||||
response = json.loads(args[0])
|
||||
assert "tags" in response
|
||||
assert len(response["tags"]) == amount
|
||||
assert response["tags"] == ["hello", "yes", "no"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_gestures_loop_with_integer_request():
|
||||
"""Fetch gestures request with integer amount."""
|
||||
fake_repsocket = AsyncMock()
|
||||
amount = 2
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return json.dumps(amount).encode()
|
||||
|
||||
fake_repsocket.recv = recv_once
|
||||
fake_repsocket.send = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.repsocket = fake_repsocket
|
||||
agent._running = True
|
||||
|
||||
await agent._fetch_gestures_loop()
|
||||
|
||||
fake_repsocket.send.assert_awaited_once()
|
||||
|
||||
args, kwargs = fake_repsocket.send.call_args
|
||||
response = json.loads(args[0])
|
||||
assert response["tags"] == ["hello", "yes"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_gestures_loop_with_invalid_json():
|
||||
"""Invalid JSON request returns all tags."""
|
||||
fake_repsocket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return b"not_json"
|
||||
|
||||
fake_repsocket.recv = recv_once
|
||||
fake_repsocket.send = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.repsocket = fake_repsocket
|
||||
agent._running = True
|
||||
|
||||
await agent._fetch_gestures_loop()
|
||||
|
||||
fake_repsocket.send.assert_awaited_once()
|
||||
|
||||
args, kwargs = fake_repsocket.send.call_args
|
||||
response = json.loads(args[0])
|
||||
assert response["tags"] == ["hello", "yes", "no"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_gestures_loop_with_non_integer_json():
|
||||
"""Non-integer JSON request returns all tags."""
|
||||
fake_repsocket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return json.dumps({"not": "an_integer"}).encode()
|
||||
|
||||
fake_repsocket.recv = recv_once
|
||||
fake_repsocket.send = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.repsocket = fake_repsocket
|
||||
agent._running = True
|
||||
|
||||
await agent._fetch_gestures_loop()
|
||||
|
||||
fake_repsocket.send.assert_awaited_once()
|
||||
|
||||
args, kwargs = fake_repsocket.send.call_args
|
||||
response = json.loads(args[0])
|
||||
assert response["tags"] == ["hello", "yes", "no"]
|
||||
|
||||
|
||||
def test_gesture_data_attribute():
|
||||
"""Test that gesture_data returns the expected list."""
|
||||
gesture_data = ["hello", "yes", "no", "wave"]
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=gesture_data, address="")
|
||||
|
||||
assert agent.gesture_data == gesture_data
|
||||
assert isinstance(agent.gesture_data, list)
|
||||
assert len(agent.gesture_data) == 4
|
||||
assert "hello" in agent.gesture_data
|
||||
assert "yes" in agent.gesture_data
|
||||
assert "no" in agent.gesture_data
|
||||
assert "invalid_tag_not_in_list" not in agent.gesture_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_closes_sockets():
|
||||
"""Stop method closes all sockets."""
|
||||
pubsocket = MagicMock()
|
||||
subsocket = MagicMock()
|
||||
repsocket = MagicMock()
|
||||
agent = RobotGestureAgent("robot_gesture", address="")
|
||||
agent.pubsocket = pubsocket
|
||||
agent.subsocket = subsocket
|
||||
agent.repsocket = repsocket
|
||||
|
||||
await agent.stop()
|
||||
|
||||
pubsocket.close.assert_called_once()
|
||||
subsocket.close.assert_called_once()
|
||||
repsocket.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_custom_gesture_data():
|
||||
"""Agent can be initialized with custom gesture data."""
|
||||
custom_gestures = ["custom1", "custom2", "custom3"]
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=custom_gestures, address="")
|
||||
|
||||
assert agent.gesture_data == custom_gestures
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_gestures_loop_handles_exception():
|
||||
"""Exception in fetch gestures loop is caught and logged."""
|
||||
fake_repsocket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
raise Exception("Test exception")
|
||||
|
||||
fake_repsocket.recv = recv_once
|
||||
fake_repsocket.send = AsyncMock()
|
||||
|
||||
agent = RobotGestureAgent("robot_gesture", gesture_data=["hello", "yes", "no"], address="")
|
||||
agent.repsocket = fake_repsocket
|
||||
agent.logger = MagicMock()
|
||||
agent._running = True
|
||||
|
||||
# Should not raise exception
|
||||
await agent._fetch_gestures_loop()
|
||||
|
||||
# Exception should be logged
|
||||
agent.logger.exception.assert_called_once()
|
||||
152
test/unit/agents/actuation/test_robot_speech_agent.py
Normal file
152
test/unit/agents/actuation/test_robot_speech_agent.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
|
||||
|
||||
def mock_speech_agent():
|
||||
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch(
|
||||
"control_backend.agents.actuation.robot_speech_agent.azmq.Context.instance"
|
||||
)
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(zmq_context, mocker):
|
||||
"""Setup binds and subscribes to internal commands."""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=True)
|
||||
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
|
||||
agent.add_behavior.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_connect(zmq_context, mocker):
|
||||
"""Setup connects when bind=False."""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
|
||||
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
agent.add_behavior.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_sends_command():
|
||||
"""Internal message is forwarded to robot pub socket as JSON."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = mock_speech_agent()
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
payload = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False}
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_awaited_once_with(payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_valid_payload(zmq_context):
|
||||
"""UI command is read from SUB and published."""
|
||||
command = {"endpoint": "actuate/speech", "data": "hello", "is_priority": False}
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
# stop after first iteration
|
||||
agent._running = False
|
||||
return (b"command", json.dumps(command).encode("utf-8"))
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
agent = mock_speech_agent()
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_awaited_once_with(command)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_invalid_json():
|
||||
"""Invalid JSON is ignored without sending."""
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return (b"command", b"{not_json}")
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
agent = mock_speech_agent()
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_invalid_payload():
|
||||
"""Invalid payload is caught and does not send."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = mock_speech_agent()
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_closes_sockets():
|
||||
pubsocket = MagicMock()
|
||||
subsocket = MagicMock()
|
||||
agent = mock_speech_agent()
|
||||
agent.pubsocket = pubsocket
|
||||
agent.subsocket = subsocket
|
||||
|
||||
await agent.stop()
|
||||
|
||||
pubsocket.close.assert_called_once()
|
||||
subsocket.close.assert_called_once()
|
||||
186
test/unit/agents/bdi/test_agentspeak_ast.py
Normal file
186
test/unit/agents/bdi/test_agentspeak_ast.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.bdi.agentspeak_ast import (
|
||||
AstAtom,
|
||||
AstBinaryOp,
|
||||
AstLiteral,
|
||||
AstLogicalExpression,
|
||||
AstNumber,
|
||||
AstPlan,
|
||||
AstProgram,
|
||||
AstRule,
|
||||
AstStatement,
|
||||
AstString,
|
||||
AstVar,
|
||||
BinaryOperatorType,
|
||||
StatementType,
|
||||
TriggerType,
|
||||
_coalesce_expr,
|
||||
)
|
||||
|
||||
|
||||
def test_ast_atom():
|
||||
atom = AstAtom("test")
|
||||
assert str(atom) == "test"
|
||||
assert atom._to_agentspeak() == "test"
|
||||
|
||||
|
||||
def test_ast_var():
|
||||
var = AstVar("Variable")
|
||||
assert str(var) == "Variable"
|
||||
assert var._to_agentspeak() == "Variable"
|
||||
|
||||
|
||||
def test_ast_number():
|
||||
num = AstNumber(42)
|
||||
assert str(num) == "42"
|
||||
num_float = AstNumber(3.14)
|
||||
assert str(num_float) == "3.14"
|
||||
|
||||
|
||||
def test_ast_string():
|
||||
s = AstString("hello")
|
||||
assert str(s) == '"hello"'
|
||||
|
||||
|
||||
def test_ast_literal():
|
||||
lit = AstLiteral("functor", [AstAtom("atom"), AstNumber(1)])
|
||||
assert str(lit) == "functor(atom, 1)"
|
||||
lit_empty = AstLiteral("functor")
|
||||
assert str(lit_empty) == "functor"
|
||||
|
||||
|
||||
def test_ast_binary_op():
|
||||
left = AstNumber(1)
|
||||
right = AstNumber(2)
|
||||
op = AstBinaryOp(left, BinaryOperatorType.GREATER_THAN, right)
|
||||
assert str(op) == "1 > 2"
|
||||
|
||||
# Test logical wrapper
|
||||
assert isinstance(op.left, AstLogicalExpression)
|
||||
assert isinstance(op.right, AstLogicalExpression)
|
||||
|
||||
|
||||
def test_ast_binary_op_parens():
|
||||
# 1 > 2
|
||||
inner = AstBinaryOp(AstNumber(1), BinaryOperatorType.GREATER_THAN, AstNumber(2))
|
||||
# (1 > 2) & 3
|
||||
outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstNumber(3))
|
||||
assert str(outer) == "(1 > 2) & 3"
|
||||
|
||||
# 3 & (1 > 2)
|
||||
outer_right = AstBinaryOp(AstNumber(3), BinaryOperatorType.AND, inner)
|
||||
assert str(outer_right) == "3 & (1 > 2)"
|
||||
|
||||
|
||||
def test_ast_binary_op_parens_negated():
|
||||
inner = AstLogicalExpression(AstAtom("foo"), negated=True)
|
||||
outer = AstBinaryOp(inner, BinaryOperatorType.AND, AstAtom("bar"))
|
||||
# The current implementation checks `if self.left.negated: l_str = f"({l_str})"`
|
||||
# str(inner) is "not foo"
|
||||
# so we expect "(not foo) & bar"
|
||||
assert str(outer) == "(not foo) & bar"
|
||||
|
||||
outer_right = AstBinaryOp(AstAtom("bar"), BinaryOperatorType.AND, inner)
|
||||
assert str(outer_right) == "bar & (not foo)"
|
||||
|
||||
|
||||
def test_ast_logical_expression_negation():
|
||||
expr = AstLogicalExpression(AstAtom("true"), negated=True)
|
||||
assert str(expr) == "not true"
|
||||
|
||||
expr_neg_neg = ~expr
|
||||
assert str(expr_neg_neg) == "true"
|
||||
assert not expr_neg_neg.negated
|
||||
|
||||
# Invert a non-logical expression (wraps it)
|
||||
term = AstAtom("true")
|
||||
inverted = ~term
|
||||
assert isinstance(inverted, AstLogicalExpression)
|
||||
assert inverted.negated
|
||||
assert str(inverted) == "not true"
|
||||
|
||||
|
||||
def test_ast_logical_expression_no_negation():
|
||||
# _as_logical on already logical expression
|
||||
expr = AstLogicalExpression(AstAtom("x"))
|
||||
# Doing binary op will call _as_logical
|
||||
op = AstBinaryOp(expr, BinaryOperatorType.AND, AstAtom("y"))
|
||||
assert isinstance(op.left, AstLogicalExpression)
|
||||
assert op.left is expr # Should reuse instance
|
||||
|
||||
|
||||
def test_ast_operators():
|
||||
t1 = AstAtom("a")
|
||||
t2 = AstAtom("b")
|
||||
|
||||
assert str(t1 & t2) == "a & b"
|
||||
assert str(t1 | t2) == "a | b"
|
||||
assert str(t1 >= t2) == "a >= b"
|
||||
assert str(t1 > t2) == "a > b"
|
||||
assert str(t1 <= t2) == "a <= b"
|
||||
assert str(t1 < t2) == "a < b"
|
||||
assert str(t1 == t2) == "a == b"
|
||||
assert str(t1 != t2) == r"a \== b"
|
||||
|
||||
|
||||
def test_coalesce_expr():
|
||||
t = AstAtom("a")
|
||||
assert str(t & "b") == 'a & "b"'
|
||||
assert str(t & 1) == "a & 1"
|
||||
assert str(t & 1.5) == "a & 1.5"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
_coalesce_expr(None)
|
||||
|
||||
|
||||
def test_ast_statement():
|
||||
stmt = AstStatement(StatementType.DO_ACTION, AstLiteral("action"))
|
||||
assert str(stmt) == ".action"
|
||||
|
||||
|
||||
def test_ast_rule():
|
||||
# Rule with condition
|
||||
rule = AstRule(AstLiteral("head"), AstLiteral("body"))
|
||||
assert str(rule) == "head :- body."
|
||||
|
||||
# Rule without condition
|
||||
rule_simple = AstRule(AstLiteral("fact"))
|
||||
assert str(rule_simple) == "fact."
|
||||
|
||||
|
||||
def test_ast_plan():
|
||||
plan = AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("goal"),
|
||||
[AstLiteral("context")],
|
||||
[AstStatement(StatementType.DO_ACTION, AstLiteral("action"))],
|
||||
)
|
||||
output = str(plan)
|
||||
# verify parts exist
|
||||
assert "+!goal" in output
|
||||
assert ": context" in output
|
||||
assert "<- .action." in output
|
||||
|
||||
|
||||
def test_ast_plan_no_context():
|
||||
plan = AstPlan(
|
||||
TriggerType.ADDED_GOAL,
|
||||
AstLiteral("goal"),
|
||||
[],
|
||||
[AstStatement(StatementType.DO_ACTION, AstLiteral("action"))],
|
||||
)
|
||||
output = str(plan)
|
||||
assert "+!goal" in output
|
||||
assert ": " not in output
|
||||
assert "<- .action." in output
|
||||
|
||||
|
||||
def test_ast_program():
|
||||
prog = AstProgram(
|
||||
rules=[AstRule(AstLiteral("fact"))],
|
||||
plans=[AstPlan(TriggerType.ADDED_BELIEF, AstLiteral("b"), [], [])],
|
||||
)
|
||||
output = str(prog)
|
||||
assert "fact." in output
|
||||
assert "+b" in output
|
||||
187
test/unit/agents/bdi/test_agentspeak_generator.py
Normal file
187
test/unit/agents/bdi/test_agentspeak_generator.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.bdi.agentspeak_ast import AstProgram
|
||||
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
|
||||
from control_backend.schemas.program import (
|
||||
BasicNorm,
|
||||
ConditionalNorm,
|
||||
Gesture,
|
||||
GestureAction,
|
||||
Goal,
|
||||
InferredBelief,
|
||||
KeywordBelief,
|
||||
LLMAction,
|
||||
LogicalOperator,
|
||||
Phase,
|
||||
Plan,
|
||||
Program,
|
||||
SemanticBelief,
|
||||
SpeechAction,
|
||||
Trigger,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator():
|
||||
return AgentSpeakGenerator()
|
||||
|
||||
|
||||
def test_generate_empty_program(generator):
|
||||
prog = Program(phases=[])
|
||||
code = generator.generate(prog)
|
||||
assert 'phase("end").' in code
|
||||
assert "!notify_cycle" in code
|
||||
|
||||
|
||||
def test_generate_basic_norm(generator):
|
||||
norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="be nice")
|
||||
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
code = generator.generate(prog)
|
||||
assert f'norm("be nice") :- phase("{phase.id}").' in code
|
||||
|
||||
|
||||
def test_generate_critical_norm(generator):
|
||||
norm = BasicNorm(id=uuid.uuid4(), name="n1", norm="safety", critical=True)
|
||||
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
code = generator.generate(prog)
|
||||
assert f'critical_norm("safety") :- phase("{phase.id}").' in code
|
||||
|
||||
|
||||
def test_generate_conditional_norm(generator):
|
||||
cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="please")
|
||||
norm = ConditionalNorm(id=uuid.uuid4(), name="n1", norm="help", condition=cond)
|
||||
phase = Phase(id=uuid.uuid4(), norms=[norm], goals=[], triggers=[])
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
code = generator.generate(prog)
|
||||
assert 'norm("help")' in code
|
||||
assert 'keyword_said("please")' in code
|
||||
assert f"force_norm_{generator._slugify_str(norm.norm)}" in code
|
||||
|
||||
|
||||
def test_generate_goal_and_plan(generator):
|
||||
action = SpeechAction(id=uuid.uuid4(), name="s1", text="hello")
|
||||
plan = Plan(id=uuid.uuid4(), name="p1", steps=[action])
|
||||
# IMPORTANT: can_fail must be False for +achieved_ belief to be added
|
||||
goal = Goal(id=uuid.uuid4(), name="g1", description="desc", plan=plan, can_fail=False)
|
||||
phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[])
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
code = generator.generate(prog)
|
||||
# Check trigger for goal
|
||||
goal_slug = generator._slugify_str(goal.name)
|
||||
assert f"+!{goal_slug}" in code
|
||||
assert f'phase("{phase.id}")' in code
|
||||
assert '!say("hello")' in code
|
||||
|
||||
# Check success belief addition
|
||||
assert f"+achieved_{goal_slug}" in code
|
||||
|
||||
|
||||
def test_generate_subgoal(generator):
|
||||
subplan = Plan(id=uuid.uuid4(), name="p2", steps=[])
|
||||
subgoal = Goal(id=uuid.uuid4(), name="sub1", description="sub", plan=subplan)
|
||||
|
||||
plan = Plan(id=uuid.uuid4(), name="p1", steps=[subgoal])
|
||||
goal = Goal(id=uuid.uuid4(), name="g1", description="main", plan=plan)
|
||||
phase = Phase(id=uuid.uuid4(), norms=[], goals=[goal], triggers=[])
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
code = generator.generate(prog)
|
||||
subgoal_slug = generator._slugify_str(subgoal.name)
|
||||
# Main goal calls subgoal
|
||||
assert f"!{subgoal_slug}" in code
|
||||
# Subgoal plan exists
|
||||
assert f"+!{subgoal_slug}" in code
|
||||
|
||||
|
||||
def test_generate_trigger(generator):
|
||||
cond = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc")
|
||||
plan = Plan(id=uuid.uuid4(), name="p1", steps=[])
|
||||
trigger = Trigger(id=uuid.uuid4(), name="t1", condition=cond, plan=plan)
|
||||
phase = Phase(id=uuid.uuid4(), norms=[], goals=[], triggers=[trigger])
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
code = generator.generate(prog)
|
||||
# Trigger logic is added to check_triggers
|
||||
assert f"{generator.slugify(cond)}" in code
|
||||
assert f'notify_trigger_start("{generator.slugify(trigger)}")' in code
|
||||
assert f'notify_trigger_end("{generator.slugify(trigger)}")' in code
|
||||
|
||||
|
||||
def test_phase_transition(generator):
|
||||
phase1 = Phase(id=uuid.uuid4(), name="p1", norms=[], goals=[], triggers=[])
|
||||
phase2 = Phase(id=uuid.uuid4(), name="p2", norms=[], goals=[], triggers=[])
|
||||
prog = Program(phases=[phase1, phase2])
|
||||
|
||||
code = generator.generate(prog)
|
||||
assert "transition_phase" in code
|
||||
assert f'phase("{phase1.id}")' in code
|
||||
assert f'phase("{phase2.id}")' in code
|
||||
assert "force_transition_phase" in code
|
||||
|
||||
|
||||
def test_astify_gesture(generator):
|
||||
gesture = Gesture(type="single", name="wave")
|
||||
action = GestureAction(id=uuid.uuid4(), name="g1", gesture=gesture)
|
||||
ast = generator._astify(action)
|
||||
assert str(ast) == 'gesture("single", "wave")'
|
||||
|
||||
|
||||
def test_astify_llm_action(generator):
|
||||
action = LLMAction(id=uuid.uuid4(), name="l1", goal="be funny")
|
||||
ast = generator._astify(action)
|
||||
assert str(ast) == 'reply_with_goal("be funny")'
|
||||
|
||||
|
||||
def test_astify_inferred_belief_and(generator):
|
||||
left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a")
|
||||
right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b")
|
||||
inf = InferredBelief(
|
||||
id=uuid.uuid4(), name="i1", operator=LogicalOperator.AND, left=left, right=right
|
||||
)
|
||||
|
||||
ast = generator._astify(inf)
|
||||
assert 'keyword_said("a") & keyword_said("b")' == str(ast)
|
||||
|
||||
|
||||
def test_astify_inferred_belief_or(generator):
|
||||
left = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="a")
|
||||
right = KeywordBelief(id=uuid.uuid4(), name="k2", keyword="b")
|
||||
inf = InferredBelief(
|
||||
id=uuid.uuid4(), name="i1", operator=LogicalOperator.OR, left=left, right=right
|
||||
)
|
||||
|
||||
ast = generator._astify(inf)
|
||||
assert 'keyword_said("a") | keyword_said("b")' == str(ast)
|
||||
|
||||
|
||||
def test_astify_semantic_belief(generator):
|
||||
sb = SemanticBelief(id=uuid.uuid4(), name="s1", description="desc")
|
||||
ast = generator._astify(sb)
|
||||
assert str(ast) == f"semantic_{generator._slugify_str(sb.name)}"
|
||||
|
||||
|
||||
def test_slugify_not_implemented(generator):
|
||||
with pytest.raises(NotImplementedError):
|
||||
generator.slugify("not a program element")
|
||||
|
||||
|
||||
def test_astify_not_implemented(generator):
|
||||
with pytest.raises(NotImplementedError):
|
||||
generator._astify("not a program element")
|
||||
|
||||
|
||||
def test_process_phase_transition_from_none(generator):
|
||||
# Initialize AstProgram manually as we are bypassing generate()
|
||||
generator._asp = AstProgram()
|
||||
# Should safely return doing nothing
|
||||
generator._add_phase_transition(None, None)
|
||||
|
||||
assert len(generator._asp.plans) == 0
|
||||
532
test/unit/agents/bdi/test_bdi_core_agent.py
Normal file
532
test/unit/agents/bdi/test_bdi_core_agent.py
Normal file
@@ -0,0 +1,532 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||
|
||||
import agentspeak
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agentspeak_env():
|
||||
with patch("agentspeak.runtime.Environment") as mock_env:
|
||||
yield mock_env
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
agent = BDICoreAgent("bdi_agent")
|
||||
agent.send = AsyncMock()
|
||||
agent.bdi_agent = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch("control_backend.agents.bdi.bdi_core_agent.experiment_logger") as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_loads_asl(mock_agentspeak_env, agent):
|
||||
# Mock file opening
|
||||
with patch("builtins.open", mock_open(read_data="+initial_goal.")):
|
||||
await agent.setup()
|
||||
|
||||
# Check if environment tried to build agent
|
||||
mock_agentspeak_env.return_value.build_agent.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_no_asl(mock_agentspeak_env, agent):
|
||||
with patch("builtins.open", side_effect=FileNotFoundError):
|
||||
await agent.setup()
|
||||
|
||||
mock_agentspeak_env.return_value.build_agent.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_belief_message(agent, mock_settings):
|
||||
"""Test that incoming beliefs are added to the BDI agent"""
|
||||
beliefs = [Belief(name="user_said", arguments=["Hello"])]
|
||||
msg = InternalMessage(
|
||||
to="bdi_agent",
|
||||
sender=mock_settings.agent_settings.text_belief_extractor_name,
|
||||
body=BeliefMessage(create=beliefs).model_dump_json(),
|
||||
thread="beliefs",
|
||||
)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Check for the specific call we expect among all calls
|
||||
# bdi_agent.call is called multiple times (for transition_phase, check_triggers)
|
||||
# We want to confirm the belief addition call exists
|
||||
found_call = False
|
||||
for call in agent.bdi_agent.call.call_args_list:
|
||||
args = call.args
|
||||
if (
|
||||
args[0] == agentspeak.Trigger.addition
|
||||
and args[1] == agentspeak.GoalType.belief
|
||||
and args[2].functor == "user_said"
|
||||
and args[2].args[0].functor == "Hello"
|
||||
):
|
||||
found_call = True
|
||||
break
|
||||
|
||||
assert found_call, "Expected belief addition call not found in bdi_agent.call history"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_delete_belief_message(agent, mock_settings):
|
||||
"""Test that incoming beliefs to be deleted are removed from the BDI agent"""
|
||||
beliefs = [Belief(name="user_said", arguments=["Hello"])]
|
||||
|
||||
msg = InternalMessage(
|
||||
to="bdi_agent",
|
||||
sender=mock_settings.agent_settings.text_belief_extractor_name,
|
||||
body=BeliefMessage(delete=beliefs).model_dump_json(),
|
||||
thread="beliefs",
|
||||
)
|
||||
await agent.handle_message(msg)
|
||||
|
||||
found_call = False
|
||||
for call in agent.bdi_agent.call.call_args_list:
|
||||
args = call.args
|
||||
if (
|
||||
args[0] == agentspeak.Trigger.removal
|
||||
and args[1] == agentspeak.GoalType.belief
|
||||
and args[2].functor == "user_said"
|
||||
and args[2].args[0].functor == "Hello"
|
||||
):
|
||||
found_call = True
|
||||
break
|
||||
|
||||
assert found_call
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_incorrect_belief_message(agent, mock_settings):
|
||||
"""Test that incorrect message format triggers an exception."""
|
||||
msg = InternalMessage(
|
||||
to="bdi_agent",
|
||||
sender=mock_settings.agent_settings.text_belief_extractor_name,
|
||||
body=json.dumps({"bad_format": "bad_format"}),
|
||||
thread="beliefs",
|
||||
)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.bdi_agent.call.assert_not_called() # did not set belief
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_llm_response(agent):
|
||||
"""Test that LLM responses are forwarded to the Robot Speech Agent"""
|
||||
msg = InternalMessage(
|
||||
to="bdi_agent", sender=settings.agent_settings.llm_name, body="This is the LLM reply"
|
||||
)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Verify forward
|
||||
assert agent.send.called
|
||||
sent_msg = agent.send.call_args[0][0]
|
||||
assert sent_msg.to == settings.agent_settings.robot_speech_name
|
||||
assert "This is the LLM reply" in sent_msg.body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_actions(agent):
|
||||
agent._send_to_llm = MagicMock(side_effect=agent.send) # Mock specific method
|
||||
|
||||
# Initialize actions manually since we didn't call setup with real file
|
||||
agent._add_custom_actions()
|
||||
|
||||
# Find the action
|
||||
action_fn = None
|
||||
for (functor, _), fn in agent.actions.actions.items():
|
||||
if functor == ".reply":
|
||||
action_fn = fn
|
||||
break
|
||||
|
||||
assert action_fn is not None
|
||||
|
||||
# Invoke action
|
||||
mock_term = MagicMock()
|
||||
mock_term.args = ["Hello", "Norm"]
|
||||
mock_intention = MagicMock()
|
||||
|
||||
# Run generator
|
||||
gen = action_fn(agent, mock_term, mock_intention)
|
||||
next(gen) # Execute
|
||||
|
||||
agent._send_to_llm.assert_called_with("Hello", "Norm", "")
|
||||
|
||||
|
||||
def test_add_belief_sets_event(agent):
|
||||
"""Test that a belief triggers wake event and call()"""
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
|
||||
belief = Belief(name="test_belief", arguments=["a", "b"])
|
||||
belief_changes = BeliefMessage(replace=[belief])
|
||||
agent._apply_belief_changes(belief_changes)
|
||||
|
||||
assert agent.bdi_agent.call.called
|
||||
agent._wake_bdi_loop.set.assert_called()
|
||||
|
||||
|
||||
def test_apply_beliefs_empty_returns(agent):
|
||||
"""Line: if not beliefs: return"""
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
agent._apply_belief_changes(BeliefMessage())
|
||||
agent.bdi_agent.call.assert_not_called()
|
||||
agent._wake_bdi_loop.set.assert_not_called()
|
||||
|
||||
|
||||
def test_remove_belief_success_wakes_loop(agent):
|
||||
"""Line: if result: wake set"""
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
agent.bdi_agent.call.return_value = True
|
||||
|
||||
agent._remove_belief("remove_me", ["x"])
|
||||
|
||||
assert agent.bdi_agent.call.called
|
||||
|
||||
call_args = agent.bdi_agent.call.call_args.args
|
||||
trigger = call_args[0]
|
||||
goaltype = call_args[1]
|
||||
literal = call_args[2]
|
||||
|
||||
assert trigger == agentspeak.Trigger.removal
|
||||
assert goaltype == agentspeak.GoalType.belief
|
||||
assert literal.functor == "remove_me"
|
||||
assert literal.args[0].functor == "x"
|
||||
|
||||
agent._wake_bdi_loop.set.assert_called()
|
||||
|
||||
|
||||
def test_remove_belief_failure_does_not_wake(agent):
|
||||
"""Line: else result is False"""
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
agent.bdi_agent.call.return_value = False
|
||||
|
||||
agent._remove_belief("not_there", ["y"])
|
||||
|
||||
assert agent.bdi_agent.call.called # removal was attempted
|
||||
agent._wake_bdi_loop.set.assert_not_called()
|
||||
|
||||
|
||||
def test_remove_all_with_name_wakes_loop(agent):
|
||||
"""Cover _remove_all_with_name() removed counter + wake"""
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
|
||||
fake_literal = agentspeak.Literal("delete_me", (agentspeak.Literal("arg1"),))
|
||||
fake_key = ("delete_me", 1)
|
||||
agent.bdi_agent.beliefs = {fake_key: {fake_literal}}
|
||||
|
||||
agent._remove_all_with_name("delete_me")
|
||||
|
||||
assert agent.bdi_agent.call.called
|
||||
agent._wake_bdi_loop.set.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bdi_step_true_branch_hits_line_67(agent):
|
||||
"""Force step() to return True once so line 67 is actually executed"""
|
||||
# counter that isn't tied to MagicMock.call_count ordering
|
||||
counter = {"i": 0}
|
||||
|
||||
def fake_step():
|
||||
counter["i"] += 1
|
||||
return counter["i"] == 1 # True only first time
|
||||
|
||||
# Important: wrap fake_step into another mock so `.called` still exists
|
||||
agent.bdi_agent.step = MagicMock(side_effect=fake_step)
|
||||
agent.bdi_agent.shortest_deadline = MagicMock(return_value=None)
|
||||
|
||||
agent._running = True
|
||||
agent._wake_bdi_loop = asyncio.Event()
|
||||
agent._wake_bdi_loop.set()
|
||||
|
||||
task = asyncio.create_task(agent._bdi_loop())
|
||||
await asyncio.sleep(0.01)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert agent.bdi_agent.step.called
|
||||
assert counter["i"] >= 1 # proves True branch ran
|
||||
|
||||
|
||||
def test_replace_belief_calls_remove_all(agent):
|
||||
"""Cover: if belief.replace: self._remove_all_with_name()"""
|
||||
agent._remove_all_with_name = MagicMock()
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
|
||||
belief = Belief(name="user_said", arguments=["Hello"])
|
||||
belief_changes = BeliefMessage(replace=[belief])
|
||||
agent._apply_belief_changes(belief_changes)
|
||||
|
||||
agent._remove_all_with_name.assert_called_with("user_said")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_llm_creates_prompt_and_sends(agent):
|
||||
"""Cover entire _send_to_llm() including message send and logger.info"""
|
||||
agent.bdi_agent = MagicMock() # ensure mocked BDI does not interfere
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
|
||||
await agent._send_to_llm("hello world", "n1\nn2", "g1")
|
||||
|
||||
# send() was called
|
||||
assert agent.send.called
|
||||
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
||||
|
||||
# Message routing values correct
|
||||
assert sent_msg.to == settings.agent_settings.llm_name
|
||||
assert "hello world" in sent_msg.body
|
||||
|
||||
# JSON contains split norms/goals
|
||||
body = json.loads(sent_msg.body)
|
||||
assert body["norms"] == ["n1", "n2"]
|
||||
assert body["goals"] == ["g1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deadline_sleep_branch(agent):
|
||||
"""Specifically assert the if deadline: sleep → maybe_more_work=True branch"""
|
||||
future_deadline = time.time() + 0.005
|
||||
agent.bdi_agent.step.return_value = False
|
||||
agent.bdi_agent.shortest_deadline.return_value = future_deadline
|
||||
|
||||
start_time = time.time()
|
||||
agent._running = True
|
||||
agent._wake_bdi_loop = asyncio.Event()
|
||||
agent._wake_bdi_loop.set()
|
||||
|
||||
task = asyncio.create_task(agent._bdi_loop())
|
||||
await asyncio.sleep(0.01)
|
||||
task.cancel()
|
||||
|
||||
duration = time.time() - start_time
|
||||
assert duration >= 0.004 # loop slept until deadline
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_new_program(agent):
|
||||
agent._load_asl = AsyncMock()
|
||||
agent.add_behavior = MagicMock()
|
||||
# Mock existing loop task so it can be cancelled
|
||||
mock_task = MagicMock()
|
||||
mock_task.cancel = MagicMock()
|
||||
agent._bdi_loop_task = mock_task
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
msg = InternalMessage(to="bdi_agent", thread="new_program", body="path/to/asl.asl")
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
mock_task.cancel.assert_called_once()
|
||||
agent._load_asl.assert_awaited_once_with("path/to/asl.asl")
|
||||
agent.add_behavior.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_user_interrupts(agent, mock_settings):
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
# force_phase_transition
|
||||
agent._set_goal = MagicMock()
|
||||
msg = InternalMessage(
|
||||
to="bdi_agent",
|
||||
sender=mock_settings.agent_settings.user_interrupt_name,
|
||||
thread="force_phase_transition",
|
||||
body="",
|
||||
)
|
||||
await agent.handle_message(msg)
|
||||
agent._set_goal.assert_called_with("transition_phase")
|
||||
|
||||
# force_trigger
|
||||
agent._force_trigger = MagicMock()
|
||||
msg.thread = "force_trigger"
|
||||
msg.body = "trigger_x"
|
||||
await agent.handle_message(msg)
|
||||
agent._force_trigger.assert_called_with("trigger_x")
|
||||
|
||||
# force_norm
|
||||
agent._force_norm = MagicMock()
|
||||
msg.thread = "force_norm"
|
||||
msg.body = "norm_y"
|
||||
await agent.handle_message(msg)
|
||||
agent._force_norm.assert_called_with("norm_y")
|
||||
|
||||
# force_next_phase
|
||||
agent._force_next_phase = MagicMock()
|
||||
msg.thread = "force_next_phase"
|
||||
msg.body = ""
|
||||
await agent.handle_message(msg)
|
||||
agent._force_next_phase.assert_called_once()
|
||||
|
||||
# unknown interrupt
|
||||
agent.logger = MagicMock()
|
||||
msg.thread = "unknown_thing"
|
||||
await agent.handle_message(msg)
|
||||
agent.logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_action_reply_with_goal(agent):
|
||||
agent._send_to_llm = MagicMock(side_effect=agent.send)
|
||||
agent._add_custom_actions()
|
||||
action_fn = agent.actions.actions[(".reply_with_goal", 3)]
|
||||
|
||||
mock_term = MagicMock(args=["msg", "norms", "goal"])
|
||||
gen = action_fn(agent, mock_term, MagicMock())
|
||||
next(gen)
|
||||
agent._send_to_llm.assert_called_with("msg", "norms", "goal")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_action_notify_norms(agent):
|
||||
agent._add_custom_actions()
|
||||
action_fn = agent.actions.actions[(".notify_norms", 1)]
|
||||
|
||||
mock_term = MagicMock(args=["norms_list"])
|
||||
gen = action_fn(agent, mock_term, MagicMock())
|
||||
next(gen)
|
||||
|
||||
agent.send.assert_called()
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert msg.thread == "active_norms_update"
|
||||
assert msg.body == "norms_list"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_action_say(agent):
|
||||
agent._add_custom_actions()
|
||||
action_fn = agent.actions.actions[(".say", 1)]
|
||||
|
||||
mock_term = MagicMock(args=["hello"])
|
||||
gen = action_fn(agent, mock_term, MagicMock())
|
||||
next(gen)
|
||||
|
||||
assert agent.send.call_count == 2
|
||||
msgs = [c[0][0] for c in agent.send.call_args_list]
|
||||
assert any(m.to == settings.agent_settings.robot_speech_name for m in msgs)
|
||||
assert any(
|
||||
m.to == settings.agent_settings.llm_name and m.thread == "assistant_message" for m in msgs
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_action_gesture(agent):
|
||||
agent._add_custom_actions()
|
||||
# Test single
|
||||
action_fn = agent.actions.actions[(".gesture", 2)]
|
||||
mock_term = MagicMock(args=["single", "wave"])
|
||||
gen = action_fn(agent, mock_term, MagicMock())
|
||||
next(gen)
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert "actuate/gesture/single" in msg.body
|
||||
|
||||
# Test tag
|
||||
mock_term.args = ["tag", "happy"]
|
||||
gen = action_fn(agent, mock_term, MagicMock())
|
||||
next(gen)
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert "actuate/gesture/tag" in msg.body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_action_notify_user_said(agent):
|
||||
agent._add_custom_actions()
|
||||
action_fn = agent.actions.actions[(".notify_user_said", 1)]
|
||||
mock_term = MagicMock(args=["hello"])
|
||||
gen = action_fn(agent, mock_term, MagicMock())
|
||||
next(gen)
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert msg.to == settings.agent_settings.llm_name
|
||||
assert msg.thread == "user_message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_action_notify_trigger_start_end(agent):
|
||||
agent._add_custom_actions()
|
||||
# Start
|
||||
action_fn = agent.actions.actions[(".notify_trigger_start", 1)]
|
||||
gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock())
|
||||
next(gen)
|
||||
assert agent.send.call_args[0][0].thread == "trigger_start"
|
||||
|
||||
# End
|
||||
action_fn = agent.actions.actions[(".notify_trigger_end", 1)]
|
||||
gen = action_fn(agent, MagicMock(args=["t1"]), MagicMock())
|
||||
next(gen)
|
||||
assert agent.send.call_args[0][0].thread == "trigger_end"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_action_notify_goal_start(agent):
|
||||
agent._add_custom_actions()
|
||||
action_fn = agent.actions.actions[(".notify_goal_start", 1)]
|
||||
gen = action_fn(agent, MagicMock(args=["g1"]), MagicMock())
|
||||
next(gen)
|
||||
assert agent.send.call_args[0][0].thread == "goal_start"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_action_notify_transition_phase(agent):
|
||||
agent._add_custom_actions()
|
||||
action_fn = agent.actions.actions[(".notify_transition_phase", 2)]
|
||||
gen = action_fn(agent, MagicMock(args=["old", "new"]), MagicMock())
|
||||
next(gen)
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert msg.thread == "transition_phase"
|
||||
assert "old" in msg.body and "new" in msg.body
|
||||
|
||||
|
||||
def test_remove_belief_no_args(agent):
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
agent.bdi_agent.call.return_value = True
|
||||
agent._remove_belief("fact", None)
|
||||
assert agent.bdi_agent.call.called
|
||||
|
||||
|
||||
def test_set_goal_with_args(agent):
|
||||
agent._wake_bdi_loop = MagicMock()
|
||||
agent._set_goal("goal", ["arg1", "arg2"])
|
||||
assert agent.bdi_agent.call.called
|
||||
|
||||
|
||||
def test_format_belief_string():
|
||||
assert BDICoreAgent.format_belief_string("b") == "b"
|
||||
assert BDICoreAgent.format_belief_string("b", ["a1", "a2"]) == "b(a1,a2)"
|
||||
|
||||
|
||||
def test_force_norm(agent):
|
||||
agent._add_belief = MagicMock()
|
||||
agent._force_norm("be_polite")
|
||||
agent._add_belief.assert_called_with("force_be_polite")
|
||||
|
||||
|
||||
def test_force_trigger(agent):
|
||||
agent._set_goal = MagicMock()
|
||||
agent._force_trigger("trig")
|
||||
agent._set_goal.assert_called_with("trig")
|
||||
|
||||
|
||||
def test_force_next_phase(agent):
|
||||
agent._set_goal = MagicMock()
|
||||
agent._force_next_phase()
|
||||
agent._set_goal.assert_called_with("force_transition_phase")
|
||||
402
test/unit/agents/bdi/test_bdi_program_manager.py
Normal file
402
test/unit/agents/bdi/test_bdi_program_manager.py
Normal file
@@ -0,0 +1,402 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.schemas.program import (
|
||||
BasicNorm,
|
||||
ConditionalNorm,
|
||||
Goal,
|
||||
InferredBelief,
|
||||
KeywordBelief,
|
||||
Phase,
|
||||
Plan,
|
||||
Program,
|
||||
Trigger,
|
||||
)
|
||||
|
||||
# Fix Windows Proactor loop for zmq
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
def make_valid_program_json(norm="N1", goal="G1") -> str:
|
||||
return Program(
|
||||
phases=[
|
||||
Phase(
|
||||
id=uuid.uuid4(),
|
||||
name="Basic Phase",
|
||||
norms=[
|
||||
BasicNorm(
|
||||
id=uuid.uuid4(),
|
||||
name=norm,
|
||||
norm=norm,
|
||||
),
|
||||
],
|
||||
goals=[
|
||||
Goal(
|
||||
id=uuid.uuid4(),
|
||||
name=goal,
|
||||
description="This description can be used to determine whether the goal "
|
||||
"has been achieved.",
|
||||
plan=Plan(
|
||||
id=uuid.uuid4(),
|
||||
name="Goal Plan",
|
||||
steps=[],
|
||||
),
|
||||
can_fail=False,
|
||||
),
|
||||
],
|
||||
triggers=[],
|
||||
),
|
||||
],
|
||||
).model_dump_json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agentspeak_and_send_to_bdi(mock_settings):
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
|
||||
program = Program.model_validate_json(make_valid_program_json())
|
||||
|
||||
with patch("builtins.open", mock_open()) as mock_file:
|
||||
await manager._create_agentspeak_and_send_to_bdi(program)
|
||||
|
||||
# Check file writing
|
||||
mock_file.assert_called_with(mock_settings.behaviour_settings.agentspeak_file, "w")
|
||||
handle = mock_file()
|
||||
handle.write.assert_called()
|
||||
|
||||
assert manager.send.await_count == 1
|
||||
msg: InternalMessage = manager.send.await_args[0][0]
|
||||
assert msg.thread == "new_program"
|
||||
assert msg.to == mock_settings.agent_settings.bdi_core_name
|
||||
assert msg.body == mock_settings.behaviour_settings.agentspeak_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_programs_valid_and_invalid():
|
||||
sub = AsyncMock()
|
||||
sub.recv_multipart.side_effect = [
|
||||
(b"program", b"{bad json"),
|
||||
(b"program", make_valid_program_json().encode()),
|
||||
]
|
||||
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager._internal_pub_socket = AsyncMock()
|
||||
manager.sub_socket = sub
|
||||
manager._create_agentspeak_and_send_to_bdi = AsyncMock()
|
||||
manager._send_clear_llm_history = AsyncMock()
|
||||
manager._send_program_to_user_interrupt = AsyncMock()
|
||||
manager._send_beliefs_to_semantic_belief_extractor = AsyncMock()
|
||||
manager._send_goals_to_semantic_belief_extractor = AsyncMock()
|
||||
|
||||
try:
|
||||
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
|
||||
await manager._receive_programs()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
# Only valid Program should have triggered _send_to_bdi
|
||||
assert manager._create_agentspeak_and_send_to_bdi.await_count == 1
|
||||
forwarded: Program = manager._create_agentspeak_and_send_to_bdi.await_args[0][0]
|
||||
assert forwarded.phases[0].norms[0].name == "N1"
|
||||
assert forwarded.phases[0].goals[0].name == "G1"
|
||||
|
||||
# Verify history clear was triggered exactly once (for the valid program)
|
||||
# The invalid program loop `continue`s before calling _send_clear_llm_history
|
||||
assert manager._send_clear_llm_history.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_clear_llm_history(mock_settings):
|
||||
# Ensure the mock returns a string for the agent name (just like in your LLM tests)
|
||||
mock_settings.agent_settings.llm_agent_name = "llm_agent"
|
||||
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
|
||||
await manager._send_clear_llm_history()
|
||||
|
||||
assert manager.send.await_count == 2
|
||||
msg: InternalMessage = manager.send.await_args_list[0][0][0]
|
||||
|
||||
# Verify the content and recipient
|
||||
assert msg.body == "clear_history"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_transition_phase(mock_settings):
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
|
||||
# Setup state
|
||||
prog = Program.model_validate_json(make_valid_program_json(norm="N1", goal="G1"))
|
||||
manager._initialize_internal_state(prog)
|
||||
|
||||
# Test valid transition (to same phase for simplicity, or we need 2 phases)
|
||||
# Let's create a program with 2 phases
|
||||
phase2_id = uuid.uuid4()
|
||||
phase2 = Phase(id=phase2_id, name="Phase 2", norms=[], goals=[], triggers=[])
|
||||
prog.phases.append(phase2)
|
||||
manager._initialize_internal_state(prog)
|
||||
|
||||
current_phase_id = str(prog.phases[0].id)
|
||||
next_phase_id = str(phase2_id)
|
||||
|
||||
payload = json.dumps({"old": current_phase_id, "new": next_phase_id})
|
||||
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
|
||||
|
||||
await manager.handle_message(msg)
|
||||
|
||||
assert str(manager._phase.id) == next_phase_id
|
||||
|
||||
# Allow background tasks to run (add_behavior)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Check notifications sent
|
||||
# 1. beliefs to extractor
|
||||
# 2. goals to extractor
|
||||
# 3. notification to user interrupt
|
||||
|
||||
assert manager.send.await_count >= 3
|
||||
|
||||
# Verify user interrupt notification
|
||||
calls = manager.send.await_args_list
|
||||
ui_msgs = [
|
||||
c[0][0] for c in calls if c[0][0].to == mock_settings.agent_settings.user_interrupt_name
|
||||
]
|
||||
assert len(ui_msgs) > 0
|
||||
assert ui_msgs[-1].body == next_phase_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_transition_phase_desync():
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.logger = MagicMock()
|
||||
|
||||
prog = Program.model_validate_json(make_valid_program_json())
|
||||
manager._initialize_internal_state(prog)
|
||||
|
||||
current_phase_id = str(prog.phases[0].id)
|
||||
|
||||
# Request transition from WRONG old phase
|
||||
payload = json.dumps({"old": "wrong_id", "new": "some_new_id"})
|
||||
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
|
||||
|
||||
await manager.handle_message(msg)
|
||||
|
||||
# Should warn and do nothing
|
||||
manager.logger.warning.assert_called_once()
|
||||
assert "Phase transition desync detected" in manager.logger.warning.call_args[0][0]
|
||||
assert str(manager._phase.id) == current_phase_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_transition_phase_end(mock_settings):
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
|
||||
prog = Program.model_validate_json(make_valid_program_json())
|
||||
manager._initialize_internal_state(prog)
|
||||
current_phase_id = str(prog.phases[0].id)
|
||||
|
||||
payload = json.dumps({"old": current_phase_id, "new": "end"})
|
||||
msg = InternalMessage(to="me", sender="bdi", body=payload, thread="transition_phase")
|
||||
|
||||
await manager.handle_message(msg)
|
||||
|
||||
assert manager._phase is None
|
||||
|
||||
# Allow background tasks to run (add_behavior)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Verify notification to user interrupt
|
||||
assert manager.send.await_count == 1
|
||||
msg_sent = manager.send.await_args[0][0]
|
||||
assert msg_sent.to == mock_settings.agent_settings.user_interrupt_name
|
||||
assert msg_sent.body == "end"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_achieve_goal(mock_settings):
|
||||
mock_settings.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent"
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
|
||||
prog = Program.model_validate_json(make_valid_program_json(goal="TargetGoal"))
|
||||
manager._initialize_internal_state(prog)
|
||||
|
||||
goal_id = str(prog.phases[0].goals[0].id)
|
||||
|
||||
msg = InternalMessage(to="me", sender="ui", body=goal_id, thread="achieve_goal")
|
||||
|
||||
await manager.handle_message(msg)
|
||||
|
||||
# Should send achieved goals to text extractor
|
||||
assert manager.send.await_count == 1
|
||||
msg_sent = manager.send.await_args[0][0]
|
||||
assert msg_sent.to == mock_settings.agent_settings.text_belief_extractor_name
|
||||
assert msg_sent.thread == "achieved_goals"
|
||||
|
||||
# Verify body
|
||||
from control_backend.schemas.belief_list import GoalList
|
||||
|
||||
gl = GoalList.model_validate_json(msg_sent.body)
|
||||
assert len(gl.goals) == 1
|
||||
assert gl.goals[0].name == "TargetGoal"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_achieve_goal_not_found():
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
manager.logger = MagicMock()
|
||||
|
||||
prog = Program.model_validate_json(make_valid_program_json())
|
||||
manager._initialize_internal_state(prog)
|
||||
|
||||
msg = InternalMessage(to="me", sender="ui", body="non_existent_id", thread="achieve_goal")
|
||||
|
||||
await manager.handle_message(msg)
|
||||
|
||||
manager.send.assert_not_called()
|
||||
manager.logger.debug.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup(mock_settings):
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
manager.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_sub = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub
|
||||
|
||||
with patch(
|
||||
"control_backend.agents.bdi.bdi_program_manager.Context.instance", return_value=mock_context
|
||||
):
|
||||
# We also need to mock file writing in _create_agentspeak_and_send_to_bdi
|
||||
with patch("builtins.open", new_callable=MagicMock):
|
||||
await manager.setup()
|
||||
|
||||
# Check logic
|
||||
# 1. Sends default empty program to BDI
|
||||
assert manager.send.await_count == 1
|
||||
assert manager.send.await_args[0][0].to == mock_settings.agent_settings.bdi_core_name
|
||||
|
||||
# 2. Connects SUB socket
|
||||
mock_sub.connect.assert_called_with(mock_settings.zmq_settings.internal_sub_address)
|
||||
mock_sub.subscribe.assert_called_with("program")
|
||||
|
||||
# 3. Adds behavior
|
||||
manager.add_behavior.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_program_to_user_interrupt(mock_settings):
|
||||
"""Test directly sending the program to the user interrupt agent."""
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
|
||||
program = Program.model_validate_json(make_valid_program_json())
|
||||
|
||||
await manager._send_program_to_user_interrupt(program)
|
||||
|
||||
assert manager.send.await_count == 1
|
||||
msg = manager.send.await_args[0][0]
|
||||
assert msg.to == "user_interrupt_agent"
|
||||
assert msg.thread == "new_program"
|
||||
assert "Basic Phase" in msg.body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_program_extraction():
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
|
||||
# 1. Create Complex Components
|
||||
|
||||
# Inferred Belief (A & B)
|
||||
belief_left = KeywordBelief(id=uuid.uuid4(), name="b1", keyword="hot")
|
||||
belief_right = KeywordBelief(id=uuid.uuid4(), name="b2", keyword="sunny")
|
||||
inferred_belief = InferredBelief(
|
||||
id=uuid.uuid4(), name="b_inf", operator="AND", left=belief_left, right=belief_right
|
||||
)
|
||||
|
||||
# Conditional Norm
|
||||
cond_norm = ConditionalNorm(
|
||||
id=uuid.uuid4(), name="norm_cond", norm="wear_hat", condition=inferred_belief
|
||||
)
|
||||
|
||||
# Trigger with Inferred Belief condition
|
||||
dummy_plan = Plan(id=uuid.uuid4(), name="dummy_plan", steps=[])
|
||||
trigger = Trigger(id=uuid.uuid4(), name="trigger_1", condition=inferred_belief, plan=dummy_plan)
|
||||
|
||||
# Nested Goal
|
||||
sub_goal = Goal(
|
||||
id=uuid.uuid4(),
|
||||
name="sub_goal",
|
||||
description="desc",
|
||||
plan=Plan(id=uuid.uuid4(), name="empty", steps=[]),
|
||||
can_fail=True,
|
||||
)
|
||||
|
||||
parent_goal = Goal(
|
||||
id=uuid.uuid4(),
|
||||
name="parent_goal",
|
||||
description="desc",
|
||||
# The plan contains the sub_goal as a step
|
||||
plan=Plan(id=uuid.uuid4(), name="parent_plan", steps=[sub_goal]),
|
||||
can_fail=False,
|
||||
)
|
||||
|
||||
# 2. Assemble Program
|
||||
phase = Phase(
|
||||
id=uuid.uuid4(),
|
||||
name="Complex Phase",
|
||||
norms=[cond_norm],
|
||||
goals=[parent_goal],
|
||||
triggers=[trigger],
|
||||
)
|
||||
program = Program(phases=[phase])
|
||||
|
||||
# 3. Initialize Internal State (Triggers _populate_goal_mapping -> Nested Goal logic)
|
||||
manager._initialize_internal_state(program)
|
||||
|
||||
# Assertion for Line 53-54 (Mapping population)
|
||||
# Both parent and sub-goal should be mapped
|
||||
assert str(parent_goal.id) in manager._goal_mapping
|
||||
assert str(sub_goal.id) in manager._goal_mapping
|
||||
|
||||
# 4. Test Belief Extraction (Triggers lines 132-133, 142-146)
|
||||
beliefs = manager._extract_current_beliefs()
|
||||
|
||||
# Should extract recursive beliefs from cond_norm and trigger
|
||||
# Inferred belief splits into Left + Right. Since we use it twice, we get duplicates
|
||||
# checking existence is enough.
|
||||
belief_names = [b.name for b in beliefs]
|
||||
assert "b1" in belief_names
|
||||
assert "b2" in belief_names
|
||||
|
||||
# 5. Test Goal Extraction (Triggers lines 173, 185)
|
||||
goals = manager._extract_current_goals()
|
||||
|
||||
goal_names = [g.name for g in goals]
|
||||
assert "parent_goal" in goal_names
|
||||
assert "sub_goal" in goal_names
|
||||
554
test/unit/agents/bdi/test_text_belief_extractor.py
Normal file
554
test/unit/agents/bdi/test_text_belief_extractor.py
Normal file
@@ -0,0 +1,554 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.bdi import TextBeliefExtractorAgent
|
||||
from control_backend.agents.bdi.text_belief_extractor_agent import BeliefState
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_list import BeliefList
|
||||
from control_backend.schemas.belief_message import Belief as InternalBelief
|
||||
from control_backend.schemas.belief_message import BeliefMessage
|
||||
from control_backend.schemas.chat_history import ChatHistory, ChatMessage
|
||||
from control_backend.schemas.program import (
|
||||
BaseGoal, # Changed from Goal
|
||||
ConditionalNorm,
|
||||
KeywordBelief,
|
||||
LLMAction,
|
||||
Phase,
|
||||
Plan,
|
||||
Program,
|
||||
SemanticBelief,
|
||||
Trigger,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
|
||||
# We must ensure _query_llm returns a dictionary so iterating it doesn't fail
|
||||
llm._query_llm = AsyncMock(return_value={})
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent(llm):
|
||||
with patch(
|
||||
"control_backend.agents.bdi.text_belief_extractor_agent.TextBeliefExtractorAgent.LLM",
|
||||
return_value=llm,
|
||||
):
|
||||
agent = TextBeliefExtractorAgent("text_belief_agent")
|
||||
agent.send = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_program():
|
||||
return Program(
|
||||
phases=[
|
||||
Phase(
|
||||
name="Some phase",
|
||||
id=uuid.uuid4(),
|
||||
norms=[
|
||||
ConditionalNorm(
|
||||
name="Some norm",
|
||||
id=uuid.uuid4(),
|
||||
norm="Use nautical terms.",
|
||||
critical=False,
|
||||
condition=SemanticBelief(
|
||||
name="is_pirate",
|
||||
id=uuid.uuid4(),
|
||||
description="The user is a pirate. Perhaps because they say "
|
||||
"they are, or because they speak like a pirate "
|
||||
'with terms like "arr".',
|
||||
),
|
||||
),
|
||||
],
|
||||
goals=[],
|
||||
triggers=[
|
||||
Trigger(
|
||||
name="Some trigger",
|
||||
id=uuid.uuid4(),
|
||||
condition=SemanticBelief(
|
||||
name="no_more_booze",
|
||||
id=uuid.uuid4(),
|
||||
description="There is no more alcohol.",
|
||||
),
|
||||
plan=Plan(
|
||||
name="Some plan",
|
||||
id=uuid.uuid4(),
|
||||
steps=[
|
||||
LLMAction(
|
||||
name="Some action",
|
||||
id=uuid.uuid4(),
|
||||
goal="Suggest eating chocolate instead.",
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage:
|
||||
return InternalMessage(to="unused", sender=sender, body=body, thread=thread)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_ignores_other_agents(agent):
|
||||
msg = make_msg("unknown", "some data", None)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_from_transcriber(agent, mock_settings):
|
||||
transcription = "hello world"
|
||||
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
||||
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
||||
assert sent.to == mock_settings.agent_settings.bdi_core_name
|
||||
assert sent.thread == "beliefs"
|
||||
parsed = BeliefMessage.model_validate_json(sent.body)
|
||||
replaced_last = parsed.replace.pop()
|
||||
assert replaced_last.name == "user_said"
|
||||
assert replaced_last.arguments == [transcription]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_llm():
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "null",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_async_client = MagicMock()
|
||||
mock_async_client.__aenter__.return_value = mock_client
|
||||
mock_async_client.__aexit__.return_value = None
|
||||
|
||||
with patch(
|
||||
"control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient",
|
||||
return_value=mock_async_client,
|
||||
):
|
||||
llm = TextBeliefExtractorAgent.LLM(MagicMock(), 4)
|
||||
|
||||
res = await llm._query_llm("hello world", {"type": "null"})
|
||||
# Response content was set as "null", so should be deserialized as None
|
||||
assert res is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_query_llm_success(llm):
|
||||
llm._query_llm.return_value = None
|
||||
res = await llm.query("hello world", {"type": "null"})
|
||||
|
||||
llm._query_llm.assert_called_once()
|
||||
assert res is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_query_llm_success_after_failure(llm):
|
||||
llm._query_llm.side_effect = [KeyError(), "real value"]
|
||||
res = await llm.query("hello world", {"type": "string"})
|
||||
|
||||
assert llm._query_llm.call_count == 2
|
||||
assert res == "real value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_query_llm_failures(llm):
|
||||
llm._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"]
|
||||
res = await llm.query("hello world", {"type": "string"})
|
||||
|
||||
assert llm._query_llm.call_count == 3
|
||||
assert res is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_query_llm_fail_immediately(llm):
|
||||
llm._query_llm.side_effect = [KeyError(), "real value"]
|
||||
res = await llm.query("hello world", {"type": "string"}, tries=1)
|
||||
|
||||
assert llm._query_llm.call_count == 1
|
||||
assert res is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracting_semantic_beliefs(agent):
|
||||
"""
|
||||
The Program Manager sends beliefs to this agent. Test whether the agent handles them correctly.
|
||||
"""
|
||||
assert len(agent.belief_inferrer.available_beliefs) == 0
|
||||
beliefs = BeliefList(
|
||||
beliefs=[
|
||||
KeywordBelief(
|
||||
id=uuid.uuid4(),
|
||||
name="keyword_hello",
|
||||
keyword="hello",
|
||||
),
|
||||
SemanticBelief(
|
||||
id=uuid.uuid4(), name="semantic_hello_1", description="Some semantic belief 1"
|
||||
),
|
||||
SemanticBelief(
|
||||
id=uuid.uuid4(), name="semantic_hello_2", description="Some semantic belief 2"
|
||||
),
|
||||
]
|
||||
)
|
||||
await agent.handle_message(
|
||||
InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=settings.agent_settings.bdi_program_manager_name,
|
||||
body=beliefs.model_dump_json(),
|
||||
thread="beliefs",
|
||||
),
|
||||
)
|
||||
assert len(agent.belief_inferrer.available_beliefs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_invalid_beliefs(agent, sample_program):
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||
assert len(agent.belief_inferrer.available_beliefs) == 2
|
||||
|
||||
await agent.handle_message(
|
||||
InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=settings.agent_settings.bdi_program_manager_name,
|
||||
body=json.dumps({"phases": "Invalid"}),
|
||||
thread="beliefs",
|
||||
),
|
||||
)
|
||||
|
||||
assert len(agent.belief_inferrer.available_beliefs) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_robot_response(agent):
|
||||
initial_length = len(agent.conversation.messages)
|
||||
response = "Hi, I'm Pepper. What's your name?"
|
||||
|
||||
await agent.handle_message(
|
||||
InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=settings.agent_settings.llm_name,
|
||||
body=response,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(agent.conversation.messages) == initial_length + 1
|
||||
assert agent.conversation.messages[-1].role == "assistant"
|
||||
assert agent.conversation.messages[-1].content == response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_real_turn_with_beliefs(agent, llm, sample_program):
|
||||
"""Test sending user message to extract beliefs from."""
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||
|
||||
# Send a user message with the belief that there's no more booze
|
||||
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": True}
|
||||
assert len(agent.conversation.messages) == 0
|
||||
await agent.handle_message(
|
||||
InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=settings.agent_settings.transcription_name,
|
||||
body="We're all out of schnaps.",
|
||||
),
|
||||
)
|
||||
assert len(agent.conversation.messages) == 1
|
||||
|
||||
# There should be a belief set and sent to the BDI core, as well as the user_said belief
|
||||
assert agent.send.call_count == 2
|
||||
|
||||
# First should be the beliefs message
|
||||
message: InternalMessage = agent.send.call_args_list[1].args[0]
|
||||
beliefs = BeliefMessage.model_validate_json(message.body)
|
||||
assert len(beliefs.create) == 1
|
||||
assert beliefs.create[0].name == "no_more_booze"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_real_turn_no_beliefs(agent, llm, sample_program):
|
||||
"""Test a user message to extract beliefs from, but no beliefs are formed."""
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||
|
||||
# Send a user message with no new beliefs
|
||||
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": None}
|
||||
await agent.handle_message(
|
||||
InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=settings.agent_settings.transcription_name,
|
||||
body="Hello there!",
|
||||
),
|
||||
)
|
||||
|
||||
# Only the user_said belief should've been sent
|
||||
agent.send.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_real_turn_no_new_beliefs(agent, llm, sample_program):
|
||||
"""
|
||||
Test a user message to extract beliefs from, but no new beliefs are formed because they already
|
||||
existed.
|
||||
"""
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||
agent._current_beliefs = BeliefState(true={InternalBelief(name="is_pirate", arguments=None)})
|
||||
|
||||
# Send a user message with the belief the user is a pirate, still
|
||||
llm._query_llm.return_value = {"is_pirate": True, "no_more_booze": None}
|
||||
await agent.handle_message(
|
||||
InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=settings.agent_settings.transcription_name,
|
||||
body="Arr, nice to meet you, matey.",
|
||||
),
|
||||
)
|
||||
|
||||
# Only the user_said belief should've been sent, as no beliefs have changed
|
||||
agent.send.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulated_real_turn_remove_belief(agent, llm, sample_program):
|
||||
"""
|
||||
Test a user message to extract beliefs from, but an existing belief is determined no longer to
|
||||
hold.
|
||||
"""
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||
agent._current_beliefs = BeliefState(
|
||||
true={InternalBelief(name="no_more_booze", arguments=None)},
|
||||
)
|
||||
|
||||
# Send a user message with the belief the user is a pirate, still
|
||||
llm._query_llm.return_value = {"is_pirate": None, "no_more_booze": False}
|
||||
await agent.handle_message(
|
||||
InternalMessage(
|
||||
to=settings.agent_settings.text_belief_extractor_name,
|
||||
sender=settings.agent_settings.transcription_name,
|
||||
body="I found an untouched barrel of wine!",
|
||||
),
|
||||
)
|
||||
|
||||
# Both user_said and belief change should've been sent
|
||||
assert agent.send.call_count == 2
|
||||
|
||||
# Agent's current beliefs should've changed
|
||||
assert any(b.name == "no_more_booze" for b in agent._current_beliefs.false)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_goal_completions_sends_beliefs(agent, llm):
|
||||
"""Test that inferred goal completions are sent to the BDI core."""
|
||||
goal = BaseGoal(
|
||||
id=uuid.uuid4(), name="Say Hello", description="The user said hello", can_fail=True
|
||||
)
|
||||
agent.goal_inferrer.goals = {goal}
|
||||
|
||||
# Mock goal inference: goal is achieved
|
||||
llm.query = AsyncMock(return_value=True)
|
||||
|
||||
await agent._infer_goal_completions()
|
||||
|
||||
# Should send belief change to BDI core
|
||||
agent.send.assert_awaited_once()
|
||||
sent: InternalMessage = agent.send.call_args.args[0]
|
||||
assert sent.to == settings.agent_settings.bdi_core_name
|
||||
assert sent.thread == "beliefs"
|
||||
|
||||
parsed = BeliefMessage.model_validate_json(sent.body)
|
||||
assert len(parsed.create) == 1
|
||||
assert parsed.create[0].name == "achieved_say_hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_handling(agent, llm, sample_program):
|
||||
"""
|
||||
Check that the agent handles failures gracefully without crashing.
|
||||
"""
|
||||
llm._query_llm.side_effect = httpx.HTTPError("")
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].norms[0].condition)
|
||||
agent.belief_inferrer.available_beliefs.append(sample_program.phases[0].triggers[0].condition)
|
||||
|
||||
belief_changes = await agent.belief_inferrer.infer_from_conversation(
|
||||
ChatHistory(
|
||||
messages=[ChatMessage(role="user", content="Good day!")],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(belief_changes.true) == 0
|
||||
assert len(belief_changes.false) == 0
|
||||
|
||||
|
||||
def test_belief_state_bool():
|
||||
# Empty
|
||||
bs = BeliefState()
|
||||
assert not bs
|
||||
|
||||
# True set
|
||||
bs_true = BeliefState(true={InternalBelief(name="a", arguments=None)})
|
||||
assert bs_true
|
||||
|
||||
# False set
|
||||
bs_false = BeliefState(false={InternalBelief(name="a", arguments=None)})
|
||||
assert bs_false
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_beliefs_message_validation_error(agent, mock_settings):
|
||||
# Invalid JSON
|
||||
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
|
||||
msg = InternalMessage(
|
||||
to="me",
|
||||
sender=mock_settings.agent_settings.bdi_program_manager_name,
|
||||
thread="beliefs",
|
||||
body="invalid json",
|
||||
)
|
||||
# Should log warning and return
|
||||
agent.logger = MagicMock()
|
||||
await agent.handle_message(msg)
|
||||
agent.logger.warning.assert_called()
|
||||
|
||||
# Invalid Model
|
||||
msg.body = json.dumps({"beliefs": [{"invalid": "obj"}]})
|
||||
await agent.handle_message(msg)
|
||||
agent.logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_goals_message_validation_error(agent, mock_settings):
|
||||
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
|
||||
msg = InternalMessage(
|
||||
to="me",
|
||||
sender=mock_settings.agent_settings.bdi_program_manager_name,
|
||||
thread="goals",
|
||||
body="invalid json",
|
||||
)
|
||||
agent.logger = MagicMock()
|
||||
await agent.handle_message(msg)
|
||||
agent.logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_goal_achieved_message_validation_error(agent, mock_settings):
|
||||
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
|
||||
msg = InternalMessage(
|
||||
to="me",
|
||||
sender=mock_settings.agent_settings.bdi_program_manager_name,
|
||||
thread="achieved_goals",
|
||||
body="invalid json",
|
||||
)
|
||||
agent.logger = MagicMock()
|
||||
await agent.handle_message(msg)
|
||||
agent.logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_goal_inferrer_infer_from_conversation(agent, llm):
|
||||
# Setup goals
|
||||
# Use BaseGoal object as typically received by the extractor
|
||||
g1 = BaseGoal(id=uuid.uuid4(), name="g1", description="desc", can_fail=True)
|
||||
|
||||
# Use real GoalAchievementInferrer
|
||||
from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer
|
||||
|
||||
inferrer = GoalAchievementInferrer(llm)
|
||||
inferrer.goals = {g1}
|
||||
|
||||
# Mock LLM response
|
||||
llm._query_llm.return_value = True
|
||||
|
||||
completions = await inferrer.infer_from_conversation(ChatHistory(messages=[]))
|
||||
assert completions
|
||||
# slugify uses slugify library, hard to predict exact string without it,
|
||||
# but we can check values
|
||||
assert list(completions.values())[0] is True
|
||||
|
||||
|
||||
def test_apply_conversation_message_limit(agent):
|
||||
with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s:
|
||||
mock_s.behaviour_settings.conversation_history_length_limit = 2
|
||||
agent.conversation.messages = []
|
||||
|
||||
agent._apply_conversation_message(ChatMessage(role="user", content="1"))
|
||||
agent._apply_conversation_message(ChatMessage(role="assistant", content="2"))
|
||||
agent._apply_conversation_message(ChatMessage(role="user", content="3"))
|
||||
|
||||
assert len(agent.conversation.messages) == 2
|
||||
assert agent.conversation.messages[0].content == "2"
|
||||
assert agent.conversation.messages[1].content == "3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_program_manager_reset(agent):
|
||||
with patch("control_backend.agents.bdi.text_belief_extractor_agent.settings") as mock_s:
|
||||
mock_s.agent_settings.bdi_program_manager_name = "pm"
|
||||
agent.conversation.messages = [ChatMessage(role="user", content="hi")]
|
||||
agent.belief_inferrer.available_beliefs = [
|
||||
SemanticBelief(id=uuid.uuid4(), name="b", description="d")
|
||||
]
|
||||
|
||||
msg = InternalMessage(to="me", sender="pm", thread="conversation_history", body="reset")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
assert len(agent.conversation.messages) == 0
|
||||
assert len(agent.belief_inferrer.available_beliefs) == 0
|
||||
|
||||
|
||||
def test_split_into_chunks():
|
||||
from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer
|
||||
|
||||
items = [1, 2, 3, 4, 5]
|
||||
chunks = SemanticBeliefInferrer._split_into_chunks(items, 2)
|
||||
assert len(chunks) == 2
|
||||
assert len(chunks[0]) + len(chunks[1]) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_beliefs_call(agent, llm):
|
||||
from control_backend.agents.bdi.text_belief_extractor_agent import SemanticBeliefInferrer
|
||||
|
||||
inferrer = SemanticBeliefInferrer(llm)
|
||||
sb = SemanticBelief(id=uuid.uuid4(), name="is_happy", description="User is happy")
|
||||
|
||||
llm.query = AsyncMock(return_value={"is_happy": True})
|
||||
|
||||
res = await inferrer._infer_beliefs(ChatHistory(messages=[]), [sb])
|
||||
assert res == {"is_happy": True}
|
||||
llm.query.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_goal_call(agent, llm):
|
||||
from control_backend.agents.bdi.text_belief_extractor_agent import GoalAchievementInferrer
|
||||
|
||||
inferrer = GoalAchievementInferrer(llm)
|
||||
goal = BaseGoal(id=uuid.uuid4(), name="g1", description="d")
|
||||
|
||||
llm.query = AsyncMock(return_value=True)
|
||||
|
||||
res = await inferrer._infer_goal(ChatHistory(messages=[]), goal)
|
||||
assert res is True
|
||||
llm.query.assert_called_once()
|
||||
433
test/unit/agents/communication/test_ri_communication_agent.py
Normal file
433
test/unit/agents/communication/test_ri_communication_agent.py
Normal file
@@ -0,0 +1,433 @@
|
||||
import asyncio
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.schemas.ri_message import PauseCommand, RIEndpoint
|
||||
|
||||
|
||||
def speech_agent_path():
|
||||
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
|
||||
|
||||
|
||||
def gesture_agent_path():
|
||||
return "control_backend.agents.communication.ri_communication_agent.RobotGestureAgent"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch(
|
||||
"control_backend.agents.communication.ri_communication_agent.Context.instance"
|
||||
)
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
def negotiation_message(
|
||||
actuation_port: int = 5556,
|
||||
bind_main: bool = False,
|
||||
bind_actuation: bool = False,
|
||||
main_port: int = 5555,
|
||||
):
|
||||
return {
|
||||
"endpoint": "negotiate/ports",
|
||||
"data": [
|
||||
{"id": "main", "port": main_port, "bind": bind_main},
|
||||
{"id": "actuation", "port": actuation_port, "bind": bind_actuation},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_success_connects_and_starts_robot(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(speech_agent_path(), autospec=True) as MockSpeech,
|
||||
patch(gesture_agent_path(), autospec=True) as MockGesture,
|
||||
):
|
||||
MockSpeech.return_value.start = AsyncMock()
|
||||
MockGesture.return_value.start = AsyncMock()
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
|
||||
MockSpeech.return_value.start.assert_awaited_once()
|
||||
MockGesture.return_value.start.assert_awaited_once()
|
||||
MockSpeech.assert_called_once_with(ANY, address="tcp://localhost:5556", bind=False)
|
||||
MockGesture.assert_called_once_with(
|
||||
ANY,
|
||||
address="tcp://localhost:5556",
|
||||
bind=False,
|
||||
gesture_data=[],
|
||||
single_gesture_data=[],
|
||||
)
|
||||
agent.add_behavior.assert_called_once()
|
||||
|
||||
assert agent.connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_binds_when_requested(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message(bind_main=True))
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
with (
|
||||
patch(speech_agent_path(), autospec=True) as MockSpeech,
|
||||
patch(gesture_agent_path(), autospec=True) as MockGesture,
|
||||
):
|
||||
MockSpeech.return_value.start = AsyncMock()
|
||||
MockGesture.return_value.start = AsyncMock()
|
||||
await agent.setup()
|
||||
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
||||
agent.add_behavior.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_invalid_endpoint_retries(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
success = await agent._negotiate_connection(max_retries=1)
|
||||
assert success is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_timeout(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
success = await agent._negotiate_connection(max_retries=1)
|
||||
|
||||
assert success is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_negotiation_response_updates_req_socket(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
with (
|
||||
patch(speech_agent_path(), autospec=True) as MockSpeech,
|
||||
patch(gesture_agent_path(), autospec=True) as MockGesture,
|
||||
):
|
||||
MockSpeech.return_value.start = AsyncMock()
|
||||
MockGesture.return_value.start = AsyncMock()
|
||||
await agent._handle_negotiation_response(
|
||||
negotiation_message(
|
||||
main_port=6000,
|
||||
actuation_port=6001,
|
||||
bind_main=False,
|
||||
bind_actuation=False,
|
||||
)
|
||||
)
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:6000")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_disconnection_publishes_and_reconnects():
|
||||
pub_socket = AsyncMock()
|
||||
pub_socket.close = MagicMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent.pub_socket = pub_socket
|
||||
agent.connected = True
|
||||
agent._negotiate_connection = AsyncMock(return_value=True)
|
||||
|
||||
await agent._handle_disconnection()
|
||||
pub_socket.send_multipart.assert_awaited()
|
||||
assert agent.connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_handles_non_ping(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"endpoint": "negotiate/ports", "data": {}}
|
||||
|
||||
fake_socket.recv_json = recv_once
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
fake_socket.send_json.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_unexpected_error(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom"))
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
assert await agent._negotiate_connection(max_retries=1) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_handle_response_error(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent._handle_negotiation_response = AsyncMock(side_effect=Exception("bad response"))
|
||||
|
||||
assert await agent._negotiate_connection(max_retries=1) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
|
||||
def swallow(coro):
|
||||
coro.close()
|
||||
|
||||
agent.add_behavior = swallow
|
||||
agent._negotiate_connection = AsyncMock(return_value=False)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
assert agent.connected is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_negotiation_response_unhandled_id():
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
|
||||
await agent._handle_negotiation_response(
|
||||
{"data": [{"id": "other", "port": 5000, "bind": False}]}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_negotiation_response_audio(zmq_context):
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
|
||||
with patch(
|
||||
"control_backend.agents.communication.ri_communication_agent.VADAgent", autospec=True
|
||||
) as MockVAD:
|
||||
MockVAD.return_value.start = AsyncMock()
|
||||
|
||||
await agent._handle_negotiation_response(
|
||||
{"data": [{"id": "audio", "port": 7000, "bind": False}]}
|
||||
)
|
||||
|
||||
MockVAD.assert_called_once_with(
|
||||
audio_in_address="tcp://localhost:7000", audio_in_bind=False
|
||||
)
|
||||
MockVAD.return_value.start.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_closes_sockets():
|
||||
req = MagicMock()
|
||||
pub = MagicMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = pub
|
||||
|
||||
await agent.stop()
|
||||
|
||||
req.close.assert_called_once()
|
||||
pub.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_not_connected(monkeypatch):
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._running = True
|
||||
agent.connected = False
|
||||
agent._req_socket = AsyncMock()
|
||||
|
||||
async def fake_sleep(duration):
|
||||
agent._running = False
|
||||
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_send_and_recv_timeout():
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock(side_effect=TimeoutError)
|
||||
req.recv_json = AsyncMock(side_effect=TimeoutError)
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
async def stop_run():
|
||||
agent._running = False
|
||||
|
||||
agent._handle_disconnection = AsyncMock(side_effect=stop_run)
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
agent._handle_disconnection.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_missing_endpoint(monkeypatch):
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"data": {}}
|
||||
|
||||
req.recv_json = recv_once
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_generic_exception():
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock()
|
||||
req.recv_json = AsyncMock(side_effect=ValueError("boom"))
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_disconnection_timeout(monkeypatch):
|
||||
pub = AsyncMock()
|
||||
pub.close = MagicMock()
|
||||
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent.pub_socket = pub
|
||||
agent._negotiate_connection = AsyncMock(return_value=False)
|
||||
|
||||
await agent._handle_disconnection()
|
||||
|
||||
pub.send_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_ping_sends_internal(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
pub_socket = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent.pub_socket = pub_socket
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"endpoint": "ping", "data": {}}
|
||||
|
||||
fake_socket.recv_json = recv_once
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
pub_socket.send_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_req_socket_none_causes_retry(zmq_context):
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = None
|
||||
|
||||
result = await agent._negotiate_connection(max_retries=1)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_pause_command(zmq_context):
|
||||
"""Test handle_message with a valid PauseCommand."""
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = AsyncMock()
|
||||
agent.logger = MagicMock()
|
||||
|
||||
agent._req_socket.recv_json.return_value = {"status": "ok"}
|
||||
|
||||
pause_cmd = PauseCommand(data=True)
|
||||
msg = InternalMessage(to="ri_comm", sender="user_int", body=pause_cmd.model_dump_json())
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent._req_socket.send_json.assert_awaited_once()
|
||||
args = agent._req_socket.send_json.await_args[0][0]
|
||||
assert args["endpoint"] == RIEndpoint.PAUSE.value
|
||||
assert args["data"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_invalid_pause_command(zmq_context):
|
||||
"""Test handle_message with invalid JSON."""
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = AsyncMock()
|
||||
agent.logger = MagicMock()
|
||||
|
||||
msg = InternalMessage(to="ri_comm", sender="user_int", body="invalid json")
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.logger.warning.assert_called_with("Incorrect message format for PauseCommand.")
|
||||
agent._req_socket.send_json.assert_not_called()
|
||||
356
test/unit/agents/llm/test_llm_agent.py
Normal file
356
test/unit/agents/llm/test_llm_agent.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""Mocks `httpx` and tests chunking logic."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
with patch("httpx.AsyncClient") as mock_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_client
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch("control_backend.agents.llm.llm_agent.experiment_logger") as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_processing_success(mock_httpx_client, mock_settings):
|
||||
# Setup the mock response for the stream
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
# Simulate stream lines
|
||||
lines = [
|
||||
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
|
||||
b'data: {"choices": [{"delta": {"content": " world"}}]}',
|
||||
b'data: {"choices": [{"delta": {"content": "."}}]}',
|
||||
b"data: [DONE]",
|
||||
]
|
||||
|
||||
async def aiter_lines_gen():
|
||||
for line in lines:
|
||||
yield line.decode()
|
||||
|
||||
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
||||
|
||||
mock_stream_context = MagicMock()
|
||||
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Configure the client
|
||||
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
||||
|
||||
# Setup Agent
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock() # Mock the send method to verify replies
|
||||
|
||||
mock_logger = MagicMock()
|
||||
agent.logger = mock_logger
|
||||
|
||||
# Simulate receiving a message from BDI
|
||||
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
||||
msg = InternalMessage(
|
||||
to="llm_agent",
|
||||
sender=mock_settings.agent_settings.bdi_core_name,
|
||||
body=prompt.model_dump_json(),
|
||||
thread="prompt_message", # REQUIRED: thread must match handle_message logic
|
||||
)
|
||||
|
||||
agent._process_bdi_message = AsyncMock()
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent._process_bdi_message.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_bdi_message_success(mock_httpx_client, mock_settings):
|
||||
# Setup the mock response for the stream
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
# Simulate stream lines
|
||||
lines = [
|
||||
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
|
||||
b'data: {"choices": [{"delta": {"content": " world"}}]}',
|
||||
b'data: {"choices": [{"delta": {"content": "."}}]}',
|
||||
b"data: [DONE]",
|
||||
]
|
||||
|
||||
async def aiter_lines_gen():
|
||||
for line in lines:
|
||||
yield line.decode()
|
||||
|
||||
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
||||
|
||||
mock_stream_context = MagicMock()
|
||||
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Configure the client
|
||||
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
||||
|
||||
# Setup Agent
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock() # Mock the send method to verify replies
|
||||
|
||||
mock_logger = MagicMock()
|
||||
agent.logger = mock_logger
|
||||
|
||||
# Simulate receiving a message from BDI
|
||||
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
||||
|
||||
await agent._process_bdi_message(prompt)
|
||||
|
||||
# Verification
|
||||
# "Hello world." constitutes one sentence/chunk based on punctuation split
|
||||
# The agent should call send once with the full sentence, PLUS once more for full reply
|
||||
assert agent.send.called
|
||||
|
||||
# Check args. We expect at least one call sending "Hello world."
|
||||
calls = agent.send.call_args_list
|
||||
bodies = [c[0][0].body for c in calls]
|
||||
assert any("Hello world." in b for b in bodies)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_processing_errors(mock_httpx_client, mock_settings):
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock()
|
||||
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
||||
|
||||
# HTTP Error: stream method RAISES exception immediately
|
||||
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
|
||||
|
||||
await agent._process_bdi_message(prompt)
|
||||
|
||||
# Check that error message was sent
|
||||
assert agent.send.called
|
||||
assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_json_error(mock_httpx_client, mock_settings):
|
||||
# Test malformed JSON in stream
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
async def aiter_lines_gen():
|
||||
yield "data: {bad_json"
|
||||
yield "data: [DONE]"
|
||||
|
||||
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
||||
|
||||
mock_stream_context = MagicMock()
|
||||
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
||||
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock()
|
||||
# Ensure logger is mocked
|
||||
agent.logger = MagicMock()
|
||||
|
||||
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
||||
await agent._process_bdi_message(prompt)
|
||||
|
||||
agent.logger.error.assert_called() # Should log JSONDecodeError
|
||||
|
||||
|
||||
def test_llm_instructions():
|
||||
# Full custom
|
||||
instr = LLMInstructions(norms=["N1", "N2"], goals=["G1", "G2"])
|
||||
text = instr.build_developer_instruction()
|
||||
assert "Norms to follow:\n- N1\n- N2" in text
|
||||
assert "Goals to reach:\n- G1\n- G2" in text
|
||||
|
||||
# Defaults
|
||||
instr_def = LLMInstructions()
|
||||
text_def = instr_def.build_developer_instruction()
|
||||
assert "Norms to follow" in text_def
|
||||
assert "Goals to reach" in text_def
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, mock_settings):
|
||||
"""
|
||||
Covers the ValidationError branch:
|
||||
except ValidationError:
|
||||
self.logger.debug("Prompt message from BDI core is invalid.")
|
||||
Assert: no message is sent.
|
||||
"""
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock()
|
||||
|
||||
# Invalid JSON that triggers ValidationError in LLMPromptMessage
|
||||
invalid_json = '{"text": "Hi", "wrong_field": 123}' # field not in schema
|
||||
|
||||
msg = InternalMessage(
|
||||
to="llm_agent",
|
||||
sender=mock_settings.agent_settings.bdi_core_name,
|
||||
body=invalid_json,
|
||||
thread="prompt_message",
|
||||
)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Should not send any reply
|
||||
agent.send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_ignored_sender_branch_no_send(mock_httpx_client, mock_settings):
|
||||
"""
|
||||
Covers the else branch for messages not from BDI core:
|
||||
else:
|
||||
self.logger.debug("Message ignored (not from BDI core.")
|
||||
Assert: no message is sent.
|
||||
"""
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock()
|
||||
|
||||
msg = InternalMessage(
|
||||
to="llm_agent",
|
||||
sender="some_other_agent", # Not BDI core
|
||||
body='{"text": "Hi"}',
|
||||
)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Should not send any reply
|
||||
agent.send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_llm_yields_final_tail_chunk(mock_settings):
|
||||
"""
|
||||
Covers the branch: if current_chunk: yield current_chunk
|
||||
Ensure that the last partial chunk is emitted.
|
||||
"""
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock()
|
||||
|
||||
agent.logger = MagicMock()
|
||||
agent.logger.llm = MagicMock()
|
||||
|
||||
# Patch _stream_query_llm to yield tokens that do NOT end with punctuation
|
||||
async def fake_stream(messages):
|
||||
yield "Hello"
|
||||
yield " world" # No punctuation to trigger the normal chunking
|
||||
|
||||
agent._stream_query_llm = fake_stream
|
||||
|
||||
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
||||
|
||||
# Collect chunks yielded
|
||||
chunks = []
|
||||
async for chunk in agent._query_llm(prompt.text, prompt.norms, prompt.goals):
|
||||
chunks.append(chunk)
|
||||
|
||||
# The final chunk should be yielded
|
||||
assert chunks[-1] == "Hello world"
|
||||
assert any("Hello" in c for c in chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_settings):
|
||||
"""
|
||||
Covers: if not line or not line.startswith("data: "): continue
|
||||
Feed lines that are empty or do not start with 'data:' and check they are skipped.
|
||||
"""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
lines = [
|
||||
"", # empty line
|
||||
"not data", # invalid prefix
|
||||
'data: {"choices": [{"delta": {"content": "Hi"}}]}',
|
||||
"data: [DONE]",
|
||||
]
|
||||
|
||||
async def aiter_lines_gen():
|
||||
for line in lines:
|
||||
yield line
|
||||
|
||||
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
||||
|
||||
# Proper async context manager for stream
|
||||
mock_stream_context = MagicMock()
|
||||
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Make stream return the async context manager
|
||||
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
||||
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock()
|
||||
|
||||
# Patch settings for local LLM URL
|
||||
with patch("control_backend.agents.llm.llm_agent.settings") as mock_sett:
|
||||
mock_sett.llm_settings.local_llm_url = "http://localhost"
|
||||
mock_sett.llm_settings.local_llm_model = "test-model"
|
||||
|
||||
# Collect tokens
|
||||
tokens = []
|
||||
async for token in agent._stream_query_llm([]):
|
||||
tokens.append(token)
|
||||
|
||||
# Only the valid 'data:' line should yield content
|
||||
assert tokens == ["Hi"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history_command(mock_settings):
|
||||
"""Test that the 'clear_history' message clears the agent's memory."""
|
||||
# setup LLM to have some history
|
||||
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.history = [
|
||||
{"role": "user", "content": "Old conversation context"},
|
||||
{"role": "assistant", "content": "Old response"},
|
||||
]
|
||||
assert len(agent.history) == 2
|
||||
msg = InternalMessage(
|
||||
to="llm_agent",
|
||||
sender=mock_settings.agent_settings.bdi_program_manager_name,
|
||||
body="clear_history",
|
||||
)
|
||||
await agent.handle_message(msg)
|
||||
assert len(agent.history) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_assistant_and_user_messages(mock_settings):
|
||||
agent = LLMAgent("llm_agent")
|
||||
|
||||
# Assistant message
|
||||
msg_ast = InternalMessage(
|
||||
to="llm_agent",
|
||||
sender=mock_settings.agent_settings.bdi_core_name,
|
||||
thread="assistant_message",
|
||||
body="I said this",
|
||||
)
|
||||
await agent.handle_message(msg_ast)
|
||||
assert agent.history[-1] == {"role": "assistant", "content": "I said this"}
|
||||
|
||||
# User message
|
||||
msg_usr = InternalMessage(
|
||||
to="llm_agent",
|
||||
sender=mock_settings.agent_settings.bdi_core_name,
|
||||
thread="user_message",
|
||||
body="User said this",
|
||||
)
|
||||
await agent.handle_message(msg_usr)
|
||||
assert agent.history[-1] == {"role": "user", "content": "User said this"}
|
||||
@@ -0,0 +1,60 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
|
||||
OpenAIWhisperSpeechRecognizer,
|
||||
SpeechRecognizer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_sr_settings(monkeypatch):
|
||||
# Patch the *module-local* settings that SpeechRecognizer imported
|
||||
from control_backend.agents.perception.transcription_agent import speech_recognizer as sr
|
||||
|
||||
# Provide real numbers for everything _estimate_max_tokens() reads
|
||||
monkeypatch.setattr(sr.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
||||
monkeypatch.setattr(
|
||||
sr.settings.behaviour_settings, "transcription_words_per_minute", 450, raising=False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sr.settings.behaviour_settings, "transcription_words_per_token", 0.75, raising=False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sr.settings.behaviour_settings, "transcription_token_buffer", 10, raising=False
|
||||
)
|
||||
|
||||
|
||||
def test_estimate_max_tokens():
|
||||
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
|
||||
expecting 610 tokens."""
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
||||
|
||||
assert actual == 610
|
||||
assert isinstance(actual, int)
|
||||
|
||||
|
||||
def test_get_decode_options():
|
||||
"""Check whether the right decode options are given under different scenarios."""
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
# With the defaults, it should limit output length based on input size
|
||||
recognizer = OpenAIWhisperSpeechRecognizer()
|
||||
options = recognizer._get_decode_options(audio)
|
||||
|
||||
assert "sample_len" in options
|
||||
assert isinstance(options["sample_len"], int)
|
||||
|
||||
# When explicitly enabled, it should limit output length based on input size
|
||||
recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=True)
|
||||
options = recognizer._get_decode_options(audio)
|
||||
|
||||
assert "sample_len" in options
|
||||
assert isinstance(options["sample_len"], int)
|
||||
|
||||
# When disabled, it should not limit output length based on input size
|
||||
recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=False)
|
||||
options = recognizer._get_decode_options(audio)
|
||||
assert "sample_len" not in options
|
||||
@@ -0,0 +1,226 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
|
||||
MLXWhisperSpeechRecognizer,
|
||||
OpenAIWhisperSpeechRecognizer,
|
||||
SpeechRecognizer,
|
||||
)
|
||||
from control_backend.agents.perception.transcription_agent.transcription_agent import (
|
||||
TranscriptionAgent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch(
|
||||
"control_backend.agents.perception"
|
||||
".transcription_agent.transcription_agent.experiment_logger"
|
||||
) as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_agent_flow(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.recv = AsyncMock()
|
||||
|
||||
# Setup context to return this specific mock socket
|
||||
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
||||
|
||||
# Data: [Audio Bytes, Cancel Loop]
|
||||
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
||||
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
||||
|
||||
# Mock Recognizer
|
||||
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
||||
mock_recognizer = MagicMock()
|
||||
mock_recognizer.recognize_speech.return_value = "Hello"
|
||||
mock_best.return_value = mock_recognizer
|
||||
|
||||
agent = TranscriptionAgent("tcp://in")
|
||||
agent.send = AsyncMock()
|
||||
|
||||
agent._running = True
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
try:
|
||||
await agent._transcribing_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Check transcription happened
|
||||
assert mock_recognizer.recognize_speech.called
|
||||
# Check sending
|
||||
assert agent.send.called
|
||||
assert agent.send.call_args[0][0].body == "Hello"
|
||||
|
||||
await agent.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_empty(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.recv = AsyncMock()
|
||||
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
||||
|
||||
# Return valid audio, but recognizer returns empty string
|
||||
fake_audio = np.zeros(10, dtype=np.float32).tobytes()
|
||||
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
||||
|
||||
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
||||
mock_recognizer = MagicMock()
|
||||
mock_recognizer.recognize_speech.return_value = ""
|
||||
mock_best.return_value = mock_recognizer
|
||||
|
||||
agent = TranscriptionAgent("tcp://in")
|
||||
agent.send = AsyncMock()
|
||||
await agent.setup()
|
||||
|
||||
try:
|
||||
await agent._transcribing_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Should NOT send message
|
||||
agent.send.assert_not_called()
|
||||
|
||||
|
||||
def test_speech_recognizer_factory():
|
||||
# Test Factory Logic
|
||||
with patch("torch.mps.is_available", return_value=True):
|
||||
assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer)
|
||||
|
||||
with patch("torch.mps.is_available", return_value=False):
|
||||
assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer)
|
||||
|
||||
|
||||
def test_openai_recognizer():
|
||||
with patch("whisper.load_model") as load_mock:
|
||||
with patch("whisper.transcribe") as trans_mock:
|
||||
rec = OpenAIWhisperSpeechRecognizer()
|
||||
rec.load_model()
|
||||
load_mock.assert_called()
|
||||
|
||||
trans_mock.return_value = {"text": "Hi"}
|
||||
res = rec.recognize_speech(np.zeros(10))
|
||||
assert res == "Hi"
|
||||
|
||||
|
||||
def test_mlx_recognizer():
|
||||
# Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir().
|
||||
# We must use create=True to inject it into the module namespace during the test.
|
||||
module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer"
|
||||
|
||||
with patch("sys.platform", "darwin"):
|
||||
with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock:
|
||||
with patch(f"{module_path}.ModelHolder", create=True) as holder_mock:
|
||||
# We also need to mock mlx.core if it's used for types/constants
|
||||
with patch(f"{module_path}.mx", create=True):
|
||||
rec = MLXWhisperSpeechRecognizer()
|
||||
rec.load_model()
|
||||
holder_mock.get_model.assert_called()
|
||||
|
||||
mlx_mock.transcribe.return_value = {"text": "Hi"}
|
||||
res = rec.recognize_speech(np.zeros(10))
|
||||
assert res == "Hi"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_loop_continues_after_error(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.recv = AsyncMock()
|
||||
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
||||
|
||||
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
||||
|
||||
mock_sub.recv.side_effect = [
|
||||
fake_audio, # first iteration → recognizer fails
|
||||
asyncio.CancelledError(), # second iteration → stop loop
|
||||
]
|
||||
|
||||
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
||||
mock_recognizer = MagicMock()
|
||||
mock_recognizer.recognize_speech.side_effect = RuntimeError("fail")
|
||||
mock_best.return_value = mock_recognizer
|
||||
|
||||
agent = TranscriptionAgent("tcp://in")
|
||||
agent._running = True # ← REQUIRED to enter the loop
|
||||
agent.send = AsyncMock() # should never be called
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro) # match other tests
|
||||
|
||||
await agent.setup()
|
||||
|
||||
try:
|
||||
await agent._transcribing_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# recognizer failed, so we should never send anything
|
||||
agent.send.assert_not_called()
|
||||
|
||||
# recv must have been called twice (audio then CancelledError)
|
||||
assert mock_sub.recv.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_continue_branch_when_empty(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.recv = AsyncMock()
|
||||
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
||||
|
||||
# First recv → audio chunk
|
||||
# Second recv → Cancel loop → stop iteration
|
||||
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
||||
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
||||
|
||||
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
||||
mock_recognizer = MagicMock()
|
||||
mock_recognizer.recognize_speech.return_value = "" # <— triggers the continue branch
|
||||
mock_best.return_value = mock_recognizer
|
||||
|
||||
agent = TranscriptionAgent("tcp://in")
|
||||
|
||||
# Make loop runnable
|
||||
agent._running = True
|
||||
agent.send = AsyncMock()
|
||||
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
# Execute loop manually
|
||||
try:
|
||||
await agent._transcribing_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# → Because of "continue", NO sending should occur
|
||||
agent.send.assert_not_called()
|
||||
|
||||
# → Continue was hit, so we must have read exactly 2 times:
|
||||
# - first audio
|
||||
# - second CancelledError
|
||||
assert mock_sub.recv.call_count == 2
|
||||
|
||||
# → recognizer was called once (first iteration)
|
||||
assert mock_recognizer.recognize_speech.call_count == 1
|
||||
152
test/unit/agents/perception/vad_agent/test_vad_agent_unit.py
Normal file
152
test/unit/agents/perception/vad_agent/test_vad_agent_unit.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.perception.vad_agent import VADAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_zmq():
|
||||
with patch("zmq.asyncio.Context") as mock:
|
||||
mock.instance.return_value = MagicMock()
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
return VADAgent("tcp://localhost:5555", False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_pause(agent):
|
||||
agent._paused = MagicMock()
|
||||
# It starts set (not paused)
|
||||
|
||||
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="PAUSE")
|
||||
|
||||
# We need to mock settings to match sender name
|
||||
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent._paused.clear.assert_called_once()
|
||||
assert agent._reset_needed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_resume(agent):
|
||||
agent._paused = MagicMock()
|
||||
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="RESUME")
|
||||
|
||||
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent._paused.set.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_unknown_command(agent):
|
||||
agent._paused = MagicMock()
|
||||
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="UNKNOWN")
|
||||
|
||||
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
agent.logger = MagicMock()
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent._paused.clear.assert_not_called()
|
||||
agent._paused.set.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_unknown_sender(agent):
|
||||
agent._paused = MagicMock()
|
||||
msg = InternalMessage(to="vad", sender="other_agent", body="PAUSE")
|
||||
|
||||
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent._paused.clear.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_loop_waits_for_running(agent):
|
||||
agent._running = True
|
||||
agent.program_sub_socket = AsyncMock()
|
||||
agent.program_sub_socket.close = MagicMock()
|
||||
agent._reset_stream = AsyncMock()
|
||||
|
||||
# Sequence of messages:
|
||||
# 1. Wrong topic
|
||||
# 2. Right topic, wrong status (STARTING)
|
||||
# 3. Right topic, RUNNING -> Should break loop
|
||||
|
||||
agent.program_sub_socket.recv_multipart.side_effect = [
|
||||
(b"wrong_topic", b"whatever"),
|
||||
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
|
||||
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
|
||||
]
|
||||
|
||||
await agent._status_loop()
|
||||
|
||||
assert agent._reset_stream.await_count == 1
|
||||
agent.program_sub_socket.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_success(agent, mock_zmq):
|
||||
def close_coro(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
agent.add_behavior = MagicMock(side_effect=close_coro)
|
||||
|
||||
mock_context = mock_zmq.instance.return_value
|
||||
mock_sub = MagicMock()
|
||||
mock_pub = MagicMock()
|
||||
|
||||
# We expect multiple socket calls:
|
||||
# 1. audio_in (SUB)
|
||||
# 2. audio_out (PUB)
|
||||
# 3. program_sub (SUB)
|
||||
mock_context.socket.side_effect = [mock_sub, mock_pub, mock_sub]
|
||||
|
||||
with patch("control_backend.agents.perception.vad_agent.torch.hub.load") as mock_load:
|
||||
mock_load.return_value = (MagicMock(), None)
|
||||
|
||||
with patch("control_backend.agents.perception.vad_agent.TranscriptionAgent") as MockTrans:
|
||||
mock_trans_instance = MockTrans.return_value
|
||||
mock_trans_instance.start = AsyncMock()
|
||||
|
||||
await agent.setup()
|
||||
|
||||
mock_trans_instance.start.assert_awaited_once()
|
||||
|
||||
assert agent.add_behavior.call_count == 2 # streaming_loop + status_loop
|
||||
assert agent.audio_in_socket is not None
|
||||
assert agent.audio_out_socket is not None
|
||||
assert agent.program_sub_socket is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_stream(agent):
|
||||
mock_poller = MagicMock()
|
||||
agent.audio_in_poller = mock_poller
|
||||
|
||||
# poll(1) returns not None twice, then None
|
||||
mock_poller.poll = AsyncMock(side_effect=[b"data", b"data", None])
|
||||
|
||||
agent._ready = MagicMock()
|
||||
|
||||
await agent._reset_stream()
|
||||
|
||||
assert mock_poller.poll.await_count == 3
|
||||
agent._ready.set.assert_called_once()
|
||||
@@ -0,0 +1,46 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.perception.vad_agent import SocketPoller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_socket_poller_with_data(socket, mocker):
|
||||
socket_data = b"test"
|
||||
socket.recv.return_value = socket_data
|
||||
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
|
||||
mock_poller.return_value.poll = AsyncMock(return_value=[(socket, zmq.POLLIN)])
|
||||
|
||||
poller = SocketPoller(socket)
|
||||
# Calling `poll` twice to be able to check that the poller is reused
|
||||
await poller.poll()
|
||||
data = await poller.poll()
|
||||
|
||||
assert data == socket_data
|
||||
|
||||
# Ensure that the poller was reused
|
||||
mock_poller.assert_called_once_with()
|
||||
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
|
||||
|
||||
assert socket.recv.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_socket_poller_no_data(socket, mocker):
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
|
||||
mock_poller.return_value.poll = AsyncMock(return_value=[])
|
||||
|
||||
poller = SocketPoller(socket)
|
||||
data = await poller.poll()
|
||||
|
||||
assert data is None
|
||||
|
||||
socket.recv.assert_not_called()
|
||||
233
test/unit/agents/perception/vad_agent/test_vad_streaming.py
Normal file
233
test/unit/agents/perception/vad_agent/test_vad_streaming.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.perception.vad_agent import VADAgent
|
||||
from control_backend.core.config import settings
|
||||
|
||||
|
||||
# We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets
|
||||
# aren't closed properly.
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_zmq():
|
||||
with patch("zmq.asyncio.Context") as mock:
|
||||
mock.instance.return_value = MagicMock()
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_out_socket():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vad_agent(audio_out_socket):
|
||||
agent = VADAgent("tcp://localhost:5555", False)
|
||||
agent._internal_pub_socket = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_settings(monkeypatch):
|
||||
# Patch the settings that vad_agent.run() reads
|
||||
from control_backend.agents.perception import vad_agent
|
||||
|
||||
monkeypatch.setattr(
|
||||
vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False
|
||||
)
|
||||
monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch("control_backend.agents.perception.vad_agent.experiment_logger") as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||
"""
|
||||
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
||||
|
||||
:param streaming: The streaming component to be tested.
|
||||
:param probabilities: A list of probabilities representing the outputs of the VAD model.
|
||||
"""
|
||||
model_item = MagicMock()
|
||||
model_item.item.side_effect = probabilities
|
||||
streaming.model = MagicMock(return_value=model_item)
|
||||
|
||||
# Prepare deterministic audio chunks and a poller that stops the loop when exhausted
|
||||
chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes()
|
||||
chunks = [chunk_bytes for _ in probabilities]
|
||||
|
||||
class DummyPoller:
|
||||
def __init__(self, data, agent):
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
|
||||
async def poll(self, timeout_ms=None):
|
||||
if self.data:
|
||||
return self.data.pop(0)
|
||||
# Stop the loop cleanly once we've consumed all chunks
|
||||
self.agent._running = False
|
||||
return None
|
||||
|
||||
streaming.audio_in_poller = DummyPoller(chunks, streaming)
|
||||
streaming._ready = AsyncMock()
|
||||
streaming._running = True
|
||||
|
||||
await streaming._streaming_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_detected(audio_out_socket, vad_agent):
|
||||
"""
|
||||
Test a scenario where there is voice activity detected between silences.
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks
|
||||
probabilities = [0.0] * 15 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) == 512 * 4 * (begin_silence_chunks + speech_chunk_count)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
|
||||
"""
|
||||
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
||||
short pause.
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks
|
||||
probabilities = (
|
||||
[0.0] * 15 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
)
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
assert isinstance(data, bytes)
|
||||
# Expecting 13 chunks (2*5 with speech, 1 pause between, begin_silence_chunks as padding)
|
||||
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + begin_silence_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_data(audio_out_socket, vad_agent):
|
||||
"""
|
||||
Test a scenario where there is no data received. This should not cause errors.
|
||||
"""
|
||||
|
||||
class DummyPoller:
|
||||
async def poll(self, timeout_ms=None):
|
||||
vad_agent._running = False
|
||||
return None
|
||||
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
vad_agent.audio_in_poller = DummyPoller()
|
||||
vad_agent._ready = AsyncMock()
|
||||
vad_agent._running = True
|
||||
|
||||
await vad_agent._streaming_loop()
|
||||
|
||||
audio_out_socket.send.assert_not_called()
|
||||
assert len(vad_agent.audio_buffer) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_loop_reset_needed(audio_out_socket, vad_agent):
|
||||
"""Test that _reset_needed branch works as expected."""
|
||||
vad_agent._reset_needed = True
|
||||
vad_agent._ready.set()
|
||||
vad_agent._paused.set()
|
||||
vad_agent._running = True
|
||||
vad_agent.audio_buffer = np.array([1.0], dtype=np.float32)
|
||||
vad_agent.i_since_speech = 0
|
||||
|
||||
# Mock _reset_stream to stop the loop by setting _running=False
|
||||
async def mock_reset():
|
||||
vad_agent._running = False
|
||||
|
||||
vad_agent._reset_stream = mock_reset
|
||||
|
||||
# Needs a poller to avoid AssertionError
|
||||
vad_agent.audio_in_poller = AsyncMock()
|
||||
vad_agent.audio_in_poller.poll.return_value = None
|
||||
|
||||
await vad_agent._streaming_loop()
|
||||
|
||||
assert vad_agent._reset_needed is False
|
||||
assert len(vad_agent.audio_buffer) == 0
|
||||
assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_loop_no_data_clears_buffer(audio_out_socket, vad_agent):
|
||||
"""Test that if poll returns None, buffer is cleared if not empty."""
|
||||
vad_agent.audio_buffer = np.array([1.0], dtype=np.float32)
|
||||
vad_agent._ready.set()
|
||||
vad_agent._paused.set()
|
||||
vad_agent._running = True
|
||||
|
||||
class MockPoller:
|
||||
async def poll(self, timeout_ms=None):
|
||||
vad_agent._running = False # stop after one poll
|
||||
return None
|
||||
|
||||
vad_agent.audio_in_poller = MockPoller()
|
||||
|
||||
await vad_agent._streaming_loop()
|
||||
|
||||
assert len(vad_agent.audio_buffer) == 0
|
||||
assert vad_agent.i_since_speech == settings.behaviour_settings.vad_initial_since_speech
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vad_model_load_failure_stops_agent(vad_agent):
|
||||
"""
|
||||
Test that if loading the VAD model raises an Exception, it is caught,
|
||||
the agent logs an exception, stops itself, and setup returns.
|
||||
"""
|
||||
# Patch torch.hub.load to raise an exception
|
||||
with patch(
|
||||
"control_backend.agents.perception.vad_agent.torch.hub.load",
|
||||
side_effect=Exception("model fail"),
|
||||
):
|
||||
# Patch stop to an AsyncMock so we can check it was awaited
|
||||
vad_agent.stop = AsyncMock()
|
||||
|
||||
await vad_agent.setup()
|
||||
|
||||
# Assert stop was called
|
||||
vad_agent.stop.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog):
|
||||
"""
|
||||
Test that if binding the output socket raises ZMQBindError,
|
||||
audio_out_socket is set to None, None is returned, and an error is logged.
|
||||
"""
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.bind.side_effect = zmq.ZMQBindError()
|
||||
with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx:
|
||||
mock_ctx.return_value.socket.return_value = mock_socket
|
||||
|
||||
with caplog.at_level("ERROR"):
|
||||
port = vad_agent._connect_audio_out_socket()
|
||||
|
||||
assert port is None
|
||||
assert vad_agent.audio_out_socket is None
|
||||
assert caplog.text is not None
|
||||
24
test/unit/agents/test_base.py
Normal file
24
test/unit/agents/test_base.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import logging
|
||||
|
||||
from control_backend.agents.base import BaseAgent
|
||||
|
||||
|
||||
class MyAgent(BaseAgent):
|
||||
async def setup(self):
|
||||
pass
|
||||
|
||||
async def handle_message(self, msg):
|
||||
pass
|
||||
|
||||
|
||||
def test_base_agent_logger_init():
|
||||
# When defining a subclass, __init_subclass__ runs
|
||||
# The BaseAgent in agents/base.py sets the logger
|
||||
assert hasattr(MyAgent, "logger")
|
||||
assert isinstance(MyAgent.logger, logging.Logger)
|
||||
# The logger name depends on the package.
|
||||
# Since this test file is running as a module, __package__ might be None or the test package.
|
||||
# In 'src/control_backend/agents/base.py', it uses __package__ of base.py which is
|
||||
# 'control_backend.agents'.
|
||||
# So logger name should be control_backend.agents.MyAgent
|
||||
assert MyAgent.logger.name == "control_backend.agents.MyAgent"
|
||||
692
test/unit/agents/user_interrupt/test_user_interrupt.py
Normal file
692
test/unit/agents/user_interrupt/test_user_interrupt.py
Normal file
@@ -0,0 +1,692 @@
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_message import BeliefMessage
|
||||
from control_backend.schemas.program import (
|
||||
ConditionalNorm,
|
||||
Goal,
|
||||
KeywordBelief,
|
||||
Phase,
|
||||
Plan,
|
||||
Program,
|
||||
Trigger,
|
||||
)
|
||||
from control_backend.schemas.ri_message import RIEndpoint
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
agent = UserInterruptAgent(name="user_interrupt_agent")
|
||||
agent.send = AsyncMock()
|
||||
agent.logger = MagicMock()
|
||||
agent.sub_socket = AsyncMock()
|
||||
agent.pub_socket = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch(
|
||||
"control_backend.agents.user_interrupt.user_interrupt_agent.experiment_logger"
|
||||
) as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_speech_agent(agent):
|
||||
"""Verify speech command format."""
|
||||
await agent._send_to_speech_agent("Hello World")
|
||||
|
||||
agent.send.assert_awaited_once()
|
||||
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
||||
|
||||
assert sent_msg.to == settings.agent_settings.robot_speech_name
|
||||
body = json.loads(sent_msg.body)
|
||||
assert body["data"] == "Hello World"
|
||||
assert body["is_priority"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_gesture_agent(agent):
|
||||
"""Verify gesture command format."""
|
||||
await agent._send_to_gesture_agent("wave_hand")
|
||||
|
||||
agent.send.assert_awaited_once()
|
||||
sent_msg: InternalMessage = agent.send.call_args.args[0]
|
||||
|
||||
assert sent_msg.to == settings.agent_settings.robot_gesture_name
|
||||
body = json.loads(sent_msg.body)
|
||||
assert body["data"] == "wave_hand"
|
||||
assert body["is_priority"] is True
|
||||
assert body["endpoint"] == RIEndpoint.GESTURE_SINGLE.value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_bdi_belief(agent):
|
||||
"""Verify belief update format."""
|
||||
context_str = "some_goal"
|
||||
|
||||
await agent._send_to_bdi_belief(context_str, "goal")
|
||||
|
||||
assert agent.send.await_count == 1
|
||||
sent_msg = agent.send.call_args.args[0]
|
||||
|
||||
assert sent_msg.to == settings.agent_settings.bdi_core_name
|
||||
assert sent_msg.thread == "beliefs"
|
||||
assert "achieved_some_goal" in sent_msg.body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_routing_success(agent):
|
||||
"""
|
||||
Test that the loop correctly:
|
||||
1. Receives 'button_pressed' topic from ZMQ
|
||||
2. Parses the JSON payload to find 'type' and 'context'
|
||||
3. Calls the correct handler method based on 'type'
|
||||
"""
|
||||
# Prepare JSON payloads as bytes
|
||||
payload_speech = json.dumps({"type": "speech", "context": "Hello Speech"}).encode()
|
||||
payload_gesture = json.dumps({"type": "gesture", "context": "Hello Gesture"}).encode()
|
||||
# override calls _send_to_bdi (for trigger/norm) OR _send_to_bdi_belief (for goal).
|
||||
|
||||
# To test routing, we need to populate the maps
|
||||
agent._goal_map["Hello Override"] = "some_goal_slug"
|
||||
payload_override = json.dumps({"type": "override", "context": "Hello Override"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"button_pressed", payload_speech),
|
||||
(b"button_pressed", payload_gesture),
|
||||
(b"button_pressed", payload_override),
|
||||
asyncio.CancelledError, # Stop the infinite loop
|
||||
]
|
||||
|
||||
agent._send_to_speech_agent = AsyncMock()
|
||||
agent._send_to_gesture_agent = AsyncMock()
|
||||
agent._send_to_bdi_belief = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Speech
|
||||
agent._send_to_speech_agent.assert_awaited_once_with("Hello Speech")
|
||||
|
||||
# Gesture
|
||||
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
|
||||
|
||||
# Override (since we mapped it to a goal)
|
||||
agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug", "goal")
|
||||
|
||||
assert agent._send_to_speech_agent.await_count == 1
|
||||
assert agent._send_to_gesture_agent.await_count == 1
|
||||
assert agent._send_to_bdi_belief.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_unknown_type(agent):
|
||||
"""Test that unknown 'type' values in the JSON log a warning and do not crash."""
|
||||
|
||||
# Prepare a payload with an unknown type
|
||||
payload_unknown = json.dumps({"type": "unknown_thing", "context": "some_data"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"button_pressed", payload_unknown),
|
||||
asyncio.CancelledError,
|
||||
]
|
||||
|
||||
agent._send_to_speech_agent = AsyncMock()
|
||||
agent._send_to_gesture_agent = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Ensure no handlers were called
|
||||
agent._send_to_speech_agent.assert_not_called()
|
||||
agent._send_to_gesture_agent.assert_not_called()
|
||||
|
||||
agent.logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mapping(agent):
|
||||
# Create a program with a trigger, goal, and conditional norm
|
||||
import uuid
|
||||
|
||||
trigger_id = uuid.uuid4()
|
||||
goal_id = uuid.uuid4()
|
||||
norm_id = uuid.uuid4()
|
||||
|
||||
cond = KeywordBelief(id=uuid.uuid4(), name="k1", keyword="key")
|
||||
plan = Plan(id=uuid.uuid4(), name="p1", steps=[])
|
||||
|
||||
trigger = Trigger(id=trigger_id, name="my_trigger", condition=cond, plan=plan)
|
||||
goal = Goal(id=goal_id, name="my_goal", description="desc", plan=plan)
|
||||
|
||||
cn = ConditionalNorm(id=norm_id, name="my_norm", norm="be polite", condition=cond)
|
||||
|
||||
phase = Phase(id=uuid.uuid4(), name="phase1", norms=[cn], goals=[goal], triggers=[trigger])
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
# Call create_mapping via handle_message
|
||||
msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json())
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Check maps
|
||||
assert str(trigger_id) in agent._trigger_map
|
||||
assert agent._trigger_map[str(trigger_id)] == "trigger_my_trigger"
|
||||
|
||||
assert str(goal_id) in agent._goal_map
|
||||
assert agent._goal_map[str(goal_id)] == "my_goal"
|
||||
|
||||
assert str(norm_id) in agent._cond_norm_map
|
||||
assert agent._cond_norm_map[str(norm_id)] == "norm_be_polite"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mapping_invalid_json(agent):
|
||||
# Pass invalid json to handle_message thread "new_program"
|
||||
msg = InternalMessage(to="me", thread="new_program", body="invalid json")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Should log error and maps should remain empty or cleared
|
||||
agent.logger.error.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_trigger_start(agent):
|
||||
# Setup reverse map manually
|
||||
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
|
||||
|
||||
msg = InternalMessage(to="me", thread="trigger_start", body="trigger_slug")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.pub_socket.send_multipart.assert_awaited_once()
|
||||
args = agent.pub_socket.send_multipart.call_args[0][0]
|
||||
assert args[0] == b"experiment"
|
||||
payload = json.loads(args[1])
|
||||
assert payload["type"] == "trigger_update"
|
||||
assert payload["id"] == "ui_id_123"
|
||||
assert payload["achieved"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_trigger_end(agent):
|
||||
agent._trigger_reverse_map["trigger_slug"] = "ui_id_123"
|
||||
|
||||
msg = InternalMessage(to="me", thread="trigger_end", body="trigger_slug")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.pub_socket.send_multipart.assert_awaited_once()
|
||||
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
||||
assert payload["type"] == "trigger_update"
|
||||
assert payload["achieved"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_transition_phase(agent):
|
||||
msg = InternalMessage(to="me", thread="transition_phase", body="phase_id_123")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.pub_socket.send_multipart.assert_awaited_once()
|
||||
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
||||
assert payload["type"] == "phase_update"
|
||||
assert payload["id"] == "phase_id_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_goal_start(agent):
|
||||
agent._goal_reverse_map["goal_slug"] = "goal_id_123"
|
||||
|
||||
msg = InternalMessage(to="me", thread="goal_start", body="goal_slug")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.pub_socket.send_multipart.assert_awaited_once()
|
||||
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
||||
assert payload["type"] == "goal_update"
|
||||
assert payload["id"] == "goal_id_123"
|
||||
assert payload["active"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_active_norms_update(agent):
|
||||
agent._cond_norm_reverse_map["norm_active"] = "id_1"
|
||||
agent._cond_norm_reverse_map["norm_inactive"] = "id_2"
|
||||
|
||||
# Body is like: "('norm_active', 'other')"
|
||||
# The split logic handles quotes etc.
|
||||
msg = InternalMessage(to="me", thread="active_norms_update", body="'norm_active', 'other'")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.pub_socket.send_multipart.assert_awaited_once()
|
||||
payload = json.loads(agent.pub_socket.send_multipart.call_args[0][0][1])
|
||||
assert payload["type"] == "cond_norms_state_update"
|
||||
norms = {n["id"]: n["active"] for n in payload["norms"]}
|
||||
assert norms["id_1"] is True
|
||||
assert norms["id_2"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_experiment_control(agent):
|
||||
# Test next_phase
|
||||
await agent._send_experiment_control_to_bdi_core("next_phase")
|
||||
agent.send.assert_awaited()
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert msg.thread == "force_next_phase"
|
||||
|
||||
# Test reset_phase
|
||||
await agent._send_experiment_control_to_bdi_core("reset_phase")
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert msg.thread == "reset_current_phase"
|
||||
|
||||
# Test reset_experiment
|
||||
await agent._send_experiment_control_to_bdi_core("reset_experiment")
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert msg.thread == "reset_experiment"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_pause_command(agent):
|
||||
await agent._send_pause_command("true")
|
||||
# Sends to RI and VAD
|
||||
assert agent.send.await_count == 2
|
||||
msgs = [call.args[0] for call in agent.send.call_args_list]
|
||||
|
||||
ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name)
|
||||
assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint
|
||||
assert json.loads(ri_msg.body)["data"] is True
|
||||
|
||||
vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name)
|
||||
assert vad_msg.body == "PAUSE"
|
||||
|
||||
agent.send.reset_mock()
|
||||
await agent._send_pause_command("false")
|
||||
assert agent.send.await_count == 2
|
||||
vad_msg = next(
|
||||
m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name
|
||||
).args[0]
|
||||
assert vad_msg.body == "RESUME"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup(agent):
|
||||
"""Test the setup method initializes sockets correctly."""
|
||||
with patch("control_backend.agents.user_interrupt.user_interrupt_agent.Context") as MockContext:
|
||||
mock_ctx_instance = MagicMock()
|
||||
MockContext.instance.return_value = mock_ctx_instance
|
||||
|
||||
mock_sub = MagicMock()
|
||||
mock_pub = MagicMock()
|
||||
mock_ctx_instance.socket.side_effect = [mock_sub, mock_pub]
|
||||
|
||||
# MOCK add_behavior so we don't rely on internal attributes
|
||||
agent.add_behavior = MagicMock()
|
||||
|
||||
await agent.setup()
|
||||
|
||||
# Check sockets
|
||||
mock_sub.connect.assert_called_with(settings.zmq_settings.internal_sub_address)
|
||||
mock_pub.connect.assert_called_with(settings.zmq_settings.internal_pub_address)
|
||||
|
||||
# Verify add_behavior was called
|
||||
agent.add_behavior.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_json_error(agent):
|
||||
"""Verify that malformed JSON is caught and logged without crashing the loop."""
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"topic", b"INVALID{JSON"),
|
||||
asyncio.CancelledError,
|
||||
]
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_override_trigger(agent):
|
||||
"""Verify routing 'override' to a Trigger."""
|
||||
agent._trigger_map["101"] = "trigger_slug"
|
||||
payload = json.dumps({"type": "override", "context": "101"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
||||
agent._send_to_bdi = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent._send_to_bdi.assert_awaited_once_with("force_trigger", "trigger_slug")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_override_norm(agent):
|
||||
"""Verify routing 'override' to a Conditional Norm."""
|
||||
agent._cond_norm_map["202"] = "norm_slug"
|
||||
payload = json.dumps({"type": "override", "context": "202"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
||||
agent._send_to_bdi_belief = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent._send_to_bdi_belief.assert_awaited_once_with("norm_slug", "cond_norm")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_override_missing(agent):
|
||||
"""Verify warning log when an override ID is not found in any map."""
|
||||
payload = json.dumps({"type": "override", "context": "999"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent.logger.warning.assert_called_with("Could not determine which element to override.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_unachieve_logic(agent):
|
||||
"""Verify success and failure paths for override_unachieve."""
|
||||
agent._cond_norm_map["202"] = "norm_slug"
|
||||
|
||||
success_payload = json.dumps({"type": "override_unachieve", "context": "202"}).encode()
|
||||
fail_payload = json.dumps({"type": "override_unachieve", "context": "999"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"topic", success_payload),
|
||||
(b"topic", fail_payload),
|
||||
asyncio.CancelledError,
|
||||
]
|
||||
agent._send_to_bdi_belief = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Assert success call (True flag for unachieve)
|
||||
agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True)
|
||||
# Assert failure log
|
||||
agent.logger.warning.assert_called_with(
|
||||
"Could not determine which conditional norm to unachieve."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_pause_resume(agent):
|
||||
"""Verify pause and resume toggle logic and logging."""
|
||||
pause_payload = json.dumps({"type": "pause", "context": "true"}).encode()
|
||||
resume_payload = json.dumps({"type": "pause", "context": ""}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"topic", pause_payload),
|
||||
(b"topic", resume_payload),
|
||||
asyncio.CancelledError,
|
||||
]
|
||||
agent._send_pause_command = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent._send_pause_command.assert_any_call("true")
|
||||
agent._send_pause_command.assert_any_call("")
|
||||
agent.logger.info.assert_any_call("Sent pause command.")
|
||||
agent.logger.info.assert_any_call("Sent resume command.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_phase_control(agent):
|
||||
"""Verify experiment flow control (next_phase)."""
|
||||
payload = json.dumps({"type": "next_phase", "context": ""}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
||||
agent._send_experiment_control_to_bdi_core = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent._send_experiment_control_to_bdi_core.assert_awaited_once_with("next_phase")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_unknown_thread(agent):
|
||||
"""Test handling of an unknown message thread (lines 213-214)."""
|
||||
msg = InternalMessage(to="me", thread="unknown_thread", body="test")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.logger.debug.assert_called_with(
|
||||
"Received internal message on unhandled thread: unknown_thread"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_bdi_belief_edge_cases(agent):
|
||||
"""
|
||||
Covers:
|
||||
- Unknown asl_type warning (lines 326-328)
|
||||
- unachieve=True logic (lines 334-337)
|
||||
"""
|
||||
# 1. Unknown Type
|
||||
await agent._send_to_bdi_belief("slug", "unknown_type")
|
||||
|
||||
agent.logger.warning.assert_called_with("Tried to send belief with unknown type")
|
||||
agent.send.assert_not_called()
|
||||
|
||||
# Reset mock for part 2
|
||||
agent.send.reset_mock()
|
||||
|
||||
# 2. Unachieve = True
|
||||
await agent._send_to_bdi_belief("slug", "cond_norm", unachieve=True)
|
||||
|
||||
agent.send.assert_awaited()
|
||||
sent_msg = agent.send.call_args.args[0]
|
||||
|
||||
# Verify it is a delete operation
|
||||
body_obj = BeliefMessage.model_validate_json(sent_msg.body)
|
||||
|
||||
# Verify 'delete' has content
|
||||
assert body_obj.delete is not None
|
||||
assert len(body_obj.delete) == 1
|
||||
assert body_obj.delete[0].name == "force_slug"
|
||||
|
||||
# Verify 'create' is empty (handling both None and [])
|
||||
assert not body_obj.create
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_experiment_control_unknown(agent):
|
||||
"""Test sending an unknown experiment control type (lines 366-367)."""
|
||||
await agent._send_experiment_control_to_bdi_core("invalid_command")
|
||||
|
||||
agent.logger.warning.assert_called_with(
|
||||
"Received unknown experiment control type '%s' to send to BDI Core.", "invalid_command"
|
||||
)
|
||||
|
||||
# Ensure it still sends an empty message (as per code logic, though thread is empty)
|
||||
agent.send.assert_awaited()
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert msg.thread == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mapping_recursive_goals(agent):
|
||||
"""Verify that nested subgoals are correctly registered in the mapping."""
|
||||
import uuid
|
||||
|
||||
# 1. Setup IDs
|
||||
parent_goal_id = uuid.uuid4()
|
||||
child_goal_id = uuid.uuid4()
|
||||
|
||||
# 2. Create the child goal
|
||||
child_goal = Goal(
|
||||
id=child_goal_id,
|
||||
name="child_goal",
|
||||
description="I am a subgoal",
|
||||
plan=Plan(id=uuid.uuid4(), name="p_child", steps=[]),
|
||||
)
|
||||
|
||||
# 3. Create the parent goal and put the child goal inside its plan steps
|
||||
parent_goal = Goal(
|
||||
id=parent_goal_id,
|
||||
name="parent_goal",
|
||||
description="I am a parent",
|
||||
plan=Plan(id=uuid.uuid4(), name="p_parent", steps=[child_goal]), # Nested here
|
||||
)
|
||||
|
||||
# 4. Build the program
|
||||
phase = Phase(
|
||||
id=uuid.uuid4(),
|
||||
name="phase1",
|
||||
norms=[],
|
||||
goals=[parent_goal], # Only the parent is top-level
|
||||
triggers=[],
|
||||
)
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
# 5. Execute mapping
|
||||
msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json())
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# 6. Assertions
|
||||
# Check parent
|
||||
assert str(parent_goal_id) in agent._goal_map
|
||||
assert agent._goal_map[str(parent_goal_id)] == "parent_goal"
|
||||
|
||||
# Check child (This confirms the recursion worked)
|
||||
assert str(child_goal_id) in agent._goal_map
|
||||
assert agent._goal_map[str(child_goal_id)] == "child_goal"
|
||||
assert agent._goal_reverse_map["child_goal"] == str(child_goal_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_advanced_scenarios(agent):
|
||||
"""
|
||||
Covers:
|
||||
- JSONDecodeError (lines 86-88)
|
||||
- Override: Trigger found (lines 108-109)
|
||||
- Override: Norm found (lines 114-115)
|
||||
- Override: Nothing found (line 134)
|
||||
- Override Unachieve: Success & Fail (lines 136-145)
|
||||
- Pause: Context true/false logs (lines 150-157)
|
||||
- Next Phase (line 160)
|
||||
"""
|
||||
# 1. Setup Data Maps
|
||||
agent._trigger_map["101"] = "trigger_slug"
|
||||
agent._cond_norm_map["202"] = "norm_slug"
|
||||
|
||||
# 2. Define Payloads
|
||||
# A. Invalid JSON
|
||||
bad_json = b"INVALID{JSON"
|
||||
|
||||
# B. Override -> Trigger
|
||||
override_trigger = json.dumps({"type": "override", "context": "101"}).encode()
|
||||
|
||||
# C. Override -> Norm
|
||||
override_norm = json.dumps({"type": "override", "context": "202"}).encode()
|
||||
|
||||
# D. Override -> Unknown
|
||||
override_fail = json.dumps({"type": "override", "context": "999"}).encode()
|
||||
|
||||
# E. Unachieve -> Success
|
||||
unachieve_success = json.dumps({"type": "override_unachieve", "context": "202"}).encode()
|
||||
|
||||
# F. Unachieve -> Fail
|
||||
unachieve_fail = json.dumps({"type": "override_unachieve", "context": "999"}).encode()
|
||||
|
||||
# G. Pause (True)
|
||||
pause_true = json.dumps({"type": "pause", "context": "true"}).encode()
|
||||
|
||||
# H. Pause (False/Resume)
|
||||
pause_false = json.dumps({"type": "pause", "context": ""}).encode()
|
||||
|
||||
# I. Next Phase
|
||||
next_phase = json.dumps({"type": "next_phase", "context": ""}).encode()
|
||||
|
||||
# 3. Setup Socket
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"topic", bad_json),
|
||||
(b"topic", override_trigger),
|
||||
(b"topic", override_norm),
|
||||
(b"topic", override_fail),
|
||||
(b"topic", unachieve_success),
|
||||
(b"topic", unachieve_fail),
|
||||
(b"topic", pause_true),
|
||||
(b"topic", pause_false),
|
||||
(b"topic", next_phase),
|
||||
asyncio.CancelledError, # End loop
|
||||
]
|
||||
|
||||
# Mock internal helpers to verify calls
|
||||
agent._send_to_bdi = AsyncMock()
|
||||
agent._send_to_bdi_belief = AsyncMock()
|
||||
agent._send_pause_command = AsyncMock()
|
||||
agent._send_experiment_control_to_bdi_core = AsyncMock()
|
||||
|
||||
# 4. Run Loop
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 5. Assertions
|
||||
|
||||
# JSON Error
|
||||
agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic")
|
||||
|
||||
# Override Trigger
|
||||
agent._send_to_bdi.assert_awaited_with("force_trigger", "trigger_slug")
|
||||
|
||||
# Override Norm
|
||||
# We expect _send_to_bdi_belief to be called for the norm
|
||||
# Note: The loop calls _send_to_bdi_belief(asl_cond_norm, "cond_norm")
|
||||
agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm")
|
||||
|
||||
# Override Fail (Warning log)
|
||||
agent.logger.warning.assert_any_call("Could not determine which element to override.")
|
||||
|
||||
# Unachieve Success
|
||||
# Loop calls _send_to_bdi_belief(asl_cond_norm, "cond_norm", True)
|
||||
agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True)
|
||||
|
||||
# Unachieve Fail
|
||||
agent.logger.warning.assert_any_call("Could not determine which conditional norm to unachieve.")
|
||||
|
||||
# Pause Logic
|
||||
agent._send_pause_command.assert_any_call("true")
|
||||
agent.logger.info.assert_any_call("Sent pause command.")
|
||||
|
||||
# Resume Logic
|
||||
agent._send_pause_command.assert_any_call("")
|
||||
agent.logger.info.assert_any_call("Sent resume command.")
|
||||
|
||||
# Next Phase
|
||||
agent._send_experiment_control_to_bdi_core.assert_awaited_with("next_phase")
|
||||
127
test/unit/api/v1/endpoints/test_logs_endpoint.py
Normal file
127
test/unit/api/v1/endpoints/test_logs_endpoint.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from control_backend.api.v1.endpoints import logs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""TestClient with logs router included."""
|
||||
app = FastAPI()
|
||||
app.include_router(logs.router)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_stream_endpoint_lines(client):
|
||||
"""Call /logs/stream with a mocked ZMQ socket to cover all lines."""
|
||||
|
||||
# Dummy socket to mock ZMQ behavior
|
||||
class DummySocket:
|
||||
def __init__(self):
|
||||
self.subscribed = []
|
||||
self.connected = False
|
||||
self.recv_count = 0
|
||||
|
||||
def subscribe(self, topic):
|
||||
self.subscribed.append(topic)
|
||||
|
||||
def connect(self, addr):
|
||||
self.connected = True
|
||||
|
||||
async def recv_multipart(self):
|
||||
# Return one message, then stop generator
|
||||
if self.recv_count == 0:
|
||||
self.recv_count += 1
|
||||
return (b"INFO", b"test message")
|
||||
else:
|
||||
raise StopAsyncIteration
|
||||
|
||||
dummy_socket = DummySocket()
|
||||
|
||||
# Patch Context.instance().socket() to return dummy socket
|
||||
with patch("control_backend.api.v1.endpoints.logs.Context.instance") as mock_context:
|
||||
mock_context.return_value.socket.return_value = dummy_socket
|
||||
|
||||
# Call the endpoint directly
|
||||
response = await logs.log_stream()
|
||||
assert isinstance(response, StreamingResponse)
|
||||
|
||||
# Fetch one chunk from the generator
|
||||
gen = response.body_iterator
|
||||
chunk = await gen.__anext__()
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8")
|
||||
assert "data:" in chunk
|
||||
|
||||
# Optional: assert subscribe/connect were called
|
||||
assert dummy_socket.subscribed # at least some log levels subscribed
|
||||
assert dummy_socket.connected # connect was called
|
||||
|
||||
|
||||
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||
def test_files_endpoint(LOGGING_DIR, client):
|
||||
file_1, file_2 = MagicMock(), MagicMock()
|
||||
file_1.name = "file_1"
|
||||
file_2.name = "file_2"
|
||||
LOGGING_DIR.glob.return_value = [file_1, file_2]
|
||||
result = client.get("/api/logs/files")
|
||||
|
||||
assert result.status_code == 200
|
||||
assert result.json() == ["file_1", "file_2"]
|
||||
|
||||
|
||||
@patch("control_backend.api.v1.endpoints.logs.FileResponse")
|
||||
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||
def test_log_file_endpoint_success(LOGGING_DIR, MockFileResponse, client):
|
||||
mock_file_path = MagicMock()
|
||||
mock_file_path.is_relative_to.return_value = True
|
||||
mock_file_path.is_file.return_value = True
|
||||
mock_file_path.name = "test.log"
|
||||
|
||||
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||
mock_file_path.resolve.return_value = mock_file_path
|
||||
|
||||
MockFileResponse.return_value = MagicMock()
|
||||
|
||||
result = client.get("/api/logs/files/test.log")
|
||||
|
||||
assert result.status_code == 200
|
||||
MockFileResponse.assert_called_once_with(mock_file_path, filename="test.log")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||
async def test_log_file_endpoint_path_traversal(LOGGING_DIR):
|
||||
from control_backend.api.v1.endpoints.logs import log_file
|
||||
|
||||
mock_file_path = MagicMock()
|
||||
mock_file_path.is_relative_to.return_value = False
|
||||
|
||||
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||
mock_file_path.resolve.return_value = mock_file_path
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await log_file("../secret.txt")
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == "Invalid filename."
|
||||
|
||||
|
||||
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||
def test_log_file_endpoint_file_not_found(LOGGING_DIR, client):
|
||||
mock_file_path = MagicMock()
|
||||
mock_file_path.is_relative_to.return_value = True
|
||||
mock_file_path.is_file.return_value = False
|
||||
|
||||
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||
mock_file_path.resolve.return_value = mock_file_path
|
||||
|
||||
result = client.get("/api/logs/files/nonexistent.log")
|
||||
|
||||
assert result.status_code == 404
|
||||
assert result.json()["detail"] == "File not found."
|
||||
45
test/unit/api/v1/endpoints/test_message_endpoint.py
Normal file
45
test/unit/api/v1/endpoints/test_message_endpoint.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from control_backend.api.v1.endpoints import message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""FastAPI TestClient for the message router."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(message.router)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_receive_message_post(client, monkeypatch):
|
||||
"""Test POST /message endpoint sends message to pub socket."""
|
||||
|
||||
# Dummy pub socket to capture sent messages
|
||||
class DummyPubSocket:
|
||||
def __init__(self):
|
||||
self.sent = []
|
||||
|
||||
async def send_multipart(self, msg):
|
||||
self.sent.append(msg)
|
||||
|
||||
dummy_socket = DummyPubSocket()
|
||||
|
||||
# Patch app.state.endpoints_pub_socket
|
||||
client.app.state.endpoints_pub_socket = dummy_socket
|
||||
|
||||
data = {"message": "Hello world"}
|
||||
response = client.post("/message", json=data)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Message received"}
|
||||
|
||||
# Ensure the message was sent via pub_socket
|
||||
assert len(dummy_socket.sent) == 1
|
||||
topic, body = dummy_socket.sent[0]
|
||||
parsed = json.loads(body.decode("utf-8"))
|
||||
assert parsed["message"] == "Hello world"
|
||||
137
test/unit/api/v1/endpoints/test_program_endpoint.py
Normal file
137
test/unit/api/v1/endpoints/test_program_endpoint.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from control_backend.api.v1.endpoints import program
|
||||
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a FastAPI app with the /program route and mock socket."""
|
||||
app = FastAPI()
|
||||
app.include_router(program.router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a TestClient."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def make_valid_program_dict():
|
||||
"""Helper to create a valid Program JSON structure."""
|
||||
# Converting to JSON using Pydantic because it knows how to convert a UUID object
|
||||
program_json_str = Program(
|
||||
phases=[
|
||||
Phase(
|
||||
id=uuid.uuid4(),
|
||||
name="Basic Phase",
|
||||
norms=[
|
||||
BasicNorm(
|
||||
id=uuid.uuid4(),
|
||||
name="Some norm",
|
||||
norm="Do normal.",
|
||||
),
|
||||
],
|
||||
goals=[
|
||||
Goal(
|
||||
id=uuid.uuid4(),
|
||||
name="Some goal",
|
||||
description="This description can be used to determine whether the goal "
|
||||
"has been achieved.",
|
||||
plan=Plan(
|
||||
id=uuid.uuid4(),
|
||||
name="Goal Plan",
|
||||
steps=[],
|
||||
),
|
||||
can_fail=False,
|
||||
),
|
||||
],
|
||||
triggers=[],
|
||||
),
|
||||
],
|
||||
).model_dump_json()
|
||||
# Converting back to a dict because that's what's expected
|
||||
return json.loads(program_json_str)
|
||||
|
||||
|
||||
def test_receive_program_success(client):
|
||||
"""Valid Program JSON should be parsed and sent through the socket."""
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
program_dict = make_valid_program_dict()
|
||||
|
||||
response = client.post("/program", json=program_dict)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Program parsed"}
|
||||
|
||||
# Verify socket call
|
||||
mock_pub_socket.send_multipart.assert_awaited_once()
|
||||
args, kwargs = mock_pub_socket.send_multipart.await_args
|
||||
|
||||
assert args[0][0] == b"program"
|
||||
|
||||
sent_bytes = args[0][1]
|
||||
sent_obj = json.loads(sent_bytes.decode())
|
||||
|
||||
# Converting to JSON using Pydantic because it knows how to handle UUIDs
|
||||
expected_obj = json.loads(Program.model_validate(program_dict).model_dump_json())
|
||||
assert sent_obj == expected_obj
|
||||
|
||||
|
||||
def test_receive_program_invalid_json(client):
|
||||
"""
|
||||
Invalid JSON (malformed) -> FastAPI never calls endpoint.
|
||||
It returns a 422 Unprocessable Entity.
|
||||
"""
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
# FastAPI only accepts valid JSON bodies, so send raw string
|
||||
response = client.post("/program", content="{invalid json}")
|
||||
|
||||
assert response.status_code == 422
|
||||
mock_pub_socket.send_multipart.assert_not_called()
|
||||
|
||||
|
||||
def test_receive_program_invalid_deep_structure(client):
|
||||
"""
|
||||
Valid JSON but schema invalid -> Pydantic throws validation error -> 422.
|
||||
"""
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
# Missing "value" in norms element
|
||||
bad_program = {
|
||||
"phases": [
|
||||
{
|
||||
"id": "phase1",
|
||||
"name": "deepfail",
|
||||
"nextPhaseId": "phase2",
|
||||
"phaseData": {
|
||||
"norms": [
|
||||
{"id": "n1", "name": "norm"} # INVALID: missing "value"
|
||||
],
|
||||
"goals": [
|
||||
{"id": "g1", "name": "goal", "description": "desc", "achieved": False}
|
||||
],
|
||||
"triggers": [
|
||||
{"id": "t1", "label": "trigger", "type": "keyword", "value": ["start"]}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = client.post("/program", json=bad_program)
|
||||
|
||||
assert response.status_code == 422
|
||||
mock_pub_socket.send_multipart.assert_not_called()
|
||||
415
test/unit/api/v1/endpoints/test_robot_endpoint.py
Normal file
415
test/unit/api/v1/endpoints/test_robot_endpoint.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# tests/test_robot_endpoints.py
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import zmq.asyncio
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from control_backend.api.v1.endpoints import robot
|
||||
from control_backend.schemas.ri_message import GestureCommand, SpeechCommand
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""
|
||||
Creates a FastAPI test app and attaches the router under test.
|
||||
Also sets up a mock internal_comm_socket.
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.include_router(robot.router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client for the app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_zmq_context():
|
||||
"""Mock the ZMQ context used by the endpoint module."""
|
||||
with patch("control_backend.api.v1.endpoints.robot.Context.instance") as mock_context:
|
||||
context_instance = MagicMock()
|
||||
mock_context.return_value = context_instance
|
||||
yield context_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sockets(mock_zmq_context):
|
||||
"""Optional helper if you want both a sub and req/push socket available."""
|
||||
mock_sub_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
|
||||
mock_zmq_context.socket.return_value = mock_sub_socket
|
||||
|
||||
return {"sub": mock_sub_socket, "req": mock_req_socket}
|
||||
|
||||
|
||||
def test_receive_speech_command_success(client):
|
||||
"""
|
||||
Test for successful reception of a command. Ensures the status code is 202 and the response body
|
||||
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
|
||||
expected data.
|
||||
"""
|
||||
# Arrange
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
|
||||
speech_command = SpeechCommand(**command_data)
|
||||
|
||||
# Act
|
||||
response = client.post("/command/speech", json=command_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Speech command received"}
|
||||
|
||||
# Verify that the ZMQ socket was used correctly
|
||||
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
||||
[b"command", speech_command.model_dump_json().encode()]
|
||||
)
|
||||
|
||||
|
||||
def test_receive_gesture_command_success(client):
|
||||
"""
|
||||
Test for successful reception of a command that is a gesture command.
|
||||
Ensures the status code is 202 and the response body is correct.
|
||||
"""
|
||||
# Arrange
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
command_data = {"endpoint": "actuate/gesture/tag", "data": "happy"}
|
||||
gesture_command = GestureCommand(**command_data)
|
||||
|
||||
# Act
|
||||
response = client.post("/command/gesture", json=command_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Gesture command received"}
|
||||
|
||||
# Verify that the ZMQ socket was used correctly
|
||||
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
||||
[b"command", gesture_command.model_dump_json().encode()]
|
||||
)
|
||||
|
||||
|
||||
def test_receive_speech_command_invalid_payload(client):
|
||||
"""
|
||||
Test invalid data handling (schema validation).
|
||||
"""
|
||||
# Missing required field(s)
|
||||
bad_payload = {"invalid": "data"}
|
||||
response = client.post("/command/speech", json=bad_payload)
|
||||
assert response.status_code == 422 # validation error
|
||||
|
||||
|
||||
def test_receive_gesture_command_invalid_payload(client):
|
||||
"""
|
||||
Test invalid data handling (schema validation).
|
||||
"""
|
||||
# Missing required field(s)
|
||||
bad_payload = {"invalid": "data"}
|
||||
response = client.post("/command/gesture", json=bad_payload)
|
||||
assert response.status_code == 422 # validation error
|
||||
|
||||
|
||||
def test_ping_check_returns_none(client):
|
||||
"""Ensure /ping_check returns 200 and None (currently unimplemented)."""
|
||||
response = client.get("/ping_check")
|
||||
assert response.status_code == 200
|
||||
assert response.json() is None
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# ping_stream tests (unchanged behavior)
|
||||
# ----------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_yields_ping_event(monkeypatch):
|
||||
"""Test that ping_stream yields a proper SSE message when a ping is received."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"ping", b"true"])
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
# patch settings address used by ping_stream
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
event = await anext(generator)
|
||||
event_text = event.decode() if isinstance(event, bytes) else str(event)
|
||||
assert event_text.strip() == "data: true"
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await anext(generator)
|
||||
|
||||
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_handles_timeout(monkeypatch):
|
||||
"""Test that ping_stream continues looping on TimeoutError."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart.side_effect = TimeoutError()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(return_value=True)
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await anext(generator)
|
||||
|
||||
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_yields_json_values(monkeypatch):
|
||||
"""Ensure ping_stream correctly parses and yields JSON body values."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart = AsyncMock(
|
||||
return_value=[b"ping", json.dumps({"connected": True}).encode()]
|
||||
)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.zmq_settings.internal_sub_address = "tcp://localhost:5555"
|
||||
monkeypatch.setattr(robot, "settings", mock_settings)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
event = await anext(generator)
|
||||
event_text = event.decode() if isinstance(event, bytes) else str(event)
|
||||
|
||||
assert "connected" in event_text
|
||||
assert "true" in event_text
|
||||
|
||||
mock_sub_socket.connect.assert_called_once_with("tcp://localhost:5555")
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Updated get_available_gesture_tags tests (REQ socket on tcp://localhost:7788)
|
||||
# ----------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_gesture_tags_success(client, monkeypatch):
|
||||
"""
|
||||
Test successful retrieval of available gesture tags using a REQ socket.
|
||||
"""
|
||||
# Arrange
|
||||
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_req_socket.connect = MagicMock()
|
||||
mock_req_socket.send = AsyncMock()
|
||||
response_data = {"tags": ["wave", "nod", "point", "dance"]}
|
||||
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_req_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
# Replace logger methods to avoid noisy logs in tests
|
||||
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
||||
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
||||
|
||||
# Act
|
||||
response = client.get("/commands/gesture/tags")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"available_gesture_tags": ["wave", "nod", "point", "dance"]}
|
||||
|
||||
# Verify ZeroMQ REQ interactions
|
||||
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
|
||||
mock_req_socket.send.assert_awaited_once_with(b"None")
|
||||
mock_req_socket.recv.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_gesture_tags_with_amount(client, monkeypatch):
|
||||
"""
|
||||
The endpoint currently ignores the 'amount' TODO, so behavior is the same as 'success'.
|
||||
This test asserts that the endpoint still sends b"None" and returns the tags.
|
||||
"""
|
||||
# Arrange
|
||||
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_req_socket.connect = MagicMock()
|
||||
mock_req_socket.send = AsyncMock()
|
||||
response_data = {"tags": ["wave", "nod"]}
|
||||
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_req_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
||||
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
||||
|
||||
# Act
|
||||
response = client.get("/commands/gesture/tags")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"available_gesture_tags": ["wave", "nod"]}
|
||||
|
||||
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
|
||||
mock_req_socket.send.assert_awaited_once_with(b"None")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_gesture_tags_timeout(client, monkeypatch):
|
||||
"""
|
||||
Test timeout scenario when fetching gesture tags. Endpoint should handle TimeoutError
|
||||
and return an empty list while logging the timeout.
|
||||
"""
|
||||
# Arrange
|
||||
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_req_socket.connect = MagicMock()
|
||||
mock_req_socket.send = AsyncMock()
|
||||
mock_req_socket.recv = AsyncMock(side_effect=TimeoutError)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_req_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
# Patch logger.debug so we can assert it was called with the expected message
|
||||
mock_debug = MagicMock()
|
||||
monkeypatch.setattr(robot.logger, "debug", mock_debug)
|
||||
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
||||
|
||||
# Act
|
||||
response = client.get("/commands/gesture/tags")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"available_gesture_tags": []}
|
||||
|
||||
# Verify the timeout was logged using the exact string from the endpoint code
|
||||
mock_debug.assert_called_once_with("Got timeout error fetching gestures.")
|
||||
|
||||
mock_req_socket.connect.assert_called_once_with("tcp://localhost:7788")
|
||||
mock_req_socket.send.assert_awaited_once_with(b"None")
|
||||
mock_req_socket.recv.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_gesture_tags_empty_response(client, monkeypatch):
|
||||
"""
|
||||
Test scenario when response contains an empty 'tags' list.
|
||||
"""
|
||||
# Arrange
|
||||
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_req_socket.connect = MagicMock()
|
||||
mock_req_socket.send = AsyncMock()
|
||||
response_data = {"tags": []}
|
||||
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_req_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
||||
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
||||
|
||||
# Act
|
||||
response = client.get("/commands/gesture/tags")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"available_gesture_tags": []}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_gesture_tags_missing_tags_key(client, monkeypatch):
|
||||
"""
|
||||
Test scenario when response JSON doesn't contain 'tags' key.
|
||||
"""
|
||||
# Arrange
|
||||
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_req_socket.connect = MagicMock()
|
||||
mock_req_socket.send = AsyncMock()
|
||||
response_data = {"some_other_key": "value"}
|
||||
mock_req_socket.recv = AsyncMock(return_value=json.dumps(response_data).encode())
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_req_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
||||
monkeypatch.setattr(robot.logger, "error", MagicMock())
|
||||
|
||||
# Act
|
||||
response = client.get("/commands/gesture/tags")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"available_gesture_tags": []}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_gesture_tags_invalid_json(client, monkeypatch):
|
||||
"""
|
||||
Test scenario when response contains invalid JSON. Endpoint should log the error
|
||||
and return an empty list.
|
||||
"""
|
||||
# Arrange
|
||||
mock_req_socket = AsyncMock(spec=zmq.asyncio.Socket)
|
||||
mock_req_socket.connect = MagicMock()
|
||||
mock_req_socket.send = AsyncMock()
|
||||
mock_req_socket.recv = AsyncMock(return_value=b"invalid json")
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_req_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_error = MagicMock()
|
||||
monkeypatch.setattr(robot.logger, "error", mock_error)
|
||||
monkeypatch.setattr(robot.logger, "debug", MagicMock())
|
||||
|
||||
# Act
|
||||
response = client.get("/commands/gesture/tags")
|
||||
|
||||
# Assert - invalid JSON should lead to empty list and error log invocation
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"available_gesture_tags": []}
|
||||
assert mock_error.call_count == 1
|
||||
15
test/unit/api/v1/endpoints/test_router.py
Normal file
15
test/unit/api/v1/endpoints/test_router.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from control_backend.api.v1.router import api_router # <--- corrected import
|
||||
|
||||
|
||||
def test_router_includes_expected_paths():
|
||||
"""Ensure api_router includes main router prefixes."""
|
||||
routes = [r for r in api_router.routes if isinstance(r, APIRoute)]
|
||||
paths = [r.path for r in routes]
|
||||
|
||||
# Ensure at least one route under each prefix exists
|
||||
assert any(p.startswith("/robot") for p in paths)
|
||||
assert any(p.startswith("/message") for p in paths)
|
||||
assert any(p.startswith("/logs") for p in paths)
|
||||
assert any(p.startswith("/program") for p in paths)
|
||||
148
test/unit/api/v1/endpoints/test_user_interact.py
Normal file
148
test/unit/api/v1/endpoints/test_user_interact.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from control_backend.api.v1.endpoints import user_interact
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = FastAPI()
|
||||
app.include_router(user_interact.router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_button_event(client):
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
payload = {"type": "speech", "context": "hello"}
|
||||
response = client.post("/button_pressed", json=payload)
|
||||
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Event received"}
|
||||
|
||||
mock_pub_socket.send_multipart.assert_awaited_once()
|
||||
args = mock_pub_socket.send_multipart.call_args[0][0]
|
||||
assert args[0] == b"button_pressed"
|
||||
assert "speech" in args[1].decode()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_button_event_invalid_payload(client):
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
# Missing context
|
||||
payload = {"type": "speech"}
|
||||
response = client.post("/button_pressed", json=payload)
|
||||
|
||||
assert response.status_code == 422
|
||||
mock_pub_socket.send_multipart.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_experiment_stream_direct_call():
|
||||
"""
|
||||
Directly calling the endpoint function to test the streaming logic
|
||||
without dealing with TestClient streaming limitations.
|
||||
"""
|
||||
mock_socket = AsyncMock()
|
||||
# 1. recv data
|
||||
# 2. recv timeout
|
||||
# 3. disconnect (request.is_disconnected returns True)
|
||||
mock_socket.recv_multipart.side_effect = [
|
||||
(b"topic", b"message1"),
|
||||
TimeoutError(),
|
||||
(b"topic", b"message2"), # Should not be reached if disconnect checks work
|
||||
]
|
||||
mock_socket.close = MagicMock()
|
||||
mock_socket.connect = MagicMock()
|
||||
mock_socket.subscribe = MagicMock()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_socket
|
||||
|
||||
with patch(
|
||||
"control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context
|
||||
):
|
||||
mock_request = AsyncMock()
|
||||
# is_disconnected sequence:
|
||||
# 1. False (before first recv) -> reads message1
|
||||
# 2. False (before second recv) -> triggers TimeoutError, continues
|
||||
# 3. True (before third recv) -> break loop
|
||||
mock_request.is_disconnected.side_effect = [False, False, True]
|
||||
|
||||
response = await user_interact.experiment_stream(mock_request)
|
||||
|
||||
lines = []
|
||||
# Consume the generator
|
||||
async for line in response.body_iterator:
|
||||
lines.append(line)
|
||||
|
||||
assert "data: message1\n\n" in lines
|
||||
assert len(lines) == 1
|
||||
|
||||
mock_socket.connect.assert_called()
|
||||
mock_socket.subscribe.assert_called_with(b"experiment")
|
||||
mock_socket.close.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_stream_direct_call():
|
||||
"""
|
||||
Test the status stream, ensuring it handles messages and sends pings on timeout.
|
||||
"""
|
||||
mock_socket = AsyncMock()
|
||||
|
||||
# Define the sequence of events for the socket:
|
||||
# 1. Successfully receive a message
|
||||
# 2. Timeout (which should trigger the ': ping' yield)
|
||||
# 3. Another message (which won't be reached because we'll simulate disconnect)
|
||||
mock_socket.recv_multipart.side_effect = [
|
||||
(b"topic", b"status_update"),
|
||||
TimeoutError(),
|
||||
(b"topic", b"ignored_msg"),
|
||||
]
|
||||
|
||||
mock_socket.close = MagicMock()
|
||||
mock_socket.connect = MagicMock()
|
||||
mock_socket.subscribe = MagicMock()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_socket
|
||||
|
||||
# Mock the ZMQ Context to return our mock_socket
|
||||
with patch(
|
||||
"control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context
|
||||
):
|
||||
mock_request = AsyncMock()
|
||||
|
||||
# is_disconnected sequence:
|
||||
# 1. False -> Process "status_update"
|
||||
# 2. False -> Process TimeoutError (yield ping)
|
||||
# 3. True -> Break loop (client disconnected)
|
||||
mock_request.is_disconnected.side_effect = [False, False, True]
|
||||
|
||||
# Call the status_stream function explicitly
|
||||
response = await user_interact.status_stream(mock_request)
|
||||
|
||||
lines = []
|
||||
async for line in response.body_iterator:
|
||||
lines.append(line)
|
||||
|
||||
# Assertions
|
||||
assert "data: status_update\n\n" in lines
|
||||
assert ": ping\n\n" in lines # Verify lines 91-92 (ping logic)
|
||||
|
||||
mock_socket.connect.assert_called()
|
||||
mock_socket.subscribe.assert_called_with(b"status")
|
||||
mock_socket.close.assert_called()
|
||||
43
test/unit/conftest.py
Normal file
43
test/unit/conftest.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.core.agent_system import _agent_directory
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_agent_directory():
|
||||
"""
|
||||
Automatically clears the global agent directory before and after each test
|
||||
to prevent state leakage between tests.
|
||||
"""
|
||||
_agent_directory.clear()
|
||||
yield
|
||||
_agent_directory.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch("control_backend.core.config.settings") as mock:
|
||||
# Set default values that match the pydantic model defaults
|
||||
# to avoid AttributeErrors during tests
|
||||
mock.zmq_settings.internal_pub_address = "tcp://localhost:5560"
|
||||
mock.zmq_settings.internal_sub_address = "tcp://localhost:5561"
|
||||
mock.zmq_settings.ri_command_address = "tcp://localhost:0000"
|
||||
mock.agent_settings.bdi_core_name = "bdi_core_agent"
|
||||
mock.agent_settings.llm_name = "llm_agent"
|
||||
mock.agent_settings.robot_speech_name = "robot_speech_agent"
|
||||
mock.agent_settings.transcription_name = "transcription_agent"
|
||||
mock.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent"
|
||||
mock.agent_settings.vad_name = "vad_agent"
|
||||
mock.behaviour_settings.sleep_s = 0.01 # Speed up tests
|
||||
mock.behaviour_settings.comm_setup_max_retries = 1
|
||||
mock.behaviour_settings.agentspeak_file = "src/control_backend/agents/bdi/agentspeak.asl"
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_zmq_context():
|
||||
with patch("zmq.asyncio.Context") as mock:
|
||||
mock.instance.return_value = MagicMock()
|
||||
yield mock
|
||||
274
test/unit/core/test_agent_system.py
Normal file
274
test/unit/core/test_agent_system.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Test the base class logic, message passing and background task handling."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.core.agent_system import AgentDirectory, BaseAgent, InternalMessage
|
||||
|
||||
|
||||
class ConcreteTestAgent(BaseAgent):
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.received = []
|
||||
|
||||
async def setup(self):
|
||||
pass
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
self.received.append(msg)
|
||||
if msg.body == "stop":
|
||||
await self.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_lifecycle():
|
||||
agent = ConcreteTestAgent("lifecycle_agent")
|
||||
await agent.start()
|
||||
assert agent._running is True
|
||||
|
||||
# Test background task
|
||||
async def dummy_task():
|
||||
pass
|
||||
|
||||
task = agent.add_behavior(dummy_task())
|
||||
assert task in agent._tasks
|
||||
|
||||
await task
|
||||
|
||||
# Wait for task to finish
|
||||
assert task not in agent._tasks
|
||||
assert len(agent._tasks) == 2 # message handling tasks are still running
|
||||
|
||||
await agent.stop()
|
||||
assert agent._running is False
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Tasks should be cancelled
|
||||
assert len(agent._tasks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_unknown_agent():
|
||||
agent = ConcreteTestAgent("sender")
|
||||
msg = InternalMessage(to="unknown_receiver", sender="sender", body="boo")
|
||||
|
||||
agent._internal_pub_socket = AsyncMock()
|
||||
|
||||
await agent.send(msg)
|
||||
|
||||
agent._internal_pub_socket.send_multipart.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent():
|
||||
agent = ConcreteTestAgent("registrant")
|
||||
assert AgentDirectory.get("registrant") == agent
|
||||
assert AgentDirectory.get("non_existent") is None
|
||||
|
||||
|
||||
class DummyAgent(BaseAgent):
|
||||
async def setup(self):
|
||||
pass # we will test this separately
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
self.last_handled = msg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_agent_setup_is_noop():
|
||||
agent = DummyAgent("dummy")
|
||||
|
||||
# Should simply return without error
|
||||
assert await agent.setup() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_local_agent(monkeypatch):
|
||||
sender = DummyAgent("sender")
|
||||
target = DummyAgent("receiver")
|
||||
|
||||
# Fake logger
|
||||
sender.logger = MagicMock()
|
||||
|
||||
# Patch inbox.put
|
||||
target.inbox.put = AsyncMock()
|
||||
|
||||
message = InternalMessage(to=target.name, sender=sender.name, body="hello")
|
||||
|
||||
await sender.send(message)
|
||||
|
||||
target.inbox.put.assert_awaited_once_with(message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_zmq_agent(monkeypatch):
|
||||
sender = DummyAgent("sender")
|
||||
target = "remote_receiver"
|
||||
|
||||
# Fake logger
|
||||
sender.logger = MagicMock()
|
||||
|
||||
# Fake zmq
|
||||
sender._internal_pub_socket = AsyncMock()
|
||||
|
||||
message = InternalMessage(to=target, sender=sender.name, body="hello")
|
||||
|
||||
await sender.send(message)
|
||||
|
||||
zmq_calls = sender._internal_pub_socket.send_multipart.call_args[0][0]
|
||||
assert zmq_calls[0] == f"internal/{target}".encode()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_multiple_local_agents(monkeypatch):
|
||||
sender = DummyAgent("sender")
|
||||
target1 = DummyAgent("receiver1")
|
||||
target2 = DummyAgent("receiver2")
|
||||
|
||||
# Fake logger
|
||||
sender.logger = MagicMock()
|
||||
|
||||
# Patch inbox.put
|
||||
target1.inbox.put = AsyncMock()
|
||||
target2.inbox.put = AsyncMock()
|
||||
|
||||
message = InternalMessage(to=[target1.name, target2.name], sender=sender.name, body="hello")
|
||||
|
||||
await sender.send(message)
|
||||
|
||||
target1.inbox.put.assert_awaited_once_with(message)
|
||||
target2.inbox.put.assert_awaited_once_with(message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_multiple_agents(monkeypatch):
|
||||
sender = DummyAgent("sender")
|
||||
target1 = DummyAgent("receiver1")
|
||||
target2 = "remote_receiver"
|
||||
|
||||
# Fake logger
|
||||
sender.logger = MagicMock()
|
||||
|
||||
# Fake zmq
|
||||
sender._internal_pub_socket = AsyncMock()
|
||||
|
||||
# Patch inbox.put
|
||||
target1.inbox.put = AsyncMock()
|
||||
|
||||
message = InternalMessage(to=[target1.name, target2], sender=sender.name, body="hello")
|
||||
|
||||
await sender.send(message)
|
||||
|
||||
target1.inbox.put.assert_awaited_once_with(message)
|
||||
zmq_calls = sender._internal_pub_socket.send_multipart.call_args[0][0]
|
||||
assert zmq_calls[0] == f"internal/{target2}".encode()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_inbox_calls_handle_message(monkeypatch):
|
||||
agent = DummyAgent("dummy")
|
||||
agent.logger = MagicMock()
|
||||
|
||||
# Make agent running so loop triggers
|
||||
agent._running = True
|
||||
|
||||
# Prepare inbox to give one message then stop
|
||||
msg = InternalMessage(to="dummy", sender="x", body="test")
|
||||
|
||||
async def get_once():
|
||||
agent._running = False # stop after first iteration
|
||||
return msg
|
||||
|
||||
agent.inbox.get = AsyncMock(side_effect=get_once)
|
||||
agent.handle_message = AsyncMock()
|
||||
|
||||
await agent._process_inbox()
|
||||
|
||||
agent.handle_message.assert_awaited_once_with(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_internal_zmq_loop_success(monkeypatch):
|
||||
agent = DummyAgent("dummy")
|
||||
agent.logger = MagicMock()
|
||||
agent._running = True
|
||||
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.recv_multipart = AsyncMock(
|
||||
side_effect=[
|
||||
(
|
||||
b"topic",
|
||||
InternalMessage(to="dummy", sender="x", body="hi").model_dump_json().encode(),
|
||||
),
|
||||
asyncio.CancelledError(), # stop loop
|
||||
]
|
||||
)
|
||||
agent._internal_sub_socket = mock_socket
|
||||
|
||||
agent.inbox.put = AsyncMock()
|
||||
|
||||
await agent._receive_internal_zmq_loop()
|
||||
|
||||
agent.inbox.put.assert_awaited() # message forwarded
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_internal_zmq_loop_exception_logs_error():
|
||||
agent = DummyAgent("dummy")
|
||||
agent.logger = MagicMock()
|
||||
agent._running = True
|
||||
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.recv_multipart = AsyncMock(
|
||||
side_effect=[Exception("boom"), asyncio.CancelledError()]
|
||||
)
|
||||
agent._internal_sub_socket = mock_socket
|
||||
|
||||
agent.inbox.put = AsyncMock()
|
||||
|
||||
await agent._receive_internal_zmq_loop()
|
||||
|
||||
agent.logger.exception.assert_called_once()
|
||||
assert "Could not process ZMQ message." in agent.logger.exception.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_agent_handle_message_not_implemented():
|
||||
class RawAgent(BaseAgent):
|
||||
async def setup(self):
|
||||
pass
|
||||
|
||||
agent = RawAgent("raw")
|
||||
|
||||
msg = InternalMessage(to="raw", sender="x", body="hi")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await BaseAgent.handle_message(agent, msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_agent_setup_abstract_method_body_executes():
|
||||
"""
|
||||
Covers the 'pass' inside BaseAgent.setup().
|
||||
Since BaseAgent is abstract, we do NOT instantiate it.
|
||||
We call the coroutine function directly on BaseAgent and pass a dummy self.
|
||||
"""
|
||||
|
||||
class Dummy:
|
||||
"""Minimal stub to act as 'self'."""
|
||||
|
||||
pass
|
||||
|
||||
stub = Dummy()
|
||||
|
||||
# Call BaseAgent.setup() as an unbound coroutine, passing stub as 'self'
|
||||
result = await BaseAgent.setup(stub)
|
||||
|
||||
# The method contains only 'pass', so it returns None
|
||||
assert result is None
|
||||
14
test/unit/core/test_config.py
Normal file
14
test/unit/core/test_config.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Test if settings load correctly and environment variables override defaults."""
|
||||
|
||||
from control_backend.core.config import Settings
|
||||
|
||||
|
||||
def test_default_settings():
|
||||
settings = Settings()
|
||||
assert settings.app_title == "PepperPlus"
|
||||
|
||||
|
||||
def test_env_override(monkeypatch):
|
||||
monkeypatch.setenv("APP_TITLE", "TestPepper")
|
||||
settings = Settings()
|
||||
assert settings.app_title == "TestPepper"
|
||||
119
test/unit/core/test_logging.py
Normal file
119
test/unit/core/test_logging.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.logging.setup_logging import add_logging_level, setup_logging
|
||||
|
||||
|
||||
def test_add_logging_level():
|
||||
# Add a unique level to avoid conflicts with other tests/libraries
|
||||
level_name = "TESTLEVEL"
|
||||
level_num = 35
|
||||
|
||||
add_logging_level(level_name, level_num)
|
||||
|
||||
assert logging.getLevelName(level_num) == level_name
|
||||
assert hasattr(logging, level_name)
|
||||
assert hasattr(logging.getLoggerClass(), level_name.lower())
|
||||
|
||||
# Test functionality
|
||||
logger = logging.getLogger("test_custom_level")
|
||||
with patch.object(logger, "_log") as mock_log:
|
||||
getattr(logger, level_name.lower())("message")
|
||||
mock_log.assert_called_with(level_num, "message", ())
|
||||
|
||||
# Test duplicates
|
||||
with pytest.raises(AttributeError):
|
||||
add_logging_level(level_name, level_num)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
add_logging_level("INFO", 20) # Existing level
|
||||
|
||||
|
||||
def test_setup_logging_no_file(caplog):
|
||||
with patch("os.path.exists", return_value=False):
|
||||
setup_logging("dummy.yaml")
|
||||
assert "Logging config file not found" in caplog.text
|
||||
|
||||
|
||||
def test_setup_logging_yaml_error(caplog):
|
||||
with patch("os.path.exists", return_value=True):
|
||||
with patch("builtins.open", mock_open(read_data="invalid: [yaml")):
|
||||
with patch("logging.config.dictConfig") as mock_dict_config:
|
||||
setup_logging("config.yaml")
|
||||
|
||||
# Verify we logged the warning
|
||||
assert "Could not load logging configuration" in caplog.text
|
||||
# Verify dictConfig was called with empty dict (which would crash real dictConfig)
|
||||
mock_dict_config.assert_called_with({})
|
||||
assert "Could not load logging configuration" in caplog.text
|
||||
|
||||
|
||||
def test_setup_logging_success():
|
||||
config_data = """
|
||||
version: 1
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
root:
|
||||
handlers: [console]
|
||||
level: INFO
|
||||
custom_levels:
|
||||
MYLEVEL: 15
|
||||
"""
|
||||
with patch("os.path.exists", return_value=True):
|
||||
with patch("builtins.open", mock_open(read_data=config_data)):
|
||||
with patch("logging.config.dictConfig") as mock_dict_config:
|
||||
setup_logging("config.yaml")
|
||||
mock_dict_config.assert_called()
|
||||
assert hasattr(logging, "MYLEVEL")
|
||||
|
||||
|
||||
def test_setup_logging_zmq_handler(mock_zmq_context):
|
||||
config_data = """
|
||||
version: 1
|
||||
handlers:
|
||||
ui:
|
||||
class: logging.NullHandler
|
||||
# In real config this would be a zmq handler, but for unit test logic
|
||||
# we just want to see if the socket injection happens
|
||||
"""
|
||||
with patch("os.path.exists", return_value=True):
|
||||
with patch("builtins.open", mock_open(read_data=config_data)):
|
||||
with patch("logging.config.dictConfig") as mock_dict_config:
|
||||
setup_logging("config.yaml")
|
||||
|
||||
args = mock_dict_config.call_args[0][0]
|
||||
assert "interface_or_socket" in args["handlers"]["ui"]
|
||||
|
||||
|
||||
def test_add_logging_level_method_name_exists_in_logging():
|
||||
# method_name explicitly set to an existing logging method → triggers first hasattr branch
|
||||
with pytest.raises(AttributeError) as exc:
|
||||
add_logging_level("NEWDUPLEVEL", 37, method_name="info")
|
||||
assert "info already defined in logging module" in str(exc.value)
|
||||
|
||||
|
||||
def test_add_logging_level_method_name_exists_in_logger_class():
|
||||
# 'makeRecord' exists on Logger class but not on the logging module
|
||||
with pytest.raises(AttributeError) as exc:
|
||||
add_logging_level("ANOTHERLEVEL", 38, method_name="makeRecord")
|
||||
assert "makeRecord already defined in logger class" in str(exc.value)
|
||||
|
||||
|
||||
def test_add_logging_level_log_to_root_path_executes_without_error():
|
||||
# Verify log_to_root is installed and callable — without asserting logging output
|
||||
level_name = "ROOTTEST"
|
||||
level_num = 36
|
||||
|
||||
add_logging_level(level_name, level_num)
|
||||
|
||||
# Simply call the injected root logger method
|
||||
# The line is executed even if we don't validate output
|
||||
root_logging_method = getattr(logging, level_name.lower(), None)
|
||||
assert callable(root_logging_method)
|
||||
|
||||
# Execute the method to hit log_to_root in coverage.
|
||||
# No need to verify log output.
|
||||
root_logging_method("some message")
|
||||
45
test/unit/logging/test_dated_file_handler.py
Normal file
45
test/unit/logging/test_dated_file_handler.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.logging.dated_file_handler import DatedFileHandler
|
||||
|
||||
|
||||
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||
def test_reset(open_):
|
||||
stream = MagicMock()
|
||||
open_.return_value = stream
|
||||
|
||||
# A file should be opened when the logger is created
|
||||
handler = DatedFileHandler(file_prefix="anything")
|
||||
assert open_.call_count == 1
|
||||
|
||||
# Upon reset, the current file should be closed, and a new one should be opened
|
||||
handler.do_rollover()
|
||||
assert stream.close.call_count == 1
|
||||
assert open_.call_count == 2
|
||||
|
||||
|
||||
@patch("control_backend.logging.dated_file_handler.Path")
|
||||
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||
def test_creates_dir(open_, Path_):
|
||||
stream = MagicMock()
|
||||
open_.return_value = stream
|
||||
|
||||
test_path = MagicMock()
|
||||
test_path.parent.is_dir.return_value = False
|
||||
Path_.return_value = test_path
|
||||
|
||||
DatedFileHandler(file_prefix="anything")
|
||||
|
||||
# The directory should've been created
|
||||
test_path.parent.mkdir.assert_called_once()
|
||||
|
||||
|
||||
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||
def test_invalid_constructor(_):
|
||||
with pytest.raises(ValueError):
|
||||
DatedFileHandler(file_prefix=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DatedFileHandler(file_prefix="")
|
||||
218
test/unit/logging/test_optional_field_formatter.py
Normal file
218
test/unit/logging/test_optional_field_formatter.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.logging.optional_field_formatter import OptionalFieldFormatter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger():
|
||||
"""Create a fresh logger for each test."""
|
||||
logger = logging.getLogger(f"test_{id(object())}")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.handlers = []
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def log_output(logger):
|
||||
"""Capture log output and return a function to get it."""
|
||||
|
||||
class ListHandler(logging.Handler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.records = []
|
||||
|
||||
def emit(self, record):
|
||||
self.records.append(self.format(record))
|
||||
|
||||
handler = ListHandler()
|
||||
logger.addHandler(handler)
|
||||
|
||||
def get_output():
|
||||
return handler.records
|
||||
|
||||
return get_output
|
||||
|
||||
|
||||
def test_optional_field_present(logger, log_output):
|
||||
"""Optional field should appear when provided in extra."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message", extra={"role": "user"})
|
||||
|
||||
assert log_output() == ["INFO - user - test message"]
|
||||
|
||||
|
||||
def test_optional_field_missing_no_default(logger, log_output):
|
||||
"""Missing optional field with no default should be None."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message")
|
||||
|
||||
assert log_output() == ["INFO - None - test message"]
|
||||
|
||||
|
||||
def test_optional_field_missing_with_default(logger, log_output):
|
||||
"""Missing optional field should use provided default."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message")
|
||||
|
||||
assert log_output() == ["INFO - assistant - test message"]
|
||||
|
||||
|
||||
def test_optional_field_overrides_default(logger, log_output):
|
||||
"""Provided extra value should override default."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message", extra={"role": "user"})
|
||||
|
||||
assert log_output() == ["INFO - user - test message"]
|
||||
|
||||
|
||||
def test_multiple_optional_fields(logger, log_output):
|
||||
"""Multiple optional fields should work independently."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s - %(role?)s - %(request_id?)s - %(message)s", defaults={"role": "assistant"}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"request_id": "123"})
|
||||
|
||||
assert log_output() == ["INFO - assistant - 123 - test"]
|
||||
|
||||
|
||||
def test_mixed_optional_and_required_fields(logger, log_output):
|
||||
"""Standard fields should work alongside optional fields."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(name)s %(role?)s %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"role": "user"})
|
||||
|
||||
output = log_output()[0]
|
||||
assert "INFO" in output
|
||||
assert "user" in output
|
||||
assert "test" in output
|
||||
|
||||
|
||||
def test_no_optional_fields(logger, log_output):
|
||||
"""Formatter should work normally with no optional fields."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message")
|
||||
|
||||
assert log_output() == ["INFO test message"]
|
||||
|
||||
|
||||
def test_integer_format_specifier(logger, log_output):
|
||||
"""Optional fields with %d specifier should work."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s %(count?)d %(message)s", defaults={"count": 0}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"count": 42})
|
||||
|
||||
assert log_output() == ["INFO 42 test"]
|
||||
|
||||
|
||||
def test_float_format_specifier(logger, log_output):
|
||||
"""Optional fields with %f specifier should work."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s %(duration?)f %(message)s", defaults={"duration": 0.0}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"duration": 1.5})
|
||||
|
||||
assert "1.5" in log_output()[0]
|
||||
|
||||
|
||||
def test_empty_string_default(logger, log_output):
|
||||
"""Empty string default should work."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults={"role": ""})
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test")
|
||||
|
||||
assert log_output() == ["INFO test"]
|
||||
|
||||
|
||||
def test_none_format_string():
|
||||
"""None format string should not raise."""
|
||||
formatter = OptionalFieldFormatter(fmt=None)
|
||||
assert formatter.optional_fields == set()
|
||||
|
||||
|
||||
def test_optional_fields_parsed_correctly():
|
||||
"""Check that optional fields are correctly identified."""
|
||||
formatter = OptionalFieldFormatter("%(asctime)s %(role?)s %(level?)d %(name)s")
|
||||
|
||||
assert formatter.optional_fields == {("role", "s"), ("level", "d")}
|
||||
|
||||
|
||||
def test_format_string_normalized():
|
||||
"""Check that ? is removed from format string."""
|
||||
formatter = OptionalFieldFormatter("%(role?)s %(message)s")
|
||||
|
||||
assert "?" not in formatter._style._fmt
|
||||
assert "%(role)s" in formatter._style._fmt
|
||||
|
||||
|
||||
def test_field_with_underscore(logger, log_output):
|
||||
"""Field names with underscores should work."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(user_id?)s %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"user_id": "abc123"})
|
||||
|
||||
assert log_output() == ["INFO abc123 test"]
|
||||
|
||||
|
||||
def test_field_with_numbers(logger, log_output):
|
||||
"""Field names with numbers should work."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(field2?)s %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"field2": "value"})
|
||||
|
||||
assert log_output() == ["INFO value test"]
|
||||
|
||||
|
||||
def test_multiple_log_calls(logger, log_output):
|
||||
"""Formatter should work correctly across multiple log calls."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s %(role?)s %(message)s", defaults={"role": "other"}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("first", extra={"role": "assistant"})
|
||||
logger.info("second")
|
||||
logger.info("third", extra={"role": "user"})
|
||||
|
||||
assert log_output() == [
|
||||
"INFO assistant first",
|
||||
"INFO other second",
|
||||
"INFO user third",
|
||||
]
|
||||
|
||||
|
||||
def test_default_not_mutated(logger, log_output):
|
||||
"""Original defaults dict should not be mutated."""
|
||||
defaults = {"role": "other"}
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults=defaults)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test")
|
||||
|
||||
assert defaults == {"role": "other"}
|
||||
83
test/unit/logging/test_partial_filter.py
Normal file
83
test/unit/logging/test_partial_filter.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.logging import PartialFilter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger():
|
||||
"""Create a fresh logger for each test."""
|
||||
logger = logging.getLogger(f"test_{id(object())}")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.handlers = []
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def log_output(logger):
|
||||
"""Capture log output and return a function to get it."""
|
||||
|
||||
class ListHandler(logging.Handler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.records = []
|
||||
|
||||
def emit(self, record):
|
||||
self.records.append(self.format(record))
|
||||
|
||||
handler = ListHandler()
|
||||
handler.addFilter(PartialFilter())
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(handler)
|
||||
|
||||
return lambda: handler.records
|
||||
|
||||
|
||||
def test_no_partial_attribute(logger, log_output):
|
||||
"""Records without partial attribute should pass through."""
|
||||
logger.info("normal message")
|
||||
|
||||
assert log_output() == ["normal message"]
|
||||
|
||||
|
||||
def test_partial_true_filtered(logger, log_output):
|
||||
"""Records with partial=True should be filtered out."""
|
||||
logger.info("partial message", extra={"partial": True})
|
||||
|
||||
assert log_output() == []
|
||||
|
||||
|
||||
def test_partial_false_passes(logger, log_output):
|
||||
"""Records with partial=False should pass through."""
|
||||
logger.info("complete message", extra={"partial": False})
|
||||
|
||||
assert log_output() == ["complete message"]
|
||||
|
||||
|
||||
def test_partial_none_passes(logger, log_output):
|
||||
"""Records with partial=None should pass through."""
|
||||
logger.info("message", extra={"partial": None})
|
||||
|
||||
assert log_output() == ["message"]
|
||||
|
||||
|
||||
def test_partial_truthy_value_passes(logger, log_output):
|
||||
"""
|
||||
Records with truthy but non-True partial should pass through, that is, only when it's exactly
|
||||
``True`` should it pass.
|
||||
"""
|
||||
logger.info("message", extra={"partial": "yes"})
|
||||
|
||||
assert log_output() == ["message"]
|
||||
|
||||
|
||||
def test_multiple_records_mixed(logger, log_output):
|
||||
"""Filter should handle mixed records correctly."""
|
||||
logger.info("first")
|
||||
logger.info("second", extra={"partial": True})
|
||||
logger.info("third", extra={"partial": False})
|
||||
logger.info("fourth", extra={"partial": True})
|
||||
logger.info("fifth")
|
||||
|
||||
assert log_output() == ["first", "third", "fifth"]
|
||||
12
test/unit/schemas/test_message.py
Normal file
12
test/unit/schemas/test_message.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from control_backend.schemas.message import Message
|
||||
|
||||
|
||||
def base_message() -> Message:
|
||||
return Message(message="Example")
|
||||
|
||||
|
||||
def test_valid_message():
|
||||
mess = base_message()
|
||||
validated = Message.model_validate(mess)
|
||||
assert isinstance(validated, Message)
|
||||
assert validated.message == "Example"
|
||||
88
test/unit/schemas/test_ri_message.py
Normal file
88
test/unit/schemas/test_ri_message.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, RIMessage, SpeechCommand
|
||||
|
||||
|
||||
def valid_command_1():
|
||||
return SpeechCommand(data="Hallo?")
|
||||
|
||||
|
||||
def valid_command_2():
|
||||
return GestureCommand(endpoint=RIEndpoint.GESTURE_TAG, data="happy")
|
||||
|
||||
|
||||
def valid_command_3():
|
||||
return GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="happy_1")
|
||||
|
||||
|
||||
def invalid_command_1():
|
||||
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
|
||||
|
||||
|
||||
def invalid_command_2():
|
||||
return RIMessage(endpoint=RIEndpoint.PING, data="Hey!")
|
||||
|
||||
|
||||
def invalid_command_3():
|
||||
return RIMessage(endpoint=RIEndpoint.GESTURE_SINGLE, data={1, 2, 3})
|
||||
|
||||
|
||||
def invalid_command_4():
|
||||
test: RIMessage = GestureCommand(endpoint=RIEndpoint.GESTURE_SINGLE, data="asdsad")
|
||||
|
||||
def change_endpoint(msg: RIMessage):
|
||||
msg.endpoint = RIEndpoint.PING
|
||||
|
||||
change_endpoint(test)
|
||||
return test
|
||||
|
||||
|
||||
def test_valid_speech_command_1():
|
||||
command = valid_command_1()
|
||||
RIMessage.model_validate(command)
|
||||
SpeechCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_valid_gesture_command_1():
|
||||
command = valid_command_2()
|
||||
RIMessage.model_validate(command)
|
||||
GestureCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_valid_gesture_command_2():
|
||||
command = valid_command_3()
|
||||
RIMessage.model_validate(command)
|
||||
GestureCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_invalid_speech_command_1():
|
||||
command = invalid_command_1()
|
||||
RIMessage.model_validate(command)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
SpeechCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_invalid_gesture_command_1():
|
||||
command = invalid_command_2()
|
||||
RIMessage.model_validate(command)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
GestureCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_invalid_gesture_command_2():
|
||||
command = invalid_command_3()
|
||||
RIMessage.model_validate(command)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
GestureCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_invalid_gesture_command_3():
|
||||
command = invalid_command_4()
|
||||
RIMessage.model_validate(command)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
GestureCommand.model_validate(command)
|
||||
205
test/unit/schemas/test_ui_program_message.py
Normal file
205
test/unit/schemas/test_ui_program_message.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.program import (
|
||||
BasicNorm,
|
||||
ConditionalNorm,
|
||||
Goal,
|
||||
InferredBelief,
|
||||
KeywordBelief,
|
||||
LogicalOperator,
|
||||
Phase,
|
||||
Plan,
|
||||
Program,
|
||||
SemanticBelief,
|
||||
Trigger,
|
||||
)
|
||||
|
||||
|
||||
def base_norm() -> BasicNorm:
|
||||
return BasicNorm(
|
||||
id=uuid.uuid4(),
|
||||
name="testNormName",
|
||||
norm="testNormNorm",
|
||||
critical=False,
|
||||
)
|
||||
|
||||
|
||||
def base_goal() -> Goal:
|
||||
return Goal(
|
||||
id=uuid.uuid4(),
|
||||
name="testGoalName",
|
||||
description="This description can be used to determine whether the goal has been achieved.",
|
||||
plan=Plan(
|
||||
id=uuid.uuid4(),
|
||||
name="testGoalPlanName",
|
||||
steps=[],
|
||||
),
|
||||
can_fail=False,
|
||||
)
|
||||
|
||||
|
||||
def base_trigger() -> Trigger:
|
||||
return Trigger(
|
||||
id=uuid.uuid4(),
|
||||
name="testTriggerName",
|
||||
condition=KeywordBelief(
|
||||
id=uuid.uuid4(),
|
||||
name="testTriggerKeywordBeliefTriggerName",
|
||||
keyword="Keyword",
|
||||
),
|
||||
plan=Plan(
|
||||
id=uuid.uuid4(),
|
||||
name="testTriggerPlanName",
|
||||
steps=[],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def base_phase() -> Phase:
|
||||
return Phase(
|
||||
id=uuid.uuid4(),
|
||||
norms=[base_norm()],
|
||||
goals=[base_goal()],
|
||||
triggers=[base_trigger()],
|
||||
)
|
||||
|
||||
|
||||
def base_program() -> Program:
|
||||
return Program(phases=[base_phase()])
|
||||
|
||||
|
||||
def invalid_program() -> dict:
|
||||
# wrong types inside phases list (not Phase objects)
|
||||
return {
|
||||
"phases": [
|
||||
{"id": uuid.uuid4()}, # incomplete
|
||||
{"not_a_phase": True},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_valid_program():
|
||||
program = base_program()
|
||||
validated = Program.model_validate(program)
|
||||
assert isinstance(validated, Program)
|
||||
assert validated.phases[0].norms[0].norm == "testNormNorm"
|
||||
|
||||
|
||||
def test_valid_deepprogram():
|
||||
program = base_program()
|
||||
validated = Program.model_validate(program)
|
||||
# validate nested components directly
|
||||
phase = validated.phases[0]
|
||||
assert isinstance(phase.goals[0], Goal)
|
||||
assert isinstance(phase.triggers[0], Trigger)
|
||||
assert isinstance(phase.norms[0], BasicNorm)
|
||||
|
||||
|
||||
def test_invalid_program():
|
||||
bad = invalid_program()
|
||||
with pytest.raises(ValidationError):
|
||||
Program.model_validate(bad)
|
||||
|
||||
|
||||
def test_conditional_norm_parsing():
|
||||
"""
|
||||
Check that pydantic is able to preserve the type of the norm, that it doesn't lose its
|
||||
"condition" field when serializing and deserializing.
|
||||
"""
|
||||
norm = ConditionalNorm(
|
||||
name="testNormName",
|
||||
id=uuid.uuid4(),
|
||||
norm="testNormNorm",
|
||||
critical=False,
|
||||
condition=KeywordBelief(
|
||||
name="testKeywordBelief",
|
||||
id=uuid.uuid4(),
|
||||
keyword="testKeywordBelief",
|
||||
),
|
||||
)
|
||||
program = Program(
|
||||
phases=[
|
||||
Phase(
|
||||
name="Some phase",
|
||||
id=uuid.uuid4(),
|
||||
norms=[norm],
|
||||
goals=[],
|
||||
triggers=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
parsed_program = Program.model_validate_json(program.model_dump_json())
|
||||
parsed_norm = parsed_program.phases[0].norms[0]
|
||||
|
||||
assert hasattr(parsed_norm, "condition")
|
||||
assert isinstance(parsed_norm, ConditionalNorm)
|
||||
|
||||
|
||||
def test_belief_type_parsing():
|
||||
"""
|
||||
Check that pydantic is able to discern between the different types of beliefs when serializing
|
||||
and deserializing.
|
||||
"""
|
||||
keyword_belief = KeywordBelief(
|
||||
name="testKeywordBelief",
|
||||
id=uuid.uuid4(),
|
||||
keyword="something",
|
||||
)
|
||||
semantic_belief = SemanticBelief(
|
||||
name="testSemanticBelief",
|
||||
id=uuid.uuid4(),
|
||||
description="something",
|
||||
)
|
||||
inferred_belief = InferredBelief(
|
||||
name="testInferredBelief",
|
||||
id=uuid.uuid4(),
|
||||
operator=LogicalOperator.OR,
|
||||
left=keyword_belief,
|
||||
right=semantic_belief,
|
||||
)
|
||||
|
||||
program = Program(
|
||||
phases=[
|
||||
Phase(
|
||||
name="Some phase",
|
||||
id=uuid.uuid4(),
|
||||
norms=[],
|
||||
goals=[],
|
||||
triggers=[
|
||||
Trigger(
|
||||
name="testTriggerKeywordTrigger",
|
||||
id=uuid.uuid4(),
|
||||
condition=keyword_belief,
|
||||
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
|
||||
),
|
||||
Trigger(
|
||||
name="testTriggerSemanticTrigger",
|
||||
id=uuid.uuid4(),
|
||||
condition=semantic_belief,
|
||||
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
|
||||
),
|
||||
Trigger(
|
||||
name="testTriggerInferredTrigger",
|
||||
id=uuid.uuid4(),
|
||||
condition=inferred_belief,
|
||||
plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
parsed_program = Program.model_validate_json(program.model_dump_json())
|
||||
|
||||
parsed_keyword_belief = parsed_program.phases[0].triggers[0].condition
|
||||
assert isinstance(parsed_keyword_belief, KeywordBelief)
|
||||
|
||||
parsed_semantic_belief = parsed_program.phases[0].triggers[1].condition
|
||||
assert isinstance(parsed_semantic_belief, SemanticBelief)
|
||||
|
||||
parsed_inferred_belief = parsed_program.phases[0].triggers[2].condition
|
||||
assert isinstance(parsed_inferred_belief, InferredBelief)
|
||||
73
test/unit/test_main.py
Normal file
73
test/unit/test_main.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from control_backend.api.v1.router import api_router
|
||||
from control_backend.main import app, lifespan
|
||||
|
||||
# Fix event loop on Windows
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# Patch setup_logging so it does nothing
|
||||
with patch("control_backend.main.setup_logging"):
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
def test_root_fast():
|
||||
# Patch heavy startup code so it doesn’t slow down
|
||||
with patch("control_backend.main.setup_logging"), patch("control_backend.main.lifespan"):
|
||||
client = TestClient(app)
|
||||
resp = client.get("/")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
|
||||
|
||||
def test_cors_middleware_added():
|
||||
"""Test that CORSMiddleware is correctly added to the app."""
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
middleware_classes = [m.cls for m in app.user_middleware]
|
||||
assert CORSMiddleware in middleware_classes
|
||||
|
||||
|
||||
def test_api_router_included():
|
||||
"""Test that the API router is included in the FastAPI app."""
|
||||
|
||||
route_paths = [r.path for r in app.routes]
|
||||
for route in api_router.routes:
|
||||
assert route.path in route_paths
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_agent_start_exception():
|
||||
"""
|
||||
Trigger an exception during agent startup to cover the error logging branch.
|
||||
Ensures exceptions are logged properly and re-raised.
|
||||
"""
|
||||
with (
|
||||
patch(
|
||||
"control_backend.main.RICommunicationAgent.start", new_callable=AsyncMock
|
||||
) as ri_start,
|
||||
patch("control_backend.main.setup_logging"),
|
||||
patch("control_backend.main.threading.Thread"),
|
||||
):
|
||||
# Force RICommunicationAgent.start to raise an exception
|
||||
ri_start.side_effect = Exception("Test exception")
|
||||
|
||||
with patch("control_backend.main.logger") as mock_logger:
|
||||
with pytest.raises(Exception, match="Test exception"):
|
||||
async with lifespan(app):
|
||||
pass
|
||||
|
||||
# Verify the error was logged correctly
|
||||
assert mock_logger.error.called
|
||||
args, _ = mock_logger.error.call_args
|
||||
assert isinstance(args[2], Exception)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user