feat: added face recognition and tests
ref: N25B-397
This commit is contained in:
@@ -1,3 +1,9 @@
|
|||||||
|
"""
|
||||||
|
This program has been developed by students from the bachelor Computer Science at Utrecht
|
||||||
|
University within the Software Project course.
|
||||||
|
© Copyright Utrecht University (Department of Information and Computing Sciences)
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
@@ -29,6 +35,11 @@ class FacePerceptionAgent(BaseAgent):
|
|||||||
|
|
||||||
self._last_face_state: bool | None = None
|
self._last_face_state: bool | None = None
|
||||||
|
|
||||||
|
# Pause functionality
|
||||||
|
# NOTE: flag is set when running, cleared when paused
|
||||||
|
self._paused = asyncio.Event()
|
||||||
|
self._paused.set()
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
self.logger.info("Starting FacePerceptionAgent")
|
self.logger.info("Starting FacePerceptionAgent")
|
||||||
|
|
||||||
@@ -36,6 +47,7 @@ class FacePerceptionAgent(BaseAgent):
|
|||||||
self._connect_socket()
|
self._connect_socket()
|
||||||
|
|
||||||
self.add_behavior(self._poll_loop())
|
self.add_behavior(self._poll_loop())
|
||||||
|
self.logger.info("Finished setting up %s", self.name)
|
||||||
|
|
||||||
def _connect_socket(self):
|
def _connect_socket(self):
|
||||||
if self._socket is not None:
|
if self._socket is not None:
|
||||||
@@ -56,6 +68,7 @@ class FacePerceptionAgent(BaseAgent):
|
|||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
|
await self._paused.wait()
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
self._socket.recv_json(), timeout=settings.behaviour_settings.sleep_s
|
self._socket.recv_json(), timeout=settings.behaviour_settings.sleep_s
|
||||||
)
|
)
|
||||||
@@ -110,3 +123,22 @@ class FacePerceptionAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self.send(message)
|
await self.send(message)
|
||||||
|
|
||||||
|
async def handle_message(self, msg: InternalMessage):
|
||||||
|
"""
|
||||||
|
Handle incoming pause/resume commands from User Interrupt Agent.
|
||||||
|
"""
|
||||||
|
sender = msg.sender
|
||||||
|
|
||||||
|
if sender == settings.agent_settings.user_interrupt_name:
|
||||||
|
if msg.body == "PAUSE":
|
||||||
|
self.logger.info("Pausing Face Perception processing.")
|
||||||
|
self._paused.clear()
|
||||||
|
self._last_face_state = None
|
||||||
|
elif msg.body == "RESUME":
|
||||||
|
self.logger.info("Resuming Face Perception processing.")
|
||||||
|
self._paused.set()
|
||||||
|
else:
|
||||||
|
self.logger.warning("Unknown command from User Interrupt Agent: %s", msg.body)
|
||||||
|
else:
|
||||||
|
self.logger.debug("Ignoring message from unknown sender: %s", sender)
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
|
|
||||||
self.add_behavior(self.emotion_update_loop())
|
self.add_behavior(self.emotion_update_loop())
|
||||||
|
|
||||||
|
self.logger.info("Finished setting up %s", self.name)
|
||||||
|
|
||||||
async def emotion_update_loop(self):
|
async def emotion_update_loop(self):
|
||||||
"""
|
"""
|
||||||
Background loop to receive video frames, recognize emotions, and update beliefs.
|
Background loop to receive video frames, recognize emotions, and update beliefs.
|
||||||
@@ -133,7 +135,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
self.logger.warning("No video frame received within timeout.")
|
self.logger.warning("No video frame received within timeout.")
|
||||||
|
|
||||||
|
|
||||||
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
||||||
"""
|
"""
|
||||||
Compare emotions from previous window and current emotions,
|
Compare emotions from previous window and current emotions,
|
||||||
@@ -198,7 +199,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
self.logger.debug(f"Ignoring message from unknown sender: {sender}")
|
self.logger.debug(f"Ignoring message from unknown sender: {sender}")
|
||||||
|
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""
|
"""
|
||||||
Clean up resources used by the agent.
|
Clean up resources used by the agent.
|
||||||
|
|||||||
@@ -23,12 +23,14 @@ class VisualEmotionRecognizer(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
||||||
"""
|
"""
|
||||||
DeepFace-based implementation of VisualEmotionRecognizer.
|
DeepFace-based implementation of VisualEmotionRecognizer.
|
||||||
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
|
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
|
||||||
emotions to be over-represented.
|
emotions to be over-represented.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
@@ -37,19 +39,16 @@ class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
|||||||
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
|
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||||
# analyze does not take a model as an argument, calling it once on a dummy image to load
|
# analyze does not take a model as an argument, calling it once on a dummy image to load
|
||||||
# the model
|
# the model
|
||||||
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
|
DeepFace.analyze(dummy_img, actions=["emotion"], enforce_detection=False)
|
||||||
print("Deepface Emotion Model loaded.")
|
print("Deepface Emotion Model loaded.")
|
||||||
|
|
||||||
def sorted_dominant_emotions(self, image) -> list[str]:
|
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||||
analysis = DeepFace.analyze(image,
|
analysis = DeepFace.analyze(image, actions=["emotion"], enforce_detection=False)
|
||||||
actions=['emotion'],
|
|
||||||
enforce_detection=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sort faces by x coordinate to maintain left-to-right order
|
# Sort faces by x coordinate to maintain left-to-right order
|
||||||
analysis.sort(key=lambda face: face['region']['x'])
|
analysis.sort(key=lambda face: face["region"]["x"])
|
||||||
|
|
||||||
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
|
analysis = [face for face in analysis if face["face_confidence"] >= 0.90]
|
||||||
|
|
||||||
dominant_emotions = [face['dominant_emotion'] for face in analysis]
|
dominant_emotions = [face["dominant_emotion"] for face in analysis]
|
||||||
return dominant_emotions
|
return dominant_emotions
|
||||||
|
|||||||
@@ -401,23 +401,25 @@ class UserInterruptAgent(BaseAgent):
|
|||||||
to=[
|
to=[
|
||||||
settings.agent_settings.vad_name,
|
settings.agent_settings.vad_name,
|
||||||
settings.agent_settings.visual_emotion_recognition_name,
|
settings.agent_settings.visual_emotion_recognition_name,
|
||||||
|
settings.agent_settings.face_agent_name,
|
||||||
],
|
],
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body="PAUSE",
|
body="PAUSE",
|
||||||
)
|
)
|
||||||
await self.send(vad_message)
|
await self.send(vad_message)
|
||||||
# Voice Activity Detection and Visual Emotion Recognition agents
|
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||||
self.logger.info("Sent pause command to VAD and VED agents.")
|
self.logger.info("Sent pause command to perception agents.")
|
||||||
else:
|
else:
|
||||||
# Send resume to VAD and VED agents
|
# Send resume to VAD and VED agents
|
||||||
vad_message = InternalMessage(
|
vad_message = InternalMessage(
|
||||||
to=[
|
to=[
|
||||||
settings.agent_settings.vad_name,
|
settings.agent_settings.vad_name,
|
||||||
settings.agent_settings.visual_emotion_recognition_name,
|
settings.agent_settings.visual_emotion_recognition_name,
|
||||||
|
settings.agent_settings.face_agent_name,
|
||||||
],
|
],
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body="RESUME",
|
body="RESUME",
|
||||||
)
|
)
|
||||||
await self.send(vad_message)
|
await self.send(vad_message)
|
||||||
# Voice Activity Detection and Visual Emotion Recognition agents
|
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||||
self.logger.info("Sent resume command to VAD and VED agents.")
|
self.logger.info("Sent resume command to perception agents.")
|
||||||
|
|||||||
171
test/unit/agents/perception/test_face_detection_agent.py
Normal file
171
test/unit/agents/perception/test_face_detection_agent.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
import control_backend.agents.perception.face_rec_agent as face_module
|
||||||
|
from control_backend.agents.perception.face_rec_agent import FacePerceptionAgent
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Fixtures
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent():
|
||||||
|
"""Return a FacePerceptionAgent instance for testing."""
|
||||||
|
return FacePerceptionAgent(
|
||||||
|
name="face_agent",
|
||||||
|
zmq_address="inproc://test",
|
||||||
|
zmq_bind=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def socket():
|
||||||
|
"""Return a mocked ZMQ socket."""
|
||||||
|
sock = AsyncMock()
|
||||||
|
sock.setsockopt_string = MagicMock()
|
||||||
|
sock.connect = MagicMock()
|
||||||
|
sock.bind = MagicMock()
|
||||||
|
return sock
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Socket setup tests
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_socket_connect(agent, socket, monkeypatch):
|
||||||
|
"""Test that _connect_socket properly connects when zmq_bind=False."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.socket.return_value = socket
|
||||||
|
monkeypatch.setattr(face_module.azmq, "Context", MagicMock(instance=lambda: ctx))
|
||||||
|
|
||||||
|
agent._connect_socket()
|
||||||
|
socket.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
|
||||||
|
socket.connect.assert_called_once_with(agent._zmq_address)
|
||||||
|
socket.bind.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_socket_bind(agent, socket, monkeypatch):
|
||||||
|
"""Test that _connect_socket properly binds when zmq_bind=True."""
|
||||||
|
agent._zmq_bind = True
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.socket.return_value = socket
|
||||||
|
monkeypatch.setattr(face_module.azmq, "Context", MagicMock(instance=lambda: ctx))
|
||||||
|
|
||||||
|
agent._connect_socket()
|
||||||
|
socket.bind.assert_called_once_with(agent._zmq_address)
|
||||||
|
socket.connect.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_socket_twice_is_noop(agent, socket):
|
||||||
|
"""Test that calling _connect_socket twice does not overwrite an existing socket."""
|
||||||
|
agent._socket = socket
|
||||||
|
agent._connect_socket()
|
||||||
|
assert agent._socket is socket
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Belief update tests
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_face_belief_present(agent):
|
||||||
|
"""Test that _update_face_belief(True) creates the 'face_present' belief."""
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
await agent._update_face_belief(True)
|
||||||
|
msg = agent.send.await_args.args[0]
|
||||||
|
payload = BeliefMessage.model_validate_json(msg.body)
|
||||||
|
assert payload.create[0].name == "face_present"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_face_belief_absent(agent):
|
||||||
|
"""Test that _update_face_belief(False) deletes the 'face_present' belief."""
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
await agent._update_face_belief(False)
|
||||||
|
msg = agent.send.await_args.args[0]
|
||||||
|
payload = BeliefMessage.model_validate_json(msg.body)
|
||||||
|
assert payload.delete[0].name == "face_present"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_face_belief_present(agent):
|
||||||
|
"""Test that _post_face_belief(True) sends a belief creation message."""
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
await agent._post_face_belief(True)
|
||||||
|
msg = agent.send.await_args.args[0]
|
||||||
|
assert '"create"' in msg.body and '"face_present"' in msg.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_face_belief_absent(agent):
|
||||||
|
"""Test that _post_face_belief(False) sends a belief deletion message."""
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
await agent._post_face_belief(False)
|
||||||
|
msg = agent.send.await_args.args[0]
|
||||||
|
assert '"delete"' in msg.body and '"face_present"' in msg.body
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Message handling tests
|
||||||
|
# -------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_pause(agent):
|
||||||
|
"""Test that a 'PAUSE' message clears _paused and resets _last_face_state."""
|
||||||
|
agent._paused.set()
|
||||||
|
agent._last_face_state = True
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=agent.name,
|
||||||
|
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||||
|
thread="cmd",
|
||||||
|
body="PAUSE",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
assert not agent._paused.is_set()
|
||||||
|
assert agent._last_face_state is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_resume(agent):
|
||||||
|
"""Test that a 'RESUME' message sets _paused."""
|
||||||
|
agent._paused.clear()
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=agent.name,
|
||||||
|
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||||
|
thread="cmd",
|
||||||
|
body="RESUME",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
assert agent._paused.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_unknown_command(agent):
|
||||||
|
"""Test that an unknown command from UserInterruptAgent is ignored (logs a warning)."""
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=agent.name,
|
||||||
|
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||||
|
thread="cmd",
|
||||||
|
body="???",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_unknown_sender(agent):
|
||||||
|
"""Test that messages from unknown senders are ignored."""
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=agent.name,
|
||||||
|
sender="someone_else",
|
||||||
|
thread="cmd",
|
||||||
|
body="PAUSE",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
Reference in New Issue
Block a user