feat: create tests for all currect functionality and add get available tags router
ref: N25B-334
This commit is contained in:
@@ -6,7 +6,7 @@ import zmq.asyncio as azmq
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.ri_message import GestureCommand
|
||||
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint
|
||||
|
||||
|
||||
class RobotGestureAgent(BaseAgent):
|
||||
@@ -36,7 +36,9 @@ class RobotGestureAgent(BaseAgent):
|
||||
gesture_data=None,
|
||||
):
|
||||
if gesture_data is None:
|
||||
gesture_data = []
|
||||
self.gesture_data = []
|
||||
else:
|
||||
self.gesture_data = gesture_data
|
||||
super().__init__(name)
|
||||
self.address = address
|
||||
self.bind = bind
|
||||
@@ -65,8 +67,10 @@ class RobotGestureAgent(BaseAgent):
|
||||
self.subsocket = context.socket(zmq.SUB)
|
||||
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
||||
# This one
|
||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"send_gestures")
|
||||
|
||||
self.add_behavior(self._zmq_command_loop())
|
||||
self.add_behavior(self._fetch_gestures_loop())
|
||||
|
||||
self.logger.info("Finished setting up %s", self.name)
|
||||
|
||||
@@ -87,6 +91,14 @@ class RobotGestureAgent(BaseAgent):
|
||||
"""
|
||||
try:
|
||||
gesture_command = GestureCommand.model_validate_json(msg.body)
|
||||
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
|
||||
if gesture_command.data not in self.availableTags():
|
||||
self.logger.warning(
|
||||
"Received gesture tag '%s' which is not in available tags. Early returning",
|
||||
gesture_command.data,
|
||||
)
|
||||
return
|
||||
|
||||
await self.pubsocket.send_json(gesture_command.model_dump())
|
||||
except Exception:
|
||||
self.logger.exception("Error processing internal message.")
|
||||
@@ -99,15 +111,63 @@ class RobotGestureAgent(BaseAgent):
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
_, body = await self.subsocket.recv_multipart()
|
||||
topic, body = await self.subsocket.recv_multipart()
|
||||
|
||||
# Don't process send_gestures here
|
||||
if topic != b"command":
|
||||
continue
|
||||
|
||||
body = json.loads(body)
|
||||
message = GestureCommand.model_validate(body)
|
||||
|
||||
await self.pubsocket.send_json(message.model_dump())
|
||||
gesture_command = GestureCommand.model_validate(body)
|
||||
if gesture_command.endpoint == RIEndpoint.GESTURE_TAG:
|
||||
if gesture_command.data not in self.availableTags():
|
||||
self.logger.warning(
|
||||
"Received gesture tag '%s' which is not in available tags.\
|
||||
Early returning",
|
||||
gesture_command.data,
|
||||
)
|
||||
continue
|
||||
await self.pubsocket.send_json(gesture_command.model_dump())
|
||||
except Exception:
|
||||
self.logger.exception("Error processing ZMQ message.")
|
||||
|
||||
async def _fetch_gestures_loop(self):
|
||||
"""
|
||||
Loop to handle fetching gestures received via ZMQ (e.g., from the UI).
|
||||
|
||||
Listens on the 'send_gestures' topic, and returns a list on the get_gestures topic.
|
||||
"""
|
||||
while self._running:
|
||||
try:
|
||||
topic, body = await self.subsocket.recv_multipart()
|
||||
|
||||
# Don't process commands here
|
||||
if topic != b"send_gestures":
|
||||
continue
|
||||
|
||||
try:
|
||||
body = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
body = None
|
||||
|
||||
# We could have the body be the nummer of gestures you want to fetch or something.
|
||||
amount = None
|
||||
if isinstance(body, int):
|
||||
amount = body
|
||||
|
||||
tags = self.availableTags()[:amount] if amount else self.availableTags()
|
||||
response = json.dumps({"tags": tags}).encode()
|
||||
|
||||
await self.pubsocket.send_multipart(
|
||||
[
|
||||
b"get_gestures",
|
||||
response,
|
||||
]
|
||||
)
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Error fetching gesture tags.")
|
||||
|
||||
def availableTags(self):
|
||||
"""
|
||||
Returns the available gesture tags.
|
||||
|
||||
@@ -58,6 +58,45 @@ async def ping(request: Request):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/get_available_gesture_tags")
|
||||
async def get_available_gesture_tags(request: Request):
|
||||
"""
|
||||
Endpoint to retrieve the available gesture tags for the robot.
|
||||
|
||||
:param request: The FastAPI request object.
|
||||
:return: A list of available gesture tags.
|
||||
"""
|
||||
sub_socket = Context.instance().socket(zmq.SUB)
|
||||
sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
sub_socket.setsockopt(zmq.SUBSCRIBE, b"get_gestures")
|
||||
|
||||
pub_socket: Socket = request.app.state.endpoints_pub_socket
|
||||
topic = b"send_gestures"
|
||||
|
||||
# TODO: Implement a way to get a certain ammount from the UI, rather than everything.
|
||||
amount = None
|
||||
timeout = 5 # seconds
|
||||
|
||||
await pub_socket.send_multipart([topic, amount.to_bytes(4, "big") if amount else b""])
|
||||
try:
|
||||
_, body = await asyncio.wait_for(sub_socket.recv_multipart(), timeout=timeout)
|
||||
except TimeoutError:
|
||||
body = b"tags: []"
|
||||
logger.debug("got timeout error fetching gestures")
|
||||
|
||||
# Handle empty response and JSON decode errors
|
||||
available_tags = []
|
||||
if body:
|
||||
try:
|
||||
available_tags = json.loads(body).get("tags", [])
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse gesture tags JSON: {e}, body: {body}")
|
||||
# Return empty list on JSON error
|
||||
available_tags = []
|
||||
|
||||
return {"available_gesture_tags": available_tags}
|
||||
|
||||
|
||||
@router.get("/ping_stream")
|
||||
async def ping_stream(request: Request):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user