fix: move VAD agent creation to RI communication agent

Previously, it was started in main, but it should use values negotiated by the RI communication agent.

ref: N25B-356
This commit is contained in:
Twirre Meulenbelt
2025-12-03 15:07:29 +01:00
parent c85753f834
commit 21e9d05d6e
6 changed files with 125 additions and 27 deletions

View File

@@ -9,6 +9,7 @@ from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from ..actuation.robot_speech_agent import RobotSpeechAgent from ..actuation.robot_speech_agent import RobotSpeechAgent
from ..perception import VADAgent
class RICommunicationAgent(BaseAgent): class RICommunicationAgent(BaseAgent):
@@ -185,6 +186,9 @@ class RICommunicationAgent(BaseAgent):
bind=bind, bind=bind,
) )
await ri_commands_agent.start() await ri_commands_agent.start()
case "audio":
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
await vad_agent.start()
case _: case _:
self.logger.warning("Unhandled negotiation id: %s", id) self.logger.warning("Unhandled negotiation id: %s", id)

View File

@@ -8,6 +8,7 @@ import zmq.asyncio as azmq
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
from .transcription_agent.transcription_agent import TranscriptionAgent from .transcription_agent.transcription_agent import TranscriptionAgent
@@ -61,6 +62,7 @@ class VADAgent(BaseAgent):
:ivar audio_in_address: Address of the input audio stream. :ivar audio_in_address: Address of the input audio stream.
:ivar audio_in_bind: Whether to bind or connect to the input address. :ivar audio_in_bind: Whether to bind or connect to the input address.
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments. :ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
""" """
def __init__(self, audio_in_address: str, audio_in_bind: bool): def __init__(self, audio_in_address: str, audio_in_bind: bool):
@@ -79,6 +81,8 @@ class VADAgent(BaseAgent):
self.audio_out_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None
self.audio_in_poller: SocketPoller | None = None self.audio_in_poller: SocketPoller | None = None
self.program_sub_socket: azmq.Socket | None = None
self.audio_buffer = np.array([], dtype=np.float32) self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
self._ready = asyncio.Event() self._ready = asyncio.Event()
@@ -90,9 +94,10 @@ class VADAgent(BaseAgent):
1. Connects audio input socket. 1. Connects audio input socket.
2. Binds audio output socket (random port). 2. Binds audio output socket (random port).
3. Loads VAD model from Torch Hub. 3. Connects to program communication socket.
4. Starts the streaming loop. 4. Loads VAD model from Torch Hub.
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address. 5. Starts the streaming loop.
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
""" """
self.logger.info("Setting up %s", self.name) self.logger.info("Setting up %s", self.name)
@@ -105,6 +110,11 @@ class VADAgent(BaseAgent):
return return
audio_out_address = f"tcp://localhost:{audio_out_port}" audio_out_address = f"tcp://localhost:{audio_out_port}"
# Connect to internal communication socket
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.program_sub_socket.subscribe(PROGRAM_STATUS)
# Initialize VAD model # Initialize VAD model
try: try:
self.model, _ = torch.hub.load( self.model, _ = torch.hub.load(
@@ -117,10 +127,8 @@ class VADAgent(BaseAgent):
await self.stop() await self.stop()
return return
# Warmup/reset
await self.reset_stream()
self.add_behavior(self._streaming_loop()) self.add_behavior(self._streaming_loop())
self.add_behavior(self._status_loop())
# Start agents dependent on the output audio fragments here # Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address) transcriber = TranscriptionAgent(audio_out_address)
@@ -165,7 +173,7 @@ class VADAgent(BaseAgent):
self.audio_out_socket = None self.audio_out_socket = None
return None return None
async def reset_stream(self): async def _reset_stream(self):
""" """
Clears the ZeroMQ queue and sets ready state. Clears the ZeroMQ queue and sets ready state.
""" """
@@ -176,6 +184,23 @@ class VADAgent(BaseAgent):
self.logger.info(f"Discarded {discarded} audio packets before starting.") self.logger.info(f"Discarded {discarded} audio packets before starting.")
self._ready.set() self._ready.set()
async def _status_loop(self):
"""Loop for checking program status. Only start listening if program is RUNNING."""
while self._running:
topic, body = await self.program_sub_socket.recv_multipart()
if topic != PROGRAM_STATUS:
continue
if body != ProgramStatus.RUNNING.value:
continue
# Program is now running, we can start our stream
await self._reset_stream()
# We don't care about further status updates
self.program_sub_socket.close()
break
async def _streaming_loop(self): async def _streaming_loop(self):
""" """
Main loop for processing audio stream. Main loop for processing audio stream.

View File

@@ -17,7 +17,6 @@ class ZMQSettings(BaseModel):
internal_sub_address: str = "tcp://localhost:5561" internal_sub_address: str = "tcp://localhost:5561"
ri_command_address: str = "tcp://localhost:0000" ri_command_address: str = "tcp://localhost:0000"
ri_communication_address: str = "tcp://*:5555" ri_communication_address: str = "tcp://*:5555"
vad_agent_address: str = "tcp://localhost:5558"
class AgentSettings(BaseModel): class AgentSettings(BaseModel):

View File

@@ -39,13 +39,11 @@ from control_backend.agents.communication import RICommunicationAgent
# LLM Agents # LLM Agents
from control_backend.agents.llm import LLMAgent from control_backend.agents.llm import LLMAgent
# Perceive agents
from control_backend.agents.perception import VADAgent
# Other backend imports # Other backend imports
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.logging import setup_logging from control_backend.logging import setup_logging
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -95,6 +93,8 @@ async def lifespan(app: FastAPI):
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address) endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
app.state.endpoints_pub_socket = endpoints_pub_socket app.state.endpoints_pub_socket = endpoints_pub_socket
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STARTING.value])
# --- Initialize Agents --- # --- Initialize Agents ---
logger.info("Initializing and starting agents.") logger.info("Initializing and starting agents.")
@@ -132,10 +132,6 @@ async def lifespan(app: FastAPI):
"name": settings.agent_settings.text_belief_extractor_name, "name": settings.agent_settings.text_belief_extractor_name,
}, },
), ),
"VADAgent": (
VADAgent,
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
),
"ProgramManagerAgent": ( "ProgramManagerAgent": (
BDIProgramManager, BDIProgramManager,
{ {
@@ -146,32 +142,28 @@ async def lifespan(app: FastAPI):
agents = [] agents = []
vad_agent = None
for name, (agent_class, kwargs) in agents_to_start.items(): for name, (agent_class, kwargs) in agents_to_start.items():
try: try:
logger.debug("Starting agent: %s", name) logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**kwargs) agent_instance = agent_class(**kwargs)
await agent_instance.start() await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance
agents.append(agent_instance) agents.append(agent_instance)
logger.info("Agent '%s' started successfully.", name) logger.info("Agent '%s' started successfully.", name)
except Exception as e: except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True) logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
raise raise
assert vad_agent is not None
await vad_agent.reset_stream()
logger.info("Application startup complete.") logger.info("Application startup complete.")
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.RUNNING.value])
yield yield
# --- APPLICATION SHUTDOWN --- # --- APPLICATION SHUTDOWN ---
logger.info("%s is shutting down.", app.title) logger.info("%s is shutting down.", app.title)
# Potential shutdown logic goes here await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
# Additional shutdown logic goes here
logger.info("Application shutdown complete.") logger.info("Application shutdown complete.")

View File

@@ -0,0 +1,16 @@
from enum import Enum
PROGRAM_STATUS = b"internal/program_status"
"""A topic key for the program status."""
class ProgramStatus(Enum):
"""
Used in internal communication, to tell agents what the status of the program is.
For example, the VAD agent only starts listening when the program is RUNNING.
"""
STARTING = b"starting"
RUNNING = b"running"
STOPPING = b"stopping"

View File

@@ -5,6 +5,7 @@ import pytest
import zmq import zmq
from control_backend.agents.perception.vad_agent import VADAgent from control_backend.agents.perception.vad_agent import VADAgent
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
@pytest.fixture @pytest.fixture
@@ -43,14 +44,12 @@ async def test_normal_setup(per_transcription_agent):
coro.close() coro.close()
per_vad_agent.add_behavior = swallow_background_task per_vad_agent.add_behavior = swallow_background_task
per_vad_agent.reset_stream = AsyncMock()
await per_vad_agent.setup() await per_vad_agent.setup()
per_transcription_agent.assert_called_once() per_transcription_agent.assert_called_once()
per_transcription_agent.return_value.start.assert_called_once() per_transcription_agent.return_value.start.assert_called_once()
per_vad_agent._streaming_loop.assert_called_once() per_vad_agent._streaming_loop.assert_called_once()
per_vad_agent.reset_stream.assert_called_once()
assert per_vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
assert per_vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
@@ -103,7 +102,7 @@ async def test_out_socket_creation_failure(zmq_context):
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.stop = AsyncMock() per_vad_agent.stop = AsyncMock()
per_vad_agent.reset_stream = AsyncMock() per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock() per_vad_agent._streaming_loop = AsyncMock()
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None) per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
@@ -124,7 +123,7 @@ async def test_stop(zmq_context, per_transcription_agent):
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
per_vad_agent.reset_stream = AsyncMock() per_vad_agent._reset_stream = AsyncMock()
per_vad_agent._streaming_loop = AsyncMock() per_vad_agent._streaming_loop = AsyncMock()
async def swallow_background_task(coro): async def swallow_background_task(coro):
@@ -142,3 +141,66 @@ async def test_stop(zmq_context, per_transcription_agent):
assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert per_vad_agent.audio_in_socket is None assert per_vad_agent.audio_in_socket is None
assert per_vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None
@pytest.mark.asyncio
async def test_application_startup_complete(zmq_context):
"""Check that it resets the stream when the program finishes startup."""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
]
await vad_agent._status_loop()
vad_agent._reset_stream.assert_called_once()
vad_agent.program_sub_socket.close.assert_called_once()
@pytest.mark.asyncio
async def test_application_other_status(zmq_context):
"""
Check that it does nothing when the internal communication message is a status update, but not
running.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
]
try:
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()
@pytest.mark.asyncio
async def test_application_message_other(zmq_context):
"""
Check that it does nothing when there's an internal communication message that is not a status
update.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._running = True
vad_agent._reset_stream = AsyncMock()
vad_agent.program_sub_socket = AsyncMock()
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
try:
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
await vad_agent._status_loop()
except StopAsyncIteration:
pass
vad_agent._reset_stream.assert_not_called()