feat: implement negotiation
By implementing SocketBase and adding the socket to the state, the negotiation will automatically give the right endpoints. ref: N25B-168
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from robot_interface.endpoints.receiver_base import ReceiverBase
|
from robot_interface.endpoints.receiver_base import ReceiverBase
|
||||||
|
from robot_interface.state import state
|
||||||
|
|
||||||
|
|
||||||
class MainReceiver(ReceiverBase):
|
class MainReceiver(ReceiverBase):
|
||||||
@@ -14,14 +15,20 @@ class MainReceiver(ReceiverBase):
|
|||||||
:param port: The port to use.
|
:param port: The port to use.
|
||||||
:type port: int
|
:type port: int
|
||||||
"""
|
"""
|
||||||
super(MainReceiver, self).__init__("main")
|
super(MainReceiver, self).__init__("main", "json")
|
||||||
self.create_socket(zmq_context, zmq.REP, port)
|
self.create_socket(zmq_context, zmq.REP, port, bind=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_ping(message):
|
def _handle_ping(message):
|
||||||
"""A simple ping endpoint. Returns the provided data."""
|
"""A simple ping endpoint. Returns the provided data."""
|
||||||
return {"endpoint": "ping", "data": message.get("data")}
|
return {"endpoint": "ping", "data": message.get("data")}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_port_negotiation(message):
|
||||||
|
endpoints = [socket.endpoint_description() for socket in state.sockets]
|
||||||
|
|
||||||
|
return {"endpoint": "negotiation/ports", "data": endpoints}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_negotiation(message):
|
def _handle_negotiation(message):
|
||||||
"""
|
"""
|
||||||
@@ -33,7 +40,11 @@ class MainReceiver(ReceiverBase):
|
|||||||
:return: A response dictionary with a 'ports' key containing a list of ports and their function.
|
:return: A response dictionary with a 'ports' key containing a list of ports and their function.
|
||||||
:rtype: dict[str, list[dict]]
|
:rtype: dict[str, list[dict]]
|
||||||
"""
|
"""
|
||||||
# TODO: .../error on all endpoints?
|
# In the future, the sender could send information like the robot's IP address, etc.
|
||||||
|
|
||||||
|
if message["endpoint"] == "negotiation/ports":
|
||||||
|
return MainReceiver._handle_port_negotiation(message)
|
||||||
|
|
||||||
return {"endpoint": "negotiation/error", "data": "The requested endpoint is not implemented."}
|
return {"endpoint": "negotiation/error", "data": "The requested endpoint is not implemented."}
|
||||||
|
|
||||||
def handle_message(self, message):
|
def handle_message(self, message):
|
||||||
@@ -42,7 +53,7 @@ class MainReceiver(ReceiverBase):
|
|||||||
|
|
||||||
if message["endpoint"] == "ping":
|
if message["endpoint"] == "ping":
|
||||||
return self._handle_ping(message)
|
return self._handle_ping(message)
|
||||||
elif message["endpoint"] == "negotiation":
|
elif message["endpoint"].startswith("negotiation"):
|
||||||
return self._handle_negotiation(message)
|
return self._handle_negotiation(message)
|
||||||
|
|
||||||
return {"endpoint": "error", "data": "The requested endpoint is not supported."}
|
return {"endpoint": "error", "data": "The requested endpoint is not supported."}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ReceiverBase(SocketBase, object):
|
|||||||
:param message: The message to handle.
|
:param message: The message to handle.
|
||||||
:type message: dict
|
:type message: dict
|
||||||
|
|
||||||
:return: A response message.
|
:return: A response message or None if this type of receiver doesn't publish.
|
||||||
:rtype: dict
|
:rtype: dict | None
|
||||||
"""
|
"""
|
||||||
return {"endpoint": "error", "data": "The requested receiver is not implemented."}
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
from abc import ABCMeta
|
from abc import ABCMeta
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from robot_interface.utils import zmq_socket_type_int_to_str, zmq_socket_type_complement
|
||||||
|
|
||||||
|
|
||||||
class SocketBase(object):
|
class SocketBase(object):
|
||||||
__metaclass__ = ABCMeta
|
__metaclass__ = ABCMeta
|
||||||
@@ -7,15 +11,21 @@ class SocketBase(object):
|
|||||||
name = None
|
name = None
|
||||||
socket = None
|
socket = None
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name, data_type):
|
||||||
"""
|
"""
|
||||||
:param name: The name of the endpoint.
|
:param name: The name of the endpoint.
|
||||||
:type name: str
|
:type name: str
|
||||||
|
|
||||||
|
:param data_type: The data type of the endpoint, e.g. "json", "binary", "text", etc.
|
||||||
|
:type data_type: str
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
self.socket = None
|
self.data_type = data_type
|
||||||
|
self.port = None # Set later by `create_socket`
|
||||||
|
self.socket = None # Set later by `create_socket`
|
||||||
|
self.bound = None # Set later by `create_socket`
|
||||||
|
|
||||||
def create_socket(self, zmq_context, socket_type, port):
|
def create_socket(self, zmq_context, socket_type, port, bind=True):
|
||||||
"""
|
"""
|
||||||
Create a ZeroMQ socket.
|
Create a ZeroMQ socket.
|
||||||
|
|
||||||
@@ -27,8 +37,16 @@ class SocketBase(object):
|
|||||||
|
|
||||||
:param port: The port to use.
|
:param port: The port to use.
|
||||||
:type port: int
|
:type port: int
|
||||||
|
|
||||||
|
:param bind: Whether to bind the socket or connect to it.
|
||||||
|
:type bind: bool
|
||||||
"""
|
"""
|
||||||
|
self.port = port
|
||||||
self.socket = zmq_context.socket(socket_type)
|
self.socket = zmq_context.socket(socket_type)
|
||||||
|
self.bound = bind
|
||||||
|
if bind:
|
||||||
|
self.socket.bind("tcp://*:{}".format(port))
|
||||||
|
else:
|
||||||
self.socket.connect("tcp://localhost:{}".format(port))
|
self.socket.connect("tcp://localhost:{}".format(port))
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@@ -36,3 +54,18 @@ class SocketBase(object):
|
|||||||
if not self.socket: return
|
if not self.socket: return
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
|
||||||
|
def endpoint_description(self):
|
||||||
|
"""
|
||||||
|
Description of the endpoint. Used for negotiation.
|
||||||
|
|
||||||
|
:return: A dictionary with the following keys: name, port, pattern, data_type. See https://utrechtuniversity.youtrack.cloud/articles/N25B-A-14/RI-CB-Communication#negotiation
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"port": self.port,
|
||||||
|
"pattern": zmq_socket_type_int_to_str[zmq_socket_type_complement[self.socket.getsockopt(zmq.TYPE)]],
|
||||||
|
"data_type": self.data_type,
|
||||||
|
"bind": not self.bound
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
@@ -25,6 +26,8 @@ def main_loop(context):
|
|||||||
for receiver in receivers:
|
for receiver in receivers:
|
||||||
poller.register(receiver.socket, zmq.POLLIN)
|
poller.register(receiver.socket, zmq.POLLIN)
|
||||||
|
|
||||||
|
logging.debug("Starting main loop.")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if state.exit_event.is_set(): break
|
if state.exit_event.is_set(): break
|
||||||
socks = dict(poller.poll(100))
|
socks = dict(poller.poll(100))
|
||||||
@@ -36,6 +39,7 @@ def main_loop(context):
|
|||||||
|
|
||||||
message = receiver.socket.recv_json()
|
message = receiver.socket.recv_json()
|
||||||
response = receiver.handle_message(message)
|
response = receiver.handle_message(message)
|
||||||
|
if receiver.socket.getsockopt(zmq.TYPE) == zmq.REP:
|
||||||
receiver.socket.send_json(response)
|
receiver.socket.send_json(response)
|
||||||
|
|
||||||
time_spent_ms = (time.time() - start_time) * 1000
|
time_spent_ms = (time.time() - start_time) * 1000
|
||||||
@@ -51,7 +55,7 @@ def main():
|
|||||||
try:
|
try:
|
||||||
main_loop(context)
|
main_loop(context)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("User interrupted.")
|
logging.info("User interrupted.")
|
||||||
finally:
|
finally:
|
||||||
state.deinitialize()
|
state.deinitialize()
|
||||||
context.term()
|
context.term()
|
||||||
|
|||||||
25
src/robot_interface/utils.py
Normal file
25
src/robot_interface/utils.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
zmq_socket_type_complement = {
|
||||||
|
0: 0, # PAIR - PAIR
|
||||||
|
1: 2, # PUB - SUB
|
||||||
|
2: 1, # SUB - PUB
|
||||||
|
3: 4, # REQ - REP
|
||||||
|
4: 3, # REP - REQ
|
||||||
|
5: 6, # DEALER - ROUTER
|
||||||
|
6: 5, # ROUTER - DEALER
|
||||||
|
7: 8, # PULL - PUSH
|
||||||
|
8: 7, # PUSH - PULL
|
||||||
|
}
|
||||||
|
|
||||||
|
zmq_socket_type_int_to_str = {
|
||||||
|
0: "PAIR",
|
||||||
|
1: "PUB",
|
||||||
|
2: "SUB",
|
||||||
|
3: "REQ",
|
||||||
|
4: "REP",
|
||||||
|
5: "DEALER",
|
||||||
|
6: "ROUTER",
|
||||||
|
7: "PULL",
|
||||||
|
8: "PUSH",
|
||||||
|
}
|
||||||
|
|
||||||
|
zmq_socket_type_str_to_int = {value: key for key, value in zmq_socket_type_int_to_str.items()}
|
||||||
Reference in New Issue
Block a user