303 lines
12 KiB
Python
303 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
|
|
import zmq
|
|
import zmq.asyncio as azmq
|
|
from zmq.asyncio import Context
|
|
|
|
from control_backend.agents import BaseAgent
|
|
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
|
from control_backend.core.config import settings
|
|
|
|
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
|
from ..perception import VADAgent
|
|
|
|
|
|
class RICommunicationAgent(BaseAgent):
|
|
"""
|
|
Robot Interface (RI) Communication Agent.
|
|
|
|
This agent manages the high-level connection negotiation and health checking (heartbeat)
|
|
between the Control Backend and the Robot Interface (or UI).
|
|
|
|
It acts as a service discovery mechanism:
|
|
1. It initiates a handshake (negotiation) to discover where other services (like the robot
|
|
command listener) are listening.
|
|
2. It spawns specific agents
|
|
(like :class:`~control_backend.agents.actuation.robot_speech_agent.RobotSpeechAgent`)
|
|
once the connection details are established.
|
|
3. It maintains a "ping" loop to ensure the connection remains active.
|
|
|
|
:ivar _address: The ZMQ address to attempt the initial connection negotiation.
|
|
:ivar _bind: Whether to bind or connect the negotiation socket.
|
|
:ivar _req_socket: ZMQ REQ socket for negotiation and pings.
|
|
:ivar pub_socket: ZMQ PUB socket for internal notifications (e.g., ping status).
|
|
:ivar connected: Boolean flag indicating active connection status.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
address=settings.zmq_settings.ri_communication_address,
|
|
bind=False,
|
|
):
|
|
super().__init__(name)
|
|
self._address = address
|
|
self._bind = bind
|
|
self._req_socket: azmq.Socket | None = None
|
|
self.pub_socket: azmq.Socket | None = None
|
|
self.connected = False
|
|
|
|
async def setup(self):
|
|
"""
|
|
Initialize the agent and attempt connection.
|
|
|
|
Tries to negotiate connection up to ``behaviour_settings.comm_setup_max_retries`` times.
|
|
If successful, starts the :meth:`_listen_loop`.
|
|
"""
|
|
self.logger.info("Setting up %s", self.name)
|
|
|
|
# Bind request socket
|
|
await self._setup_sockets()
|
|
|
|
if await self._negotiate_connection():
|
|
self.connected = True
|
|
self.add_behavior(self._listen_loop())
|
|
else:
|
|
self.logger.warning("Failed to negotiate connection during setup.")
|
|
|
|
self.logger.info("Finished setting up %s", self.name)
|
|
|
|
async def _setup_sockets(self, force=False):
|
|
"""
|
|
Initialize ZMQ sockets (REQ for negotiation, PUB for internal updates).
|
|
"""
|
|
# Bind request socket
|
|
if self._req_socket is None or force:
|
|
self._req_socket = Context.instance().socket(zmq.REQ)
|
|
if self._bind:
|
|
self._req_socket.bind(self._address)
|
|
else:
|
|
self._req_socket.connect(self._address)
|
|
|
|
if self.pub_socket is None or force:
|
|
self.pub_socket = Context.instance().socket(zmq.PUB)
|
|
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
|
|
|
async def _negotiate_connection(
|
|
self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries
|
|
):
|
|
"""
|
|
Perform the handshake protocol with the Robot Interface.
|
|
|
|
Sends a ``negotiate/ports`` request and expects a configuration response containing
|
|
port assignments for various services (e.g., actuation).
|
|
|
|
:param max_retries: Number of attempts before giving up.
|
|
:return: True if negotiation succeeded, False otherwise.
|
|
"""
|
|
retries = 0
|
|
while retries < max_retries:
|
|
if self._req_socket is None:
|
|
retries += 1
|
|
continue
|
|
|
|
# Send our message and receive one back
|
|
message = {"endpoint": "negotiate/ports", "data": {}}
|
|
await self._req_socket.send_json(message)
|
|
|
|
retry_frequency = 1.0
|
|
try:
|
|
received_message = await asyncio.wait_for(
|
|
self._req_socket.recv_json(), timeout=retry_frequency
|
|
)
|
|
except TimeoutError:
|
|
self.logger.warning(
|
|
"No connection established in %d seconds (attempt %d/%d)",
|
|
retries * retry_frequency,
|
|
retries + 1,
|
|
max_retries,
|
|
)
|
|
retries += 1
|
|
continue
|
|
except Exception as e:
|
|
self.logger.warning("Unexpected error during negotiation: %s", e)
|
|
retries += 1
|
|
continue
|
|
|
|
# Validate endpoint
|
|
endpoint = received_message.get("endpoint")
|
|
if endpoint != "negotiate/ports":
|
|
self.logger.warning(
|
|
"Invalid endpoint '%s' received (attempt %d/%d)",
|
|
endpoint,
|
|
retries + 1,
|
|
max_retries,
|
|
)
|
|
retries += 1
|
|
await asyncio.sleep(1)
|
|
continue
|
|
|
|
# At this point, we have a valid response
|
|
try:
|
|
await self._handle_negotiation_response(received_message)
|
|
# Let UI know that we're connected
|
|
topic = b"ping"
|
|
data = json.dumps(True).encode()
|
|
if self.pub_socket:
|
|
await self.pub_socket.send_multipart([topic, data])
|
|
return True
|
|
except Exception as e:
|
|
self.logger.warning("Error unpacking negotiation data: %s", e)
|
|
retries += 1
|
|
await asyncio.sleep(settings.behaviour_settings.sleep_s)
|
|
continue
|
|
|
|
return False
|
|
|
|
async def _handle_negotiation_response(self, received_message):
|
|
"""
|
|
Parse the negotiation response and initialize services.
|
|
|
|
Based on the response, it might re-connect the main socket or spawn new agents
|
|
(e.g., for robot actuation).
|
|
"""
|
|
for port_data in received_message["data"]:
|
|
id = port_data["id"]
|
|
port = port_data["port"]
|
|
bind = port_data["bind"]
|
|
|
|
if not bind:
|
|
addr = f"tcp://{settings.ri_host}:{port}"
|
|
else:
|
|
addr = f"tcp://*:{port}"
|
|
|
|
match id:
|
|
case "main":
|
|
if addr != self._address:
|
|
assert self._req_socket is not None
|
|
if not bind:
|
|
self._req_socket.connect(addr)
|
|
else:
|
|
self._req_socket.bind(addr)
|
|
case "actuation":
|
|
gesture_tags = port_data.get("gestures", [])
|
|
gesture_single = port_data.get("single_gestures", [])
|
|
gesture_basic = port_data.get("basic_gestures", [])
|
|
robot_speech_agent = RobotSpeechAgent(
|
|
settings.agent_settings.robot_speech_name,
|
|
address=addr,
|
|
bind=bind,
|
|
)
|
|
robot_gesture_agent = RobotGestureAgent(
|
|
settings.agent_settings.robot_gesture_name,
|
|
address=addr,
|
|
bind=bind,
|
|
gesture_tags=gesture_tags,
|
|
gesture_basic=gesture_basic,
|
|
gesture_single=gesture_single,
|
|
)
|
|
await robot_speech_agent.start()
|
|
await asyncio.sleep(0.1) # Small delay
|
|
await robot_gesture_agent.start()
|
|
case "audio":
|
|
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
|
await vad_agent.start()
|
|
case _:
|
|
self.logger.warning("Unhandled negotiation id: %s", id)
|
|
|
|
async def stop(self):
|
|
"""
|
|
Closes all sockets.
|
|
:return:
|
|
"""
|
|
if self._req_socket:
|
|
self._req_socket.close()
|
|
if self.pub_socket:
|
|
self.pub_socket.close()
|
|
await super().stop()
|
|
|
|
async def _listen_loop(self):
|
|
"""
|
|
Maintain the connection via a heartbeat (ping) loop.
|
|
|
|
Sends a ``ping`` request periodically and waits for a reply.
|
|
If pings fail repeatedly, it triggers a disconnection handler to restart negotiation.
|
|
"""
|
|
while self._running:
|
|
if not self.connected:
|
|
await asyncio.sleep(settings.behaviour_settings.sleep_s)
|
|
continue
|
|
|
|
# We need to listen and send pings.
|
|
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
|
|
seconds_to_wait_total = settings.behaviour_settings.sleep_s
|
|
try:
|
|
assert self._req_socket is not None
|
|
await asyncio.wait_for(
|
|
self._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
|
|
)
|
|
except TimeoutError:
|
|
self.logger.debug(
|
|
"Waited too long to send message - "
|
|
"we probably dont have any receivers... but let's check!"
|
|
)
|
|
|
|
# Wait up to {seconds_to_wait_total/2} seconds for a reply
|
|
try:
|
|
assert self._req_socket is not None
|
|
message = await asyncio.wait_for(
|
|
self._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
|
|
)
|
|
|
|
self.logger.debug(f'Received message "{message}" from RI.')
|
|
if "endpoint" not in message:
|
|
self.logger.warning("No received endpoint in message, expected ping endpoint.")
|
|
continue
|
|
|
|
# See what endpoint we received
|
|
match message["endpoint"]:
|
|
case "ping":
|
|
topic = b"ping"
|
|
data = json.dumps(True).encode()
|
|
if self.pub_socket is not None:
|
|
await self.pub_socket.send_multipart([topic, data])
|
|
await asyncio.sleep(settings.behaviour_settings.sleep_s)
|
|
case _:
|
|
self.logger.debug(
|
|
"Received message with topic different than ping, while ping expected."
|
|
)
|
|
# We didnt get a reply
|
|
except TimeoutError:
|
|
self.logger.info(
|
|
f"No ping retrieved in {seconds_to_wait_total} seconds, "
|
|
"sending UI disconnection event and attempting to restart."
|
|
)
|
|
await self._handle_disconnection()
|
|
continue
|
|
except Exception:
|
|
self.logger.error("Error while waiting for ping message.", exc_info=True)
|
|
raise
|
|
|
|
async def _handle_disconnection(self):
|
|
"""
|
|
Handle connection loss.
|
|
|
|
Notifies the UI of disconnection (via internal PUB) and attempts to restart negotiation.
|
|
"""
|
|
self.connected = False
|
|
|
|
# Tell UI we're disconnected.
|
|
topic = b"ping"
|
|
data = json.dumps(False).encode()
|
|
if self.pub_socket:
|
|
try:
|
|
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
|
|
except TimeoutError:
|
|
self.logger.warning("Connection ping for router timed out.")
|
|
|
|
# Try to reboot/renegotiate
|
|
self.logger.debug("Restarting communication negotiation.")
|
|
if await self._negotiate_connection(max_retries=1):
|
|
self.connected = True
|