feat: subscribe instead of req/res for face detection
ref: N25B-395
This commit is contained in:
@@ -219,11 +219,10 @@ class RICommunicationAgent(BaseAgent):
|
||||
self.visual_emotion_recognition_agent = visual_emotion_agent
|
||||
await visual_emotion_agent.start()
|
||||
case "face":
|
||||
self.logger.warn("yup we here")
|
||||
face_agent = FacePerceptionAgent(
|
||||
settings.agent_settings.face_agent_name,
|
||||
address=addr,
|
||||
bind=bind,
|
||||
zmq_address=addr,
|
||||
zmq_bind=bind,
|
||||
)
|
||||
await face_agent.start()
|
||||
case _:
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
from random import random
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio as azmq
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
@@ -17,37 +15,52 @@ class FacePerceptionAgent(BaseAgent):
|
||||
via the internal PUB/SUB bus.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, name: str, zmq_address: str, zmq_bind: bool):
|
||||
"""
|
||||
:param name: The name of the agent.
|
||||
:param zmq_address: The ZMQ address to subscribe to, an endpoint which sends face presence
|
||||
updates.
|
||||
:param zmq_bind: Whether to connect to the ZMQ endpoint, or to bind.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self._zmq_address = zmq_address
|
||||
self._zmq_bind = zmq_bind
|
||||
self._socket: azmq.Socket | None = None
|
||||
|
||||
self._last_face_state: bool | None = None
|
||||
self._req_socket: azmq.Socket | None = None
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info("Starting FacePerceptionAgent")
|
||||
|
||||
if self._req_socket is None:
|
||||
self._req_socket = Context.instance().socket(zmq.REQ)
|
||||
self._req_socket.connect("tcp://localhost:5559")
|
||||
if self._socket is None:
|
||||
self._connect_socket()
|
||||
|
||||
self.add_behavior(self._poll_loop())
|
||||
|
||||
async def _poll_loop(self):
|
||||
poll_interval = 1.0
|
||||
def _connect_socket(self):
|
||||
if self._socket is not None:
|
||||
self.logger.warning("ZMQ socket already initialized. Did you call setup() twice?")
|
||||
return
|
||||
|
||||
if self._req_socket is None:
|
||||
self.logger.warn("REQ socket not initialized before poll loop")
|
||||
self._socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self._socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
if self._zmq_bind:
|
||||
self._socket.bind(self._zmq_address)
|
||||
else:
|
||||
self._socket.connect(self._zmq_address)
|
||||
|
||||
async def _poll_loop(self):
|
||||
if self._socket is None:
|
||||
self.logger.warning("Connection not initialized before poll loop. Call setup() first.")
|
||||
return
|
||||
|
||||
while self._running:
|
||||
if self._req_socket is None:
|
||||
self.logger.debug("REQ socket not initialized in poll loop")
|
||||
try:
|
||||
await self._req_socket.send_json({"endpoint": "face", "data": {}})
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
self._req_socket.recv_json(), timeout=poll_interval
|
||||
self._socket.recv_json(), timeout=settings.behaviour_settings.sleep_s
|
||||
)
|
||||
|
||||
face_present = bool(response.get("data", False))
|
||||
face_present = response.get("face_detected", False)
|
||||
|
||||
if self._last_face_state is None:
|
||||
self._last_face_state = face_present
|
||||
@@ -57,14 +70,10 @@ class FacePerceptionAgent(BaseAgent):
|
||||
self._last_face_state = face_present
|
||||
self.logger.debug("Face detected" if face_present else "Face lost")
|
||||
await self._update_face_belief(face_present)
|
||||
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.logger.warning("Face polling failed")
|
||||
self.logger.warn(e)
|
||||
i = random()
|
||||
await self._update_face_belief(i > 0.5)
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
self.logger.error("Face polling failed", exc_info=e)
|
||||
|
||||
async def _post_face_belief(self, present: bool):
|
||||
"""
|
||||
|
||||
@@ -37,7 +37,6 @@ from control_backend.agents.communication import RICommunicationAgent
|
||||
# Emotional Agents
|
||||
# LLM Agents
|
||||
from control_backend.agents.llm import LLMAgent
|
||||
from control_backend.agents.perception.face_rec_agent import FacePerceptionAgent
|
||||
|
||||
# User Interrupt Agent
|
||||
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
||||
@@ -140,12 +139,6 @@ async def lifespan(app: FastAPI):
|
||||
"name": settings.agent_settings.user_interrupt_name,
|
||||
},
|
||||
),
|
||||
"FaceDetectionAgent": (
|
||||
FacePerceptionAgent,
|
||||
{
|
||||
"name": settings.agent_settings.face_agent_name,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
agents = []
|
||||
|
||||
Reference in New Issue
Block a user