fix: fixed new tests and merged dev into branch

ref: N25B-256
This commit is contained in:
Björn Otgaar
2025-11-05 16:29:56 +01:00
29 changed files with 520 additions and 298 deletions

77
.githooks/check-branch-name.sh Executable file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env bash
# This script checks if the current branch name follows the specified format.
# It's designed to be used as a 'pre-commit' git hook.
# Format: <type>/<short-description>
# Example: feat/add-user-login
# --- Configuration ---
# An array of allowed commit types
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
# An array of branches to ignore
IGNORED_BRANCHES=(main dev demo)
# --- Colors for Output ---
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# --- Helper Functions ---
error_exit() {
echo -e "${RED}ERROR: $1${NC}" >&2
echo -e "${YELLOW}Branch name format is incorrect. Aborting commit.${NC}" >&2
exit 1
}
# --- Main Logic ---
# 1. Get the current branch name
BRANCH_NAME=$(git symbolic-ref --short HEAD)
# 2. Check if the current branch is in the ignored list
for ignored_branch in "${IGNORED_BRANCHES[@]}"; do
if [ "$BRANCH_NAME" == "$ignored_branch" ]; then
echo -e "${GREEN}Branch check skipped for default branch: $BRANCH_NAME${NC}"
exit 0
fi
done
# 3. Validate the overall structure: <type>/<description>
if ! [[ "$BRANCH_NAME" =~ ^[a-z]+/.+$ ]]; then
error_exit "Branch name must be in the format: <type>/<short-description>\nExample: feat/add-user-login"
fi
# 4. Extract the type and description
TYPE=$(echo "$BRANCH_NAME" | cut -d'/' -f1)
DESCRIPTION=$(echo "$BRANCH_NAME" | cut -d'/' -f2-)
# 5. Validate the <type>
type_valid=false
for allowed_type in "${ALLOWED_TYPES[@]}"; do
if [ "$TYPE" == "$allowed_type" ]; then
type_valid=true
break
fi
done
if [ "$type_valid" == false ]; then
error_exit "Invalid type '$TYPE'.\nAllowed types are: ${ALLOWED_TYPES[*]}"
fi
# 6. Validate the <short-description>
# Regex breakdown:
# ^[a-z0-9]+ - Starts with one or more lowercase letters/numbers (the first word).
# (-[a-z0-9]+){0,5} - Followed by a group of (dash + word) 0 to 5 times.
# $ - End of the string.
# This entire pattern enforces 1 to 6 words total, separated by dashes.
DESCRIPTION_REGEX="^[a-z0-9]+(-[a-z0-9]+){0,5}$"
if ! [[ "$DESCRIPTION" =~ $DESCRIPTION_REGEX ]]; then
error_exit "Invalid short description '$DESCRIPTION'.\nIt must be a maximum of 6 words, all lowercase, separated by dashes.\nExample: add-new-user-authentication-feature"
fi
# If all checks pass, exit successfully
echo -e "${GREEN}Branch name '$BRANCH_NAME' is valid.${NC}"
exit 0

135
.githooks/check-commit-msg.sh Executable file
View File

@@ -0,0 +1,135 @@
#!/usr/bin/env bash
# This script checks if a commit message follows the specified format.
# It's designed to be used as a 'commit-msg' git hook.
# Format:
# <type>: <short description>
#
# [optional]<body>
#
# [ref/close]: <issue identifier>
# --- Configuration ---
# An array of allowed commit types
ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert)
# --- Colors for Output ---
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# The first argument to the hook is the path to the file containing the commit message
COMMIT_MSG_FILE=$1
# --- Automated Commit Detection ---
# Read the first line (header) for initial checks
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Revert commits
# Example: "Revert "feat: add new feature""
REVERT_PATTERN="^Revert \".*\""
if [[ "$HEADER" =~ $REVERT_PATTERN ]]; then
echo -e "${GREEN}Revert commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Cherry-pick commits (this pattern appears at the end of the message)
# Example: "(cherry picked from commit deadbeef...)"
# We use grep -q to search the whole file quietly.
CHERRY_PICK_PATTERN="\(cherry picked from commit [a-f0-9]{7,40}\)"
if grep -qE "$CHERRY_PICK_PATTERN" "$COMMIT_MSG_FILE"; then
echo -e "${GREEN}Cherry-pick detected by message content. Skipping validation.${NC}"
exit 0
fi
# Check for Squash
# Example: "Squash commits ..."
SQUASH_PATTERN="^Squash .+"
if [[ "$HEADER" =~ $SQUASH_PATTERN ]]; then
echo -e "${GREEN}Squash commit detected by message content. Skipping validation.${NC}"
exit 0
fi
# --- Validation Functions ---
# Function to print an error message and exit
# Usage: error_exit "Your error message here"
error_exit() {
# >&2 redirects echo to stderr
echo -e "${RED}ERROR: $1${NC}" >&2
echo -e "${YELLOW}Commit message format is incorrect. Aborting commit.${NC}" >&2
exit 1
}
# --- Main Logic ---
# 1. Read the header (first line) of the commit message
HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# 2. Validate the header format: <type>: <description>
# Regex breakdown:
# ^(type1|type2|...) - Starts with one of the allowed types
# : - Followed by a literal colon
# \s - Followed by a single space
# .+ - Followed by one or more characters for the description
# $ - End of the line
TYPES_REGEX=$(
IFS="|"
echo "${ALLOWED_TYPES[*]}"
)
HEADER_REGEX="^($TYPES_REGEX): .+$"
if ! [[ "$HEADER" =~ $HEADER_REGEX ]]; then
error_exit "Invalid header format.\n\nHeader must be in the format: <type>: <short description>\nAllowed types: ${ALLOWED_TYPES[*]}\nExample: feat: add new user authentication feature"
fi
# Only validate footer if commit type is not chore
TYPE=$(echo "$HEADER" | cut -d':' -f1)
if [ "$TYPE" != "chore" ]; then
# 3. Validate the footer (last line) of the commit message
FOOTER=$(tail -n 1 "$COMMIT_MSG_FILE")
# Regex breakdown:
# ^(ref|close) - Starts with 'ref' or 'close'
# : - Followed by a literal colon
# \s - Followed by a single space
# N25B- - Followed by the literal string 'N25B-'
# [0-9]+ - Followed by one or more digits
# $ - End of the line
FOOTER_REGEX="^(ref|close): N25B-[0-9]+$"
if ! [[ "$FOOTER" =~ $FOOTER_REGEX ]]; then
error_exit "Invalid footer format.\n\nFooter must be in the format: [ref/close]: <issue identifier>\nExample: ref: N25B-123"
fi
fi
# 4. If the message has more than 2 lines, validate the separator
# A blank line must exist between the header and the body.
LINE_COUNT=$(wc -l <"$COMMIT_MSG_FILE" | xargs) # xargs trims whitespace
# We only care if there is a body. Header + Footer = 2 lines.
# Header + Blank Line + Body... + Footer > 2 lines.
if [ "$LINE_COUNT" -gt 2 ]; then
# Get the second line
SECOND_LINE=$(sed -n '2p' "$COMMIT_MSG_FILE")
# Check if the second line is NOT empty. If it's not, it's an error.
if [ -n "$SECOND_LINE" ]; then
error_exit "Missing blank line between header and body.\n\nThe second line of your commit message must be empty if a body is present."
fi
fi
# If all checks pass, exit with success
echo -e "${GREEN}Commit message is valid.${NC}"
exit 0

View File

@@ -1,16 +0,0 @@
#!/bin/sh
commit_msg_file=$1
commit_msg=$(cat "$commit_msg_file")
if echo "$commit_msg" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert): .+"; then
if echo "$commit_msg" | grep -Eq "^(ref|close):\sN25B-.+"; then
exit 0
else
echo "❌ Commit message invalid! Must end with [ref/close]: N25B-000"
exit 1
fi
else
echo "❌ Commit message invalid! Must start with <type>: <description>"
exit 1
fi

View File

@@ -1,17 +0,0 @@
#!/bin/sh
# Get current branch
branch=$(git rev-parse --abbrev-ref HEAD)
if echo "$branch" | grep -Eq "(dev|main)"; then
echo 0
fi
# allowed pattern <type/>
if echo "$branch" | grep -Eq "^(feat|fix|refactor|perf|style|test|docs|build|chore|revert)\/\w+(-\w+){0,5}$"; then
exit 0
else
echo "❌ Invalid branch name: $branch"
echo "Branch must be named <type>/<description-of-branch> (must have one to six words separated by a dash)"
exit 1
fi

View File

@@ -1,9 +0,0 @@
#!/bin/sh
echo "#<type>: <description>
#[optional body]
#[optional footer(s)]
#[ref/close]: <issue identifier>" > $1

View File

@@ -1,10 +1,24 @@
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: v0.14.2 rev: v0.14.2
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff-check - id: ruff-check
args: [ --fix ] # Run the formatter.
# Run the formatter. - id: ruff-format
- id: ruff-format # Configure local hooks
- repo: local
hooks:
- id: check-commit-msg
name: Check commit message format
entry: .githooks/check-commit-msg.sh
language: script
stages: [commit-msg]
- id: check-branch-name
name: Check branch name format
entry: .githooks/check-branch-name.sh
language: script
stages: [pre-commit]
always_run: true
pass_filenames: false

View File

@@ -47,21 +47,19 @@ Or for integration tests:
uv run --group integration-test pytest test/integration uv run --group integration-test pytest test/integration
``` ```
## GitHooks ## Git Hooks
To activate automatic commits/branch name checks run: To activate automatic linting, formatting, branch name checks and commit message checks, run:
```shell ```bash
git config --local core.hooksPath .githooks uv run pre-commit install
uv run pre-commit install --hook-type commit-msg
``` ```
If your commit fails its either: You might get an error along the lines of `Can't install pre-commit with core.hooksPath` set. To fix this, simply unset the hooksPath by running:
branch name != <type>/description-of-branch ,
commit name != <type>: description of the commit.
<ref>: N25B-Num's
To add automatic linting and formatting, run: ```bash
git config --local --unset core.hooksPath
```
```shell Then run the pre-commit install commands again.
uv run pre-commit install
```

View File

@@ -1,8 +1,10 @@
import logging import logging
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
class ReceiveLLMResponseBehaviour(CyclicBehaviour): class ReceiveLLMResponseBehaviour(CyclicBehaviour):
@@ -10,7 +12,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
Adds behavior to receive responses from the LLM Agent. Adds behavior to receive responses from the LLM Agent.
""" """
logger = logging.getLogger("BDI/LLM Reciever") logger = logging.getLogger("BDI/LLM Receiver")
async def run(self): async def run(self):
msg = await self.receive(timeout=2) msg = await self.receive(timeout=2)
@@ -22,7 +24,20 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
case settings.agent_settings.llm_agent_name: case settings.agent_settings.llm_agent_name:
content = msg.body content = msg.body
self.logger.info("Received LLM response: %s", content) self.logger.info("Received LLM response: %s", content)
# Here the BDI can pass the message back as a response
speech_command = SpeechCommand(data=content)
message = Message(
to=settings.agent_settings.ri_command_agent_name
+ "@"
+ settings.agent_settings.host,
sender=self.agent.jid,
body=speech_command.model_dump_json(),
)
self.logger.debug("Sending message: %s", message)
await self.send(message)
case _: case _:
self.logger.debug("Not from the llm, discarding message") self.logger.debug("Not from the llm, discarding message")
pass pass

View File

@@ -13,23 +13,23 @@ class BeliefFromText(CyclicBehaviour):
# TODO: LLM prompt nog hardcoded # TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """ llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent. You are an information extraction assistent for a BDI agent. Your task is to extract values \
Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules: from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs" You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
(a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript). and "text" (user's transcript).
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs. Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
A single piece of text might contain multiple instances that match a belief. A single piece of text might contain multiple instances that match a belief.
Respond ONLY with a single JSON object. Respond ONLY with a single JSON object.
The JSON object's keys should be the belief functors (e.g., "weather"). The JSON object's keys should be the belief functors (e.g., "weather").
The value for each key must be a list of lists. The value for each key must be a list of lists.
Each inner list must contain the extracted arguments Each inner list must contain the extracted arguments (as strings) for one instance \
(as strings) for one instance of that belief. of that belief.
CRITICAL: If no information in the text matches a belief, CRITICAL: If no information in the text matches a belief, DO NOT include that key \
DO NOT include that key in your response. in your response.
""" """
# on_start agent receives message containing the beliefs to look out # on_start agent receives message containing the beliefs to look out for and
# for and sets up the LLM with instruction prompt # sets up the LLM with instruction prompt
# async def on_start(self): # async def on_start(self):
# msg = await self.receive(timeout=0.1) # msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message # self.beliefs = dict uit message

View File

@@ -70,8 +70,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
Expected payload: Expected payload:
{ {
"type": "belief_extraction_text", "type": "belief_extraction_text",
"beliefs": {"user_said": ["hello"","Can you help me?", "beliefs": {"user_said": ["Can you help me?"]}
"stop talking to me","No","Pepper do a dance"]}
} }

View File

@@ -3,8 +3,10 @@ LLM Agent module for routing text queries from the BDI Core Agent to a local LLM
service and returning its responses back to the BDI Core Agent. service and returning its responses back to the BDI Core Agent.
""" """
import json
import logging import logging
from typing import Any import re
from collections.abc import AsyncGenerator
import httpx import httpx
from spade.agent import Agent from spade.agent import Agent
@@ -54,11 +56,16 @@ class LLMAgent(Agent):
async def _process_bdi_message(self, message: Message): async def _process_bdi_message(self, message: Message):
""" """
Forwards user text to the LLM and replies with the generated text. Forwards user text from the BDI to the LLM and replies with the generated text in chunks
separated by punctuation.
""" """
user_text = message.body user_text = message.body
llm_response = await self._query_llm(user_text) # Consume the streaming generator and send a reply for every chunk
await self._reply(llm_response) async for chunk in self._query_llm(user_text):
await self._reply(chunk)
self.agent.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI Core Agent."
)
async def _reply(self, msg: str): async def _reply(self, msg: str):
""" """
@@ -69,48 +76,89 @@ class LLMAgent(Agent):
body=msg, body=msg,
) )
await self.send(reply) await self.send(reply)
self.agent.logger.info("Reply sent to BDI Core Agent")
async def _query_llm(self, prompt: str) -> str: async def _query_llm(self, prompt: str) -> AsyncGenerator[str]:
""" """
Sends a chat completion request to the local LLM service. Sends a chat completion request to the local LLM service and streams the response by
yielding fragments separated by punctuation like.
:param prompt: Input text prompt to pass to the LLM. :param prompt: Input text prompt to pass to the LLM.
:return: LLM-generated content or fallback message. :yield: Fragments of the LLM-generated content.
""" """
async with httpx.AsyncClient(timeout=120.0) as client: instructions = LLMInstructions(
# Example dynamic content for future (optional) "- Be friendly and respectful.\n"
"- Make the conversation feel natural and engaging.\n"
"- Speak like a pirate.\n"
"- When the user asks what you can do, tell them.",
"- Try to learn the user's name during conversation.\n"
"- Suggest playing a game of asking yes or no questions where you think of a word "
"and the user must guess it.",
)
messages = [
{
"role": "developer",
"content": instructions.build_developer_instruction(),
},
{
"role": "user",
"content": prompt,
},
]
instructions = LLMInstructions() try:
developer_instruction = instructions.build_developer_instruction() current_chunk = ""
async for token in self._stream_query_llm(messages):
current_chunk += token
response = await client.post( # Stream the message in chunks separated by punctuation.
# We include the delimiter in the emitted chunk for natural flow.
pattern = re.compile(r".*?(?:,|;|:|—||\.{3}|…|\.|\?|!)\s*", re.DOTALL)
for m in pattern.finditer(current_chunk):
chunk = m.group(0)
if chunk:
yield current_chunk
current_chunk = ""
# Yield any remaining tail
if current_chunk:
yield current_chunk
except httpx.HTTPError as err:
self.agent.logger.error("HTTP error.", exc_info=err)
yield "LLM service unavailable."
except Exception as err:
self.agent.logger.error("Unexpected error.", exc_info=err)
yield "Error processing the request."
async def _stream_query_llm(self, messages) -> AsyncGenerator[str]:
"""Raises httpx.HTTPError when the API gives an error."""
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
settings.llm_settings.local_llm_url, settings.llm_settings.local_llm_url,
headers={"Content-Type": "application/json"},
json={ json={
"model": settings.llm_settings.local_llm_model, "model": settings.llm_settings.local_llm_model,
"messages": [ "messages": messages,
{"role": "developer", "content": developer_instruction},
{"role": "user", "content": prompt},
],
"temperature": 0.3, "temperature": 0.3,
"stream": True,
}, },
) ) as response:
try:
response.raise_for_status() response.raise_for_status()
data: dict[str, Any] = response.json()
return ( async for line in response.aiter_lines():
data.get("choices", [{}])[0] if not line or not line.startswith("data: "):
.get("message", {}) continue
.get("content", "No response")
) data = line[len("data: ") :]
except httpx.HTTPError as err: if data.strip() == "[DONE]":
self.agent.logger.error("HTTP error: %s", err) break
return "LLM service unavailable."
except Exception as err: try:
self.agent.logger.error("Unexpected error: %s", err) event = json.loads(data)
return "Error processing the request." delta = event.get("choices", [{}])[0].get("delta", {}).get("content")
if delta:
yield delta
except json.JSONDecodeError:
self.agent.logger.error("Failed to parse LLM response: %s", data)
async def setup(self): async def setup(self):
""" """

View File

@@ -28,7 +28,9 @@ class LLMInstructions:
""" """
sections = [ sections = [
"You are a Pepper robot engaging in natural human conversation.", "You are a Pepper robot engaging in natural human conversation.",
"Keep responses between 15 sentences, unless instructed otherwise.\n", "Keep responses between 13 sentences, unless told otherwise.\n",
"You're given goals to reach. Reach them in order, but make the conversation feel "
"natural. Some turns you should not try to achieve your goals.\n",
] ]
if self.norms: if self.norms:

View File

@@ -11,8 +11,9 @@ class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour): class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self): async def run(self):
to_jid = ( to_jid = (
f"{settings.agent_settings.belief_collector_agent_name}" settings.agent_settings.belief_collector_agent_name
f"@{settings.agent_settings.host}" + "@"
+ settings.agent_settings.host
) )
# Send multiple beliefs in one JSON payload # Send multiple beliefs in one JSON payload

View File

@@ -1,6 +1,7 @@
import json import json
import logging import logging
import spade.agent
import zmq import zmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
@@ -32,6 +33,8 @@ class RICommandAgent(Agent):
self.bind = bind self.bind = bind
class SendCommandsBehaviour(CyclicBehaviour): class SendCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from the UI."""
async def run(self): async def run(self):
""" """
Run the command publishing loop indefinetely. Run the command publishing loop indefinetely.
@@ -50,6 +53,18 @@ class RICommandAgent(Agent):
except Exception as e: except Exception as e:
logger.error("Error processing message: %s", e) logger.error("Error processing message: %s", e)
class SendPythonCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from other Python agents."""
async def run(self):
message: spade.agent.Message = await self.receive(timeout=0.1)
if message and message.to == self.agent.jid:
try:
speech_command = SpeechCommand.model_validate_json(message.body)
await self.agent.pubsocket.send_json(speech_command.model_dump())
except Exception as e:
logger.error("Error processing message: %s", e)
async def setup(self): async def setup(self):
""" """
Setup the command agent Setup the command agent
@@ -73,5 +88,6 @@ class RICommandAgent(Agent):
# Add behaviour to our agent # Add behaviour to our agent
commands_behaviour = self.SendCommandsBehaviour() commands_behaviour = self.SendCommandsBehaviour()
self.add_behaviour(commands_behaviour) self.add_behaviour(commands_behaviour)
self.add_behaviour(self.SendPythonCommandsBehaviour())
logger.info("Finished setting up %s", self.jid) logger.info("Finished setting up %s", self.jid)

View File

@@ -63,7 +63,25 @@ class RICommunicationAgent(Agent):
# We didnt get a reply :( # We didnt get a reply :(
except TimeoutError: except TimeoutError:
logger.info("No ping retrieved in 3 seconds, killing myself.") logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
# Tell UI we're disconnected.
topic = b"ping"
data = json.dumps(False).encode()
if self.agent.pub_socket is None:
logger.error("communication agent pub socket not correctly initialized.")
else:
try:
await asyncio.wait_for(
self.agent.pub_socket.send_multipart([topic, data]), 5
)
except TimeoutError:
logger.error(
"Initial connection ping for router timed"
" out in ri_communication_agent."
)
# Try to reboot.
self.agent.setup()
logger.debug('Received message "%s"', message) logger.debug('Received message "%s"', message)
if "endpoint" not in message: if "endpoint" not in message:

View File

@@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC):
def _estimate_max_tokens(audio: np.ndarray) -> int: def _estimate_max_tokens(audio: np.ndarray) -> int:
""" """
Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking
rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens. rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens.
:param audio: The audio sample (16 kHz) to use for length estimation. :param audio: The audio sample (16 kHz) to use for length estimation.
:return: The estimated length of the transcribed audio in tokens. :return: The estimated length of the transcribed audio in tokens.
""" """
length_seconds = len(audio) / 16_000 length_seconds = len(audio) / 16_000
length_minutes = length_seconds / 60 length_minutes = length_seconds / 60
word_count = length_minutes * 300 word_count = length_minutes * 450
token_count = word_count / 3 * 4 token_count = word_count / 3 * 4
return int(token_count) return int(token_count) + 10
def _get_decode_options(self, audio: np.ndarray) -> dict: def _get_decode_options(self, audio: np.ndarray) -> dict:
""" """
@@ -85,9 +85,10 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str: def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model() self.load_model()
return mlx_whisper.transcribe( return mlx_whisper.transcribe(
audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio) audio,
)["text"] path_or_hf_repo=self.model_name,
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip() **self._get_decode_options(audio),
)["text"].strip()
class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
@@ -103,6 +104,4 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str: def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model() self.load_model()
return whisper.transcribe( return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"]
self.model, audio, decode_options=self._get_decode_options(audio)
)["text"]

View File

@@ -58,6 +58,10 @@ class TranscriptionAgent(Agent):
audio = await self.audio_in_socket.recv() audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32) audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio) speech = await self._transcribe(audio)
if not speech:
logger.info("Nothing transcribed.")
return
logger.info("Transcribed speech: %s", speech) logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech) await self._share_transcription(speech)

View File

@@ -54,8 +54,20 @@ class Streaming(CyclicBehaviour):
self.audio_buffer = np.array([], dtype=np.float32) self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100 # Used to allow small pauses in speech self.i_since_speech = 100 # Used to allow small pauses in speech
self._ready = False
async def reset(self):
"""Clears the ZeroMQ queue and tells this behavior to start."""
discarded = 0
while await self.audio_in_poller.poll(1) is not None:
discarded += 1
logging.info(f"Discarded {discarded} audio packets before starting.")
self._ready = True
async def run(self) -> None: async def run(self) -> None:
if not self._ready:
return
data = await self.audio_in_poller.poll() data = await self.audio_in_poller.poll()
if data is None: if data is None:
if len(self.audio_buffer) > 0: if len(self.audio_buffer) > 0:
@@ -107,6 +119,8 @@ class VADAgent(Agent):
self.audio_in_socket: azmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None
self.streaming_behaviour: Streaming | None = None
async def stop(self): async def stop(self):
""" """
Stop listening to audio, stop publishing audio, close sockets. Stop listening to audio, stop publishing audio, close sockets.
@@ -149,8 +163,8 @@ class VADAgent(Agent):
return return
audio_out_address = f"tcp://localhost:{audio_out_port}" audio_out_address = f"tcp://localhost:{audio_out_port}"
streaming = Streaming(self.audio_in_socket, self.audio_out_socket) self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming) self.add_behaviour(self.streaming_behaviour)
# Start agents dependent on the output audio fragments here # Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address) transcriber = TranscriptionAgent(audio_out_address)

View File

@@ -22,8 +22,8 @@ async def receive_command(command: SpeechCommand, request: Request):
topic = b"command" topic = b"command"
# TODO: Check with Kasper # TODO: Check with Kasper
pub_socket: Socket = request.app.state.internal_comm_socket pub_socket: Socket = request.app.state.endpoints_pub_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()]) await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"} return {"status": "Command received"}

View File

@@ -14,8 +14,6 @@ from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent from control_backend.agents.llm.llm import LLMAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.vad_agent import VADAgent from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
@@ -99,6 +97,8 @@ async def lifespan(app: FastAPI):
_temp_vad_agent = VADAgent("tcp://localhost:5558", False) _temp_vad_agent = VADAgent("tcp://localhost:5558", False)
await _temp_vad_agent.start() await _temp_vad_agent.start()
logger.info("VAD agent started, now making ready...")
await _temp_vad_agent.streaming_behaviour.reset()
yield yield

View File

@@ -7,25 +7,21 @@ import zmq
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.ri_command_agent import RICommandAgent
@pytest.mark.asyncio @pytest.fixture
async def test_setup_bind(monkeypatch): def zmq_context(mocker):
"""Test setup with bind=True""" mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
fake_socket = MagicMock() mock_context.return_value = MagicMock()
fake_context = MagicMock() return mock_context
fake_context.socket.return_value = fake_socket
# Patch Context.instance() to return fake_context
monkeypatch.setattr( @pytest.mark.asyncio
"control_backend.agents.ri_command_agent.Context", async def test_setup_bind(zmq_context, mocker):
MagicMock(instance=MagicMock(return_value=fake_context)), """Test setup with bind=True"""
) fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
monkeypatch.setattr( settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
"control_backend.agents.ri_command_agent.settings",
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
)
await agent.setup() await agent.setup()
@@ -36,23 +32,13 @@ async def test_setup_bind(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_connect(monkeypatch): async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False""" """Test setup with bind=False"""
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
# Patch Context.instance() to return fake_context
monkeypatch.setattr(
"control_backend.agents.ri_command_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
monkeypatch.setattr( settings = mocker.patch("control_backend.agents.ri_command_agent.settings")
"control_backend.agents.ri_command_agent.settings", settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
)
await agent.setup() await agent.setup()

View File

@@ -84,25 +84,24 @@ def fake_json_invalid_id_negototiate():
) )
@pytest.fixture
def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_1(monkeypatch): async def test_setup_creates_socket_and_negotiate_1(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1() fake_socket.recv_json = fake_json_correct_negototiate_1()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -135,24 +134,16 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_2(monkeypatch): async def test_setup_creates_socket_and_negotiate_2(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2() fake_socket.recv_json = fake_json_correct_negototiate_2()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -185,24 +176,16 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect negotiation message Test the functionality of setup with incorrect negotiation message
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1() fake_socket.recv_json = fake_json_wrong_negototiate_1()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
@@ -235,24 +218,16 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_4(monkeypatch): async def test_setup_creates_socket_and_negotiate_4(zmq_context):
""" """
Test the setup of the communication agent with different bind value Test the setup of the communication agent with different bind value
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3() fake_socket.recv_json = fake_json_correct_negototiate_3()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -284,24 +259,16 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_5(monkeypatch): async def test_setup_creates_socket_and_negotiate_5(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4() fake_socket.recv_json = fake_json_correct_negototiate_4()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -333,24 +300,16 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_6(monkeypatch): async def test_setup_creates_socket_and_negotiate_6(zmq_context):
""" """
Test the setup of the communication agent Test the setup of the communication agent
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5() fake_socket.recv_json = fake_json_correct_negototiate_5()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
@@ -382,28 +341,20 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect id Test the functionality of setup with incorrect id
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate() fake_socket.recv_json = fake_json_invalid_id_negototiate()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a etter response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent: ) as MockCommandAgent:
@@ -430,24 +381,16 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
""" """
Test the functionality of setup with incorrect negotiation message Test the functionality of setup with incorrect negotiation message
""" """
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock context.socket to return our fake socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent: ) as MockCommandAgent:
@@ -534,8 +477,8 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_timeout(caplog): async def test_listen_behaviour_timeout(zmq_context, caplog):
fake_socket = AsyncMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout # recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
@@ -585,20 +528,13 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unexpected_exception(monkeypatch, caplog): async def test_setup_unexpected_exception(zmq_context, caplog):
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json() # Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "test@server",
"password", "password",
@@ -614,9 +550,9 @@ async def test_setup_unexpected_exception(monkeypatch, caplog):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unpacking_exception(monkeypatch, caplog): async def test_setup_unpacking_exception(zmq_context, caplog):
# --- Arrange --- # --- Arrange ---
fake_socket = MagicMock() fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
@@ -627,14 +563,6 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
} # missing 'port' and 'bind' } # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data) fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch context.socket
fake_context = MagicMock()
fake_context.socket.return_value = fake_socket
monkeypatch.setattr(
"control_backend.agents.ri_communication_agent.Context",
MagicMock(instance=MagicMock(return_value=fake_context)),
)
# Patch RICommandAgent so it won't actually start # Patch RICommandAgent so it won't actually start
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True

View File

@@ -10,7 +10,9 @@ from control_backend.agents.vad_agent import VADAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
return mocker.patch("control_backend.agents.vad_agent.zmq_context") mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock()
return mock_context
@pytest.fixture @pytest.fixture
@@ -54,13 +56,18 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
assert vad_agent.audio_in_socket is not None assert vad_agent.audio_in_socket is not None
zmq_context.socket.assert_called_once_with(zmq.SUB) zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
zmq.SUBSCRIBE,
"",
)
if do_bind: if do_bind:
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else: else:
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
"tcp://localhost:12345"
)
def test_out_socket_creation(zmq_context): def test_out_socket_creation(zmq_context):
@@ -73,8 +80,8 @@ def test_out_socket_creation(zmq_context):
assert vad_agent.audio_out_socket is not None assert vad_agent.audio_out_socket is not None
zmq_context.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.socket.return_value.bind_to_random_port.assert_called_once() zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -83,7 +90,9 @@ async def test_out_socket_creation_failure(zmq_context):
Test setup failure when the audio output socket cannot be created. Test setup failure when the audio output socket cannot be created.
""" """
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError
)
vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup() await vad_agent.setup()
@@ -98,11 +107,14 @@ async def test_stop(zmq_context, transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000,
10000,
)
await vad_agent.setup() await vad_agent.setup()
await vad_agent.stop() await vad_agent.stop()
assert zmq_context.socket.return_value.close.call_count == 2 assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None assert vad_agent.audio_out_socket is None

View File

@@ -48,6 +48,7 @@ async def test_real_audio(mocker):
audio_out_socket = AsyncMock() audio_out_socket = AsyncMock()
vad_streamer = Streaming(audio_in_socket, audio_out_socket) vad_streamer = Streaming(audio_in_socket, audio_out_socket)
vad_streamer._ready = True
for _ in audio_chunks: for _ in audio_chunks:
await vad_streamer.run() await vad_streamer.run()

View File

@@ -1,4 +1,4 @@
from unittest.mock import MagicMock from unittest.mock import AsyncMock
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
@@ -16,7 +16,6 @@ def app():
""" """
app = FastAPI() app = FastAPI()
app.include_router(robot.router) app.include_router(robot.router)
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
return app return app
@@ -26,32 +25,30 @@ def client(app):
return TestClient(app) return TestClient(app)
def test_receive_command_endpoint(client, app): def test_receive_command_success(client):
""" """
Test that a POST to /command sends the right multipart message Test for successful reception of a command. Ensures the status code is 202 and the response body
and returns a 202 with the expected JSON body. is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
""" """
mock_socket = app.state.internal_comm_socket # Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Prepare test payload that matches SpeechCommand command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
payload = {"endpoint": "actuate/speech", "data": "yooo"} speech_command = SpeechCommand(**command_data)
# Send POST request # Act
response = client.post("/command", json=payload) response = client.post("/command", json=command_data)
# Check response # Assert
assert response.status_code == 202 assert response.status_code == 202
assert response.json() == {"status": "Command received"} assert response.json() == {"status": "Command received"}
# Verify that the socket was called with the correct data # Verify that the ZMQ socket was used correctly
assert mock_socket.send_multipart.called, "Socket should be used to send data" mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
args, kwargs = mock_socket.send_multipart.call_args )
sent_data = args[0]
assert sent_data[0] == b"command"
# Check JSON encoding roughly matches
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
def test_receive_command_invalid_payload(client): def test_receive_command_invalid_payload(client):

View File

@@ -16,12 +16,11 @@ def test_valid_speech_command_1():
command = valid_command_1() command = valid_command_1()
RIMessage.model_validate(command) RIMessage.model_validate(command)
SpeechCommand.model_validate(command) SpeechCommand.model_validate(command)
assert True
def test_invalid_speech_command_1(): def test_invalid_speech_command_1():
command = invalid_command_1() command = invalid_command_1()
RIMessage.model_validate(command) RIMessage.model_validate(command)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
SpeechCommand.model_validate(command) SpeechCommand.model_validate(command)
assert True

View File

@@ -182,8 +182,6 @@ async def test_belief_text_values_not_lists(continuous_collector, mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker): async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}} payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
# Your code calls self.send(..); patch it
# (or switch implementation to self.agent.send and patch that)
continuous_collector.send = AsyncMock() continuous_collector.send = AsyncMock()
logger_mock = mocker.patch( logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger" "control_backend.agents.belief_collector.behaviours.continuous_collect.logger"

View File

@@ -21,7 +21,9 @@ def streaming(audio_in_socket, audio_out_socket):
import torch import torch
torch.hub.load.return_value = (..., ...) # Mock torch.hub.load.return_value = (..., ...) # Mock
return Streaming(audio_in_socket, audio_out_socket) streaming = Streaming(audio_in_socket, audio_out_socket)
streaming._ready = True
return streaming
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):

View File

@@ -5,12 +5,13 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper
def test_estimate_max_tokens(): def test_estimate_max_tokens():
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens.""" """Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
expecting 610 tokens."""
audio = np.empty(shape=(60 * 16_000), dtype=np.float32) audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
actual = SpeechRecognizer._estimate_max_tokens(audio) actual = SpeechRecognizer._estimate_max_tokens(audio)
assert actual == 400 assert actual == 610
assert isinstance(actual, int) assert isinstance(actual, int)