feat: visual emotion recognition agent #54
@@ -7,6 +7,7 @@ requires-python = ">=3.13"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"agentspeak>=0.2.2",
|
"agentspeak>=0.2.2",
|
||||||
"colorlog>=6.10.1",
|
"colorlog>=6.10.1",
|
||||||
|
"deepface>=0.0.96",
|
||||||
"fastapi[all]>=0.115.6",
|
"fastapi[all]>=0.115.6",
|
||||||
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
||||||
"numpy>=2.3.3",
|
"numpy>=2.3.3",
|
||||||
@@ -21,6 +22,7 @@ dependencies = [
|
|||||||
"silero-vad>=6.0.0",
|
"silero-vad>=6.0.0",
|
||||||
"sphinx>=7.3.7",
|
"sphinx>=7.3.7",
|
||||||
"sphinx-rtd-theme>=3.0.2",
|
"sphinx-rtd-theme>=3.0.2",
|
||||||
|
"tf-keras>=2.20.1",
|
||||||
"torch>=2.8.0",
|
"torch>=2.8.0",
|
||||||
"uvicorn>=0.37.0",
|
"uvicorn>=0.37.0",
|
||||||
]
|
]
|
||||||
@@ -38,6 +40,7 @@ dev = [
|
|||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
"agentspeak>=0.2.2",
|
"agentspeak>=0.2.2",
|
||||||
|
"deepface>=0.0.97",
|
||||||
"fastapi>=0.115.6",
|
"fastapi>=0.115.6",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
||||||
@@ -52,6 +55,7 @@ test = [
|
|||||||
"pyyaml>=6.0.3",
|
"pyyaml>=6.0.3",
|
||||||
"pyzmq>=27.1.0",
|
"pyzmq>=27.1.0",
|
||||||
"soundfile>=0.13.1",
|
"soundfile>=0.13.1",
|
||||||
|
"tf-keras>=2.20.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from control_backend.schemas.program import (
|
|||||||
BaseGoal,
|
BaseGoal,
|
||||||
BasicNorm,
|
BasicNorm,
|
||||||
ConditionalNorm,
|
ConditionalNorm,
|
||||||
|
EmotionBelief,
|
||||||
GestureAction,
|
GestureAction,
|
||||||
Goal,
|
Goal,
|
||||||
InferredBelief,
|
InferredBelief,
|
||||||
@@ -682,6 +683,10 @@ class AgentSpeakGenerator:
|
|||||||
"""
|
"""
|
||||||
return AstLiteral(self.slugify(sb))
|
return AstLiteral(self.slugify(sb))
|
||||||
|
|
||||||
|
@_astify.register
|
||||||
|
def _(self, eb: EmotionBelief) -> AstExpression:
|
||||||
|
return AstLiteral("emotion_detected", [AstAtom(eb.emotion)])
|
||||||
|
|
||||||
@_astify.register
|
@_astify.register
|
||||||
def _(self, ib: InferredBelief) -> AstExpression:
|
def _(self, ib: InferredBelief) -> AstExpression:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -9,14 +9,14 @@ import json
|
|||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio as azmq
|
import zmq.asyncio as azmq
|
||||||
from pydantic import ValidationError
|
|
||||||
from zmq.asyncio import Context
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
||||||
|
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent import ( # noqa
|
||||||
|
VisualEmotionRecognitionAgent,
|
||||||
|
)
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.internal_message import InternalMessage
|
|
||||||
from control_backend.schemas.ri_message import PauseCommand
|
|
||||||
|
|
||||||
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||||
from ..perception import VADAgent
|
from ..perception import VADAgent
|
||||||
@@ -58,6 +58,7 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
self.connected = False
|
self.connected = False
|
||||||
self.gesture_agent: RobotGestureAgent | None = None
|
self.gesture_agent: RobotGestureAgent | None = None
|
||||||
self.speech_agent: RobotSpeechAgent | None = None
|
self.speech_agent: RobotSpeechAgent | None = None
|
||||||
|
self.visual_emotion_recognition_agent: VisualEmotionRecognitionAgent | None = None
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""
|
"""
|
||||||
@@ -215,6 +216,14 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
case "audio":
|
case "audio":
|
||||||
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
||||||
await vad_agent.start()
|
await vad_agent.start()
|
||||||
|
case "video":
|
||||||
|
visual_emotion_agent = VisualEmotionRecognitionAgent(
|
||||||
|
settings.agent_settings.visual_emotion_recognition_name,
|
||||||
|
socket_address=addr,
|
||||||
|
bind=bind,
|
||||||
|
)
|
||||||
|
self.visual_emotion_recognition_agent = visual_emotion_agent
|
||||||
|
await visual_emotion_agent.start()
|
||||||
case _:
|
case _:
|
||||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||||
|
|
||||||
@@ -320,6 +329,9 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
if self.speech_agent is not None:
|
if self.speech_agent is not None:
|
||||||
await self.speech_agent.stop()
|
await self.speech_agent.stop()
|
||||||
|
|
||||||
|
if self.visual_emotion_recognition_agent is not None:
|
||||||
|
await self.visual_emotion_recognition_agent.stop()
|
||||||
|
|
||||||
if self.pub_socket is not None:
|
if self.pub_socket is not None:
|
||||||
self.pub_socket.close()
|
self.pub_socket.close()
|
||||||
|
|
||||||
@@ -327,10 +339,3 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
if await self._negotiate_connection(max_retries=2):
|
if await self._negotiate_connection(max_retries=2):
|
||||||
self.connected = True
|
self.connected = True
|
||||||
|
|
||||||
async def handle_message(self, msg: InternalMessage):
|
|
||||||
try:
|
|
||||||
pause_command = PauseCommand.model_validate_json(msg.body)
|
|
||||||
await self._req_socket.send_json(pause_command.model_dump())
|
|
||||||
self.logger.debug(await self._req_socket.recv_json())
|
|
||||||
except ValidationError:
|
|
||||||
self.logger.warning("Incorrect message format for PauseCommand.")
|
|
||||||
|
|||||||
@@ -0,0 +1,207 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio as azmq
|
||||||
|
from pydantic_core import ValidationError
|
||||||
|
|
||||||
|
from control_backend.agents import BaseAgent
|
||||||
|
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognizer import ( # noqa
|
||||||
|
DeepFaceEmotionRecognizer,
|
||||||
|
)
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import Belief
|
||||||
|
|
||||||
|
|
||||||
|
class VisualEmotionRecognitionAgent(BaseAgent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
socket_address: str,
|
||||||
|
bind: bool = False,
|
||||||
|
timeout_ms: int = 1000,
|
||||||
|
window_duration: int = settings.behaviour_settings.visual_emotion_recognition_window_duration_s, # noqa
|
||||||
|
min_frames_required: int = settings.behaviour_settings.visual_emotion_recognition_min_frames_per_face, # noqa
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Visual Emotion Recognition Agent.
|
||||||
|
|
||||||
|
:param name: Name of the agent
|
||||||
|
:param socket_address: Address of the socket to connect or bind to
|
||||||
|
:param bind: Whether to bind to the socket address (True) or connect (False)
|
||||||
|
:param timeout_ms: Timeout for socket receive operations in milliseconds
|
||||||
|
:param window_duration: Duration in seconds over which to aggregate emotions
|
||||||
|
:param min_frames_required: Minimum number of frames per face required to consider a face
|
||||||
|
valid
|
||||||
|
"""
|
||||||
|
super().__init__(name)
|
||||||
|
self.socket_address = socket_address
|
||||||
|
self.socket_bind = bind
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
self.window_duration = window_duration
|
||||||
|
self.min_frames_required = min_frames_required
|
||||||
|
|
||||||
|
# Pause functionality
|
||||||
|
# NOTE: flag is set when running, cleared when paused
|
||||||
|
self._paused = asyncio.Event()
|
||||||
|
self._paused.set()
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""
|
||||||
|
Initialize the agent resources.
|
||||||
|
1. Initializes the :class:`VisualEmotionRecognizer`.
|
||||||
|
2. Connects to the video input ZMQ socket.
|
||||||
|
3. Starts the background emotion recognition loop.
|
||||||
|
"""
|
||||||
|
self.logger.info("Setting up %s.", self.name)
|
||||||
|
|
||||||
|
self.emotion_recognizer = DeepFaceEmotionRecognizer()
|
||||||
|
|
||||||
|
self.video_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||||
|
|
||||||
|
self.video_in_socket.setsockopt(zmq.RCVHWM, 3)
|
||||||
|
|
||||||
|
if self.socket_bind:
|
||||||
|
self.video_in_socket.bind(self.socket_address)
|
||||||
|
else:
|
||||||
|
self.video_in_socket.connect(self.socket_address)
|
||||||
|
|
||||||
|
self.video_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
self.video_in_socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
|
||||||
|
|
||||||
|
self.add_behavior(self.emotion_update_loop())
|
||||||
|
|
||||||
|
async def emotion_update_loop(self):
|
||||||
|
"""
|
||||||
|
Background loop to receive video frames, recognize emotions, and update beliefs.
|
||||||
|
1. Receives video frames from the ZMQ socket.
|
||||||
|
2. Uses the :class:`VisualEmotionRecognizer` to detect emotions.
|
||||||
|
3. Aggregates emotions over a time window.
|
||||||
|
4. Sends updates to the BDI Core Agent about detected emotions.
|
||||||
|
"""
|
||||||
|
# Next time to process the window and update emotions
|
||||||
|
next_window_time = time.time() + self.window_duration
|
||||||
|
|
||||||
|
# Tracks counts of detected emotions per face index
|
||||||
|
face_stats = defaultdict(Counter)
|
||||||
|
|
||||||
|
prev_dominant_emotions = set()
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await self._paused.wait()
|
||||||
|
|
||||||
|
width, height, image_bytes = await self.video_in_socket.recv_multipart()
|
||||||
|
|
||||||
|
width = int.from_bytes(width, 'little')
|
||||||
|
height = int.from_bytes(height, 'little')
|
||||||
|
|
||||||
|
# Convert bytes to a numpy buffer
|
||||||
|
image_array = np.frombuffer(image_bytes, np.uint8)
|
||||||
|
|
||||||
|
frame = image_array.reshape((height, width, 3))
|
||||||
|
|
||||||
|
# Get the dominant emotion from each face
|
||||||
|
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame)
|
||||||
|
# Update emotion counts for each detected face
|
||||||
|
for i, emotion in enumerate(current_emotions):
|
||||||
|
face_stats[i][emotion] += 1
|
||||||
|
|
||||||
|
# If window duration has passed, process the collected stats
|
||||||
|
if time.time() >= next_window_time:
|
||||||
|
window_dominant_emotions = set()
|
||||||
|
# Determine dominant emotion for each face in the window
|
||||||
|
for _, counter in face_stats.items():
|
||||||
|
total_detections = sum(counter.values())
|
||||||
|
|
||||||
|
if total_detections >= self.min_frames_required:
|
||||||
|
dominant_emotion = counter.most_common(1)[0][0]
|
||||||
|
window_dominant_emotions.add(dominant_emotion)
|
||||||
|
|
||||||
|
await self.update_emotions(prev_dominant_emotions, window_dominant_emotions)
|
||||||
|
prev_dominant_emotions = window_dominant_emotions
|
||||||
|
face_stats.clear()
|
||||||
|
next_window_time = time.time() + self.window_duration
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
self.logger.warning("No video frame received within timeout.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in emotion recognition loop: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
||||||
|
"""
|
||||||
|
Compare emotions from previous window and current emotions,
|
||||||
|
send updates to BDI Core Agent.
|
||||||
|
"""
|
||||||
|
emotions_to_remove = prev_emotions - emotions
|
||||||
|
emotions_to_add = emotions - prev_emotions
|
||||||
|
|
||||||
|
if not emotions_to_add and not emotions_to_remove:
|
||||||
|
return
|
||||||
|
|
||||||
|
emotion_beliefs_remove = []
|
||||||
|
for emotion in emotions_to_remove:
|
||||||
|
self.logger.info(f"Emotion '{emotion}' has disappeared.")
|
||||||
|
try:
|
||||||
|
emotion_beliefs_remove.append(
|
||||||
|
Belief(name="emotion_detected", arguments=[emotion], remove=True)
|
||||||
|
)
|
||||||
|
except ValidationError:
|
||||||
|
self.logger.warning("Invalid belief for emotion removal: %s", emotion)
|
||||||
|
|
||||||
|
emotion_beliefs_add = []
|
||||||
|
for emotion in emotions_to_add:
|
||||||
|
self.logger.info(f"New emotion detected: '{emotion}'")
|
||||||
|
try:
|
||||||
|
emotion_beliefs_add.append(Belief(name="emotion_detected", arguments=[emotion]))
|
||||||
|
except ValidationError:
|
||||||
|
self.logger.warning("Invalid belief for new emotion: %s", emotion)
|
||||||
|
|
||||||
|
beliefs_list_add = [b.model_dump() for b in emotion_beliefs_add]
|
||||||
|
beliefs_list_remove = [b.model_dump() for b in emotion_beliefs_remove]
|
||||||
|
payload = {"create": beliefs_list_add, "delete": beliefs_list_remove}
|
||||||
|
|
||||||
|
message = InternalMessage(
|
||||||
|
to=settings.agent_settings.bdi_core_name,
|
||||||
|
sender=self.name,
|
||||||
|
body=json.dumps(payload),
|
||||||
|
thread="beliefs",
|
||||||
|
)
|
||||||
|
await self.send(message)
|
||||||
|
|
||||||
|
async def handle_message(self, msg: InternalMessage):
|
||||||
|
"""
|
||||||
|
Handle incoming messages.
|
||||||
|
|
||||||
|
Expects messages to pause or resume the Visual Emotion Recognition
|
||||||
|
processing from User Interrupt Agent.
|
||||||
|
|
||||||
|
:param msg: The received internal message.
|
||||||
|
"""
|
||||||
|
sender = msg.sender
|
||||||
|
|
||||||
|
if sender == settings.agent_settings.user_interrupt_name:
|
||||||
|
if msg.body == "PAUSE":
|
||||||
|
self.logger.info("Pausing Visual Emotion Recognition processing.")
|
||||||
|
self._paused.clear()
|
||||||
|
elif msg.body == "RESUME":
|
||||||
|
self.logger.info("Resuming Visual Emotion Recognition processing.")
|
||||||
|
self._paused.set()
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}")
|
||||||
|
else:
|
||||||
|
self.logger.debug(f"Ignoring message from unknown sender: {sender}")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""
|
||||||
|
Clean up resources used by the agent.
|
||||||
|
"""
|
||||||
|
self.video_in_socket.close()
|
||||||
|
await super().stop()
|
||||||
|
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from deepface import DeepFace
|
||||||
|
|
||||||
|
|
||||||
|
class VisualEmotionRecognizer(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_model(self):
|
||||||
|
"""Load the visual emotion recognition model into memory."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||||
|
"""
|
||||||
|
Recognize dominant emotions from faces in the given image.
|
||||||
|
Emotions can be one of ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral'].
|
||||||
|
To minimize false positives, consider filtering faces with low confidence.
|
||||||
|
|
||||||
|
:param image: The input image for emotion recognition.
|
||||||
|
:return: List of dominant emotion detected for each face in the image,
|
||||||
|
sorted per face.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
||||||
|
"""
|
||||||
|
DeepFace-based implementation of VisualEmotionRecognizer.
|
||||||
|
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
|
||||||
|
emotions to be over-represented.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||||
|
# analyze does not take a model as an argument, calling it once on a dummy image to load
|
||||||
|
# the model
|
||||||
|
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
|
||||||
|
|
||||||
|
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||||
|
analysis = DeepFace.analyze(image,
|
||||||
|
actions=['emotion'],
|
||||||
|
enforce_detection=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort faces by x coordinate to maintain left-to-right order
|
||||||
|
analysis.sort(key=lambda face: face['region']['x'])
|
||||||
|
|
||||||
|
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
|
||||||
|
|
||||||
|
dominant_emotions = [face['dominant_emotion'] for face in analysis]
|
||||||
|
return dominant_emotions
|
||||||
@@ -18,7 +18,6 @@ from control_backend.schemas.belief_message import Belief, BeliefMessage
|
|||||||
from control_backend.schemas.program import ConditionalNorm, Goal, Program
|
from control_backend.schemas.program import ConditionalNorm, Goal, Program
|
||||||
from control_backend.schemas.ri_message import (
|
from control_backend.schemas.ri_message import (
|
||||||
GestureCommand,
|
GestureCommand,
|
||||||
PauseCommand,
|
|
||||||
RIEndpoint,
|
RIEndpoint,
|
||||||
SpeechCommand,
|
SpeechCommand,
|
||||||
)
|
)
|
||||||
@@ -398,34 +397,29 @@ class UserInterruptAgent(BaseAgent):
|
|||||||
self.logger.debug("Sending experiment control '%s' to BDI Core.", thread)
|
self.logger.debug("Sending experiment control '%s' to BDI Core.", thread)
|
||||||
await self.send(out_msg)
|
await self.send(out_msg)
|
||||||
|
|
||||||
async def _send_pause_command(self, pause):
|
async def _send_pause_command(self, pause: str):
|
||||||
"""
|
"""
|
||||||
Send a pause command to the Robot Interface via the RI Communication Agent.
|
Send a pause command to the other internal agents; for now just VAD and VED agent.
|
||||||
Send a pause command to the other internal agents; for now just VAD agent.
|
|
||||||
"""
|
"""
|
||||||
cmd = PauseCommand(data=pause)
|
|
||||||
message = InternalMessage(
|
|
||||||
to=settings.agent_settings.ri_communication_name,
|
|
||||||
sender=self.name,
|
|
||||||
body=cmd.model_dump_json(),
|
|
||||||
)
|
|
||||||
await self.send(message)
|
|
||||||
|
|
||||||
if pause == "true":
|
if pause == "true":
|
||||||
# Send pause to VAD agent
|
# Send pause to VAD and VED agent
|
||||||
vad_message = InternalMessage(
|
vad_message = InternalMessage(
|
||||||
to=settings.agent_settings.vad_name,
|
to=[settings.agent_settings.vad_name,
|
||||||
|
settings.agent_settings.visual_emotion_recognition_name],
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body="PAUSE",
|
body="PAUSE",
|
||||||
)
|
)
|
||||||
await self.send(vad_message)
|
await self.send(vad_message)
|
||||||
self.logger.info("Sent pause command to VAD Agent and RI Communication Agent.")
|
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||||
|
self.logger.info("Sent pause command to VAD and VED agents.")
|
||||||
else:
|
else:
|
||||||
# Send resume to VAD agent
|
# Send resume to VAD and VED agents
|
||||||
vad_message = InternalMessage(
|
vad_message = InternalMessage(
|
||||||
to=settings.agent_settings.vad_name,
|
to=[settings.agent_settings.vad_name,
|
||||||
|
settings.agent_settings.visual_emotion_recognition_name],
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body="RESUME",
|
body="RESUME",
|
||||||
)
|
)
|
||||||
await self.send(vad_message)
|
await self.send(vad_message)
|
||||||
self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.")
|
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||||
|
self.logger.info("Sent resume command to VAD and VED agents.")
|
||||||
@@ -54,6 +54,7 @@ class AgentSettings(BaseModel):
|
|||||||
# agent names
|
# agent names
|
||||||
bdi_core_name: str = "bdi_core_agent"
|
bdi_core_name: str = "bdi_core_agent"
|
||||||
bdi_program_manager_name: str = "bdi_program_manager_agent"
|
bdi_program_manager_name: str = "bdi_program_manager_agent"
|
||||||
|
visual_emotion_recognition_name: str = "visual_emotion_recognition_agent"
|
||||||
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
||||||
vad_name: str = "vad_agent"
|
vad_name: str = "vad_agent"
|
||||||
llm_name: str = "llm_agent"
|
llm_name: str = "llm_agent"
|
||||||
@@ -81,6 +82,10 @@ class BehaviourSettings(BaseModel):
|
|||||||
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
|
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
|
||||||
:ivar transcription_token_buffer: Buffer for transcription tokens.
|
:ivar transcription_token_buffer: Buffer for transcription tokens.
|
||||||
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
|
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
|
||||||
|
:ivar visual_emotion_recognition_window_duration_s: Duration in seconds over which to aggregate
|
||||||
|
emotions and update emotion beliefs.
|
||||||
|
:ivar visual_emotion_recognition_min_frames_per_face: Minimum number of frames per face required
|
||||||
|
to consider a face valid.
|
||||||
:ivar trigger_time_to_wait: Amount of milliseconds to wait before informing the UI about trigger
|
:ivar trigger_time_to_wait: Amount of milliseconds to wait before informing the UI about trigger
|
||||||
completion.
|
completion.
|
||||||
"""
|
"""
|
||||||
@@ -106,6 +111,9 @@ class BehaviourSettings(BaseModel):
|
|||||||
# Text belief extractor settings
|
# Text belief extractor settings
|
||||||
conversation_history_length_limit: int = 10
|
conversation_history_length_limit: int = 10
|
||||||
|
|
||||||
|
# Visual Emotion Recognition settings
|
||||||
|
visual_emotion_recognition_window_duration_s: int = 5
|
||||||
|
visual_emotion_recognition_min_frames_per_face: int = 3
|
||||||
# AgentSpeak related settings
|
# AgentSpeak related settings
|
||||||
trigger_time_to_wait: int = 2000
|
trigger_time_to_wait: int = 2000
|
||||||
agentspeak_file: str = "src/control_backend/agents/bdi/agentspeak.asl"
|
agentspeak_file: str = "src/control_backend/agents/bdi/agentspeak.asl"
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ class LogicalOperator(Enum):
|
|||||||
OR = "OR"
|
OR = "OR"
|
||||||
|
|
||||||
|
|
||||||
type Belief = KeywordBelief | SemanticBelief | InferredBelief
|
type Belief = KeywordBelief | SemanticBelief | InferredBelief | EmotionBelief
|
||||||
type BasicBelief = KeywordBelief | SemanticBelief
|
type BasicBelief = KeywordBelief | SemanticBelief | EmotionBelief
|
||||||
|
|
||||||
|
|
||||||
class KeywordBelief(ProgramElement):
|
class KeywordBelief(ProgramElement):
|
||||||
@@ -105,6 +105,15 @@ class InferredBelief(ProgramElement):
|
|||||||
left: Belief
|
left: Belief
|
||||||
right: Belief
|
right: Belief
|
||||||
|
|
||||||
|
class EmotionBelief(ProgramElement):
|
||||||
|
"""
|
||||||
|
Represents a belief that is set when a certain emotion is detected.
|
||||||
|
|
||||||
|
:ivar emotion: The emotion on which this belief gets set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
emotion: str
|
||||||
|
|
||||||
class Norm(ProgramElement):
|
class Norm(ProgramElement):
|
||||||
"""
|
"""
|
||||||
@@ -315,3 +324,9 @@ class Program(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
phases: list[Phase]
|
phases: list[Phase]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
input = input("Enter program JSON: ")
|
||||||
|
program = Program.model_validate_json(input)
|
||||||
|
print(program)
|
||||||
@@ -10,8 +10,6 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
|
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
|
||||||
from control_backend.schemas.ri_message import PauseCommand, RIEndpoint
|
|
||||||
|
|
||||||
|
|
||||||
def speech_agent_path():
|
def speech_agent_path():
|
||||||
@@ -402,38 +400,3 @@ async def test_negotiate_req_socket_none_causes_retry(zmq_context):
|
|||||||
result = await agent._negotiate_connection(max_retries=1)
|
result = await agent._negotiate_connection(max_retries=1)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_message_pause_command(zmq_context):
|
|
||||||
"""Test handle_message with a valid PauseCommand."""
|
|
||||||
agent = RICommunicationAgent("ri_comm")
|
|
||||||
agent._req_socket = AsyncMock()
|
|
||||||
agent.logger = MagicMock()
|
|
||||||
|
|
||||||
agent._req_socket.recv_json.return_value = {"status": "ok"}
|
|
||||||
|
|
||||||
pause_cmd = PauseCommand(data=True)
|
|
||||||
msg = InternalMessage(to="ri_comm", sender="user_int", body=pause_cmd.model_dump_json())
|
|
||||||
|
|
||||||
await agent.handle_message(msg)
|
|
||||||
|
|
||||||
agent._req_socket.send_json.assert_awaited_once()
|
|
||||||
args = agent._req_socket.send_json.await_args[0][0]
|
|
||||||
assert args["endpoint"] == RIEndpoint.PAUSE.value
|
|
||||||
assert args["data"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_message_invalid_pause_command(zmq_context):
|
|
||||||
"""Test handle_message with invalid JSON."""
|
|
||||||
agent = RICommunicationAgent("ri_comm")
|
|
||||||
agent._req_socket = AsyncMock()
|
|
||||||
agent.logger = MagicMock()
|
|
||||||
|
|
||||||
msg = InternalMessage(to="ri_comm", sender="user_int", body="invalid json")
|
|
||||||
|
|
||||||
await agent.handle_message(msg)
|
|
||||||
|
|
||||||
agent.logger.warning.assert_called_with("Incorrect message format for PauseCommand.")
|
|
||||||
agent._req_socket.send_json.assert_not_called()
|
|
||||||
|
|||||||
@@ -0,0 +1,338 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import zmq
|
||||||
|
from pydantic_core import ValidationError
|
||||||
|
|
||||||
|
# Adjust the import path to match your project structure
|
||||||
|
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent import ( # noqa
|
||||||
|
VisualEmotionRecognitionAgent,
|
||||||
|
)
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.settings") as mock: # noqa
|
||||||
|
# Set default values required by the agent
|
||||||
|
mock.behaviour_settings.visual_emotion_recognition_window_duration_s = 5
|
||||||
|
mock.behaviour_settings.visual_emotion_recognition_min_frames_per_face = 3
|
||||||
|
mock.agent_settings.bdi_core_name = "bdi_core_agent"
|
||||||
|
mock.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_deepface():
|
||||||
|
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.DeepFaceEmotionRecognizer") as mock: # noqa
|
||||||
|
instance = mock.return_value
|
||||||
|
instance.sorted_dominant_emotions.return_value = []
|
||||||
|
yield instance
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_zmq_context():
|
||||||
|
with patch("zmq.asyncio.Context.instance") as mock_ctx:
|
||||||
|
mock_socket = MagicMock()
|
||||||
|
# Mock socket methods to return None or AsyncMock for async methods
|
||||||
|
mock_socket.bind = MagicMock()
|
||||||
|
mock_socket.connect = MagicMock()
|
||||||
|
mock_socket.setsockopt = MagicMock()
|
||||||
|
mock_socket.setsockopt_string = MagicMock()
|
||||||
|
mock_socket.recv_multipart = AsyncMock()
|
||||||
|
mock_socket.close = MagicMock()
|
||||||
|
|
||||||
|
mock_ctx.return_value.socket.return_value = mock_socket
|
||||||
|
yield mock_ctx
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent(mock_settings, mock_deepface, mock_zmq_context):
|
||||||
|
# Initialize agent with specific params to control testing
|
||||||
|
agent = VisualEmotionRecognitionAgent(
|
||||||
|
name="test_agent",
|
||||||
|
socket_address="tcp://localhost:5555",
|
||||||
|
bind=False,
|
||||||
|
timeout_ms=100,
|
||||||
|
window_duration=2,
|
||||||
|
min_frames_required=2
|
||||||
|
)
|
||||||
|
# Mock the internal send method from BaseAgent
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
# Mock the add_behavior method from BaseAgent
|
||||||
|
agent.add_behavior = MagicMock()
|
||||||
|
# Mock the logger
|
||||||
|
agent.logger = MagicMock()
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialization(agent):
|
||||||
|
"""Test that the agent initializes with correct attributes."""
|
||||||
|
assert agent.name == "test_agent"
|
||||||
|
assert agent.socket_address == "tcp://localhost:5555"
|
||||||
|
assert agent.socket_bind is False
|
||||||
|
assert agent.timeout_ms == 100
|
||||||
|
assert agent._paused.is_set()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_connect(agent, mock_zmq_context, mock_deepface):
|
||||||
|
"""Test setup routine when binding is False (connect)."""
|
||||||
|
agent.socket_bind = False
|
||||||
|
await agent.setup()
|
||||||
|
|
||||||
|
socket = agent.video_in_socket
|
||||||
|
socket.connect.assert_called_with("tcp://localhost:5555")
|
||||||
|
socket.bind.assert_not_called()
|
||||||
|
socket.setsockopt.assert_any_call(zmq.RCVHWM, 3)
|
||||||
|
socket.setsockopt.assert_any_call(zmq.RCVTIMEO, 100)
|
||||||
|
agent.add_behavior.assert_called_once()
|
||||||
|
assert agent.emotion_recognizer == mock_deepface
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_bind(agent, mock_zmq_context):
|
||||||
|
"""Test setup routine when binding is True."""
|
||||||
|
agent.socket_bind = True
|
||||||
|
await agent.setup()
|
||||||
|
|
||||||
|
socket = agent.video_in_socket
|
||||||
|
socket.bind.assert_called_with("tcp://localhost:5555")
|
||||||
|
socket.connect.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_emotion_update_loop_normal_flow(agent, mock_deepface):
|
||||||
|
"""
|
||||||
|
Test the main loop logic:
|
||||||
|
1. Receive frames
|
||||||
|
2. Aggregate stats
|
||||||
|
3. Trigger window update
|
||||||
|
4. Call update_emotions
|
||||||
|
"""
|
||||||
|
# Setup dependencies
|
||||||
|
await agent.setup()
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
# Create fake image data (10x10 pixels)
|
||||||
|
width, height = 10, 10
|
||||||
|
image_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes()
|
||||||
|
w_bytes = width.to_bytes(4, 'little')
|
||||||
|
h_bytes = height.to_bytes(4, 'little')
|
||||||
|
|
||||||
|
# Mock ZMQ receive to return data 3 times, then stop the loop
|
||||||
|
# We use a side_effect on recv_multipart to simulate frames and then stop the loop
|
||||||
|
async def recv_side_effect():
|
||||||
|
if agent._running:
|
||||||
|
return w_bytes, h_bytes, image_bytes
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
agent.video_in_socket.recv_multipart.side_effect = recv_side_effect
|
||||||
|
|
||||||
|
# Mock DeepFace to return emotions
|
||||||
|
# Frame 1: Happy
|
||||||
|
# Frame 2: Happy
|
||||||
|
# Frame 3: Happy (Trigger window)
|
||||||
|
mock_deepface.sorted_dominant_emotions.side_effect = [
|
||||||
|
["happy"],
|
||||||
|
["happy"],
|
||||||
|
["happy"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock update_emotions to verify it's called
|
||||||
|
agent.update_emotions = AsyncMock()
|
||||||
|
|
||||||
|
# Mock time.time to simulate window passage
|
||||||
|
# We need time to advance significantly after the frames are collected
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
with patch("time.time") as mock_time:
|
||||||
|
# Sequence of time calls:
|
||||||
|
# 1. Init next_window_time calculation
|
||||||
|
# 2. Loop 1 check
|
||||||
|
# 3. Loop 2 check
|
||||||
|
# 4. Loop 3 check (Make this one pass the window threshold)
|
||||||
|
mock_time.side_effect = [
|
||||||
|
start_time, # init
|
||||||
|
start_time + 0.1, # frame 1 check
|
||||||
|
start_time + 0.2, # frame 2 check
|
||||||
|
start_time + 10.0, # frame 3 check (triggers window reset)
|
||||||
|
start_time + 10.1, # next init
|
||||||
|
start_time + 10.2 # break loop
|
||||||
|
]
|
||||||
|
|
||||||
|
# We need to manually break the infinite loop after the update
|
||||||
|
# We can do this by wrapping update_emotions to set _running = False
|
||||||
|
async def stop_loop(*args, **kwargs):
|
||||||
|
agent._running = False
|
||||||
|
|
||||||
|
agent.update_emotions.side_effect = stop_loop
|
||||||
|
|
||||||
|
# Run the loop
|
||||||
|
await agent.emotion_update_loop()
|
||||||
|
|
||||||
|
# Verifications
|
||||||
|
assert agent.update_emotions.called
|
||||||
|
# Check that it detected 'happy' as dominant (2 required, 3 found)
|
||||||
|
call_args = agent.update_emotions.call_args
|
||||||
|
assert call_args is not None
|
||||||
|
# args: (prev_emotions, window_dominant_emotions)
|
||||||
|
assert call_args[0][1] == {"happy"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_emotion_update_loop_insufficient_frames(agent, mock_deepface):
|
||||||
|
"""Test that emotions are NOT updated if min_frames_required is not met."""
|
||||||
|
await agent.setup()
|
||||||
|
agent._running = True
|
||||||
|
agent.min_frames_required = 5 # Set high requirement
|
||||||
|
|
||||||
|
width, height = 10, 10
|
||||||
|
image_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes()
|
||||||
|
w_bytes = width.to_bytes(4, 'little')
|
||||||
|
h_bytes = height.to_bytes(4, 'little')
|
||||||
|
|
||||||
|
agent.video_in_socket.recv_multipart.return_value = (w_bytes, h_bytes, image_bytes)
|
||||||
|
mock_deepface.sorted_dominant_emotions.return_value = ["sad"]
|
||||||
|
|
||||||
|
agent.update_emotions = AsyncMock()
|
||||||
|
|
||||||
|
with patch("time.time") as mock_time:
|
||||||
|
# Time setup to trigger window processing immediately
|
||||||
|
mock_time.side_effect = [0, 100, 101]
|
||||||
|
|
||||||
|
# Stop loop after first pass
|
||||||
|
async def stop_loop(*args, **kwargs):
|
||||||
|
agent._running = False
|
||||||
|
agent.update_emotions.side_effect = stop_loop
|
||||||
|
|
||||||
|
await agent.emotion_update_loop()
|
||||||
|
|
||||||
|
# It should call update_emotions with EMPTY set because min frames (5) > detected (1)
|
||||||
|
call_args = agent.update_emotions.call_args
|
||||||
|
assert call_args[0][1] == set()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_emotion_update_loop_zmq_again_and_exception(agent):
|
||||||
|
"""Test that the loop handles ZMQ timeouts (Again) and generic exceptions."""
|
||||||
|
await agent.setup()
|
||||||
|
agent._running = True
|
||||||
|
|
||||||
|
# Side effect:
|
||||||
|
# 1. Raise ZMQ Again (Timeout) -> should log warning
|
||||||
|
# 2. Raise Generic Exception -> should log error
|
||||||
|
# 3. Raise CancelledError -> stop loop (simulating stop)
|
||||||
|
agent.video_in_socket.recv_multipart.side_effect = [
|
||||||
|
zmq.Again(),
|
||||||
|
RuntimeError("Random Failure"),
|
||||||
|
asyncio.CancelledError() # To break loop cleanly
|
||||||
|
]
|
||||||
|
|
||||||
|
# We need to ensure the loop doesn't block on _paused
|
||||||
|
agent._paused.set()
|
||||||
|
|
||||||
|
# Run loop
|
||||||
|
try:
|
||||||
|
await agent.emotion_update_loop()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_emotions_logic(agent, mock_settings):
|
||||||
|
"""Test the logic for calculating diffs and sending messages."""
|
||||||
|
agent.name = "viz_agent"
|
||||||
|
|
||||||
|
# Case 1: No change
|
||||||
|
await agent.update_emotions({"happy"}, {"happy"})
|
||||||
|
agent.send.assert_not_called()
|
||||||
|
|
||||||
|
# Case 2: Remove 'happy', Add 'sad'
|
||||||
|
await agent.update_emotions({"happy"}, {"sad"})
|
||||||
|
|
||||||
|
assert agent.send.called
|
||||||
|
call_args = agent.send.call_args
|
||||||
|
msg = call_args[0][0] # InternalMessage object
|
||||||
|
|
||||||
|
assert msg.to == mock_settings.agent_settings.bdi_core_name
|
||||||
|
assert msg.sender == "viz_agent"
|
||||||
|
assert msg.thread == "beliefs"
|
||||||
|
|
||||||
|
payload = json.loads(msg.body)
|
||||||
|
|
||||||
|
# Check Created Beliefs
|
||||||
|
assert len(payload["create"]) == 1
|
||||||
|
assert payload["create"][0]["name"] == "emotion_detected"
|
||||||
|
assert payload["create"][0]["arguments"] == ["sad"]
|
||||||
|
|
||||||
|
# Check Deleted Beliefs
|
||||||
|
assert len(payload["delete"]) == 1
|
||||||
|
assert payload["delete"][0]["name"] == "emotion_detected"
|
||||||
|
assert payload["delete"][0]["arguments"] == ["happy"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_emotions_validation_error(agent):
|
||||||
|
"""Test that ValidationErrors during Belief creation are caught."""
|
||||||
|
|
||||||
|
# We patch Belief to raise ValidationError
|
||||||
|
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.Belief") as MockBelief: # noqa
|
||||||
|
MockBelief.side_effect = ValidationError.from_exception_data("Simulated Error", [])
|
||||||
|
|
||||||
|
# Try to update emotions
|
||||||
|
await agent.update_emotions(prev_emotions={"happy"}, emotions={"sad"})
|
||||||
|
|
||||||
|
# Verify empty payload is sent (or payload with valid ones if mixed)
|
||||||
|
# In this case both failed, so payload lists should be empty
|
||||||
|
assert agent.send.called
|
||||||
|
msg = agent.send.call_args[0][0]
|
||||||
|
payload = json.loads(msg.body)
|
||||||
|
assert payload["create"] == []
|
||||||
|
assert payload["delete"] == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message(agent, mock_settings):
|
||||||
|
"""Test message handling for Pause/Resume."""
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
ui_name = mock_settings.agent_settings.user_interrupt_name
|
||||||
|
|
||||||
|
# 1. PAUSE message
|
||||||
|
msg_pause = InternalMessage(to="me", sender=ui_name, body="PAUSE")
|
||||||
|
await agent.handle_message(msg_pause)
|
||||||
|
assert not agent._paused.is_set() # Should be cleared (paused)
|
||||||
|
agent.logger.info.assert_called_with("Pausing Visual Emotion Recognition processing.")
|
||||||
|
|
||||||
|
# 2. RESUME message
|
||||||
|
msg_resume = InternalMessage(to="me", sender=ui_name, body="RESUME")
|
||||||
|
await agent.handle_message(msg_resume)
|
||||||
|
assert agent._paused.is_set() # Should be set (running)
|
||||||
|
|
||||||
|
# 3. Unknown command
|
||||||
|
msg_unknown = InternalMessage(to="me", sender=ui_name, body="DANCE")
|
||||||
|
await agent.handle_message(msg_unknown)
|
||||||
|
|
||||||
|
# 4. Unknown sender
|
||||||
|
msg_random = InternalMessage(to="me", sender="random_guy", body="PAUSE")
|
||||||
|
await agent.handle_message(msg_random)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop(agent, mock_zmq_context):
|
||||||
|
"""Test the stop method cleans up resources."""
|
||||||
|
# We need to mock super().stop(). Since we can't easily patch super(),
|
||||||
|
# and the provided BaseAgent code shows stop() just sets _running and cancels tasks,
|
||||||
|
# we can rely on the fact that VisualEmotionRecognitionAgent calls it.
|
||||||
|
|
||||||
|
# However, since we provided a 'agent' fixture that mocks things, we should verify specific cleanups. # noqa
|
||||||
|
await agent.setup()
|
||||||
|
|
||||||
|
with patch("control_backend.agents.BaseAgent.stop", new_callable=AsyncMock) as mock_super_stop:
|
||||||
|
await agent.stop()
|
||||||
|
|
||||||
|
# Verify socket closed
|
||||||
|
agent.video_in_socket.close.assert_called_once()
|
||||||
|
# Verify parent stop called
|
||||||
|
mock_super_stop.assert_called_once()
|
||||||
@@ -305,26 +305,30 @@ async def test_send_experiment_control(agent):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_pause_command(agent):
|
async def test_send_pause_command(agent):
|
||||||
|
# --- Test PAUSE ---
|
||||||
await agent._send_pause_command("true")
|
await agent._send_pause_command("true")
|
||||||
# Sends to RI and VAD
|
|
||||||
assert agent.send.await_count == 2
|
|
||||||
msgs = [call.args[0] for call in agent.send.call_args_list]
|
|
||||||
|
|
||||||
ri_msg = next(m for m in msgs if m.to == settings.agent_settings.ri_communication_name)
|
# Should send exactly 1 message
|
||||||
assert json.loads(ri_msg.body)["endpoint"] == "" # PAUSE endpoint
|
assert agent.send.await_count == 1
|
||||||
assert json.loads(ri_msg.body)["data"] is True
|
|
||||||
|
|
||||||
vad_msg = next(m for m in msgs if m.to == settings.agent_settings.vad_name)
|
# Extract the message object from the mock call
|
||||||
assert vad_msg.body == "PAUSE"
|
# call_args[0] are positional args, and [0] is the first arg (the message)
|
||||||
|
msg = agent.send.call_args[0][0]
|
||||||
|
|
||||||
|
# Verify Body
|
||||||
|
assert msg.body == "PAUSE"
|
||||||
|
|
||||||
|
# --- Test RESUME ---
|
||||||
agent.send.reset_mock()
|
agent.send.reset_mock()
|
||||||
await agent._send_pause_command("false")
|
await agent._send_pause_command("false")
|
||||||
assert agent.send.await_count == 2
|
|
||||||
vad_msg = next(
|
|
||||||
m for m in agent.send.call_args_list if m.args[0].to == settings.agent_settings.vad_name
|
|
||||||
).args[0]
|
|
||||||
assert vad_msg.body == "RESUME"
|
|
||||||
|
|
||||||
|
# Should send exactly 1 message
|
||||||
|
assert agent.send.await_count == 1
|
||||||
|
|
||||||
|
msg = agent.send.call_args[0][0]
|
||||||
|
|
||||||
|
# Verify Body
|
||||||
|
assert msg.body == "RESUME"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup(agent):
|
async def test_setup(agent):
|
||||||
|
|||||||
Reference in New Issue
Block a user