Compare commits

...

23 Commits

Author SHA1 Message Date
Twirre Meulenbelt
a87ac35201 docs: add docstrings to dated file handler
ref: N25B-401
2026-01-22 11:34:51 +01:00
Twirre Meulenbelt
3fed2f95b0 Merge remote-tracking branch 'refs/remotes/origin/feat/visual-emotion-recognition' into feat/add-experiment-logs 2026-01-20 10:11:43 +01:00
Twirre Meulenbelt
ae39298f9c Merge branch 'feat/experiment-logging' into feat/add-experiment-logs
# Conflicts:
#	.gitignore
2026-01-20 09:37:34 +01:00
Storm
424294b0a3 Merged feat/longer-pauses-possible into feat/visual-emotion-recognition 2026-01-19 18:35:07 +01:00
Pim Hutting
bc0947fac1 chore: added a dot 2026-01-19 18:26:15 +01:00
Storm
cd80cdf93b Merge branch 'feat/longer-pauses-possible' into feat/visual-emotion-recognition 2026-01-19 18:24:31 +01:00
Storm
985327de70 docs: updated docstrings and fixed styling
ref: N25B-393
2026-01-19 12:52:00 +01:00
Twirre Meulenbelt
58881b5914 test: add test cases
ref: N25B-401
2026-01-19 12:47:59 +01:00
Storm
302c50934e feat: implemented emotion recognition functionality in AgentSpeak
ref: N25B-393
2026-01-19 12:10:58 +01:00
Storm
f9c69cafb3 Merge branch 'feat/reset-experiment-and-phase' into feat/visual-emotion-recognition 2026-01-19 11:45:31 +01:00
Twirre Meulenbelt
ba79d09c5d feat: log download endpoints
ref: N25B-401
2026-01-16 16:32:51 +01:00
Storm
1b0b72d63a chore: fixed broken uv.lock file 2026-01-16 15:10:55 +01:00
Storm
0941b26703 refactor: updated how changes are passed to bdi_core_agent after merge
ref: N25B-393
2026-01-16 15:05:13 +01:00
Storm
48ae0c7a12 Merge remote-tracking branch 'origin/feat/reset-experiment-and-phase' into feat/visual-emotion-recognition 2026-01-16 14:45:16 +01:00
Storm
a09d8b3d9a chore: small changes 2026-01-16 14:40:59 +01:00
Storm
ac20048f02 Merge branch 'dev' into feat/visual-emotion-recognition 2026-01-16 14:16:28 +01:00
Storm
05804c158d feat: fully implemented visual emotion recognition agent in pipeline
ref: N25B-393
2026-01-16 13:26:53 +01:00
Storm
0771b0d607 feat: implemented visual emotion recogntion agent
ref: N25B-393
2026-01-16 09:50:59 +01:00
Twirre Meulenbelt
4cda4e5e70 feat: experiment log stream, to file and UI
Adds a few new logging utility classes. One to save to files with a date, one to support optional fields in formats, last to filter partial log messages.

ref: N25B-401
2026-01-15 17:07:49 +01:00
Luijkx,S.O.H. (Storm)
a9df9208bc Merge branch 'feat/multiple-receivers' into 'dev'
feat: able to send to multiple receivers

See merge request ics/sp/2025/n25b/pepperplus-cb!42
2026-01-15 09:26:12 +00:00
Twirre Meulenbelt
d7d697b293 docs: update to docstring
ref: N25B-441
2026-01-13 17:09:26 +01:00
Twirre Meulenbelt
9a55067a13 fix: set sender for internal messages
ref: N25B-441
2026-01-13 17:07:17 +01:00
Storm
1c88ae6078 feat: visual emotion recognition agent
ref: N25B-393
2026-01-13 12:41:18 +01:00
22 changed files with 1727 additions and 47 deletions

2
.gitignore vendored
View File

@@ -224,6 +224,8 @@ docs/*
# Generated files
agentspeak.asl
experiment-*.log

View File

@@ -1,36 +1,57 @@
version: 1
custom_levels:
OBSERVATION: 25
ACTION: 26
OBSERVATION: 24
ACTION: 25
CHAT: 26
LLM: 9
formatters:
# Console output
colored:
(): "colorlog.ColoredFormatter"
class: colorlog.ColoredFormatter
format: "{log_color}{asctime}.{msecs:03.0f} | {levelname:11} | {name:70} | {message}"
style: "{"
datefmt: "%H:%M:%S"
# User-facing UI (structured JSON)
json_experiment:
(): "pythonjsonlogger.jsonlogger.JsonFormatter"
json:
class: pythonjsonlogger.jsonlogger.JsonFormatter
format: "{name} {levelname} {levelno} {message} {created} {relativeCreated}"
style: "{"
# Experiment stream for console and file output, with optional `role` field
experiment:
class: control_backend.logging.OptionalFieldFormatter
format: "%(asctime)s %(levelname)s %(role?)s %(message)s"
defaults:
role: "-"
filters:
# Filter out any log records that have the extra field "partial" set to True, indicating that they
# will be replaced later.
partial:
(): control_backend.logging.PartialFilter
handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: colored
filters: [partial]
stream: ext://sys.stdout
ui:
class: zmq.log.handlers.PUBHandler
level: LLM
formatter: json_experiment
formatter: json
file:
class: control_backend.logging.DatedFileHandler
formatter: experiment
filters: [partial]
# Directory must match config.logging_settings.experiment_log_directory
file_prefix: experiment_logs/experiment
# Level of external libraries
# Level for external libraries
root:
level: WARN
handlers: [console]
@@ -39,3 +60,6 @@ loggers:
control_backend:
level: LLM
handlers: [ui]
experiment: # This name must match config.logging_settings.experiment_logger_name
level: DEBUG
handlers: [ui, file]

View File

@@ -7,6 +7,7 @@ requires-python = ">=3.13"
dependencies = [
"agentspeak>=0.2.2",
"colorlog>=6.10.1",
"deepface>=0.0.96",
"fastapi[all]>=0.115.6",
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
"numpy>=2.3.3",
@@ -21,6 +22,7 @@ dependencies = [
"silero-vad>=6.0.0",
"sphinx>=7.3.7",
"sphinx-rtd-theme>=3.0.2",
"tf-keras>=2.20.1",
"torch>=2.8.0",
"uvicorn>=0.37.0",
]

View File

@@ -1,9 +1,10 @@
import logging
from abc import ABC
from control_backend.core.agent_system import BaseAgent as CoreBaseAgent
class BaseAgent(CoreBaseAgent):
class BaseAgent(CoreBaseAgent, ABC):
"""
The primary base class for all implementation agents.

View File

@@ -22,6 +22,7 @@ from control_backend.schemas.program import (
BaseGoal,
BasicNorm,
ConditionalNorm,
EmotionBelief,
GestureAction,
Goal,
InferredBelief,
@@ -459,6 +460,10 @@ class AgentSpeakGenerator:
@_astify.register
def _(self, sb: SemanticBelief) -> AstExpression:
return AstLiteral(self.slugify(sb))
@_astify.register
def _(self, eb: EmotionBelief) -> AstExpression:
return AstLiteral("emotion_detected", [AstAtom(eb.emotion)])
@_astify.register
def _(self, ib: InferredBelief) -> AstExpression:

View File

@@ -8,6 +8,9 @@ from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent import ( # noqa
VisualEmotionRecognitionAgent,
)
from control_backend.core.config import settings
from control_backend.schemas.internal_message import InternalMessage
from control_backend.schemas.ri_message import PauseCommand
@@ -209,6 +212,13 @@ class RICommunicationAgent(BaseAgent):
case "audio":
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
await vad_agent.start()
case "video":
visual_emotion_agent = VisualEmotionRecognitionAgent(
settings.agent_settings.visual_emotion_recognition_name,
socket_address=addr,
bind=bind,
)
await visual_emotion_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)

View File

@@ -0,0 +1,166 @@
import json
import time
from collections import Counter, defaultdict
import cv2
import numpy as np
import zmq
import zmq.asyncio as azmq
from pydantic_core import ValidationError
from control_backend.agents import BaseAgent
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognizer import ( # noqa
DeepFaceEmotionRecognizer,
)
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief
class VisualEmotionRecognitionAgent(BaseAgent):
def __init__(
self,
name: str,
socket_address: str,
bind: bool = False,
timeout_ms: int = 1000,
window_duration: int = settings.behaviour_settings.visual_emotion_recognition_window_duration_s, # noqa
min_frames_required: int = settings.behaviour_settings.visual_emotion_recognition_min_frames_per_face, # noqa
):
"""
Initialize the Visual Emotion Recognition Agent.
:param name: Name of the agent
:param socket_address: Address of the socket to connect or bind to
:param bind: Whether to bind to the socket address (True) or connect (False)
:param timeout_ms: Timeout for socket receive operations in milliseconds
:param window_duration: Duration in seconds over which to aggregate emotions
:param min_frames_required: Minimum number of frames per face required to consider a face
valid
"""
super().__init__(name)
self.socket_address = socket_address
self.socket_bind = bind
self.timeout_ms = timeout_ms
self.window_duration = window_duration
self.min_frames_required = min_frames_required
async def setup(self):
"""
Initialize the agent resources.
1. Initializes the :class:`VisualEmotionRecognizer`.
2. Connects to the video input ZMQ socket.
3. Starts the background emotion recognition loop.
"""
self.logger.info("Setting up %s.", self.name)
self.emotion_recognizer = DeepFaceEmotionRecognizer()
self.video_in_socket = azmq.Context.instance().socket(zmq.SUB)
if self.socket_bind:
self.video_in_socket.bind(self.socket_address)
else:
self.video_in_socket.connect(self.socket_address)
self.video_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.video_in_socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
self.video_in_socket.setsockopt(zmq.CONFLATE, 1)
self.add_behavior(self.emotion_update_loop())
async def emotion_update_loop(self):
"""
Background loop to receive video frames, recognize emotions, and update beliefs.
1. Receives video frames from the ZMQ socket.
2. Uses the :class:`VisualEmotionRecognizer` to detect emotions.
3. Aggregates emotions over a time window.
4. Sends updates to the BDI Core Agent about detected emotions.
"""
# Next time to process the window and update emotions
next_window_time = time.time() + self.window_duration
# Tracks counts of detected emotions per face index
face_stats = defaultdict(Counter)
prev_dominant_emotions = set()
while self._running:
try:
frame_bytes = await self.video_in_socket.recv()
# Convert bytes to a numpy buffer
nparr = np.frombuffer(frame_bytes, np.uint8)
# Decode image into the generic Numpy Array DeepFace expects
frame_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame_image is None:
# Could not decode image, skip this frame
continue
# Get the dominant emotion from each face
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame_image)
# Update emotion counts for each detected face
for i, emotion in enumerate(current_emotions):
face_stats[i][emotion] += 1
# If window duration has passed, process the collected stats
if time.time() >= next_window_time:
window_dominant_emotions = set()
# Determine dominant emotion for each face in the window
for _, counter in face_stats.items():
total_detections = sum(counter.values())
if total_detections >= self.min_frames_required:
dominant_emotion = counter.most_common(1)[0][0]
window_dominant_emotions.add(dominant_emotion)
await self.update_emotions(prev_dominant_emotions, window_dominant_emotions)
prev_dominant_emotions = window_dominant_emotions
face_stats.clear()
next_window_time = time.time() + self.window_duration
except zmq.Again:
self.logger.warning("No video frame received within timeout.")
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
"""
Compare emotions from previous window and current emotions,
send updates to BDI Core Agent.
"""
emotions_to_remove = prev_emotions - emotions
emotions_to_add = emotions - prev_emotions
if not emotions_to_add and not emotions_to_remove:
return
emotion_beliefs_remove = []
for emotion in emotions_to_remove:
self.logger.info(f"Emotion '{emotion}' has disappeared.")
try:
emotion_beliefs_remove.append(
Belief(name="emotion_detected", arguments=[emotion], remove=True)
)
except ValidationError:
self.logger.warning("Invalid belief for emotion removal: %s", emotion)
emotion_beliefs_add = []
for emotion in emotions_to_add:
self.logger.info(f"New emotion detected: '{emotion}'")
try:
emotion_beliefs_add.append(Belief(name="emotion_detected", arguments=[emotion]))
except ValidationError:
self.logger.warning("Invalid belief for new emotion: %s", emotion)
beliefs_list_add = [b.model_dump() for b in emotion_beliefs_add]
beliefs_list_remove = [b.model_dump() for b in emotion_beliefs_remove]
payload = {"create": beliefs_list_add, "delete": beliefs_list_remove}
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=json.dumps(payload),
thread="beliefs",
)
await self.send(message)

View File

@@ -0,0 +1,55 @@
import abc
import numpy as np
from deepface import DeepFace
class VisualEmotionRecognizer(abc.ABC):
@abc.abstractmethod
def load_model(self):
"""Load the visual emotion recognition model into memory."""
pass
@abc.abstractmethod
def sorted_dominant_emotions(self, image) -> list[str]:
"""
Recognize dominant emotions from faces in the given image.
Emotions can be one of ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'].
To minimize false positives, consider filtering faces with low confidence.
:param image: The input image for emotion recognition.
:return: List of dominant emotion detected for each face in the image,
sorted per face.
"""
pass
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
"""
DeepFace-based implementation of VisualEmotionRecognizer.
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
emotions to be over-represented.
"""
def __init__(self):
self.load_model()
def load_model(self):
print("Loading Deepface Emotion Model...")
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
# analyze does not take a model as an argument, calling it once on a dummy image to load
# the model
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
print("Deepface Emotion Model loaded.")
def sorted_dominant_emotions(self, image) -> list[str]:
analysis = DeepFace.analyze(image,
actions=['emotion'],
enforce_detection=False
)
# Sort faces by x coordinate to maintain left-to-right order
analysis.sort(key=lambda face: face['region']['x'])
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
dominant_emotions = [face['dominant_emotion'] for face in analysis]
return dominant_emotions

View File

@@ -1,8 +1,9 @@
import logging
from pathlib import Path
import zmq
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse, StreamingResponse
from zmq.asyncio import Context
from control_backend.core.config import settings
@@ -38,3 +39,29 @@ async def log_stream():
yield f"data: {message}\n\n"
return StreamingResponse(gen(), media_type="text/event-stream")
LOGGING_DIR = Path(settings.logging_settings.experiment_log_directory).resolve()
@router.get("/logs/files")
@router.get("/api/logs/files")
async def log_directory():
"""
Get a list of all log files stored in the experiment log file directory.
"""
return [f.name for f in LOGGING_DIR.glob("*.log")]
@router.get("/logs/files/{filename}")
@router.get("/api/logs/files/{filename}")
async def log_file(filename: str):
# Prevent path-traversal
file_path = (LOGGING_DIR / filename).resolve() # This .resolve() is important
if not file_path.is_relative_to(LOGGING_DIR):
raise HTTPException(status_code=400, detail="Invalid filename.")
if not file_path.is_file():
raise HTTPException(status_code=404, detail="File not found.")
return FileResponse(file_path, filename=file_path.name)

View File

@@ -50,6 +50,7 @@ class AgentSettings(BaseModel):
# agent names
bdi_core_name: str = "bdi_core_agent"
bdi_program_manager_name: str = "bdi_program_manager_agent"
visual_emotion_recognition_name: str = "visual_emotion_recognition_agent"
text_belief_extractor_name: str = "text_belief_extractor_agent"
vad_name: str = "vad_agent"
llm_name: str = "llm_agent"
@@ -77,6 +78,10 @@ class BehaviourSettings(BaseModel):
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
:ivar transcription_token_buffer: Buffer for transcription tokens.
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
:ivar visual_emotion_recognition_window_duration_s: Duration in seconds over which to aggregate
emotions and update emotion beliefs.
:ivar visual_emotion_recognition_min_frames_per_face: Minimum number of frames per face required
to consider a face valid.
"""
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
@@ -100,6 +105,9 @@ class BehaviourSettings(BaseModel):
# Text belief extractor settings
conversation_history_length_limit: int = 10
# Visual Emotion Recognition settings
visual_emotion_recognition_window_duration_s: int = 5
visual_emotion_recognition_min_frames_per_face: int = 3
class LLMSettings(BaseModel):
"""
@@ -154,6 +162,20 @@ class SpeechModelSettings(BaseModel):
openai_model_name: str = "small.en"
class LoggingSettings(BaseModel):
"""
Configuration for logging.
:ivar logging_config_file: Path to the logging configuration file.
:ivar experiment_log_directory: Location of the experiment logs. Must match the logging config.
:ivar experiment_logger_name: Name of the experiment logger. Must match the logging config.
"""
logging_config_file: str = ".logging_config.yaml"
experiment_log_directory: str = "experiment_logs"
experiment_logger_name: str = "experiment"
class Settings(BaseSettings):
"""
Global application settings.
@@ -175,6 +197,8 @@ class Settings(BaseSettings):
ri_host: str = "localhost"
logging_settings: LoggingSettings = LoggingSettings()
zmq_settings: ZMQSettings = ZMQSettings()
agent_settings: AgentSettings = AgentSettings()

View File

@@ -1 +1,4 @@
from .dated_file_handler import DatedFileHandler as DatedFileHandler
from .optional_field_formatter import OptionalFieldFormatter as OptionalFieldFormatter
from .partial_filter import PartialFilter as PartialFilter
from .setup_logging import setup_logging as setup_logging

View File

@@ -0,0 +1,38 @@
from datetime import datetime
from logging import FileHandler
from pathlib import Path
class DatedFileHandler(FileHandler):
def __init__(self, file_prefix: str, **kwargs):
if not file_prefix:
raise ValueError("`file_prefix` argument cannot be empty.")
self._file_prefix = file_prefix
kwargs["filename"] = self._make_filename()
super().__init__(**kwargs)
def _make_filename(self) -> str:
"""
Create the filename for the current logfile, using the configured file prefix and the
current date and time. If the directory does not exist, it gets created.
:return: A filepath.
"""
filepath = Path(f"{self._file_prefix}-{datetime.now():%Y%m%d-%H%M%S}.log")
if not filepath.parent.is_dir():
filepath.parent.mkdir(parents=True, exist_ok=True)
return str(filepath)
def do_rollover(self):
"""
Close the current logfile and create a new one with the current date and time.
"""
self.acquire()
try:
if self.stream:
self.stream.close()
self.baseFilename = self._make_filename()
self.stream = self._open()
finally:
self.release()

View File

@@ -0,0 +1,67 @@
import logging
import re
class OptionalFieldFormatter(logging.Formatter):
"""
Logging formatter that supports optional fields marked by `?`.
Optional fields are denoted by placing a `?` after the field name inside
the parentheses, e.g., `%(role?)s`. If the field is not provided in the
log call's `extra` dict, it will use the default value from `defaults`
or `None` if no default is specified.
:param fmt: Format string with optional `%(name?)s` style fields.
:type fmt: str or None
:param datefmt: Date format string, passed to parent :class:`logging.Formatter`.
:type datefmt: str or None
:param style: Formatting style, must be '%'. Passed to parent.
:type style: str
:param defaults: Default values for optional fields when not provided.
:type defaults: dict or None
:example:
>>> formatter = OptionalFieldFormatter(
... fmt="%(asctime)s %(levelname)s %(role?)s %(message)s",
... defaults={"role": ""-""}
... )
>>> handler = logging.StreamHandler()
>>> handler.setFormatter(formatter)
>>> logger = logging.getLogger(__name__)
>>> logger.addHandler(handler)
>>>
>>> logger.chat("Hello there!", extra={"role": "USER"})
2025-01-15 10:30:00 CHAT USER Hello there!
>>>
>>> logger.info("A logging message")
2025-01-15 10:30:01 INFO - A logging message
.. note::
Only `%`-style formatting is supported. The `{` and `$` styles are not
compatible with this formatter.
.. seealso::
:class:`logging.Formatter` for base formatter documentation.
"""
# Match %(name?)s or %(name?)d etc.
OPTIONAL_PATTERN = re.compile(r"%\((\w+)\?\)([sdifFeEgGxXocrba%])")
def __init__(self, fmt=None, datefmt=None, style="%", defaults=None):
self.defaults = defaults or {}
self.optional_fields = set(self.OPTIONAL_PATTERN.findall(fmt or ""))
# Convert %(name?)s to %(name)s for standard formatting
normalized_fmt = self.OPTIONAL_PATTERN.sub(r"%(\1)\2", fmt or "")
super().__init__(normalized_fmt, datefmt, style)
def format(self, record):
for field, _ in self.optional_fields:
if not hasattr(record, field):
default = self.defaults.get(field, None)
setattr(record, field, default)
return super().format(record)

View File

@@ -0,0 +1,10 @@
import logging
class PartialFilter(logging.Filter):
"""
Class to filter any log records that have the "partial" attribute set to ``True``.
"""
def filter(self, record):
return getattr(record, "partial", False) is not True

View File

@@ -37,7 +37,7 @@ def add_logging_level(level_name: str, level_num: int, method_name: str | None =
setattr(logging, method_name, log_to_root)
def setup_logging(path: str = ".logging_config.yaml") -> None:
def setup_logging(path: str = settings.logging_settings.logging_config_file) -> None:
"""
Setup logging configuration of the CB. Tries to load the logging configuration from a file,
in which we specify custom loggers, formatters, handlers, etc.
@@ -65,7 +65,7 @@ def setup_logging(path: str = ".logging_config.yaml") -> None:
# Patch ZMQ PUBHandler to know about custom levels
if custom_levels:
for logger_name in ("control_backend",):
for logger_name in config.get("loggers", {}):
logger = logging.getLogger(logger_name)
for handler in logger.handlers:
if isinstance(handler, PUBHandler):

View File

@@ -7,7 +7,7 @@ class InternalMessage(BaseModel):
"""
Standard message envelope for communication between agents within the Control Backend.
:ivar to: The name of the destination agent.
:ivar to: The name(s) of the destination agent(s).
:ivar sender: The name of the sending agent.
:ivar body: The string payload (often a JSON-serialized model).
:ivar thread: An optional thread identifier/topic to categorize the message (e.g., 'beliefs').

View File

@@ -28,8 +28,8 @@ class LogicalOperator(Enum):
OR = "OR"
type Belief = KeywordBelief | SemanticBelief | InferredBelief
type BasicBelief = KeywordBelief | SemanticBelief
type Belief = KeywordBelief | SemanticBelief | InferredBelief | EmotionBelief
type BasicBelief = KeywordBelief | SemanticBelief | EmotionBelief
class KeywordBelief(ProgramElement):
@@ -69,6 +69,15 @@ class InferredBelief(ProgramElement):
left: Belief
right: Belief
class EmotionBelief(ProgramElement):
"""
Represents a belief that is set when a certain emotion is detected.
:ivar emotion: The emotion on which this belief gets set.
"""
name: str = ""
emotion: str
class Norm(ProgramElement):
"""
@@ -226,3 +235,9 @@ class Program(BaseModel):
"""
phases: list[Phase]
if __name__ == "__main__":
input = input("Enter program JSON: ")
program = Program.model_validate_json(input)
print(program)

View File

@@ -1,7 +1,7 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from starlette.responses import StreamingResponse
@@ -61,3 +61,67 @@ async def test_log_stream_endpoint_lines(client):
# Optional: assert subscribe/connect were called
assert dummy_socket.subscribed # at least some log levels subscribed
assert dummy_socket.connected # connect was called
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
def test_files_endpoint(LOGGING_DIR, client):
file_1, file_2 = MagicMock(), MagicMock()
file_1.name = "file_1"
file_2.name = "file_2"
LOGGING_DIR.glob.return_value = [file_1, file_2]
result = client.get("/api/logs/files")
assert result.status_code == 200
assert result.json() == ["file_1", "file_2"]
@patch("control_backend.api.v1.endpoints.logs.FileResponse")
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
def test_log_file_endpoint_success(LOGGING_DIR, MockFileResponse, client):
mock_file_path = MagicMock()
mock_file_path.is_relative_to.return_value = True
mock_file_path.is_file.return_value = True
mock_file_path.name = "test.log"
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
mock_file_path.resolve.return_value = mock_file_path
MockFileResponse.return_value = MagicMock()
result = client.get("/api/logs/files/test.log")
assert result.status_code == 200
MockFileResponse.assert_called_once_with(mock_file_path, filename="test.log")
@pytest.mark.asyncio
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
async def test_log_file_endpoint_path_traversal(LOGGING_DIR):
from control_backend.api.v1.endpoints.logs import log_file
mock_file_path = MagicMock()
mock_file_path.is_relative_to.return_value = False
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
mock_file_path.resolve.return_value = mock_file_path
with pytest.raises(HTTPException) as exc_info:
await log_file("../secret.txt")
assert exc_info.value.status_code == 400
assert exc_info.value.detail == "Invalid filename."
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
def test_log_file_endpoint_file_not_found(LOGGING_DIR, client):
mock_file_path = MagicMock()
mock_file_path.is_relative_to.return_value = True
mock_file_path.is_file.return_value = False
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
mock_file_path.resolve.return_value = mock_file_path
result = client.get("/api/logs/files/nonexistent.log")
assert result.status_code == 404
assert result.json()["detail"] == "File not found."

View File

@@ -0,0 +1,45 @@
from unittest.mock import MagicMock, patch
import pytest
from control_backend.logging.dated_file_handler import DatedFileHandler
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
def test_reset(open_):
stream = MagicMock()
open_.return_value = stream
# A file should be opened when the logger is created
handler = DatedFileHandler(file_prefix="anything")
assert open_.call_count == 1
# Upon reset, the current file should be closed, and a new one should be opened
handler.do_rollover()
assert stream.close.call_count == 1
assert open_.call_count == 2
@patch("control_backend.logging.dated_file_handler.Path")
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
def test_creates_dir(open_, Path_):
stream = MagicMock()
open_.return_value = stream
test_path = MagicMock()
test_path.parent.is_dir.return_value = False
Path_.return_value = test_path
DatedFileHandler(file_prefix="anything")
# The directory should've been created
test_path.parent.mkdir.assert_called_once()
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
def test_invalid_constructor(_):
with pytest.raises(ValueError):
DatedFileHandler(file_prefix=None)
with pytest.raises(ValueError):
DatedFileHandler(file_prefix="")

View File

@@ -0,0 +1,218 @@
import logging
import pytest
from control_backend.logging.optional_field_formatter import OptionalFieldFormatter
@pytest.fixture
def logger():
"""Create a fresh logger for each test."""
logger = logging.getLogger(f"test_{id(object())}")
logger.setLevel(logging.DEBUG)
logger.handlers = []
return logger
@pytest.fixture
def log_output(logger):
"""Capture log output and return a function to get it."""
class ListHandler(logging.Handler):
def __init__(self):
super().__init__()
self.records = []
def emit(self, record):
self.records.append(self.format(record))
handler = ListHandler()
logger.addHandler(handler)
def get_output():
return handler.records
return get_output
def test_optional_field_present(logger, log_output):
"""Optional field should appear when provided in extra."""
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test message", extra={"role": "user"})
assert log_output() == ["INFO - user - test message"]
def test_optional_field_missing_no_default(logger, log_output):
"""Missing optional field with no default should be None."""
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test message")
assert log_output() == ["INFO - None - test message"]
def test_optional_field_missing_with_default(logger, log_output):
"""Missing optional field should use provided default."""
formatter = OptionalFieldFormatter(
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test message")
assert log_output() == ["INFO - assistant - test message"]
def test_optional_field_overrides_default(logger, log_output):
"""Provided extra value should override default."""
formatter = OptionalFieldFormatter(
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test message", extra={"role": "user"})
assert log_output() == ["INFO - user - test message"]
def test_multiple_optional_fields(logger, log_output):
"""Multiple optional fields should work independently."""
formatter = OptionalFieldFormatter(
"%(levelname)s - %(role?)s - %(request_id?)s - %(message)s", defaults={"role": "assistant"}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"request_id": "123"})
assert log_output() == ["INFO - assistant - 123 - test"]
def test_mixed_optional_and_required_fields(logger, log_output):
"""Standard fields should work alongside optional fields."""
formatter = OptionalFieldFormatter("%(levelname)s %(name)s %(role?)s %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"role": "user"})
output = log_output()[0]
assert "INFO" in output
assert "user" in output
assert "test" in output
def test_no_optional_fields(logger, log_output):
"""Formatter should work normally with no optional fields."""
formatter = OptionalFieldFormatter("%(levelname)s %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test message")
assert log_output() == ["INFO test message"]
def test_integer_format_specifier(logger, log_output):
"""Optional fields with %d specifier should work."""
formatter = OptionalFieldFormatter(
"%(levelname)s %(count?)d %(message)s", defaults={"count": 0}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"count": 42})
assert log_output() == ["INFO 42 test"]
def test_float_format_specifier(logger, log_output):
"""Optional fields with %f specifier should work."""
formatter = OptionalFieldFormatter(
"%(levelname)s %(duration?)f %(message)s", defaults={"duration": 0.0}
)
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"duration": 1.5})
assert "1.5" in log_output()[0]
def test_empty_string_default(logger, log_output):
"""Empty string default should work."""
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults={"role": ""})
logger.handlers[0].setFormatter(formatter)
logger.info("test")
assert log_output() == ["INFO test"]
def test_none_format_string():
"""None format string should not raise."""
formatter = OptionalFieldFormatter(fmt=None)
assert formatter.optional_fields == set()
def test_optional_fields_parsed_correctly():
"""Check that optional fields are correctly identified."""
formatter = OptionalFieldFormatter("%(asctime)s %(role?)s %(level?)d %(name)s")
assert formatter.optional_fields == {("role", "s"), ("level", "d")}
def test_format_string_normalized():
"""Check that ? is removed from format string."""
formatter = OptionalFieldFormatter("%(role?)s %(message)s")
assert "?" not in formatter._style._fmt
assert "%(role)s" in formatter._style._fmt
def test_field_with_underscore(logger, log_output):
"""Field names with underscores should work."""
formatter = OptionalFieldFormatter("%(levelname)s %(user_id?)s %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"user_id": "abc123"})
assert log_output() == ["INFO abc123 test"]
def test_field_with_numbers(logger, log_output):
"""Field names with numbers should work."""
formatter = OptionalFieldFormatter("%(levelname)s %(field2?)s %(message)s")
logger.handlers[0].setFormatter(formatter)
logger.info("test", extra={"field2": "value"})
assert log_output() == ["INFO value test"]
def test_multiple_log_calls(logger, log_output):
"""Formatter should work correctly across multiple log calls."""
formatter = OptionalFieldFormatter(
"%(levelname)s %(role?)s %(message)s", defaults={"role": "other"}
)
logger.handlers[0].setFormatter(formatter)
logger.info("first", extra={"role": "assistant"})
logger.info("second")
logger.info("third", extra={"role": "user"})
assert log_output() == [
"INFO assistant first",
"INFO other second",
"INFO user third",
]
def test_default_not_mutated(logger, log_output):
"""Original defaults dict should not be mutated."""
defaults = {"role": "other"}
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults=defaults)
logger.handlers[0].setFormatter(formatter)
logger.info("test")
assert defaults == {"role": "other"}

View File

@@ -0,0 +1,83 @@
import logging
import pytest
from control_backend.logging import PartialFilter
@pytest.fixture
def logger():
"""Create a fresh logger for each test."""
logger = logging.getLogger(f"test_{id(object())}")
logger.setLevel(logging.DEBUG)
logger.handlers = []
return logger
@pytest.fixture
def log_output(logger):
"""Capture log output and return a function to get it."""
class ListHandler(logging.Handler):
def __init__(self):
super().__init__()
self.records = []
def emit(self, record):
self.records.append(self.format(record))
handler = ListHandler()
handler.addFilter(PartialFilter())
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
return lambda: handler.records
def test_no_partial_attribute(logger, log_output):
"""Records without partial attribute should pass through."""
logger.info("normal message")
assert log_output() == ["normal message"]
def test_partial_true_filtered(logger, log_output):
"""Records with partial=True should be filtered out."""
logger.info("partial message", extra={"partial": True})
assert log_output() == []
def test_partial_false_passes(logger, log_output):
"""Records with partial=False should pass through."""
logger.info("complete message", extra={"partial": False})
assert log_output() == ["complete message"]
def test_partial_none_passes(logger, log_output):
"""Records with partial=None should pass through."""
logger.info("message", extra={"partial": None})
assert log_output() == ["message"]
def test_partial_truthy_value_passes(logger, log_output):
"""
Records with truthy but non-True partial should pass through, that is, only when it's exactly
``True`` should it pass.
"""
logger.info("message", extra={"partial": "yes"})
assert log_output() == ["message"]
def test_multiple_records_mixed(logger, log_output):
"""Filter should handle mixed records correctly."""
logger.info("first")
logger.info("second", extra={"partial": True})
logger.info("third", extra={"partial": False})
logger.info("fourth", extra={"partial": True})
logger.info("fifth")
assert log_output() == ["first", "third", "fifth"]

881
uv.lock generated

File diff suppressed because it is too large Load Diff