Compare commits

..

39 Commits

Author SHA1 Message Date
Pim Hutting
02063a73b2 feat: added recursive mapping
ref: N25B-400
2026-01-22 11:15:13 +01:00
Storm
b9a47eeb0c Merge branch 'feat/visual-emotion-recognition' into demo 2026-01-20 12:48:27 +01:00
Storm
f9b807fc97 chore: quick push before demo; fixed image receiving from RI 2026-01-20 12:46:30 +01:00
8575ddcbcf feat: add experiment log for phase transition
ref: N25B-453
2026-01-20 12:30:47 +01:00
59b35b31b2 feat: add UI log statement for triggers
ref: N25B-453
2026-01-20 12:08:31 +01:00
Twirre Meulenbelt
7516667545 feat: add useful experiment logs to various agents
ref: N25B-401
2026-01-20 11:58:30 +01:00
651f1b74a6 chore: long timeout for non-local LLM 2026-01-20 11:55:00 +01:00
5ed751de8c chore: add logs to .gitignore 2026-01-20 11:05:32 +01:00
89ebe45724 Merge remote-tracking branch 'origin/feat/experiment-logging' into demo 2026-01-20 11:04:31 +01:00
Storm
424294b0a3 Merged feat/longer-pauses-possible into feat/visual-emotion-recognition 2026-01-19 18:35:07 +01:00
Pim Hutting
bc0947fac1 chore: added a dot 2026-01-19 18:26:15 +01:00
Storm
cd80cdf93b Merge branch 'feat/longer-pauses-possible' into feat/visual-emotion-recognition 2026-01-19 18:24:31 +01:00
230afef16f test: fix tests
ref: N25B-452
2026-01-19 16:06:17 +01:00
1cd5b46f97 fix: should work now
Also added trimming to Windows transcription.

ref: N25B-452
2026-01-19 15:03:59 +01:00
c0789e82a9 feat: add previously interrupted message to current
ref: N25B-452
2026-01-19 14:47:11 +01:00
04d19cee5c feat: (maybe) stop response when new user message
If we get a new message before the LLM is done responding, interrupt it.

ref: N25B-452
2026-01-19 14:08:26 +01:00
Storm
985327de70 docs: updated docstrings and fixed styling
ref: N25B-393
2026-01-19 12:52:00 +01:00
Twirre Meulenbelt
58881b5914 test: add test cases
ref: N25B-401
2026-01-19 12:47:59 +01:00
Storm
302c50934e feat: implemented emotion recognition functionality in AgentSpeak
ref: N25B-393
2026-01-19 12:10:58 +01:00
Storm
f9c69cafb3 Merge branch 'feat/reset-experiment-and-phase' into feat/visual-emotion-recognition 2026-01-19 11:45:31 +01:00
Twirre Meulenbelt
ba79d09c5d feat: log download endpoints
ref: N25B-401
2026-01-16 16:32:51 +01:00
db64eaeb0b fix: failing tests and warnings
ref: N25B-449
2026-01-16 16:18:36 +01:00
7f7e0c542e docs: add missing docs
ref: N25B-115
2026-01-16 15:35:41 +01:00
Storm
1b0b72d63a chore: fixed broken uv.lock file 2026-01-16 15:10:55 +01:00
41bd3ffc50 Merge branch 'test/increase-coverage' into feat/reset-experiment-and-phase 2026-01-16 15:08:34 +01:00
Storm
0941b26703 refactor: updated how changes are passed to bdi_core_agent after merge
ref: N25B-393
2026-01-16 15:05:13 +01:00
Storm
48ae0c7a12 Merge remote-tracking branch 'origin/feat/reset-experiment-and-phase' into feat/visual-emotion-recognition 2026-01-16 14:45:16 +01:00
Storm
a09d8b3d9a chore: small changes 2026-01-16 14:40:59 +01:00
Pim Hutting
7c10c50336 chore: removed resetExperiment from backened
now it happens in UI

ref: N25B-400
2026-01-16 14:29:46 +01:00
Pim Hutting
6d03ba8a41 feat: added extra endpoint for norm pings
also made sure that you cannot skip phase on end phase

ref: N25B-400
2026-01-16 14:28:27 +01:00
Storm
ac20048f02 Merge branch 'dev' into feat/visual-emotion-recognition 2026-01-16 14:16:28 +01:00
Storm
05804c158d feat: fully implemented visual emotion recognition agent in pipeline
ref: N25B-393
2026-01-16 13:26:53 +01:00
Storm
0771b0d607 feat: implemented visual emotion recogntion agent
ref: N25B-393
2026-01-16 09:50:59 +01:00
Twirre Meulenbelt
4cda4e5e70 feat: experiment log stream, to file and UI
Adds a few new logging utility classes. One to save to files with a date, one to support optional fields in formats, last to filter partial log messages.

ref: N25B-401
2026-01-15 17:07:49 +01:00
Luijkx,S.O.H. (Storm)
a9df9208bc Merge branch 'feat/multiple-receivers' into 'dev'
feat: able to send to multiple receivers

See merge request ics/sp/2025/n25b/pepperplus-cb!42
2026-01-15 09:26:12 +00:00
Pim Hutting
041fc4ab6e chore: cond_norms unachieve and via belief msg 2026-01-15 09:02:52 +01:00
Twirre Meulenbelt
d7d697b293 docs: update to docstring
ref: N25B-441
2026-01-13 17:09:26 +01:00
Twirre Meulenbelt
9a55067a13 fix: set sender for internal messages
ref: N25B-441
2026-01-13 17:07:17 +01:00
Storm
1c88ae6078 feat: visual emotion recognition agent
ref: N25B-393
2026-01-13 12:41:18 +01:00
50 changed files with 2239 additions and 279 deletions

2
.gitignore vendored
View File

@@ -224,7 +224,7 @@ docs/*
# Generated files
agentspeak.asl
experiment-*.log

View File

@@ -1,36 +1,57 @@
version: 1
custom_levels:
OBSERVATION: 25
ACTION: 26
OBSERVATION: 24
ACTION: 25
CHAT: 26
LLM: 9
formatters:
# Console output
colored:
(): "colorlog.ColoredFormatter"
class: colorlog.ColoredFormatter
format: "{log_color}{asctime}.{msecs:03.0f} | {levelname:11} | {name:70} | {message}"
style: "{"
datefmt: "%H:%M:%S"
# User-facing UI (structured JSON)
json_experiment:
(): "pythonjsonlogger.jsonlogger.JsonFormatter"
json:
class: pythonjsonlogger.jsonlogger.JsonFormatter
format: "{name} {levelname} {levelno} {message} {created} {relativeCreated}"
style: "{"
# Experiment stream for console and file output, with optional `role` field
experiment:
class: control_backend.logging.OptionalFieldFormatter
format: "%(asctime)s %(levelname)s %(role?)s %(message)s"
defaults:
role: "-"
filters:
# Filter out any log records that have the extra field "partial" set to True, indicating that they
# will be replaced later.
partial:
(): control_backend.logging.PartialFilter
handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: colored
filters: [partial]
stream: ext://sys.stdout
ui:
class: zmq.log.handlers.PUBHandler
level: LLM
formatter: json_experiment
formatter: json
file:
class: control_backend.logging.DatedFileHandler
formatter: experiment
filters: [partial]
# Directory must match config.logging_settings.experiment_log_directory
file_prefix: experiment_logs/experiment
# Level of external libraries
# Level for external libraries
root:
level: WARN
handlers: [console]
@@ -39,3 +60,6 @@ loggers:
control_backend:
level: LLM
handlers: [ui]
experiment: # This name must match config.logging_settings.experiment_logger_name
level: DEBUG
handlers: [ui, file]

View File

@@ -7,6 +7,7 @@ requires-python = ">=3.13"
dependencies = [
"agentspeak>=0.2.2",
"colorlog>=6.10.1",
"deepface>=0.0.96",
"fastapi[all]>=0.115.6",
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
"numpy>=2.3.3",
@@ -21,6 +22,7 @@ dependencies = [
"silero-vad>=6.0.0",
"sphinx>=7.3.7",
"sphinx-rtd-theme>=3.0.2",
"tf-keras>=2.20.1",
"torch>=2.8.0",
"uvicorn>=0.37.0",
]
@@ -48,6 +50,7 @@ test = [
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-mock>=3.15.1",
"python-slugify>=8.0.4",
"pyyaml>=6.0.3",
"pyzmq>=27.1.0",
"soundfile>=0.13.1",

View File

@@ -1 +1,5 @@
"""
This package contains all agent implementations for the PepperPlus Control Backend.
"""
from .base import BaseAgent as BaseAgent

View File

@@ -1,2 +1,6 @@
"""
Agents responsible for controlling the robot's physical actions, such as speech and gestures.
"""
from .robot_gesture_agent import RobotGestureAgent as RobotGestureAgent
from .robot_speech_agent import RobotSpeechAgent as RobotSpeechAgent

View File

@@ -1,4 +1,5 @@
import json
import logging
import zmq
import zmq.asyncio as azmq
@@ -8,6 +9,8 @@ from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
class RobotGestureAgent(BaseAgent):
"""
@@ -111,6 +114,7 @@ class RobotGestureAgent(BaseAgent):
gesture_command.data,
)
return
experiment_logger.action("Gesture: %s", gesture_command.data)
await self.pubsocket.send_json(gesture_command.model_dump())
except Exception:
self.logger.exception("Error processing internal message.")

View File

@@ -1,9 +1,10 @@
import logging
from abc import ABC
from control_backend.core.agent_system import BaseAgent as CoreBaseAgent
class BaseAgent(CoreBaseAgent):
class BaseAgent(CoreBaseAgent, ABC):
"""
The primary base class for all implementation agents.

View File

@@ -1,3 +1,8 @@
"""
Agents and utilities for the BDI (Belief-Desire-Intention) reasoning system,
implementing AgentSpeak(L) logic.
"""
from control_backend.agents.bdi.bdi_core_agent import BDICoreAgent as BDICoreAgent
from .text_belief_extractor_agent import (

View File

@@ -80,7 +80,7 @@ class AstTerm(AstExpression, ABC):
@dataclass(eq=False)
class AstAtom(AstTerm):
"""
Grounded expression in all lowercase.
Represents a grounded atom in AgentSpeak (e.g., lowercase constants).
"""
value: str
@@ -92,7 +92,7 @@ class AstAtom(AstTerm):
@dataclass(eq=False)
class AstVar(AstTerm):
"""
Ungrounded variable expression. First letter capitalized.
Represents an ungrounded variable in AgentSpeak (e.g., capitalized names).
"""
name: str
@@ -103,6 +103,10 @@ class AstVar(AstTerm):
@dataclass(eq=False)
class AstNumber(AstTerm):
"""
Represents a numeric constant in AgentSpeak.
"""
value: int | float
def _to_agentspeak(self) -> str:
@@ -111,6 +115,10 @@ class AstNumber(AstTerm):
@dataclass(eq=False)
class AstString(AstTerm):
"""
Represents a string literal in AgentSpeak.
"""
value: str
def _to_agentspeak(self) -> str:
@@ -119,6 +127,10 @@ class AstString(AstTerm):
@dataclass(eq=False)
class AstLiteral(AstTerm):
"""
Represents a literal (functor and terms) in AgentSpeak.
"""
functor: str
terms: list[AstTerm] = field(default_factory=list)
@@ -142,6 +154,10 @@ class BinaryOperatorType(StrEnum):
@dataclass
class AstBinaryOp(AstExpression):
"""
Represents a binary logical or relational operation in AgentSpeak.
"""
left: AstExpression
operator: BinaryOperatorType
right: AstExpression
@@ -167,6 +183,10 @@ class AstBinaryOp(AstExpression):
@dataclass
class AstLogicalExpression(AstExpression):
"""
Represents a logical expression, potentially negated, in AgentSpeak.
"""
expression: AstExpression
negated: bool = False
@@ -208,6 +228,10 @@ class AstStatement(AstNode):
@dataclass
class AstRule(AstNode):
"""
Represents an inference rule in AgentSpeak. If there is no condition, it always holds.
"""
result: AstExpression
condition: AstExpression | None = None
@@ -231,6 +255,10 @@ class TriggerType(StrEnum):
@dataclass
class AstPlan(AstNode):
"""
Represents a plan in AgentSpeak, consisting of a trigger, context, and body.
"""
type: TriggerType
trigger_literal: AstExpression
context: list[AstExpression]
@@ -260,6 +288,10 @@ class AstPlan(AstNode):
@dataclass
class AstProgram(AstNode):
"""
Represents a full AgentSpeak program, consisting of rules and plans.
"""
rules: list[AstRule] = field(default_factory=list)
plans: list[AstPlan] = field(default_factory=list)

View File

@@ -22,6 +22,7 @@ from control_backend.schemas.program import (
BaseGoal,
BasicNorm,
ConditionalNorm,
EmotionBelief,
GestureAction,
Goal,
InferredBelief,
@@ -40,9 +41,23 @@ from control_backend.schemas.program import (
class AgentSpeakGenerator:
"""
Generator class that translates a high-level :class:`~control_backend.schemas.program.Program`
into AgentSpeak(L) source code.
It handles the conversion of phases, norms, goals, and triggers into AgentSpeak rules and plans,
ensuring the robot follows the defined behavioral logic.
"""
_asp: AstProgram
def generate(self, program: Program) -> str:
"""
Translates a Program object into an AgentSpeak source string.
:param program: The behavior program to translate.
:return: The generated AgentSpeak code as a string.
"""
self._asp = AstProgram()
if program.phases:
@@ -424,6 +439,16 @@ class AgentSpeakGenerator:
)
)
# Force phase transition fallback
self._asp.plans.append(
AstPlan(
TriggerType.ADDED_GOAL,
AstLiteral("force_transition_phase"),
[],
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
)
)
@singledispatchmethod
def _astify(self, element: ProgramElement) -> AstExpression:
raise NotImplementedError(f"Cannot convert element {element} to an AgentSpeak expression.")
@@ -435,6 +460,10 @@ class AgentSpeakGenerator:
@_astify.register
def _(self, sb: SemanticBelief) -> AstExpression:
return AstLiteral(self.slugify(sb))
@_astify.register
def _(self, eb: EmotionBelief) -> AstExpression:
return AstLiteral("emotion_detected", [AstAtom(eb.emotion)])
@_astify.register
def _(self, ib: InferredBelief) -> AstExpression:

View File

@@ -1,6 +1,7 @@
import asyncio
import copy
import json
import logging
import time
from collections.abc import Iterable
@@ -19,6 +20,9 @@ from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, Speec
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
class BDICoreAgent(BaseAgent):
"""
BDI Core Agent.
@@ -207,6 +211,9 @@ class BDICoreAgent(BaseAgent):
else:
term = agentspeak.Literal(name)
if name != "user_said":
experiment_logger.observation(f"Formed new belief: {name}{f'={args}' if args else ''}")
self.bdi_agent.call(
agentspeak.Trigger.addition,
agentspeak.GoalType.belief,
@@ -244,6 +251,9 @@ class BDICoreAgent(BaseAgent):
new_args = (agentspeak.Literal(arg) for arg in args)
term = agentspeak.Literal(name, new_args)
if name != "user_said":
experiment_logger.observation(f"Removed belief: {name}{f'={args}' if args else ''}")
result = self.bdi_agent.call(
agentspeak.Trigger.removal,
agentspeak.GoalType.belief,
@@ -338,7 +348,7 @@ class BDICoreAgent(BaseAgent):
yield
@self.actions.add(".reply_with_goal", 3)
def _reply_with_goal(agent: "BDICoreAgent", term, intention):
def _reply_with_goal(agent, term, intention):
"""
Let the LLM generate a response to a user's utterance with the current norms and a
specific goal.
@@ -386,6 +396,8 @@ class BDICoreAgent(BaseAgent):
body=str(message_text),
)
experiment_logger.chat(str(message_text), extra={"role": "assistant"})
self.add_behavior(self.send(chat_history_message))
yield
@@ -441,6 +453,7 @@ class BDICoreAgent(BaseAgent):
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("Started trigger %s", trigger_name)
experiment_logger.observation("Triggered: %s", trigger_name)
msg = InternalMessage(
to=settings.agent_settings.user_interrupt_name,
@@ -512,10 +525,6 @@ class BDICoreAgent(BaseAgent):
yield
@self.actions.add(".notify_ui", 0)
def _notify_ui(agent, term, intention):
pass
async def _send_to_llm(self, text: str, norms: str, goals: str):
"""
Sends a text query to the LLM agent asynchronously.

View File

@@ -18,6 +18,12 @@ type JSONLike = None | bool | int | float | str | list["JSONLike"] | dict[str, "
class BeliefState(BaseModel):
"""
Represents the state of inferred semantic beliefs.
Maintains sets of beliefs that are currently considered true or false.
"""
true: set[InternalBelief] = set()
false: set[InternalBelief] = set()
@@ -312,6 +318,9 @@ class TextBeliefExtractorAgent(BaseAgent):
async with httpx.AsyncClient() as client:
response = await client.post(
settings.llm_settings.local_llm_url,
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
if settings.llm_settings.api_key
else {},
json={
"model": settings.llm_settings.local_llm_model,
"messages": [{"role": "user", "content": prompt}],
@@ -338,7 +347,7 @@ class TextBeliefExtractorAgent(BaseAgent):
class SemanticBeliefInferrer:
"""
Class that handles only prompting an LLM for semantic beliefs.
Infers semantic beliefs from conversation history using an LLM.
"""
def __init__(
@@ -464,6 +473,10 @@ Respond with a JSON similar to the following, but with the property names as giv
class GoalAchievementInferrer(SemanticBeliefInferrer):
"""
Infers whether specific conversational goals have been achieved using an LLM.
"""
def __init__(self, llm: TextBeliefExtractorAgent.LLM):
super().__init__(llm)
self.goals: set[BaseGoal] = set()

View File

@@ -1 +1,5 @@
"""
Agents responsible for external communication and service discovery.
"""
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent

View File

@@ -8,6 +8,9 @@ from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent import ( # noqa
VisualEmotionRecognitionAgent,
)
from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
from control_backend.schemas.ri_message import PauseCommand
@@ -52,6 +55,7 @@ class RICommunicationAgent(BaseAgent):
self.connected = False
self.gesture_agent: RobotGestureAgent | None = None
self.speech_agent: RobotSpeechAgent | None = None
self.visual_emotion_recognition_agent: VisualEmotionRecognitionAgent | None = None
async def setup(self):
"""
@@ -209,6 +213,14 @@ class RICommunicationAgent(BaseAgent):
case "audio":
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
await vad_agent.start()
case "video":
visual_emotion_agent = VisualEmotionRecognitionAgent(
settings.agent_settings.visual_emotion_recognition_name,
socket_address=addr,
bind=bind,
)
self.visual_emotion_recognition_agent = visual_emotion_agent
await visual_emotion_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)
@@ -313,6 +325,9 @@ class RICommunicationAgent(BaseAgent):
if self.speech_agent is not None:
await self.speech_agent.stop()
if self.visual_emotion_recognition_agent is not None:
await self.visual_emotion_recognition_agent.stop()
if self.pub_socket is not None:
self.pub_socket.close()
@@ -322,6 +337,7 @@ class RICommunicationAgent(BaseAgent):
self.connected = True
async def handle_message(self, msg: InternalMessage):
return
try:
pause_command = PauseCommand.model_validate_json(msg.body)
await self._req_socket.send_json(pause_command.model_dump())

View File

@@ -1 +1,5 @@
"""
Agents that interface with Large Language Models for natural language processing and generation.
"""
from .llm_agent import LLMAgent as LLMAgent

View File

@@ -1,4 +1,6 @@
import asyncio
import json
import logging
import re
import uuid
from collections.abc import AsyncGenerator
@@ -13,6 +15,8 @@ from control_backend.core.config import settings
from ...schemas.llm_prompt_message import LLMPromptMessage
from .llm_instructions import LLMInstructions
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
class LLMAgent(BaseAgent):
"""
@@ -32,6 +36,10 @@ class LLMAgent(BaseAgent):
def __init__(self, name: str):
super().__init__(name)
self.history = []
self._querying = False
self._interrupted = False
self._interrupted_message = ""
self._go_ahead = asyncio.Event()
async def setup(self):
self.logger.info("Setting up %s.", self.name)
@@ -50,13 +58,13 @@ class LLMAgent(BaseAgent):
case "prompt_message":
try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message)
self.add_behavior(self._process_bdi_message(prompt_message)) # no block
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
case "assistant_message":
self.history.append({"role": "assistant", "content": msg.body})
self._apply_conversation_message({"role": "assistant", "content": msg.body})
case "user_message":
self.history.append({"role": "user", "content": msg.body})
self._apply_conversation_message({"role": "user", "content": msg.body})
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
if msg.body == "clear_history":
self.logger.debug("Clearing conversation history.")
@@ -73,12 +81,45 @@ class LLMAgent(BaseAgent):
:param message: The parsed prompt message containing text, norms, and goals.
"""
if self._querying:
self.logger.debug("Received another BDI prompt while processing previous message.")
self._interrupted = True # interrupt the previous processing
await self._go_ahead.wait() # wait until we get the go-ahead
message.text = f"{self._interrupted_message} {message.text}"
self._go_ahead.clear()
self._querying = True
full_message = ""
async for chunk in self._query_llm(message.text, message.norms, message.goals):
if self._interrupted:
self._interrupted_message = message.text
self.logger.debug("Interrupted processing of previous message.")
break
await self._send_reply(chunk)
full_message += chunk
self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.")
await self._send_full_reply(full_message)
else:
self._querying = False
self._apply_conversation_message(
{
"role": "assistant",
"content": full_message,
}
)
self.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI core."
)
await self._send_full_reply(full_message)
self._go_ahead.set()
self._interrupted = False
def _apply_conversation_message(self, message: dict[str, str]):
if len(self.history) > 0 and message["role"] == self.history[-1]["role"]:
self.history[-1]["content"] += " " + message["content"]
return
self.history.append(message)
async def _send_reply(self, msg: str):
"""
@@ -132,7 +173,7 @@ class LLMAgent(BaseAgent):
*self.history,
]
message_id = str(uuid.uuid4()) # noqa
message_id = str(uuid.uuid4())
try:
full_message = ""
@@ -141,10 +182,9 @@ class LLMAgent(BaseAgent):
full_message += token
current_chunk += token
self.logger.llm(
"Received token: %s",
experiment_logger.chat(
full_message,
extra={"reference": message_id}, # Used in the UI to update old logs
extra={"role": "assistant", "reference": message_id, "partial": True},
)
# Stream the message in chunks separated by punctuation.
@@ -160,11 +200,9 @@ class LLMAgent(BaseAgent):
if current_chunk:
yield current_chunk
self.history.append(
{
"role": "assistant",
"content": full_message,
}
experiment_logger.chat(
full_message,
extra={"role": "assistant", "reference": message_id, "partial": False},
)
except httpx.HTTPError as err:
self.logger.error("HTTP error.", exc_info=err)
@@ -181,10 +219,13 @@ class LLMAgent(BaseAgent):
:yield: Raw text tokens (deltas) from the SSE stream.
:raises httpx.HTTPError: If the API returns a non-200 status.
"""
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=httpx.Timeout(20.0)) as client:
async with client.stream(
"POST",
settings.llm_settings.local_llm_url,
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
if settings.llm_settings.api_key
else {},
json={
"model": settings.llm_settings.local_llm_model,
"messages": messages,

View File

@@ -1,3 +1,8 @@
"""
Agents responsible for processing sensory input, such as audio transcription and voice activity
detection.
"""
from .transcription_agent.transcription_agent import (
TranscriptionAgent as TranscriptionAgent,
)

View File

@@ -145,4 +145,6 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model()
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"]
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))[
"text"
].strip()

View File

@@ -1,4 +1,5 @@
import asyncio
import logging
import numpy as np
import zmq
@@ -10,6 +11,8 @@ from control_backend.core.config import settings
from .speech_recognizer import SpeechRecognizer
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
class TranscriptionAgent(BaseAgent):
"""
@@ -25,6 +28,8 @@ class TranscriptionAgent(BaseAgent):
:ivar audio_in_socket: The ZMQ SUB socket instance.
:ivar speech_recognizer: The speech recognition engine instance.
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
:ivar _current_speech_reference: The reference of the current user utterance, for synchronising
experiment logs.
"""
def __init__(self, audio_in_address: str):
@@ -39,6 +44,7 @@ class TranscriptionAgent(BaseAgent):
self.audio_in_socket: azmq.Socket | None = None
self.speech_recognizer = None
self._concurrency = None
self._current_speech_reference: str | None = None
async def setup(self):
"""
@@ -63,6 +69,10 @@ class TranscriptionAgent(BaseAgent):
self.logger.info("Finished setting up %s", self.name)
async def handle_message(self, msg: InternalMessage):
if msg.thread == "voice_activity":
self._current_speech_reference = msg.body
async def stop(self):
"""
Stop the agent and close sockets.
@@ -74,7 +84,7 @@ class TranscriptionAgent(BaseAgent):
def _connect_audio_in_socket(self):
"""
Helper to connect the ZMQ SUB socket for audio input.
Connects the ZMQ SUB socket for receiving audio data.
"""
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
@@ -96,24 +106,25 @@ class TranscriptionAgent(BaseAgent):
async def _share_transcription(self, transcription: str):
"""
Share a transcription to the other agents that depend on it.
Share a transcription to the other agents that depend on it, and to experiment logs.
Currently sends to:
- :attr:`settings.agent_settings.text_belief_extractor_name`
- The UI via the experiment logger
:param transcription: The transcribed text.
"""
receiver_names = [
settings.agent_settings.text_belief_extractor_name,
]
experiment_logger.chat(
transcription,
extra={"role": "user", "reference": self._current_speech_reference, "partial": False},
)
for receiver_name in receiver_names:
message = InternalMessage(
to=receiver_name,
sender=self.name,
body=transcription,
)
await self.send(message)
message = InternalMessage(
to=settings.agent_settings.text_belief_extractor_name,
sender=self.name,
body=transcription,
)
await self.send(message)
async def _transcribing_loop(self) -> None:
"""
@@ -129,10 +140,9 @@ class TranscriptionAgent(BaseAgent):
audio = np.frombuffer(audio_data, dtype=np.float32)
speech = await self._transcribe(audio)
if not speech:
self.logger.info("Nothing transcribed.")
self.logger.debug("Nothing transcribed.")
continue
self.logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)
except Exception as e:
self.logger.error(f"Error in transcription loop: {e}")

View File

@@ -1,4 +1,6 @@
import asyncio
import logging
import uuid
import numpy as np
import torch
@@ -12,6 +14,8 @@ from control_backend.schemas.internal_message import InternalMessage
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
from .transcription_agent.transcription_agent import TranscriptionAgent
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
class SocketPoller[T]:
"""
@@ -252,6 +256,18 @@ class VADAgent(BaseAgent):
if prob > prob_threshold:
if self.i_since_speech > non_speech_patience + begin_silence_length:
self.logger.debug("Speech started.")
reference = str(uuid.uuid4())
experiment_logger.chat(
"...",
extra={"role": "user", "reference": reference, "partial": True},
)
await self.send(
InternalMessage(
to=settings.agent_settings.transcription_name,
body=reference,
thread="voice_activity",
)
)
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
continue

View File

@@ -0,0 +1,167 @@
import json
import time
from collections import Counter, defaultdict
import cv2
import numpy as np
import zmq
import zmq.asyncio as azmq
from pydantic_core import ValidationError
from control_backend.agents import BaseAgent
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognizer import ( # noqa
DeepFaceEmotionRecognizer,
)
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief
class VisualEmotionRecognitionAgent(BaseAgent):
def __init__(
self,
name: str,
socket_address: str,
bind: bool = False,
timeout_ms: int = 1000,
window_duration: int = settings.behaviour_settings.visual_emotion_recognition_window_duration_s, # noqa
min_frames_required: int = settings.behaviour_settings.visual_emotion_recognition_min_frames_per_face, # noqa
):
"""
Initialize the Visual Emotion Recognition Agent.
:param name: Name of the agent
:param socket_address: Address of the socket to connect or bind to
:param bind: Whether to bind to the socket address (True) or connect (False)
:param timeout_ms: Timeout for socket receive operations in milliseconds
:param window_duration: Duration in seconds over which to aggregate emotions
:param min_frames_required: Minimum number of frames per face required to consider a face
valid
"""
super().__init__(name)
self.socket_address = socket_address
self.socket_bind = bind
self.timeout_ms = timeout_ms
self.window_duration = window_duration
self.min_frames_required = min_frames_required
async def setup(self):
"""
Initialize the agent resources.
1. Initializes the :class:`VisualEmotionRecognizer`.
2. Connects to the video input ZMQ socket.
3. Starts the background emotion recognition loop.
"""
self.logger.info("Setting up %s.", self.name)
self.emotion_recognizer = DeepFaceEmotionRecognizer()
self.video_in_socket = azmq.Context.instance().socket(zmq.SUB)
if self.socket_bind:
self.video_in_socket.bind(self.socket_address)
else:
self.video_in_socket.connect(self.socket_address)
self.video_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.video_in_socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
self.video_in_socket.setsockopt(zmq.CONFLATE, 1)
self.add_behavior(self.emotion_update_loop())
async def emotion_update_loop(self):
"""
Background loop to receive video frames, recognize emotions, and update beliefs.
1. Receives video frames from the ZMQ socket.
2. Uses the :class:`VisualEmotionRecognizer` to detect emotions.
3. Aggregates emotions over a time window.
4. Sends updates to the BDI Core Agent about detected emotions.
"""
# Next time to process the window and update emotions
next_window_time = time.time() + self.window_duration
# Tracks counts of detected emotions per face index
face_stats = defaultdict(Counter)
prev_dominant_emotions = set()
while self._running:
try:
frame_bytes = await self.video_in_socket.recv()
# Convert bytes to a numpy buffer
nparr = np.frombuffer(frame_bytes, np.uint8)
# Decode image into the generic Numpy Array DeepFace expects
frame_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame_image is None:
# Could not decode image, skip this frame
self.logger.warning("Received invalid video frame, skipping.")
continue
# Get the dominant emotion from each face
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame_image)
# Update emotion counts for each detected face
for i, emotion in enumerate(current_emotions):
face_stats[i][emotion] += 1
# If window duration has passed, process the collected stats
if time.time() >= next_window_time:
window_dominant_emotions = set()
# Determine dominant emotion for each face in the window
for _, counter in face_stats.items():
total_detections = sum(counter.values())
if total_detections >= self.min_frames_required:
dominant_emotion = counter.most_common(1)[0][0]
window_dominant_emotions.add(dominant_emotion)
await self.update_emotions(prev_dominant_emotions, window_dominant_emotions)
prev_dominant_emotions = window_dominant_emotions
face_stats.clear()
next_window_time = time.time() + self.window_duration
except zmq.Again:
self.logger.warning("No video frame received within timeout.")
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
"""
Compare emotions from previous window and current emotions,
send updates to BDI Core Agent.
"""
emotions_to_remove = prev_emotions - emotions
emotions_to_add = emotions - prev_emotions
if not emotions_to_add and not emotions_to_remove:
return
emotion_beliefs_remove = []
for emotion in emotions_to_remove:
self.logger.info(f"Emotion '{emotion}' has disappeared.")
try:
emotion_beliefs_remove.append(
Belief(name="emotion_detected", arguments=[emotion], remove=True)
)
except ValidationError:
self.logger.warning("Invalid belief for emotion removal: %s", emotion)
emotion_beliefs_add = []
for emotion in emotions_to_add:
self.logger.info(f"New emotion detected: '{emotion}'")
try:
emotion_beliefs_add.append(Belief(name="emotion_detected", arguments=[emotion]))
except ValidationError:
self.logger.warning("Invalid belief for new emotion: %s", emotion)
beliefs_list_add = [b.model_dump() for b in emotion_beliefs_add]
beliefs_list_remove = [b.model_dump() for b in emotion_beliefs_remove]
payload = {"create": beliefs_list_add, "delete": beliefs_list_remove}
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=json.dumps(payload),
thread="beliefs",
)
await self.send(message)

View File

@@ -0,0 +1,55 @@
import abc
import numpy as np
from deepface import DeepFace
class VisualEmotionRecognizer(abc.ABC):
@abc.abstractmethod
def load_model(self):
"""Load the visual emotion recognition model into memory."""
pass
@abc.abstractmethod
def sorted_dominant_emotions(self, image) -> list[str]:
"""
Recognize dominant emotions from faces in the given image.
Emotions can be one of ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'].
To minimize false positives, consider filtering faces with low confidence.
:param image: The input image for emotion recognition.
:return: List of dominant emotion detected for each face in the image,
sorted per face.
"""
pass
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
"""
DeepFace-based implementation of VisualEmotionRecognizer.
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
emotions to be over-represented.
"""
def __init__(self):
self.load_model()
def load_model(self):
print("Loading Deepface Emotion Model...")
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
# analyze does not take a model as an argument, calling it once on a dummy image to load
# the model
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
print("Deepface Emotion Model loaded.")
def sorted_dominant_emotions(self, image) -> list[str]:
analysis = DeepFace.analyze(image,
actions=['emotion'],
enforce_detection=False
)
# Sort faces by x coordinate to maintain left-to-right order
analysis.sort(key=lambda face: face['region']['x'])
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
dominant_emotions = [face['dominant_emotion'] for face in analysis]
return dominant_emotions

View File

@@ -1,4 +1,5 @@
import json
import logging
import zmq
from zmq.asyncio import Context
@@ -8,7 +9,7 @@ from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.program import ConditionalNorm, Program
from control_backend.schemas.program import ConditionalNorm, Goal, Program
from control_backend.schemas.ri_message import (
GestureCommand,
PauseCommand,
@@ -16,6 +17,8 @@ from control_backend.schemas.ri_message import (
SpeechCommand,
)
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
class UserInterruptAgent(BaseAgent):
"""
@@ -26,7 +29,7 @@ class UserInterruptAgent(BaseAgent):
- Send a prioritized message to the `RobotSpeechAgent`
- Send a prioritized gesture to the `RobotGestureAgent`
- Send a belief override to the `BDIProgramManager`in order to activate a
- Send a belief override to the `BDI Core` in order to activate a
trigger/conditional norm or complete a goal.
Prioritized actions clear the current RI queue before inserting the new item,
@@ -50,10 +53,8 @@ class UserInterruptAgent(BaseAgent):
async def setup(self):
"""
Initialize the agent.
Connects the internal ZMQ SUB socket and subscribes to the 'button_pressed' topic.
Starts the background behavior to receive the user interrupts.
Initialize the agent by setting up ZMQ sockets for receiving button events and
publishing updates.
"""
context = Context.instance()
@@ -68,17 +69,15 @@ class UserInterruptAgent(BaseAgent):
async def _receive_button_event(self):
"""
The behaviour of the UserInterruptAgent.
Continuous loop that receives button_pressed events from the button_pressed HTTP endpoint.
These events contain a type and a context.
Main loop to receive and process button press events from the UI.
These are the different types and contexts:
- type: "speech", context: string that the robot has to say.
- type: "gesture", context: single gesture name that the robot has to perform.
- type: "override", context: belief_id that overrides the goal/trigger/conditional norm.
- type: "pause", context: boolean indicating whether to pause
- type: "reset_phase", context: None, indicates to the BDI Core to
- type: "reset_experiment", context: None, indicates to the BDI Core to
Handles different event types:
- `speech`: Triggers immediate robot speech.
- `gesture`: Triggers an immediate robot gesture.
- `override`: Forces a belief, trigger, or goal completion in the BDI core.
- `override_unachieve`: Removes a belief from the BDI core.
- `pause`: Toggles the system's pause state.
- `next_phase` / `reset_phase`: Controls experiment flow.
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
@@ -93,71 +92,88 @@ class UserInterruptAgent(BaseAgent):
self.logger.debug("Received event type %s", event_type)
if event_type == "speech":
await self._send_to_speech_agent(event_context)
self.logger.info(
"Forwarded button press (speech) with context '%s' to RobotSpeechAgent.",
event_context,
)
elif event_type == "gesture":
await self._send_to_gesture_agent(event_context)
self.logger.info(
"Forwarded button press (gesture) with context '%s' to RobotGestureAgent.",
event_context,
)
elif event_type == "override":
ui_id = str(event_context)
if asl_trigger := self._trigger_map.get(ui_id):
await self._send_to_bdi("force_trigger", asl_trigger)
match event_type:
case "speech":
await self._send_to_speech_agent(event_context)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
"Forwarded button press (speech) with context '%s' to RobotSpeechAgent.",
event_context,
)
elif asl_cond_norm := self._cond_norm_map.get(ui_id):
await self._send_to_bdi("force_norm", asl_cond_norm)
case "gesture":
await self._send_to_gesture_agent(event_context)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
"Forwarded button press (gesture) with context '%s' to RobotGestureAgent.",
event_context,
)
elif asl_goal := self._goal_map.get(ui_id):
await self._send_to_bdi_belief(asl_goal)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
case "override":
ui_id = str(event_context)
if asl_trigger := self._trigger_map.get(ui_id):
await self._send_to_bdi("force_trigger", asl_trigger)
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
event_context,
)
elif asl_cond_norm := self._cond_norm_map.get(ui_id):
await self._send_to_bdi_belief(asl_cond_norm, "cond_norm")
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
event_context,
)
elif asl_goal := self._goal_map.get(ui_id):
await self._send_to_bdi_belief(asl_goal, "goal")
self.logger.info(
"Forwarded button press (override) with context '%s' to BDI Core.",
event_context,
)
# Send achieve_goal to program manager to update semantic belief extractor
goal_achieve_msg = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
thread="achieve_goal",
body=ui_id,
)
await self.send(goal_achieve_msg)
else:
self.logger.warning("Could not determine which element to override.")
case "override_unachieve":
ui_id = str(event_context)
if asl_cond_norm := self._cond_norm_map.get(ui_id):
await self._send_to_bdi_belief(asl_cond_norm, "cond_norm", True)
self.logger.info(
"Forwarded button press (override_unachieve)"
"with context '%s' to BDI Core.",
event_context,
)
else:
self.logger.warning(
"Could not determine which conditional norm to unachieve."
)
case "pause":
self.logger.debug(
"Received pause/resume button press with context '%s'.", event_context
)
await self._send_pause_command(event_context)
if event_context:
self.logger.info("Sent pause command.")
else:
self.logger.info("Sent resume command.")
case "next_phase" | "reset_phase":
await self._send_experiment_control_to_bdi_core(event_type)
case _:
self.logger.warning(
"Received button press with unknown type '%s' (context: '%s').",
event_type,
event_context,
)
goal_achieve_msg = InternalMessage(
to=settings.agent_settings.bdi_program_manager_name,
thread="achieve_goal",
body=ui_id,
)
await self.send(goal_achieve_msg)
else:
self.logger.warning("Could not determine which element to override.")
elif event_type == "pause":
self.logger.debug(
"Received pause/resume button press with context '%s'.", event_context
)
await self._send_pause_command(event_context)
if event_context:
self.logger.info("Sent pause command.")
else:
self.logger.info("Sent resume command.")
elif event_type in ["next_phase", "reset_phase", "reset_experiment"]:
await self._send_experiment_control_to_bdi_core(event_type)
else:
self.logger.warning(
"Received button press with unknown type '%s' (context: '%s').",
event_type,
event_context,
)
async def handle_message(self, msg: InternalMessage):
"""
Handle commands received from other internal Python agents.
Handles internal messages from other agents, such as program updates or trigger
notifications.
:param msg: The incoming :class:`~control_backend.core.agent_system.InternalMessage`.
"""
match msg.thread:
case "new_program":
@@ -171,11 +187,9 @@ class UserInterruptAgent(BaseAgent):
payload = {"type": "trigger_update", "id": ui_id, "achieved": True}
await self._send_experiment_update(payload)
self.logger.info(f"UI Update: Trigger {asl_slug} started (ID: {ui_id})")
case "trigger_end":
asl_slug = msg.body
ui_id = self._trigger_reverse_map.get(asl_slug)
if ui_id:
payload = {"type": "trigger_update", "id": ui_id, "achieved": False}
await self._send_experiment_update(payload)
@@ -183,6 +197,7 @@ class UserInterruptAgent(BaseAgent):
case "transition_phase":
new_phase_id = msg.body
self.logger.info(f"Phase transition detected: {new_phase_id}")
experiment_logger.observation("Transitioned to next phase.")
payload = {"type": "phase_update", "id": new_phase_id}
@@ -195,31 +210,37 @@ class UserInterruptAgent(BaseAgent):
await self._send_experiment_update(payload)
self.logger.info(f"UI Update: Goal {goal_name} started (ID: {ui_id})")
case "active_norms_update":
norm_list = [s.strip("() '\",") for s in msg.body.split(",") if s.strip("() '\",")]
await self._broadcast_cond_norms(norm_list)
active_norms_asl = [
s.strip("() '\",") for s in msg.body.split(",") if s.strip("() '\",")
]
await self._broadcast_cond_norms(active_norms_asl)
case _:
self.logger.debug(f"Received internal message on unhandled thread: {msg.thread}")
async def _broadcast_cond_norms(self, active_slugs: list[str]):
"""
Sends the current state of all conditional norms to the UI.
:param active_slugs: A list of slugs (strings) currently active in the BDI core.
Broadcasts the current activation state of all conditional norms to the UI.
:param active_slugs: A list of sluggified norm names currently active in the BDI core.
"""
updates = []
for asl_slug, ui_id in self._cond_norm_reverse_map.items():
is_active = asl_slug in active_slugs
updates.append({"id": ui_id, "name": asl_slug, "active": is_active})
updates.append({"id": ui_id, "active": is_active})
payload = {"type": "cond_norms_state_update", "norms": updates}
await self._send_experiment_update(payload, should_log=False)
# self.logger.debug(f"Broadcasted state for {len(updates)} conditional norms.")
if self.pub_socket:
topic = b"status"
body = json.dumps(payload).encode("utf-8")
await self.pub_socket.send_multipart([topic, body])
# self.logger.info(f"UI Update: Active norms {updates}")
def _create_mapping(self, program_json: str):
"""
Create mappings between UI IDs and ASL slugs for triggers, goals, and conditional norms
Creates a bidirectional mapping between UI identifiers and AgentSpeak slugs.
:param program_json: The JSON representation of the behavioral program.
"""
try:
program = Program.model_validate_json(program_json)
@@ -229,6 +250,18 @@ class UserInterruptAgent(BaseAgent):
self._cond_norm_map = {}
self._cond_norm_reverse_map = {}
def _register_goal(goal: Goal):
"""Recursively register goals and their subgoals."""
slug = AgentSpeakGenerator.slugify(goal)
self._goal_map[str(goal.id)] = slug
self._goal_reverse_map[slug] = str(goal.id)
# Recursively check steps for subgoals
if goal.plan and goal.plan.steps:
for step in goal.plan.steps:
if isinstance(step, Goal):
_register_goal(step)
for phase in program.phases:
for trigger in phase.triggers:
slug = AgentSpeakGenerator.slugify(trigger)
@@ -236,8 +269,7 @@ class UserInterruptAgent(BaseAgent):
self._trigger_reverse_map[slug] = str(trigger.id)
for goal in phase.goals:
self._goal_map[str(goal.id)] = AgentSpeakGenerator.slugify(goal)
self._goal_reverse_map[AgentSpeakGenerator.slugify(goal)] = str(goal.id)
_register_goal(goal)
for goal, id in self._goal_reverse_map.items():
self.logger.debug(f"Goal mapping: UI ID {goal} -> {id}")
@@ -261,8 +293,10 @@ class UserInterruptAgent(BaseAgent):
async def _send_experiment_update(self, data, should_log: bool = True):
"""
Sends an update to the 'experiment' topic.
The SSE endpoint will pick this up and push it to the UI.
Publishes an experiment state update to the internal ZMQ bus for the UI.
:param data: The update payload.
:param should_log: Whether to log the update.
"""
if self.pub_socket:
topic = b"experiment"
@@ -277,6 +311,7 @@ class UserInterruptAgent(BaseAgent):
:param text_to_say: The string that the robot has to say.
"""
experiment_logger.chat(text_to_say, extra={"role": "user"})
cmd = SpeechCommand(data=text_to_say, is_priority=True)
out_msg = InternalMessage(
to=settings.agent_settings.robot_speech_name,
@@ -308,12 +343,20 @@ class UserInterruptAgent(BaseAgent):
await self.send(msg)
self.logger.info(f"Directly forced {thread} in BDI: {body}")
async def _send_to_bdi_belief(self, asl_goal: str):
async def _send_to_bdi_belief(self, asl: str, asl_type: str, unachieve: bool = False):
"""Send belief to BDI Core"""
belief_name = f"achieved_{asl_goal}"
if asl_type == "goal":
belief_name = f"achieved_{asl}"
elif asl_type == "cond_norm":
belief_name = f"force_{asl}"
else:
self.logger.warning("Tried to send belief with unknown type")
belief = Belief(name=belief_name, arguments=None)
self.logger.debug(f"Sending belief to BDI Core: {belief_name}")
belief_message = BeliefMessage(create=[belief])
# Conditional norms are unachieved by removing the belief
belief_message = (
BeliefMessage(delete=[belief]) if unachieve else BeliefMessage(create=[belief])
)
msg = InternalMessage(
to=settings.agent_settings.bdi_core_name,
thread="beliefs",

View File

@@ -1,8 +1,9 @@
import logging
from pathlib import Path
import zmq
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from zmq.asyncio import Context
from control_backend.core.config import settings
@@ -38,3 +39,29 @@ async def log_stream():
yield f"data: {message}\n\n"
return StreamingResponse(gen(), media_type="text/event-stream")
LOGGING_DIR = Path(settings.logging_settings.experiment_log_directory).resolve()
@router.get("/logs/files")
@router.get("/api/logs/files")
async def log_directory():
"""
Get a list of all log files stored in the experiment log file directory.
"""
return [f.name for f in LOGGING_DIR.glob("*.log")]
@router.get("/logs/files/{filename}")
@router.get("/api/logs/files/{filename}")
async def log_file(filename: str):
# Prevent path-traversal
file_path = (LOGGING_DIR / filename).resolve() # This .resolve() is important
if not file_path.is_relative_to(LOGGING_DIR):
raise HTTPException(status_code=400, detail="Invalid filename.")
if not file_path.is_file():
raise HTTPException(status_code=404, detail="File not found.")
return FileResponse(file_path, filename=file_path.name)

View File

@@ -1,12 +0,0 @@
from fastapi import APIRouter, Request
router = APIRouter()
# TODO: implement
@router.get("/sse")
async def sse(request: Request):
"""
Placeholder for future Server-Sent Events endpoint.
"""
pass

View File

@@ -52,11 +52,11 @@ async def experiment_stream(request: Request):
while True:
# Check if client closed the tab
if await request.is_disconnected():
logger.info("Client disconnected from experiment stream.")
logger.error("Client disconnected from experiment stream.")
break
try:
parts = await asyncio.wait_for(socket.recv_multipart(), timeout=1.0)
parts = await asyncio.wait_for(socket.recv_multipart(), timeout=10.0)
_, message = parts
yield f"data: {message.decode().strip()}\n\n"
except TimeoutError:
@@ -65,3 +65,30 @@ async def experiment_stream(request: Request):
socket.close()
return StreamingResponse(gen(), media_type="text/event-stream")
@router.get("/status_stream")
async def status_stream(request: Request):
context = Context.instance()
socket = context.socket(zmq.SUB)
socket.connect(settings.zmq_settings.internal_sub_address)
socket.subscribe(b"status")
async def gen():
try:
while True:
if await request.is_disconnected():
break
try:
# Shorter timeout since this is frequent
parts = await asyncio.wait_for(socket.recv_multipart(), timeout=0.5)
_, message = parts
yield f"data: {message.decode().strip()}\n\n"
except TimeoutError:
yield ": ping\n\n" # Keep the connection alive
continue
finally:
socket.close()
return StreamingResponse(gen(), media_type="text/event-stream")

View File

@@ -1,13 +1,11 @@
from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import logs, message, program, robot, sse, user_interact
from control_backend.api.v1.endpoints import logs, message, program, robot, user_interact
api_router = APIRouter()
api_router.include_router(message.router, tags=["Messages"])
api_router.include_router(sse.router, tags=["SSE"])
api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"])
api_router.include_router(logs.router, tags=["Logs"])

View File

@@ -22,10 +22,22 @@ class AgentDirectory:
@staticmethod
def register(name: str, agent: "BaseAgent"):
"""
Registers an agent instance with a unique name.
:param name: The name of the agent.
:param agent: The :class:`BaseAgent` instance.
"""
_agent_directory[name] = agent
@staticmethod
def get(name: str) -> "BaseAgent | None":
"""
Retrieves a registered agent instance by name.
:param name: The name of the agent to retrieve.
:return: The :class:`BaseAgent` instance, or None if not found.
"""
return _agent_directory.get(name)

View File

@@ -50,6 +50,7 @@ class AgentSettings(BaseModel):
# agent names
bdi_core_name: str = "bdi_core_agent"
bdi_program_manager_name: str = "bdi_program_manager_agent"
visual_emotion_recognition_name: str = "visual_emotion_recognition_agent"
text_belief_extractor_name: str = "text_belief_extractor_agent"
vad_name: str = "vad_agent"
llm_name: str = "llm_agent"
@@ -77,6 +78,10 @@ class BehaviourSettings(BaseModel):
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
:ivar transcription_token_buffer: Buffer for transcription tokens.
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
:ivar visual_emotion_recognition_window_duration_s: Duration in seconds over which to aggregate
emotions and update emotion beliefs.
:ivar visual_emotion_recognition_min_frames_per_face: Minimum number of frames per face required
to consider a face valid.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
@@ -100,6 +105,9 @@ class BehaviourSettings(BaseModel):
# Text belief extractor settings
conversation_history_length_limit: int = 10
# Visual Emotion Recognition settings
visual_emotion_recognition_window_duration_s: int = 5
visual_emotion_recognition_min_frames_per_face: int = 3
class LLMSettings(BaseModel):
"""
@@ -117,6 +125,7 @@ class LLMSettings(BaseModel):
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "gpt-oss"
api_key: str = ""
chat_temperature: float = 1.0
code_temperature: float = 0.3
n_parallel: int = 4
@@ -153,6 +162,20 @@ class SpeechModelSettings(BaseModel):
openai_model_name: str = "small.en"
class LoggingSettings(BaseModel):
"""
Configuration for logging.
:ivar logging_config_file: Path to the logging configuration file.
:ivar experiment_log_directory: Location of the experiment logs. Must match the logging config.
:ivar experiment_logger_name: Name of the experiment logger. Must match the logging config.
"""
logging_config_file: str = ".logging_config.yaml"
experiment_log_directory: str = "experiment_logs"
experiment_logger_name: str = "experiment"
class Settings(BaseSettings):
"""
Global application settings.
@@ -174,6 +197,8 @@ class Settings(BaseSettings):
ri_host: str = "localhost"
logging_settings: LoggingSettings = LoggingSettings()
zmq_settings: ZMQSettings = ZMQSettings()
agent_settings: AgentSettings = AgentSettings()

View File

@@ -1 +1,4 @@
from .dated_file_handler import DatedFileHandler as DatedFileHandler
from .optional_field_formatter import OptionalFieldFormatter as OptionalFieldFormatter
from .partial_filter import PartialFilter as PartialFilter
from .setup_logging import setup_logging as setup_logging

View File

@@ -0,0 +1,29 @@
from datetime import datetime
from logging import FileHandler
from pathlib import Path
class DatedFileHandler(FileHandler):
def __init__(self, file_prefix: str, **kwargs):
if not file_prefix:
raise ValueError("`file_prefix` argument cannot be empty.")
self._file_prefix = file_prefix
kwargs["filename"] = self._make_filename()
super().__init__(**kwargs)
def _make_filename(self) -> str:
filepath = Path(f"{self._file_prefix}-{datetime.now():%Y%m%d-%H%M%S}.log")
if not filepath.parent.is_dir():
filepath.parent.mkdir(parents=True, exist_ok=True)
return str(filepath)
def do_rollover(self):
self.acquire()
try:
if self.stream:
self.stream.close()
self.baseFilename = self._make_filename()
self.stream = self._open()
finally:
self.release()

View File

@@ -0,0 +1,67 @@
import logging
import re
class OptionalFieldFormatter(logging.Formatter):
"""
Logging formatter that supports optional fields marked by `?`.
Optional fields are denoted by placing a `?` after the field name inside
the parentheses, e.g., `%(role?)s`. If the field is not provided in the
log call's `extra` dict, it will use the default value from `defaults`
or `None` if no default is specified.
:param fmt: Format string with optional `%(name?)s` style fields.
:type fmt: str or None
:param datefmt: Date format string, passed to parent :class:`logging.Formatter`.
:type datefmt: str or None
:param style: Formatting style, must be '%'. Passed to parent.
:type style: str
:param defaults: Default values for optional fields when not provided.
:type defaults: dict or None
:example:
>>> formatter = OptionalFieldFormatter(
... fmt="%(asctime)s %(levelname)s %(role?)s %(message)s",
... defaults={"role": ""-""}
... )
>>> handler = logging.StreamHandler()
>>> handler.setFormatter(formatter)
>>> logger = logging.getLogger(__name__)
>>> logger.addHandler(handler)
>>>
>>> logger.chat("Hello there!", extra={"role": "USER"})
2025-01-15 10:30:00 CHAT USER Hello there!
>>>
>>> logger.info("A logging message")
2025-01-15 10:30:01 INFO - A logging message
.. note::
Only `%`-style formatting is supported. The `{` and `$` styles are not
compatible with this formatter.
.. seealso::
:class:`logging.Formatter` for base formatter documentation.
"""
# Match %(name?)s or %(name?)d etc.
OPTIONAL_PATTERN = re.compile(r"%\((\w+)\?\)([sdifFeEgGxXocrba%])")
def __init__(self, fmt=None, datefmt=None, style="%", defaults=None):
self.defaults = defaults or {}
self.optional_fields = set(self.OPTIONAL_PATTERN.findall(fmt or ""))
# Convert %(name?)s to %(name)s for standard formatting
normalized_fmt = self.OPTIONAL_PATTERN.sub(r"%(\1)\2", fmt or "")
super().__init__(normalized_fmt, datefmt, style)
def format(self, record):
for field, _ in self.optional_fields:
if not hasattr(record, field):
default = self.defaults.get(field, None)
setattr(record, field, default)
return super().format(record)

View File

@@ -0,0 +1,10 @@
import logging
class PartialFilter(logging.Filter):
"""
Class to filter any log records that have the "partial" attribute set to ``True``.
"""
def filter(self, record):
return getattr(record, "partial", False) is not True

View File

@@ -37,7 +37,7 @@ def add_logging_level(level_name: str, level_num: int, method_name: str | None =
setattr(logging, method_name, log_to_root)
def setup_logging(path: str = ".logging_config.yaml") -> None:
def setup_logging(path: str = settings.logging_settings.logging_config_file) -> None:
"""
Setup logging configuration of the CB. Tries to load the logging configuration from a file,
in which we specify custom loggers, formatters, handlers, etc.
@@ -65,7 +65,7 @@ def setup_logging(path: str = ".logging_config.yaml") -> None:
# Patch ZMQ PUBHandler to know about custom levels
if custom_levels:
for logger_name in ("control_backend",):
for logger_name in config.get("loggers", {}):
logger = logging.getLogger(logger_name)
for handler in logger.handlers:
if isinstance(handler, PUBHandler):

View File

@@ -16,4 +16,10 @@ class BeliefList(BaseModel):
class GoalList(BaseModel):
"""
Represents a list of goals, used for communicating multiple goals between agents.
:ivar goals: The list of goals.
"""
goals: list[BaseGoal]

View File

@@ -2,9 +2,22 @@ from pydantic import BaseModel
class ChatMessage(BaseModel):
"""
Represents a single message in a conversation.
:ivar role: The role of the speaker (e.g., 'user', 'assistant').
:ivar content: The text content of the message.
"""
role: str
content: str
class ChatHistory(BaseModel):
"""
Represents a sequence of chat messages, forming a conversation history.
:ivar messages: An ordered list of :class:`ChatMessage` objects.
"""
messages: list[ChatMessage]

View File

@@ -2,5 +2,13 @@ from pydantic import BaseModel
class ButtonPressedEvent(BaseModel):
"""
Represents a button press event from the UI.
:ivar type: The type of event (e.g., 'speech', 'gesture', 'override').
:ivar context: Additional data associated with the event (e.g., speech text, gesture name,
or ID).
"""
type: str
context: str

View File

@@ -7,7 +7,7 @@ class InternalMessage(BaseModel):
"""
Standard message envelope for communication between agents within the Control Backend.
:ivar to: The name of the destination agent.
:ivar to: The name(s) of the destination agent(s).
:ivar sender: The name of the sending agent.
:ivar body: The string payload (often a JSON-serialized model).
:ivar thread: An optional thread identifier/topic to categorize the message (e.g., 'beliefs').

View File

@@ -20,19 +20,23 @@ class ProgramElement(BaseModel):
class LogicalOperator(Enum):
"""
Logical operators for combining beliefs.
"""
AND = "AND"
OR = "OR"
type Belief = KeywordBelief | SemanticBelief | InferredBelief
type BasicBelief = KeywordBelief | SemanticBelief
type Belief = KeywordBelief | SemanticBelief | InferredBelief | EmotionBelief
type BasicBelief = KeywordBelief | SemanticBelief | EmotionBelief
class KeywordBelief(ProgramElement):
"""
Represents a belief that is set when the user spoken text contains a certain keyword.
Represents a belief that is activated when a specific keyword is detected in the user's speech.
:ivar keyword: The keyword on which this belief gets set.
:ivar keyword: The string to look for in the transcription.
"""
name: str = ""
@@ -41,9 +45,11 @@ class KeywordBelief(ProgramElement):
class SemanticBelief(ProgramElement):
"""
Represents a belief that is set by semantic LLM validation.
Represents a belief whose truth value is determined by an LLM analyzing the conversation
context.
:ivar description: Description of how to form the belief, used by the LLM.
:ivar description: A natural language description of what this belief represents,
used as a prompt for the LLM.
"""
description: str
@@ -51,13 +57,11 @@ class SemanticBelief(ProgramElement):
class InferredBelief(ProgramElement):
"""
Represents a belief that gets formed by combining two beliefs with a logical AND or OR.
Represents a belief derived from other beliefs using logical operators.
These beliefs can also be :class:`InferredBelief`, leading to arbitrarily deep nesting.
:ivar operator: The logical operator to apply.
:ivar left: The left part of the logical expression.
:ivar right: The right part of the logical expression.
:ivar operator: The :class:`LogicalOperator` (AND/OR) to apply.
:ivar left: The left operand (another belief).
:ivar right: The right operand (another belief).
"""
name: str = ""
@@ -65,8 +69,24 @@ class InferredBelief(ProgramElement):
left: Belief
right: Belief
class EmotionBelief(ProgramElement):
"""
Represents a belief that is set when a certain emotion is detected.
:ivar emotion: The emotion on which this belief gets set.
"""
name: str = ""
emotion: str
class Norm(ProgramElement):
"""
Base class for behavioral norms that guide the robot's interactions.
:ivar norm: The textual description of the norm.
:ivar critical: Whether this norm is considered critical and should be strictly enforced.
"""
name: str = ""
norm: str
critical: bool = False
@@ -74,10 +94,7 @@ class Norm(ProgramElement):
class BasicNorm(Norm):
"""
Represents a behavioral norm.
:ivar norm: The actual norm text describing the behavior.
:ivar critical: When true, this norm should absolutely not be violated (checked separately).
A simple behavioral norm that is always considered for activation when its phase is active.
"""
pass
@@ -85,9 +102,9 @@ class BasicNorm(Norm):
class ConditionalNorm(Norm):
"""
Represents a norm that is only active when a condition is met (i.e., a certain belief holds).
A behavioral norm that is only active when a specific condition (belief) is met.
:ivar condition: When to activate this norm.
:ivar condition: The :class:`Belief` that must hold for this norm to be active.
"""
condition: Belief
@@ -140,9 +157,9 @@ type Action = SpeechAction | GestureAction | LLMAction
class SpeechAction(ProgramElement):
"""
Represents the action of the robot speaking a literal text.
An action where the robot speaks a predefined literal text.
:ivar text: The text to speak.
:ivar text: The text content to be spoken.
"""
name: str = ""
@@ -151,11 +168,10 @@ class SpeechAction(ProgramElement):
class Gesture(BaseModel):
"""
Represents a gesture to be performed. Can be either a single gesture,
or a random gesture from a category (tag).
Defines a physical gesture for the robot to perform.
:ivar type: The type of the gesture, "tag" or "single".
:ivar name: The name of the single gesture or tag.
:ivar type: Whether to use a specific "single" gesture or a random one from a "tag" category.
:ivar name: The identifier for the gesture or tag.
"""
type: Literal["tag", "single"]
@@ -164,9 +180,9 @@ class Gesture(BaseModel):
class GestureAction(ProgramElement):
"""
Represents the action of the robot performing a gesture.
An action where the robot performs a physical gesture.
:ivar gesture: The gesture to perform.
:ivar gesture: The :class:`Gesture` definition.
"""
name: str = ""
@@ -175,10 +191,9 @@ class GestureAction(ProgramElement):
class LLMAction(ProgramElement):
"""
Represents the action of letting an LLM generate a reply based on its chat history
and an additional goal added in the prompt.
An action that triggers an LLM-generated conversational response.
:ivar goal: The extra (temporary) goal to add to the LLM.
:ivar goal: A temporary conversational goal to guide the LLM's response generation.
"""
name: str = ""
@@ -187,10 +202,10 @@ class LLMAction(ProgramElement):
class Trigger(ProgramElement):
"""
Represents a belief-based trigger. When a belief is set, the corresponding plan is executed.
Defines a reactive behavior: when the condition (belief) is met, the plan is executed.
:ivar condition: When to activate the trigger.
:ivar plan: The plan to execute.
:ivar condition: The :class:`Belief` that triggers this behavior.
:ivar plan: The :class:`Plan` to execute upon activation.
"""
condition: Belief
@@ -199,11 +214,11 @@ class Trigger(ProgramElement):
class Phase(ProgramElement):
"""
A distinct phase within a program, containing norms, goals, and triggers.
A logical stage in the interaction program, grouping norms, goals, and triggers.
:ivar norms: List of norms active in this phase.
:ivar goals: List of goals to pursue in this phase.
:ivar triggers: List of triggers that define transitions out of this phase.
:ivar norms: List of norms active during this phase.
:ivar goals: List of goals the robot pursues in this phase.
:ivar triggers: List of reactive behaviors defined for this phase.
"""
name: str = ""
@@ -214,9 +229,15 @@ class Phase(ProgramElement):
class Program(BaseModel):
"""
Represents a complete interaction program, consisting of a sequence or set of phases.
The top-level container for a complete robot behavior definition.
:ivar phases: The list of phases that make up the program.
:ivar phases: An ordered list of :class:`Phase` objects defining the interaction flow.
"""
phases: list[Phase]
if __name__ == "__main__":
input = input("Enter program JSON: ")
program = Program.model_validate_json(input)
print(program)

View File

@@ -40,7 +40,7 @@ async def test_normal_setup(per_transcription_agent):
per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
@@ -106,7 +106,7 @@ async def test_out_socket_creation_failure(zmq_context):
per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
async def swallow_background_task(coro):
def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
@@ -126,7 +126,7 @@ async def test_stop(zmq_context, per_transcription_agent):
per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro):
def swallow_background_task(coro):
coro.close()
per_vad_agent.add_behavior = swallow_background_task
@@ -150,6 +150,7 @@ async def test_application_startup_complete(zmq_context):
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.close = MagicMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
]

View File

@@ -61,8 +61,52 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
thread="prompt_message", # REQUIRED: thread must match handle_message logic
)
agent._process_bdi_message = AsyncMock()
await agent.handle_message(msg)
agent._process_bdi_message.assert_called()
@pytest.mark.asyncio
async def test_process_bdi_message_success(mock_httpx_client, mock_settings):
# Setup the mock response for the stream
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
# Simulate stream lines
lines = [
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
b'data: {"choices": [{"delta": {"content": " world"}}]}',
b'data: {"choices": [{"delta": {"content": "."}}]}',
b"data: [DONE]",
]
async def aiter_lines_gen():
for line in lines:
yield line.decode()
mock_response.aiter_lines.side_effect = aiter_lines_gen
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
# Configure the client
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
# Setup Agent
agent = LLMAgent("llm_agent")
agent.send = AsyncMock() # Mock the send method to verify replies
mock_logger = MagicMock()
agent.logger = mock_logger
# Simulate receiving a message from BDI
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
await agent._process_bdi_message(prompt)
# Verification
# "Hello world." constitutes one sentence/chunk based on punctuation split
# The agent should call send once with the full sentence, PLUS once more for full reply
@@ -79,28 +123,16 @@ async def test_llm_processing_errors(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
thread="prompt_message",
)
# HTTP Error: stream method RAISES exception immediately
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
await agent.handle_message(msg)
await agent._process_bdi_message(prompt)
# Check that error message was sent
assert agent.send.called
assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body
# General Exception
agent.send.reset_mock()
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
await agent.handle_message(msg)
assert "Error processing the request." in agent.send.call_args_list[0][0][0].body
@pytest.mark.asyncio
async def test_llm_json_error(mock_httpx_client, mock_settings):
@@ -125,13 +157,7 @@ async def test_llm_json_error(mock_httpx_client, mock_settings):
agent.logger = MagicMock()
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
thread="prompt_message",
)
await agent.handle_message(msg)
await agent._process_bdi_message(prompt)
agent.logger.error.assert_called() # Should log JSONDecodeError

View File

@@ -63,7 +63,7 @@ async def test_send_to_bdi_belief(agent):
"""Verify belief update format."""
context_str = "some_goal"
await agent._send_to_bdi_belief(context_str)
await agent._send_to_bdi_belief(context_str, "goal")
assert agent.send.await_count == 1
sent_msg = agent.send.call_args.args[0]
@@ -115,7 +115,7 @@ async def test_receive_loop_routing_success(agent):
agent._send_to_gesture_agent.assert_awaited_once_with("Hello Gesture")
# Override (since we mapped it to a goal)
agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug")
agent._send_to_bdi_belief.assert_awaited_once_with("some_goal_slug", "goal")
assert agent._send_to_speech_agent.await_count == 1
assert agent._send_to_gesture_agent.await_count == 1

View File

@@ -1,7 +1,7 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from starlette.responses import StreamingResponse
@@ -61,3 +61,67 @@ async def test_log_stream_endpoint_lines(client):
# Optional: assert subscribe/connect were called
assert dummy_socket.subscribed # at least some log levels subscribed
assert dummy_socket.connected # connect was called
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
def test_files_endpoint(LOGGING_DIR, client):
file_1, file_2 = MagicMock(), MagicMock()
file_1.name = "file_1"
file_2.name = "file_2"
LOGGING_DIR.glob.return_value = [file_1, file_2]
result = client.get("/api/logs/files")
assert result.status_code == 200
assert result.json() == ["file_1", "file_2"]
@patch("control_backend.api.v1.endpoints.logs.FileResponse")
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
def test_log_file_endpoint_success(LOGGING_DIR, MockFileResponse, client):
mock_file_path = MagicMock()
mock_file_path.is_relative_to.return_value = True
mock_file_path.is_file.return_value = True
mock_file_path.name = "test.log"
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
mock_file_path.resolve.return_value = mock_file_path
MockFileResponse.return_value = MagicMock()
result = client.get("/api/logs/files/test.log")
assert result.status_code == 200
MockFileResponse.assert_called_once_with(mock_file_path, filename="test.log")
@pytest.mark.asyncio
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
async def test_log_file_endpoint_path_traversal(LOGGING_DIR):
from control_backend.api.v1.endpoints.logs import log_file
mock_file_path = MagicMock()
mock_file_path.is_relative_to.return_value = False
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
mock_file_path.resolve.return_value = mock_file_path
with pytest.raises(HTTPException) as exc_info:
await log_file("../secret.txt")
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Invalid filename."
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
def test_log_file_endpoint_file_not_found(LOGGING_DIR, client):
mock_file_path = MagicMock()
mock_file_path.is_relative_to.return_value = True
mock_file_path.is_file.return_value = False
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
mock_file_path.resolve.return_value = mock_file_path
result = client.get("/api/logs/files/nonexistent.log")
assert result.status_code == 404
assert result.json()["detail"] == "File not found."

View File

@@ -11,6 +11,5 @@ def test_router_includes_expected_paths():
# Ensure at least one route under each prefix exists
assert any(p.startswith("/robot") for p in paths)
assert any(p.startswith("/message") for p in paths)
assert any(p.startswith("/sse") for p in paths)
assert any(p.startswith("/logs") for p in paths)
assert any(p.startswith("/program") for p in paths)

View File

@@ -1,24 +0,0 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import sse
@pytest.fixture
def app():
app = FastAPI()
app.include_router(sse.router)
return app
@pytest.fixture
def client(app):
return TestClient(app)
def test_sse_route_exists(client):
"""Minimal smoke test to ensure /sse route exists and responds."""
response = client.get("/sse")
# Since implementation is not done, we only assert it doesn't crash
assert response.status_code == 200

View File

@@ -0,0 +1,45 @@
from unittest.mock import MagicMock, patch
import pytest
from control_backend.logging.dated_file_handler import DatedFileHandler
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
def test_reset(open_):
stream = MagicMock()
open_.return_value = stream
# A file should be opened when the logger is created
handler = DatedFileHandler(file_prefix="anything")
assert open_.call_count == 1
# Upon reset, the current file should be closed, and a new one should be opened
handler.do_rollover()
assert stream.close.call_count == 1
assert open_.call_count == 2
@patch("control_backend.logging.dated_file_handler.Path")
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
def test_creates_dir(open_, Path_):
stream = MagicMock()
open_.return_value = stream
test_path = MagicMock()
test_path.parent.is_dir.return_value = False
Path_.return_value = test_path
DatedFileHandler(file_prefix="anything")
# The directory should've been created
test_path.parent.mkdir.assert_called_once()
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
def test_invalid_constructor(_):
with pytest.raises(ValueError):
DatedFileHandler(file_prefix=None)
with pytest.raises(ValueError):
DatedFileHandler(file_prefix="")

View File

@@ -0,0 +1,218 @@
import logging
import pytest
from control_backend.logging.optional_field_formatter import OptionalFieldFormatter
@pytest.fixture
def logger():
"""Create a fresh logger for each test."""
logger = logging.getLogger(f"test_{id(object())}")
logger.setLevel(logging.DEBUG)
logger.handlers = []
return logger
@pytest.fixture
def log_output(logger):
"""Capture log output and return a function to get it."""
class ListHandler(logging.Handler):
def __init__(self):
super().__init__()
self.records = []
def emit(self, record):
self.records.append(self.format(record))
handler = ListHandler()
logger.addHandler(handler)
def get_output():
return handler.records
return get_output
def test_optional_field_present(logger, log_output):
"""Optional field should appear when provided in extra."""
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test message", extra={"role": "user"})
assert log_output() == ["INFO - user - test message"]
def test_optional_field_missing_no_default(logger, log_output):
"""Missing optional field with no default should be None."""
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test message")
assert log_output() == ["INFO - None - test message"]
def test_optional_field_missing_with_default(logger, log_output):
"""Missing optional field should use provided default."""
formatter = OptionalFieldFormatter(
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test message")
assert log_output() == ["INFO - assistant - test message"]
def test_optional_field_overrides_default(logger, log_output):
"""Provided extra value should override default."""
formatter = OptionalFieldFormatter(
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test message", extra={"role": "user"})
assert log_output() == ["INFO - user - test message"]
def test_multiple_optional_fields(logger, log_output):
"""Multiple optional fields should work independently."""
formatter = OptionalFieldFormatter(
"%(levelname)s - %(role?)s - %(request_id?)s - %(message)s", defaults={"role": "assistant"}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"request_id": "123"})
assert log_output() == ["INFO - assistant - 123 - test"]
def test_mixed_optional_and_required_fields(logger, log_output):
"""Standard fields should work alongside optional fields."""
formatter = OptionalFieldFormatter("%(levelname)s %(name)s %(role?)s %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"role": "user"})
output = log_output()[0]
assert "INFO" in output
assert "user" in output
assert "test" in output
def test_no_optional_fields(logger, log_output):
"""Formatter should work normally with no optional fields."""
formatter = OptionalFieldFormatter("%(levelname)s %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test message")
assert log_output() == ["INFO test message"]
def test_integer_format_specifier(logger, log_output):
"""Optional fields with %d specifier should work."""
formatter = OptionalFieldFormatter(
"%(levelname)s %(count?)d %(message)s", defaults={"count": 0}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"count": 42})
assert log_output() == ["INFO 42 test"]
def test_float_format_specifier(logger, log_output):
"""Optional fields with %f specifier should work."""
formatter = OptionalFieldFormatter(
"%(levelname)s %(duration?)f %(message)s", defaults={"duration": 0.0}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"duration": 1.5})
assert "1.5" in log_output()[0]
def test_empty_string_default(logger, log_output):
"""Empty string default should work."""
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults={"role": ""})
logger.handlers[0].setFormatter(formatter)
logger.info("test")
assert log_output() == ["INFO test"]
def test_none_format_string():
"""None format string should not raise."""
formatter = OptionalFieldFormatter(fmt=None)
assert formatter.optional_fields == set()
def test_optional_fields_parsed_correctly():
"""Check that optional fields are correctly identified."""
formatter = OptionalFieldFormatter("%(asctime)s %(role?)s %(level?)d %(name)s")
assert formatter.optional_fields == {("role", "s"), ("level", "d")}
def test_format_string_normalized():
"""Check that ? is removed from format string."""
formatter = OptionalFieldFormatter("%(role?)s %(message)s")
assert "?" not in formatter._style._fmt
assert "%(role)s" in formatter._style._fmt
def test_field_with_underscore(logger, log_output):
"""Field names with underscores should work."""
formatter = OptionalFieldFormatter("%(levelname)s %(user_id?)s %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"user_id": "abc123"})
assert log_output() == ["INFO abc123 test"]
def test_field_with_numbers(logger, log_output):
"""Field names with numbers should work."""
formatter = OptionalFieldFormatter("%(levelname)s %(field2?)s %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"field2": "value"})
assert log_output() == ["INFO value test"]
def test_multiple_log_calls(logger, log_output):
"""Formatter should work correctly across multiple log calls."""
formatter = OptionalFieldFormatter(
"%(levelname)s %(role?)s %(message)s", defaults={"role": "other"}
)
logger.handlers[0].setFormatter(formatter)
logger.info("first", extra={"role": "assistant"})
logger.info("second")
logger.info("third", extra={"role": "user"})
assert log_output() == [
"INFO assistant first",
"INFO other second",
"INFO user third",
]
def test_default_not_mutated(logger, log_output):
"""Original defaults dict should not be mutated."""
defaults = {"role": "other"}
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults=defaults)
logger.handlers[0].setFormatter(formatter)
logger.info("test")
assert defaults == {"role": "other"}

View File

@@ -0,0 +1,83 @@
import logging
import pytest
from control_backend.logging import PartialFilter
@pytest.fixture
def logger():
"""Create a fresh logger for each test."""
logger = logging.getLogger(f"test_{id(object())}")
logger.setLevel(logging.DEBUG)
logger.handlers = []
return logger
@pytest.fixture
def log_output(logger):
"""Capture log output and return a function to get it."""
class ListHandler(logging.Handler):
def __init__(self):
super().__init__()
self.records = []
def emit(self, record):
self.records.append(self.format(record))
handler = ListHandler()
handler.addFilter(PartialFilter())
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
return lambda: handler.records
def test_no_partial_attribute(logger, log_output):
"""Records without partial attribute should pass through."""
logger.info("normal message")
assert log_output() == ["normal message"]
def test_partial_true_filtered(logger, log_output):
"""Records with partial=True should be filtered out."""
logger.info("partial message", extra={"partial": True})
assert log_output() == []
def test_partial_false_passes(logger, log_output):
"""Records with partial=False should pass through."""
logger.info("complete message", extra={"partial": False})
assert log_output() == ["complete message"]
def test_partial_none_passes(logger, log_output):
"""Records with partial=None should pass through."""
logger.info("message", extra={"partial": None})
assert log_output() == ["message"]
def test_partial_truthy_value_passes(logger, log_output):
"""
Records with truthy but non-True partial should pass through, that is, only when it's exactly
``True`` should it pass.
"""
logger.info("message", extra={"partial": "yes"})
assert log_output() == ["message"]
def test_multiple_records_mixed(logger, log_output):
"""Filter should handle mixed records correctly."""
logger.info("first")
logger.info("second", extra={"partial": True})
logger.info("third", extra={"partial": False})
logger.info("fourth", extra={"partial": True})
logger.info("fifth")
assert log_output() == ["first", "third", "fifth"]

885
uv.lock generated

File diff suppressed because it is too large Load Diff