feat: visual emotion recognition agent
This commit is contained in:
@@ -29,6 +29,7 @@ from control_backend.schemas.program import (
|
||||
BaseGoal,
|
||||
BasicNorm,
|
||||
ConditionalNorm,
|
||||
EmotionBelief,
|
||||
GestureAction,
|
||||
Goal,
|
||||
InferredBelief,
|
||||
@@ -681,6 +682,10 @@ class AgentSpeakGenerator:
|
||||
:return: An AstLiteral representing the semantic belief.
|
||||
"""
|
||||
return AstLiteral(self.slugify(sb))
|
||||
|
||||
@_astify.register
|
||||
def _(self, eb: EmotionBelief) -> AstExpression:
|
||||
return AstLiteral("emotion_detected", [AstAtom(eb.emotion)])
|
||||
|
||||
@_astify.register
|
||||
def _(self, ib: InferredBelief) -> AstExpression:
|
||||
|
||||
@@ -9,14 +9,14 @@ import json
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
from pydantic import ValidationError
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
||||
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent import ( # noqa
|
||||
VisualEmotionRecognitionAgent,
|
||||
)
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.internal_message import InternalMessage
|
||||
from control_backend.schemas.ri_message import PauseCommand
|
||||
|
||||
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||
from ..perception import VADAgent
|
||||
@@ -58,6 +58,7 @@ class RICommunicationAgent(BaseAgent):
|
||||
self.connected = False
|
||||
self.gesture_agent: RobotGestureAgent | None = None
|
||||
self.speech_agent: RobotSpeechAgent | None = None
|
||||
self.visual_emotion_recognition_agent: VisualEmotionRecognitionAgent | None = None
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
@@ -215,6 +216,14 @@ class RICommunicationAgent(BaseAgent):
|
||||
case "audio":
|
||||
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
||||
await vad_agent.start()
|
||||
case "video":
|
||||
visual_emotion_agent = VisualEmotionRecognitionAgent(
|
||||
settings.agent_settings.visual_emotion_recognition_name,
|
||||
socket_address=addr,
|
||||
bind=bind,
|
||||
)
|
||||
self.visual_emotion_recognition_agent = visual_emotion_agent
|
||||
await visual_emotion_agent.start()
|
||||
case _:
|
||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||
|
||||
@@ -319,6 +328,9 @@ class RICommunicationAgent(BaseAgent):
|
||||
|
||||
if self.speech_agent is not None:
|
||||
await self.speech_agent.stop()
|
||||
|
||||
if self.visual_emotion_recognition_agent is not None:
|
||||
await self.visual_emotion_recognition_agent.stop()
|
||||
|
||||
if self.pub_socket is not None:
|
||||
self.pub_socket.close()
|
||||
@@ -326,11 +338,4 @@ class RICommunicationAgent(BaseAgent):
|
||||
self.logger.debug("Restarting communication negotiation.")
|
||||
if await self._negotiate_connection(max_retries=2):
|
||||
self.connected = True
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
try:
|
||||
pause_command = PauseCommand.model_validate_json(msg.body)
|
||||
await self._req_socket.send_json(pause_command.model_dump())
|
||||
self.logger.debug(await self._req_socket.recv_json())
|
||||
except ValidationError:
|
||||
self.logger.warning("Incorrect message format for PauseCommand.")
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognizer import ( # noqa
|
||||
DeepFaceEmotionRecognizer,
|
||||
)
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.belief_message import Belief
|
||||
|
||||
|
||||
class VisualEmotionRecognitionAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
socket_address: str,
|
||||
bind: bool = False,
|
||||
timeout_ms: int = 1000,
|
||||
window_duration: int = settings.behaviour_settings.visual_emotion_recognition_window_duration_s, # noqa
|
||||
min_frames_required: int = settings.behaviour_settings.visual_emotion_recognition_min_frames_per_face, # noqa
|
||||
):
|
||||
"""
|
||||
Initialize the Visual Emotion Recognition Agent.
|
||||
|
||||
:param name: Name of the agent
|
||||
:param socket_address: Address of the socket to connect or bind to
|
||||
:param bind: Whether to bind to the socket address (True) or connect (False)
|
||||
:param timeout_ms: Timeout for socket receive operations in milliseconds
|
||||
:param window_duration: Duration in seconds over which to aggregate emotions
|
||||
:param min_frames_required: Minimum number of frames per face required to consider a face
|
||||
valid
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.socket_address = socket_address
|
||||
self.socket_bind = bind
|
||||
self.timeout_ms = timeout_ms
|
||||
self.window_duration = window_duration
|
||||
self.min_frames_required = min_frames_required
|
||||
|
||||
# Pause functionality
|
||||
# NOTE: flag is set when running, cleared when paused
|
||||
self._paused = asyncio.Event()
|
||||
self._paused.set()
|
||||
|
||||
async def setup(self):
|
||||
"""
|
||||
Initialize the agent resources.
|
||||
1. Initializes the :class:`VisualEmotionRecognizer`.
|
||||
2. Connects to the video input ZMQ socket.
|
||||
3. Starts the background emotion recognition loop.
|
||||
"""
|
||||
self.logger.info("Setting up %s.", self.name)
|
||||
|
||||
self.emotion_recognizer = DeepFaceEmotionRecognizer()
|
||||
|
||||
self.video_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
|
||||
self.video_in_socket.setsockopt(zmq.RCVHWM, 3)
|
||||
|
||||
if self.socket_bind:
|
||||
self.video_in_socket.bind(self.socket_address)
|
||||
else:
|
||||
self.video_in_socket.connect(self.socket_address)
|
||||
|
||||
self.video_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.video_in_socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
|
||||
|
||||
self.add_behavior(self.emotion_update_loop())
|
||||
|
||||
async def emotion_update_loop(self):
|
||||
"""
|
||||
Background loop to receive video frames, recognize emotions, and update beliefs.
|
||||
1. Receives video frames from the ZMQ socket.
|
||||
2. Uses the :class:`VisualEmotionRecognizer` to detect emotions.
|
||||
3. Aggregates emotions over a time window.
|
||||
4. Sends updates to the BDI Core Agent about detected emotions.
|
||||
"""
|
||||
# Next time to process the window and update emotions
|
||||
next_window_time = time.time() + self.window_duration
|
||||
|
||||
# Tracks counts of detected emotions per face index
|
||||
face_stats = defaultdict(Counter)
|
||||
|
||||
prev_dominant_emotions = set()
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._paused.wait()
|
||||
|
||||
width, height, image_bytes = await self.video_in_socket.recv_multipart()
|
||||
|
||||
width = int.from_bytes(width, 'little')
|
||||
height = int.from_bytes(height, 'little')
|
||||
|
||||
# Convert bytes to a numpy buffer
|
||||
image_array = np.frombuffer(image_bytes, np.uint8)
|
||||
|
||||
frame = image_array.reshape((height, width, 3))
|
||||
|
||||
# Get the dominant emotion from each face
|
||||
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame)
|
||||
# Update emotion counts for each detected face
|
||||
for i, emotion in enumerate(current_emotions):
|
||||
face_stats[i][emotion] += 1
|
||||
|
||||
# If window duration has passed, process the collected stats
|
||||
if time.time() >= next_window_time:
|
||||
window_dominant_emotions = set()
|
||||
# Determine dominant emotion for each face in the window
|
||||
for _, counter in face_stats.items():
|
||||
total_detections = sum(counter.values())
|
||||
|
||||
if total_detections >= self.min_frames_required:
|
||||
dominant_emotion = counter.most_common(1)[0][0]
|
||||
window_dominant_emotions.add(dominant_emotion)
|
||||
|
||||
await self.update_emotions(prev_dominant_emotions, window_dominant_emotions)
|
||||
prev_dominant_emotions = window_dominant_emotions
|
||||
face_stats.clear()
|
||||
next_window_time = time.time() + self.window_duration
|
||||
|
||||
except zmq.Again:
|
||||
self.logger.warning("No video frame received within timeout.")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in emotion recognition loop: {e}")
|
||||
|
||||
|
||||
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
||||
"""
|
||||
Compare emotions from previous window and current emotions,
|
||||
send updates to BDI Core Agent.
|
||||
"""
|
||||
emotions_to_remove = prev_emotions - emotions
|
||||
emotions_to_add = emotions - prev_emotions
|
||||
|
||||
if not emotions_to_add and not emotions_to_remove:
|
||||
return
|
||||
|
||||
emotion_beliefs_remove = []
|
||||
for emotion in emotions_to_remove:
|
||||
self.logger.info(f"Emotion '{emotion}' has disappeared.")
|
||||
try:
|
||||
emotion_beliefs_remove.append(
|
||||
Belief(name="emotion_detected", arguments=[emotion], remove=True)
|
||||
)
|
||||
except ValidationError:
|
||||
self.logger.warning("Invalid belief for emotion removal: %s", emotion)
|
||||
|
||||
emotion_beliefs_add = []
|
||||
for emotion in emotions_to_add:
|
||||
self.logger.info(f"New emotion detected: '{emotion}'")
|
||||
try:
|
||||
emotion_beliefs_add.append(Belief(name="emotion_detected", arguments=[emotion]))
|
||||
except ValidationError:
|
||||
self.logger.warning("Invalid belief for new emotion: %s", emotion)
|
||||
|
||||
beliefs_list_add = [b.model_dump() for b in emotion_beliefs_add]
|
||||
beliefs_list_remove = [b.model_dump() for b in emotion_beliefs_remove]
|
||||
payload = {"create": beliefs_list_add, "delete": beliefs_list_remove}
|
||||
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.bdi_core_name,
|
||||
sender=self.name,
|
||||
body=json.dumps(payload),
|
||||
thread="beliefs",
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
"""
|
||||
Handle incoming messages.
|
||||
|
||||
Expects messages to pause or resume the Visual Emotion Recognition
|
||||
processing from User Interrupt Agent.
|
||||
|
||||
:param msg: The received internal message.
|
||||
"""
|
||||
sender = msg.sender
|
||||
|
||||
if sender == settings.agent_settings.user_interrupt_name:
|
||||
if msg.body == "PAUSE":
|
||||
self.logger.info("Pausing Visual Emotion Recognition processing.")
|
||||
self._paused.clear()
|
||||
elif msg.body == "RESUME":
|
||||
self.logger.info("Resuming Visual Emotion Recognition processing.")
|
||||
self._paused.set()
|
||||
else:
|
||||
self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}")
|
||||
else:
|
||||
self.logger.debug(f"Ignoring message from unknown sender: {sender}")
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Clean up resources used by the agent.
|
||||
"""
|
||||
self.video_in_socket.close()
|
||||
await super().stop()
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import abc
|
||||
|
||||
import numpy as np
|
||||
from deepface import DeepFace
|
||||
|
||||
|
||||
class VisualEmotionRecognizer(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def load_model(self):
|
||||
"""Load the visual emotion recognition model into memory."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||
"""
|
||||
Recognize dominant emotions from faces in the given image.
|
||||
Emotions can be one of ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'].
|
||||
To minimize false positives, consider filtering faces with low confidence.
|
||||
|
||||
:param image: The input image for emotion recognition.
|
||||
:return: List of dominant emotion detected for each face in the image,
|
||||
sorted per face.
|
||||
"""
|
||||
pass
|
||||
|
||||
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
||||
"""
|
||||
DeepFace-based implementation of VisualEmotionRecognizer.
|
||||
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
|
||||
emotions to be over-represented.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.load_model()
|
||||
|
||||
def load_model(self):
|
||||
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||
# analyze does not take a model as an argument, calling it once on a dummy image to load
|
||||
# the model
|
||||
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
|
||||
|
||||
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||
analysis = DeepFace.analyze(image,
|
||||
actions=['emotion'],
|
||||
enforce_detection=False
|
||||
)
|
||||
|
||||
# Sort faces by x coordinate to maintain left-to-right order
|
||||
analysis.sort(key=lambda face: face['region']['x'])
|
||||
|
||||
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
|
||||
|
||||
dominant_emotions = [face['dominant_emotion'] for face in analysis]
|
||||
return dominant_emotions
|
||||
@@ -18,7 +18,6 @@ from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||
from control_backend.schemas.program import ConditionalNorm, Goal, Program
|
||||
from control_backend.schemas.ri_message import (
|
||||
GestureCommand,
|
||||
PauseCommand,
|
||||
RIEndpoint,
|
||||
SpeechCommand,
|
||||
)
|
||||
@@ -398,34 +397,29 @@ class UserInterruptAgent(BaseAgent):
|
||||
self.logger.debug("Sending experiment control '%s' to BDI Core.", thread)
|
||||
await self.send(out_msg)
|
||||
|
||||
async def _send_pause_command(self, pause):
|
||||
async def _send_pause_command(self, pause: str):
|
||||
"""
|
||||
Send a pause command to the Robot Interface via the RI Communication Agent.
|
||||
Send a pause command to the other internal agents; for now just VAD agent.
|
||||
Send a pause command to the other internal agents; for now just VAD and VED agent.
|
||||
"""
|
||||
cmd = PauseCommand(data=pause)
|
||||
message = InternalMessage(
|
||||
to=settings.agent_settings.ri_communication_name,
|
||||
sender=self.name,
|
||||
body=cmd.model_dump_json(),
|
||||
)
|
||||
await self.send(message)
|
||||
|
||||
if pause == "true":
|
||||
# Send pause to VAD agent
|
||||
# Send pause to VAD and VED agent
|
||||
vad_message = InternalMessage(
|
||||
to=settings.agent_settings.vad_name,
|
||||
to=[settings.agent_settings.vad_name,
|
||||
settings.agent_settings.visual_emotion_recognition_name],
|
||||
sender=self.name,
|
||||
body="PAUSE",
|
||||
)
|
||||
await self.send(vad_message)
|
||||
self.logger.info("Sent pause command to VAD Agent and RI Communication Agent.")
|
||||
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||
self.logger.info("Sent pause command to VAD and VED agents.")
|
||||
else:
|
||||
# Send resume to VAD agent
|
||||
# Send resume to VAD and VED agents
|
||||
vad_message = InternalMessage(
|
||||
to=settings.agent_settings.vad_name,
|
||||
to=[settings.agent_settings.vad_name,
|
||||
settings.agent_settings.visual_emotion_recognition_name],
|
||||
sender=self.name,
|
||||
body="RESUME",
|
||||
)
|
||||
await self.send(vad_message)
|
||||
self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.")
|
||||
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||
self.logger.info("Sent resume command to VAD and VED agents.")
|
||||
@@ -54,6 +54,7 @@ class AgentSettings(BaseModel):
|
||||
# agent names
|
||||
bdi_core_name: str = "bdi_core_agent"
|
||||
bdi_program_manager_name: str = "bdi_program_manager_agent"
|
||||
visual_emotion_recognition_name: str = "visual_emotion_recognition_agent"
|
||||
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
||||
vad_name: str = "vad_agent"
|
||||
llm_name: str = "llm_agent"
|
||||
@@ -81,6 +82,10 @@ class BehaviourSettings(BaseModel):
|
||||
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
|
||||
:ivar transcription_token_buffer: Buffer for transcription tokens.
|
||||
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
|
||||
:ivar visual_emotion_recognition_window_duration_s: Duration in seconds over which to aggregate
|
||||
emotions and update emotion beliefs.
|
||||
:ivar visual_emotion_recognition_min_frames_per_face: Minimum number of frames per face required
|
||||
to consider a face valid.
|
||||
:ivar trigger_time_to_wait: Amount of milliseconds to wait before informing the UI about trigger
|
||||
completion.
|
||||
"""
|
||||
@@ -106,6 +111,9 @@ class BehaviourSettings(BaseModel):
|
||||
# Text belief extractor settings
|
||||
conversation_history_length_limit: int = 10
|
||||
|
||||
# Visual Emotion Recognition settings
|
||||
visual_emotion_recognition_window_duration_s: int = 5
|
||||
visual_emotion_recognition_min_frames_per_face: int = 3
|
||||
# AgentSpeak related settings
|
||||
trigger_time_to_wait: int = 2000
|
||||
agentspeak_file: str = "src/control_backend/agents/bdi/agentspeak.asl"
|
||||
|
||||
@@ -41,8 +41,8 @@ class LogicalOperator(Enum):
|
||||
OR = "OR"
|
||||
|
||||
|
||||
type Belief = KeywordBelief | SemanticBelief | InferredBelief
|
||||
type BasicBelief = KeywordBelief | SemanticBelief
|
||||
type Belief = KeywordBelief | SemanticBelief | InferredBelief | EmotionBelief
|
||||
type BasicBelief = KeywordBelief | SemanticBelief | EmotionBelief
|
||||
|
||||
|
||||
class KeywordBelief(ProgramElement):
|
||||
@@ -105,6 +105,15 @@ class InferredBelief(ProgramElement):
|
||||
left: Belief
|
||||
right: Belief
|
||||
|
||||
class EmotionBelief(ProgramElement):
|
||||
"""
|
||||
Represents a belief that is set when a certain emotion is detected.
|
||||
|
||||
:ivar emotion: The emotion on which this belief gets set.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
emotion: str
|
||||
|
||||
class Norm(ProgramElement):
|
||||
"""
|
||||
@@ -315,3 +324,9 @@ class Program(BaseModel):
|
||||
"""
|
||||
|
||||
phases: list[Phase]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input = input("Enter program JSON: ")
|
||||
program = Program.model_validate_json(input)
|
||||
print(program)
|
||||
Reference in New Issue
Block a user