feat: subscribe instead of req/res for face detection

ref: N25B-395
This commit is contained in:
Twirre Meulenbelt
2026-01-29 17:16:38 +01:00
parent 00a13426a0
commit 0f964795f3
3 changed files with 35 additions and 34 deletions

View File

@@ -219,11 +219,10 @@ class RICommunicationAgent(BaseAgent):
self.visual_emotion_recognition_agent = visual_emotion_agent self.visual_emotion_recognition_agent = visual_emotion_agent
await visual_emotion_agent.start() await visual_emotion_agent.start()
case "face": case "face":
self.logger.warn("yup we here")
face_agent = FacePerceptionAgent( face_agent = FacePerceptionAgent(
settings.agent_settings.face_agent_name, settings.agent_settings.face_agent_name,
address=addr, zmq_address=addr,
bind=bind, zmq_bind=bind,
) )
await face_agent.start() await face_agent.start()
case _: case _:

View File

@@ -1,9 +1,7 @@
import asyncio import asyncio
from random import random
import zmq import zmq
import zmq.asyncio as azmq import zmq.asyncio as azmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
@@ -17,37 +15,52 @@ class FacePerceptionAgent(BaseAgent):
via the internal PUB/SUB bus. via the internal PUB/SUB bus.
""" """
def __init__(self, name: str): def __init__(self, name: str, zmq_address: str, zmq_bind: bool):
"""
:param name: The name of the agent.
:param zmq_address: The ZMQ address to subscribe to, an endpoint which sends face presence
updates.
:param zmq_bind: Whether to connect to the ZMQ endpoint, or to bind.
"""
super().__init__(name) super().__init__(name)
self._zmq_address = zmq_address
self._zmq_bind = zmq_bind
self._socket: azmq.Socket | None = None
self._last_face_state: bool | None = None self._last_face_state: bool | None = None
self._req_socket: azmq.Socket | None = None
async def setup(self): async def setup(self):
self.logger.info("Starting FacePerceptionAgent") self.logger.info("Starting FacePerceptionAgent")
if self._req_socket is None: if self._socket is None:
self._req_socket = Context.instance().socket(zmq.REQ) self._connect_socket()
self._req_socket.connect("tcp://localhost:5559")
self.add_behavior(self._poll_loop()) self.add_behavior(self._poll_loop())
async def _poll_loop(self): def _connect_socket(self):
poll_interval = 1.0 if self._socket is not None:
self.logger.warning("ZMQ socket already initialized. Did you call setup() twice?")
return
if self._req_socket is None: self._socket = azmq.Context.instance().socket(zmq.SUB)
self.logger.warn("REQ socket not initialized before poll loop") self._socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self._zmq_bind:
self._socket.bind(self._zmq_address)
else:
self._socket.connect(self._zmq_address)
async def _poll_loop(self):
if self._socket is None:
self.logger.warning("Connection not initialized before poll loop. Call setup() first.")
return
while self._running: while self._running:
if self._req_socket is None:
self.logger.debug("REQ socket not initialized in poll loop")
try: try:
await self._req_socket.send_json({"endpoint": "face", "data": {}})
response = await asyncio.wait_for( response = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=poll_interval self._socket.recv_json(), timeout=settings.behaviour_settings.sleep_s
) )
face_present = bool(response.get("data", False)) face_present = response.get("face_detected", False)
if self._last_face_state is None: if self._last_face_state is None:
self._last_face_state = face_present self._last_face_state = face_present
@@ -57,14 +70,10 @@ class FacePerceptionAgent(BaseAgent):
self._last_face_state = face_present self._last_face_state = face_present
self.logger.debug("Face detected" if face_present else "Face lost") self.logger.debug("Face detected" if face_present else "Face lost")
await self._update_face_belief(face_present) await self._update_face_belief(face_present)
except TimeoutError:
pass
except Exception as e: except Exception as e:
self.logger.warning("Face polling failed") self.logger.error("Face polling failed", exc_info=e)
self.logger.warn(e)
i = random()
await self._update_face_belief(i > 0.5)
await asyncio.sleep(poll_interval)
async def _post_face_belief(self, present: bool): async def _post_face_belief(self, present: bool):
""" """

View File

@@ -37,7 +37,6 @@ from control_backend.agents.communication import RICommunicationAgent
# Emotional Agents # Emotional Agents
# LLM Agents # LLM Agents
from control_backend.agents.llm import LLMAgent from control_backend.agents.llm import LLMAgent
from control_backend.agents.perception.face_rec_agent import FacePerceptionAgent
# User Interrupt Agent # User Interrupt Agent
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
@@ -140,12 +139,6 @@ async def lifespan(app: FastAPI):
"name": settings.agent_settings.user_interrupt_name, "name": settings.agent_settings.user_interrupt_name,
}, },
), ),
"FaceDetectionAgent": (
FacePerceptionAgent,
{
"name": settings.agent_settings.face_agent_name,
},
),
} }
agents = [] agents = []