feat: visual emotion recognition agent
This commit is contained in:
@@ -10,8 +10,6 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.schemas.ri_message import PauseCommand, RIEndpoint
|
||||
|
||||
|
||||
def speech_agent_path():
|
||||
@@ -402,38 +400,3 @@ async def test_negotiate_req_socket_none_causes_retry(zmq_context):
|
||||
result = await agent._negotiate_connection(max_retries=1)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_pause_command(zmq_context):
|
||||
"""Test handle_message with a valid PauseCommand."""
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = AsyncMock()
|
||||
agent.logger = MagicMock()
|
||||
|
||||
agent._req_socket.recv_json.return_value = {"status": "ok"}
|
||||
|
||||
pause_cmd = PauseCommand(data=True)
|
||||
msg = InternalMessage(to="ri_comm", sender="user_int", body=pause_cmd.model_dump_json())
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent._req_socket.send_json.assert_awaited_once()
|
||||
args = agent._req_socket.send_json.await_args[0][0]
|
||||
assert args["endpoint"] == RIEndpoint.PAUSE.value
|
||||
assert args["data"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_invalid_pause_command(zmq_context):
|
||||
"""Test handle_message with invalid JSON."""
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = AsyncMock()
|
||||
agent.logger = MagicMock()
|
||||
|
||||
msg = InternalMessage(to="ri_comm", sender="user_int", body="invalid json")
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.logger.warning.assert_called_with("Incorrect message format for PauseCommand.")
|
||||
agent._req_socket.send_json.assert_not_called()
|
||||
|
||||
@@ -0,0 +1,338 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import zmq
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
# Adjust the import path to match your project structure
|
||||
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent import ( # noqa
|
||||
VisualEmotionRecognitionAgent,
|
||||
)
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.settings") as mock: # noqa
|
||||
# Set default values required by the agent
|
||||
mock.behaviour_settings.visual_emotion_recognition_window_duration_s = 5
|
||||
mock.behaviour_settings.visual_emotion_recognition_min_frames_per_face = 3
|
||||
mock.agent_settings.bdi_core_name = "bdi_core_agent"
|
||||
mock.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deepface():
|
||||
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.DeepFaceEmotionRecognizer") as mock: # noqa
|
||||
instance = mock.return_value
|
||||
instance.sorted_dominant_emotions.return_value = []
|
||||
yield instance
|
||||
|
||||
@pytest.fixture
|
||||
def mock_zmq_context():
|
||||
with patch("zmq.asyncio.Context.instance") as mock_ctx:
|
||||
mock_socket = MagicMock()
|
||||
# Mock socket methods to return None or AsyncMock for async methods
|
||||
mock_socket.bind = MagicMock()
|
||||
mock_socket.connect = MagicMock()
|
||||
mock_socket.setsockopt = MagicMock()
|
||||
mock_socket.setsockopt_string = MagicMock()
|
||||
mock_socket.recv_multipart = AsyncMock()
|
||||
mock_socket.close = MagicMock()
|
||||
|
||||
mock_ctx.return_value.socket.return_value = mock_socket
|
||||
yield mock_ctx
|
||||
|
||||
@pytest.fixture
|
||||
def agent(mock_settings, mock_deepface, mock_zmq_context):
|
||||
# Initialize agent with specific params to control testing
|
||||
agent = VisualEmotionRecognitionAgent(
|
||||
name="test_agent",
|
||||
socket_address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
timeout_ms=100,
|
||||
window_duration=2,
|
||||
min_frames_required=2
|
||||
)
|
||||
# Mock the internal send method from BaseAgent
|
||||
agent.send = AsyncMock()
|
||||
# Mock the add_behavior method from BaseAgent
|
||||
agent.add_behavior = MagicMock()
|
||||
# Mock the logger
|
||||
agent.logger = MagicMock()
|
||||
|
||||
return agent
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(agent):
|
||||
"""Test that the agent initializes with correct attributes."""
|
||||
assert agent.name == "test_agent"
|
||||
assert agent.socket_address == "tcp://localhost:5555"
|
||||
assert agent.socket_bind is False
|
||||
assert agent.timeout_ms == 100
|
||||
assert agent._paused.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_connect(agent, mock_zmq_context, mock_deepface):
|
||||
"""Test setup routine when binding is False (connect)."""
|
||||
agent.socket_bind = False
|
||||
await agent.setup()
|
||||
|
||||
socket = agent.video_in_socket
|
||||
socket.connect.assert_called_with("tcp://localhost:5555")
|
||||
socket.bind.assert_not_called()
|
||||
socket.setsockopt.assert_any_call(zmq.RCVHWM, 3)
|
||||
socket.setsockopt.assert_any_call(zmq.RCVTIMEO, 100)
|
||||
agent.add_behavior.assert_called_once()
|
||||
assert agent.emotion_recognizer == mock_deepface
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(agent, mock_zmq_context):
|
||||
"""Test setup routine when binding is True."""
|
||||
agent.socket_bind = True
|
||||
await agent.setup()
|
||||
|
||||
socket = agent.video_in_socket
|
||||
socket.bind.assert_called_with("tcp://localhost:5555")
|
||||
socket.connect.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emotion_update_loop_normal_flow(agent, mock_deepface):
|
||||
"""
|
||||
Test the main loop logic:
|
||||
1. Receive frames
|
||||
2. Aggregate stats
|
||||
3. Trigger window update
|
||||
4. Call update_emotions
|
||||
"""
|
||||
# Setup dependencies
|
||||
await agent.setup()
|
||||
agent._running = True
|
||||
|
||||
# Create fake image data (10x10 pixels)
|
||||
width, height = 10, 10
|
||||
image_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes()
|
||||
w_bytes = width.to_bytes(4, 'little')
|
||||
h_bytes = height.to_bytes(4, 'little')
|
||||
|
||||
# Mock ZMQ receive to return data 3 times, then stop the loop
|
||||
# We use a side_effect on recv_multipart to simulate frames and then stop the loop
|
||||
async def recv_side_effect():
|
||||
if agent._running:
|
||||
return w_bytes, h_bytes, image_bytes
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
agent.video_in_socket.recv_multipart.side_effect = recv_side_effect
|
||||
|
||||
# Mock DeepFace to return emotions
|
||||
# Frame 1: Happy
|
||||
# Frame 2: Happy
|
||||
# Frame 3: Happy (Trigger window)
|
||||
mock_deepface.sorted_dominant_emotions.side_effect = [
|
||||
["happy"],
|
||||
["happy"],
|
||||
["happy"]
|
||||
]
|
||||
|
||||
# Mock update_emotions to verify it's called
|
||||
agent.update_emotions = AsyncMock()
|
||||
|
||||
# Mock time.time to simulate window passage
|
||||
# We need time to advance significantly after the frames are collected
|
||||
start_time = time.time()
|
||||
|
||||
with patch("time.time") as mock_time:
|
||||
# Sequence of time calls:
|
||||
# 1. Init next_window_time calculation
|
||||
# 2. Loop 1 check
|
||||
# 3. Loop 2 check
|
||||
# 4. Loop 3 check (Make this one pass the window threshold)
|
||||
mock_time.side_effect = [
|
||||
start_time, # init
|
||||
start_time + 0.1, # frame 1 check
|
||||
start_time + 0.2, # frame 2 check
|
||||
start_time + 10.0, # frame 3 check (triggers window reset)
|
||||
start_time + 10.1, # next init
|
||||
start_time + 10.2 # break loop
|
||||
]
|
||||
|
||||
# We need to manually break the infinite loop after the update
|
||||
# We can do this by wrapping update_emotions to set _running = False
|
||||
async def stop_loop(*args, **kwargs):
|
||||
agent._running = False
|
||||
|
||||
agent.update_emotions.side_effect = stop_loop
|
||||
|
||||
# Run the loop
|
||||
await agent.emotion_update_loop()
|
||||
|
||||
# Verifications
|
||||
assert agent.update_emotions.called
|
||||
# Check that it detected 'happy' as dominant (2 required, 3 found)
|
||||
call_args = agent.update_emotions.call_args
|
||||
assert call_args is not None
|
||||
# args: (prev_emotions, window_dominant_emotions)
|
||||
assert call_args[0][1] == {"happy"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emotion_update_loop_insufficient_frames(agent, mock_deepface):
|
||||
"""Test that emotions are NOT updated if min_frames_required is not met."""
|
||||
await agent.setup()
|
||||
agent._running = True
|
||||
agent.min_frames_required = 5 # Set high requirement
|
||||
|
||||
width, height = 10, 10
|
||||
image_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes()
|
||||
w_bytes = width.to_bytes(4, 'little')
|
||||
h_bytes = height.to_bytes(4, 'little')
|
||||
|
||||
agent.video_in_socket.recv_multipart.return_value = (w_bytes, h_bytes, image_bytes)
|
||||
mock_deepface.sorted_dominant_emotions.return_value = ["sad"]
|
||||
|
||||
agent.update_emotions = AsyncMock()
|
||||
|
||||
with patch("time.time") as mock_time:
|
||||
# Time setup to trigger window processing immediately
|
||||
mock_time.side_effect = [0, 100, 101]
|
||||
|
||||
# Stop loop after first pass
|
||||
async def stop_loop(*args, **kwargs):
|
||||
agent._running = False
|
||||
agent.update_emotions.side_effect = stop_loop
|
||||
|
||||
await agent.emotion_update_loop()
|
||||
|
||||
# It should call update_emotions with EMPTY set because min frames (5) > detected (1)
|
||||
call_args = agent.update_emotions.call_args
|
||||
assert call_args[0][1] == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emotion_update_loop_zmq_again_and_exception(agent):
|
||||
"""Test that the loop handles ZMQ timeouts (Again) and generic exceptions."""
|
||||
await agent.setup()
|
||||
agent._running = True
|
||||
|
||||
# Side effect:
|
||||
# 1. Raise ZMQ Again (Timeout) -> should log warning
|
||||
# 2. Raise Generic Exception -> should log error
|
||||
# 3. Raise CancelledError -> stop loop (simulating stop)
|
||||
agent.video_in_socket.recv_multipart.side_effect = [
|
||||
zmq.Again(),
|
||||
RuntimeError("Random Failure"),
|
||||
asyncio.CancelledError() # To break loop cleanly
|
||||
]
|
||||
|
||||
# We need to ensure the loop doesn't block on _paused
|
||||
agent._paused.set()
|
||||
|
||||
# Run loop
|
||||
try:
|
||||
await agent.emotion_update_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_emotions_logic(agent, mock_settings):
|
||||
"""Test the logic for calculating diffs and sending messages."""
|
||||
agent.name = "viz_agent"
|
||||
|
||||
# Case 1: No change
|
||||
await agent.update_emotions({"happy"}, {"happy"})
|
||||
agent.send.assert_not_called()
|
||||
|
||||
# Case 2: Remove 'happy', Add 'sad'
|
||||
await agent.update_emotions({"happy"}, {"sad"})
|
||||
|
||||
assert agent.send.called
|
||||
call_args = agent.send.call_args
|
||||
msg = call_args[0][0] # InternalMessage object
|
||||
|
||||
assert msg.to == mock_settings.agent_settings.bdi_core_name
|
||||
assert msg.sender == "viz_agent"
|
||||
assert msg.thread == "beliefs"
|
||||
|
||||
payload = json.loads(msg.body)
|
||||
|
||||
# Check Created Beliefs
|
||||
assert len(payload["create"]) == 1
|
||||
assert payload["create"][0]["name"] == "emotion_detected"
|
||||
assert payload["create"][0]["arguments"] == ["sad"]
|
||||
|
||||
# Check Deleted Beliefs
|
||||
assert len(payload["delete"]) == 1
|
||||
assert payload["delete"][0]["name"] == "emotion_detected"
|
||||
assert payload["delete"][0]["arguments"] == ["happy"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_emotions_validation_error(agent):
|
||||
"""Test that ValidationErrors during Belief creation are caught."""
|
||||
|
||||
# We patch Belief to raise ValidationError
|
||||
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.Belief") as MockBelief: # noqa
|
||||
MockBelief.side_effect = ValidationError.from_exception_data("Simulated Error", [])
|
||||
|
||||
# Try to update emotions
|
||||
await agent.update_emotions(prev_emotions={"happy"}, emotions={"sad"})
|
||||
|
||||
# Verify empty payload is sent (or payload with valid ones if mixed)
|
||||
# In this case both failed, so payload lists should be empty
|
||||
assert agent.send.called
|
||||
msg = agent.send.call_args[0][0]
|
||||
payload = json.loads(msg.body)
|
||||
assert payload["create"] == []
|
||||
assert payload["delete"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message(agent, mock_settings):
|
||||
"""Test message handling for Pause/Resume."""
|
||||
|
||||
# Setup
|
||||
ui_name = mock_settings.agent_settings.user_interrupt_name
|
||||
|
||||
# 1. PAUSE message
|
||||
msg_pause = InternalMessage(to="me", sender=ui_name, body="PAUSE")
|
||||
await agent.handle_message(msg_pause)
|
||||
assert not agent._paused.is_set() # Should be cleared (paused)
|
||||
agent.logger.info.assert_called_with("Pausing Visual Emotion Recognition processing.")
|
||||
|
||||
# 2. RESUME message
|
||||
msg_resume = InternalMessage(to="me", sender=ui_name, body="RESUME")
|
||||
await agent.handle_message(msg_resume)
|
||||
assert agent._paused.is_set() # Should be set (running)
|
||||
|
||||
# 3. Unknown command
|
||||
msg_unknown = InternalMessage(to="me", sender=ui_name, body="DANCE")
|
||||
await agent.handle_message(msg_unknown)
|
||||
|
||||
# 4. Unknown sender
|
||||
msg_random = InternalMessage(to="me", sender="random_guy", body="PAUSE")
|
||||
await agent.handle_message(msg_random)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop(agent, mock_zmq_context):
|
||||
"""Test the stop method cleans up resources."""
|
||||
# We need to mock super().stop(). Since we can't easily patch super(),
|
||||
# and the provided BaseAgent code shows stop() just sets _running and cancels tasks,
|
||||
# we can rely on the fact that VisualEmotionRecognitionAgent calls it.
|
||||
|
||||
# However, since we provided a 'agent' fixture that mocks things, we should verify specific cleanups. # noqa
|
||||
await agent.setup()
|
||||
|
||||
with patch("control_backend.agents.BaseAgent.stop", new_callable=AsyncMock) as mock_super_stop:
|
||||
await agent.stop()
|
||||
|
||||
# Verify socket closed
|
||||
agent.video_in_socket.close.assert_called_once()
|
||||
# Verify parent stop called
|
||||
mock_super_stop.assert_called_once()
|
||||
@@ -305,26 +305,30 @@ async def test_send_experiment_control(agent):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_pause_command(agent):
|
||||
# --- Test PAUSE ---
|
||||
await agent._send_pause_command("true")
|
||||
# Sends to RI and VAD
|
||||
assert agent.send.await_count == 2
|
||||
msgs = [call.args[0] for call in agent.send.call_args_list]
|
||||
|
||||
ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name)
|
||||
assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint
|
||||
assert json.loads(ri_msg.body)["data"] is True
|
||||
# Should send exactly 1 message
|
||||
assert agent.send.await_count == 1
|
||||
|
||||
# Extract the message object from the mock call
|
||||
# call_args[0] are positional args, and [0] is the first arg (the message)
|
||||
msg = agent.send.call_args[0][0]
|
||||
|
||||
vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name)
|
||||
assert vad_msg.body == "PAUSE"
|
||||
# Verify Body
|
||||
assert msg.body == "PAUSE"
|
||||
|
||||
# --- Test RESUME ---
|
||||
agent.send.reset_mock()
|
||||
await agent._send_pause_command("false")
|
||||
assert agent.send.await_count == 2
|
||||
vad_msg = next(
|
||||
m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name
|
||||
).args[0]
|
||||
assert vad_msg.body == "RESUME"
|
||||
|
||||
# Should send exactly 1 message
|
||||
assert agent.send.await_count == 1
|
||||
|
||||
msg = agent.send.call_args[0][0]
|
||||
|
||||
# Verify Body
|
||||
assert msg.body == "RESUME"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup(agent):
|
||||
|
||||
Reference in New Issue
Block a user