Merge remote-tracking branch 'refs/remotes/origin/feat/visual-emotion-recognition' into feat/add-experiment-logs
This commit is contained in:
@@ -7,6 +7,7 @@ requires-python = ">=3.13"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"agentspeak>=0.2.2",
|
"agentspeak>=0.2.2",
|
||||||
"colorlog>=6.10.1",
|
"colorlog>=6.10.1",
|
||||||
|
"deepface>=0.0.96",
|
||||||
"fastapi[all]>=0.115.6",
|
"fastapi[all]>=0.115.6",
|
||||||
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
||||||
"numpy>=2.3.3",
|
"numpy>=2.3.3",
|
||||||
@@ -21,6 +22,7 @@ dependencies = [
|
|||||||
"silero-vad>=6.0.0",
|
"silero-vad>=6.0.0",
|
||||||
"sphinx>=7.3.7",
|
"sphinx>=7.3.7",
|
||||||
"sphinx-rtd-theme>=3.0.2",
|
"sphinx-rtd-theme>=3.0.2",
|
||||||
|
"tf-keras>=2.20.1",
|
||||||
"torch>=2.8.0",
|
"torch>=2.8.0",
|
||||||
"uvicorn>=0.37.0",
|
"uvicorn>=0.37.0",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from control_backend.schemas.program import (
|
|||||||
BaseGoal,
|
BaseGoal,
|
||||||
BasicNorm,
|
BasicNorm,
|
||||||
ConditionalNorm,
|
ConditionalNorm,
|
||||||
|
EmotionBelief,
|
||||||
GestureAction,
|
GestureAction,
|
||||||
Goal,
|
Goal,
|
||||||
InferredBelief,
|
InferredBelief,
|
||||||
@@ -459,6 +460,10 @@ class AgentSpeakGenerator:
|
|||||||
@_astify.register
|
@_astify.register
|
||||||
def _(self, sb: SemanticBelief) -> AstExpression:
|
def _(self, sb: SemanticBelief) -> AstExpression:
|
||||||
return AstLiteral(self.slugify(sb))
|
return AstLiteral(self.slugify(sb))
|
||||||
|
|
||||||
|
@_astify.register
|
||||||
|
def _(self, eb: EmotionBelief) -> AstExpression:
|
||||||
|
return AstLiteral("emotion_detected", [AstAtom(eb.emotion)])
|
||||||
|
|
||||||
@_astify.register
|
@_astify.register
|
||||||
def _(self, ib: InferredBelief) -> AstExpression:
|
def _(self, ib: InferredBelief) -> AstExpression:
|
||||||
|
|||||||
@@ -338,7 +338,7 @@ class BDICoreAgent(BaseAgent):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
@self.actions.add(".reply_with_goal", 3)
|
@self.actions.add(".reply_with_goal", 3)
|
||||||
def _reply_with_goal(agent: "BDICoreAgent", term, intention):
|
def _reply_with_goal(agent, term, intention):
|
||||||
"""
|
"""
|
||||||
Let the LLM generate a response to a user's utterance with the current norms and a
|
Let the LLM generate a response to a user's utterance with the current norms and a
|
||||||
specific goal.
|
specific goal.
|
||||||
@@ -512,10 +512,6 @@ class BDICoreAgent(BaseAgent):
|
|||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@self.actions.add(".notify_ui", 0)
|
|
||||||
def _notify_ui(agent, term, intention):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _send_to_llm(self, text: str, norms: str, goals: str):
|
async def _send_to_llm(self, text: str, norms: str, goals: str):
|
||||||
"""
|
"""
|
||||||
Sends a text query to the LLM agent asynchronously.
|
Sends a text query to the LLM agent asynchronously.
|
||||||
|
|||||||
@@ -318,6 +318,9 @@ class TextBeliefExtractorAgent(BaseAgent):
|
|||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
settings.llm_settings.local_llm_url,
|
settings.llm_settings.local_llm_url,
|
||||||
|
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
|
||||||
|
if settings.llm_settings.api_key
|
||||||
|
else {},
|
||||||
json={
|
json={
|
||||||
"model": settings.llm_settings.local_llm_model,
|
"model": settings.llm_settings.local_llm_model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from zmq.asyncio import Context
|
|||||||
|
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
||||||
|
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent import ( # noqa
|
||||||
|
VisualEmotionRecognitionAgent,
|
||||||
|
)
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.internal_message import InternalMessage
|
from control_backend.schemas.internal_message import InternalMessage
|
||||||
from control_backend.schemas.ri_message import PauseCommand
|
from control_backend.schemas.ri_message import PauseCommand
|
||||||
@@ -209,6 +212,13 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
case "audio":
|
case "audio":
|
||||||
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
||||||
await vad_agent.start()
|
await vad_agent.start()
|
||||||
|
case "video":
|
||||||
|
visual_emotion_agent = VisualEmotionRecognitionAgent(
|
||||||
|
settings.agent_settings.visual_emotion_recognition_name,
|
||||||
|
socket_address=addr,
|
||||||
|
bind=bind,
|
||||||
|
)
|
||||||
|
await visual_emotion_agent.start()
|
||||||
case _:
|
case _:
|
||||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@@ -32,6 +33,10 @@ class LLMAgent(BaseAgent):
|
|||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.history = []
|
self.history = []
|
||||||
|
self._querying = False
|
||||||
|
self._interrupted = False
|
||||||
|
self._interrupted_message = ""
|
||||||
|
self._go_ahead = asyncio.Event()
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
self.logger.info("Setting up %s.", self.name)
|
self.logger.info("Setting up %s.", self.name)
|
||||||
@@ -50,13 +55,13 @@ class LLMAgent(BaseAgent):
|
|||||||
case "prompt_message":
|
case "prompt_message":
|
||||||
try:
|
try:
|
||||||
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
|
||||||
await self._process_bdi_message(prompt_message)
|
self.add_behavior(self._process_bdi_message(prompt_message)) # no block
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
self.logger.debug("Prompt message from BDI core is invalid.")
|
self.logger.debug("Prompt message from BDI core is invalid.")
|
||||||
case "assistant_message":
|
case "assistant_message":
|
||||||
self.history.append({"role": "assistant", "content": msg.body})
|
self._apply_conversation_message({"role": "assistant", "content": msg.body})
|
||||||
case "user_message":
|
case "user_message":
|
||||||
self.history.append({"role": "user", "content": msg.body})
|
self._apply_conversation_message({"role": "user", "content": msg.body})
|
||||||
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
|
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
|
||||||
if msg.body == "clear_history":
|
if msg.body == "clear_history":
|
||||||
self.logger.debug("Clearing conversation history.")
|
self.logger.debug("Clearing conversation history.")
|
||||||
@@ -73,12 +78,45 @@ class LLMAgent(BaseAgent):
|
|||||||
|
|
||||||
:param message: The parsed prompt message containing text, norms, and goals.
|
:param message: The parsed prompt message containing text, norms, and goals.
|
||||||
"""
|
"""
|
||||||
|
if self._querying:
|
||||||
|
self.logger.debug("Received another BDI prompt while processing previous message.")
|
||||||
|
self._interrupted = True # interrupt the previous processing
|
||||||
|
await self._go_ahead.wait() # wait until we get the go-ahead
|
||||||
|
|
||||||
|
message.text = f"{self._interrupted_message} {message.text}"
|
||||||
|
|
||||||
|
self._go_ahead.clear()
|
||||||
|
self._querying = True
|
||||||
full_message = ""
|
full_message = ""
|
||||||
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
async for chunk in self._query_llm(message.text, message.norms, message.goals):
|
||||||
|
if self._interrupted:
|
||||||
|
self._interrupted_message = message.text
|
||||||
|
self.logger.debug("Interrupted processing of previous message.")
|
||||||
|
break
|
||||||
await self._send_reply(chunk)
|
await self._send_reply(chunk)
|
||||||
full_message += chunk
|
full_message += chunk
|
||||||
self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.")
|
else:
|
||||||
await self._send_full_reply(full_message)
|
self._querying = False
|
||||||
|
|
||||||
|
self._apply_conversation_message(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.logger.debug(
|
||||||
|
"Finished processing BDI message. Response sent in chunks to BDI core."
|
||||||
|
)
|
||||||
|
await self._send_full_reply(full_message)
|
||||||
|
|
||||||
|
self._go_ahead.set()
|
||||||
|
self._interrupted = False
|
||||||
|
|
||||||
|
def _apply_conversation_message(self, message: dict[str, str]):
|
||||||
|
if len(self.history) > 0 and message["role"] == self.history[-1]["role"]:
|
||||||
|
self.history[-1]["content"] += " " + message["content"]
|
||||||
|
return
|
||||||
|
self.history.append(message)
|
||||||
|
|
||||||
async def _send_reply(self, msg: str):
|
async def _send_reply(self, msg: str):
|
||||||
"""
|
"""
|
||||||
@@ -159,13 +197,6 @@ class LLMAgent(BaseAgent):
|
|||||||
# Yield any remaining tail
|
# Yield any remaining tail
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
yield current_chunk
|
yield current_chunk
|
||||||
|
|
||||||
self.history.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": full_message,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except httpx.HTTPError as err:
|
except httpx.HTTPError as err:
|
||||||
self.logger.error("HTTP error.", exc_info=err)
|
self.logger.error("HTTP error.", exc_info=err)
|
||||||
yield "LLM service unavailable."
|
yield "LLM service unavailable."
|
||||||
@@ -185,6 +216,9 @@ class LLMAgent(BaseAgent):
|
|||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
settings.llm_settings.local_llm_url,
|
settings.llm_settings.local_llm_url,
|
||||||
|
headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"}
|
||||||
|
if settings.llm_settings.api_key
|
||||||
|
else {},
|
||||||
json={
|
json={
|
||||||
"model": settings.llm_settings.local_llm_model,
|
"model": settings.llm_settings.local_llm_model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
|||||||
@@ -145,4 +145,6 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
|||||||
|
|
||||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"]
|
return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))[
|
||||||
|
"text"
|
||||||
|
].strip()
|
||||||
|
|||||||
@@ -0,0 +1,166 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio as azmq
|
||||||
|
from pydantic_core import ValidationError
|
||||||
|
|
||||||
|
from control_backend.agents import BaseAgent
|
||||||
|
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognizer import ( # noqa
|
||||||
|
DeepFaceEmotionRecognizer,
|
||||||
|
)
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import Belief
|
||||||
|
|
||||||
|
|
||||||
|
class VisualEmotionRecognitionAgent(BaseAgent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
socket_address: str,
|
||||||
|
bind: bool = False,
|
||||||
|
timeout_ms: int = 1000,
|
||||||
|
window_duration: int = settings.behaviour_settings.visual_emotion_recognition_window_duration_s, # noqa
|
||||||
|
min_frames_required: int = settings.behaviour_settings.visual_emotion_recognition_min_frames_per_face, # noqa
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Visual Emotion Recognition Agent.
|
||||||
|
|
||||||
|
:param name: Name of the agent
|
||||||
|
:param socket_address: Address of the socket to connect or bind to
|
||||||
|
:param bind: Whether to bind to the socket address (True) or connect (False)
|
||||||
|
:param timeout_ms: Timeout for socket receive operations in milliseconds
|
||||||
|
:param window_duration: Duration in seconds over which to aggregate emotions
|
||||||
|
:param min_frames_required: Minimum number of frames per face required to consider a face
|
||||||
|
valid
|
||||||
|
"""
|
||||||
|
super().__init__(name)
|
||||||
|
self.socket_address = socket_address
|
||||||
|
self.socket_bind = bind
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
self.window_duration = window_duration
|
||||||
|
self.min_frames_required = min_frames_required
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""
|
||||||
|
Initialize the agent resources.
|
||||||
|
1. Initializes the :class:`VisualEmotionRecognizer`.
|
||||||
|
2. Connects to the video input ZMQ socket.
|
||||||
|
3. Starts the background emotion recognition loop.
|
||||||
|
"""
|
||||||
|
self.logger.info("Setting up %s.", self.name)
|
||||||
|
|
||||||
|
self.emotion_recognizer = DeepFaceEmotionRecognizer()
|
||||||
|
|
||||||
|
self.video_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||||
|
|
||||||
|
if self.socket_bind:
|
||||||
|
self.video_in_socket.bind(self.socket_address)
|
||||||
|
else:
|
||||||
|
self.video_in_socket.connect(self.socket_address)
|
||||||
|
|
||||||
|
self.video_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
self.video_in_socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
|
||||||
|
self.video_in_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
|
||||||
|
self.add_behavior(self.emotion_update_loop())
|
||||||
|
|
||||||
|
async def emotion_update_loop(self):
|
||||||
|
"""
|
||||||
|
Background loop to receive video frames, recognize emotions, and update beliefs.
|
||||||
|
1. Receives video frames from the ZMQ socket.
|
||||||
|
2. Uses the :class:`VisualEmotionRecognizer` to detect emotions.
|
||||||
|
3. Aggregates emotions over a time window.
|
||||||
|
4. Sends updates to the BDI Core Agent about detected emotions.
|
||||||
|
"""
|
||||||
|
# Next time to process the window and update emotions
|
||||||
|
next_window_time = time.time() + self.window_duration
|
||||||
|
|
||||||
|
# Tracks counts of detected emotions per face index
|
||||||
|
face_stats = defaultdict(Counter)
|
||||||
|
|
||||||
|
prev_dominant_emotions = set()
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
frame_bytes = await self.video_in_socket.recv()
|
||||||
|
|
||||||
|
# Convert bytes to a numpy buffer
|
||||||
|
nparr = np.frombuffer(frame_bytes, np.uint8)
|
||||||
|
|
||||||
|
# Decode image into the generic Numpy Array DeepFace expects
|
||||||
|
frame_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
|
||||||
|
if frame_image is None:
|
||||||
|
# Could not decode image, skip this frame
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the dominant emotion from each face
|
||||||
|
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame_image)
|
||||||
|
# Update emotion counts for each detected face
|
||||||
|
for i, emotion in enumerate(current_emotions):
|
||||||
|
face_stats[i][emotion] += 1
|
||||||
|
|
||||||
|
# If window duration has passed, process the collected stats
|
||||||
|
if time.time() >= next_window_time:
|
||||||
|
window_dominant_emotions = set()
|
||||||
|
# Determine dominant emotion for each face in the window
|
||||||
|
for _, counter in face_stats.items():
|
||||||
|
total_detections = sum(counter.values())
|
||||||
|
|
||||||
|
if total_detections >= self.min_frames_required:
|
||||||
|
dominant_emotion = counter.most_common(1)[0][0]
|
||||||
|
window_dominant_emotions.add(dominant_emotion)
|
||||||
|
|
||||||
|
await self.update_emotions(prev_dominant_emotions, window_dominant_emotions)
|
||||||
|
prev_dominant_emotions = window_dominant_emotions
|
||||||
|
face_stats.clear()
|
||||||
|
next_window_time = time.time() + self.window_duration
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
self.logger.warning("No video frame received within timeout.")
|
||||||
|
|
||||||
|
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
||||||
|
"""
|
||||||
|
Compare emotions from previous window and current emotions,
|
||||||
|
send updates to BDI Core Agent.
|
||||||
|
"""
|
||||||
|
emotions_to_remove = prev_emotions - emotions
|
||||||
|
emotions_to_add = emotions - prev_emotions
|
||||||
|
|
||||||
|
if not emotions_to_add and not emotions_to_remove:
|
||||||
|
return
|
||||||
|
|
||||||
|
emotion_beliefs_remove = []
|
||||||
|
for emotion in emotions_to_remove:
|
||||||
|
self.logger.info(f"Emotion '{emotion}' has disappeared.")
|
||||||
|
try:
|
||||||
|
emotion_beliefs_remove.append(
|
||||||
|
Belief(name="emotion_detected", arguments=[emotion], remove=True)
|
||||||
|
)
|
||||||
|
except ValidationError:
|
||||||
|
self.logger.warning("Invalid belief for emotion removal: %s", emotion)
|
||||||
|
|
||||||
|
emotion_beliefs_add = []
|
||||||
|
for emotion in emotions_to_add:
|
||||||
|
self.logger.info(f"New emotion detected: '{emotion}'")
|
||||||
|
try:
|
||||||
|
emotion_beliefs_add.append(Belief(name="emotion_detected", arguments=[emotion]))
|
||||||
|
except ValidationError:
|
||||||
|
self.logger.warning("Invalid belief for new emotion: %s", emotion)
|
||||||
|
|
||||||
|
beliefs_list_add = [b.model_dump() for b in emotion_beliefs_add]
|
||||||
|
beliefs_list_remove = [b.model_dump() for b in emotion_beliefs_remove]
|
||||||
|
payload = {"create": beliefs_list_add, "delete": beliefs_list_remove}
|
||||||
|
|
||||||
|
message = InternalMessage(
|
||||||
|
to=settings.agent_settings.bdi_core_name,
|
||||||
|
sender=self.name,
|
||||||
|
body=json.dumps(payload),
|
||||||
|
thread="beliefs",
|
||||||
|
)
|
||||||
|
await self.send(message)
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from deepface import DeepFace
|
||||||
|
|
||||||
|
|
||||||
|
class VisualEmotionRecognizer(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_model(self):
|
||||||
|
"""Load the visual emotion recognition model into memory."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||||
|
"""
|
||||||
|
Recognize dominant emotions from faces in the given image.
|
||||||
|
Emotions can be one of ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'].
|
||||||
|
To minimize false positives, consider filtering faces with low confidence.
|
||||||
|
|
||||||
|
:param image: The input image for emotion recognition.
|
||||||
|
:return: List of dominant emotion detected for each face in the image,
|
||||||
|
sorted per face.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
||||||
|
"""
|
||||||
|
DeepFace-based implementation of VisualEmotionRecognizer.
|
||||||
|
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
|
||||||
|
emotions to be over-represented.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
print("Loading Deepface Emotion Model...")
|
||||||
|
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||||
|
# analyze does not take a model as an argument, calling it once on a dummy image to load
|
||||||
|
# the model
|
||||||
|
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
|
||||||
|
print("Deepface Emotion Model loaded.")
|
||||||
|
|
||||||
|
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||||
|
analysis = DeepFace.analyze(image,
|
||||||
|
actions=['emotion'],
|
||||||
|
enforce_detection=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort faces by x coordinate to maintain left-to-right order
|
||||||
|
analysis.sort(key=lambda face: face['region']['x'])
|
||||||
|
|
||||||
|
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
|
||||||
|
|
||||||
|
dominant_emotions = [face['dominant_emotion'] for face in analysis]
|
||||||
|
return dominant_emotions
|
||||||
@@ -50,6 +50,7 @@ class AgentSettings(BaseModel):
|
|||||||
# agent names
|
# agent names
|
||||||
bdi_core_name: str = "bdi_core_agent"
|
bdi_core_name: str = "bdi_core_agent"
|
||||||
bdi_program_manager_name: str = "bdi_program_manager_agent"
|
bdi_program_manager_name: str = "bdi_program_manager_agent"
|
||||||
|
visual_emotion_recognition_name: str = "visual_emotion_recognition_agent"
|
||||||
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
||||||
vad_name: str = "vad_agent"
|
vad_name: str = "vad_agent"
|
||||||
llm_name: str = "llm_agent"
|
llm_name: str = "llm_agent"
|
||||||
@@ -77,6 +78,10 @@ class BehaviourSettings(BaseModel):
|
|||||||
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
|
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
|
||||||
:ivar transcription_token_buffer: Buffer for transcription tokens.
|
:ivar transcription_token_buffer: Buffer for transcription tokens.
|
||||||
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
|
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
|
||||||
|
:ivar visual_emotion_recognition_window_duration_s: Duration in seconds over which to aggregate
|
||||||
|
emotions and update emotion beliefs.
|
||||||
|
:ivar visual_emotion_recognition_min_frames_per_face: Minimum number of frames per face required
|
||||||
|
to consider a face valid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
||||||
@@ -100,6 +105,9 @@ class BehaviourSettings(BaseModel):
|
|||||||
# Text belief extractor settings
|
# Text belief extractor settings
|
||||||
conversation_history_length_limit: int = 10
|
conversation_history_length_limit: int = 10
|
||||||
|
|
||||||
|
# Visual Emotion Recognition settings
|
||||||
|
visual_emotion_recognition_window_duration_s: int = 5
|
||||||
|
visual_emotion_recognition_min_frames_per_face: int = 3
|
||||||
|
|
||||||
class LLMSettings(BaseModel):
|
class LLMSettings(BaseModel):
|
||||||
"""
|
"""
|
||||||
@@ -117,6 +125,7 @@ class LLMSettings(BaseModel):
|
|||||||
|
|
||||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||||
local_llm_model: str = "gpt-oss"
|
local_llm_model: str = "gpt-oss"
|
||||||
|
api_key: str = ""
|
||||||
chat_temperature: float = 1.0
|
chat_temperature: float = 1.0
|
||||||
code_temperature: float = 0.3
|
code_temperature: float = 0.3
|
||||||
n_parallel: int = 4
|
n_parallel: int = 4
|
||||||
|
|||||||
@@ -28,8 +28,8 @@ class LogicalOperator(Enum):
|
|||||||
OR = "OR"
|
OR = "OR"
|
||||||
|
|
||||||
|
|
||||||
type Belief = KeywordBelief | SemanticBelief | InferredBelief
|
type Belief = KeywordBelief | SemanticBelief | InferredBelief | EmotionBelief
|
||||||
type BasicBelief = KeywordBelief | SemanticBelief
|
type BasicBelief = KeywordBelief | SemanticBelief | EmotionBelief
|
||||||
|
|
||||||
|
|
||||||
class KeywordBelief(ProgramElement):
|
class KeywordBelief(ProgramElement):
|
||||||
@@ -69,6 +69,15 @@ class InferredBelief(ProgramElement):
|
|||||||
left: Belief
|
left: Belief
|
||||||
right: Belief
|
right: Belief
|
||||||
|
|
||||||
|
class EmotionBelief(ProgramElement):
|
||||||
|
"""
|
||||||
|
Represents a belief that is set when a certain emotion is detected.
|
||||||
|
|
||||||
|
:ivar emotion: The emotion on which this belief gets set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
emotion: str
|
||||||
|
|
||||||
class Norm(ProgramElement):
|
class Norm(ProgramElement):
|
||||||
"""
|
"""
|
||||||
@@ -226,3 +235,9 @@ class Program(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
phases: list[Phase]
|
phases: list[Phase]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
input = input("Enter program JSON: ")
|
||||||
|
program = Program.model_validate_json(input)
|
||||||
|
print(program)
|
||||||
@@ -61,8 +61,52 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
|
|||||||
thread="prompt_message", # REQUIRED: thread must match handle_message logic
|
thread="prompt_message", # REQUIRED: thread must match handle_message logic
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent._process_bdi_message = AsyncMock()
|
||||||
|
|
||||||
await agent.handle_message(msg)
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
agent._process_bdi_message.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_bdi_message_success(mock_httpx_client, mock_settings):
|
||||||
|
# Setup the mock response for the stream
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
# Simulate stream lines
|
||||||
|
lines = [
|
||||||
|
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
|
||||||
|
b'data: {"choices": [{"delta": {"content": " world"}}]}',
|
||||||
|
b'data: {"choices": [{"delta": {"content": "."}}]}',
|
||||||
|
b"data: [DONE]",
|
||||||
|
]
|
||||||
|
|
||||||
|
async def aiter_lines_gen():
|
||||||
|
for line in lines:
|
||||||
|
yield line.decode()
|
||||||
|
|
||||||
|
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
||||||
|
|
||||||
|
mock_stream_context = MagicMock()
|
||||||
|
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
# Configure the client
|
||||||
|
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
||||||
|
|
||||||
|
# Setup Agent
|
||||||
|
agent = LLMAgent("llm_agent")
|
||||||
|
agent.send = AsyncMock() # Mock the send method to verify replies
|
||||||
|
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
agent.logger = mock_logger
|
||||||
|
|
||||||
|
# Simulate receiving a message from BDI
|
||||||
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
||||||
|
|
||||||
|
await agent._process_bdi_message(prompt)
|
||||||
|
|
||||||
# Verification
|
# Verification
|
||||||
# "Hello world." constitutes one sentence/chunk based on punctuation split
|
# "Hello world." constitutes one sentence/chunk based on punctuation split
|
||||||
# The agent should call send once with the full sentence, PLUS once more for full reply
|
# The agent should call send once with the full sentence, PLUS once more for full reply
|
||||||
@@ -79,28 +123,16 @@ async def test_llm_processing_errors(mock_httpx_client, mock_settings):
|
|||||||
agent = LLMAgent("llm_agent")
|
agent = LLMAgent("llm_agent")
|
||||||
agent.send = AsyncMock()
|
agent.send = AsyncMock()
|
||||||
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
||||||
msg = InternalMessage(
|
|
||||||
to="llm",
|
|
||||||
sender=mock_settings.agent_settings.bdi_core_name,
|
|
||||||
body=prompt.model_dump_json(),
|
|
||||||
thread="prompt_message",
|
|
||||||
)
|
|
||||||
|
|
||||||
# HTTP Error: stream method RAISES exception immediately
|
# HTTP Error: stream method RAISES exception immediately
|
||||||
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
|
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
|
||||||
|
|
||||||
await agent.handle_message(msg)
|
await agent._process_bdi_message(prompt)
|
||||||
|
|
||||||
# Check that error message was sent
|
# Check that error message was sent
|
||||||
assert agent.send.called
|
assert agent.send.called
|
||||||
assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body
|
assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body
|
||||||
|
|
||||||
# General Exception
|
|
||||||
agent.send.reset_mock()
|
|
||||||
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
|
|
||||||
await agent.handle_message(msg)
|
|
||||||
assert "Error processing the request." in agent.send.call_args_list[0][0][0].body
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_json_error(mock_httpx_client, mock_settings):
|
async def test_llm_json_error(mock_httpx_client, mock_settings):
|
||||||
@@ -125,13 +157,7 @@ async def test_llm_json_error(mock_httpx_client, mock_settings):
|
|||||||
agent.logger = MagicMock()
|
agent.logger = MagicMock()
|
||||||
|
|
||||||
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
||||||
msg = InternalMessage(
|
await agent._process_bdi_message(prompt)
|
||||||
to="llm",
|
|
||||||
sender=mock_settings.agent_settings.bdi_core_name,
|
|
||||||
body=prompt.model_dump_json(),
|
|
||||||
thread="prompt_message",
|
|
||||||
)
|
|
||||||
await agent.handle_message(msg)
|
|
||||||
|
|
||||||
agent.logger.error.assert_called() # Should log JSONDecodeError
|
agent.logger.error.assert_called() # Should log JSONDecodeError
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user