diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index 0a6a12f..94761c7 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -219,11 +219,10 @@ class RICommunicationAgent(BaseAgent): self.visual_emotion_recognition_agent = visual_emotion_agent await visual_emotion_agent.start() case "face": - self.logger.warn("yup we here") face_agent = FacePerceptionAgent( settings.agent_settings.face_agent_name, - address=addr, - bind=bind, + zmq_address=addr, + zmq_bind=bind, ) await face_agent.start() case _: diff --git a/src/control_backend/agents/perception/face_rec_agent.py b/src/control_backend/agents/perception/face_rec_agent.py index 9fd0098..eda723c 100644 --- a/src/control_backend/agents/perception/face_rec_agent.py +++ b/src/control_backend/agents/perception/face_rec_agent.py @@ -1,9 +1,7 @@ import asyncio -from random import random import zmq import zmq.asyncio as azmq -from zmq.asyncio import Context from control_backend.agents import BaseAgent from control_backend.core.agent_system import InternalMessage @@ -17,37 +15,52 @@ class FacePerceptionAgent(BaseAgent): via the internal PUB/SUB bus. """ - def __init__(self, name: str): + def __init__(self, name: str, zmq_address: str, zmq_bind: bool): + """ + :param name: The name of the agent. + :param zmq_address: The ZMQ address to subscribe to, an endpoint which sends face presence + updates. + :param zmq_bind: Whether to connect to the ZMQ endpoint, or to bind. + """ super().__init__(name) + self._zmq_address = zmq_address + self._zmq_bind = zmq_bind + self._socket: azmq.Socket | None = None + self._last_face_state: bool | None = None - self._req_socket: azmq.Socket | None = None async def setup(self): self.logger.info("Starting FacePerceptionAgent") - if self._req_socket is None: - self._req_socket = Context.instance().socket(zmq.REQ) - self._req_socket.connect("tcp://localhost:5559") + if self._socket is None: + self._connect_socket() self.add_behavior(self._poll_loop()) - async def _poll_loop(self): - poll_interval = 1.0 + def _connect_socket(self): + if self._socket is not None: + self.logger.warning("ZMQ socket already initialized. Did you call setup() twice?") + return - if self._req_socket is None: - self.logger.warn("REQ socket not initialized before poll loop") + self._socket = azmq.Context.instance().socket(zmq.SUB) + self._socket.setsockopt_string(zmq.SUBSCRIBE, "") + if self._zmq_bind: + self._socket.bind(self._zmq_address) + else: + self._socket.connect(self._zmq_address) + + async def _poll_loop(self): + if self._socket is None: + self.logger.warning("Connection not initialized before poll loop. Call setup() first.") + return while self._running: - if self._req_socket is None: - self.logger.debug("REQ socket not initialized in poll loop") try: - await self._req_socket.send_json({"endpoint": "face", "data": {}}) - response = await asyncio.wait_for( - self._req_socket.recv_json(), timeout=poll_interval + self._socket.recv_json(), timeout=settings.behaviour_settings.sleep_s ) - face_present = bool(response.get("data", False)) + face_present = response.get("face_detected", False) if self._last_face_state is None: self._last_face_state = face_present @@ -57,14 +70,10 @@ class FacePerceptionAgent(BaseAgent): self._last_face_state = face_present self.logger.debug("Face detected" if face_present else "Face lost") await self._update_face_belief(face_present) - + except TimeoutError: + pass except Exception as e: - self.logger.warning("Face polling failed") - self.logger.warn(e) - i = random() - await self._update_face_belief(i > 0.5) - - await asyncio.sleep(poll_interval) + self.logger.error("Face polling failed", exc_info=e) async def _post_face_belief(self, present: bool): """ diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 49b725e..a0136bd 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -37,7 +37,6 @@ from control_backend.agents.communication import RICommunicationAgent # Emotional Agents # LLM Agents from control_backend.agents.llm import LLMAgent -from control_backend.agents.perception.face_rec_agent import FacePerceptionAgent # User Interrupt Agent from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent @@ -140,12 +139,6 @@ async def lifespan(app: FastAPI): "name": settings.agent_settings.user_interrupt_name, }, ), - "FaceDetectionAgent": ( - FacePerceptionAgent, - { - "name": settings.agent_settings.face_agent_name, - }, - ), } agents = []