Merge remote-tracking branch 'origin/dev' into refactor/zmq-internal-socket-behaviour

# Conflicts:
#	src/control_backend/agents/ri_command_agent.py
#	src/control_backend/agents/ri_communication_agent.py
#	src/control_backend/api/v1/endpoints/command.py
#	src/control_backend/main.py
#	test/integration/api/endpoints/test_command_endpoint.py
This commit is contained in:
Twirre Meulenbelt
2025-11-05 12:16:18 +01:00
30 changed files with 443 additions and 211 deletions

77
.githooks/check-branch-name.sh Executable file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env bash
# This script checks if the current branch name follows the specified format.
# It's designed to be used as a 'pre-commit' git hook.
# Format: <type>/<short-description>
# Example: feat/add-user-login
# --- Configuration ---
# An array of allowed commit types
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
# An array of branches to ignore
IGNORED_BRANCHES=(main dev)
# --- Colors for Output ---
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# --- Helper Functions ---
error_exit() {
echo -e "${RED}ERROR: $1${NC}" >&2
echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2
exit 1
}
# --- Main Logic ---
# 1. Get the current branch name
BRANCH_NAME=$(git symbolic-ref --short HEAD)
# 2. Check if the current branch is in the ignored list
for ignored_branch in "${IGNORED_BRANCHES[@]}"; do
if [ "$BRANCH_NAME" == "$ignored_branch" ]; then
echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}"
exit 0
fi
done
# 3. Validate the overall structure: <type>/<description>
if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then
error_exit "Branch name must be in the format: <type>/<short-description>\nExample: feat/add-user-login"
fi
# 4. Extract the type and description
TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1)
DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-)
# 5. Validate the <type>
type_valid=false
for allowed_type in "${ALLOWED_TYPES[@]}"; do
if [ "$TYPE" == "$allowed_type" ]; then
type_valid=true
break
fi
done
if [ "$type_valid" == false ]; then
error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}"
fi
# 6. Validate the <short-description>
# Regex breakdown:
# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word).
# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times.
# $ - End of the string.
# This entire pattern enforces 1 to 6 words total, separated by dashes.
DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$"
if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then
error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature"
fi
# If all checks pass, exit successfully
echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}"
exit 0

135
.githooks/check-commit-msg.sh Executable file
View File

@@ -0,0 +1,135 @@
#!/usr/bin/env bash
# This script checks if a commit message follows the specified format.
# It's designed to be used as a 'commit-msg' git hook.
# Format:
# <type>: <short description>
#
# [optional]<body>
#
# [ref/close]: <issue identifier>
# --- Configuration ---
# An array of allowed commit types
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
# --- Colors for Output ---
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# The first argument to the hook is the path to the file containing the commit message
COMMIT_MSG_FILE=$1
# --- Automated Commit Detection ---
# Read the first line (header) for initial checks
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Revert commits
# Example: "Revert "feat: add new feature""
REVERT_PATTERN="^Revert \".*\""
if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then
echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Cherry-pick commits (this pattern appears at the end of the message)
# Example: "(cherry picked from commit deadbeef...)"
# We use grep -q to search the whole file quietly.
CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)"
if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then
echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Squash
# Example: "Squash commits ..."
SQUASH_PATTERN="^Squash .+"
if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then
echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# --- Validation Functions ---
# Function to print an error message and exit
# Usage: error_exit "Your error message here"
error_exit() {
# >&2 redirects echo to stderr
echo -e "${RED}ERROR: $1${NC}" >&2
echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2
exit 1
}
# --- Main Logic ---
# 1. Read the header (first line) of the commit message
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# 2. Validate the header format: <type>: <description>
# Regex breakdown:
# ^(type1|type2|...) - Starts with one of the allowed types
# : - Followed by a literal colon
# \s - Followed by a single space
# .+ - Followed by one or more characters for the description
# $ - End of the line
TYPES_REGEX=$(
IFS="|"
echo "${ALLOWED_TYPES[*]}"
)
HEADER_REGEX="^($TYPES_REGEX): .+$"
if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then
error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature"
fi
# Only validate footer if commit type is not chore
TYPE=$(echo "$HEADER" | cut -d':' -f1)
if [ "$TYPE" != "chore" ]; then
# 3. Validate the footer (last line) of the commit message
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE")
# Regex breakdown:
# ^(ref|close) - Starts with 'ref' or 'close'
# : - Followed by a literal colon
# \s - Followed by a single space
# N25B- - Followed by the literal string 'N25B-'
# [0-9]+ - Followed by one or more digits
# $ - End of the line
FOOTER_REGEX="^(ref|close): N25B-[0-9]+$"
if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then
error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123"
fi
fi
# 4. If the message has more than 2 lines, validate the separator
# A blank line must exist between the header and the body.
LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace
# We only care if there is a body. Header + Footer = 2 lines.
# Header + Blank Line + Body... + Footer > 2 lines.
if [ "$LINE_COUNT" -gt 2 ]; then
# Get the second line
SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE")
# Check if the second line is NOT empty. If it's not, it's an error.
if [ -n "$SECOND_LINE" ]; then
error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present."
fi
fi
# If all checks pass, exit with success
echo -e "${GREEN}Commit message is valid.${NC}"
exit 0

View File

@@ -1,16 +0,0 @@
#!/bin/sh
commit_msg_file=$1
commit_msg=$(cat "$commit_msg_file")
if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then
if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then
exit 0
else
echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000"
exit 1
fi
else
echo "❌ Commit message invalid! Must start with <type>: <description>"
exit 1
fi

View File

@@ -1,17 +0,0 @@
#!/bin/sh
# Get current branch
branch=$(git rev-parse --abbrev-ref HEAD)
if echo "$branch" | grep -Eq "(dev|main)"; then
echo 0
fi
# allowed pattern <type/>
if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then
exit 0
else
echo "❌ Invalid branch name: $branch"
echo "Branch must be named <type>/<description-of-branch> (must have one to six words separated by a dash)"
exit 1
fi

View File

@@ -1,9 +0,0 @@
#!/bin/sh
echo "#<type>: <description>
#[optional body]
#[optional footer(s)]
#[ref/close]: <issue identifier>" > $1

View File

@@ -1,10 +1,24 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.2
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.2
hooks:
# Run the linter.
- id: ruff-check
# Run the formatter.
- id: ruff-format
# Configure local hooks
- repo: local
hooks:
- id: check-commit-msg
name: Check commit message format
entry: .githooks/check-commit-msg.sh
language: script
stages: [commit-msg]
- id: check-branch-name
name: Check branch name format
entry: .githooks/check-branch-name.sh
language: script
stages: [pre-commit]
always_run: true
pass_filenames: false

View File

@@ -47,21 +47,19 @@ Or for integration tests:
uv run --group integration-test pytest test/integration
```
## GitHooks
## Git Hooks
To activate automatic commits/branch name checks run:
To activate automatic linting, formatting, branch name checks and commit message checks, run:
```shell
git config --local core.hooksPath .githooks
```bash
uv run pre-commit install
uv run pre-commit install --hook-type commit-msg
```
If your commit fails its either:
branch name != <type>/description-of-branch ,
commit name != <type>: description of the commit.
<ref>: N25B-Num's
You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running:
To add automatic linting and formatting, run:
```bash
git config --local --unset core.hooksPath
```
```shell
uv run pre-commit install
```
Then run the pre-commit install commands again.

View File

@@ -58,11 +58,11 @@ class BDICoreAgent(BDIAgent):
class SendBehaviour(OneShotBehaviour):
async def run(self) -> None:
msg = Message(
to= settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host,
body= text
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
body=text,
)
await self.send(msg)
self.agent.logger.info("Message sent to LLM: %s", text)
self.add_behaviour(SendBehaviour())
self.add_behaviour(SendBehaviour())

View File

@@ -3,7 +3,7 @@ import logging
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from spade_bdi.bdi import BDIAgent, BeliefNotInitiated
from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings
@@ -23,7 +23,6 @@ class BeliefSetterBehaviour(CyclicBehaviour):
self.logger.info(f"Received message {msg.body}")
self._process_message(msg)
def _process_message(self, message: Message):
sender = message.sender.node # removes host from jid and converts to str
self.logger.debug("Sender: %s", sender)
@@ -61,6 +60,7 @@ class BeliefSetterBehaviour(CyclicBehaviour):
self.agent.bdi.set_belief(belief, *arguments)
# Special case: if there's a new user message, flag that we haven't responded yet
if belief == "user_said": self.agent.bdi.set_belief("new_message")
if belief == "user_said":
self.agent.bdi.set_belief("new_message")
self.logger.info("Set belief %s with arguments %s", belief, arguments)

View File

@@ -9,18 +9,20 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
"""
Adds behavior to receive responses from the LLM Agent.
"""
logger = logging.getLogger("BDI/LLM Reciever")
async def run(self):
msg = await self.receive(timeout=2)
if not msg:
return
sender = msg.sender.node
sender = msg.sender.node
match sender:
case settings.agent_settings.llm_agent_name:
content = msg.body
self.logger.info("Received LLM response: %s", content)
#Here the BDI can pass the message back as a response
# Here the BDI can pass the message back as a response
case _:
self.logger.debug("Not from the llm, discarding message")
pass
pass

View File

@@ -13,28 +13,30 @@ class BeliefFromText(CyclicBehaviour):
# TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript).
You are an information extraction assistent for a BDI agent. Your task is to extract values \
from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
and "text" (user's transcript).
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
A single piece of text might contain multiple instances that match a belief.
Respond ONLY with a single JSON object.
The JSON object's keys should be the belief functors (e.g., "weather").
The value for each key must be a list of lists.
Each inner list must contain the extracted arguments (as strings) for one instance of that belief.
CRITICAL: If no information in the text matches a belief, DO NOT include that key in your response.
Each inner list must contain the extracted arguments (as strings) for one instance \
of that belief.
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
in your response.
"""
# on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt
#async def on_start(self):
# on_start agent receives message containing the beliefs to look out for and
# sets up the LLM with instruction prompt
# async def on_start(self):
# msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message
# send instruction prompt to LLM
beliefs: dict[str, list[str]]
beliefs = {
"mood": ["X"],
"car": ["Y"]
}
beliefs = {"mood": ["X"], "car": ["Y"]}
async def run(self):
msg = await self.receive(timeout=0.1)
@@ -58,8 +60,8 @@ class BeliefFromText(CyclicBehaviour):
prompt = text_prompt + beliefs_prompt
self.logger.info(prompt)
#prompt_msg = Message(to="LLMAgent@whatever")
#response = self.send(prompt_msg)
# prompt_msg = Message(to="LLMAgent@whatever")
# response = self.send(prompt_msg)
# Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]]
response = '{"mood": [["happy"]]}'
@@ -67,8 +69,9 @@ class BeliefFromText(CyclicBehaviour):
try:
json.loads(response)
belief_message = Message(
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
body=response)
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=response,
)
belief_message.thread = "beliefs"
await self.send(belief_message)
@@ -85,9 +88,12 @@ class BeliefFromText(CyclicBehaviour):
"""
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = Message(to=settings.agent_settings.belief_collector_agent_name
+ '@' + settings.agent_settings.host,
body=payload)
belief_msg = Message(
to=settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host,
body=payload,
)
belief_msg.thread = "beliefs"
await self.send(belief_msg)

View File

@@ -6,4 +6,4 @@ from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFr
class TBeliefExtractor(Agent):
async def setup(self):
self.b = BeliefFromText()
self.add_behaviour(self.b)
self.add_behaviour(self.b)

View File

@@ -1,11 +1,14 @@
import json
import logging
from spade.behaviour import CyclicBehaviour
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
logger = logging.getLogger(__name__)
class ContinuousBeliefCollector(CyclicBehaviour):
"""
Continuously collects beliefs/emotions from extractor agents:
@@ -17,7 +20,6 @@ class ContinuousBeliefCollector(CyclicBehaviour):
if msg:
await self._process_message(msg)
async def _process_message(self, msg: Message):
sender_node = self._sender_node(msg)
@@ -27,7 +29,9 @@ class ContinuousBeliefCollector(CyclicBehaviour):
except Exception as e:
logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node, msg.body, e
sender_node,
msg.body,
e,
)
return
@@ -35,16 +39,21 @@ class ContinuousBeliefCollector(CyclicBehaviour):
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock":
logger.info("BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node)
logger.info(
"BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node
)
await self._handle_belief_text(payload, sender_node)
#This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
logger.info("BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
logger.info(
"BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node
)
await self._handle_emo_text(payload, sender_node)
else:
logger.info(
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
sender_node, msg_type
sender_node,
msg_type,
)
@staticmethod
@@ -56,13 +65,12 @@ class ContinuousBeliefCollector(CyclicBehaviour):
s = str(msg.sender) if msg.sender is not None else "no_sender"
return s.split("@", 1)[0] if "@" in s else s
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Expected payload:
{
"type": "belief_extraction_text",
"beliefs": {"user_said": ["hello"","Can you help me?","stop talking to me","No","Pepper do a dance"]}
"beliefs": {"user_said": ["Can you help me?"]}
}
@@ -72,11 +80,11 @@ class ContinuousBeliefCollector(CyclicBehaviour):
if not beliefs:
logger.info("BeliefCollector: no beliefs to process.")
return
if not isinstance(beliefs, dict):
logger.warning("BeliefCollector: 'beliefs' is not a dict: %r", beliefs)
return
if not all(isinstance(v, list) for v in beliefs.values()):
logger.warning("BeliefCollector: 'beliefs' values are not all lists: %r", beliefs)
return
@@ -84,17 +92,14 @@ class ContinuousBeliefCollector(CyclicBehaviour):
logger.info("BeliefCollector: forwarding %d beliefs.", len(beliefs))
for belief_name, belief_list in beliefs.items():
for belief in belief_list:
logger.info(" - %s %s", belief_name,str(belief))
logger.info(" - %s %s", belief_name, str(belief))
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""TODO: implement (after we have emotional recogntion)"""
pass
async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None):
"""
Sends a unified belief packet to the BDI agent.
@@ -107,6 +112,5 @@ class ContinuousBeliefCollector(CyclicBehaviour):
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs)
await self.send(msg)
logger.info("BeliefCollector: sent %d belief(s) to BDI at %s", len(beliefs), to_jid)

View File

@@ -1,13 +1,15 @@
import logging
from spade.agent import Agent
from .behaviours.continuous_collect import ContinuousBeliefCollector
logger = logging.getLogger(__name__)
class BeliefCollectorAgent(Agent):
async def setup(self):
logger.info("BeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(ContinuousBeliefCollector())
logger.info("BeliefCollectorAgent ready.")
logger.info("BeliefCollectorAgent ready.")

View File

@@ -65,8 +65,8 @@ class LLMAgent(Agent):
Sends a response message back to the BDI Core Agent.
"""
reply = Message(
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
body=msg
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=msg,
)
await self.send(reply)
self.agent.logger.info("Reply sent to BDI Core Agent")
@@ -80,35 +80,31 @@ class LLMAgent(Agent):
"""
async with httpx.AsyncClient(timeout=120.0) as client:
# Example dynamic content for future (optional)
instructions = LLMInstructions()
developer_instruction = instructions.build_developer_instruction()
response = await client.post(
settings.llm_settings.local_llm_url,
headers={"Content-Type": "application/json"},
json={
"model": settings.llm_settings.local_llm_model,
"messages": [
{
"role": "developer",
"content": developer_instruction
},
{
"role": "user",
"content": prompt
}
{"role": "developer", "content": developer_instruction},
{"role": "user", "content": prompt},
],
"temperature": 0.3
"temperature": 0.3,
},
)
try:
response.raise_for_status()
data: dict[str, Any] = response.json()
return data.get("choices", [{}])[0].get(
"message", {}
).get("content", "No response")
return (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "No response")
)
except httpx.HTTPError as err:
self.agent.logger.error("HTTP error: %s", err)
return "LLM service unavailable."

View File

@@ -1,18 +1,33 @@
import json
from spade.agent import Agent
from spade.behaviour import OneShotBehaviour
from spade.message import Message
from control_backend.core.config import settings
class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self):
to_jid = f"{settings.agent_settings.belief_collector_agent_name}@{settings.agent_settings.host}"
to_jid = (
settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host
)
# Send multiple beliefs in one JSON payload
payload = {
"type": "belief_extraction_text",
"beliefs": {"user_said": ["hello test","Can you help me?","stop talking to me","No","Pepper do a dance"]}
"beliefs": {
"user_said": [
"hello test",
"Can you help me?",
"stop talking to me",
"No",
"Pepper do a dance",
]
},
}
msg = Message(to=to_jid)

View File

@@ -1,9 +1,9 @@
import json
import logging
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from zmq.asyncio import Context
from control_backend.core.config import settings

View File

@@ -1,9 +1,9 @@
import asyncio
import logging
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
from zmq.asyncio import Context
from control_backend.agents.ri_command_agent import RICommandAgent

View File

@@ -75,7 +75,8 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
self.model_name = "mlx-community/whisper-small.en-mlx"
def load_model(self):
if self.was_loaded: return
if self.was_loaded:
return
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
# store it in memory for later usage
ModelHolder.get_model(self.model_name, mx.float16)
@@ -83,9 +84,9 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return mlx_whisper.transcribe(audio,
path_or_hf_repo=self.model_name,
decode_options=self._get_decode_options(audio))["text"]
return mlx_whisper.transcribe(
audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio)
)["text"]
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
@@ -95,12 +96,13 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
self.model = None
def load_model(self):
if self.model is not None: return
if self.model is not None:
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = whisper.load_model("small.en", device=device)
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return whisper.transcribe(self.model,
audio,
decode_options=self._get_decode_options(audio))["text"]
return whisper.transcribe(
self.model, audio, decode_options=self._get_decode_options(audio)
)["text"]

View File

@@ -1,10 +1,7 @@
import logging
import zmq
from fastapi import APIRouter, Request
from zmq.asyncio import Context
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import message, sse, command
from control_backend.api.v1.endpoints import command, message, sse
api_router = APIRouter()

View File

@@ -14,16 +14,16 @@ from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router
from control_backend.core.config import settings
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
def setup_sockets():
context = Context.instance()
@@ -42,6 +42,7 @@ def setup_sockets():
internal_pub_socket.close()
internal_sub_socket.close()
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("%s starting up.", app.title)

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Literal
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel
class RIEndpoint(str, Enum):

View File

@@ -1,10 +1,10 @@
import asyncio
import zmq
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.schemas.ri_message import SpeechCommand
@pytest.fixture

View File

@@ -1,6 +1,8 @@
import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY
from control_backend.agents.ri_communication_agent import RICommunicationAgent
@@ -177,8 +179,8 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time.
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
@@ -330,8 +332,8 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
# better response, within a limited time.
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:

View File

@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, patch
from unittest.mock import MagicMock
import pytest
from fastapi import FastAPI
@@ -16,6 +16,7 @@ def app():
"""
app = FastAPI()
app.include_router(command.router)
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
return app
@@ -25,32 +26,32 @@ def client(app):
return TestClient(app)
@pytest.mark.asyncio
@patch("control_backend.api.v1.endpoints.command.Context.instance")
async def test_receive_command_success(mock_context_instance, client):
def test_receive_command_endpoint(client, app):
"""
Test for successful reception of a command.
Ensures the status code is 202 and the response body is correct.
It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data.
Test that a POST to /command sends the right multipart message
and returns a 202 with the expected JSON body.
"""
# Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
mock_socket = app.state.internal_comm_socket
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
speech_command = SpeechCommand(**command_data)
# Prepare test payload that matches SpeechCommand
payload = {"endpoint": "actuate/speech", "data": "yooo"}
# Act
response = client.post("/command", json=command_data)
# Send POST request
response = client.post("/command", json=payload)
# Assert
# Check response
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
)
# Verify that the socket was called with the correct data
assert mock_socket.send_multipart.called, "Socket should be used to send data"
args, kwargs = mock_socket.send_multipart.call_args
sent_data = args[0]
assert sent_data[0] == b"command"
# Check JSON encoding roughly matches
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
def test_receive_command_invalid_payload(client):

View File

@@ -1,7 +1,8 @@
import pytest
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
from pydantic import ValidationError
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
def valid_command_1():
return SpeechCommand(data="Hallo?")
@@ -13,24 +14,13 @@ def invalid_command_1():
def test_valid_speech_command_1():
command = valid_command_1()
try:
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
assert True
except ValidationError:
assert False
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
def test_invalid_speech_command_1():
command = invalid_command_1()
passed_ri_message_validation = False
try:
# Should succeed, still.
RIMessage.model_validate(command)
passed_ri_message_validation = True
RIMessage.model_validate(command)
# Should fail.
with pytest.raises(ValidationError):
SpeechCommand.model_validate(command)
assert False
except ValidationError:
assert passed_ri_message_validation

View File

@@ -203,6 +203,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
assert "Set belief is_hot with arguments ['kitchen']" in caplog.text
assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text
# def test_responded_unset(belief_setter, mock_agent):
# # Arrange
# new_beliefs = {"user_said": ["message"]}

View File

@@ -1,10 +1,12 @@
import json
import logging
from unittest.mock import MagicMock, AsyncMock, call
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.belief_collector.behaviours.continuous_collect import ContinuousBeliefCollector
from control_backend.agents.belief_collector.behaviours.continuous_collect import (
ContinuousBeliefCollector,
)
@pytest.fixture
def mock_agent(mocker):
@@ -13,18 +15,20 @@ def mock_agent(mocker):
agent.jid = "belief_collector_agent@test"
return agent
@pytest.fixture
def continuous_collector(mock_agent, mocker):
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
# Patch asyncio.sleep to prevent tests from actually waiting
mocker.patch("asyncio.sleep", return_value=None)
collector = ContinuousBeliefCollector()
collector.agent = mock_agent
# Mock the receive method, we will control its return value in each test
collector.receive = AsyncMock()
return collector
@pytest.mark.asyncio
async def test_run_no_message_received(continuous_collector, mocker):
"""
@@ -40,6 +44,7 @@ async def test_run_no_message_received(continuous_collector, mocker):
# Assert
continuous_collector._process_message.assert_not_called()
@pytest.mark.asyncio
async def test_run_message_received(continuous_collector, mocker):
"""
@@ -55,7 +60,8 @@ async def test_run_message_received(continuous_collector, mocker):
# Assert
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
@pytest.mark.asyncio
async def test_process_message_invalid(continuous_collector, mocker):
"""
@@ -66,15 +72,18 @@ async def test_process_message_invalid(continuous_collector, mocker):
msg = MagicMock()
msg.body = invalid_json
msg.sender = "belief_text_agent_mock@test"
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
# Act
await continuous_collector._process_message(msg)
# Assert
logger_mock.warning.assert_called_once()
def test_get_sender_from_message(continuous_collector):
"""
Test that _sender_node correctly extracts the sender node from the message JID.
@@ -89,6 +98,7 @@ def test_get_sender_from_message(continuous_collector):
# Assert
assert sender_node == "agent_node"
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
msg = MagicMock()
@@ -98,6 +108,7 @@ async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
msg = MagicMock()
@@ -107,6 +118,7 @@ async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mock
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
msg = MagicMock()
@@ -116,50 +128,64 @@ async def test_routes_to_handle_emo_text(continuous_collector, mocker):
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_unrecognized_message_logs_info(continuous_collector, mocker):
msg = MagicMock()
msg.body = json.dumps({"type": "something_else"})
msg.sender = "x@test"
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._process_message(msg)
logger_mock.info.assert_any_call(
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", "x", "something_else"
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
"x",
"something_else",
)
@pytest.mark.asyncio
async def test_belief_text_no_beliefs(continuous_collector, mocker):
msg_payload = {"type": "belief_extraction_text"} # no 'beliefs'
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._handle_belief_text(msg_payload, "origin_node")
logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.")
@pytest.mark.asyncio
async def test_belief_text_beliefs_not_dict(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]}
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._handle_belief_text(payload, "origin")
logger_mock.warning.assert_any_call("BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"])
logger_mock.warning.assert_any_call(
"BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]
)
@pytest.mark.asyncio
async def test_belief_text_values_not_lists(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}}
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._handle_belief_text(payload, "origin")
logger_mock.warning.assert_any_call(
"BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"}
)
@pytest.mark.asyncio
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
payload = {
"type": "belief_extraction_text",
"beliefs": {"user_said": ["hello test", "No"]}
}
# Your code calls self.send(..); patch it (or switch implementation to self.agent.send and patch that)
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
continuous_collector.send = AsyncMock()
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1)
@@ -169,12 +195,14 @@ async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector,
# make sure we attempted a send
continuous_collector.send.assert_awaited_once()
@pytest.mark.asyncio
async def test_send_beliefs_noop_on_empty(continuous_collector):
continuous_collector.send = AsyncMock()
await continuous_collector._send_beliefs_to_bdi([], origin="o")
continuous_collector.send.assert_not_awaited()
# @pytest.mark.asyncio
# async def test_send_beliefs_sends_json_packet(continuous_collector):
# # Patch .send and capture the message body
@@ -191,19 +219,22 @@ async def test_send_beliefs_noop_on_empty(continuous_collector):
# assert "belief_packet" in json.loads(sent["body"])["type"]
# assert json.loads(sent["body"])["beliefs"] == beliefs
def test_sender_node_no_sender_returns_literal(continuous_collector):
msg = MagicMock()
msg.sender = None
assert continuous_collector._sender_node(msg) == "no_sender"
def test_sender_node_without_at(continuous_collector):
msg = MagicMock()
msg.sender = "localpartonly"
assert continuous_collector._sender_node(msg) == "localpartonly"
@pytest.mark.asyncio
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
continuous_collector.send = AsyncMock()
await continuous_collector._handle_belief_text(payload, "origin")
continuous_collector.send.assert_awaited_once()
continuous_collector.send.assert_awaited_once()

View File

@@ -6,7 +6,7 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper
def test_estimate_max_tokens():
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
audio = np.empty(shape=(60*16_000), dtype=np.float32)
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
actual = SpeechRecognizer._estimate_max_tokens(audio)
@@ -16,7 +16,7 @@ def test_estimate_max_tokens():
def test_get_decode_options():
"""Check whether the right decode options are given under different scenarios."""
audio = np.empty(shape=(60*16_000), dtype=np.float32)
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
# With the defaults, it should limit output length based on input size
recognizer = OpenAIWhisperSpeechRecognizer()