feat: face recognition agent #53

Open
2584433 wants to merge 31 commits from feat/face-recognition into main
5 changed files with 228 additions and 24 deletions
Showing only changes of commit f89fb2266a - Show all commits

View File

@@ -1,3 +1,9 @@
"""
This program has been developed by students from the bachelor Computer Science at Utrecht
University within the Software Project course.
© Copyright Utrecht University (Department of Information and Computing Sciences)
"""
import asyncio
import zmq
@@ -29,6 +35,11 @@ class FacePerceptionAgent(BaseAgent):
self._last_face_state: bool | None = None
# Pause functionality
# NOTE: flag is set when running, cleared when paused
self._paused = asyncio.Event()
self._paused.set()
async def setup(self):
self.logger.info("Starting FacePerceptionAgent")
@@ -36,6 +47,7 @@ class FacePerceptionAgent(BaseAgent):
self._connect_socket()
self.add_behavior(self._poll_loop())
self.logger.info("Finished setting up %s", self.name)
def _connect_socket(self):
if self._socket is not None:
@@ -56,6 +68,7 @@ class FacePerceptionAgent(BaseAgent):
while self._running:
try:
await self._paused.wait()
response = await asyncio.wait_for(
self._socket.recv_json(), timeout=settings.behaviour_settings.sleep_s
)
@@ -110,3 +123,22 @@ class FacePerceptionAgent(BaseAgent):
)
await self.send(message)
async def handle_message(self, msg: InternalMessage):
"""
Handle incoming pause/resume commands from User Interrupt Agent.
"""
sender = msg.sender
if sender == settings.agent_settings.user_interrupt_name:
if msg.body == "PAUSE":
self.logger.info("Pausing Face Perception processing.")
self._paused.clear()
self._last_face_state = None
elif msg.body == "RESUME":
self.logger.info("Resuming Face Perception processing.")
self._paused.set()
else:
self.logger.warning("Unknown command from User Interrupt Agent: %s", msg.body)
else:
self.logger.debug("Ignoring message from unknown sender: %s", sender)

View File

@@ -75,6 +75,8 @@ class VisualEmotionRecognitionAgent(BaseAgent):
self.add_behavior(self.emotion_update_loop())
self.logger.info("Finished setting up %s", self.name)
async def emotion_update_loop(self):
"""
Background loop to receive video frames, recognize emotions, and update beliefs.
@@ -133,7 +135,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
except zmq.Again:
self.logger.warning("No video frame received within timeout.")
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
"""
Compare emotions from previous window and current emotions,
@@ -198,7 +199,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
else:
self.logger.debug(f"Ignoring message from unknown sender: {sender}")
async def stop(self):
"""
Clean up resources used by the agent.

View File

@@ -23,12 +23,14 @@ class VisualEmotionRecognizer(abc.ABC):
"""
pass
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
"""
DeepFace-based implementation of VisualEmotionRecognizer.
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
emotions to be over-represented.
"""
def __init__(self):
self.load_model()
@@ -37,19 +39,16 @@ class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
# analyze does not take a model as an argument, calling it once on a dummy image to load
# the model
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
DeepFace.analyze(dummy_img, actions=["emotion"], enforce_detection=False)
print("Deepface Emotion Model loaded.")
def sorted_dominant_emotions(self, image) -> list[str]:
analysis = DeepFace.analyze(image,
actions=['emotion'],
enforce_detection=False
)
analysis = DeepFace.analyze(image, actions=["emotion"], enforce_detection=False)
# Sort faces by x coordinate to maintain left-to-right order
analysis.sort(key=lambda face: face['region']['x'])
analysis.sort(key=lambda face: face["region"]["x"])
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
analysis = [face for face in analysis if face["face_confidence"] >= 0.90]
dominant_emotions = [face['dominant_emotion'] for face in analysis]
dominant_emotions = [face["dominant_emotion"] for face in analysis]
return dominant_emotions

View File

@@ -401,23 +401,25 @@ class UserInterruptAgent(BaseAgent):
to=[
settings.agent_settings.vad_name,
settings.agent_settings.visual_emotion_recognition_name,
settings.agent_settings.face_agent_name,
],
sender=self.name,
body="PAUSE",
)
await self.send(vad_message)
# Voice Activity Detection and Visual Emotion Recognition agents
self.logger.info("Sent pause command to VAD and VED agents.")
self.logger.info("Sent pause command to perception agents.")
else:
# Send resume to VAD and VED agents
vad_message = InternalMessage(
to=[
settings.agent_settings.vad_name,
settings.agent_settings.visual_emotion_recognition_name,
settings.agent_settings.face_agent_name,
],
sender=self.name,
body="RESUME",
)
await self.send(vad_message)
# Voice Activity Detection and Visual Emotion Recognition agents
self.logger.info("Sent resume command to VAD and VED agents.")
self.logger.info("Sent resume command to perception agents.")

View File

@@ -0,0 +1,171 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
import control_backend.agents.perception.face_rec_agent as face_module
from control_backend.agents.perception.face_rec_agent import FacePerceptionAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.belief_message import BeliefMessage
# -------------------------
# Fixtures
# -------------------------
@pytest.fixture
def agent():
"""Return a FacePerceptionAgent instance for testing."""
return FacePerceptionAgent(
name="face_agent",
zmq_address="inproc://test",
zmq_bind=False,
)
@pytest.fixture
def socket():
"""Return a mocked ZMQ socket."""
sock = AsyncMock()
sock.setsockopt_string = MagicMock()
sock.connect = MagicMock()
sock.bind = MagicMock()
return sock
# -------------------------
# Socket setup tests
# -------------------------
def test_connect_socket_connect(agent, socket, monkeypatch):
"""Test that _connect_socket properly connects when zmq_bind=False."""
ctx = MagicMock()
ctx.socket.return_value = socket
monkeypatch.setattr(face_module.azmq, "Context", MagicMock(instance=lambda: ctx))
agent._connect_socket()
socket.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
socket.connect.assert_called_once_with(agent._zmq_address)
socket.bind.assert_not_called()
def test_connect_socket_bind(agent, socket, monkeypatch):
"""Test that _connect_socket properly binds when zmq_bind=True."""
agent._zmq_bind = True
ctx = MagicMock()
ctx.socket.return_value = socket
monkeypatch.setattr(face_module.azmq, "Context", MagicMock(instance=lambda: ctx))
agent._connect_socket()
socket.bind.assert_called_once_with(agent._zmq_address)
socket.connect.assert_not_called()
def test_connect_socket_twice_is_noop(agent, socket):
"""Test that calling _connect_socket twice does not overwrite an existing socket."""
agent._socket = socket
agent._connect_socket()
assert agent._socket is socket
# -------------------------
# Belief update tests
# -------------------------
@pytest.mark.asyncio
async def test_update_face_belief_present(agent):
"""Test that _update_face_belief(True) creates the 'face_present' belief."""
agent.send = AsyncMock()
await agent._update_face_belief(True)
msg = agent.send.await_args.args[0]
payload = BeliefMessage.model_validate_json(msg.body)
assert payload.create[0].name == "face_present"
@pytest.mark.asyncio
async def test_update_face_belief_absent(agent):
"""Test that _update_face_belief(False) deletes the 'face_present' belief."""
agent.send = AsyncMock()
await agent._update_face_belief(False)
msg = agent.send.await_args.args[0]
payload = BeliefMessage.model_validate_json(msg.body)
assert payload.delete[0].name == "face_present"
@pytest.mark.asyncio
async def test_post_face_belief_present(agent):
"""Test that _post_face_belief(True) sends a belief creation message."""
agent.send = AsyncMock()
await agent._post_face_belief(True)
msg = agent.send.await_args.args[0]
assert '"create"' in msg.body and '"face_present"' in msg.body
@pytest.mark.asyncio
async def test_post_face_belief_absent(agent):
"""Test that _post_face_belief(False) sends a belief deletion message."""
agent.send = AsyncMock()
await agent._post_face_belief(False)
msg = agent.send.await_args.args[0]
assert '"delete"' in msg.body and '"face_present"' in msg.body
# -------------------------
# Message handling tests
# -------------------------
@pytest.mark.asyncio
async def test_handle_pause(agent):
"""Test that a 'PAUSE' message clears _paused and resets _last_face_state."""
agent._paused.set()
agent._last_face_state = True
msg = InternalMessage(
to=agent.name,
sender=face_module.settings.agent_settings.user_interrupt_name,
thread="cmd",
body="PAUSE",
)
await agent.handle_message(msg)
assert not agent._paused.is_set()
assert agent._last_face_state is None
@pytest.mark.asyncio
async def test_handle_resume(agent):
"""Test that a 'RESUME' message sets _paused."""
agent._paused.clear()
msg = InternalMessage(
to=agent.name,
sender=face_module.settings.agent_settings.user_interrupt_name,
thread="cmd",
body="RESUME",
)
await agent.handle_message(msg)
assert agent._paused.is_set()
@pytest.mark.asyncio
async def test_handle_unknown_command(agent):
"""Test that an unknown command from UserInterruptAgent is ignored (logs a warning)."""
msg = InternalMessage(
to=agent.name,
sender=face_module.settings.agent_settings.user_interrupt_name,
thread="cmd",
body="???",
)
await agent.handle_message(msg)
@pytest.mark.asyncio
async def test_handle_unknown_sender(agent):
"""Test that messages from unknown senders are ignored."""
msg = InternalMessage(
to=agent.name,
sender="someone_else",
thread="cmd",
body="PAUSE",
)
await agent.handle_message(msg)