fix: fixed new tests and merged dev into branch

ref: N25B-256
This commit is contained in:
Björn Otgaar
2025-11-05 16:29:56 +01:00
29 changed files with 520 additions and 298 deletions

77
.githooks/check-branch-name.sh Executable file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env bash
# This script checks if the current branch name follows the specified format.
# It's designed to be used as a 'pre-commit' git hook.
# Format: <type>/<short-description>
# Example: feat/add-user-login
# --- Configuration ---
# An array of allowed commit types
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
# An array of branches to ignore
IGNORED_BRANCHES=(main dev demo)
# --- Colors for Output ---
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# --- Helper Functions ---
error_exit() {
echo -e "${RED}ERROR: $1${NC}" >&2
echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2
exit 1
}
# --- Main Logic ---
# 1. Get the current branch name
BRANCH_NAME=$(git symbolic-ref --short HEAD)
# 2. Check if the current branch is in the ignored list
for ignored_branch in "${IGNORED_BRANCHES[@]}"; do
if [ "$BRANCH_NAME" == "$ignored_branch" ]; then
echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}"
exit 0
fi
done
# 3. Validate the overall structure: <type>/<description>
if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then
error_exit "Branch name must be in the format: <type>/<short-description>\nExample: feat/add-user-login"
fi
# 4. Extract the type and description
TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1)
DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-)
# 5. Validate the <type>
type_valid=false
for allowed_type in "${ALLOWED_TYPES[@]}"; do
if [ "$TYPE" == "$allowed_type" ]; then
type_valid=true
break
fi
done
if [ "$type_valid" == false ]; then
error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}"
fi
# 6. Validate the <short-description>
# Regex breakdown:
# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word).
# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times.
# $ - End of the string.
# This entire pattern enforces 1 to 6 words total, separated by dashes.
DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$"
if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then
error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature"
fi
# If all checks pass, exit successfully
echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}"
exit 0

135
.githooks/check-commit-msg.sh Executable file
View File

@@ -0,0 +1,135 @@
#!/usr/bin/env bash
# This script checks if a commit message follows the specified format.
# It's designed to be used as a 'commit-msg' git hook.
# Format:
# <type>: <short description>
#
# [optional]<body>
#
# [ref/close]: <issue identifier>
# --- Configuration ---
# An array of allowed commit types
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
# --- Colors for Output ---
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# The first argument to the hook is the path to the file containing the commit message
COMMIT_MSG_FILE=$1
# --- Automated Commit Detection ---
# Read the first line (header) for initial checks
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Revert commits
# Example: "Revert "feat: add new feature""
REVERT_PATTERN="^Revert \".*\""
if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then
echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Cherry-pick commits (this pattern appears at the end of the message)
# Example: "(cherry picked from commit deadbeef...)"
# We use grep -q to search the whole file quietly.
CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)"
if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then
echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Squash
# Example: "Squash commits ..."
SQUASH_PATTERN="^Squash .+"
if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then
echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# --- Validation Functions ---
# Function to print an error message and exit
# Usage: error_exit "Your error message here"
error_exit() {
# >&2 redirects echo to stderr
echo -e "${RED}ERROR: $1${NC}" >&2
echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2
exit 1
}
# --- Main Logic ---
# 1. Read the header (first line) of the commit message
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# 2. Validate the header format: <type>: <description>
# Regex breakdown:
# ^(type1|type2|...) - Starts with one of the allowed types
# : - Followed by a literal colon
# \s - Followed by a single space
# .+ - Followed by one or more characters for the description
# $ - End of the line
TYPES_REGEX=$(
IFS="|"
echo "${ALLOWED_TYPES[*]}"
)
HEADER_REGEX="^($TYPES_REGEX): .+$"
if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then
error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature"
fi
# Only validate footer if commit type is not chore
TYPE=$(echo "$HEADER" | cut -d':' -f1)
if [ "$TYPE" != "chore" ]; then
# 3. Validate the footer (last line) of the commit message
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE")
# Regex breakdown:
# ^(ref|close) - Starts with 'ref' or 'close'
# : - Followed by a literal colon
# \s - Followed by a single space
# N25B- - Followed by the literal string 'N25B-'
# [0-9]+ - Followed by one or more digits
# $ - End of the line
FOOTER_REGEX="^(ref|close): N25B-[0-9]+$"
if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then
error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123"
fi
fi
# 4. If the message has more than 2 lines, validate the separator
# A blank line must exist between the header and the body.
LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace
# We only care if there is a body. Header + Footer = 2 lines.
# Header + Blank Line + Body... + Footer > 2 lines.
if [ "$LINE_COUNT" -gt 2 ]; then
# Get the second line
SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE")
# Check if the second line is NOT empty. If it's not, it's an error.
if [ -n "$SECOND_LINE" ]; then
error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present."
fi
fi
# If all checks pass, exit with success
echo -e "${GREEN}Commit message is valid.${NC}"
exit 0

View File

@@ -1,16 +0,0 @@
#!/bin/sh
commit_msg_file=$1
commit_msg=$(cat "$commit_msg_file")
if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then
if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then
exit 0
else
echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000"
exit 1
fi
else
echo "❌ Commit message invalid! Must start with <type>: <description>"
exit 1
fi

View File

@@ -1,17 +0,0 @@
#!/bin/sh
# Get current branch
branch=$(git rev-parse --abbrev-ref HEAD)
if echo "$branch" | grep -Eq "(dev|main)"; then
echo 0
fi
# allowed pattern <type/>
if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then
exit 0
else
echo "❌ Invalid branch name: $branch"
echo "Branch must be named <type>/<description-of-branch> (must have one to six words separated by a dash)"
exit 1
fi

View File

@@ -1,9 +0,0 @@
#!/bin/sh
echo "#<type>: <description>
#[optional body]
#[optional footer(s)]
#[ref/close]: <issue identifier>" > $1

View File

@@ -1,10 +1,24 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.2
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.2
hooks:
# Run the linter.
- id: ruff-check
# Run the formatter.
- id: ruff-format
# Configure local hooks
- repo: local
hooks:
- id: check-commit-msg
name: Check commit message format
entry: .githooks/check-commit-msg.sh
language: script
stages: [commit-msg]
- id: check-branch-name
name: Check branch name format
entry: .githooks/check-branch-name.sh
language: script
stages: [pre-commit]
always_run: true
pass_filenames: false

View File

@@ -47,21 +47,19 @@ Or for integration tests:
uv run --group integration-test pytest test/integration
```
## GitHooks
## Git Hooks
To activate automatic commits/branch name checks run:
To activate automatic linting, formatting, branch name checks and commit message checks, run:
```shell
git config --local core.hooksPath .githooks
```bash
uv run pre-commit install
uv run pre-commit install --hook-type commit-msg
```
If your commit fails its either:
branch name != <type>/description-of-branch ,
commit name != <type>: description of the commit.
<ref>: N25B-Num's
You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running:
To add automatic linting and formatting, run:
```bash
git config --local --unset core.hooksPath
```
```shell
uv run pre-commit install
```
Then run the pre-commit install commands again.

View File

@@ -1,8 +1,10 @@
import logging
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
class ReceiveLLMResponseBehaviour(CyclicBehaviour):
@@ -10,7 +12,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
Adds behavior to receive responses from the LLM Agent.
"""
logger = logging.getLogger("BDI/LLM Reciever")
logger = logging.getLogger("BDI/LLM Receiver")
async def run(self):
msg = await self.receive(timeout=2)
@@ -22,7 +24,20 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
case settings.agent_settings.llm_agent_name:
content = msg.body
self.logger.info("Received LLM response: %s", content)
# Here the BDI can pass the message back as a response
speech_command = SpeechCommand(data=content)
message = Message(
to=settings.agent_settings.ri_command_agent_name
+ "@"
+ settings.agent_settings.host,
sender=self.agent.jid,
body=speech_command.model_dump_json(),
)
self.logger.debug("Sending message: %s", message)
await self.send(message)
case _:
self.logger.debug("Not from the llm, discarding message")
pass

View File

@@ -13,23 +13,23 @@ class BeliefFromText(CyclicBehaviour):
# TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent.
Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs"
(a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript).
You are an information extraction assistent for a BDI agent. Your task is to extract values \
from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
and "text" (user's transcript).
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
A single piece of text might contain multiple instances that match a belief.
Respond ONLY with a single JSON object.
The JSON object's keys should be the belief functors (e.g., "weather").
The value for each key must be a list of lists.
Each inner list must contain the extracted arguments
(as strings) for one instance of that belief.
CRITICAL: If no information in the text matches a belief,
DO NOT include that key in your response.
Each inner list must contain the extracted arguments (as strings) for one instance \
of that belief.
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
in your response.
"""
# on_start agent receives message containing the beliefs to look out
# for and sets up the LLM with instruction prompt
# on_start agent receives message containing the beliefs to look out for and
# sets up the LLM with instruction prompt
# async def on_start(self):
# msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message

View File

@@ -70,8 +70,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
Expected payload:
{
"type": "belief_extraction_text",
"beliefs": {"user_said": ["hello"","Can you help me?",
"stop talking to me","No","Pepper do a dance"]}
"beliefs": {"user_said": ["Can you help me?"]}
}

View File

@@ -3,8 +3,10 @@ LLM Agent module for routing text queries from the BDI Core Agent to a local LLM
service and returning its responses back to the BDI Core Agent.
"""
import json
import logging
from typing import Any
import re
from collections.abc import AsyncGenerator
import httpx
from spade.agent import Agent
@@ -54,11 +56,16 @@ class LLMAgent(Agent):
async def _process_bdi_message(self, message: Message):
"""
Forwards user text to the LLM and replies with the generated text.
Forwards user text from the BDI to the LLM and replies with the generated text in chunks
separated by punctuation.
"""
user_text = message.body
llm_response = await self._query_llm(user_text)
await self._reply(llm_response)
# Consume the streaming generator and send a reply for every chunk
async for chunk in self._query_llm(user_text):
await self._reply(chunk)
self.agent.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI Core Agent."
)
async def _reply(self, msg: str):
"""
@@ -69,48 +76,89 @@ class LLMAgent(Agent):
body=msg,
)
await self.send(reply)
self.agent.logger.info("Reply sent to BDI Core Agent")
async def _query_llm(self, prompt: str) -> str:
async def _query_llm(self, prompt: str) -> AsyncGenerator[str]:
"""
Sends a chat completion request to the local LLM service.
Sends a chat completion request to the local LLM service and streams the response by
yielding fragments separated by punctuation like.
:param prompt: Input text prompt to pass to the LLM.
:return: LLM-generated content or fallback message.
:yield: Fragments of the LLM-generated content.
"""
async with httpx.AsyncClient(timeout=120.0) as client:
# Example dynamic content for future (optional)
instructions = LLMInstructions(
"- Be friendly and respectful.\n"
"- Make the conversation feel natural and engaging.\n"
"- Speak like a pirate.\n"
"- When the user asks what you can do, tell them.",
"- Try to learn the user's name during conversation.\n"
"- Suggest playing a game of asking yes or no questions where you think of a word "
"and the user must guess it.",
)
messages = [
{
"role": "developer",
"content": instructions.build_developer_instruction(),
},
{
"role": "user",
"content": prompt,
},
]
instructions = LLMInstructions()
developer_instruction = instructions.build_developer_instruction()
try:
current_chunk = ""
async for token in self._stream_query_llm(messages):
current_chunk += token
response = await client.post(
# Stream the message in chunks separated by punctuation.
# We include the delimiter in the emitted chunk for natural flow.
pattern = re.compile(r".*?(?:,|;|:|—||\.{3}|…|\.|\?|!)\s*", re.DOTALL)
for m in pattern.finditer(current_chunk):
chunk = m.group(0)
if chunk:
yield current_chunk
current_chunk = ""
# Yield any remaining tail
if current_chunk:
yield current_chunk
except httpx.HTTPError as err:
self.agent.logger.error("HTTP error.", exc_info=err)
yield "LLM service unavailable."
except Exception as err:
self.agent.logger.error("Unexpected error.", exc_info=err)
yield "Error processing the request."
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
"""Raises httpx.HTTPError when the API gives an error."""
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
settings.llm_settings.local_llm_url,
headers={"Content-Type": "application/json"},
json={
"model": settings.llm_settings.local_llm_model,
"messages": [
{"role": "developer", "content": developer_instruction},
{"role": "user", "content": prompt},
],
"messages": messages,
"temperature": 0.3,
"stream": True,
},
)
try:
) as response:
response.raise_for_status()
data: dict[str, Any] = response.json()
return (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "No response")
)
except httpx.HTTPError as err:
self.agent.logger.error("HTTP error: %s", err)
return "LLM service unavailable."
except Exception as err:
self.agent.logger.error("Unexpected error: %s", err)
return "Error processing the request."
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data = line[len("data: ") :]
if data.strip() == "[DONE]":
break
try:
event = json.loads(data)
delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
if delta:
yield delta
except json.JSONDecodeError:
self.agent.logger.error("Failed to parse LLM response: %s", data)
async def setup(self):
"""

View File

@@ -28,7 +28,9 @@ class LLMInstructions:
"""
sections = [
"You are a Pepper robot engaging in natural human conversation.",
"Keep responses between 15 sentences, unless instructed otherwise.\n",
"Keep responses between 13 sentences, unless told otherwise.\n",
"You're given goals to reach. Reach them in order, but make the conversation feel "
"natural. Some turns you should not try to achieve your goals.\n",
]
if self.norms:

View File

@@ -11,8 +11,9 @@ class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self):
to_jid = (
f"{settings.agent_settings.belief_collector_agent_name}"
f"@{settings.agent_settings.host}"
settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host
)
# Send multiple beliefs in one JSON payload

View File

@@ -1,6 +1,7 @@
import json
import logging
import spade.agent
import zmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
@@ -32,6 +33,8 @@ class RICommandAgent(Agent):
self.bind = bind
class SendCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from the UI."""
async def run(self):
"""
Run the command publishing loop indefinetely.
@@ -50,6 +53,18 @@ class RICommandAgent(Agent):
except Exception as e:
logger.error("Error processing message: %s", e)
class SendPythonCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from other Python agents."""
async def run(self):
message: spade.agent.Message = await self.receive(timeout=0.1)
if message and message.to == self.agent.jid:
try:
speech_command = SpeechCommand.model_validate_json(message.body)
await self.agent.pubsocket.send_json(speech_command.model_dump())
except Exception as e:
logger.error("Error processing message: %s", e)
async def setup(self):
"""
Setup the command agent
@@ -73,5 +88,6 @@ class RICommandAgent(Agent):
# Add behaviour to our agent
commands_behaviour = self.SendCommandsBehaviour()
self.add_behaviour(commands_behaviour)
self.add_behaviour(self.SendPythonCommandsBehaviour())
logger.info("Finished setting up %s", self.jid)

View File

@@ -63,7 +63,25 @@ class RICommunicationAgent(Agent):
# We didnt get a reply :(
except TimeoutError:
logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
# Tell UI we're disconnected.
topic = b"ping"
data = json.dumps(False).encode()
if self.agent.pub_socket is None:
logger.error("communication agent pub socket not correctly initialized.")
else:
try:
await asyncio.wait_for(
self.agent.pub_socket.send_multipart([topic, data]), 5
)
except TimeoutError:
logger.error(
"Initial connection ping for router timed"
" out in ri_communication_agent."
)
# Try to reboot.
self.agent.setup()
logger.debug('Received message "%s"', message)
if "endpoint" not in message:

View File

@@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC):
def _estimate_max_tokens(audio: np.ndarray) -> int:
"""
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens.
rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
:param audio: The audio sample (16 kHz) to use for length estimation.
:return: The estimated length of the transcribed audio in tokens.
"""
length_seconds = len(audio) / 16_000
length_minutes = length_seconds / 60
word_count = length_minutes * 300
word_count = length_minutes * 450
token_count = word_count / 3 * 4
return int(token_count)
return int(token_count) + 10
def _get_decode_options(self, audio: np.ndarray) -> dict:
"""
@@ -85,9 +85,10 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return mlx_whisper.transcribe(
audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio)
)["text"]
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
audio,
path_or_hf_repo=self.model_name,
**self._get_decode_options(audio),
)["text"].strip()
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
@@ -103,6 +104,4 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return whisper.transcribe(
self.model, audio, decode_options=self._get_decode_options(audio)
)["text"]
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"]

View File

@@ -58,6 +58,10 @@ class TranscriptionAgent(Agent):
audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
logger.info("Nothing transcribed.")
return
logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)

View File

@@ -54,8 +54,20 @@ class Streaming(CyclicBehaviour):
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100 # Used to allow small pauses in speech
self._ready = False
async def reset(self):
"""Clears the ZeroMQ queue and tells this behavior to start."""
discarded = 0
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
logging.info(f"Discarded {discarded} audio packets before starting.")
self._ready = True
async def run(self) -> None:
if not self._ready:
return
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
@@ -107,6 +119,8 @@ class VADAgent(Agent):
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
self.streaming_behaviour: Streaming | None = None
async def stop(self):
"""
Stop listening to audio, stop publishing audio, close sockets.
@@ -149,8 +163,8 @@ class VADAgent(Agent):
return
audio_out_address = f"tcp://localhost:{audio_out_port}"
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming)
self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(self.streaming_behaviour)
# Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address)

View File

@@ -22,8 +22,8 @@ async def receive_command(command: SpeechCommand, request: Request):
topic = b"command"
# TODO: Check with Kasper
pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}

View File

@@ -14,8 +14,6 @@ from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router
@@ -99,6 +97,8 @@ async def lifespan(app: FastAPI):
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
await _temp_vad_agent.start()
logger.info("VAD agent started, now making ready...")
await _temp_vad_agent.streaming_behaviour.reset()
yield

View File

@@ -7,25 +7,21 @@ import zmq
from control_backend.agents.ri_command_agent import RICommandAgent
@pytest.mark.asyncio
async def test_setup_bind(monkeypatch):
"""Test setup with bind=True"""
fake_socket = MagicMock()
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
# Patch Context.instance() to return fake_context
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
@pytest.mark.asyncio
async def test_setup_bind(zmq_context, mocker):
"""Test setup with bind=True"""
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup()
@@ -36,23 +32,13 @@ async def test_setup_bind(monkeypatch):
@pytest.mark.asyncio
async def test_setup_connect(monkeypatch):
async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False"""
fake_socket = MagicMock()
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
# Patch Context.instance() to return fake_context
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup()

View File

@@ -84,25 +84,24 @@ def fake_json_invalid_id_negototiate():
)
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
async def test_setup_creates_socket_and_negotiate_1(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1()
fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -135,24 +134,16 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
async def test_setup_creates_socket_and_negotiate_2(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2()
fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -185,24 +176,16 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1()
fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent,
@@ -235,24 +218,16 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
async def test_setup_creates_socket_and_negotiate_4(zmq_context):
"""
Test the setup of the communication agent with different bind value
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3()
fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -284,24 +259,16 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
async def test_setup_creates_socket_and_negotiate_5(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4()
fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -333,24 +300,16 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
async def test_setup_creates_socket_and_negotiate_6(zmq_context):
"""
Test the setup of the communication agent
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5()
fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -382,28 +341,20 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
"""
Test the functionality of setup with incorrect id
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate()
fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a etter response, within a limited time.
# so we should retry and expect a better response, within a limited time.
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
@@ -430,24 +381,16 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
@pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
"""
Test the functionality of setup with incorrect negotiation message
"""
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
@@ -534,8 +477,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
@pytest.mark.asyncio
async def test_listen_behaviour_timeout(caplog):
fake_socket = AsyncMock()
async def test_listen_behaviour_timeout(zmq_context, caplog):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
@@ -585,20 +528,13 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
@pytest.mark.asyncio
async def test_setup_unexpected_exception(monkeypatch, caplog):
fake_socket = MagicMock()
async def test_setup_unexpected_exception(zmq_context, caplog):
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
fake_socket.send_multipart = AsyncMock()
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
agent = RICommunicationAgent(
"test@server",
"password",
@@ -614,9 +550,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
@pytest.mark.asyncio
async def test_setup_unpacking_exception(monkeypatch, caplog):
async def test_setup_unpacking_exception(zmq_context, caplog):
# --- Arrange ---
fake_socket = MagicMock()
fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock()
fake_socket.send_multipart = AsyncMock()
@@ -627,14 +563,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
} # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch context.socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Patch RICommandAgent so it won't actually start
with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True

View File

@@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent
@pytest.fixture
def zmq_context(mocker):
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.fixture
@@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
assert vad_agent.audio_in_socket is not None
zmq_context.socket.assert_called_once_with(zmq.SUB)
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
zmq.SUBSCRIBE,
"",
)
if do_bind:
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else:
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
"tcp://localhost:12345"
)
def test_out_socket_creation(zmq_context):
@@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context):
assert vad_agent.audio_out_socket is not None
zmq_context.socket.assert_called_once_with(zmq.PUB)
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio
@@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context):
Test setup failure when the audio output socket cannot be created.
"""
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError
)
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
@@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000)
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000,
10000,
)
await vad_agent.setup()
await vad_agent.stop()
assert zmq_context.socket.return_value.close.call_count == 2
assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None

View File

@@ -48,6 +48,7 @@ async def test_real_audio(mocker):
audio_out_socket = AsyncMock()
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
vad_streamer._ready = True
for _ in audio_chunks:
await vad_streamer.run()

View File

@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI
@@ -16,7 +16,6 @@ def app():
"""
app = FastAPI()
app.include_router(robot.router)
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
return app
@@ -26,32 +25,30 @@ def client(app):
return TestClient(app)
def test_receive_command_endpoint(client, app):
def test_receive_command_success(client):
"""
Test that a POST to /command sends the right multipart message
and returns a 202 with the expected JSON body.
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
"""
mock_socket = app.state.internal_comm_socket
# Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Prepare test payload that matches SpeechCommand
payload = {"endpoint": "actuate/speech", "data": "yooo"}
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
speech_command = SpeechCommand(**command_data)
# Send POST request
response = client.post("/command", json=payload)
# Act
response = client.post("/command", json=command_data)
# Check response
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the socket was called with the correct data
assert mock_socket.send_multipart.called, "Socket should be used to send data"
args, kwargs = mock_socket.send_multipart.call_args
sent_data = args[0]
assert sent_data[0] == b"command"
# Check JSON encoding roughly matches
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
)
def test_receive_command_invalid_payload(client):

View File

@@ -16,12 +16,11 @@ def test_valid_speech_command_1():
command = valid_command_1()
RIMessage.model_validate(command)
SpeechCommand.model_validate(command)
assert True
def test_invalid_speech_command_1():
command = invalid_command_1()
RIMessage.model_validate(command)
with pytest.raises(ValidationError):
SpeechCommand.model_validate(command)
assert True

View File

@@ -182,8 +182,6 @@ async def test_belief_text_values_not_lists(continuous_collector, mocker):
@pytest.mark.asyncio
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
# Your code calls self.send(..); patch it
# (or switch implementation to self.agent.send and patch that)
continuous_collector.send = AsyncMock()
logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"

View File

@@ -21,7 +21,9 @@ def streaming(audio_in_socket, audio_out_socket):
import torch
torch.hub.load.return_value = (..., ...) # Mock
return Streaming(audio_in_socket, audio_out_socket)
streaming = Streaming(audio_in_socket, audio_out_socket)
streaming._ready = True
return streaming
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):

View File

@@ -5,12 +5,13 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper
def test_estimate_max_tokens():
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
expecting 610 tokens."""
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
actual = SpeechRecognizer._estimate_max_tokens(audio)
assert actual == 400
assert actual == 610
assert isinstance(actual, int)