import asyncio from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest from control_backend.agents.perception.transcription_agent.speech_recognizer import ( MLXWhisperSpeechRecognizer, OpenAIWhisperSpeechRecognizer, SpeechRecognizer, ) from control_backend.agents.perception.transcription_agent.transcription_agent import ( TranscriptionAgent, ) @pytest.mark.asyncio async def test_transcription_agent_flow(mock_zmq_context): mock_sub = MagicMock() mock_sub.recv = AsyncMock() # Setup context to return this specific mock socket mock_zmq_context.instance.return_value.socket.return_value = mock_sub # Data: [Audio Bytes, Cancel Loop] fake_audio = np.zeros(16000, dtype=np.float32).tobytes() mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()] # Mock Recognizer with patch.object(SpeechRecognizer, "best_type") as mock_best: mock_recognizer = MagicMock() mock_recognizer.recognize_speech.return_value = "Hello" mock_best.return_value = mock_recognizer agent = TranscriptionAgent("tcp://in") agent.send = AsyncMock() agent._running = True agent.add_background_task = AsyncMock() await agent.setup() try: await agent._transcribing_loop() except asyncio.CancelledError: pass # Check transcription happened assert mock_recognizer.recognize_speech.called # Check sending assert agent.send.called assert agent.send.call_args[0][0].body == "Hello" await agent.stop() @pytest.mark.asyncio async def test_transcription_empty(mock_zmq_context): mock_sub = MagicMock() mock_sub.recv = AsyncMock() mock_zmq_context.instance.return_value.socket.return_value = mock_sub # Return valid audio, but recognizer returns empty string fake_audio = np.zeros(10, dtype=np.float32).tobytes() mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()] with patch.object(SpeechRecognizer, "best_type") as mock_best: mock_recognizer = MagicMock() mock_recognizer.recognize_speech.return_value = "" mock_best.return_value = mock_recognizer agent = TranscriptionAgent("tcp://in") agent.send = AsyncMock() await agent.setup() try: await agent._transcribing_loop() except asyncio.CancelledError: pass # Should NOT send message agent.send.assert_not_called() def test_speech_recognizer_factory(): # Test Factory Logic with patch("torch.mps.is_available", return_value=True): assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer) with patch("torch.mps.is_available", return_value=False): assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer) def test_openai_recognizer(): with patch("whisper.load_model") as load_mock: with patch("whisper.transcribe") as trans_mock: rec = OpenAIWhisperSpeechRecognizer() rec.load_model() load_mock.assert_called() trans_mock.return_value = {"text": "Hi"} res = rec.recognize_speech(np.zeros(10)) assert res == "Hi" def test_mlx_recognizer(): # Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir(). # We must use create=True to inject it into the module namespace during the test. module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer" with patch("sys.platform", "darwin"): with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock: with patch(f"{module_path}.ModelHolder", create=True) as holder_mock: # We also need to mock mlx.core if it's used for types/constants with patch(f"{module_path}.mx", create=True): rec = MLXWhisperSpeechRecognizer() rec.load_model() holder_mock.get_model.assert_called() mlx_mock.transcribe.return_value = {"text": "Hi"} res = rec.recognize_speech(np.zeros(10)) assert res == "Hi"