fix: move VAD agent creation to RI communication agent
Previously, it was started in main, but it should use values negotiated by the RI communication agent. ref: N25B-356
This commit is contained in:
@@ -9,6 +9,7 @@ from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
|
||||
from ..actuation.robot_speech_agent import RobotSpeechAgent
|
||||
from ..perception import VADAgent
|
||||
|
||||
|
||||
class RICommunicationAgent(BaseAgent):
|
||||
@@ -185,6 +186,9 @@ class RICommunicationAgent(BaseAgent):
|
||||
bind=bind,
|
||||
)
|
||||
await ri_commands_agent.start()
|
||||
case "audio":
|
||||
vad_agent = VADAgent(audio_in_address=addr, audio_in_bind=bind)
|
||||
await vad_agent.start()
|
||||
case _:
|
||||
self.logger.warning("Unhandled negotiation id: %s", id)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import zmq.asyncio as azmq
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
|
||||
from ...schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||
from .transcription_agent.transcription_agent import TranscriptionAgent
|
||||
|
||||
|
||||
@@ -61,6 +62,7 @@ class VADAgent(BaseAgent):
|
||||
:ivar audio_in_address: Address of the input audio stream.
|
||||
:ivar audio_in_bind: Whether to bind or connect to the input address.
|
||||
:ivar audio_out_socket: ZMQ PUB socket for sending speech fragments.
|
||||
:ivar program_sub_socket: ZMQ SUB socket for receiving program status updates.
|
||||
"""
|
||||
|
||||
def __init__(self, audio_in_address: str, audio_in_bind: bool):
|
||||
@@ -79,6 +81,8 @@ class VADAgent(BaseAgent):
|
||||
self.audio_out_socket: azmq.Socket | None = None
|
||||
self.audio_in_poller: SocketPoller | None = None
|
||||
|
||||
self.program_sub_socket: azmq.Socket | None = None
|
||||
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.i_since_speech = settings.behaviour_settings.vad_initial_since_speech
|
||||
self._ready = asyncio.Event()
|
||||
@@ -90,9 +94,10 @@ class VADAgent(BaseAgent):
|
||||
|
||||
1. Connects audio input socket.
|
||||
2. Binds audio output socket (random port).
|
||||
3. Loads VAD model from Torch Hub.
|
||||
4. Starts the streaming loop.
|
||||
5. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
|
||||
3. Connects to program communication socket.
|
||||
4. Loads VAD model from Torch Hub.
|
||||
5. Starts the streaming loop.
|
||||
6. Instantiates and starts the :class:`TranscriptionAgent` with the output address.
|
||||
"""
|
||||
self.logger.info("Setting up %s", self.name)
|
||||
|
||||
@@ -105,6 +110,11 @@ class VADAgent(BaseAgent):
|
||||
return
|
||||
audio_out_address = f"tcp://localhost:{audio_out_port}"
|
||||
|
||||
# Connect to internal communication socket
|
||||
self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||
self.program_sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.program_sub_socket.subscribe(PROGRAM_STATUS)
|
||||
|
||||
# Initialize VAD model
|
||||
try:
|
||||
self.model, _ = torch.hub.load(
|
||||
@@ -117,10 +127,8 @@ class VADAgent(BaseAgent):
|
||||
await self.stop()
|
||||
return
|
||||
|
||||
# Warmup/reset
|
||||
await self.reset_stream()
|
||||
|
||||
self.add_behavior(self._streaming_loop())
|
||||
self.add_behavior(self._status_loop())
|
||||
|
||||
# Start agents dependent on the output audio fragments here
|
||||
transcriber = TranscriptionAgent(audio_out_address)
|
||||
@@ -165,7 +173,7 @@ class VADAgent(BaseAgent):
|
||||
self.audio_out_socket = None
|
||||
return None
|
||||
|
||||
async def reset_stream(self):
|
||||
async def _reset_stream(self):
|
||||
"""
|
||||
Clears the ZeroMQ queue and sets ready state.
|
||||
"""
|
||||
@@ -176,6 +184,23 @@ class VADAgent(BaseAgent):
|
||||
self.logger.info(f"Discarded {discarded} audio packets before starting.")
|
||||
self._ready.set()
|
||||
|
||||
async def _status_loop(self):
|
||||
"""Loop for checking program status. Only start listening if program is RUNNING."""
|
||||
while self._running:
|
||||
topic, body = await self.program_sub_socket.recv_multipart()
|
||||
|
||||
if topic != PROGRAM_STATUS:
|
||||
continue
|
||||
if body != ProgramStatus.RUNNING.value:
|
||||
continue
|
||||
|
||||
# Program is now running, we can start our stream
|
||||
await self._reset_stream()
|
||||
|
||||
# We don't care about further status updates
|
||||
self.program_sub_socket.close()
|
||||
break
|
||||
|
||||
async def _streaming_loop(self):
|
||||
"""
|
||||
Main loop for processing audio stream.
|
||||
|
||||
@@ -17,7 +17,6 @@ class ZMQSettings(BaseModel):
|
||||
internal_sub_address: str = "tcp://localhost:5561"
|
||||
ri_command_address: str = "tcp://localhost:0000"
|
||||
ri_communication_address: str = "tcp://*:5555"
|
||||
vad_agent_address: str = "tcp://localhost:5558"
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
|
||||
@@ -39,13 +39,11 @@ from control_backend.agents.communication import RICommunicationAgent
|
||||
# LLM Agents
|
||||
from control_backend.agents.llm import LLMAgent
|
||||
|
||||
# Perceive agents
|
||||
from control_backend.agents.perception import VADAgent
|
||||
|
||||
# Other backend imports
|
||||
from control_backend.api.v1.router import api_router
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.logging import setup_logging
|
||||
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -95,6 +93,8 @@ async def lifespan(app: FastAPI):
|
||||
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||
app.state.endpoints_pub_socket = endpoints_pub_socket
|
||||
|
||||
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STARTING.value])
|
||||
|
||||
# --- Initialize Agents ---
|
||||
logger.info("Initializing and starting agents.")
|
||||
|
||||
@@ -132,10 +132,6 @@ async def lifespan(app: FastAPI):
|
||||
"name": settings.agent_settings.text_belief_extractor_name,
|
||||
},
|
||||
),
|
||||
"VADAgent": (
|
||||
VADAgent,
|
||||
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
|
||||
),
|
||||
"ProgramManagerAgent": (
|
||||
BDIProgramManager,
|
||||
{
|
||||
@@ -146,32 +142,28 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
agents = []
|
||||
|
||||
vad_agent = None
|
||||
|
||||
for name, (agent_class, kwargs) in agents_to_start.items():
|
||||
try:
|
||||
logger.debug("Starting agent: %s", name)
|
||||
agent_instance = agent_class(**kwargs)
|
||||
await agent_instance.start()
|
||||
if isinstance(agent_instance, VADAgent):
|
||||
vad_agent = agent_instance
|
||||
agents.append(agent_instance)
|
||||
logger.info("Agent '%s' started successfully.", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
|
||||
raise
|
||||
|
||||
assert vad_agent is not None
|
||||
await vad_agent.reset_stream()
|
||||
|
||||
logger.info("Application startup complete.")
|
||||
|
||||
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.RUNNING.value])
|
||||
|
||||
yield
|
||||
|
||||
# --- APPLICATION SHUTDOWN ---
|
||||
logger.info("%s is shutting down.", app.title)
|
||||
|
||||
# Potential shutdown logic goes here
|
||||
await endpoints_pub_socket.send_multipart([PROGRAM_STATUS, ProgramStatus.STOPPING.value])
|
||||
# Additional shutdown logic goes here
|
||||
|
||||
logger.info("Application shutdown complete.")
|
||||
|
||||
|
||||
16
src/control_backend/schemas/program_status.py
Normal file
16
src/control_backend/schemas/program_status.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from enum import Enum
|
||||
|
||||
PROGRAM_STATUS = b"internal/program_status"
|
||||
"""A topic key for the program status."""
|
||||
|
||||
|
||||
class ProgramStatus(Enum):
|
||||
"""
|
||||
Used in internal communication, to tell agents what the status of the program is.
|
||||
|
||||
For example, the VAD agent only starts listening when the program is RUNNING.
|
||||
"""
|
||||
|
||||
STARTING = b"starting"
|
||||
RUNNING = b"running"
|
||||
STOPPING = b"stopping"
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.perception.vad_agent import VADAgent
|
||||
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -43,14 +44,12 @@ async def test_normal_setup(per_transcription_agent):
|
||||
coro.close()
|
||||
|
||||
per_vad_agent.add_behavior = swallow_background_task
|
||||
per_vad_agent.reset_stream = AsyncMock()
|
||||
|
||||
await per_vad_agent.setup()
|
||||
|
||||
per_transcription_agent.assert_called_once()
|
||||
per_transcription_agent.return_value.start.assert_called_once()
|
||||
per_vad_agent._streaming_loop.assert_called_once()
|
||||
per_vad_agent.reset_stream.assert_called_once()
|
||||
assert per_vad_agent.audio_in_socket is not None
|
||||
assert per_vad_agent.audio_out_socket is not None
|
||||
|
||||
@@ -103,7 +102,7 @@ async def test_out_socket_creation_failure(zmq_context):
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent.stop = AsyncMock()
|
||||
per_vad_agent.reset_stream = AsyncMock()
|
||||
per_vad_agent._reset_stream = AsyncMock()
|
||||
per_vad_agent._streaming_loop = AsyncMock()
|
||||
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
|
||||
|
||||
@@ -124,7 +123,7 @@ async def test_stop(zmq_context, per_transcription_agent):
|
||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent.reset_stream = AsyncMock()
|
||||
per_vad_agent._reset_stream = AsyncMock()
|
||||
per_vad_agent._streaming_loop = AsyncMock()
|
||||
|
||||
async def swallow_background_task(coro):
|
||||
@@ -142,3 +141,66 @@ async def test_stop(zmq_context, per_transcription_agent):
|
||||
assert zmq_context.return_value.socket.return_value.close.call_count == 2
|
||||
assert per_vad_agent.audio_in_socket is None
|
||||
assert per_vad_agent.audio_out_socket is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_application_startup_complete(zmq_context):
|
||||
"""Check that it resets the stream when the program finishes startup."""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent._running = True
|
||||
vad_agent._reset_stream = AsyncMock()
|
||||
vad_agent.program_sub_socket = AsyncMock()
|
||||
vad_agent.program_sub_socket.recv_multipart.side_effect = [
|
||||
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
|
||||
]
|
||||
|
||||
await vad_agent._status_loop()
|
||||
|
||||
vad_agent._reset_stream.assert_called_once()
|
||||
vad_agent.program_sub_socket.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_application_other_status(zmq_context):
|
||||
"""
|
||||
Check that it does nothing when the internal communication message is a status update, but not
|
||||
running.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent._running = True
|
||||
vad_agent._reset_stream = AsyncMock()
|
||||
vad_agent.program_sub_socket = AsyncMock()
|
||||
|
||||
vad_agent.program_sub_socket.recv_multipart.side_effect = [
|
||||
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
|
||||
(PROGRAM_STATUS, ProgramStatus.STOPPING.value),
|
||||
]
|
||||
try:
|
||||
# Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart`
|
||||
await vad_agent._status_loop()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
vad_agent._reset_stream.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_application_message_other(zmq_context):
|
||||
"""
|
||||
Check that it does nothing when there's an internal communication message that is not a status
|
||||
update.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent._running = True
|
||||
vad_agent._reset_stream = AsyncMock()
|
||||
vad_agent.program_sub_socket = AsyncMock()
|
||||
|
||||
vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")]
|
||||
|
||||
try:
|
||||
# Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart`
|
||||
await vad_agent._status_loop()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
vad_agent._reset_stream.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user