Compare commits
31 Commits
main
...
feat/face-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce99dc5ec3 | ||
|
|
12b905ff22 | ||
|
|
11d7c1409e | ||
|
|
f89fb2266a | ||
|
|
0f964795f3 | ||
|
|
00a13426a0 | ||
|
|
f6477e5325 | ||
|
|
09d8cca309 | ||
|
|
0b1c2ce20a | ||
|
|
f9b807fc97 | ||
|
|
424294b0a3 | ||
|
|
bc0947fac1 | ||
|
|
cd80cdf93b | ||
|
|
985327de70 | ||
|
|
302c50934e | ||
|
|
f9c69cafb3 | ||
|
|
37da9992ba | ||
|
|
f41201dd8e | ||
|
|
2033e02116 | ||
|
|
1b0b72d63a | ||
|
|
0941b26703 | ||
|
|
48ae0c7a12 | ||
|
|
a09d8b3d9a | ||
|
|
ac20048f02 | ||
|
|
05804c158d | ||
|
|
0771b0d607 | ||
|
|
1c88ae6078 | ||
|
|
6b790de53a | ||
|
|
1932ac959b | ||
|
|
bb0c1bd383 | ||
|
|
03954bef54 |
@@ -24,7 +24,6 @@ dependencies = [
|
|||||||
"sphinx-rtd-theme>=3.0.2",
|
"sphinx-rtd-theme>=3.0.2",
|
||||||
"tf-keras>=2.20.1",
|
"tf-keras>=2.20.1",
|
||||||
"torch>=2.8.0",
|
"torch>=2.8.0",
|
||||||
"tornado ; sys_platform == 'win32'",
|
|
||||||
"uvicorn>=0.37.0",
|
"uvicorn>=0.37.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -41,7 +40,6 @@ dev = [
|
|||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
"agentspeak>=0.2.2",
|
"agentspeak>=0.2.2",
|
||||||
"deepface>=0.0.97",
|
|
||||||
"fastapi>=0.115.6",
|
"fastapi>=0.115.6",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
|
||||||
@@ -56,7 +54,6 @@ test = [
|
|||||||
"pyyaml>=6.0.3",
|
"pyyaml>=6.0.3",
|
||||||
"pyzmq>=27.1.0",
|
"pyzmq>=27.1.0",
|
||||||
"soundfile>=0.13.1",
|
"soundfile>=0.13.1",
|
||||||
"tf-keras>=2.20.1",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ University within the Software Project course.
|
|||||||
© Copyright Utrecht University (Department of Information and Computing Sciences)
|
© Copyright Utrecht University (Department of Information and Computing Sciences)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
|
|
||||||
from slugify import slugify
|
from slugify import slugify
|
||||||
@@ -31,6 +30,7 @@ from control_backend.schemas.program import (
|
|||||||
BasicNorm,
|
BasicNorm,
|
||||||
ConditionalNorm,
|
ConditionalNorm,
|
||||||
EmotionBelief,
|
EmotionBelief,
|
||||||
|
FaceBelief,
|
||||||
GestureAction,
|
GestureAction,
|
||||||
Goal,
|
Goal,
|
||||||
InferredBelief,
|
InferredBelief,
|
||||||
@@ -67,7 +67,6 @@ class AgentSpeakGenerator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_asp: AstProgram
|
_asp: AstProgram
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def generate(self, program: Program) -> str:
|
def generate(self, program: Program) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -107,7 +106,7 @@ class AgentSpeakGenerator:
|
|||||||
check if a keyword is a substring of the user's message.
|
check if a keyword is a substring of the user's message.
|
||||||
|
|
||||||
The generated rule has the form:
|
The generated rule has the form:
|
||||||
keyword_said(Keyword) :- user_said(Message) & .substring_case_insensitive(Keyword, Message, Pos) & Pos >= 0
|
keyword_said(Keyword) :- user_said(Message) & .substring(Keyword, Message, Pos) & Pos >= 0
|
||||||
|
|
||||||
This enables the system to trigger behaviors based on keyword detection.
|
This enables the system to trigger behaviors based on keyword detection.
|
||||||
"""
|
"""
|
||||||
@@ -119,7 +118,7 @@ class AgentSpeakGenerator:
|
|||||||
AstRule(
|
AstRule(
|
||||||
AstLiteral("keyword_said", [keyword]),
|
AstLiteral("keyword_said", [keyword]),
|
||||||
AstLiteral("user_said", [message])
|
AstLiteral("user_said", [message])
|
||||||
& AstLiteral(".substring_case_insensitive", [keyword, message, position])
|
& AstLiteral(".substring", [keyword, message, position])
|
||||||
& (position >= 0),
|
& (position >= 0),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -135,6 +134,7 @@ class AgentSpeakGenerator:
|
|||||||
"""
|
"""
|
||||||
self._add_reply_with_goal_plan()
|
self._add_reply_with_goal_plan()
|
||||||
self._add_say_plan()
|
self._add_say_plan()
|
||||||
|
self._add_reply_plan()
|
||||||
self._add_notify_cycle_plan()
|
self._add_notify_cycle_plan()
|
||||||
|
|
||||||
def _add_reply_with_goal_plan(self):
|
def _add_reply_with_goal_plan(self):
|
||||||
@@ -198,6 +198,40 @@ class AgentSpeakGenerator:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_reply_plan(self):
|
||||||
|
"""
|
||||||
|
Adds a plan for general reply actions.
|
||||||
|
|
||||||
|
This plan handles general reply actions where the agent needs to respond
|
||||||
|
to user input without a specific conversational goal. It:
|
||||||
|
1. Marks that the agent has responded this turn
|
||||||
|
2. Gathers all active norms
|
||||||
|
3. Generates a reply based on the user message and norms
|
||||||
|
|
||||||
|
Trigger: +!reply
|
||||||
|
Context: user_said(Message)
|
||||||
|
"""
|
||||||
|
self._asp.plans.append(
|
||||||
|
AstPlan(
|
||||||
|
TriggerType.ADDED_GOAL,
|
||||||
|
AstLiteral("reply"),
|
||||||
|
[AstLiteral("user_said", [AstVar("Message")])],
|
||||||
|
[
|
||||||
|
AstStatement(StatementType.ADD_BELIEF, AstLiteral("responded_this_turn")),
|
||||||
|
AstStatement(
|
||||||
|
StatementType.DO_ACTION,
|
||||||
|
AstLiteral(
|
||||||
|
"findall",
|
||||||
|
[AstVar("Norm"), AstLiteral("norm", [AstVar("Norm")]), AstVar("Norms")],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
AstStatement(
|
||||||
|
StatementType.DO_ACTION,
|
||||||
|
AstLiteral("reply", [AstVar("Message"), AstVar("Norms")]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _add_notify_cycle_plan(self):
|
def _add_notify_cycle_plan(self):
|
||||||
"""
|
"""
|
||||||
@@ -235,39 +269,6 @@ class AgentSpeakGenerator:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_stop_plan(self, phase: Phase):
|
|
||||||
"""
|
|
||||||
Adds a plan to stop the program. This just skips to the end phase,
|
|
||||||
where there is no behavior defined.
|
|
||||||
"""
|
|
||||||
self._asp.plans.append(
|
|
||||||
AstPlan(
|
|
||||||
TriggerType.ADDED_GOAL,
|
|
||||||
AstLiteral("stop"),
|
|
||||||
[AstLiteral("phase", [AstString(phase.id)])],
|
|
||||||
[
|
|
||||||
AstStatement(
|
|
||||||
StatementType.DO_ACTION,
|
|
||||||
AstLiteral(
|
|
||||||
"notify_transition_phase",
|
|
||||||
[
|
|
||||||
AstString(phase.id),
|
|
||||||
AstString("end")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
|
||||||
AstStatement(
|
|
||||||
StatementType.REMOVE_BELIEF,
|
|
||||||
AstLiteral("phase", [AstVar("Phase")]),
|
|
||||||
),
|
|
||||||
AstStatement(
|
|
||||||
StatementType.ADD_BELIEF,
|
|
||||||
AstLiteral("phase", [AstString("end")])
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_phases(self, phases: list[Phase]) -> None:
|
def _process_phases(self, phases: list[Phase]) -> None:
|
||||||
"""
|
"""
|
||||||
Processes all phases in the program and their transitions.
|
Processes all phases in the program and their transitions.
|
||||||
@@ -284,6 +285,21 @@ class AgentSpeakGenerator:
|
|||||||
self._process_phase(curr_phase)
|
self._process_phase(curr_phase)
|
||||||
self._add_phase_transition(curr_phase, next_phase)
|
self._add_phase_transition(curr_phase, next_phase)
|
||||||
|
|
||||||
|
# End phase behavior
|
||||||
|
# When deleting this, the entire `reply` plan and action can be deleted
|
||||||
|
self._asp.plans.append(
|
||||||
|
AstPlan(
|
||||||
|
type=TriggerType.ADDED_BELIEF,
|
||||||
|
trigger_literal=AstLiteral("user_said", [AstVar("Message")]),
|
||||||
|
context=[AstLiteral("phase", [AstString("end")])],
|
||||||
|
body=[
|
||||||
|
AstStatement(
|
||||||
|
StatementType.DO_ACTION, AstLiteral("notify_user_said", [AstVar("Message")])
|
||||||
|
),
|
||||||
|
AstStatement(StatementType.ACHIEVE_GOAL, AstLiteral("reply")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _process_phase(self, phase: Phase) -> None:
|
def _process_phase(self, phase: Phase) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -310,9 +326,6 @@ class AgentSpeakGenerator:
|
|||||||
for trigger in phase.triggers:
|
for trigger in phase.triggers:
|
||||||
self._process_trigger(trigger, phase)
|
self._process_trigger(trigger, phase)
|
||||||
|
|
||||||
# Add force transition to end phase
|
|
||||||
self._add_stop_plan(phase)
|
|
||||||
|
|
||||||
def _add_phase_transition(self, from_phase: Phase | None, to_phase: Phase | None) -> None:
|
def _add_phase_transition(self, from_phase: Phase | None, to_phase: Phase | None) -> None:
|
||||||
"""
|
"""
|
||||||
Adds plans for transitioning between phases.
|
Adds plans for transitioning between phases.
|
||||||
@@ -488,13 +501,9 @@ class AgentSpeakGenerator:
|
|||||||
if isinstance(step, Goal):
|
if isinstance(step, Goal):
|
||||||
subgoals.append(step)
|
subgoals.append(step)
|
||||||
|
|
||||||
if not goal.can_fail:
|
if not goal.can_fail and not continues_response:
|
||||||
body.append(AstStatement(StatementType.ADD_BELIEF, self._astify(goal, achieved=True)))
|
body.append(AstStatement(StatementType.ADD_BELIEF, self._astify(goal, achieved=True)))
|
||||||
|
|
||||||
if len(body) == 0:
|
|
||||||
self.logger.warning("Goal with no plan detected: %s", goal.name)
|
|
||||||
body.append(AstStatement(StatementType.EMPTY, AstLiteral("true")))
|
|
||||||
|
|
||||||
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(goal), context, body))
|
self._asp.plans.append(AstPlan(TriggerType.ADDED_GOAL, self._astify(goal), context, body))
|
||||||
|
|
||||||
self._asp.plans.append(
|
self._asp.plans.append(
|
||||||
@@ -555,10 +564,10 @@ class AgentSpeakGenerator:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
for step in trigger.plan.steps:
|
for step in trigger.plan.steps:
|
||||||
if isinstance(step, Goal):
|
|
||||||
new_step = step.model_copy(update={"can_fail": False}) # triggers are sequence
|
|
||||||
subgoals.append(new_step)
|
|
||||||
body.append(self._step_to_statement(step))
|
body.append(self._step_to_statement(step))
|
||||||
|
if isinstance(step, Goal):
|
||||||
|
step.can_fail = False # triggers are continuous sequence
|
||||||
|
subgoals.append(step)
|
||||||
|
|
||||||
# Arbitrary wait for UI to display nicely
|
# Arbitrary wait for UI to display nicely
|
||||||
body.append(
|
body.append(
|
||||||
@@ -602,7 +611,6 @@ class AgentSpeakGenerator:
|
|||||||
- check_triggers: When no triggers are applicable
|
- check_triggers: When no triggers are applicable
|
||||||
- transition_phase: When phase transition conditions aren't met
|
- transition_phase: When phase transition conditions aren't met
|
||||||
- force_transition_phase: When forced transitions aren't possible
|
- force_transition_phase: When forced transitions aren't possible
|
||||||
- stop: When we are already in the end phase
|
|
||||||
"""
|
"""
|
||||||
# Trigger fallback
|
# Trigger fallback
|
||||||
self._asp.plans.append(
|
self._asp.plans.append(
|
||||||
@@ -634,16 +642,6 @@ class AgentSpeakGenerator:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stop fallback
|
|
||||||
self._asp.plans.append(
|
|
||||||
AstPlan(
|
|
||||||
TriggerType.ADDED_GOAL,
|
|
||||||
AstLiteral("stop"),
|
|
||||||
[],
|
|
||||||
[AstStatement(StatementType.EMPTY, AstLiteral("true"))],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@singledispatchmethod
|
@singledispatchmethod
|
||||||
def _astify(self, element: ProgramElement) -> AstExpression:
|
def _astify(self, element: ProgramElement) -> AstExpression:
|
||||||
"""
|
"""
|
||||||
@@ -690,6 +688,10 @@ class AgentSpeakGenerator:
|
|||||||
def _(self, eb: EmotionBelief) -> AstExpression:
|
def _(self, eb: EmotionBelief) -> AstExpression:
|
||||||
return AstLiteral("emotion_detected", [AstAtom(eb.emotion)])
|
return AstLiteral("emotion_detected", [AstAtom(eb.emotion)])
|
||||||
|
|
||||||
|
@_astify.register
|
||||||
|
def _(self, eb: FaceBelief) -> AstExpression:
|
||||||
|
return AstLiteral("face_present")
|
||||||
|
|
||||||
@_astify.register
|
@_astify.register
|
||||||
def _(self, ib: InferredBelief) -> AstExpression:
|
def _(self, ib: InferredBelief) -> AstExpression:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -176,8 +176,6 @@ class BDICoreAgent(BaseAgent):
|
|||||||
self._force_norm(msg.body)
|
self._force_norm(msg.body)
|
||||||
case "force_next_phase":
|
case "force_next_phase":
|
||||||
self._force_next_phase()
|
self._force_next_phase()
|
||||||
case "stop":
|
|
||||||
self._stop()
|
|
||||||
case _:
|
case _:
|
||||||
self.logger.warning("Received unknown user interruption: %s", msg)
|
self.logger.warning("Received unknown user interruption: %s", msg)
|
||||||
|
|
||||||
@@ -337,11 +335,6 @@ class BDICoreAgent(BaseAgent):
|
|||||||
|
|
||||||
self.logger.info("Manually forced phase transition.")
|
self.logger.info("Manually forced phase transition.")
|
||||||
|
|
||||||
def _stop(self):
|
|
||||||
self._set_goal("stop")
|
|
||||||
|
|
||||||
self.logger.info("Stopped the program (skipped to end phase).")
|
|
||||||
|
|
||||||
def _add_custom_actions(self) -> None:
|
def _add_custom_actions(self) -> None:
|
||||||
"""
|
"""
|
||||||
Add any custom actions here. Inside `@self.actions.add()`, the first argument is
|
Add any custom actions here. Inside `@self.actions.add()`, the first argument is
|
||||||
@@ -349,28 +342,6 @@ class BDICoreAgent(BaseAgent):
|
|||||||
the function expects (which will be located in `term.args`).
|
the function expects (which will be located in `term.args`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@self.actions.add(".substring_case_insensitive", 3)
|
|
||||||
@agentspeak.optimizer.function_like
|
|
||||||
def _substring(agent, term, intention):
|
|
||||||
"""
|
|
||||||
Find out if a string is a substring of another (case insensitive). Copied mostly from
|
|
||||||
the agentspeak library method .substring.
|
|
||||||
"""
|
|
||||||
needle = agentspeak.asl_str(agentspeak.grounded(term.args[0], intention.scope)).lower()
|
|
||||||
haystack = agentspeak.asl_str(agentspeak.grounded(term.args[1], intention.scope)).lower()
|
|
||||||
|
|
||||||
choicepoint = object()
|
|
||||||
|
|
||||||
pos = haystack.find(needle)
|
|
||||||
while pos != -1:
|
|
||||||
intention.stack.append(choicepoint)
|
|
||||||
|
|
||||||
if agentspeak.unify(term.args[2], pos, intention.scope, intention.stack):
|
|
||||||
yield
|
|
||||||
|
|
||||||
agentspeak.reroll(intention.scope, intention.stack, choicepoint)
|
|
||||||
pos = haystack.find(needle, pos + 1)
|
|
||||||
|
|
||||||
@self.actions.add(".reply", 2)
|
@self.actions.add(".reply", 2)
|
||||||
def _reply(agent, term, intention):
|
def _reply(agent, term, intention):
|
||||||
"""
|
"""
|
||||||
@@ -496,6 +467,7 @@ class BDICoreAgent(BaseAgent):
|
|||||||
body=str(trigger_name),
|
body=str(trigger_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: check with Pim
|
||||||
self.add_behavior(self.send(msg))
|
self.add_behavior(self.send(msg))
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from control_backend.agents.perception.visual_emotion_recognition_agent.visual_e
|
|||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||||
from ..perception import VADAgent
|
from ..perception import FacePerceptionAgent, VADAgent
|
||||||
|
|
||||||
|
|
||||||
class RICommunicationAgent(BaseAgent):
|
class RICommunicationAgent(BaseAgent):
|
||||||
@@ -181,7 +181,7 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
bind = port_data["bind"]
|
bind = port_data["bind"]
|
||||||
|
|
||||||
if not bind:
|
if not bind:
|
||||||
addr = f"tcp://{settings.ri_host}:{port}"
|
addr = f"tcp://localhost:{port}"
|
||||||
else:
|
else:
|
||||||
addr = f"tcp://*:{port}"
|
addr = f"tcp://*:{port}"
|
||||||
|
|
||||||
@@ -224,6 +224,13 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
self.visual_emotion_recognition_agent = visual_emotion_agent
|
self.visual_emotion_recognition_agent = visual_emotion_agent
|
||||||
await visual_emotion_agent.start()
|
await visual_emotion_agent.start()
|
||||||
|
case "face":
|
||||||
|
face_agent = FacePerceptionAgent(
|
||||||
|
settings.agent_settings.face_agent_name,
|
||||||
|
zmq_address=addr,
|
||||||
|
zmq_bind=bind,
|
||||||
|
)
|
||||||
|
await face_agent.start()
|
||||||
case _:
|
case _:
|
||||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||||
|
|
||||||
@@ -338,4 +345,3 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
self.logger.debug("Restarting communication negotiation.")
|
self.logger.debug("Restarting communication negotiation.")
|
||||||
if await self._negotiate_connection(max_retries=2):
|
if await self._negotiate_connection(max_retries=2):
|
||||||
self.connected = True
|
self.connected = True
|
||||||
|
|
||||||
@@ -185,9 +185,6 @@ class LLMAgent(BaseAgent):
|
|||||||
full_message = ""
|
full_message = ""
|
||||||
current_chunk = ""
|
current_chunk = ""
|
||||||
async for token in self._stream_query_llm(messages):
|
async for token in self._stream_query_llm(messages):
|
||||||
if self._interrupted:
|
|
||||||
return
|
|
||||||
|
|
||||||
full_message += token
|
full_message += token
|
||||||
current_chunk += token
|
current_chunk += token
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Agents responsible for processing sensory input, such as audio transcription and
|
|||||||
detection.
|
detection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .face_rec_agent import FacePerceptionAgent as FacePerceptionAgent
|
||||||
from .transcription_agent.transcription_agent import (
|
from .transcription_agent.transcription_agent import (
|
||||||
TranscriptionAgent as TranscriptionAgent,
|
TranscriptionAgent as TranscriptionAgent,
|
||||||
)
|
)
|
||||||
|
|||||||
144
src/control_backend/agents/perception/face_rec_agent.py
Normal file
144
src/control_backend/agents/perception/face_rec_agent.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""
|
||||||
|
This program has been developed by students from the bachelor Computer Science at Utrecht
|
||||||
|
University within the Software Project course.
|
||||||
|
© Copyright Utrecht University (Department of Information and Computing Sciences)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio as azmq
|
||||||
|
|
||||||
|
from control_backend.agents import BaseAgent
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||||
|
|
||||||
|
|
||||||
|
class FacePerceptionAgent(BaseAgent):
|
||||||
|
"""
|
||||||
|
Receives face presence updates from the RICommunicationAgent
|
||||||
|
via the internal PUB/SUB bus.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, zmq_address: str, zmq_bind: bool):
|
||||||
|
"""
|
||||||
|
:param name: The name of the agent.
|
||||||
|
:param zmq_address: The ZMQ address to subscribe to, an endpoint which sends face presence
|
||||||
|
updates.
|
||||||
|
:param zmq_bind: Whether to connect to the ZMQ endpoint, or to bind.
|
||||||
|
"""
|
||||||
|
super().__init__(name)
|
||||||
|
self._zmq_address = zmq_address
|
||||||
|
self._zmq_bind = zmq_bind
|
||||||
|
self._socket: azmq.Socket | None = None
|
||||||
|
|
||||||
|
self._last_face_state: bool | None = None
|
||||||
|
|
||||||
|
# Pause functionality
|
||||||
|
# NOTE: flag is set when running, cleared when paused
|
||||||
|
self._paused = asyncio.Event()
|
||||||
|
self._paused.set()
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
self.logger.info("Starting FacePerceptionAgent")
|
||||||
|
|
||||||
|
if self._socket is None:
|
||||||
|
self._connect_socket()
|
||||||
|
|
||||||
|
self.add_behavior(self._poll_loop())
|
||||||
|
self.logger.info("Finished setting up %s", self.name)
|
||||||
|
|
||||||
|
def _connect_socket(self):
|
||||||
|
if self._socket is not None:
|
||||||
|
self.logger.warning("ZMQ socket already initialized. Did you call setup() twice?")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._socket = azmq.Context.instance().socket(zmq.SUB)
|
||||||
|
self._socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
if self._zmq_bind:
|
||||||
|
self._socket.bind(self._zmq_address)
|
||||||
|
else:
|
||||||
|
self._socket.connect(self._zmq_address)
|
||||||
|
|
||||||
|
async def _poll_loop(self):
|
||||||
|
if self._socket is None:
|
||||||
|
self.logger.warning("Connection not initialized before poll loop. Call setup() first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await self._paused.wait()
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self._socket.recv_json(), timeout=settings.behaviour_settings.sleep_s
|
||||||
|
)
|
||||||
|
|
||||||
|
face_present = response.get("face_detected", False)
|
||||||
|
|
||||||
|
if self._last_face_state is None:
|
||||||
|
self._last_face_state = face_present
|
||||||
|
continue
|
||||||
|
|
||||||
|
if face_present != self._last_face_state:
|
||||||
|
self._last_face_state = face_present
|
||||||
|
self.logger.debug("Face detected" if face_present else "Face lost")
|
||||||
|
await self._update_face_belief(face_present)
|
||||||
|
except TimeoutError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error("Face polling failed", exc_info=e)
|
||||||
|
|
||||||
|
async def _post_face_belief(self, present: bool):
|
||||||
|
"""
|
||||||
|
Send a face_present belief update to the BDI Core Agent.
|
||||||
|
"""
|
||||||
|
if present:
|
||||||
|
belief_msg = BeliefMessage(create=[{"name": "face_present", "arguments": []}])
|
||||||
|
else:
|
||||||
|
belief_msg = BeliefMessage(delete=[{"name": "face_present", "arguments": []}])
|
||||||
|
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=settings.agent_settings.bdi_core_name,
|
||||||
|
sender=self.name,
|
||||||
|
thread="beliefs",
|
||||||
|
body=belief_msg.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.send(msg)
|
||||||
|
|
||||||
|
async def _update_face_belief(self, present: bool):
|
||||||
|
"""
|
||||||
|
Add or remove the `face_present` belief in the BDI Core Agent.
|
||||||
|
"""
|
||||||
|
if present:
|
||||||
|
payload = BeliefMessage(create=[Belief(name="face_present").model_dump()])
|
||||||
|
else:
|
||||||
|
payload = BeliefMessage(delete=[Belief(name="face_present").model_dump()])
|
||||||
|
|
||||||
|
message = InternalMessage(
|
||||||
|
to=settings.agent_settings.bdi_core_name,
|
||||||
|
sender=self.name,
|
||||||
|
thread="beliefs",
|
||||||
|
body=payload.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.send(message)
|
||||||
|
|
||||||
|
async def handle_message(self, msg: InternalMessage):
|
||||||
|
"""
|
||||||
|
Handle incoming pause/resume commands from User Interrupt Agent.
|
||||||
|
"""
|
||||||
|
sender = msg.sender
|
||||||
|
|
||||||
|
if sender == settings.agent_settings.user_interrupt_name:
|
||||||
|
if msg.body == "PAUSE":
|
||||||
|
self.logger.info("Pausing Face Perception processing.")
|
||||||
|
self._paused.clear()
|
||||||
|
self._last_face_state = None
|
||||||
|
elif msg.body == "RESUME":
|
||||||
|
self.logger.info("Resuming Face Perception processing.")
|
||||||
|
self._paused.set()
|
||||||
|
else:
|
||||||
|
self.logger.warning("Unknown command from User Interrupt Agent: %s", msg.body)
|
||||||
|
else:
|
||||||
|
self.logger.debug("Ignoring message from unknown sender: %s", sender)
|
||||||
@@ -3,6 +3,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio as azmq
|
import zmq.asyncio as azmq
|
||||||
@@ -63,8 +64,6 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
|
|
||||||
self.video_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
self.video_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||||
|
|
||||||
self.video_in_socket.setsockopt(zmq.RCVHWM, 3)
|
|
||||||
|
|
||||||
if self.socket_bind:
|
if self.socket_bind:
|
||||||
self.video_in_socket.bind(self.socket_address)
|
self.video_in_socket.bind(self.socket_address)
|
||||||
else:
|
else:
|
||||||
@@ -72,9 +71,12 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
|
|
||||||
self.video_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
self.video_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
self.video_in_socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
|
self.video_in_socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
|
||||||
|
self.video_in_socket.setsockopt(zmq.CONFLATE, 1)
|
||||||
|
|
||||||
self.add_behavior(self.emotion_update_loop())
|
self.add_behavior(self.emotion_update_loop())
|
||||||
|
|
||||||
|
self.logger.info("Finished setting up %s", self.name)
|
||||||
|
|
||||||
async def emotion_update_loop(self):
|
async def emotion_update_loop(self):
|
||||||
"""
|
"""
|
||||||
Background loop to receive video frames, recognize emotions, and update beliefs.
|
Background loop to receive video frames, recognize emotions, and update beliefs.
|
||||||
@@ -95,18 +97,21 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
try:
|
try:
|
||||||
await self._paused.wait()
|
await self._paused.wait()
|
||||||
|
|
||||||
width, height, image_bytes = await self.video_in_socket.recv_multipart()
|
frame_bytes = await self.video_in_socket.recv()
|
||||||
|
|
||||||
width = int.from_bytes(width, 'little')
|
|
||||||
height = int.from_bytes(height, 'little')
|
|
||||||
|
|
||||||
# Convert bytes to a numpy buffer
|
# Convert bytes to a numpy buffer
|
||||||
image_array = np.frombuffer(image_bytes, np.uint8)
|
nparr = np.frombuffer(frame_bytes, np.uint8)
|
||||||
|
|
||||||
frame = image_array.reshape((height, width, 3))
|
# Decode image into the generic Numpy Array DeepFace expects
|
||||||
|
frame_image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
|
||||||
|
if frame_image is None:
|
||||||
|
# Could not decode image, skip this frame
|
||||||
|
self.logger.warning("Received invalid video frame, skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
# Get the dominant emotion from each face
|
# Get the dominant emotion from each face
|
||||||
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame)
|
current_emotions = self.emotion_recognizer.sorted_dominant_emotions(frame_image)
|
||||||
# Update emotion counts for each detected face
|
# Update emotion counts for each detected face
|
||||||
for i, emotion in enumerate(current_emotions):
|
for i, emotion in enumerate(current_emotions):
|
||||||
face_stats[i][emotion] += 1
|
face_stats[i][emotion] += 1
|
||||||
@@ -128,11 +133,7 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
next_window_time = time.time() + self.window_duration
|
next_window_time = time.time() + self.window_duration
|
||||||
|
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
pass
|
self.logger.warning("No video frame received within timeout.")
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error in emotion recognition loop: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
async def update_emotions(self, prev_emotions: set[str], emotions: set[str]):
|
||||||
"""
|
"""
|
||||||
@@ -204,4 +205,3 @@ class VisualEmotionRecognitionAgent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
self.video_in_socket.close()
|
self.video_in_socket.close()
|
||||||
await super().stop()
|
await super().stop()
|
||||||
|
|
||||||
|
|||||||
@@ -23,31 +23,32 @@ class VisualEmotionRecognizer(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
class DeepFaceEmotionRecognizer(VisualEmotionRecognizer):
|
||||||
"""
|
"""
|
||||||
DeepFace-based implementation of VisualEmotionRecognizer.
|
DeepFace-based implementation of VisualEmotionRecognizer.
|
||||||
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
|
DeepFape has proven to be quite a pessimistic model, so expect sad, fear and neutral
|
||||||
emotions to be over-represented.
|
emotions to be over-represented.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
|
print("Loading Deepface Emotion Model...")
|
||||||
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
|
dummy_img = np.zeros((224, 224, 3), dtype=np.uint8)
|
||||||
# analyze does not take a model as an argument, calling it once on a dummy image to load
|
# analyze does not take a model as an argument, calling it once on a dummy image to load
|
||||||
# the model
|
# the model
|
||||||
DeepFace.analyze(dummy_img, actions=['emotion'], enforce_detection=False)
|
DeepFace.analyze(dummy_img, actions=["emotion"], enforce_detection=False)
|
||||||
|
print("Deepface Emotion Model loaded.")
|
||||||
|
|
||||||
def sorted_dominant_emotions(self, image) -> list[str]:
|
def sorted_dominant_emotions(self, image) -> list[str]:
|
||||||
analysis = DeepFace.analyze(image,
|
analysis = DeepFace.analyze(image, actions=["emotion"], enforce_detection=False)
|
||||||
actions=['emotion'],
|
|
||||||
enforce_detection=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sort faces by x coordinate to maintain left-to-right order
|
# Sort faces by x coordinate to maintain left-to-right order
|
||||||
analysis.sort(key=lambda face: face['region']['x'])
|
analysis.sort(key=lambda face: face["region"]["x"])
|
||||||
|
|
||||||
analysis = [face for face in analysis if face['face_confidence'] >= 0.90]
|
analysis = [face for face in analysis if face["face_confidence"] >= 0.90]
|
||||||
|
|
||||||
dominant_emotions = [face['dominant_emotion'] for face in analysis]
|
dominant_emotions = [face["dominant_emotion"] for face in analysis]
|
||||||
return dominant_emotions
|
return dominant_emotions
|
||||||
|
|||||||
@@ -164,12 +164,6 @@ class UserInterruptAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
self.logger.info("Sent resume command.")
|
self.logger.info("Sent resume command.")
|
||||||
|
|
||||||
case "stop":
|
|
||||||
self.logger.debug(
|
|
||||||
"Received stop command."
|
|
||||||
)
|
|
||||||
await self._send_stop_command()
|
|
||||||
|
|
||||||
case "next_phase" | "reset_phase":
|
case "next_phase" | "reset_phase":
|
||||||
await self._send_experiment_control_to_bdi_core(event_type)
|
await self._send_experiment_control_to_bdi_core(event_type)
|
||||||
case _:
|
case _:
|
||||||
@@ -410,34 +404,28 @@ class UserInterruptAgent(BaseAgent):
|
|||||||
if pause == "true":
|
if pause == "true":
|
||||||
# Send pause to VAD and VED agent
|
# Send pause to VAD and VED agent
|
||||||
vad_message = InternalMessage(
|
vad_message = InternalMessage(
|
||||||
to=[settings.agent_settings.vad_name,
|
to=[
|
||||||
settings.agent_settings.visual_emotion_recognition_name],
|
settings.agent_settings.vad_name,
|
||||||
|
settings.agent_settings.visual_emotion_recognition_name,
|
||||||
|
settings.agent_settings.face_agent_name,
|
||||||
|
],
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body="PAUSE",
|
body="PAUSE",
|
||||||
)
|
)
|
||||||
await self.send(vad_message)
|
await self.send(vad_message)
|
||||||
# Voice Activity Detection and Visual Emotion Recognition agents
|
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||||
self.logger.info("Sent pause command to VAD and VED agents.")
|
self.logger.info("Sent pause command to perception agents.")
|
||||||
else:
|
else:
|
||||||
# Send resume to VAD and VED agents
|
# Send resume to VAD and VED agents
|
||||||
vad_message = InternalMessage(
|
vad_message = InternalMessage(
|
||||||
to=[settings.agent_settings.vad_name,
|
to=[
|
||||||
settings.agent_settings.visual_emotion_recognition_name],
|
settings.agent_settings.vad_name,
|
||||||
|
settings.agent_settings.visual_emotion_recognition_name,
|
||||||
|
settings.agent_settings.face_agent_name,
|
||||||
|
],
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body="RESUME",
|
body="RESUME",
|
||||||
)
|
)
|
||||||
await self.send(vad_message)
|
await self.send(vad_message)
|
||||||
# Voice Activity Detection and Visual Emotion Recognition agents
|
# Voice Activity Detection and Visual Emotion Recognition agents
|
||||||
self.logger.info("Sent resume command to VAD and VED agents.")
|
self.logger.info("Sent resume command to perception agents.")
|
||||||
|
|
||||||
async def _send_stop_command(self):
|
|
||||||
"""
|
|
||||||
Send a command to the BDI to stop the program (i.e., skip to end phase).
|
|
||||||
"""
|
|
||||||
msg = InternalMessage(
|
|
||||||
to=settings.agent_settings.bdi_core_name,
|
|
||||||
body="",
|
|
||||||
thread="stop"
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.send(msg)
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ api_router = APIRouter()
|
|||||||
|
|
||||||
api_router.include_router(message.router, tags=["Messages"])
|
api_router.include_router(message.router, tags=["Messages"])
|
||||||
|
|
||||||
api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"])
|
api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands", "Face"])
|
||||||
|
|
||||||
api_router.include_router(logs.router, tags=["Logs"])
|
api_router.include_router(logs.router, tags=["Logs"])
|
||||||
|
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ class AgentSettings(BaseModel):
|
|||||||
robot_speech_name: str = "robot_speech_agent"
|
robot_speech_name: str = "robot_speech_agent"
|
||||||
robot_gesture_name: str = "robot_gesture_agent"
|
robot_gesture_name: str = "robot_gesture_agent"
|
||||||
user_interrupt_name: str = "user_interrupt_agent"
|
user_interrupt_name: str = "user_interrupt_agent"
|
||||||
|
face_agent_name: str = "face_detection_agent"
|
||||||
|
|
||||||
|
|
||||||
class BehaviourSettings(BaseModel):
|
class BehaviourSettings(BaseModel):
|
||||||
@@ -82,12 +83,12 @@ class BehaviourSettings(BaseModel):
|
|||||||
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
|
:ivar transcription_words_per_token: Estimated words per token for transcription timing.
|
||||||
:ivar transcription_token_buffer: Buffer for transcription tokens.
|
:ivar transcription_token_buffer: Buffer for transcription tokens.
|
||||||
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
|
:ivar conversation_history_length_limit: The maximum amount of messages to extract beliefs from.
|
||||||
|
:ivar trigger_time_to_wait: Amount of milliseconds to wait before informing the UI about trigger
|
||||||
|
completion.
|
||||||
:ivar visual_emotion_recognition_window_duration_s: Duration in seconds over which to aggregate
|
:ivar visual_emotion_recognition_window_duration_s: Duration in seconds over which to aggregate
|
||||||
emotions and update emotion beliefs.
|
emotions and update emotion beliefs.
|
||||||
:ivar visual_emotion_recognition_min_frames_per_face: Minimum number of frames per face required
|
:ivar visual_emotion_recognition_min_frames_per_face: Minimum number of frames per face required
|
||||||
to consider a face valid.
|
to consider a face valid.
|
||||||
:ivar trigger_time_to_wait: Amount of milliseconds to wait before informing the UI about trigger
|
|
||||||
completion.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
# ATTENTION: When adding/removing settings, make sure to update the .env.example file
|
||||||
@@ -111,13 +112,14 @@ class BehaviourSettings(BaseModel):
|
|||||||
# Text belief extractor settings
|
# Text belief extractor settings
|
||||||
conversation_history_length_limit: int = 10
|
conversation_history_length_limit: int = 10
|
||||||
|
|
||||||
# Visual Emotion Recognition settings
|
|
||||||
visual_emotion_recognition_window_duration_s: int = 5
|
|
||||||
visual_emotion_recognition_min_frames_per_face: int = 3
|
|
||||||
# AgentSpeak related settings
|
# AgentSpeak related settings
|
||||||
trigger_time_to_wait: int = 2000
|
trigger_time_to_wait: int = 2000
|
||||||
agentspeak_file: str = "src/control_backend/agents/bdi/agentspeak.asl"
|
agentspeak_file: str = "src/control_backend/agents/bdi/agentspeak.asl"
|
||||||
|
|
||||||
|
# Visual Emotion Recognition settings
|
||||||
|
visual_emotion_recognition_window_duration_s: int = 5
|
||||||
|
visual_emotion_recognition_min_frames_per_face: int = 3
|
||||||
|
|
||||||
|
|
||||||
class LLMSettings(BaseModel):
|
class LLMSettings(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ University within the Software Project course.
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import UUID4, BaseModel, field_validator
|
from pydantic import UUID4, BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ProgramElement(BaseModel):
|
class ProgramElement(BaseModel):
|
||||||
@@ -24,13 +24,6 @@ class ProgramElement(BaseModel):
|
|||||||
# To make program elements hashable
|
# To make program elements hashable
|
||||||
model_config = {"frozen": True}
|
model_config = {"frozen": True}
|
||||||
|
|
||||||
@field_validator("name")
|
|
||||||
@classmethod
|
|
||||||
def name_must_not_start_with_number(cls, v: str) -> str:
|
|
||||||
if v and v[0].isdigit():
|
|
||||||
raise ValueError('Field "name" must not start with a number.')
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class LogicalOperator(Enum):
|
class LogicalOperator(Enum):
|
||||||
"""
|
"""
|
||||||
@@ -48,8 +41,8 @@ class LogicalOperator(Enum):
|
|||||||
OR = "OR"
|
OR = "OR"
|
||||||
|
|
||||||
|
|
||||||
type Belief = KeywordBelief | SemanticBelief | InferredBelief | EmotionBelief
|
type Belief = KeywordBelief | SemanticBelief | InferredBelief | EmotionBelief | FaceBelief
|
||||||
type BasicBelief = KeywordBelief | SemanticBelief | EmotionBelief
|
type BasicBelief = KeywordBelief | SemanticBelief | EmotionBelief | FaceBelief
|
||||||
|
|
||||||
|
|
||||||
class KeywordBelief(ProgramElement):
|
class KeywordBelief(ProgramElement):
|
||||||
@@ -124,6 +117,16 @@ class EmotionBelief(ProgramElement):
|
|||||||
emotion: str
|
emotion: str
|
||||||
|
|
||||||
|
|
||||||
|
class FaceBelief(ProgramElement):
|
||||||
|
"""
|
||||||
|
Represents the belief that at least one face is currently detected.
|
||||||
|
This belief is maintained by a perception agent (not inferred).
|
||||||
|
"""
|
||||||
|
|
||||||
|
face_present: bool
|
||||||
|
name: str = ""
|
||||||
|
|
||||||
|
|
||||||
class Norm(ProgramElement):
|
class Norm(ProgramElement):
|
||||||
"""
|
"""
|
||||||
Base class for behavioral norms that guide the robot's interactions.
|
Base class for behavioral norms that guide the robot's interactions.
|
||||||
|
|||||||
152
test/unit/agents/perception/test_face_detection_agent.py
Normal file
152
test/unit/agents/perception/test_face_detection_agent.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
import control_backend.agents.perception.face_rec_agent as face_module
|
||||||
|
from control_backend.agents.perception.face_rec_agent import FacePerceptionAgent
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent():
|
||||||
|
"""Return a FacePerceptionAgent instance for testing."""
|
||||||
|
return FacePerceptionAgent(
|
||||||
|
name="face_agent",
|
||||||
|
zmq_address="inproc://test",
|
||||||
|
zmq_bind=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def socket():
|
||||||
|
"""Return a mocked ZMQ socket."""
|
||||||
|
sock = AsyncMock()
|
||||||
|
sock.setsockopt_string = MagicMock()
|
||||||
|
sock.connect = MagicMock()
|
||||||
|
sock.bind = MagicMock()
|
||||||
|
return sock
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_socket_connect(agent, socket, monkeypatch):
|
||||||
|
"""Test that _connect_socket properly connects when zmq_bind=False."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.socket.return_value = socket
|
||||||
|
monkeypatch.setattr(face_module.azmq, "Context", MagicMock(instance=lambda: ctx))
|
||||||
|
|
||||||
|
agent._connect_socket()
|
||||||
|
socket.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
|
||||||
|
socket.connect.assert_called_once_with(agent._zmq_address)
|
||||||
|
socket.bind.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_socket_bind(agent, socket, monkeypatch):
|
||||||
|
"""Test that _connect_socket properly binds when zmq_bind=True."""
|
||||||
|
agent._zmq_bind = True
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.socket.return_value = socket
|
||||||
|
monkeypatch.setattr(face_module.azmq, "Context", MagicMock(instance=lambda: ctx))
|
||||||
|
|
||||||
|
agent._connect_socket()
|
||||||
|
socket.bind.assert_called_once_with(agent._zmq_address)
|
||||||
|
socket.connect.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_socket_twice_is_noop(agent, socket):
|
||||||
|
"""Test that calling _connect_socket twice does not overwrite an existing socket."""
|
||||||
|
agent._socket = socket
|
||||||
|
agent._connect_socket()
|
||||||
|
assert agent._socket is socket
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_face_belief_present(agent):
|
||||||
|
"""Test that _update_face_belief(True) creates the 'face_present' belief."""
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
await agent._update_face_belief(True)
|
||||||
|
msg = agent.send.await_args.args[0]
|
||||||
|
payload = BeliefMessage.model_validate_json(msg.body)
|
||||||
|
assert payload.create[0].name == "face_present"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_face_belief_absent(agent):
|
||||||
|
"""Test that _update_face_belief(False) deletes the 'face_present' belief."""
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
await agent._update_face_belief(False)
|
||||||
|
msg = agent.send.await_args.args[0]
|
||||||
|
payload = BeliefMessage.model_validate_json(msg.body)
|
||||||
|
assert payload.delete[0].name == "face_present"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_face_belief_present(agent):
|
||||||
|
"""Test that _post_face_belief(True) sends a belief creation message."""
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
await agent._post_face_belief(True)
|
||||||
|
msg = agent.send.await_args.args[0]
|
||||||
|
assert '"create"' in msg.body and '"face_present"' in msg.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_face_belief_absent(agent):
|
||||||
|
"""Test that _post_face_belief(False) sends a belief deletion message."""
|
||||||
|
agent.send = AsyncMock()
|
||||||
|
await agent._post_face_belief(False)
|
||||||
|
msg = agent.send.await_args.args[0]
|
||||||
|
assert '"delete"' in msg.body and '"face_present"' in msg.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_pause(agent):
|
||||||
|
"""Test that a 'PAUSE' message clears _paused and resets _last_face_state."""
|
||||||
|
agent._paused.set()
|
||||||
|
agent._last_face_state = True
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=agent.name,
|
||||||
|
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||||
|
thread="cmd",
|
||||||
|
body="PAUSE",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
assert not agent._paused.is_set()
|
||||||
|
assert agent._last_face_state is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_resume(agent):
|
||||||
|
"""Test that a 'RESUME' message sets _paused."""
|
||||||
|
agent._paused.clear()
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=agent.name,
|
||||||
|
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||||
|
thread="cmd",
|
||||||
|
body="RESUME",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
assert agent._paused.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_unknown_command(agent):
|
||||||
|
"""Test that an unknown command from UserInterruptAgent is ignored (logs a warning)."""
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=agent.name,
|
||||||
|
sender=face_module.settings.agent_settings.user_interrupt_name,
|
||||||
|
thread="cmd",
|
||||||
|
body="???",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_unknown_sender(agent):
|
||||||
|
"""Test that messages from unknown senders are ignored."""
|
||||||
|
msg = InternalMessage(
|
||||||
|
to=agent.name,
|
||||||
|
sender="someone_else",
|
||||||
|
thread="cmd",
|
||||||
|
body="PAUSE",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
@@ -1,338 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import zmq
|
|
||||||
from pydantic_core import ValidationError
|
|
||||||
|
|
||||||
# Adjust the import path to match your project structure
|
|
||||||
from control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent import ( # noqa
|
|
||||||
VisualEmotionRecognitionAgent,
|
|
||||||
)
|
|
||||||
from control_backend.core.agent_system import InternalMessage
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Fixtures
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_settings():
|
|
||||||
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.settings") as mock: # noqa
|
|
||||||
# Set default values required by the agent
|
|
||||||
mock.behaviour_settings.visual_emotion_recognition_window_duration_s = 5
|
|
||||||
mock.behaviour_settings.visual_emotion_recognition_min_frames_per_face = 3
|
|
||||||
mock.agent_settings.bdi_core_name = "bdi_core_agent"
|
|
||||||
mock.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_deepface():
|
|
||||||
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.DeepFaceEmotionRecognizer") as mock: # noqa
|
|
||||||
instance = mock.return_value
|
|
||||||
instance.sorted_dominant_emotions.return_value = []
|
|
||||||
yield instance
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_zmq_context():
|
|
||||||
with patch("zmq.asyncio.Context.instance") as mock_ctx:
|
|
||||||
mock_socket = MagicMock()
|
|
||||||
# Mock socket methods to return None or AsyncMock for async methods
|
|
||||||
mock_socket.bind = MagicMock()
|
|
||||||
mock_socket.connect = MagicMock()
|
|
||||||
mock_socket.setsockopt = MagicMock()
|
|
||||||
mock_socket.setsockopt_string = MagicMock()
|
|
||||||
mock_socket.recv_multipart = AsyncMock()
|
|
||||||
mock_socket.close = MagicMock()
|
|
||||||
|
|
||||||
mock_ctx.return_value.socket.return_value = mock_socket
|
|
||||||
yield mock_ctx
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def agent(mock_settings, mock_deepface, mock_zmq_context):
|
|
||||||
# Initialize agent with specific params to control testing
|
|
||||||
agent = VisualEmotionRecognitionAgent(
|
|
||||||
name="test_agent",
|
|
||||||
socket_address="tcp://localhost:5555",
|
|
||||||
bind=False,
|
|
||||||
timeout_ms=100,
|
|
||||||
window_duration=2,
|
|
||||||
min_frames_required=2
|
|
||||||
)
|
|
||||||
# Mock the internal send method from BaseAgent
|
|
||||||
agent.send = AsyncMock()
|
|
||||||
# Mock the add_behavior method from BaseAgent
|
|
||||||
agent.add_behavior = MagicMock()
|
|
||||||
# Mock the logger
|
|
||||||
agent.logger = MagicMock()
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
# Tests
|
|
||||||
# -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialization(agent):
|
|
||||||
"""Test that the agent initializes with correct attributes."""
|
|
||||||
assert agent.name == "test_agent"
|
|
||||||
assert agent.socket_address == "tcp://localhost:5555"
|
|
||||||
assert agent.socket_bind is False
|
|
||||||
assert agent.timeout_ms == 100
|
|
||||||
assert agent._paused.is_set()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_setup_connect(agent, mock_zmq_context, mock_deepface):
|
|
||||||
"""Test setup routine when binding is False (connect)."""
|
|
||||||
agent.socket_bind = False
|
|
||||||
await agent.setup()
|
|
||||||
|
|
||||||
socket = agent.video_in_socket
|
|
||||||
socket.connect.assert_called_with("tcp://localhost:5555")
|
|
||||||
socket.bind.assert_not_called()
|
|
||||||
socket.setsockopt.assert_any_call(zmq.RCVHWM, 3)
|
|
||||||
socket.setsockopt.assert_any_call(zmq.RCVTIMEO, 100)
|
|
||||||
agent.add_behavior.assert_called_once()
|
|
||||||
assert agent.emotion_recognizer == mock_deepface
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_setup_bind(agent, mock_zmq_context):
|
|
||||||
"""Test setup routine when binding is True."""
|
|
||||||
agent.socket_bind = True
|
|
||||||
await agent.setup()
|
|
||||||
|
|
||||||
socket = agent.video_in_socket
|
|
||||||
socket.bind.assert_called_with("tcp://localhost:5555")
|
|
||||||
socket.connect.assert_not_called()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_emotion_update_loop_normal_flow(agent, mock_deepface):
|
|
||||||
"""
|
|
||||||
Test the main loop logic:
|
|
||||||
1. Receive frames
|
|
||||||
2. Aggregate stats
|
|
||||||
3. Trigger window update
|
|
||||||
4. Call update_emotions
|
|
||||||
"""
|
|
||||||
# Setup dependencies
|
|
||||||
await agent.setup()
|
|
||||||
agent._running = True
|
|
||||||
|
|
||||||
# Create fake image data (10x10 pixels)
|
|
||||||
width, height = 10, 10
|
|
||||||
image_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes()
|
|
||||||
w_bytes = width.to_bytes(4, 'little')
|
|
||||||
h_bytes = height.to_bytes(4, 'little')
|
|
||||||
|
|
||||||
# Mock ZMQ receive to return data 3 times, then stop the loop
|
|
||||||
# We use a side_effect on recv_multipart to simulate frames and then stop the loop
|
|
||||||
async def recv_side_effect():
|
|
||||||
if agent._running:
|
|
||||||
return w_bytes, h_bytes, image_bytes
|
|
||||||
raise asyncio.CancelledError()
|
|
||||||
|
|
||||||
agent.video_in_socket.recv_multipart.side_effect = recv_side_effect
|
|
||||||
|
|
||||||
# Mock DeepFace to return emotions
|
|
||||||
# Frame 1: Happy
|
|
||||||
# Frame 2: Happy
|
|
||||||
# Frame 3: Happy (Trigger window)
|
|
||||||
mock_deepface.sorted_dominant_emotions.side_effect = [
|
|
||||||
["happy"],
|
|
||||||
["happy"],
|
|
||||||
["happy"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Mock update_emotions to verify it's called
|
|
||||||
agent.update_emotions = AsyncMock()
|
|
||||||
|
|
||||||
# Mock time.time to simulate window passage
|
|
||||||
# We need time to advance significantly after the frames are collected
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
with patch("time.time") as mock_time:
|
|
||||||
# Sequence of time calls:
|
|
||||||
# 1. Init next_window_time calculation
|
|
||||||
# 2. Loop 1 check
|
|
||||||
# 3. Loop 2 check
|
|
||||||
# 4. Loop 3 check (Make this one pass the window threshold)
|
|
||||||
mock_time.side_effect = [
|
|
||||||
start_time, # init
|
|
||||||
start_time + 0.1, # frame 1 check
|
|
||||||
start_time + 0.2, # frame 2 check
|
|
||||||
start_time + 10.0, # frame 3 check (triggers window reset)
|
|
||||||
start_time + 10.1, # next init
|
|
||||||
start_time + 10.2 # break loop
|
|
||||||
]
|
|
||||||
|
|
||||||
# We need to manually break the infinite loop after the update
|
|
||||||
# We can do this by wrapping update_emotions to set _running = False
|
|
||||||
async def stop_loop(*args, **kwargs):
|
|
||||||
agent._running = False
|
|
||||||
|
|
||||||
agent.update_emotions.side_effect = stop_loop
|
|
||||||
|
|
||||||
# Run the loop
|
|
||||||
await agent.emotion_update_loop()
|
|
||||||
|
|
||||||
# Verifications
|
|
||||||
assert agent.update_emotions.called
|
|
||||||
# Check that it detected 'happy' as dominant (2 required, 3 found)
|
|
||||||
call_args = agent.update_emotions.call_args
|
|
||||||
assert call_args is not None
|
|
||||||
# args: (prev_emotions, window_dominant_emotions)
|
|
||||||
assert call_args[0][1] == {"happy"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_emotion_update_loop_insufficient_frames(agent, mock_deepface):
|
|
||||||
"""Test that emotions are NOT updated if min_frames_required is not met."""
|
|
||||||
await agent.setup()
|
|
||||||
agent._running = True
|
|
||||||
agent.min_frames_required = 5 # Set high requirement
|
|
||||||
|
|
||||||
width, height = 10, 10
|
|
||||||
image_bytes = np.zeros((10, 10, 3), dtype=np.uint8).tobytes()
|
|
||||||
w_bytes = width.to_bytes(4, 'little')
|
|
||||||
h_bytes = height.to_bytes(4, 'little')
|
|
||||||
|
|
||||||
agent.video_in_socket.recv_multipart.return_value = (w_bytes, h_bytes, image_bytes)
|
|
||||||
mock_deepface.sorted_dominant_emotions.return_value = ["sad"]
|
|
||||||
|
|
||||||
agent.update_emotions = AsyncMock()
|
|
||||||
|
|
||||||
with patch("time.time") as mock_time:
|
|
||||||
# Time setup to trigger window processing immediately
|
|
||||||
mock_time.side_effect = [0, 100, 101]
|
|
||||||
|
|
||||||
# Stop loop after first pass
|
|
||||||
async def stop_loop(*args, **kwargs):
|
|
||||||
agent._running = False
|
|
||||||
agent.update_emotions.side_effect = stop_loop
|
|
||||||
|
|
||||||
await agent.emotion_update_loop()
|
|
||||||
|
|
||||||
# It should call update_emotions with EMPTY set because min frames (5) > detected (1)
|
|
||||||
call_args = agent.update_emotions.call_args
|
|
||||||
assert call_args[0][1] == set()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_emotion_update_loop_zmq_again_and_exception(agent):
|
|
||||||
"""Test that the loop handles ZMQ timeouts (Again) and generic exceptions."""
|
|
||||||
await agent.setup()
|
|
||||||
agent._running = True
|
|
||||||
|
|
||||||
# Side effect:
|
|
||||||
# 1. Raise ZMQ Again (Timeout) -> should log warning
|
|
||||||
# 2. Raise Generic Exception -> should log error
|
|
||||||
# 3. Raise CancelledError -> stop loop (simulating stop)
|
|
||||||
agent.video_in_socket.recv_multipart.side_effect = [
|
|
||||||
zmq.Again(),
|
|
||||||
RuntimeError("Random Failure"),
|
|
||||||
asyncio.CancelledError() # To break loop cleanly
|
|
||||||
]
|
|
||||||
|
|
||||||
# We need to ensure the loop doesn't block on _paused
|
|
||||||
agent._paused.set()
|
|
||||||
|
|
||||||
# Run loop
|
|
||||||
try:
|
|
||||||
await agent.emotion_update_loop()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_emotions_logic(agent, mock_settings):
|
|
||||||
"""Test the logic for calculating diffs and sending messages."""
|
|
||||||
agent.name = "viz_agent"
|
|
||||||
|
|
||||||
# Case 1: No change
|
|
||||||
await agent.update_emotions({"happy"}, {"happy"})
|
|
||||||
agent.send.assert_not_called()
|
|
||||||
|
|
||||||
# Case 2: Remove 'happy', Add 'sad'
|
|
||||||
await agent.update_emotions({"happy"}, {"sad"})
|
|
||||||
|
|
||||||
assert agent.send.called
|
|
||||||
call_args = agent.send.call_args
|
|
||||||
msg = call_args[0][0] # InternalMessage object
|
|
||||||
|
|
||||||
assert msg.to == mock_settings.agent_settings.bdi_core_name
|
|
||||||
assert msg.sender == "viz_agent"
|
|
||||||
assert msg.thread == "beliefs"
|
|
||||||
|
|
||||||
payload = json.loads(msg.body)
|
|
||||||
|
|
||||||
# Check Created Beliefs
|
|
||||||
assert len(payload["create"]) == 1
|
|
||||||
assert payload["create"][0]["name"] == "emotion_detected"
|
|
||||||
assert payload["create"][0]["arguments"] == ["sad"]
|
|
||||||
|
|
||||||
# Check Deleted Beliefs
|
|
||||||
assert len(payload["delete"]) == 1
|
|
||||||
assert payload["delete"][0]["name"] == "emotion_detected"
|
|
||||||
assert payload["delete"][0]["arguments"] == ["happy"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_emotions_validation_error(agent):
|
|
||||||
"""Test that ValidationErrors during Belief creation are caught."""
|
|
||||||
|
|
||||||
# We patch Belief to raise ValidationError
|
|
||||||
with patch("control_backend.agents.perception.visual_emotion_recognition_agent.visual_emotion_recognition_agent.Belief") as MockBelief: # noqa
|
|
||||||
MockBelief.side_effect = ValidationError.from_exception_data("Simulated Error", [])
|
|
||||||
|
|
||||||
# Try to update emotions
|
|
||||||
await agent.update_emotions(prev_emotions={"happy"}, emotions={"sad"})
|
|
||||||
|
|
||||||
# Verify empty payload is sent (or payload with valid ones if mixed)
|
|
||||||
# In this case both failed, so payload lists should be empty
|
|
||||||
assert agent.send.called
|
|
||||||
msg = agent.send.call_args[0][0]
|
|
||||||
payload = json.loads(msg.body)
|
|
||||||
assert payload["create"] == []
|
|
||||||
assert payload["delete"] == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_message(agent, mock_settings):
|
|
||||||
"""Test message handling for Pause/Resume."""
|
|
||||||
|
|
||||||
# Setup
|
|
||||||
ui_name = mock_settings.agent_settings.user_interrupt_name
|
|
||||||
|
|
||||||
# 1. PAUSE message
|
|
||||||
msg_pause = InternalMessage(to="me", sender=ui_name, body="PAUSE")
|
|
||||||
await agent.handle_message(msg_pause)
|
|
||||||
assert not agent._paused.is_set() # Should be cleared (paused)
|
|
||||||
agent.logger.info.assert_called_with("Pausing Visual Emotion Recognition processing.")
|
|
||||||
|
|
||||||
# 2. RESUME message
|
|
||||||
msg_resume = InternalMessage(to="me", sender=ui_name, body="RESUME")
|
|
||||||
await agent.handle_message(msg_resume)
|
|
||||||
assert agent._paused.is_set() # Should be set (running)
|
|
||||||
|
|
||||||
# 3. Unknown command
|
|
||||||
msg_unknown = InternalMessage(to="me", sender=ui_name, body="DANCE")
|
|
||||||
await agent.handle_message(msg_unknown)
|
|
||||||
|
|
||||||
# 4. Unknown sender
|
|
||||||
msg_random = InternalMessage(to="me", sender="random_guy", body="PAUSE")
|
|
||||||
await agent.handle_message(msg_random)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_stop(agent, mock_zmq_context):
|
|
||||||
"""Test the stop method cleans up resources."""
|
|
||||||
# We need to mock super().stop(). Since we can't easily patch super(),
|
|
||||||
# and the provided BaseAgent code shows stop() just sets _running and cancels tasks,
|
|
||||||
# we can rely on the fact that VisualEmotionRecognitionAgent calls it.
|
|
||||||
|
|
||||||
# However, since we provided a 'agent' fixture that mocks things, we should verify specific cleanups. # noqa
|
|
||||||
await agent.setup()
|
|
||||||
|
|
||||||
with patch("control_backend.agents.BaseAgent.stop", new_callable=AsyncMock) as mock_super_stop:
|
|
||||||
await agent.stop()
|
|
||||||
|
|
||||||
# Verify socket closed
|
|
||||||
agent.video_in_socket.close.assert_called_once()
|
|
||||||
# Verify parent stop called
|
|
||||||
mock_super_stop.assert_called_once()
|
|
||||||
@@ -303,33 +303,6 @@ async def test_send_experiment_control(agent):
|
|||||||
assert msg.thread == "reset_experiment"
|
assert msg.thread == "reset_experiment"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_send_pause_command(agent):
|
|
||||||
# --- Test PAUSE ---
|
|
||||||
await agent._send_pause_command("true")
|
|
||||||
|
|
||||||
# Should send exactly 1 message
|
|
||||||
assert agent.send.await_count == 1
|
|
||||||
|
|
||||||
# Extract the message object from the mock call
|
|
||||||
# call_args[0] are positional args, and [0] is the first arg (the message)
|
|
||||||
msg = agent.send.call_args[0][0]
|
|
||||||
|
|
||||||
# Verify Body
|
|
||||||
assert msg.body == "PAUSE"
|
|
||||||
|
|
||||||
# --- Test RESUME ---
|
|
||||||
agent.send.reset_mock()
|
|
||||||
await agent._send_pause_command("false")
|
|
||||||
|
|
||||||
# Should send exactly 1 message
|
|
||||||
assert agent.send.await_count == 1
|
|
||||||
|
|
||||||
msg = agent.send.call_args[0][0]
|
|
||||||
|
|
||||||
# Verify Body
|
|
||||||
assert msg.body == "RESUME"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_setup(agent):
|
async def test_setup(agent):
|
||||||
"""Test the setup method initializes sockets correctly."""
|
"""Test the setup method initializes sockets correctly."""
|
||||||
|
|||||||
17
uv.lock
generated
17
uv.lock
generated
@@ -1524,7 +1524,6 @@ dependencies = [
|
|||||||
{ name = "sphinx-rtd-theme" },
|
{ name = "sphinx-rtd-theme" },
|
||||||
{ name = "tf-keras" },
|
{ name = "tf-keras" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
{ name = "tornado", marker = "sys_platform == 'win32'" },
|
|
||||||
{ name = "uvicorn" },
|
{ name = "uvicorn" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1541,7 +1540,6 @@ dev = [
|
|||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
{ name = "agentspeak" },
|
{ name = "agentspeak" },
|
||||||
{ name = "deepface" },
|
|
||||||
{ name = "fastapi" },
|
{ name = "fastapi" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'" },
|
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'" },
|
||||||
@@ -1556,7 +1554,6 @@ test = [
|
|||||||
{ name = "pyyaml" },
|
{ name = "pyyaml" },
|
||||||
{ name = "pyzmq" },
|
{ name = "pyzmq" },
|
||||||
{ name = "soundfile" },
|
{ name = "soundfile" },
|
||||||
{ name = "tf-keras" },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
@@ -1580,7 +1577,6 @@ requires-dist = [
|
|||||||
{ name = "sphinx-rtd-theme", specifier = ">=3.0.2" },
|
{ name = "sphinx-rtd-theme", specifier = ">=3.0.2" },
|
||||||
{ name = "tf-keras", specifier = ">=2.20.1" },
|
{ name = "tf-keras", specifier = ">=2.20.1" },
|
||||||
{ name = "torch", specifier = ">=2.8.0" },
|
{ name = "torch", specifier = ">=2.8.0" },
|
||||||
{ name = "tornado", marker = "sys_platform == 'win32'" },
|
|
||||||
{ name = "uvicorn", specifier = ">=0.37.0" },
|
{ name = "uvicorn", specifier = ">=0.37.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1597,7 +1593,6 @@ dev = [
|
|||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
{ name = "agentspeak", specifier = ">=0.2.2" },
|
{ name = "agentspeak", specifier = ">=0.2.2" },
|
||||||
{ name = "deepface", specifier = ">=0.0.97" },
|
|
||||||
{ name = "fastapi", specifier = ">=0.115.6" },
|
{ name = "fastapi", specifier = ">=0.115.6" },
|
||||||
{ name = "httpx", specifier = ">=0.28.1" },
|
{ name = "httpx", specifier = ">=0.28.1" },
|
||||||
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'", specifier = ">=0.4.3" },
|
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'", specifier = ">=0.4.3" },
|
||||||
@@ -1612,7 +1607,6 @@ test = [
|
|||||||
{ name = "pyyaml", specifier = ">=6.0.3" },
|
{ name = "pyyaml", specifier = ">=6.0.3" },
|
||||||
{ name = "pyzmq", specifier = ">=27.1.0" },
|
{ name = "pyzmq", specifier = ">=27.1.0" },
|
||||||
{ name = "soundfile", specifier = ">=0.13.1" },
|
{ name = "soundfile", specifier = ">=0.13.1" },
|
||||||
{ name = "tf-keras", specifier = ">=2.20.1" },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2726,17 +2720,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/52/27/7fc2d7435af044ffbe0b9b8e98d99eac096d43f128a5cde23c04825d5dcf/torchaudio-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d4a715d09ac28c920d031ee1e60ecbc91e8a5079ad8c61c0277e658436c821a6", size = 2549553, upload-time = "2025-08-06T14:59:00.019Z" },
|
{ url = "https://files.pythonhosted.org/packages/52/27/7fc2d7435af044ffbe0b9b8e98d99eac096d43f128a5cde23c04825d5dcf/torchaudio-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:d4a715d09ac28c920d031ee1e60ecbc91e8a5079ad8c61c0277e658436c821a6", size = 2549553, upload-time = "2025-08-06T14:59:00.019Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tornado"
|
|
||||||
version = "6.5.4"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/37/1d/0a336abf618272d53f62ebe274f712e213f5a03c0b2339575430b8362ef2/tornado-6.5.4.tar.gz", hash = "sha256:a22fa9047405d03260b483980635f0b041989d8bcc9a313f8fe18b411d84b1d7", size = 513632, upload-time = "2025-12-15T19:21:03.836Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/0c/1a/d7592328d037d36f2d2462f4bc1fbb383eec9278bc786c1b111cbbd44cfa/tornado-6.5.4-cp39-abi3-win32.whl", hash = "sha256:1768110f2411d5cd281bac0a090f707223ce77fd110424361092859e089b38d1", size = 446481, upload-time = "2025-12-15T19:21:00.008Z" },
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/d6/6d/c69be695a0a64fd37a97db12355a035a6d90f79067a3cf936ec2b1dc38cd/tornado-6.5.4-cp39-abi3-win_amd64.whl", hash = "sha256:fa07d31e0cd85c60713f2b995da613588aa03e1303d75705dca6af8babc18ddc", size = 446886, upload-time = "2025-12-15T19:21:01.287Z" },
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/50/49/8dc3fd90902f70084bd2cd059d576ddb4f8bb44c2c7c0e33a11422acb17e/tornado-6.5.4-cp39-abi3-win_arm64.whl", hash = "sha256:053e6e16701eb6cbe641f308f4c1a9541f91b6261991160391bfc342e8a551a1", size = 445910, upload-time = "2025-12-15T19:21:02.571Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tqdm"
|
name = "tqdm"
|
||||||
version = "4.67.1"
|
version = "4.67.1"
|
||||||
|
|||||||
Reference in New Issue
Block a user