feat: fully implemented visual emotion recognition agent in pipeline
ref: N25B-393
This commit is contained in:
@@ -158,6 +158,9 @@ class BDICoreAgent(BaseAgent):
|
|||||||
for belief in beliefs:
|
for belief in beliefs:
|
||||||
if belief.replace:
|
if belief.replace:
|
||||||
self._remove_all_with_name(belief.name)
|
self._remove_all_with_name(belief.name)
|
||||||
|
elif belief.remove:
|
||||||
|
self._remove_belief(belief.name, belief.arguments)
|
||||||
|
continue
|
||||||
self._add_belief(belief.name, belief.arguments)
|
self._add_belief(belief.name, belief.arguments)
|
||||||
|
|
||||||
def _add_belief(self, name: str, args: Iterable[str] = []):
|
def _add_belief(self, name: str, args: Iterable[str] = []):
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ from zmq.asyncio import Context
|
|||||||
|
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
||||||
|
from control_backend.agents.perception.visual_emotion_detection_agent.visual_emotion_recognition_agent import (
|
||||||
|
VisualEmotionRecognitionAgent,
|
||||||
|
)
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||||
@@ -201,6 +204,13 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
case "audio":
|
case "audio":
|
||||||
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
||||||
await vad_agent.start()
|
await vad_agent.start()
|
||||||
|
case "video":
|
||||||
|
visual_emotion_agent = VisualEmotionRecognitionAgent(
|
||||||
|
settings.agent_settings.visual_emotion_recognition_name,
|
||||||
|
socket_address=addr,
|
||||||
|
bind=bind,
|
||||||
|
)
|
||||||
|
await visual_emotion_agent.start()
|
||||||
case _:
|
case _:
|
||||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +1,28 @@
|
|||||||
import asyncio
|
import json
|
||||||
|
import time
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from pydantic_core import ValidationError
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio as azmq
|
import zmq.asyncio as azmq
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
from collections import defaultdict, Counter
|
|
||||||
import time
|
|
||||||
|
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.agents.perception.visual_emotion_detection_agent.visual_emotion_recognizer import DeepFaceEmotionRecognizer
|
from control_backend.agents.perception.visual_emotion_detection_agent.visual_emotion_recognizer import (
|
||||||
|
DeepFaceEmotionRecognizer,
|
||||||
|
)
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import Belief
|
||||||
|
|
||||||
# START FROM RI COMMUNICATION AGENT?
|
# START FROM RI COMMUNICATION AGENT?
|
||||||
|
|
||||||
class VisualEmotionRecognitionAgent(BaseAgent):
|
class VisualEmotionRecognitionAgent(BaseAgent):
|
||||||
def __init__(self, socket_address: str, socket_bind: bool = False, timeout_ms: int = 1000):
|
def __init__(self, name, socket_address: str, bind: bool = False, timeout_ms: int = 1000):
|
||||||
super().__init__(settings.agent_settings.visual_emotion_recognition_name)
|
super().__init__(name)
|
||||||
self.socket_address = socket_address
|
self.socket_address = socket_address
|
||||||
self.socket_bind = socket_bind
|
self.socket_bind = bind
|
||||||
self.timeout_ms = timeout_ms
|
self.timeout_ms = timeout_ms
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
@@ -41,8 +46,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
async def emotion_update_loop(self):
|
async def emotion_update_loop(self):
|
||||||
"""
|
"""
|
||||||
Retrieve a video frame from the input socket.
|
Retrieve a video frame from the input socket.
|
||||||
|
|
||||||
:return: The received video frame, or None if timeout occurs.
|
|
||||||
"""
|
"""
|
||||||
window_duration = 1 # seconds
|
window_duration = 1 # seconds
|
||||||
next_window_time = time.time() + window_duration
|
next_window_time = time.time() + window_duration
|
||||||
@@ -70,7 +73,7 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
if frame_image is None:
|
if frame_image is None:
|
||||||
# Could not decode image, skip this frame
|
# Could not decode image, skip this frame
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the dominant emotion from each face
|
# Get the dominant emotion from each face
|
||||||
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame_image)
|
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame_image)
|
||||||
# Update emotion counts for each detected face
|
# Update emotion counts for each detected face
|
||||||
@@ -90,7 +93,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
window_dominant_emotions.add(dominant_emotion)
|
window_dominant_emotions.add(dominant_emotion)
|
||||||
|
|
||||||
await self.update_emotions(prev_dominant_emotions, window_dominant_emotions)
|
await self.update_emotions(prev_dominant_emotions, window_dominant_emotions)
|
||||||
|
|
||||||
prev_dominant_emotions = window_dominant_emotions
|
prev_dominant_emotions = window_dominant_emotions
|
||||||
face_stats.clear()
|
face_stats.clear()
|
||||||
next_window_time = time.time() + window_duration
|
next_window_time = time.time() + window_duration
|
||||||
@@ -98,14 +100,40 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
self.logger.warning("No video frame received within timeout.")
|
self.logger.warning("No video frame received within timeout.")
|
||||||
|
|
||||||
async def update_emotions(self, prev_emotions, emotions):
|
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
||||||
|
"""
|
||||||
|
Compare emotions from previous window and current emotions,
|
||||||
|
send updates to BDI Core Agent.
|
||||||
|
"""
|
||||||
# Remove emotions that are no longer present
|
# Remove emotions that are no longer present
|
||||||
emotions_to_remove = prev_emotions - emotions
|
emotions_to_remove = prev_emotions - emotions
|
||||||
|
new_emotions = emotions - prev_emotions
|
||||||
|
|
||||||
|
if not new_emotions and not emotions_to_remove:
|
||||||
|
return
|
||||||
|
|
||||||
|
emotion_beliefs = []
|
||||||
|
# Remove emotions that have disappeared
|
||||||
for emotion in emotions_to_remove:
|
for emotion in emotions_to_remove:
|
||||||
self.logger.info(f"Emotion '{emotion}' has disappeared.")
|
self.logger.info(f"Emotion '{emotion}' has disappeared.")
|
||||||
|
try:
|
||||||
|
emotion_beliefs.append(Belief(name="emotion", arguments=[emotion], remove=True))
|
||||||
|
except ValidationError:
|
||||||
|
self.logger.warning("Invalid belief for emotion removal: %s", emotion)
|
||||||
|
|
||||||
# Add new emotions that have appeared
|
# Add new emotions that have appeared
|
||||||
new_emotions = emotions - prev_emotions
|
|
||||||
for emotion in new_emotions:
|
for emotion in new_emotions:
|
||||||
self.logger.info(f"New emotion detected: '{emotion}'")
|
self.logger.info(f"New emotion detected: '{emotion}'")
|
||||||
|
try:
|
||||||
|
emotion_beliefs.append(Belief(name="emotion", arguments=[emotion]))
|
||||||
|
except ValidationError:
|
||||||
|
self.logger.warning("Invalid belief for new emotion: %s", emotion)
|
||||||
|
|
||||||
|
message = InternalMessage(
|
||||||
|
to=settings.agent_settings.bdi_core_name,
|
||||||
|
sender=self.name,
|
||||||
|
body=json.dumps(emotion_beliefs),
|
||||||
|
thread="beliefs",
|
||||||
|
)
|
||||||
|
self.logger.debug("Sending emotion beliefs update: %s", emotion_beliefs)
|
||||||
|
await self.send(message)
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
import abc
|
import abc
|
||||||
from deepface import DeepFace
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import Counter
|
from deepface import DeepFace
|
||||||
|
|
||||||
|
|
||||||
class VisualEmotionRecognizer(abc.ABC):
|
class VisualEmotionRecognizer(abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -42,7 +43,6 @@ class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
|||||||
|
|
||||||
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
|
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
|
||||||
|
|
||||||
# Return list of (dominant_emotion, face_confidence) tuples
|
|
||||||
dominant_emotions = [face['dominant_emotion'] for face in analysis]
|
dominant_emotions = [face['dominant_emotion'] for face in analysis]
|
||||||
return dominant_emotions
|
return dominant_emotions
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,9 @@ from control_backend.agents.communication import RICommunicationAgent
|
|||||||
from control_backend.agents.llm import LLMAgent
|
from control_backend.agents.llm import LLMAgent
|
||||||
|
|
||||||
# User Interrupt Agent
|
# User Interrupt Agent
|
||||||
from control_backend.agents.perception.visual_emotion_detection_agent.visual_emotion_recognition_agent import VisualEmotionRecognitionAgent
|
from control_backend.agents.perception.visual_emotion_detection_agent.visual_emotion_recognition_agent import (
|
||||||
|
VisualEmotionRecognitionAgent,
|
||||||
|
)
|
||||||
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
||||||
|
|
||||||
# Other backend imports
|
# Other backend imports
|
||||||
@@ -148,13 +150,6 @@ async def lifespan(app: FastAPI):
|
|||||||
"name": settings.agent_settings.user_interrupt_name,
|
"name": settings.agent_settings.user_interrupt_name,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
# TODO: Spawn agent from RI Communication Agent
|
|
||||||
"VisualEmotionRecognitionAgent": (
|
|
||||||
VisualEmotionRecognitionAgent,
|
|
||||||
{
|
|
||||||
"socket_address": "tcp://localhost:5556", # TODO: move to settings
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
agents = []
|
agents = []
|
||||||
|
|||||||
@@ -8,11 +8,13 @@ class Belief(BaseModel):
|
|||||||
:ivar name: The functor or name of the belief (e.g., 'user_said').
|
:ivar name: The functor or name of the belief (e.g., 'user_said').
|
||||||
:ivar arguments: A list of string arguments for the belief.
|
:ivar arguments: A list of string arguments for the belief.
|
||||||
:ivar replace: If True, existing beliefs with this name should be replaced by this one.
|
:ivar replace: If True, existing beliefs with this name should be replaced by this one.
|
||||||
|
:ivar remove: If True, this belief should be removed from the belief base.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
arguments: list[str]
|
arguments: list[str]
|
||||||
replace: bool = False
|
replace: bool = False
|
||||||
|
remove: bool = False
|
||||||
|
|
||||||
|
|
||||||
class BeliefMessage(BaseModel):
|
class BeliefMessage(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user