Merge branch 'dev' into feat/visual-emotion-recognition
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
@@ -19,6 +19,12 @@ def zmq_context(mocker):
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch("control_backend.agents.actuation.robot_gesture_agent.experiment_logger") as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(zmq_context, mocker):
|
||||
"""Setup binds and subscribes to internal commands."""
|
||||
|
||||
@@ -26,6 +26,12 @@ def agent():
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch("control_backend.agents.bdi.bdi_core_agent.experiment_logger") as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_loads_asl(mock_agentspeak_env, agent):
|
||||
# Mock file opening
|
||||
|
||||
@@ -8,7 +8,17 @@ import pytest
|
||||
|
||||
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.schemas.program import BasicNorm, Goal, Phase, Plan, Program
|
||||
from control_backend.schemas.program import (
|
||||
BasicNorm,
|
||||
ConditionalNorm,
|
||||
Goal,
|
||||
InferredBelief,
|
||||
KeywordBelief,
|
||||
Phase,
|
||||
Plan,
|
||||
Program,
|
||||
Trigger,
|
||||
)
|
||||
|
||||
# Fix Windows Proactor loop for zmq
|
||||
if sys.platform.startswith("win"):
|
||||
@@ -59,7 +69,7 @@ async def test_create_agentspeak_and_send_to_bdi(mock_settings):
|
||||
await manager._create_agentspeak_and_send_to_bdi(program)
|
||||
|
||||
# Check file writing
|
||||
mock_file.assert_called_with("src/control_backend/agents/bdi/agentspeak.asl", "w")
|
||||
mock_file.assert_called_with(mock_settings.behaviour_settings.agentspeak_file, "w")
|
||||
handle = mock_file()
|
||||
handle.write.assert_called()
|
||||
|
||||
@@ -67,7 +77,7 @@ async def test_create_agentspeak_and_send_to_bdi(mock_settings):
|
||||
msg: InternalMessage = manager.send.await_args[0][0]
|
||||
assert msg.thread == "new_program"
|
||||
assert msg.to == mock_settings.agent_settings.bdi_core_name
|
||||
assert msg.body == "src/control_backend/agents/bdi/agentspeak.asl"
|
||||
assert msg.body == mock_settings.behaviour_settings.agentspeak_file
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -295,3 +305,98 @@ async def test_setup(mock_settings):
|
||||
|
||||
# 3. Adds behavior
|
||||
manager.add_behavior.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_program_to_user_interrupt(mock_settings):
|
||||
"""Test directly sending the program to the user interrupt agent."""
|
||||
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
manager.send = AsyncMock()
|
||||
|
||||
program = Program.model_validate_json(make_valid_program_json())
|
||||
|
||||
await manager._send_program_to_user_interrupt(program)
|
||||
|
||||
assert manager.send.await_count == 1
|
||||
msg = manager.send.await_args[0][0]
|
||||
assert msg.to == "user_interrupt_agent"
|
||||
assert msg.thread == "new_program"
|
||||
assert "Basic Phase" in msg.body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_program_extraction():
|
||||
manager = BDIProgramManager(name="program_manager_test")
|
||||
|
||||
# 1. Create Complex Components
|
||||
|
||||
# Inferred Belief (A & B)
|
||||
belief_left = KeywordBelief(id=uuid.uuid4(), name="b1", keyword="hot")
|
||||
belief_right = KeywordBelief(id=uuid.uuid4(), name="b2", keyword="sunny")
|
||||
inferred_belief = InferredBelief(
|
||||
id=uuid.uuid4(), name="b_inf", operator="AND", left=belief_left, right=belief_right
|
||||
)
|
||||
|
||||
# Conditional Norm
|
||||
cond_norm = ConditionalNorm(
|
||||
id=uuid.uuid4(), name="norm_cond", norm="wear_hat", condition=inferred_belief
|
||||
)
|
||||
|
||||
# Trigger with Inferred Belief condition
|
||||
dummy_plan = Plan(id=uuid.uuid4(), name="dummy_plan", steps=[])
|
||||
trigger = Trigger(id=uuid.uuid4(), name="trigger_1", condition=inferred_belief, plan=dummy_plan)
|
||||
|
||||
# Nested Goal
|
||||
sub_goal = Goal(
|
||||
id=uuid.uuid4(),
|
||||
name="sub_goal",
|
||||
description="desc",
|
||||
plan=Plan(id=uuid.uuid4(), name="empty", steps=[]),
|
||||
can_fail=True,
|
||||
)
|
||||
|
||||
parent_goal = Goal(
|
||||
id=uuid.uuid4(),
|
||||
name="parent_goal",
|
||||
description="desc",
|
||||
# The plan contains the sub_goal as a step
|
||||
plan=Plan(id=uuid.uuid4(), name="parent_plan", steps=[sub_goal]),
|
||||
can_fail=False,
|
||||
)
|
||||
|
||||
# 2. Assemble Program
|
||||
phase = Phase(
|
||||
id=uuid.uuid4(),
|
||||
name="Complex Phase",
|
||||
norms=[cond_norm],
|
||||
goals=[parent_goal],
|
||||
triggers=[trigger],
|
||||
)
|
||||
program = Program(phases=[phase])
|
||||
|
||||
# 3. Initialize Internal State (Triggers _populate_goal_mapping -> Nested Goal logic)
|
||||
manager._initialize_internal_state(program)
|
||||
|
||||
# Assertion for Line 53-54 (Mapping population)
|
||||
# Both parent and sub-goal should be mapped
|
||||
assert str(parent_goal.id) in manager._goal_mapping
|
||||
assert str(sub_goal.id) in manager._goal_mapping
|
||||
|
||||
# 4. Test Belief Extraction (Triggers lines 132-133, 142-146)
|
||||
beliefs = manager._extract_current_beliefs()
|
||||
|
||||
# Should extract recursive beliefs from cond_norm and trigger
|
||||
# Inferred belief splits into Left + Right. Since we use it twice, we get duplicates
|
||||
# checking existence is enough.
|
||||
belief_names = [b.name for b in beliefs]
|
||||
assert "b1" in belief_names
|
||||
assert "b2" in belief_names
|
||||
|
||||
# 5. Test Goal Extraction (Triggers lines 173, 185)
|
||||
goals = manager._extract_current_goals()
|
||||
|
||||
goal_names = [g.name for g in goals]
|
||||
assert "parent_goal" in goal_names
|
||||
assert "sub_goal" in goal_names
|
||||
|
||||
@@ -18,6 +18,12 @@ def mock_httpx_client():
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch("control_backend.agents.llm.llm_agent.experiment_logger") as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_processing_success(mock_httpx_client, mock_settings):
|
||||
# Setup the mock response for the stream
|
||||
|
||||
@@ -14,6 +14,15 @@ from control_backend.agents.perception.transcription_agent.transcription_agent i
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch(
|
||||
"control_backend.agents.perception"
|
||||
".transcription_agent.transcription_agent.experiment_logger"
|
||||
) as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_agent_flow(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
|
||||
@@ -24,7 +24,9 @@ def audio_out_socket():
|
||||
|
||||
@pytest.fixture
|
||||
def vad_agent(audio_out_socket):
|
||||
return VADAgent("tcp://localhost:5555", False)
|
||||
agent = VADAgent("tcp://localhost:5555", False)
|
||||
agent._internal_pub_socket = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -44,6 +46,12 @@ def patch_settings(monkeypatch):
|
||||
monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch("control_backend.agents.perception.vad_agent.experiment_logger") as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||
"""
|
||||
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
||||
@@ -84,14 +92,15 @@ async def test_voice_activity_detected(audio_out_socket, vad_agent):
|
||||
Test a scenario where there is voice activity detected between silences.
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks
|
||||
probabilities = [0.0] * 15 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) == 512 * 4 * (speech_chunk_count + 1)
|
||||
assert len(data) == 512 * 4 * (begin_silence_chunks + speech_chunk_count)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -101,8 +110,9 @@ async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
|
||||
short pause.
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
begin_silence_chunks = settings.behaviour_settings.vad_begin_silence_chunks
|
||||
probabilities = (
|
||||
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
[0.0] * 15 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
)
|
||||
vad_agent.audio_out_socket = audio_out_socket
|
||||
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
||||
@@ -110,8 +120,8 @@ async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
assert isinstance(data, bytes)
|
||||
# Expecting 13 chunks (2*5 with speech, 1 pause between, 1 as padding)
|
||||
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 1)
|
||||
# Expecting 13 chunks (2*5 with speech, 1 pause between, begin_silence_chunks as padding)
|
||||
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + begin_silence_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_message import BeliefMessage
|
||||
from control_backend.schemas.program import (
|
||||
ConditionalNorm,
|
||||
Goal,
|
||||
@@ -29,6 +30,14 @@ def agent():
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_experiment_logger():
|
||||
with patch(
|
||||
"control_backend.agents.user_interrupt.user_interrupt_agent.experiment_logger"
|
||||
) as logger:
|
||||
yield logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_speech_agent(agent):
|
||||
"""Verify speech command format."""
|
||||
@@ -309,3 +318,375 @@ async def test_send_pause_command(agent):
|
||||
m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name
|
||||
).args[0]
|
||||
assert vad_msg.body == "RESUME"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup(agent):
|
||||
"""Test the setup method initializes sockets correctly."""
|
||||
with patch("control_backend.agents.user_interrupt.user_interrupt_agent.Context") as MockContext:
|
||||
mock_ctx_instance = MagicMock()
|
||||
MockContext.instance.return_value = mock_ctx_instance
|
||||
|
||||
mock_sub = MagicMock()
|
||||
mock_pub = MagicMock()
|
||||
mock_ctx_instance.socket.side_effect = [mock_sub, mock_pub]
|
||||
|
||||
# MOCK add_behavior so we don't rely on internal attributes
|
||||
agent.add_behavior = MagicMock()
|
||||
|
||||
await agent.setup()
|
||||
|
||||
# Check sockets
|
||||
mock_sub.connect.assert_called_with(settings.zmq_settings.internal_sub_address)
|
||||
mock_pub.connect.assert_called_with(settings.zmq_settings.internal_pub_address)
|
||||
|
||||
# Verify add_behavior was called
|
||||
agent.add_behavior.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_json_error(agent):
|
||||
"""Verify that malformed JSON is caught and logged without crashing the loop."""
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"topic", b"INVALID{JSON"),
|
||||
asyncio.CancelledError,
|
||||
]
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_override_trigger(agent):
|
||||
"""Verify routing 'override' to a Trigger."""
|
||||
agent._trigger_map["101"] = "trigger_slug"
|
||||
payload = json.dumps({"type": "override", "context": "101"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
||||
agent._send_to_bdi = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent._send_to_bdi.assert_awaited_once_with("force_trigger", "trigger_slug")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_override_norm(agent):
|
||||
"""Verify routing 'override' to a Conditional Norm."""
|
||||
agent._cond_norm_map["202"] = "norm_slug"
|
||||
payload = json.dumps({"type": "override", "context": "202"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
||||
agent._send_to_bdi_belief = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent._send_to_bdi_belief.assert_awaited_once_with("norm_slug", "cond_norm")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_override_missing(agent):
|
||||
"""Verify warning log when an override ID is not found in any map."""
|
||||
payload = json.dumps({"type": "override", "context": "999"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent.logger.warning.assert_called_with("Could not determine which element to override.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_unachieve_logic(agent):
|
||||
"""Verify success and failure paths for override_unachieve."""
|
||||
agent._cond_norm_map["202"] = "norm_slug"
|
||||
|
||||
success_payload = json.dumps({"type": "override_unachieve", "context": "202"}).encode()
|
||||
fail_payload = json.dumps({"type": "override_unachieve", "context": "999"}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"topic", success_payload),
|
||||
(b"topic", fail_payload),
|
||||
asyncio.CancelledError,
|
||||
]
|
||||
agent._send_to_bdi_belief = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Assert success call (True flag for unachieve)
|
||||
agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True)
|
||||
# Assert failure log
|
||||
agent.logger.warning.assert_called_with(
|
||||
"Could not determine which conditional norm to unachieve."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_pause_resume(agent):
|
||||
"""Verify pause and resume toggle logic and logging."""
|
||||
pause_payload = json.dumps({"type": "pause", "context": "true"}).encode()
|
||||
resume_payload = json.dumps({"type": "pause", "context": ""}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"topic", pause_payload),
|
||||
(b"topic", resume_payload),
|
||||
asyncio.CancelledError,
|
||||
]
|
||||
agent._send_pause_command = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent._send_pause_command.assert_any_call("true")
|
||||
agent._send_pause_command.assert_any_call("")
|
||||
agent.logger.info.assert_any_call("Sent pause command.")
|
||||
agent.logger.info.assert_any_call("Sent resume command.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_phase_control(agent):
|
||||
"""Verify experiment flow control (next_phase)."""
|
||||
payload = json.dumps({"type": "next_phase", "context": ""}).encode()
|
||||
|
||||
agent.sub_socket.recv_multipart.side_effect = [(b"topic", payload), asyncio.CancelledError]
|
||||
agent._send_experiment_control_to_bdi_core = AsyncMock()
|
||||
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
agent._send_experiment_control_to_bdi_core.assert_awaited_once_with("next_phase")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_unknown_thread(agent):
|
||||
"""Test handling of an unknown message thread (lines 213-214)."""
|
||||
msg = InternalMessage(to="me", thread="unknown_thread", body="test")
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.logger.debug.assert_called_with(
|
||||
"Received internal message on unhandled thread: unknown_thread"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_bdi_belief_edge_cases(agent):
|
||||
"""
|
||||
Covers:
|
||||
- Unknown asl_type warning (lines 326-328)
|
||||
- unachieve=True logic (lines 334-337)
|
||||
"""
|
||||
# 1. Unknown Type
|
||||
await agent._send_to_bdi_belief("slug", "unknown_type")
|
||||
|
||||
agent.logger.warning.assert_called_with("Tried to send belief with unknown type")
|
||||
agent.send.assert_not_called()
|
||||
|
||||
# Reset mock for part 2
|
||||
agent.send.reset_mock()
|
||||
|
||||
# 2. Unachieve = True
|
||||
await agent._send_to_bdi_belief("slug", "cond_norm", unachieve=True)
|
||||
|
||||
agent.send.assert_awaited()
|
||||
sent_msg = agent.send.call_args.args[0]
|
||||
|
||||
# Verify it is a delete operation
|
||||
body_obj = BeliefMessage.model_validate_json(sent_msg.body)
|
||||
|
||||
# Verify 'delete' has content
|
||||
assert body_obj.delete is not None
|
||||
assert len(body_obj.delete) == 1
|
||||
assert body_obj.delete[0].name == "force_slug"
|
||||
|
||||
# Verify 'create' is empty (handling both None and [])
|
||||
assert not body_obj.create
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_experiment_control_unknown(agent):
|
||||
"""Test sending an unknown experiment control type (lines 366-367)."""
|
||||
await agent._send_experiment_control_to_bdi_core("invalid_command")
|
||||
|
||||
agent.logger.warning.assert_called_with(
|
||||
"Received unknown experiment control type '%s' to send to BDI Core.", "invalid_command"
|
||||
)
|
||||
|
||||
# Ensure it still sends an empty message (as per code logic, though thread is empty)
|
||||
agent.send.assert_awaited()
|
||||
msg = agent.send.call_args[0][0]
|
||||
assert msg.thread == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mapping_recursive_goals(agent):
|
||||
"""Verify that nested subgoals are correctly registered in the mapping."""
|
||||
import uuid
|
||||
|
||||
# 1. Setup IDs
|
||||
parent_goal_id = uuid.uuid4()
|
||||
child_goal_id = uuid.uuid4()
|
||||
|
||||
# 2. Create the child goal
|
||||
child_goal = Goal(
|
||||
id=child_goal_id,
|
||||
name="child_goal",
|
||||
description="I am a subgoal",
|
||||
plan=Plan(id=uuid.uuid4(), name="p_child", steps=[]),
|
||||
)
|
||||
|
||||
# 3. Create the parent goal and put the child goal inside its plan steps
|
||||
parent_goal = Goal(
|
||||
id=parent_goal_id,
|
||||
name="parent_goal",
|
||||
description="I am a parent",
|
||||
plan=Plan(id=uuid.uuid4(), name="p_parent", steps=[child_goal]), # Nested here
|
||||
)
|
||||
|
||||
# 4. Build the program
|
||||
phase = Phase(
|
||||
id=uuid.uuid4(),
|
||||
name="phase1",
|
||||
norms=[],
|
||||
goals=[parent_goal], # Only the parent is top-level
|
||||
triggers=[],
|
||||
)
|
||||
prog = Program(phases=[phase])
|
||||
|
||||
# 5. Execute mapping
|
||||
msg = InternalMessage(to="me", thread="new_program", body=prog.model_dump_json())
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# 6. Assertions
|
||||
# Check parent
|
||||
assert str(parent_goal_id) in agent._goal_map
|
||||
assert agent._goal_map[str(parent_goal_id)] == "parent_goal"
|
||||
|
||||
# Check child (This confirms the recursion worked)
|
||||
assert str(child_goal_id) in agent._goal_map
|
||||
assert agent._goal_map[str(child_goal_id)] == "child_goal"
|
||||
assert agent._goal_reverse_map["child_goal"] == str(child_goal_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_loop_advanced_scenarios(agent):
|
||||
"""
|
||||
Covers:
|
||||
- JSONDecodeError (lines 86-88)
|
||||
- Override: Trigger found (lines 108-109)
|
||||
- Override: Norm found (lines 114-115)
|
||||
- Override: Nothing found (line 134)
|
||||
- Override Unachieve: Success & Fail (lines 136-145)
|
||||
- Pause: Context true/false logs (lines 150-157)
|
||||
- Next Phase (line 160)
|
||||
"""
|
||||
# 1. Setup Data Maps
|
||||
agent._trigger_map["101"] = "trigger_slug"
|
||||
agent._cond_norm_map["202"] = "norm_slug"
|
||||
|
||||
# 2. Define Payloads
|
||||
# A. Invalid JSON
|
||||
bad_json = b"INVALID{JSON"
|
||||
|
||||
# B. Override -> Trigger
|
||||
override_trigger = json.dumps({"type": "override", "context": "101"}).encode()
|
||||
|
||||
# C. Override -> Norm
|
||||
override_norm = json.dumps({"type": "override", "context": "202"}).encode()
|
||||
|
||||
# D. Override -> Unknown
|
||||
override_fail = json.dumps({"type": "override", "context": "999"}).encode()
|
||||
|
||||
# E. Unachieve -> Success
|
||||
unachieve_success = json.dumps({"type": "override_unachieve", "context": "202"}).encode()
|
||||
|
||||
# F. Unachieve -> Fail
|
||||
unachieve_fail = json.dumps({"type": "override_unachieve", "context": "999"}).encode()
|
||||
|
||||
# G. Pause (True)
|
||||
pause_true = json.dumps({"type": "pause", "context": "true"}).encode()
|
||||
|
||||
# H. Pause (False/Resume)
|
||||
pause_false = json.dumps({"type": "pause", "context": ""}).encode()
|
||||
|
||||
# I. Next Phase
|
||||
next_phase = json.dumps({"type": "next_phase", "context": ""}).encode()
|
||||
|
||||
# 3. Setup Socket
|
||||
agent.sub_socket.recv_multipart.side_effect = [
|
||||
(b"topic", bad_json),
|
||||
(b"topic", override_trigger),
|
||||
(b"topic", override_norm),
|
||||
(b"topic", override_fail),
|
||||
(b"topic", unachieve_success),
|
||||
(b"topic", unachieve_fail),
|
||||
(b"topic", pause_true),
|
||||
(b"topic", pause_false),
|
||||
(b"topic", next_phase),
|
||||
asyncio.CancelledError, # End loop
|
||||
]
|
||||
|
||||
# Mock internal helpers to verify calls
|
||||
agent._send_to_bdi = AsyncMock()
|
||||
agent._send_to_bdi_belief = AsyncMock()
|
||||
agent._send_pause_command = AsyncMock()
|
||||
agent._send_experiment_control_to_bdi_core = AsyncMock()
|
||||
|
||||
# 4. Run Loop
|
||||
try:
|
||||
await agent._receive_button_event()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 5. Assertions
|
||||
|
||||
# JSON Error
|
||||
agent.logger.error.assert_called_with("Received invalid JSON payload on topic %s", b"topic")
|
||||
|
||||
# Override Trigger
|
||||
agent._send_to_bdi.assert_awaited_with("force_trigger", "trigger_slug")
|
||||
|
||||
# Override Norm
|
||||
# We expect _send_to_bdi_belief to be called for the norm
|
||||
# Note: The loop calls _send_to_bdi_belief(asl_cond_norm, "cond_norm")
|
||||
agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm")
|
||||
|
||||
# Override Fail (Warning log)
|
||||
agent.logger.warning.assert_any_call("Could not determine which element to override.")
|
||||
|
||||
# Unachieve Success
|
||||
# Loop calls _send_to_bdi_belief(asl_cond_norm, "cond_norm", True)
|
||||
agent._send_to_bdi_belief.assert_any_call("norm_slug", "cond_norm", True)
|
||||
|
||||
# Unachieve Fail
|
||||
agent.logger.warning.assert_any_call("Could not determine which conditional norm to unachieve.")
|
||||
|
||||
# Pause Logic
|
||||
agent._send_pause_command.assert_any_call("true")
|
||||
agent.logger.info.assert_any_call("Sent pause command.")
|
||||
|
||||
# Resume Logic
|
||||
agent._send_pause_command.assert_any_call("")
|
||||
agent.logger.info.assert_any_call("Sent resume command.")
|
||||
|
||||
# Next Phase
|
||||
agent._send_experiment_control_to_bdi_core.assert_awaited_with("next_phase")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
@@ -61,3 +61,67 @@ async def test_log_stream_endpoint_lines(client):
|
||||
# Optional: assert subscribe/connect were called
|
||||
assert dummy_socket.subscribed # at least some log levels subscribed
|
||||
assert dummy_socket.connected # connect was called
|
||||
|
||||
|
||||
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||
def test_files_endpoint(LOGGING_DIR, client):
|
||||
file_1, file_2 = MagicMock(), MagicMock()
|
||||
file_1.name = "file_1"
|
||||
file_2.name = "file_2"
|
||||
LOGGING_DIR.glob.return_value = [file_1, file_2]
|
||||
result = client.get("/api/logs/files")
|
||||
|
||||
assert result.status_code == 200
|
||||
assert result.json() == ["file_1", "file_2"]
|
||||
|
||||
|
||||
@patch("control_backend.api.v1.endpoints.logs.FileResponse")
|
||||
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||
def test_log_file_endpoint_success(LOGGING_DIR, MockFileResponse, client):
|
||||
mock_file_path = MagicMock()
|
||||
mock_file_path.is_relative_to.return_value = True
|
||||
mock_file_path.is_file.return_value = True
|
||||
mock_file_path.name = "test.log"
|
||||
|
||||
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||
mock_file_path.resolve.return_value = mock_file_path
|
||||
|
||||
MockFileResponse.return_value = MagicMock()
|
||||
|
||||
result = client.get("/api/logs/files/test.log")
|
||||
|
||||
assert result.status_code == 200
|
||||
MockFileResponse.assert_called_once_with(mock_file_path, filename="test.log")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||
async def test_log_file_endpoint_path_traversal(LOGGING_DIR):
|
||||
from control_backend.api.v1.endpoints.logs import log_file
|
||||
|
||||
mock_file_path = MagicMock()
|
||||
mock_file_path.is_relative_to.return_value = False
|
||||
|
||||
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||
mock_file_path.resolve.return_value = mock_file_path
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await log_file("../secret.txt")
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.detail == "Invalid filename."
|
||||
|
||||
|
||||
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
||||
def test_log_file_endpoint_file_not_found(LOGGING_DIR, client):
|
||||
mock_file_path = MagicMock()
|
||||
mock_file_path.is_relative_to.return_value = True
|
||||
mock_file_path.is_file.return_value = False
|
||||
|
||||
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
||||
mock_file_path.resolve.return_value = mock_file_path
|
||||
|
||||
result = client.get("/api/logs/files/nonexistent.log")
|
||||
|
||||
assert result.status_code == 404
|
||||
assert result.json()["detail"] == "File not found."
|
||||
|
||||
@@ -94,3 +94,55 @@ async def test_experiment_stream_direct_call():
|
||||
mock_socket.connect.assert_called()
|
||||
mock_socket.subscribe.assert_called_with(b"experiment")
|
||||
mock_socket.close.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_stream_direct_call():
|
||||
"""
|
||||
Test the status stream, ensuring it handles messages and sends pings on timeout.
|
||||
"""
|
||||
mock_socket = AsyncMock()
|
||||
|
||||
# Define the sequence of events for the socket:
|
||||
# 1. Successfully receive a message
|
||||
# 2. Timeout (which should trigger the ': ping' yield)
|
||||
# 3. Another message (which won't be reached because we'll simulate disconnect)
|
||||
mock_socket.recv_multipart.side_effect = [
|
||||
(b"topic", b"status_update"),
|
||||
TimeoutError(),
|
||||
(b"topic", b"ignored_msg"),
|
||||
]
|
||||
|
||||
mock_socket.close = MagicMock()
|
||||
mock_socket.connect = MagicMock()
|
||||
mock_socket.subscribe = MagicMock()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_socket
|
||||
|
||||
# Mock the ZMQ Context to return our mock_socket
|
||||
with patch(
|
||||
"control_backend.api.v1.endpoints.user_interact.Context.instance", return_value=mock_context
|
||||
):
|
||||
mock_request = AsyncMock()
|
||||
|
||||
# is_disconnected sequence:
|
||||
# 1. False -> Process "status_update"
|
||||
# 2. False -> Process TimeoutError (yield ping)
|
||||
# 3. True -> Break loop (client disconnected)
|
||||
mock_request.is_disconnected.side_effect = [False, False, True]
|
||||
|
||||
# Call the status_stream function explicitly
|
||||
response = await user_interact.status_stream(mock_request)
|
||||
|
||||
lines = []
|
||||
async for line in response.body_iterator:
|
||||
lines.append(line)
|
||||
|
||||
# Assertions
|
||||
assert "data: status_update\n\n" in lines
|
||||
assert ": ping\n\n" in lines # Verify lines 91-92 (ping logic)
|
||||
|
||||
mock_socket.connect.assert_called()
|
||||
mock_socket.subscribe.assert_called_with(b"status")
|
||||
mock_socket.close.assert_called()
|
||||
|
||||
@@ -32,6 +32,7 @@ def mock_settings():
|
||||
mock.agent_settings.vad_name = "vad_agent"
|
||||
mock.behaviour_settings.sleep_s = 0.01 # Speed up tests
|
||||
mock.behaviour_settings.comm_setup_max_retries = 1
|
||||
mock.behaviour_settings.agentspeak_file = "src/control_backend/agents/bdi/agentspeak.asl"
|
||||
yield mock
|
||||
|
||||
|
||||
|
||||
45
test/unit/logging/test_dated_file_handler.py
Normal file
45
test/unit/logging/test_dated_file_handler.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.logging.dated_file_handler import DatedFileHandler
|
||||
|
||||
|
||||
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||
def test_reset(open_):
|
||||
stream = MagicMock()
|
||||
open_.return_value = stream
|
||||
|
||||
# A file should be opened when the logger is created
|
||||
handler = DatedFileHandler(file_prefix="anything")
|
||||
assert open_.call_count == 1
|
||||
|
||||
# Upon reset, the current file should be closed, and a new one should be opened
|
||||
handler.do_rollover()
|
||||
assert stream.close.call_count == 1
|
||||
assert open_.call_count == 2
|
||||
|
||||
|
||||
@patch("control_backend.logging.dated_file_handler.Path")
|
||||
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||
def test_creates_dir(open_, Path_):
|
||||
stream = MagicMock()
|
||||
open_.return_value = stream
|
||||
|
||||
test_path = MagicMock()
|
||||
test_path.parent.is_dir.return_value = False
|
||||
Path_.return_value = test_path
|
||||
|
||||
DatedFileHandler(file_prefix="anything")
|
||||
|
||||
# The directory should've been created
|
||||
test_path.parent.mkdir.assert_called_once()
|
||||
|
||||
|
||||
@patch("control_backend.logging.dated_file_handler.DatedFileHandler._open")
|
||||
def test_invalid_constructor(_):
|
||||
with pytest.raises(ValueError):
|
||||
DatedFileHandler(file_prefix=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DatedFileHandler(file_prefix="")
|
||||
218
test/unit/logging/test_optional_field_formatter.py
Normal file
218
test/unit/logging/test_optional_field_formatter.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.logging.optional_field_formatter import OptionalFieldFormatter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger():
|
||||
"""Create a fresh logger for each test."""
|
||||
logger = logging.getLogger(f"test_{id(object())}")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.handlers = []
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def log_output(logger):
|
||||
"""Capture log output and return a function to get it."""
|
||||
|
||||
class ListHandler(logging.Handler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.records = []
|
||||
|
||||
def emit(self, record):
|
||||
self.records.append(self.format(record))
|
||||
|
||||
handler = ListHandler()
|
||||
logger.addHandler(handler)
|
||||
|
||||
def get_output():
|
||||
return handler.records
|
||||
|
||||
return get_output
|
||||
|
||||
|
||||
def test_optional_field_present(logger, log_output):
|
||||
"""Optional field should appear when provided in extra."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message", extra={"role": "user"})
|
||||
|
||||
assert log_output() == ["INFO - user - test message"]
|
||||
|
||||
|
||||
def test_optional_field_missing_no_default(logger, log_output):
|
||||
"""Missing optional field with no default should be None."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s - %(role?)s - %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message")
|
||||
|
||||
assert log_output() == ["INFO - None - test message"]
|
||||
|
||||
|
||||
def test_optional_field_missing_with_default(logger, log_output):
|
||||
"""Missing optional field should use provided default."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message")
|
||||
|
||||
assert log_output() == ["INFO - assistant - test message"]
|
||||
|
||||
|
||||
def test_optional_field_overrides_default(logger, log_output):
|
||||
"""Provided extra value should override default."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s - %(role?)s - %(message)s", defaults={"role": "assistant"}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message", extra={"role": "user"})
|
||||
|
||||
assert log_output() == ["INFO - user - test message"]
|
||||
|
||||
|
||||
def test_multiple_optional_fields(logger, log_output):
|
||||
"""Multiple optional fields should work independently."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s - %(role?)s - %(request_id?)s - %(message)s", defaults={"role": "assistant"}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"request_id": "123"})
|
||||
|
||||
assert log_output() == ["INFO - assistant - 123 - test"]
|
||||
|
||||
|
||||
def test_mixed_optional_and_required_fields(logger, log_output):
|
||||
"""Standard fields should work alongside optional fields."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(name)s %(role?)s %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"role": "user"})
|
||||
|
||||
output = log_output()[0]
|
||||
assert "INFO" in output
|
||||
assert "user" in output
|
||||
assert "test" in output
|
||||
|
||||
|
||||
def test_no_optional_fields(logger, log_output):
|
||||
"""Formatter should work normally with no optional fields."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test message")
|
||||
|
||||
assert log_output() == ["INFO test message"]
|
||||
|
||||
|
||||
def test_integer_format_specifier(logger, log_output):
|
||||
"""Optional fields with %d specifier should work."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s %(count?)d %(message)s", defaults={"count": 0}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"count": 42})
|
||||
|
||||
assert log_output() == ["INFO 42 test"]
|
||||
|
||||
|
||||
def test_float_format_specifier(logger, log_output):
|
||||
"""Optional fields with %f specifier should work."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s %(duration?)f %(message)s", defaults={"duration": 0.0}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"duration": 1.5})
|
||||
|
||||
assert "1.5" in log_output()[0]
|
||||
|
||||
|
||||
def test_empty_string_default(logger, log_output):
|
||||
"""Empty string default should work."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults={"role": ""})
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test")
|
||||
|
||||
assert log_output() == ["INFO test"]
|
||||
|
||||
|
||||
def test_none_format_string():
|
||||
"""None format string should not raise."""
|
||||
formatter = OptionalFieldFormatter(fmt=None)
|
||||
assert formatter.optional_fields == set()
|
||||
|
||||
|
||||
def test_optional_fields_parsed_correctly():
|
||||
"""Check that optional fields are correctly identified."""
|
||||
formatter = OptionalFieldFormatter("%(asctime)s %(role?)s %(level?)d %(name)s")
|
||||
|
||||
assert formatter.optional_fields == {("role", "s"), ("level", "d")}
|
||||
|
||||
|
||||
def test_format_string_normalized():
|
||||
"""Check that ? is removed from format string."""
|
||||
formatter = OptionalFieldFormatter("%(role?)s %(message)s")
|
||||
|
||||
assert "?" not in formatter._style._fmt
|
||||
assert "%(role)s" in formatter._style._fmt
|
||||
|
||||
|
||||
def test_field_with_underscore(logger, log_output):
|
||||
"""Field names with underscores should work."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(user_id?)s %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"user_id": "abc123"})
|
||||
|
||||
assert log_output() == ["INFO abc123 test"]
|
||||
|
||||
|
||||
def test_field_with_numbers(logger, log_output):
|
||||
"""Field names with numbers should work."""
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(field2?)s %(message)s")
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test", extra={"field2": "value"})
|
||||
|
||||
assert log_output() == ["INFO value test"]
|
||||
|
||||
|
||||
def test_multiple_log_calls(logger, log_output):
|
||||
"""Formatter should work correctly across multiple log calls."""
|
||||
formatter = OptionalFieldFormatter(
|
||||
"%(levelname)s %(role?)s %(message)s", defaults={"role": "other"}
|
||||
)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("first", extra={"role": "assistant"})
|
||||
logger.info("second")
|
||||
logger.info("third", extra={"role": "user"})
|
||||
|
||||
assert log_output() == [
|
||||
"INFO assistant first",
|
||||
"INFO other second",
|
||||
"INFO user third",
|
||||
]
|
||||
|
||||
|
||||
def test_default_not_mutated(logger, log_output):
|
||||
"""Original defaults dict should not be mutated."""
|
||||
defaults = {"role": "other"}
|
||||
formatter = OptionalFieldFormatter("%(levelname)s %(role?)s %(message)s", defaults=defaults)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
logger.info("test")
|
||||
|
||||
assert defaults == {"role": "other"}
|
||||
83
test/unit/logging/test_partial_filter.py
Normal file
83
test/unit/logging/test_partial_filter.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.logging import PartialFilter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger():
|
||||
"""Create a fresh logger for each test."""
|
||||
logger = logging.getLogger(f"test_{id(object())}")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.handlers = []
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def log_output(logger):
|
||||
"""Capture log output and return a function to get it."""
|
||||
|
||||
class ListHandler(logging.Handler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.records = []
|
||||
|
||||
def emit(self, record):
|
||||
self.records.append(self.format(record))
|
||||
|
||||
handler = ListHandler()
|
||||
handler.addFilter(PartialFilter())
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(handler)
|
||||
|
||||
return lambda: handler.records
|
||||
|
||||
|
||||
def test_no_partial_attribute(logger, log_output):
|
||||
"""Records without partial attribute should pass through."""
|
||||
logger.info("normal message")
|
||||
|
||||
assert log_output() == ["normal message"]
|
||||
|
||||
|
||||
def test_partial_true_filtered(logger, log_output):
|
||||
"""Records with partial=True should be filtered out."""
|
||||
logger.info("partial message", extra={"partial": True})
|
||||
|
||||
assert log_output() == []
|
||||
|
||||
|
||||
def test_partial_false_passes(logger, log_output):
|
||||
"""Records with partial=False should pass through."""
|
||||
logger.info("complete message", extra={"partial": False})
|
||||
|
||||
assert log_output() == ["complete message"]
|
||||
|
||||
|
||||
def test_partial_none_passes(logger, log_output):
|
||||
"""Records with partial=None should pass through."""
|
||||
logger.info("message", extra={"partial": None})
|
||||
|
||||
assert log_output() == ["message"]
|
||||
|
||||
|
||||
def test_partial_truthy_value_passes(logger, log_output):
|
||||
"""
|
||||
Records with truthy but non-True partial should pass through, that is, only when it's exactly
|
||||
``True`` should it pass.
|
||||
"""
|
||||
logger.info("message", extra={"partial": "yes"})
|
||||
|
||||
assert log_output() == ["message"]
|
||||
|
||||
|
||||
def test_multiple_records_mixed(logger, log_output):
|
||||
"""Filter should handle mixed records correctly."""
|
||||
logger.info("first")
|
||||
logger.info("second", extra={"partial": True})
|
||||
logger.info("third", extra={"partial": False})
|
||||
logger.info("fourth", extra={"partial": True})
|
||||
logger.info("fifth")
|
||||
|
||||
assert log_output() == ["first", "third", "fifth"]
|
||||
Reference in New Issue
Block a user