Compare commits
15 Commits
feat/face-
...
demo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9248eaadbc | ||
|
|
3095cb437b | ||
|
|
482c6b1082 | ||
|
|
b9a47eeb0c | ||
|
|
b3721322b9 | ||
| 8575ddcbcf | |||
|
|
4fb10730a4 | ||
| 59b35b31b2 | |||
|
|
7516667545 | ||
| 651f1b74a6 | |||
| 5ed751de8c | |||
| 89ebe45724 | |||
|
|
58881b5914 | ||
|
|
ba79d09c5d | ||
|
|
4cda4e5e70 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -224,7 +224,7 @@ docs/*
|
|||||||
|
|
||||||
# Generated files
|
# Generated files
|
||||||
agentspeak.asl
|
agentspeak.asl
|
||||||
|
experiment-*.log
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,36 +1,57 @@
|
|||||||
version: 1
|
version: 1
|
||||||
|
|
||||||
custom_levels:
|
custom_levels:
|
||||||
OBSERVATION: 25
|
OBSERVATION: 24
|
||||||
ACTION: 26
|
ACTION: 25
|
||||||
|
CHAT: 26
|
||||||
LLM: 9
|
LLM: 9
|
||||||
|
|
||||||
formatters:
|
formatters:
|
||||||
# Console output
|
# Console output
|
||||||
colored:
|
colored:
|
||||||
(): "colorlog.ColoredFormatter"
|
class: colorlog.ColoredFormatter
|
||||||
format: "{log_color}{asctime}.{msecs:03.0f} | {levelname:11} | {name:70} | {message}"
|
format: "{log_color}{asctime}.{msecs:03.0f} | {levelname:11} | {name:70} | {message}"
|
||||||
style: "{"
|
style: "{"
|
||||||
datefmt: "%H:%M:%S"
|
datefmt: "%H:%M:%S"
|
||||||
|
|
||||||
# User-facing UI (structured JSON)
|
# User-facing UI (structured JSON)
|
||||||
json_experiment:
|
json:
|
||||||
(): "pythonjsonlogger.jsonlogger.JsonFormatter"
|
class: pythonjsonlogger.jsonlogger.JsonFormatter
|
||||||
format: "{name} {levelname} {levelno} {message} {created} {relativeCreated}"
|
format: "{name} {levelname} {levelno} {message} {created} {relativeCreated}"
|
||||||
style: "{"
|
style: "{"
|
||||||
|
|
||||||
|
# Experiment stream for console and file output, with optional `role` field
|
||||||
|
experiment:
|
||||||
|
class: control_backend.logging.OptionalFieldFormatter
|
||||||
|
format: "%(asctime)s %(levelname)s %(role?)s %(message)s"
|
||||||
|
defaults:
|
||||||
|
role: "-"
|
||||||
|
|
||||||
|
filters:
|
||||||
|
# Filter out any log records that have the extra field "partial" set to True, indicating that they
|
||||||
|
# will be replaced later.
|
||||||
|
partial:
|
||||||
|
(): control_backend.logging.PartialFilter
|
||||||
|
|
||||||
handlers:
|
handlers:
|
||||||
console:
|
console:
|
||||||
class: logging.StreamHandler
|
class: logging.StreamHandler
|
||||||
level: DEBUG
|
level: DEBUG
|
||||||
formatter: colored
|
formatter: colored
|
||||||
|
filters: [partial]
|
||||||
stream: ext://sys.stdout
|
stream: ext://sys.stdout
|
||||||
ui:
|
ui:
|
||||||
class: zmq.log.handlers.PUBHandler
|
class: zmq.log.handlers.PUBHandler
|
||||||
level: LLM
|
level: LLM
|
||||||
formatter: json_experiment
|
formatter: json
|
||||||
|
file:
|
||||||
|
class: control_backend.logging.DatedFileHandler
|
||||||
|
formatter: experiment
|
||||||
|
filters: [partial]
|
||||||
|
# Directory must match config.logging_settings.experiment_log_directory
|
||||||
|
file_prefix: experiment_logs/experiment
|
||||||
|
|
||||||
# Level of external libraries
|
# Level for external libraries
|
||||||
root:
|
root:
|
||||||
level: WARN
|
level: WARN
|
||||||
handlers: [console]
|
handlers: [console]
|
||||||
@@ -39,3 +60,6 @@ loggers:
|
|||||||
control_backend:
|
control_backend:
|
||||||
level: LLM
|
level: LLM
|
||||||
handlers: [ui]
|
handlers: [ui]
|
||||||
|
experiment: # This name must match config.logging_settings.experiment_logger_name
|
||||||
|
level: DEBUG
|
||||||
|
handlers: [ui, file]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio as azmq
|
import zmq.asyncio as azmq
|
||||||
@@ -8,6 +9,8 @@ from control_backend.core.agent_system import InternalMessage
|
|||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
|
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
|
||||||
|
|
||||||
|
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||||
|
|
||||||
|
|
||||||
class RobotGestureAgent(BaseAgent):
|
class RobotGestureAgent(BaseAgent):
|
||||||
"""
|
"""
|
||||||
@@ -111,6 +114,7 @@ class RobotGestureAgent(BaseAgent):
|
|||||||
gesture_command.data,
|
gesture_command.data,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
experiment_logger.action("Gesture: %s", gesture_command.data)
|
||||||
await self.pubsocket.send_json(gesture_command.model_dump())
|
await self.pubsocket.send_json(gesture_command.model_dump())
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception("Error processing internal message.")
|
self.logger.exception("Error processing internal message.")
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
from control_backend.core.agent_system import BaseAgent as CoreBaseAgent
|
from control_backend.core.agent_system import BaseAgent as CoreBaseAgent
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(CoreBaseAgent):
|
class BaseAgent(CoreBaseAgent, ABC):
|
||||||
"""
|
"""
|
||||||
The primary base class for all implementation agents.
|
The primary base class for all implementation agents.
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
@@ -19,6 +20,9 @@ from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, Speec
|
|||||||
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
|
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
|
||||||
|
|
||||||
|
|
||||||
|
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||||
|
|
||||||
|
|
||||||
class BDICoreAgent(BaseAgent):
|
class BDICoreAgent(BaseAgent):
|
||||||
"""
|
"""
|
||||||
BDI Core Agent.
|
BDI Core Agent.
|
||||||
@@ -207,6 +211,9 @@ class BDICoreAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
term = agentspeak.Literal(name)
|
term = agentspeak.Literal(name)
|
||||||
|
|
||||||
|
if name != "user_said":
|
||||||
|
experiment_logger.observation(f"Formed new belief: {name}{f'={args}' if args else ''}")
|
||||||
|
|
||||||
self.bdi_agent.call(
|
self.bdi_agent.call(
|
||||||
agentspeak.Trigger.addition,
|
agentspeak.Trigger.addition,
|
||||||
agentspeak.GoalType.belief,
|
agentspeak.GoalType.belief,
|
||||||
@@ -244,6 +251,9 @@ class BDICoreAgent(BaseAgent):
|
|||||||
new_args = (agentspeak.Literal(arg) for arg in args)
|
new_args = (agentspeak.Literal(arg) for arg in args)
|
||||||
term = agentspeak.Literal(name, new_args)
|
term = agentspeak.Literal(name, new_args)
|
||||||
|
|
||||||
|
if name != "user_said":
|
||||||
|
experiment_logger.observation(f"Removed belief: {name}{f'={args}' if args else ''}")
|
||||||
|
|
||||||
result = self.bdi_agent.call(
|
result = self.bdi_agent.call(
|
||||||
agentspeak.Trigger.removal,
|
agentspeak.Trigger.removal,
|
||||||
agentspeak.GoalType.belief,
|
agentspeak.GoalType.belief,
|
||||||
@@ -386,6 +396,8 @@ class BDICoreAgent(BaseAgent):
|
|||||||
body=str(message_text),
|
body=str(message_text),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
experiment_logger.chat(str(message_text), extra={"role": "assistant"})
|
||||||
|
|
||||||
self.add_behavior(self.send(chat_history_message))
|
self.add_behavior(self.send(chat_history_message))
|
||||||
|
|
||||||
yield
|
yield
|
||||||
@@ -441,6 +453,7 @@ class BDICoreAgent(BaseAgent):
|
|||||||
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
|
trigger_name = agentspeak.grounded(term.args[0], intention.scope)
|
||||||
|
|
||||||
self.logger.debug("Started trigger %s", trigger_name)
|
self.logger.debug("Started trigger %s", trigger_name)
|
||||||
|
experiment_logger.observation("Triggered: %s", trigger_name)
|
||||||
|
|
||||||
msg = InternalMessage(
|
msg = InternalMessage(
|
||||||
to=settings.agent_settings.user_interrupt_name,
|
to=settings.agent_settings.user_interrupt_name,
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from zmq.asyncio import Context
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
|
import control_backend
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
|
from control_backend.agents.bdi.agentspeak_generator import AgentSpeakGenerator
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
@@ -19,6 +21,8 @@ from control_backend.schemas.program import (
|
|||||||
Program,
|
Program,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||||
|
|
||||||
|
|
||||||
class BDIProgramManager(BaseAgent):
|
class BDIProgramManager(BaseAgent):
|
||||||
"""
|
"""
|
||||||
@@ -241,6 +245,18 @@ class BDIProgramManager(BaseAgent):
|
|||||||
await self.send(extractor_msg)
|
await self.send(extractor_msg)
|
||||||
self.logger.debug("Sent message to extractor agent to clear history.")
|
self.logger.debug("Sent message to extractor agent to clear history.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rollover_experiment_logs():
|
||||||
|
"""
|
||||||
|
A new experiment program started; make a new experiment log file.
|
||||||
|
"""
|
||||||
|
handlers = logging.getLogger("experiment").handlers
|
||||||
|
for handler in handlers:
|
||||||
|
if isinstance(handler, control_backend.logging.DatedFileHandler):
|
||||||
|
experiment_logger.action("Doing rollover...")
|
||||||
|
handler.do_rollover()
|
||||||
|
experiment_logger.debug("Finished rollover.")
|
||||||
|
|
||||||
async def _receive_programs(self):
|
async def _receive_programs(self):
|
||||||
"""
|
"""
|
||||||
Continuous loop that receives program updates from the HTTP endpoint.
|
Continuous loop that receives program updates from the HTTP endpoint.
|
||||||
@@ -261,6 +277,7 @@ class BDIProgramManager(BaseAgent):
|
|||||||
self._initialize_internal_state(program)
|
self._initialize_internal_state(program)
|
||||||
await self._send_program_to_user_interrupt(program)
|
await self._send_program_to_user_interrupt(program)
|
||||||
await self._send_clear_llm_history()
|
await self._send_clear_llm_history()
|
||||||
|
self._rollover_experiment_logs()
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
self._create_agentspeak_and_send_to_bdi(program),
|
self._create_agentspeak_and_send_to_bdi(program),
|
||||||
|
|||||||
@@ -150,6 +150,9 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
return
|
return
|
||||||
|
|
||||||
available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
|
available_beliefs = [b for b in belief_list.beliefs if isinstance(b, SemanticBelief)]
|
||||||
|
self._current_beliefs = BeliefState(
|
||||||
|
false={InternalBelief(name=b.name, arguments=None) for b in available_beliefs},
|
||||||
|
)
|
||||||
self.belief_inferrer.available_beliefs = available_beliefs
|
self.belief_inferrer.available_beliefs = available_beliefs
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Received %d semantic beliefs from the program manager: %s",
|
"Received %d semantic beliefs from the program manager: %s",
|
||||||
@@ -170,6 +173,9 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
available_goals = {g for g in goals_list.goals if g.can_fail}
|
available_goals = {g for g in goals_list.goals if g.can_fail}
|
||||||
available_goals -= self._force_completed_goals
|
available_goals -= self._force_completed_goals
|
||||||
self.goal_inferrer.goals = available_goals
|
self.goal_inferrer.goals = available_goals
|
||||||
|
self._current_goal_completions = {
|
||||||
|
f"achieved_{AgentSpeakGenerator.slugify(goal)}": False for goal in available_goals
|
||||||
|
}
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Received %d failable goals from the program manager: %s",
|
"Received %d failable goals from the program manager: %s",
|
||||||
len(available_goals),
|
len(available_goals),
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -14,6 +15,8 @@ from control_backend.core.config import settings
|
|||||||
from ...schemas.llm_prompt_message import LLMPromptMessage
|
from ...schemas.llm_prompt_message import LLMPromptMessage
|
||||||
from .llm_instructions import LLMInstructions
|
from .llm_instructions import LLMInstructions
|
||||||
|
|
||||||
|
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||||
|
|
||||||
|
|
||||||
class LLMAgent(BaseAgent):
|
class LLMAgent(BaseAgent):
|
||||||
"""
|
"""
|
||||||
@@ -170,7 +173,7 @@ class LLMAgent(BaseAgent):
|
|||||||
*self.history,
|
*self.history,
|
||||||
]
|
]
|
||||||
|
|
||||||
message_id = str(uuid.uuid4()) # noqa
|
message_id = str(uuid.uuid4())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
full_message = ""
|
full_message = ""
|
||||||
@@ -179,10 +182,9 @@ class LLMAgent(BaseAgent):
|
|||||||
full_message += token
|
full_message += token
|
||||||
current_chunk += token
|
current_chunk += token
|
||||||
|
|
||||||
self.logger.llm(
|
experiment_logger.chat(
|
||||||
"Received token: %s",
|
|
||||||
full_message,
|
full_message,
|
||||||
extra={"reference": message_id}, # Used in the UI to update old logs
|
extra={"role": "assistant", "reference": message_id, "partial": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream the message in chunks separated by punctuation.
|
# Stream the message in chunks separated by punctuation.
|
||||||
@@ -197,6 +199,11 @@ class LLMAgent(BaseAgent):
|
|||||||
# Yield any remaining tail
|
# Yield any remaining tail
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
yield current_chunk
|
yield current_chunk
|
||||||
|
|
||||||
|
experiment_logger.chat(
|
||||||
|
full_message,
|
||||||
|
extra={"role": "assistant", "reference": message_id, "partial": False},
|
||||||
|
)
|
||||||
except httpx.HTTPError as err:
|
except httpx.HTTPError as err:
|
||||||
self.logger.error("HTTP error.", exc_info=err)
|
self.logger.error("HTTP error.", exc_info=err)
|
||||||
yield "LLM service unavailable."
|
yield "LLM service unavailable."
|
||||||
@@ -212,7 +219,7 @@ class LLMAgent(BaseAgent):
|
|||||||
:yield: Raw text tokens (deltas) from the SSE stream.
|
:yield: Raw text tokens (deltas) from the SSE stream.
|
||||||
:raises httpx.HTTPError: If the API returns a non-200 status.
|
:raises httpx.HTTPError: If the API returns a non-200 status.
|
||||||
"""
|
"""
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(timeout=httpx.Timeout(20.0)) as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
settings.llm_settings.local_llm_url,
|
settings.llm_settings.local_llm_url,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
@@ -10,6 +11,8 @@ from control_backend.core.config import settings
|
|||||||
|
|
||||||
from .speech_recognizer import SpeechRecognizer
|
from .speech_recognizer import SpeechRecognizer
|
||||||
|
|
||||||
|
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionAgent(BaseAgent):
|
class TranscriptionAgent(BaseAgent):
|
||||||
"""
|
"""
|
||||||
@@ -25,6 +28,8 @@ class TranscriptionAgent(BaseAgent):
|
|||||||
:ivar audio_in_socket: The ZMQ SUB socket instance.
|
:ivar audio_in_socket: The ZMQ SUB socket instance.
|
||||||
:ivar speech_recognizer: The speech recognition engine instance.
|
:ivar speech_recognizer: The speech recognition engine instance.
|
||||||
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
|
:ivar _concurrency: Semaphore to limit concurrent transcriptions.
|
||||||
|
:ivar _current_speech_reference: The reference of the current user utterance, for synchronising
|
||||||
|
experiment logs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, audio_in_address: str):
|
def __init__(self, audio_in_address: str):
|
||||||
@@ -39,6 +44,7 @@ class TranscriptionAgent(BaseAgent):
|
|||||||
self.audio_in_socket: azmq.Socket | None = None
|
self.audio_in_socket: azmq.Socket | None = None
|
||||||
self.speech_recognizer = None
|
self.speech_recognizer = None
|
||||||
self._concurrency = None
|
self._concurrency = None
|
||||||
|
self._current_speech_reference: str | None = None
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""
|
"""
|
||||||
@@ -63,6 +69,10 @@ class TranscriptionAgent(BaseAgent):
|
|||||||
|
|
||||||
self.logger.info("Finished setting up %s", self.name)
|
self.logger.info("Finished setting up %s", self.name)
|
||||||
|
|
||||||
|
async def handle_message(self, msg: InternalMessage):
|
||||||
|
if msg.thread == "voice_activity":
|
||||||
|
self._current_speech_reference = msg.body
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""
|
"""
|
||||||
Stop the agent and close sockets.
|
Stop the agent and close sockets.
|
||||||
@@ -96,24 +106,25 @@ class TranscriptionAgent(BaseAgent):
|
|||||||
|
|
||||||
async def _share_transcription(self, transcription: str):
|
async def _share_transcription(self, transcription: str):
|
||||||
"""
|
"""
|
||||||
Share a transcription to the other agents that depend on it.
|
Share a transcription to the other agents that depend on it, and to experiment logs.
|
||||||
|
|
||||||
Currently sends to:
|
Currently sends to:
|
||||||
- :attr:`settings.agent_settings.text_belief_extractor_name`
|
- :attr:`settings.agent_settings.text_belief_extractor_name`
|
||||||
|
- The UI via the experiment logger
|
||||||
|
|
||||||
:param transcription: The transcribed text.
|
:param transcription: The transcribed text.
|
||||||
"""
|
"""
|
||||||
receiver_names = [
|
experiment_logger.chat(
|
||||||
settings.agent_settings.text_belief_extractor_name,
|
transcription,
|
||||||
]
|
extra={"role": "user", "reference": self._current_speech_reference, "partial": False},
|
||||||
|
)
|
||||||
|
|
||||||
for receiver_name in receiver_names:
|
message = InternalMessage(
|
||||||
message = InternalMessage(
|
to=settings.agent_settings.text_belief_extractor_name,
|
||||||
to=receiver_name,
|
sender=self.name,
|
||||||
sender=self.name,
|
body=transcription,
|
||||||
body=transcription,
|
)
|
||||||
)
|
await self.send(message)
|
||||||
await self.send(message)
|
|
||||||
|
|
||||||
async def _transcribing_loop(self) -> None:
|
async def _transcribing_loop(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -129,10 +140,9 @@ class TranscriptionAgent(BaseAgent):
|
|||||||
audio = np.frombuffer(audio_data, dtype=np.float32)
|
audio = np.frombuffer(audio_data, dtype=np.float32)
|
||||||
speech = await self._transcribe(audio)
|
speech = await self._transcribe(audio)
|
||||||
if not speech:
|
if not speech:
|
||||||
self.logger.info("Nothing transcribed.")
|
self.logger.debug("Nothing transcribed.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.logger.info("Transcribed speech: %s", speech)
|
|
||||||
await self._share_transcription(speech)
|
await self._share_transcription(speech)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error in transcription loop: {e}")
|
self.logger.error(f"Error in transcription loop: {e}")
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -12,6 +14,8 @@ from control_backend.schemas.internal_message import InternalMessage
|
|||||||
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||||
from .transcription_agent.transcription_agent import TranscriptionAgent
|
from .transcription_agent.transcription_agent import TranscriptionAgent
|
||||||
|
|
||||||
|
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||||
|
|
||||||
|
|
||||||
class SocketPoller[T]:
|
class SocketPoller[T]:
|
||||||
"""
|
"""
|
||||||
@@ -252,6 +256,18 @@ class VADAgent(BaseAgent):
|
|||||||
if prob > prob_threshold:
|
if prob > prob_threshold:
|
||||||
if self.i_since_speech > non_speech_patience + begin_silence_length:
|
if self.i_since_speech > non_speech_patience + begin_silence_length:
|
||||||
self.logger.debug("Speech started.")
|
self.logger.debug("Speech started.")
|
||||||
|
reference = str(uuid.uuid4())
|
||||||
|
experiment_logger.chat(
|
||||||
|
"...",
|
||||||
|
extra={"role": "user", "reference": reference, "partial": True},
|
||||||
|
)
|
||||||
|
await self.send(
|
||||||
|
InternalMessage(
|
||||||
|
to=settings.agent_settings.transcription_name,
|
||||||
|
body=reference,
|
||||||
|
thread="voice_activity",
|
||||||
|
)
|
||||||
|
)
|
||||||
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||||
self.i_since_speech = 0
|
self.i_since_speech = 0
|
||||||
continue
|
continue
|
||||||
@@ -269,9 +285,10 @@ class VADAgent(BaseAgent):
|
|||||||
assert self.audio_out_socket is not None
|
assert self.audio_out_socket is not None
|
||||||
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
|
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
|
||||||
|
|
||||||
# At this point, we know that the speech has ended.
|
# At this point, we know that there is no speech.
|
||||||
# Prepend the last chunk that had no speech, for a more fluent boundary
|
# Prepend the last few chunks that had no speech, for a more fluent boundary.
|
||||||
self.audio_buffer = chunk
|
self.audio_buffer = np.append(self.audio_buffer, chunk)
|
||||||
|
self.audio_buffer = self.audio_buffer[-begin_silence_length * len(chunk) :]
|
||||||
|
|
||||||
async def handle_message(self, msg: InternalMessage):
|
async def handle_message(self, msg: InternalMessage):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import numpy as np
|
|||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio as azmq
|
import zmq.asyncio as azmq
|
||||||
from pydantic_core import ValidationError
|
from pydantic_core import ValidationError
|
||||||
import struct
|
|
||||||
|
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognizer import ( # noqa
|
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognizer import ( # noqa
|
||||||
@@ -89,7 +88,7 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
frame_bytes = await self.video_in_socket.recv()
|
frame_bytes = await self.video_in_socket.recv()
|
||||||
|
|
||||||
# Convert bytes to a numpy buffer
|
# Convert bytes to a numpy buffer
|
||||||
nparr = np.frombuffer(frame_bytes, np.uint8)
|
nparr = np.frombuffer(frame_bytes, np.uint8)
|
||||||
|
|
||||||
@@ -126,7 +125,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
self.logger.warning("No video frame received within timeout.")
|
self.logger.warning("No video frame received within timeout.")
|
||||||
|
|
||||||
|
|
||||||
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
||||||
"""
|
"""
|
||||||
Compare emotions from previous window and current emotions,
|
Compare emotions from previous window and current emotions,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
from zmq.asyncio import Context
|
from zmq.asyncio import Context
|
||||||
@@ -16,6 +17,8 @@ from control_backend.schemas.ri_message import (
|
|||||||
SpeechCommand,
|
SpeechCommand,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
experiment_logger = logging.getLogger(settings.logging_settings.experiment_logger_name)
|
||||||
|
|
||||||
|
|
||||||
class UserInterruptAgent(BaseAgent):
|
class UserInterruptAgent(BaseAgent):
|
||||||
"""
|
"""
|
||||||
@@ -194,6 +197,7 @@ class UserInterruptAgent(BaseAgent):
|
|||||||
case "transition_phase":
|
case "transition_phase":
|
||||||
new_phase_id = msg.body
|
new_phase_id = msg.body
|
||||||
self.logger.info(f"Phase transition detected: {new_phase_id}")
|
self.logger.info(f"Phase transition detected: {new_phase_id}")
|
||||||
|
experiment_logger.observation("Transitioned to next phase.")
|
||||||
|
|
||||||
payload = {"type": "phase_update", "id": new_phase_id}
|
payload = {"type": "phase_update", "id": new_phase_id}
|
||||||
|
|
||||||
@@ -296,6 +300,7 @@ class UserInterruptAgent(BaseAgent):
|
|||||||
|
|
||||||
:param text_to_say: The string that the robot has to say.
|
:param text_to_say: The string that the robot has to say.
|
||||||
"""
|
"""
|
||||||
|
experiment_logger.chat(text_to_say, extra={"role": "assistant"})
|
||||||
cmd = SpeechCommand(data=text_to_say, is_priority=True)
|
cmd = SpeechCommand(data=text_to_say, is_priority=True)
|
||||||
out_msg = InternalMessage(
|
out_msg = InternalMessage(
|
||||||
to=settings.agent_settings.robot_speech_name,
|
to=settings.agent_settings.robot_speech_name,
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
from zmq.asyncio import Context
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
@@ -38,3 +39,29 @@ async def log_stream():
|
|||||||
yield f"data: {message}\n\n"
|
yield f"data: {message}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
LOGGING_DIR = Path(settings.logging_settings.experiment_log_directory).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/logs/files")
|
||||||
|
@router.get("/api/logs/files")
|
||||||
|
async def log_directory():
|
||||||
|
"""
|
||||||
|
Get a list of all log files stored in the experiment log file directory.
|
||||||
|
"""
|
||||||
|
return [f.name for f in LOGGING_DIR.glob("*.log")]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/logs/files/{filename}")
|
||||||
|
@router.get("/api/logs/files/{filename}")
|
||||||
|
async def log_file(filename: str):
|
||||||
|
# Prevent path-traversal
|
||||||
|
file_path = (LOGGING_DIR / filename).resolve() # This .resolve() is important
|
||||||
|
if not file_path.is_relative_to(LOGGING_DIR):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid filename.")
|
||||||
|
|
||||||
|
if not file_path.is_file():
|
||||||
|
raise HTTPException(status_code=404, detail="File not found.")
|
||||||
|
|
||||||
|
return FileResponse(file_path, filename=file_path.name)
|
||||||
|
|||||||
@@ -162,6 +162,20 @@ class SpeechModelSettings(BaseModel):
|
|||||||
openai_model_name: str = "small.en"
|
openai_model_name: str = "small.en"
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingSettings(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for logging.
|
||||||
|
|
||||||
|
:ivar logging_config_file: Path to the logging configuration file.
|
||||||
|
:ivar experiment_log_directory: Location of the experiment logs. Must match the logging config.
|
||||||
|
:ivar experiment_logger_name: Name of the experiment logger. Must match the logging config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logging_config_file: str = ".logging_config.yaml"
|
||||||
|
experiment_log_directory: str = "experiment_logs"
|
||||||
|
experiment_logger_name: str = "experiment"
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Global application settings.
|
Global application settings.
|
||||||
@@ -183,6 +197,8 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
ri_host: str = "localhost"
|
ri_host: str = "localhost"
|
||||||
|
|
||||||
|
logging_settings: LoggingSettings = LoggingSettings()
|
||||||
|
|
||||||
zmq_settings: ZMQSettings = ZMQSettings()
|
zmq_settings: ZMQSettings = ZMQSettings()
|
||||||
|
|
||||||
agent_settings: AgentSettings = AgentSettings()
|
agent_settings: AgentSettings = AgentSettings()
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
|
from .dated_file_handler import DatedFileHandler as DatedFileHandler
|
||||||
|
from .optional_field_formatter import OptionalFieldFormatter as OptionalFieldFormatter
|
||||||
|
from .partial_filter import PartialFilter as PartialFilter
|
||||||
from .setup_logging import setup_logging as setup_logging
|
from .setup_logging import setup_logging as setup_logging
|
||||||
|
|||||||
29
src/control_backend/logging/dated_file_handler.py
Normal file
29
src/control_backend/logging/dated_file_handler.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from logging import FileHandler
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class DatedFileHandler(FileHandler):
|
||||||
|
def __init__(self, file_prefix: str, **kwargs):
|
||||||
|
if not file_prefix:
|
||||||
|
raise ValueError("`file_prefix` argument cannot be empty.")
|
||||||
|
self._file_prefix = file_prefix
|
||||||
|
kwargs["filename"] = self._make_filename()
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _make_filename(self) -> str:
|
||||||
|
filepath = Path(f"{self._file_prefix}-{datetime.now():%Y%m%d-%H%M%S}.log")
|
||||||
|
if not filepath.parent.is_dir():
|
||||||
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
return str(filepath)
|
||||||
|
|
||||||
|
def do_rollover(self):
|
||||||
|
self.acquire()
|
||||||
|
try:
|
||||||
|
if self.stream:
|
||||||
|
self.stream.close()
|
||||||
|
|
||||||
|
self.baseFilename = self._make_filename()
|
||||||
|
self.stream = self._open()
|
||||||
|
finally:
|
||||||
|
self.release()
|
||||||
67
src/control_backend/logging/optional_field_formatter.py
Normal file
67
src/control_backend/logging/optional_field_formatter.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class OptionalFieldFormatter(logging.Formatter):
|
||||||
|
"""
|
||||||
|
Logging formatter that supports optional fields marked by `?`.
|
||||||
|
|
||||||
|
Optional fields are denoted by placing a `?` after the field name inside
|
||||||
|
the parentheses, e.g., `%(role?)s`. If the field is not provided in the
|
||||||
|
log call's `extra` dict, it will use the default value from `defaults`
|
||||||
|
or `None` if no default is specified.
|
||||||
|
|
||||||
|
:param fmt: Format string with optional `%(name?)s` style fields.
|
||||||
|
:type fmt: str or None
|
||||||
|
:param datefmt: Date format string, passed to parent :class:`logging.Formatter`.
|
||||||
|
:type datefmt: str or None
|
||||||
|
:param style: Formatting style, must be '%'. Passed to parent.
|
||||||
|
:type style: str
|
||||||
|
:param defaults: Default values for optional fields when not provided.
|
||||||
|
:type defaults: dict or None
|
||||||
|
|
||||||
|
:example:
|
||||||
|
|
||||||
|
>>> formatter = OptionalFieldFormatter(
|
||||||
|
... fmt="%(asctime)s %(levelname)s %(role?)s %(message)s",
|
||||||
|
... defaults={"role": ""-""}
|
||||||
|
... )
|
||||||
|
>>> handler = logging.StreamHandler()
|
||||||
|
>>> handler.setFormatter(formatter)
|
||||||
|
>>> logger = logging.getLogger(__name__)
|
||||||
|
>>> logger.addHandler(handler)
|
||||||
|
>>>
|
||||||
|
>>> logger.chat("Hello there!", extra={"role": "USER"})
|
||||||
|
2025-01-15 10:30:00 CHAT USER Hello there!
|
||||||
|
>>>
|
||||||
|
>>> logger.info("A logging message")
|
||||||
|
2025-01-15 10:30:01 INFO - A logging message
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Only `%`-style formatting is supported. The `{` and `$` styles are not
|
||||||
|
compatible with this formatter.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
:class:`logging.Formatter` for base formatter documentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Match %(name?)s or %(name?)d etc.
|
||||||
|
OPTIONAL_PATTERN = re.compile(r"%\((\w+)\?\)([sdifFeEgGxXocrba%])")
|
||||||
|
|
||||||
|
def __init__(self, fmt=None, datefmt=None, style="%", defaults=None):
|
||||||
|
self.defaults = defaults or {}
|
||||||
|
|
||||||
|
self.optional_fields = set(self.OPTIONAL_PATTERN.findall(fmt or ""))
|
||||||
|
|
||||||
|
# Convert %(name?)s to %(name)s for standard formatting
|
||||||
|
normalized_fmt = self.OPTIONAL_PATTERN.sub(r"%(\1)\2", fmt or "")
|
||||||
|
|
||||||
|
super().__init__(normalized_fmt, datefmt, style)
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
for field, _ in self.optional_fields:
|
||||||
|
if not hasattr(record, field):
|
||||||
|
default = self.defaults.get(field, None)
|
||||||
|
setattr(record, field, default)
|
||||||
|
|
||||||
|
return super().format(record)
|
||||||
10
src/control_backend/logging/partial_filter.py
Normal file
10
src/control_backend/logging/partial_filter.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class PartialFilter(logging.Filter):
|
||||||
|
"""
|
||||||
|
Class to filter any log records that have the "partial" attribute set to ``True``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
return getattr(record, "partial", False) is not True
|
||||||
@@ -37,7 +37,7 @@ def add_logging_level(level_name: str, level_num: int, method_name: str | None =
|
|||||||
setattr(logging, method_name, log_to_root)
|
setattr(logging, method_name, log_to_root)
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(path: str = ".logging_config.yaml") -> None:
|
def setup_logging(path: str = settings.logging_settings.logging_config_file) -> None:
|
||||||
"""
|
"""
|
||||||
Setup logging configuration of the CB. Tries to load the logging configuration from a file,
|
Setup logging configuration of the CB. Tries to load the logging configuration from a file,
|
||||||
in which we specify custom loggers, formatters, handlers, etc.
|
in which we specify custom loggers, formatters, handlers, etc.
|
||||||
@@ -65,7 +65,7 @@ def setup_logging(path: str = ".logging_config.yaml") -> None:
|
|||||||
|
|
||||||
# Patch ZMQ PUBHandler to know about custom levels
|
# Patch ZMQ PUBHandler to know about custom levels
|
||||||
if custom_levels:
|
if custom_levels:
|
||||||
for logger_name in ("control_backend",):
|
for logger_name in config.get("loggers", {}):
|
||||||
logger = logging.getLogger(logger_name)
|
logger = logging.getLogger(logger_name)
|
||||||
for handler in logger.handlers:
|
for handler in logger.handlers:
|
||||||
if isinstance(handler, PUBHandler):
|
if isinstance(handler, PUBHandler):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
@@ -61,3 +61,67 @@ async def test_log_stream_endpoint_lines(client):
|
|||||||
# Optional: assert subscribe/connect were called
|
# Optional: assert subscribe/connect were called
|
||||||
assert dummy_socket.subscribed # at least some log levels subscribed
|
assert dummy_socket.subscribed # at least some log levels subscribed
|
||||||
assert dummy_socket.connected # connect was called
|
assert dummy_socket.connected # connect was called
|
||||||
|
|
||||||
|
|
||||||
|
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||||
|
def test_files_endpoint(LOGGING_DIR, client):
|
||||||
|
file_1, file_2 = MagicMock(), MagicMock()
|
||||||
|
file_1.name = "file_1"
|
||||||
|
file_2.name = "file_2"
|
||||||
|
LOGGING_DIR.glob.return_value = [file_1, file_2]
|
||||||
|
result = client.get("/api/logs/files")
|
||||||
|
|
||||||
|
assert result.status_code == 200
|
||||||
|
assert result.json() == ["file_1", "file_2"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("control_backend.api.v1.endpoints.logs.FileResponse")
|
||||||
|
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||||
|
def test_log_file_endpoint_success(LOGGING_DIR, MockFileResponse, client):
|
||||||
|
mock_file_path = MagicMock()
|
||||||
|
mock_file_path.is_relative_to.return_value = True
|
||||||
|
mock_file_path.is_file.return_value = True
|
||||||
|
mock_file_path.name = "test.log"
|
||||||
|
|
||||||
|
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||||
|
mock_file_path.resolve.return_value = mock_file_path
|
||||||
|
|
||||||
|
MockFileResponse.return_value = MagicMock()
|
||||||
|
|
||||||
|
result = client.get("/api/logs/files/test.log")
|
||||||
|
|
||||||
|
assert result.status_code == 200
|
||||||
|
MockFileResponse.assert_called_once_with(mock_file_path, filename="test.log")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||||
|
async def test_log_file_endpoint_path_traversal(LOGGING_DIR):
|
||||||
|
from control_backend.api.v1.endpoints.logs import log_file
|
||||||
|
|
||||||
|
mock_file_path = MagicMock()
|
||||||
|
mock_file_path.is_relative_to.return_value = False
|
||||||
|
|
||||||
|
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||||
|
mock_file_path.resolve.return_value = mock_file_path
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await log_file("../secret.txt")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert exc_info.value.detail == "Invalid filename."
|
||||||
|
|
||||||
|
|
||||||
|
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||||
|
def test_log_file_endpoint_file_not_found(LOGGING_DIR, client):
|
||||||
|
mock_file_path = MagicMock()
|
||||||
|
mock_file_path.is_relative_to.return_value = True
|
||||||
|
mock_file_path.is_file.return_value = False
|
||||||
|
|
||||||
|
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||||
|
mock_file_path.resolve.return_value = mock_file_path
|
||||||
|
|
||||||
|
result = client.get("/api/logs/files/nonexistent.log")
|
||||||
|
|
||||||
|
assert result.status_code == 404
|
||||||
|
assert result.json()["detail"] == "File not found."
|
||||||
|
|||||||
45
test/unit/logging/test_dated_file_handler.py
Normal file
45
test/unit/logging/test_dated_file_handler.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from control_backend.logging.dated_file_handler import DatedFileHandler
|
||||||
|
|
||||||
|
|
||||||
|
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||||
|
def test_reset(open_):
|
||||||
|
stream = MagicMock()
|
||||||
|
open_.return_value = stream
|
||||||
|
|
||||||
|
# A file should be opened when the logger is created
|
||||||
|
handler = DatedFileHandler(file_prefix="anything")
|
||||||
|
assert open_.call_count == 1
|
||||||
|
|
||||||
|
# Upon reset, the current file should be closed, and a new one should be opened
|
||||||
|
handler.do_rollover()
|
||||||
|
assert stream.close.call_count == 1
|
||||||
|
assert open_.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@patch("control_backend.logging.dated_file_handler.Path")
|
||||||
|
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||||
|
def test_creates_dir(open_, Path_):
|
||||||
|
stream = MagicMock()
|
||||||
|
open_.return_value = stream
|
||||||
|
|
||||||
|
test_path = MagicMock()
|
||||||
|
test_path.parent.is_dir.return_value = False
|
||||||
|
Path_.return_value = test_path
|
||||||
|
|
||||||
|
DatedFileHandler(file_prefix="anything")
|
||||||
|
|
||||||
|
# The directory should've been created
|
||||||
|
test_path.parent.mkdir.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||||
|
def test_invalid_constructor(_):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DatedFileHandler(file_prefix=None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DatedFileHandler(file_prefix="")
|
||||||
218
test/unit/logging/test_optional_field_formatter.py
Normal file
218
test/unit/logging/test_optional_field_formatter.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from control_backend.logging.optional_field_formatter import OptionalFieldFormatter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def logger():
|
||||||
|
"""Create a fresh logger for each test."""
|
||||||
|
logger = logging.getLogger(f"test_{id(object())}")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.handlers = []
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def log_output(logger):
|
||||||
|
"""Capture log output and return a function to get it."""
|
||||||
|
|
||||||
|
class ListHandler(logging.Handler):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.records = []
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
self.records.append(self.format(record))
|
||||||
|
|
||||||
|
handler = ListHandler()
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def get_output():
|
||||||
|
return handler.records
|
||||||
|
|
||||||
|
return get_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_field_present(logger, log_output):
|
||||||
|
"""Optional field should appear when provided in extra."""
|
||||||
|
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test message", extra={"role": "user"})
|
||||||
|
|
||||||
|
assert log_output() == ["INFO - user - test message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_field_missing_no_default(logger, log_output):
|
||||||
|
"""Missing optional field with no default should be None."""
|
||||||
|
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test message")
|
||||||
|
|
||||||
|
assert log_output() == ["INFO - None - test message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_field_missing_with_default(logger, log_output):
|
||||||
|
"""Missing optional field should use provided default."""
|
||||||
|
formatter = OptionalFieldFormatter(
|
||||||
|
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
|
||||||
|
)
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test message")
|
||||||
|
|
||||||
|
assert log_output() == ["INFO - assistant - test message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_field_overrides_default(logger, log_output):
|
||||||
|
"""Provided extra value should override default."""
|
||||||
|
formatter = OptionalFieldFormatter(
|
||||||
|
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
|
||||||
|
)
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test message", extra={"role": "user"})
|
||||||
|
|
||||||
|
assert log_output() == ["INFO - user - test message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_optional_fields(logger, log_output):
|
||||||
|
"""Multiple optional fields should work independently."""
|
||||||
|
formatter = OptionalFieldFormatter(
|
||||||
|
"%(levelname)s - %(role?)s - %(request_id?)s - %(message)s", defaults={"role": "assistant"}
|
||||||
|
)
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test", extra={"request_id": "123"})
|
||||||
|
|
||||||
|
assert log_output() == ["INFO - assistant - 123 - test"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mixed_optional_and_required_fields(logger, log_output):
|
||||||
|
"""Standard fields should work alongside optional fields."""
|
||||||
|
formatter = OptionalFieldFormatter("%(levelname)s %(name)s %(role?)s %(message)s")
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test", extra={"role": "user"})
|
||||||
|
|
||||||
|
output = log_output()[0]
|
||||||
|
assert "INFO" in output
|
||||||
|
assert "user" in output
|
||||||
|
assert "test" in output
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_optional_fields(logger, log_output):
|
||||||
|
"""Formatter should work normally with no optional fields."""
|
||||||
|
formatter = OptionalFieldFormatter("%(levelname)s %(message)s")
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test message")
|
||||||
|
|
||||||
|
assert log_output() == ["INFO test message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_integer_format_specifier(logger, log_output):
|
||||||
|
"""Optional fields with %d specifier should work."""
|
||||||
|
formatter = OptionalFieldFormatter(
|
||||||
|
"%(levelname)s %(count?)d %(message)s", defaults={"count": 0}
|
||||||
|
)
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test", extra={"count": 42})
|
||||||
|
|
||||||
|
assert log_output() == ["INFO 42 test"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_float_format_specifier(logger, log_output):
|
||||||
|
"""Optional fields with %f specifier should work."""
|
||||||
|
formatter = OptionalFieldFormatter(
|
||||||
|
"%(levelname)s %(duration?)f %(message)s", defaults={"duration": 0.0}
|
||||||
|
)
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test", extra={"duration": 1.5})
|
||||||
|
|
||||||
|
assert "1.5" in log_output()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_string_default(logger, log_output):
|
||||||
|
"""Empty string default should work."""
|
||||||
|
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults={"role": ""})
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test")
|
||||||
|
|
||||||
|
assert log_output() == ["INFO test"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_none_format_string():
|
||||||
|
"""None format string should not raise."""
|
||||||
|
formatter = OptionalFieldFormatter(fmt=None)
|
||||||
|
assert formatter.optional_fields == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_fields_parsed_correctly():
|
||||||
|
"""Check that optional fields are correctly identified."""
|
||||||
|
formatter = OptionalFieldFormatter("%(asctime)s %(role?)s %(level?)d %(name)s")
|
||||||
|
|
||||||
|
assert formatter.optional_fields == {("role", "s"), ("level", "d")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_string_normalized():
|
||||||
|
"""Check that ? is removed from format string."""
|
||||||
|
formatter = OptionalFieldFormatter("%(role?)s %(message)s")
|
||||||
|
|
||||||
|
assert "?" not in formatter._style._fmt
|
||||||
|
assert "%(role)s" in formatter._style._fmt
|
||||||
|
|
||||||
|
|
||||||
|
def test_field_with_underscore(logger, log_output):
|
||||||
|
"""Field names with underscores should work."""
|
||||||
|
formatter = OptionalFieldFormatter("%(levelname)s %(user_id?)s %(message)s")
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test", extra={"user_id": "abc123"})
|
||||||
|
|
||||||
|
assert log_output() == ["INFO abc123 test"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_field_with_numbers(logger, log_output):
|
||||||
|
"""Field names with numbers should work."""
|
||||||
|
formatter = OptionalFieldFormatter("%(levelname)s %(field2?)s %(message)s")
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test", extra={"field2": "value"})
|
||||||
|
|
||||||
|
assert log_output() == ["INFO value test"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_log_calls(logger, log_output):
|
||||||
|
"""Formatter should work correctly across multiple log calls."""
|
||||||
|
formatter = OptionalFieldFormatter(
|
||||||
|
"%(levelname)s %(role?)s %(message)s", defaults={"role": "other"}
|
||||||
|
)
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("first", extra={"role": "assistant"})
|
||||||
|
logger.info("second")
|
||||||
|
logger.info("third", extra={"role": "user"})
|
||||||
|
|
||||||
|
assert log_output() == [
|
||||||
|
"INFO assistant first",
|
||||||
|
"INFO other second",
|
||||||
|
"INFO user third",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_not_mutated(logger, log_output):
|
||||||
|
"""Original defaults dict should not be mutated."""
|
||||||
|
defaults = {"role": "other"}
|
||||||
|
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults=defaults)
|
||||||
|
logger.handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.info("test")
|
||||||
|
|
||||||
|
assert defaults == {"role": "other"}
|
||||||
83
test/unit/logging/test_partial_filter.py
Normal file
83
test/unit/logging/test_partial_filter.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from control_backend.logging import PartialFilter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def logger():
|
||||||
|
"""Create a fresh logger for each test."""
|
||||||
|
logger = logging.getLogger(f"test_{id(object())}")
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.handlers = []
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def log_output(logger):
|
||||||
|
"""Capture log output and return a function to get it."""
|
||||||
|
|
||||||
|
class ListHandler(logging.Handler):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.records = []
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
self.records.append(self.format(record))
|
||||||
|
|
||||||
|
handler = ListHandler()
|
||||||
|
handler.addFilter(PartialFilter())
|
||||||
|
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
return lambda: handler.records
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_partial_attribute(logger, log_output):
|
||||||
|
"""Records without partial attribute should pass through."""
|
||||||
|
logger.info("normal message")
|
||||||
|
|
||||||
|
assert log_output() == ["normal message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_true_filtered(logger, log_output):
|
||||||
|
"""Records with partial=True should be filtered out."""
|
||||||
|
logger.info("partial message", extra={"partial": True})
|
||||||
|
|
||||||
|
assert log_output() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_false_passes(logger, log_output):
|
||||||
|
"""Records with partial=False should pass through."""
|
||||||
|
logger.info("complete message", extra={"partial": False})
|
||||||
|
|
||||||
|
assert log_output() == ["complete message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_none_passes(logger, log_output):
|
||||||
|
"""Records with partial=None should pass through."""
|
||||||
|
logger.info("message", extra={"partial": None})
|
||||||
|
|
||||||
|
assert log_output() == ["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_truthy_value_passes(logger, log_output):
|
||||||
|
"""
|
||||||
|
Records with truthy but non-True partial should pass through, that is, only when it's exactly
|
||||||
|
``True`` should it pass.
|
||||||
|
"""
|
||||||
|
logger.info("message", extra={"partial": "yes"})
|
||||||
|
|
||||||
|
assert log_output() == ["message"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_records_mixed(logger, log_output):
|
||||||
|
"""Filter should handle mixed records correctly."""
|
||||||
|
logger.info("first")
|
||||||
|
logger.info("second", extra={"partial": True})
|
||||||
|
logger.info("third", extra={"partial": False})
|
||||||
|
logger.info("fourth", extra={"partial": True})
|
||||||
|
logger.info("fifth")
|
||||||
|
|
||||||
|
assert log_output() == ["first", "third", "fifth"]
|
||||||
Reference in New Issue
Block a user