Files
pepperplus-cb/src/control_backend/agents/communication/ri_communication_agent.py

303 lines
12 KiB
Python

import asyncio
import json
import zmq
import zmq.asyncio as azmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
from control_backend.core.config import settings
from ..actuation.robot_speech_agent import RobotSpeechAgent
from ..perception import VADAgent
class RICommunicationAgent(BaseAgent):
"""
Robot Interface (RI) Communication Agent.
This agent manages the high-level connection negotiation and health checking (heartbeat)
between the Control Backend and the Robot Interface (or UI).
It acts as a service discovery mechanism:
1. It initiates a handshake (negotiation) to discover where other services (like the robot
command listener) are listening.
2. It spawns specific agents
(like :class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`)
once the connection details are established.
3. It maintains a "ping" loop to ensure the connection remains active.
:ivar _address: The ZMQ address to attempt the initial connection negotiation.
:ivar _bind: Whether to bind or connect the negotiation socket.
:ivar _req_socket: ZMQ REQ socket for negotiation and pings.
:ivar pub_socket: ZMQ PUB socket for internal notifications (e.g., ping status).
:ivar connected: Boolean flag indicating active connection status.
"""
def __init__(
self,
name: str,
address=settings.zmq_settings.ri_communication_address,
bind=False,
):
super().__init__(name)
self._address = address
self._bind = bind
self._req_socket: azmq.Socket | None = None
self.pub_socket: azmq.Socket | None = None
self.connected = False
async def setup(self):
"""
Initialize the agent and attempt connection.
Tries to negotiate connection up to ``behaviour_settings.comm_setup_max_retries`` times.
If successful, starts the :meth:`_listen_loop`.
"""
self.logger.info("Setting up %s", self.name)
# Bind request socket
await self._setup_sockets()
if await self._negotiate_connection():
self.connected = True
self.add_behavior(self._listen_loop())
else:
self.logger.warning("Failed to negotiate connection during setup.")
self.logger.info("Finished setting up %s", self.name)
async def _setup_sockets(self, force=False):
"""
Initialize ZMQ sockets (REQ for negotiation, PUB for internal updates).
"""
# Bind request socket
if self._req_socket is None or force:
self._req_socket = Context.instance().socket(zmq.REQ)
if self._bind:
self._req_socket.bind(self._address)
else:
self._req_socket.connect(self._address)
if self.pub_socket is None or force:
self.pub_socket = Context.instance().socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
async def _negotiate_connection(
self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries
):
"""
Perform the handshake protocol with the Robot Interface.
Sends a ``negotiate/ports`` request and expects a configuration response containing
port assignments for various services (e.g., actuation).
:param max_retries: Number of attempts before giving up.
:return: True if negotiation succeeded, False otherwise.
"""
retries = 0
while retries < max_retries:
if self._req_socket is None:
retries += 1
continue
# Send our message and receive one back
message = {"endpoint": "negotiate/ports", "data": {}}
await self._req_socket.send_json(message)
retry_frequency = 1.0
try:
received_message = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=retry_frequency
)
except TimeoutError:
self.logger.warning(
"No connection established in %d seconds (attempt %d/%d)",
retries * retry_frequency,
retries + 1,
max_retries,
)
retries += 1
continue
except Exception as e:
self.logger.warning("Unexpected error during negotiation: %s", e)
retries += 1
continue
# Validate endpoint
endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports":
self.logger.warning(
"Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
max_retries,
)
retries += 1
await asyncio.sleep(1)
continue
# At this point, we have a valid response
try:
await self._handle_negotiation_response(received_message)
# Let UI know that we're connected
topic = b"ping"
data = json.dumps(True).encode()
if self.pub_socket:
await self.pub_socket.send_multipart([topic, data])
return True
except Exception as e:
self.logger.warning("Error unpacking negotiation data: %s", e)
retries += 1
await asyncio.sleep(settings.behaviour_settings.sleep_s)
continue
return False
async def _handle_negotiation_response(self, received_message):
"""
Parse the negotiation response and initialize services.
Based on the response, it might re-connect the main socket or spawn new agents
(e.g., for robot actuation).
"""
for port_data in received_message["data"]:
id = port_data["id"]
port = port_data["port"]
bind = port_data["bind"]
if not bind:
addr = f"tcp://{settings.ri_host}:{port}"
else:
addr = f"tcp://*:{port}"
match id:
case "main":
if addr != self._address:
assert self._req_socket is not None
if not bind:
self._req_socket.connect(addr)
else:
self._req_socket.bind(addr)
case "actuation":
gesture_tags = port_data.get("gestures", [])
gesture_single = port_data.get("single_gestures", [])
gesture_basic = port_data.get("basic_gestures", [])
robot_speech_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name,
address=addr,
bind=bind,
)
robot_gesture_agent = RobotGestureAgent(
settings.agent_settings.robot_gesture_name,
address=addr,
bind=bind,
gesture_tags=gesture_tags,
gesture_basic=gesture_basic,
gesture_single=gesture_single,
)
await robot_speech_agent.start()
await asyncio.sleep(0.1) # Small delay
await robot_gesture_agent.start()
case "audio":
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
await vad_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)
async def stop(self):
"""
Closes all sockets.
:return:
"""
if self._req_socket:
self._req_socket.close()
if self.pub_socket:
self.pub_socket.close()
await super().stop()
async def _listen_loop(self):
"""
Maintain the connection via a heartbeat (ping) loop.
Sends a ``ping`` request periodically and waits for a reply.
If pings fail repeatedly, it triggers a disconnection handler to restart negotiation.
"""
while self._running:
if not self.connected:
await asyncio.sleep(settings.behaviour_settings.sleep_s)
continue
# We need to listen and send pings.
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
seconds_to_wait_total = settings.behaviour_settings.sleep_s
try:
assert self._req_socket is not None
await asyncio.wait_for(
self._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
)
except TimeoutError:
self.logger.debug(
"Waited too long to send message - "
"we probably dont have any receivers... but let's check!"
)
# Wait up to {seconds_to_wait_total/2} seconds for a reply
try:
assert self._req_socket is not None
message = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
)
self.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" not in message:
self.logger.warning("No received endpoint in message, expected ping endpoint.")
continue
# See what endpoint we received
match message["endpoint"]:
case "ping":
topic = b"ping"
data = json.dumps(True).encode()
if self.pub_socket is not None:
await self.pub_socket.send_multipart([topic, data])
await asyncio.sleep(settings.behaviour_settings.sleep_s)
case _:
self.logger.debug(
"Received message with topic different than ping, while ping expected."
)
# We didnt get a reply
except TimeoutError:
self.logger.info(
f"No ping retrieved in {seconds_to_wait_total} seconds, "
"sending UI disconnection event and attempting to restart."
)
await self._handle_disconnection()
continue
except Exception:
self.logger.error("Error while waiting for ping message.", exc_info=True)
raise
async def _handle_disconnection(self):
"""
Handle connection loss.
Notifies the UI of disconnection (via internal PUB) and attempts to restart negotiation.
"""
self.connected = False
# Tell UI we're disconnected.
topic = b"ping"
data = json.dumps(False).encode()
if self.pub_socket:
try:
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
except TimeoutError:
self.logger.warning("Connection ping for router timed out.")
# Try to reboot/renegotiate
self.logger.debug("Restarting communication negotiation.")
if await self._negotiate_connection(max_retries=1):
self.connected = True