import random from unittest.mock import AsyncMock, MagicMock import pytest import zmq from control_backend.agents.perception.vad_agent import VADAgent from control_backend.schemas.program_status import PROGRAM_STATUS, ProgramStatus @pytest.fixture def zmq_context(mocker): mock_context = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") mock_context.return_value = MagicMock() return mock_context @pytest.fixture def per_transcription_agent(mocker): return mocker.patch( "control_backend.agents.perception.vad_agent.TranscriptionAgent", autospec=True ) @pytest.fixture(autouse=True) def torch_load(mocker): mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch") model = MagicMock() mock_torch.hub.load.return_value = (model, None) mock_torch.from_numpy.side_effect = lambda arr: arr return mock_torch @pytest.mark.asyncio async def test_normal_setup(per_transcription_agent): """ Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio sockets, and starts the TranscriptionAgent without loading real models. """ per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent._streaming_loop = AsyncMock() def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task await per_vad_agent.setup() per_transcription_agent.assert_called_once() per_transcription_agent.return_value.start.assert_called_once() per_vad_agent._streaming_loop.assert_called_once() assert per_vad_agent.audio_in_socket is not None assert per_vad_agent.audio_out_socket is not None @pytest.mark.parametrize("do_bind", [True, False]) def test_in_socket_creation(zmq_context, do_bind: bool): """ Test that the VAD agent creates an audio input socket, differentiating between binding and connecting. """ per_vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind) per_vad_agent._connect_audio_in_socket() assert per_vad_agent.audio_in_socket is not None zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with( zmq.SUBSCRIBE, "", ) if do_bind: zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345") else: zmq_context.return_value.socket.return_value.connect.assert_called_once_with( "tcp://localhost:12345" ) def test_out_socket_creation(zmq_context): """ Test that the VAD agent creates an audio output socket correctly. """ per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent._connect_audio_out_socket() assert per_vad_agent.audio_out_socket is not None zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.return_value.bind.assert_called_once_with("inproc://vad_stream") @pytest.mark.asyncio async def test_out_socket_creation_failure(zmq_context): """ Test setup failure when the audio output socket cannot be created. """ zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent.stop = AsyncMock() per_vad_agent._reset_stream = AsyncMock() per_vad_agent._streaming_loop = AsyncMock() per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None) def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task await per_vad_agent.setup() assert per_vad_agent.audio_out_socket is None per_vad_agent.stop.assert_called_once() @pytest.mark.asyncio async def test_stop(zmq_context, per_transcription_agent): """ Test that when the VAD agent is stopped, the sockets are closed correctly. """ per_vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent._reset_stream = AsyncMock() per_vad_agent._streaming_loop = AsyncMock() def swallow_background_task(coro): coro.close() per_vad_agent.add_behavior = swallow_background_task zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( 1000, 10000, ) await per_vad_agent.setup() await per_vad_agent.stop() assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert per_vad_agent.audio_in_socket is None assert per_vad_agent.audio_out_socket is None @pytest.mark.asyncio async def test_application_startup_complete(zmq_context): """Check that it resets the stream when the program finishes startup.""" vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent._running = True vad_agent._reset_stream = AsyncMock() vad_agent.program_sub_socket = AsyncMock() vad_agent.program_sub_socket.close = MagicMock() vad_agent.program_sub_socket.recv_multipart.side_effect = [ (PROGRAM_STATUS, ProgramStatus.RUNNING.value), ] await vad_agent._status_loop() vad_agent._reset_stream.assert_called_once() vad_agent.program_sub_socket.close.assert_called_once() @pytest.mark.asyncio async def test_application_other_status(zmq_context): """ Check that it does nothing when the internal communication message is a status update, but not running. """ vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent._running = True vad_agent._reset_stream = AsyncMock() vad_agent.program_sub_socket = AsyncMock() vad_agent.program_sub_socket.recv_multipart.side_effect = [ (PROGRAM_STATUS, ProgramStatus.STARTING.value), (PROGRAM_STATUS, ProgramStatus.STOPPING.value), ] try: # Raises StopAsyncIteration the third time it calls `program_sub_socket.recv_multipart` await vad_agent._status_loop() except StopAsyncIteration: pass vad_agent._reset_stream.assert_not_called() @pytest.mark.asyncio async def test_application_message_other(zmq_context): """ Check that it does nothing when there's an internal communication message that is not a status update. """ vad_agent = VADAgent("tcp://localhost:12345", False) vad_agent._running = True vad_agent._reset_stream = AsyncMock() vad_agent.program_sub_socket = AsyncMock() vad_agent.program_sub_socket.recv_multipart.side_effect = [(b"internal/other", b"Whatever")] try: # Raises StopAsyncIteration the second time it calls `program_sub_socket.recv_multipart` await vad_agent._status_loop() except StopAsyncIteration: pass vad_agent._reset_stream.assert_not_called()