123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
|
|
MLXWhisperSpeechRecognizer,
|
|
OpenAIWhisperSpeechRecognizer,
|
|
SpeechRecognizer,
|
|
)
|
|
from control_backend.agents.perception.transcription_agent.transcription_agent import (
|
|
TranscriptionAgent,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcription_agent_flow(mock_zmq_context):
|
|
mock_sub = MagicMock()
|
|
mock_sub.recv = AsyncMock()
|
|
|
|
# Setup context to return this specific mock socket
|
|
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
|
|
|
# Data: [Audio Bytes, Cancel Loop]
|
|
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
|
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
|
|
|
# Mock Recognizer
|
|
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
|
mock_recognizer = MagicMock()
|
|
mock_recognizer.recognize_speech.return_value = "Hello"
|
|
mock_best.return_value = mock_recognizer
|
|
|
|
agent = TranscriptionAgent("tcp://in")
|
|
agent.send = AsyncMock()
|
|
|
|
agent._running = True
|
|
agent.add_behavior = AsyncMock()
|
|
|
|
await agent.setup()
|
|
|
|
try:
|
|
await agent._transcribing_loop()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Check transcription happened
|
|
assert mock_recognizer.recognize_speech.called
|
|
# Check sending
|
|
assert agent.send.called
|
|
assert agent.send.call_args[0][0].body == "Hello"
|
|
|
|
await agent.stop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcription_empty(mock_zmq_context):
|
|
mock_sub = MagicMock()
|
|
mock_sub.recv = AsyncMock()
|
|
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
|
|
|
# Return valid audio, but recognizer returns empty string
|
|
fake_audio = np.zeros(10, dtype=np.float32).tobytes()
|
|
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
|
|
|
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
|
mock_recognizer = MagicMock()
|
|
mock_recognizer.recognize_speech.return_value = ""
|
|
mock_best.return_value = mock_recognizer
|
|
|
|
agent = TranscriptionAgent("tcp://in")
|
|
agent.send = AsyncMock()
|
|
await agent.setup()
|
|
|
|
try:
|
|
await agent._transcribing_loop()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Should NOT send message
|
|
agent.send.assert_not_called()
|
|
|
|
|
|
def test_speech_recognizer_factory():
|
|
# Test Factory Logic
|
|
with patch("torch.mps.is_available", return_value=True):
|
|
assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer)
|
|
|
|
with patch("torch.mps.is_available", return_value=False):
|
|
assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer)
|
|
|
|
|
|
def test_openai_recognizer():
|
|
with patch("whisper.load_model") as load_mock:
|
|
with patch("whisper.transcribe") as trans_mock:
|
|
rec = OpenAIWhisperSpeechRecognizer()
|
|
rec.load_model()
|
|
load_mock.assert_called()
|
|
|
|
trans_mock.return_value = {"text": "Hi"}
|
|
res = rec.recognize_speech(np.zeros(10))
|
|
assert res == "Hi"
|
|
|
|
|
|
def test_mlx_recognizer():
|
|
# Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir().
|
|
# We must use create=True to inject it into the module namespace during the test.
|
|
module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer"
|
|
|
|
with patch("sys.platform", "darwin"):
|
|
with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock:
|
|
with patch(f"{module_path}.ModelHolder", create=True) as holder_mock:
|
|
# We also need to mock mlx.core if it's used for types/constants
|
|
with patch(f"{module_path}.mx", create=True):
|
|
rec = MLXWhisperSpeechRecognizer()
|
|
rec.load_model()
|
|
holder_mock.get_model.assert_called()
|
|
|
|
mlx_mock.transcribe.return_value = {"text": "Hi"}
|
|
res = rec.recognize_speech(np.zeros(10))
|
|
assert res == "Hi"
|