feat: added face recognition and tests
ref: N25B-397
This commit is contained in:
@@ -1,3 +1,9 @@
|
||||
"""
|
||||
This program has been developed by students from the bachelor Computer Science at Utrecht
|
||||
University within the Software Project course.
|
||||
© Copyright Utrecht University (Department of Information and Computing Sciences)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import zmq
|
||||
@@ -29,6 +35,11 @@ class FacePerceptionAgent(BaseAgent):
|
||||
|
||||
self._last_face_state: bool | None = None
|
||||
|
||||
# Pause functionality
|
||||
# NOTE: flag is set when running, cleared when paused
|
||||
self._paused = asyncio.Event()
|
||||
self._paused.set()
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Starting FacePerceptionAgent")
|
||||
|
||||
@@ -36,6 +47,7 @@ class FacePerceptionAgent(BaseAgent):
|
||||
self._connect_socket()
|
||||
|
||||
self.add_behavior(self._poll_loop())
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
def _connect_socket(self):
|
||||
if self._socket is not None:
|
||||
@@ -56,6 +68,7 @@ class FacePerceptionAgent(BaseAgent):
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._paused.wait()
|
||||
response = await asyncio.wait_for(
|
||||
self._socket.recv_json(), timeout=settings.behaviour_settings.sleep_s
|
||||
)
|
||||
@@ -110,3 +123,22 @@ class FacePerceptionAgent(BaseAgent):
|
||||
)
|
||||
|
||||
await self.send(message)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming pause/resume commands from User Interrupt Agent.
|
||||
"""
|
||||
sender = msg.sender
|
||||
|
||||
if sender == settings.agent_settings.user_interrupt_name:
|
||||
if msg.body == "PAUSE":
|
||||
self.logger.info("Pausing Face Perception processing.")
|
||||
self._paused.clear()
|
||||
self._last_face_state = None
|
||||
elif msg.body == "RESUME":
|
||||
self.logger.info("Resuming Face Perception processing.")
|
||||
self._paused.set()
|
||||
else:
|
||||
self.logger.warning("Unknown command from User Interrupt Agent: %s", msg.body)
|
||||
else:
|
||||
self.logger.debug("Ignoring message from unknown sender: %s", sender)
|
||||
|
||||
@@ -75,6 +75,8 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
||||
|
||||
self.add_behavior(self.emotion_update_loop())
|
||||
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
async def emotion_update_loop(self):
|
||||
"""
|
||||
Background loop to receive video frames, recognize emotions, and update beliefs.
|
||||
@@ -133,7 +135,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
||||
except zmq.Again:
|
||||
self.logger.warning("No video frame received within timeout.")
|
||||
|
||||
|
||||
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
||||
"""
|
||||
Compare emotions from previous window and current emotions,
|
||||
@@ -198,7 +199,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
||||
else:
|
||||
self.logger.debug(f"Ignoring message from unknown sender: {sender}")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Clean up resources used by the agent.
|
||||
|
||||
@@ -23,12 +23,14 @@ class VisualEmotionRecognizer(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
||||
"""
|
||||
DeepFace-based implementation of VisualEmotionRecognizer.
|
||||
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
|
||||
emotions to be over-represented.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.load_model()
|
||||
|
||||
@@ -37,19 +39,16 @@ class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
||||
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||
# analyze does not take a model as an argument, calling it once on a dummy image to load
|
||||
# the model
|
||||
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
|
||||
DeepFace.analyze(dummy_img, actions=["emotion"], enforce_detection=False)
|
||||
print("Deepface Emotion Model loaded.")
|
||||
|
||||
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||
analysis = DeepFace.analyze(image,
|
||||
actions=['emotion'],
|
||||
enforce_detection=False
|
||||
)
|
||||
analysis = DeepFace.analyze(image, actions=["emotion"], enforce_detection=False)
|
||||
|
||||
# Sort faces by x coordinate to maintain left-to-right order
|
||||
analysis.sort(key=lambda face: face['region']['x'])
|
||||
analysis.sort(key=lambda face: face["region"]["x"])
|
||||
|
||||
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
|
||||
analysis = [face for face in analysis if face["face_confidence"] >= 0.90]
|
||||
|
||||
dominant_emotions = [face['dominant_emotion'] for face in analysis]
|
||||
dominant_emotions = [face["dominant_emotion"] for face in analysis]
|
||||
return dominant_emotions
|
||||
|
||||
@@ -401,23 +401,25 @@ class UserInterruptAgent(BaseAgent):
|
||||
to=[
|
||||
settings.agent_settings.vad_name,
|
||||
settings.agent_settings.visual_emotion_recognition_name,
|
||||
settings.agent_settings.face_agent_name,
|
||||
],
|
||||
sender=self.name,
|
||||
body="PAUSE",
|
||||
)
|
||||
await self.send(vad_message)
|
||||
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||
self.logger.info("Sent pause command to VAD and VED agents.")
|
||||
self.logger.info("Sent pause command to perception agents.")
|
||||
else:
|
||||
# Send resume to VAD and VED agents
|
||||
vad_message = InternalMessage(
|
||||
to=[
|
||||
settings.agent_settings.vad_name,
|
||||
settings.agent_settings.visual_emotion_recognition_name,
|
||||
settings.agent_settings.face_agent_name,
|
||||
],
|
||||
sender=self.name,
|
||||
body="RESUME",
|
||||
)
|
||||
await self.send(vad_message)
|
||||
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||
self.logger.info("Sent resume command to VAD and VED agents.")
|
||||
self.logger.info("Sent resume command to perception agents.")
|
||||
|
||||
171
test/unit/agents/perception/test_face_detection_agent.py
Normal file
171
test/unit/agents/perception/test_face_detection_agent.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
import control_backend.agents.perception.face_rec_agent as face_module
|
||||
from control_backend.agents.perception.face_rec_agent import FacePerceptionAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.schemas.belief_message import BeliefMessage
|
||||
|
||||
# -------------------------
|
||||
# Fixtures
|
||||
# -------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
"""Return a FacePerceptionAgent instance for testing."""
|
||||
return FacePerceptionAgent(
|
||||
name="face_agent",
|
||||
zmq_address="inproc://test",
|
||||
zmq_bind=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket():
|
||||
"""Return a mocked ZMQ socket."""
|
||||
sock = AsyncMock()
|
||||
sock.setsockopt_string = MagicMock()
|
||||
sock.connect = MagicMock()
|
||||
sock.bind = MagicMock()
|
||||
return sock
|
||||
|
||||
|
||||
# -------------------------
|
||||
# Socket setup tests
|
||||
# -------------------------
|
||||
|
||||
|
||||
def test_connect_socket_connect(agent, socket, monkeypatch):
|
||||
"""Test that _connect_socket properly connects when zmq_bind=False."""
|
||||
ctx = MagicMock()
|
||||
ctx.socket.return_value = socket
|
||||
monkeypatch.setattr(face_module.azmq, "Context", MagicMock(instance=lambda: ctx))
|
||||
|
||||
agent._connect_socket()
|
||||
socket.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
|
||||
socket.connect.assert_called_once_with(agent._zmq_address)
|
||||
socket.bind.assert_not_called()
|
||||
|
||||
|
||||
def test_connect_socket_bind(agent, socket, monkeypatch):
|
||||
"""Test that _connect_socket properly binds when zmq_bind=True."""
|
||||
agent._zmq_bind = True
|
||||
ctx = MagicMock()
|
||||
ctx.socket.return_value = socket
|
||||
monkeypatch.setattr(face_module.azmq, "Context", MagicMock(instance=lambda: ctx))
|
||||
|
||||
agent._connect_socket()
|
||||
socket.bind.assert_called_once_with(agent._zmq_address)
|
||||
socket.connect.assert_not_called()
|
||||
|
||||
|
||||
def test_connect_socket_twice_is_noop(agent, socket):
|
||||
"""Test that calling _connect_socket twice does not overwrite an existing socket."""
|
||||
agent._socket = socket
|
||||
agent._connect_socket()
|
||||
assert agent._socket is socket
|
||||
|
||||
|
||||
# -------------------------
|
||||
# Belief update tests
|
||||
# -------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_face_belief_present(agent):
|
||||
"""Test that _update_face_belief(True) creates the 'face_present' belief."""
|
||||
agent.send = AsyncMock()
|
||||
await agent._update_face_belief(True)
|
||||
msg = agent.send.await_args.args[0]
|
||||
payload = BeliefMessage.model_validate_json(msg.body)
|
||||
assert payload.create[0].name == "face_present"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_face_belief_absent(agent):
|
||||
"""Test that _update_face_belief(False) deletes the 'face_present' belief."""
|
||||
agent.send = AsyncMock()
|
||||
await agent._update_face_belief(False)
|
||||
msg = agent.send.await_args.args[0]
|
||||
payload = BeliefMessage.model_validate_json(msg.body)
|
||||
assert payload.delete[0].name == "face_present"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_face_belief_present(agent):
|
||||
"""Test that _post_face_belief(True) sends a belief creation message."""
|
||||
agent.send = AsyncMock()
|
||||
await agent._post_face_belief(True)
|
||||
msg = agent.send.await_args.args[0]
|
||||
assert '"create"' in msg.body and '"face_present"' in msg.body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_face_belief_absent(agent):
|
||||
"""Test that _post_face_belief(False) sends a belief deletion message."""
|
||||
agent.send = AsyncMock()
|
||||
await agent._post_face_belief(False)
|
||||
msg = agent.send.await_args.args[0]
|
||||
assert '"delete"' in msg.body and '"face_present"' in msg.body
|
||||
|
||||
|
||||
# -------------------------
|
||||
# Message handling tests
|
||||
# -------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_pause(agent):
|
||||
"""Test that a 'PAUSE' message clears _paused and resets _last_face_state."""
|
||||
agent._paused.set()
|
||||
agent._last_face_state = True
|
||||
msg = InternalMessage(
|
||||
to=agent.name,
|
||||
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||
thread="cmd",
|
||||
body="PAUSE",
|
||||
)
|
||||
await agent.handle_message(msg)
|
||||
assert not agent._paused.is_set()
|
||||
assert agent._last_face_state is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_resume(agent):
|
||||
"""Test that a 'RESUME' message sets _paused."""
|
||||
agent._paused.clear()
|
||||
msg = InternalMessage(
|
||||
to=agent.name,
|
||||
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||
thread="cmd",
|
||||
body="RESUME",
|
||||
)
|
||||
await agent.handle_message(msg)
|
||||
assert agent._paused.is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unknown_command(agent):
|
||||
"""Test that an unknown command from UserInterruptAgent is ignored (logs a warning)."""
|
||||
msg = InternalMessage(
|
||||
to=agent.name,
|
||||
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||
thread="cmd",
|
||||
body="???",
|
||||
)
|
||||
await agent.handle_message(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unknown_sender(agent):
|
||||
"""Test that messages from unknown senders are ignored."""
|
||||
msg = InternalMessage(
|
||||
to=agent.name,
|
||||
sender="someone_else",
|
||||
thread="cmd",
|
||||
body="PAUSE",
|
||||
)
|
||||
await agent.handle_message(msg)
|
||||
Reference in New Issue
Block a user