fix: wait for req socket send to make sure we dont stay stuck - if there's no... #23

Merged
9828273 merged 19 commits from feat/cb2ui-robot-connections into dev 2025-11-18 12:24:15 +00:00
3 changed files with 67 additions and 75 deletions
Showing only changes of commit debc87c0bb - Show all commits

View File

@@ -56,28 +56,29 @@ class RICommunicationAgent(BaseAgent):
"we probably dont have any receivers... but let's check!" "we probably dont have any receivers... but let's check!"
) )
# Wait up to {seconds_to_wait_total/2} seconds for a reply:) # Wait up to {seconds_to_wait_total/2} seconds for a reply
try: try:
message = await asyncio.wait_for( message = await asyncio.wait_for(
self.agent._req_socket.recv_json(), timeout=seconds_to_wait_total / 2 self.agent._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
) )
# We didnt get a reply :( # We didnt get a reply
except TimeoutError: except TimeoutError:
self.agent.logger.info( self.agent.logger.info(
f"No ping retrieved in {seconds_to_wait_total} seconds, " f"No ping retrieved in {seconds_to_wait_total} seconds, "
"sending UI disconnection event and soft killing myself." "sending UI disconnection event and attempting to restart."
) )
# Make sure we dont retry receiving messages untill we're setup. # Make sure we dont retry receiving messages untill we're setup.
self.agent.connected = False self.agent.connected = False
self.agent.remove_behaviour(self)
# Tell UI we're disconnected. # Tell UI we're disconnected.
topic = b"ping" topic = b"ping"
data = json.dumps(False).encode() data = json.dumps(False).encode()
if self.agent.pub_socket is None: if self.agent.pub_socket is None:
self.agent.logger.error( self.agent.logger.error(
"communication agent pub socket not correctly initialized." "Communication agent pub socket not correctly initialized."
) )
else: else:
try: try:
@@ -85,17 +86,20 @@ class RICommunicationAgent(BaseAgent):
self.agent.pub_socket.send_multipart([topic, data]), 5 self.agent.pub_socket.send_multipart([topic, data]), 5
) )
except TimeoutError: except TimeoutError:
self.agent.logger.error( self.agent.logger.warning(
"Initial connection ping for router timed" "Initial connection ping for router timed"
" out in ri_communication_agent." " out in ri_communication_agent."
) )
# Try to reboot. # Try to reboot.
self.agent.logger.debug("Restarting communication agent.")
await self.agent.setup() await self.agent.setup()
self.agent.logger.debug(f'Received message "{message}" from RI.') self.agent.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" not in message: if "endpoint" not in message:
self.agent.logger.error("No received endpoint in message, excepted ping endpoint.") self.agent.logger.warning(
"No received endpoint in message, expected ping endpoint."
)
return return
# See what endpoint we received # See what endpoint we received
@@ -107,7 +111,7 @@ class RICommunicationAgent(BaseAgent):
await self.agent.pub_socket.send_multipart([topic, data]) await self.agent.pub_socket.send_multipart([topic, data])
await asyncio.sleep(1) await asyncio.sleep(1)
case _: case _:
self.agent.logger.info( self.agent.logger.debug(
"Received message with topic different than ping, while ping expected." "Received message with topic different than ping, while ping expected."
) )
@@ -143,16 +147,20 @@ class RICommunicationAgent(BaseAgent):
if self._req_socket is None: if self._req_socket is None:
continue continue
# Send our message and receive one back:) # Send our message and receive one back
message = {"endpoint": "negotiate/ports", "data": {}} message = {"endpoint": "negotiate/ports", "data": {}}
await self._req_socket.send_json(message) await self._req_socket.send_json(message)
retry_frequency = 1.0
try: try:
received_message = await asyncio.wait_for(self._req_socket.recv_json(), timeout=1.0) received_message = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=retry_frequency
)
except TimeoutError: except TimeoutError:
self.logger.warning( self.logger.warning(
"No connection established in 20 seconds (attempt %d/%d)", "No connection established in %d seconds (attempt %d/%d)",
retries * retry_frequency,
retries + 1, retries + 1,
max_retries, max_retries,
) )
@@ -160,21 +168,21 @@ class RICommunicationAgent(BaseAgent):
continue continue
except Exception as e: except Exception as e:
self.logger.error("Unexpected error during negotiation: %s", e) self.logger.warning("Unexpected error during negotiation: %s", e)
retries += 1 retries += 1
continue continue
# Validate endpoint # Validate endpoint
endpoint = received_message.get("endpoint") endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports": if endpoint != "negotiate/ports":
# TODO: Should this send a message back? self.logger.warning(
self.logger.error(
"Invalid endpoint '%s' received (attempt %d/%d)", "Invalid endpoint '%s' received (attempt %d/%d)",
endpoint, endpoint,
retries + 1, retries + 1,
max_retries, max_retries,
) )
retries += 1 retries += 1
await asyncio.sleep(1)
continue continue
# At this point, we have a valid response # At this point, we have a valid response
@@ -194,7 +202,7 @@ class RICommunicationAgent(BaseAgent):
if addr != self._address: if addr != self._address:
if not bind: if not bind:
self._req_socket.connect(addr) self._req_socket.connect(addr)
else: # TODO: Should this ever be the case? else:
self._req_socket.bind(addr) self._req_socket.bind(addr)
case "actuation": case "actuation":
ri_commands_agent = RICommandAgent( ri_commands_agent = RICommandAgent(
@@ -210,31 +218,32 @@ class RICommunicationAgent(BaseAgent):
self.logger.warning("Unhandled negotiation id: %s", id) self.logger.warning("Unhandled negotiation id: %s", id)
except Exception as e: except Exception as e:
self.logger.error("Error unpacking negotiation data: %s", e) self.logger.warning("Error unpacking negotiation data: %s", e)
retries += 1 retries += 1
await asyncio.sleep(1)
continue continue
# setup succeeded # setup succeeded
break break
else: else:
self.logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries) self.logger.error("Failed to set up %s after %d retries", self.name, max_retries)
return return
# Set up ping behaviour # Set up ping behaviour
listen_behaviour = self.ListenBehaviour() listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour) self.add_behaviour(listen_behaviour)
# Let UI know that we're connected >:) # Let UI know that we're connected
topic = b"ping" topic = b"ping"
data = json.dumps(True).encode() data = json.dumps(True).encode()
if self.pub_socket is None: if self.pub_socket is None:
self.logger.error("communication agent pub socket not correctly initialized.") self.logger.error("Communication agent pub socket not correctly initialized.")
else: else:
try: try:
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5) await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
except TimeoutError: except TimeoutError:
self.logger.error( self.logger.warning(
"Initial connection ping for router timed out in ri_communication_agent." "Initial connection ping for router timed out in ri_communication_agent."
) )

View File

@@ -21,7 +21,6 @@ async def receive_command(command: SpeechCommand, request: Request):
SpeechCommand.model_validate(command) SpeechCommand.model_validate(command)
topic = b"command" topic = b"command"
# TODO: Check with Kasper
pub_socket: Socket = request.app.state.endpoints_pub_socket pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()]) await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
@@ -48,8 +47,8 @@ async def ping_stream(request: Request):
ping_frequency = 2 ping_frequency = 2
# Even though its most likely the updates should alternate # Even though its most likely the updates should alternate
# So, True - False - True - False for connectivity. # (So, True - False - True - False for connectivity),
# Let's still check:) # let's still check.
while True: while True:
try: try:
topic, body = await asyncio.wait_for( topic, body = await asyncio.wait_for(
@@ -58,11 +57,11 @@ async def ping_stream(request: Request):
connected = json.loads(body) connected = json.loads(body)
except TimeoutError: except TimeoutError:
logger.debug("got timeout error in ping loop in ping router") logger.debug("got timeout error in ping loop in ping router")
await asyncio.sleep(0.1) connected = False
# Stop if client disconnected # Stop if client disconnected
if await request.is_disconnected(): if await request.is_disconnected():
print("Client disconnected from SSE") logger.info("Client disconnected from SSE")
break break
logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}") logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}")

View File

@@ -196,14 +196,14 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("ERROR"):
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "test@server",
"password", "password",
address="tcp://localhost:5555", address="tcp://localhost:5555",
bind=False, bind=False,
) )
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
@@ -211,7 +211,6 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
# Since it failed, there should not be any command agent. # Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()
assert "Failed to set up RICommunicationAgent" in caplog.text
# Ensure the agent did not attach a ListenBehaviour # Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@@ -362,14 +361,14 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "test@server",
"password", "password",
address="tcp://localhost:5555", address="tcp://localhost:5555",
bind=False, bind=False,
) )
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
@@ -377,7 +376,6 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
# Since it failed, there should not be any command agent. # Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()
assert "Unhandled negotiation id:" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -398,21 +396,20 @@ async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "test@server",
"password", "password",
address="tcp://localhost:5555", address="tcp://localhost:5555",
bind=False, bind=False,
) )
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
# Since it failed, there should not be any command agent. # Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()
assert "No connection established in 20 seconds" in caplog.text
# Ensure the agent did not attach a ListenBehaviour # Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@@ -425,7 +422,6 @@ async def test_listen_behaviour_ping_correct(caplog):
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}}) fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# TODO: Integration test between actual server and password needed for spade agents
agent = RICommunicationAgent("test@server", "password") agent = RICommunicationAgent("test@server", "password")
agent._req_socket = fake_socket agent._req_socket = fake_socket
agent.connected = True agent.connected = True
@@ -433,13 +429,10 @@ async def test_listen_behaviour_ping_correct(caplog):
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops) await behaviour.run()
with caplog.at_level("DEBUG"):
await behaviour.run()
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
assert "Received message" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -470,10 +463,9 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops) # Run once (CyclicBehaviour normally loops)
with caplog.at_level("INFO"):
await behaviour.run()
assert "Received message with topic different than ping, while ping expected." in caplog.text await behaviour.run()
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
@@ -493,10 +485,9 @@ async def test_listen_behaviour_timeout(zmq_context, caplog):
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
with caplog.at_level("INFO"): await behaviour.run()
await behaviour.run() assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
assert not agent.connected
assert "No ping" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -522,11 +513,8 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops) await behaviour.run()
with caplog.at_level("ERROR"):
await behaviour.run()
assert "No received endpoint in message, excepted ping endpoint." in caplog.text
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
@@ -546,11 +534,10 @@ async def test_setup_unexpected_exception(zmq_context, caplog):
bind=False, bind=False,
) )
with caplog.at_level("ERROR"): await agent.setup(max_retries=1)
await agent.setup(max_retries=1)
# Ensure that the error was logged assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
assert "Unexpected error during negotiation: boom!" in caplog.text assert not agent.connected
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -582,11 +569,8 @@ async def test_setup_unpacking_exception(zmq_context, caplog):
) )
# --- Act & Assert --- # --- Act & Assert ---
with caplog.at_level("ERROR"):
await agent.setup(max_retries=1)
# Ensure the unpacking exception was logged await agent.setup(max_retries=1)
assert "Error unpacking negotiation data" in caplog.text
# Ensure no command agent was started # Ensure no command agent was started
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()