fix: move VAD agent creation to RI communication agent
Previously, it was started in main, but it should use values negotiated by the RI communication agent. ref: N25B-356
This commit is contained in:
@@ -9,6 +9,7 @@ from control_backend.agents import BaseAgent
|
|||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||||
|
from ..perception import VADAgent
|
||||||
|
|
||||||
|
|
||||||
class RICommunicationAgent(BaseAgent):
|
class RICommunicationAgent(BaseAgent):
|
||||||
@@ -185,6 +186,9 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
bind=bind,
|
bind=bind,
|
||||||
)
|
)
|
||||||
await ri_commands_agent.start()
|
await ri_commands_agent.start()
|
||||||
|
case "audio":
|
||||||
|
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
||||||
|
await vad_agent.start()
|
||||||
case _:
|
case _:
|
||||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import zmq.asyncio as azmq
|
|||||||
from control_backend.agents import BaseAgent
|
from control_backend.agents import BaseAgent
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
|
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||||
from .transcription_agent.transcription_agent import TranscriptionAgent
|
from .transcription_agent.transcription_agent import TranscriptionAgent
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +62,7 @@ class VADAgent(BaseAgent):
|
|||||||
:ivar audio_in_address: Address of the input audio stream.
|
:ivar audio_in_address: Address of the input audio stream.
|
||||||
:ivar audio_in_bind: Whether to bind or connect to the input address.
|
:ivar audio_in_bind: Whether to bind or connect to the input address.
|
||||||
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
|
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
|
||||||
|
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||||
@@ -79,6 +81,8 @@ class VADAgent(BaseAgent):
|
|||||||
self.audio_out_socket: azmq.Socket | None = None
|
self.audio_out_socket: azmq.Socket | None = None
|
||||||
self.audio_in_poller: SocketPoller | None = None
|
self.audio_in_poller: SocketPoller | None = None
|
||||||
|
|
||||||
|
self.program_sub_socket: azmq.Socket | None = None
|
||||||
|
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
self.audio_buffer = np.array([], dtype=np.float32)
|
||||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||||
self._ready = asyncio.Event()
|
self._ready = asyncio.Event()
|
||||||
@@ -90,9 +94,10 @@ class VADAgent(BaseAgent):
|
|||||||
|
|
||||||
1. Connects audio input socket.
|
1. Connects audio input socket.
|
||||||
2. Binds audio output socket (random port).
|
2. Binds audio output socket (random port).
|
||||||
3. Loads VAD model from Torch Hub.
|
3. Connects to program communication socket.
|
||||||
4. Starts the streaming loop.
|
4. Loads VAD model from Torch Hub.
|
||||||
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
|
5. Starts the streaming loop.
|
||||||
|
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
|
||||||
"""
|
"""
|
||||||
self.logger.info("Setting up %s", self.name)
|
self.logger.info("Setting up %s", self.name)
|
||||||
|
|
||||||
@@ -105,6 +110,11 @@ class VADAgent(BaseAgent):
|
|||||||
return
|
return
|
||||||
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||||
|
|
||||||
|
# Connect to internal communication socket
|
||||||
|
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||||
|
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||||
|
self.program_sub_socket.subscribe(PROGRAM_STATUS)
|
||||||
|
|
||||||
# Initialize VAD model
|
# Initialize VAD model
|
||||||
try:
|
try:
|
||||||
self.model, _ = torch.hub.load(
|
self.model, _ = torch.hub.load(
|
||||||
@@ -117,10 +127,8 @@ class VADAgent(BaseAgent):
|
|||||||
await self.stop()
|
await self.stop()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Warmup/reset
|
|
||||||
await self.reset_stream()
|
|
||||||
|
|
||||||
self.add_behavior(self._streaming_loop())
|
self.add_behavior(self._streaming_loop())
|
||||||
|
self.add_behavior(self._status_loop())
|
||||||
|
|
||||||
# Start agents dependent on the output audio fragments here
|
# Start agents dependent on the output audio fragments here
|
||||||
transcriber = TranscriptionAgent(audio_out_address)
|
transcriber = TranscriptionAgent(audio_out_address)
|
||||||
@@ -165,7 +173,7 @@ class VADAgent(BaseAgent):
|
|||||||
self.audio_out_socket = None
|
self.audio_out_socket = None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def reset_stream(self):
|
async def _reset_stream(self):
|
||||||
"""
|
"""
|
||||||
Clears the ZeroMQ queue and sets ready state.
|
Clears the ZeroMQ queue and sets ready state.
|
||||||
"""
|
"""
|
||||||
@@ -176,6 +184,23 @@ class VADAgent(BaseAgent):
|
|||||||
self.logger.info(f"Discarded {discarded} audio packets before starting.")
|
self.logger.info(f"Discarded {discarded} audio packets before starting.")
|
||||||
self._ready.set()
|
self._ready.set()
|
||||||
|
|
||||||
|
async def _status_loop(self):
|
||||||
|
"""Loop for checking program status. Only start listening if program is RUNNING."""
|
||||||
|
while self._running:
|
||||||
|
topic, body = await self.program_sub_socket.recv_multipart()
|
||||||
|
|
||||||
|
if topic != PROGRAM_STATUS:
|
||||||
|
continue
|
||||||
|
if body != ProgramStatus.RUNNING.value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Program is now running, we can start our stream
|
||||||
|
await self._reset_stream()
|
||||||
|
|
||||||
|
# We don't care about further status updates
|
||||||
|
self.program_sub_socket.close()
|
||||||
|
break
|
||||||
|
|
||||||
async def _streaming_loop(self):
|
async def _streaming_loop(self):
|
||||||
"""
|
"""
|
||||||
Main loop for processing audio stream.
|
Main loop for processing audio stream.
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ class ZMQSettings(BaseModel):
|
|||||||
internal_sub_address: str = "tcp://localhost:5561"
|
internal_sub_address: str = "tcp://localhost:5561"
|
||||||
ri_command_address: str = "tcp://localhost:0000"
|
ri_command_address: str = "tcp://localhost:0000"
|
||||||
ri_communication_address: str = "tcp://*:5555"
|
ri_communication_address: str = "tcp://*:5555"
|
||||||
vad_agent_address: str = "tcp://localhost:5558"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentSettings(BaseModel):
|
class AgentSettings(BaseModel):
|
||||||
|
|||||||
@@ -39,13 +39,11 @@ from control_backend.agents.communication import RICommunicationAgent
|
|||||||
# LLM Agents
|
# LLM Agents
|
||||||
from control_backend.agents.llm import LLMAgent
|
from control_backend.agents.llm import LLMAgent
|
||||||
|
|
||||||
# Perceive agents
|
|
||||||
from control_backend.agents.perception import VADAgent
|
|
||||||
|
|
||||||
# Other backend imports
|
# Other backend imports
|
||||||
from control_backend.api.v1.router import api_router
|
from control_backend.api.v1.router import api_router
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.logging import setup_logging
|
from control_backend.logging import setup_logging
|
||||||
|
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -95,6 +93,8 @@ async def lifespan(app: FastAPI):
|
|||||||
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||||
app.state.endpoints_pub_socket = endpoints_pub_socket
|
app.state.endpoints_pub_socket = endpoints_pub_socket
|
||||||
|
|
||||||
|
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STARTING.value])
|
||||||
|
|
||||||
# --- Initialize Agents ---
|
# --- Initialize Agents ---
|
||||||
logger.info("Initializing and starting agents.")
|
logger.info("Initializing and starting agents.")
|
||||||
|
|
||||||
@@ -132,10 +132,6 @@ async def lifespan(app: FastAPI):
|
|||||||
"name": settings.agent_settings.text_belief_extractor_name,
|
"name": settings.agent_settings.text_belief_extractor_name,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
"VADAgent": (
|
|
||||||
VADAgent,
|
|
||||||
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
|
|
||||||
),
|
|
||||||
"ProgramManagerAgent": (
|
"ProgramManagerAgent": (
|
||||||
BDIProgramManager,
|
BDIProgramManager,
|
||||||
{
|
{
|
||||||
@@ -146,32 +142,28 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
agents = []
|
agents = []
|
||||||
|
|
||||||
vad_agent = None
|
|
||||||
|
|
||||||
for name, (agent_class, kwargs) in agents_to_start.items():
|
for name, (agent_class, kwargs) in agents_to_start.items():
|
||||||
try:
|
try:
|
||||||
logger.debug("Starting agent: %s", name)
|
logger.debug("Starting agent: %s", name)
|
||||||
agent_instance = agent_class(**kwargs)
|
agent_instance = agent_class(**kwargs)
|
||||||
await agent_instance.start()
|
await agent_instance.start()
|
||||||
if isinstance(agent_instance, VADAgent):
|
|
||||||
vad_agent = agent_instance
|
|
||||||
agents.append(agent_instance)
|
agents.append(agent_instance)
|
||||||
logger.info("Agent '%s' started successfully.", name)
|
logger.info("Agent '%s' started successfully.", name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
|
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
assert vad_agent is not None
|
|
||||||
await vad_agent.reset_stream()
|
|
||||||
|
|
||||||
logger.info("Application startup complete.")
|
logger.info("Application startup complete.")
|
||||||
|
|
||||||
|
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.RUNNING.value])
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# --- APPLICATION SHUTDOWN ---
|
# --- APPLICATION SHUTDOWN ---
|
||||||
logger.info("%s is shutting down.", app.title)
|
logger.info("%s is shutting down.", app.title)
|
||||||
|
|
||||||
# Potential shutdown logic goes here
|
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
|
||||||
|
# Additional shutdown logic goes here
|
||||||
|
|
||||||
logger.info("Application shutdown complete.")
|
logger.info("Application shutdown complete.")
|
||||||
|
|
||||||
|
|||||||
16
src/control_backend/schemas/program_status.py
Normal file
16
src/control_backend/schemas/program_status.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
PROGRAM_STATUS = b"internal/program_status"
|
||||||
|
"""A topic key for the program status."""
|
||||||
|
|
||||||
|
|
||||||
|
class ProgramStatus(Enum):
|
||||||
|
"""
|
||||||
|
Used in internal communication, to tell agents what the status of the program is.
|
||||||
|
|
||||||
|
For example, the VAD agent only starts listening when the program is RUNNING.
|
||||||
|
"""
|
||||||
|
|
||||||
|
STARTING = b"starting"
|
||||||
|
RUNNING = b"running"
|
||||||
|
STOPPING = b"stopping"
|
||||||
@@ -5,6 +5,7 @@ import pytest
|
|||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from control_backend.agents.perception.vad_agent import VADAgent
|
from control_backend.agents.perception.vad_agent import VADAgent
|
||||||
|
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -43,14 +44,12 @@ async def test_normal_setup(per_transcription_agent):
|
|||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
per_vad_agent.add_behavior = swallow_background_task
|
per_vad_agent.add_behavior = swallow_background_task
|
||||||
per_vad_agent.reset_stream = AsyncMock()
|
|
||||||
|
|
||||||
await per_vad_agent.setup()
|
await per_vad_agent.setup()
|
||||||
|
|
||||||
per_transcription_agent.assert_called_once()
|
per_transcription_agent.assert_called_once()
|
||||||
per_transcription_agent.return_value.start.assert_called_once()
|
per_transcription_agent.return_value.start.assert_called_once()
|
||||||
per_vad_agent._streaming_loop.assert_called_once()
|
per_vad_agent._streaming_loop.assert_called_once()
|
||||||
per_vad_agent.reset_stream.assert_called_once()
|
|
||||||
assert per_vad_agent.audio_in_socket is not None
|
assert per_vad_agent.audio_in_socket is not None
|
||||||
assert per_vad_agent.audio_out_socket is not None
|
assert per_vad_agent.audio_out_socket is not None
|
||||||
|
|
||||||
@@ -103,7 +102,7 @@ async def test_out_socket_creation_failure(zmq_context):
|
|||||||
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
per_vad_agent.stop = AsyncMock()
|
per_vad_agent.stop = AsyncMock()
|
||||||
per_vad_agent.reset_stream = AsyncMock()
|
per_vad_agent._reset_stream = AsyncMock()
|
||||||
per_vad_agent._streaming_loop = AsyncMock()
|
per_vad_agent._streaming_loop = AsyncMock()
|
||||||
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
|
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
|
||||||
|
|
||||||
@@ -124,7 +123,7 @@ async def test_stop(zmq_context, per_transcription_agent):
|
|||||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||||
"""
|
"""
|
||||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
per_vad_agent.reset_stream = AsyncMock()
|
per_vad_agent._reset_stream = AsyncMock()
|
||||||
per_vad_agent._streaming_loop = AsyncMock()
|
per_vad_agent._streaming_loop = AsyncMock()
|
||||||
|
|
||||||
async def swallow_background_task(coro):
|
async def swallow_background_task(coro):
|
||||||
@@ -142,3 +141,66 @@ async def test_stop(zmq_context, per_transcription_agent):
|
|||||||
assert zmq_context.return_value.socket.return_value.close.call_count == 2
|
assert zmq_context.return_value.socket.return_value.close.call_count == 2
|
||||||
assert per_vad_agent.audio_in_socket is None
|
assert per_vad_agent.audio_in_socket is None
|
||||||
assert per_vad_agent.audio_out_socket is None
|
assert per_vad_agent.audio_out_socket is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_application_startup_complete(zmq_context):
|
||||||
|
"""Check that it resets the stream when the program finishes startup."""
|
||||||
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
vad_agent._running = True
|
||||||
|
vad_agent._reset_stream = AsyncMock()
|
||||||
|
vad_agent.program_sub_socket = AsyncMock()
|
||||||
|
vad_agent.program_sub_socket.recv_multipart.side_effect = [
|
||||||
|
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
|
||||||
|
]
|
||||||
|
|
||||||
|
await vad_agent._status_loop()
|
||||||
|
|
||||||
|
vad_agent._reset_stream.assert_called_once()
|
||||||
|
vad_agent.program_sub_socket.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_application_other_status(zmq_context):
|
||||||
|
"""
|
||||||
|
Check that it does nothing when the internal communication message is a status update, but not
|
||||||
|
running.
|
||||||
|
"""
|
||||||
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
vad_agent._running = True
|
||||||
|
vad_agent._reset_stream = AsyncMock()
|
||||||
|
vad_agent.program_sub_socket = AsyncMock()
|
||||||
|
|
||||||
|
vad_agent.program_sub_socket.recv_multipart.side_effect = [
|
||||||
|
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
|
||||||
|
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
|
||||||
|
await vad_agent._status_loop()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
vad_agent._reset_stream.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_application_message_other(zmq_context):
|
||||||
|
"""
|
||||||
|
Check that it does nothing when there's an internal communication message that is not a status
|
||||||
|
update.
|
||||||
|
"""
|
||||||
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
vad_agent._running = True
|
||||||
|
vad_agent._reset_stream = AsyncMock()
|
||||||
|
vad_agent.program_sub_socket = AsyncMock()
|
||||||
|
|
||||||
|
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
|
||||||
|
await vad_agent._status_loop()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
vad_agent._reset_stream.assert_not_called()
|
||||||
|
|||||||
Reference in New Issue
Block a user