153 lines
4.7 KiB
Python
153 lines
4.7 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from control_backend.agents.perception.vad_agent import VADAgent
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_zmq():
|
|
with patch("zmq.asyncio.Context") as mock:
|
|
mock.instance.return_value = MagicMock()
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def agent():
|
|
return VADAgent("tcp://localhost:5555", False)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_pause(agent):
|
|
agent._paused = MagicMock()
|
|
# It starts set (not paused)
|
|
|
|
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="PAUSE")
|
|
|
|
# We need to mock settings to match sender name
|
|
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
|
|
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
agent._paused.clear.assert_called_once()
|
|
assert agent._reset_needed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_resume(agent):
|
|
agent._paused = MagicMock()
|
|
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="RESUME")
|
|
|
|
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
|
|
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
agent._paused.set.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_unknown_command(agent):
|
|
agent._paused = MagicMock()
|
|
msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="UNKNOWN")
|
|
|
|
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
|
|
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
|
agent.logger = MagicMock()
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
agent._paused.clear.assert_not_called()
|
|
agent._paused.set.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_unknown_sender(agent):
|
|
agent._paused = MagicMock()
|
|
msg = InternalMessage(to="vad", sender="other_agent", body="PAUSE")
|
|
|
|
with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings:
|
|
mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent"
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
agent._paused.clear.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_status_loop_waits_for_running(agent):
|
|
agent._running = True
|
|
agent.program_sub_socket = AsyncMock()
|
|
agent.program_sub_socket.close = MagicMock()
|
|
agent._reset_stream = AsyncMock()
|
|
|
|
# Sequence of messages:
|
|
# 1. Wrong topic
|
|
# 2. Right topic, wrong status (STARTING)
|
|
# 3. Right topic, RUNNING -> Should break loop
|
|
|
|
agent.program_sub_socket.recv_multipart.side_effect = [
|
|
(b"wrong_topic", b"whatever"),
|
|
(PROGRAM_STATUS, ProgramStatus.STARTING.value),
|
|
(PROGRAM_STATUS, ProgramStatus.RUNNING.value),
|
|
]
|
|
|
|
await agent._status_loop()
|
|
|
|
assert agent._reset_stream.await_count == 1
|
|
agent.program_sub_socket.close.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_setup_success(agent, mock_zmq):
|
|
def close_coro(coro):
|
|
coro.close()
|
|
return MagicMock()
|
|
|
|
agent.add_behavior = MagicMock(side_effect=close_coro)
|
|
|
|
mock_context = mock_zmq.instance.return_value
|
|
mock_sub = MagicMock()
|
|
mock_pub = MagicMock()
|
|
|
|
# We expect multiple socket calls:
|
|
# 1. audio_in (SUB)
|
|
# 2. audio_out (PUB)
|
|
# 3. program_sub (SUB)
|
|
mock_context.socket.side_effect = [mock_sub, mock_pub, mock_sub]
|
|
|
|
with patch("control_backend.agents.perception.vad_agent.torch.hub.load") as mock_load:
|
|
mock_load.return_value = (MagicMock(), None)
|
|
|
|
with patch("control_backend.agents.perception.vad_agent.TranscriptionAgent") as MockTrans:
|
|
mock_trans_instance = MockTrans.return_value
|
|
mock_trans_instance.start = AsyncMock()
|
|
|
|
await agent.setup()
|
|
|
|
mock_trans_instance.start.assert_awaited_once()
|
|
|
|
assert agent.add_behavior.call_count == 2 # streaming_loop + status_loop
|
|
assert agent.audio_in_socket is not None
|
|
assert agent.audio_out_socket is not None
|
|
assert agent.program_sub_socket is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset_stream(agent):
|
|
mock_poller = MagicMock()
|
|
agent.audio_in_poller = mock_poller
|
|
|
|
# poll(1) returns not None twice, then None
|
|
mock_poller.poll = AsyncMock(side_effect=[b"data", b"data", None])
|
|
|
|
agent._ready = MagicMock()
|
|
|
|
await agent._reset_stream()
|
|
|
|
assert mock_poller.poll.await_count == 3
|
|
agent._ready.set.assert_called_once()
|