from unittest.mock import AsyncMock, MagicMock, patch import pytest from control_backend.agents.perception.vad_agent import VADAgent from control_backend.core.agent_system import InternalMessage from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus @pytest.fixture(autouse=True) def mock_zmq(): with patch("zmq.asyncio.Context") as mock: mock.instance.return_value = MagicMock() yield mock @pytest.fixture def agent(): return VADAgent("tcp://localhost:5555", False) @pytest.mark.asyncio async def test_handle_message_pause(agent): agent._paused = MagicMock() # It starts set (not paused) msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="PAUSE") # We need to mock settings to match sender name with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" await agent.handle_message(msg) agent._paused.clear.assert_called_once() assert agent._reset_needed is True @pytest.mark.asyncio async def test_handle_message_resume(agent): agent._paused = MagicMock() msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="RESUME") with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" await agent.handle_message(msg) agent._paused.set.assert_called_once() @pytest.mark.asyncio async def test_handle_message_unknown_command(agent): agent._paused = MagicMock() msg = InternalMessage(to="vad", sender="user_interrupt_agent", body="UNKNOWN") with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" agent.logger = MagicMock() await agent.handle_message(msg) agent.logger.warning.assert_called() agent._paused.clear.assert_not_called() agent._paused.set.assert_not_called() @pytest.mark.asyncio async def test_handle_message_unknown_sender(agent): agent._paused = MagicMock() msg = InternalMessage(to="vad", sender="other_agent", body="PAUSE") with patch("control_backend.agents.perception.vad_agent.settings") as mock_settings: mock_settings.agent_settings.user_interrupt_name = "user_interrupt_agent" await agent.handle_message(msg) agent._paused.clear.assert_not_called() @pytest.mark.asyncio async def test_status_loop_waits_for_running(agent): agent._running = True agent.program_sub_socket = AsyncMock() agent.program_sub_socket.close = MagicMock() agent._reset_stream = AsyncMock() # Sequence of messages: # 1. Wrong topic # 2. Right topic, wrong status (STARTING) # 3. Right topic, RUNNING -> Should break loop agent.program_sub_socket.recv_multipart.side_effect = [ (b"wrong_topic", b"whatever"), (PROGRAM_STATUS, ProgramStatus.STARTING.value), (PROGRAM_STATUS, ProgramStatus.RUNNING.value), ] await agent._status_loop() assert agent._reset_stream.await_count == 1 agent.program_sub_socket.close.assert_called_once() @pytest.mark.asyncio async def test_setup_success(agent, mock_zmq): def close_coro(coro): coro.close() return MagicMock() agent.add_behavior = MagicMock(side_effect=close_coro) mock_context = mock_zmq.instance.return_value mock_sub = MagicMock() mock_pub = MagicMock() # We expect multiple socket calls: # 1. audio_in (SUB) # 2. audio_out (PUB) # 3. program_sub (SUB) mock_context.socket.side_effect = [mock_sub, mock_pub, mock_sub] with patch("control_backend.agents.perception.vad_agent.torch.hub.load") as mock_load: mock_load.return_value = (MagicMock(), None) with patch("control_backend.agents.perception.vad_agent.TranscriptionAgent") as MockTrans: mock_trans_instance = MockTrans.return_value mock_trans_instance.start = AsyncMock() await agent.setup() mock_trans_instance.start.assert_awaited_once() assert agent.add_behavior.call_count == 2 # streaming_loop + status_loop assert agent.audio_in_socket is not None assert agent.audio_out_socket is not None assert agent.program_sub_socket is not None @pytest.mark.asyncio async def test_reset_stream(agent): mock_poller = MagicMock() agent.audio_in_poller = mock_poller # poll(1) returns not None twice, then None mock_poller.poll = AsyncMock(side_effect=[b"data", b"data", None]) agent._ready = MagicMock() await agent._reset_stream() assert mock_poller.poll.await_count == 3 agent._ready.set.assert_called_once()