Compare commits
14 Commits
feat/reset
...
feat/pause
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cdb7fac53f | ||
|
|
d1ad2c1549 | ||
|
|
612a96940d | ||
|
|
4c20656c75 | ||
|
|
6ca86e4b81 | ||
|
|
867837dcc4 | ||
|
|
9adeb1efff | ||
|
|
7d798f2e77 | ||
|
|
5282c2471f | ||
|
|
200bd27d9b | ||
|
|
539e814c5a | ||
|
|
0c682d6440 | ||
|
|
32d8f20dc9 | ||
|
|
9cc0e39955 |
@@ -3,11 +3,14 @@ import json
|
|||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio as azmq
|
import zmq.asyncio as azmq
|
||||||
|
from pydantic import ValidationError
|
||||||
from zmq.asyncio import Context
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
from control_backend.agents.actuation.robot_gesture_agent import RobotGestureAgent
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.ri_message import PauseCommand
|
||||||
|
|
||||||
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||||
from ..perception import VADAgent
|
from ..perception import VADAgent
|
||||||
@@ -298,3 +301,11 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
self.logger.debug("Restarting communication negotiation.")
|
self.logger.debug("Restarting communication negotiation.")
|
||||||
if await self._negotiate_connection(max_retries=1):
|
if await self._negotiate_connection(max_retries=1):
|
||||||
self.connected = True
|
self.connected = True
|
||||||
|
|
||||||
|
async def handle_message(self, msg : InternalMessage):
|
||||||
|
try:
|
||||||
|
pause_command = PauseCommand.model_validate_json(msg.body)
|
||||||
|
self._req_socket.send_json(pause_command.model_dump())
|
||||||
|
self.logger.debug(self._req_socket.recv_json())
|
||||||
|
except ValidationError:
|
||||||
|
self.logger.warning("Incorrect message format for PauseCommand.")
|
||||||
|
|||||||
68
src/control_backend/agents/mock_agents/test_pause_ri.py
Normal file
68
src/control_backend/agents/mock_agents/test_pause_ri.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
|
from control_backend.agents.base import BaseAgent
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestPauseAgent(BaseAgent):
|
||||||
|
def __init__(self, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
context = Context.instance()
|
||||||
|
self.pub_socket = context.socket(zmq.PUB)
|
||||||
|
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||||
|
|
||||||
|
self.add_behavior(self._pause_command_loop())
|
||||||
|
self.logger.debug("TestPauseAgent setup complete.")
|
||||||
|
|
||||||
|
async def _pause_command_loop(self):
|
||||||
|
print("Starting Pause command test loop.")
|
||||||
|
while True:
|
||||||
|
pause_command = {
|
||||||
|
"endpoint": "pause",
|
||||||
|
"data": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
message = InternalMessage(
|
||||||
|
to="ri_communication_agent",
|
||||||
|
sender=self.name,
|
||||||
|
body=json.dumps(pause_command),
|
||||||
|
)
|
||||||
|
await self.send(message)
|
||||||
|
|
||||||
|
# User interrupt message
|
||||||
|
data = {
|
||||||
|
"type": "pause",
|
||||||
|
"context": True,
|
||||||
|
}
|
||||||
|
await self.pub_socket.send_multipart([b"button_pressed", json.dumps(data).encode()])
|
||||||
|
|
||||||
|
self.logger.info("Pausing robot actions.")
|
||||||
|
await asyncio.sleep(15) # Simulate delay between messages
|
||||||
|
|
||||||
|
pause_command = {
|
||||||
|
"endpoint": "pause",
|
||||||
|
"data": False,
|
||||||
|
}
|
||||||
|
message = InternalMessage(
|
||||||
|
to="ri_communication_agent",
|
||||||
|
sender=self.name,
|
||||||
|
body=json.dumps(pause_command),
|
||||||
|
)
|
||||||
|
await self.send(message)
|
||||||
|
|
||||||
|
# User interrupt message
|
||||||
|
data = {
|
||||||
|
"type": "pause",
|
||||||
|
"context": False,
|
||||||
|
}
|
||||||
|
await self.pub_socket.send_multipart([b"button_pressed", json.dumps(data).encode()])
|
||||||
|
|
||||||
|
self.logger.info("Resuming robot actions.")
|
||||||
|
await asyncio.sleep(15) # Simulate delay between messages
|
||||||
@@ -7,6 +7,7 @@ import zmq.asyncio as azmq
|
|||||||
|
|
||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.internal_message import InternalMessage
|
||||||
|
|
||||||
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||||
from .transcription_agent.transcription_agent import TranscriptionAgent
|
from .transcription_agent.transcription_agent import TranscriptionAgent
|
||||||
@@ -86,6 +87,12 @@ class VADAgent(BaseAgent):
|
|||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||||
self._ready = asyncio.Event()
|
self._ready = asyncio.Event()
|
||||||
|
|
||||||
|
# Pause control
|
||||||
|
self._reset_needed = False
|
||||||
|
self._paused = asyncio.Event()
|
||||||
|
self._paused.set() # Not paused at start
|
||||||
|
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
@@ -213,6 +220,16 @@ class VADAgent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
await self._ready.wait()
|
await self._ready.wait()
|
||||||
while self._running:
|
while self._running:
|
||||||
|
await self._paused.wait()
|
||||||
|
|
||||||
|
# After being unpaused, reset stream and buffers
|
||||||
|
if self._reset_needed:
|
||||||
|
self.logger.debug("Resuming: resetting stream and buffers.")
|
||||||
|
await self._reset_stream()
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
|
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||||
|
self._reset_needed = False
|
||||||
|
|
||||||
assert self.audio_in_poller is not None
|
assert self.audio_in_poller is not None
|
||||||
data = await self.audio_in_poller.poll()
|
data = await self.audio_in_poller.poll()
|
||||||
if data is None:
|
if data is None:
|
||||||
@@ -254,3 +271,27 @@ class VADAgent(BaseAgent):
|
|||||||
# At this point, we know that the speech has ended.
|
# At this point, we know that the speech has ended.
|
||||||
# Prepend the last chunk that had no speech, for a more fluent boundary
|
# Prepend the last chunk that had no speech, for a more fluent boundary
|
||||||
self.audio_buffer = chunk
|
self.audio_buffer = chunk
|
||||||
|
|
||||||
|
async def handle_message(self, msg: InternalMessage):
|
||||||
|
"""
|
||||||
|
Handle incoming messages.
|
||||||
|
|
||||||
|
Expects messages to pause or resume the VAD processing from User Interrupt Agent.
|
||||||
|
|
||||||
|
:param msg: The received internal message.
|
||||||
|
"""
|
||||||
|
sender = msg.sender
|
||||||
|
|
||||||
|
if sender == settings.agent_settings.user_interrupt_name:
|
||||||
|
if msg.body == "PAUSE":
|
||||||
|
self.logger.info("Pausing VAD processing.")
|
||||||
|
self._paused.clear()
|
||||||
|
# If the robot needs to pick up speaking where it left off, do not set _reset_needed
|
||||||
|
self._reset_needed = True
|
||||||
|
elif msg.body == "RESUME":
|
||||||
|
self.logger.info("Resuming VAD processing.")
|
||||||
|
self._paused.set()
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Unknown command from User Interrupt Agent: {msg.body}")
|
||||||
|
else:
|
||||||
|
self.logger.debug(f"Ignoring message from unknown sender: {sender}")
|
||||||
@@ -6,7 +6,12 @@ from zmq.asyncio import Context
|
|||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.ri_message import GestureCommand, RIEndpoint, SpeechCommand
|
from control_backend.schemas.ri_message import (
|
||||||
|
GestureCommand,
|
||||||
|
PauseCommand,
|
||||||
|
RIEndpoint,
|
||||||
|
SpeechCommand,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserInterruptAgent(BaseAgent):
|
class UserInterruptAgent(BaseAgent):
|
||||||
@@ -71,6 +76,12 @@ class UserInterruptAgent(BaseAgent):
|
|||||||
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
|
"Forwarded button press (override) with context '%s' to BDIProgramManager.",
|
||||||
event_context,
|
event_context,
|
||||||
)
|
)
|
||||||
|
elif event_type == "pause":
|
||||||
|
await self._send_pause_command(event_context)
|
||||||
|
if event_context:
|
||||||
|
self.logger.info("Sent pause command.")
|
||||||
|
else:
|
||||||
|
self.logger.info("Sent resume command.")
|
||||||
else:
|
else:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Received button press with unknown type '%s' (context: '%s').",
|
"Received button press with unknown type '%s' (context: '%s').",
|
||||||
@@ -130,6 +141,38 @@ class UserInterruptAgent(BaseAgent):
|
|||||||
belief_id,
|
belief_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _send_pause_command(self, pause : bool):
|
||||||
|
"""
|
||||||
|
Send a pause command to the Robot Interface via the RI Communication Agent.
|
||||||
|
Send a pause command to the other internal agents; for now just VAD agent.
|
||||||
|
"""
|
||||||
|
cmd = PauseCommand(data=pause)
|
||||||
|
message = InternalMessage(
|
||||||
|
to=settings.agent_settings.ri_communication_name,
|
||||||
|
sender=self.name,
|
||||||
|
body=cmd.model_dump_json(),
|
||||||
|
)
|
||||||
|
await self.send(message)
|
||||||
|
|
||||||
|
if pause:
|
||||||
|
# Send pause to VAD agent
|
||||||
|
vad_message = InternalMessage(
|
||||||
|
to=settings.agent_settings.vad_name,
|
||||||
|
sender=self.name,
|
||||||
|
body="PAUSE",
|
||||||
|
)
|
||||||
|
await self.send(vad_message)
|
||||||
|
self.logger.info("Sent pause command to VAD Agent and RI Communication Agent.")
|
||||||
|
else:
|
||||||
|
# Send resume to VAD agent
|
||||||
|
vad_message = InternalMessage(
|
||||||
|
to=settings.agent_settings.vad_name,
|
||||||
|
sender=self.name,
|
||||||
|
body="RESUME",
|
||||||
|
)
|
||||||
|
await self.send(vad_message)
|
||||||
|
self.logger.info("Sent resume command to VAD Agent and RI Communication Agent.")
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""
|
"""
|
||||||
Initialize the agent.
|
Initialize the agent.
|
||||||
|
|||||||
@@ -39,10 +39,11 @@ from control_backend.agents.communication import RICommunicationAgent
|
|||||||
# LLM Agents
|
# LLM Agents
|
||||||
from control_backend.agents.llm import LLMAgent
|
from control_backend.agents.llm import LLMAgent
|
||||||
|
|
||||||
|
# Other backend imports
|
||||||
|
from control_backend.agents.mock_agents.test_pause_ri import TestPauseAgent
|
||||||
|
|
||||||
# User Interrupt Agent
|
# User Interrupt Agent
|
||||||
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
from control_backend.agents.user_interrupt.user_interrupt_agent import UserInterruptAgent
|
||||||
|
|
||||||
# Other backend imports
|
|
||||||
from control_backend.api.v1.router import api_router
|
from control_backend.api.v1.router import api_router
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.logging import setup_logging
|
from control_backend.logging import setup_logging
|
||||||
@@ -141,6 +142,12 @@ async def lifespan(app: FastAPI):
|
|||||||
"name": settings.agent_settings.bdi_program_manager_name,
|
"name": settings.agent_settings.bdi_program_manager_name,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
"TestPauseAgent": (
|
||||||
|
TestPauseAgent,
|
||||||
|
{
|
||||||
|
"name": "pause_test_agent",
|
||||||
|
},
|
||||||
|
),
|
||||||
"UserInterruptAgent": (
|
"UserInterruptAgent": (
|
||||||
UserInterruptAgent,
|
UserInterruptAgent,
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class RIEndpoint(str, Enum):
|
|||||||
GESTURE_TAG = "actuate/gesture/tag"
|
GESTURE_TAG = "actuate/gesture/tag"
|
||||||
PING = "ping"
|
PING = "ping"
|
||||||
NEGOTIATE_PORTS = "negotiate/ports"
|
NEGOTIATE_PORTS = "negotiate/ports"
|
||||||
|
PAUSE = "pause"
|
||||||
|
|
||||||
|
|
||||||
class RIMessage(BaseModel):
|
class RIMessage(BaseModel):
|
||||||
@@ -64,3 +65,14 @@ class GestureCommand(RIMessage):
|
|||||||
if self.endpoint not in allowed:
|
if self.endpoint not in allowed:
|
||||||
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
|
raise ValueError("endpoint must be GESTURE_SINGLE or GESTURE_TAG")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
class PauseCommand(RIMessage):
|
||||||
|
"""
|
||||||
|
A specific command to pause or unpause the robot's actions.
|
||||||
|
|
||||||
|
:ivar endpoint: Fixed to ``RIEndpoint.PAUSE``.
|
||||||
|
:ivar data: A boolean indicating whether to pause (True) or unpause (False).
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint: RIEndpoint = RIEndpoint(RIEndpoint.PAUSE)
|
||||||
|
data: bool
|
||||||
Reference in New Issue
Block a user