227 lines
7.4 KiB
Python
227 lines
7.4 KiB
Python
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
|
|
MLXWhisperSpeechRecognizer,
|
|
OpenAIWhisperSpeechRecognizer,
|
|
SpeechRecognizer,
|
|
)
|
|
from control_backend.agents.perception.transcription_agent.transcription_agent import (
|
|
TranscriptionAgent,
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_experiment_logger():
|
|
with patch(
|
|
"control_backend.agents.perception"
|
|
".transcription_agent.transcription_agent.experiment_logger"
|
|
) as logger:
|
|
yield logger
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcription_agent_flow(mock_zmq_context):
|
|
mock_sub = MagicMock()
|
|
mock_sub.recv = AsyncMock()
|
|
|
|
# Setup context to return this specific mock socket
|
|
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
|
|
|
# Data: [Audio Bytes, Cancel Loop]
|
|
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
|
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
|
|
|
# Mock Recognizer
|
|
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
|
mock_recognizer = MagicMock()
|
|
mock_recognizer.recognize_speech.return_value = "Hello"
|
|
mock_best.return_value = mock_recognizer
|
|
|
|
agent = TranscriptionAgent("tcp://in")
|
|
agent.send = AsyncMock()
|
|
|
|
agent._running = True
|
|
|
|
def close_coro(coro):
|
|
coro.close()
|
|
return MagicMock()
|
|
|
|
agent.add_behavior = MagicMock(side_effect=close_coro)
|
|
|
|
await agent.setup()
|
|
|
|
try:
|
|
await agent._transcribing_loop()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Check transcription happened
|
|
assert mock_recognizer.recognize_speech.called
|
|
# Check sending
|
|
assert agent.send.called
|
|
assert agent.send.call_args[0][0].body == "Hello"
|
|
|
|
await agent.stop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcription_empty(mock_zmq_context):
|
|
mock_sub = MagicMock()
|
|
mock_sub.recv = AsyncMock()
|
|
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
|
|
|
# Return valid audio, but recognizer returns empty string
|
|
fake_audio = np.zeros(10, dtype=np.float32).tobytes()
|
|
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
|
|
|
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
|
mock_recognizer = MagicMock()
|
|
mock_recognizer.recognize_speech.return_value = ""
|
|
mock_best.return_value = mock_recognizer
|
|
|
|
agent = TranscriptionAgent("tcp://in")
|
|
agent.send = AsyncMock()
|
|
await agent.setup()
|
|
|
|
try:
|
|
await agent._transcribing_loop()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Should NOT send message
|
|
agent.send.assert_not_called()
|
|
|
|
|
|
def test_speech_recognizer_factory():
|
|
# Test Factory Logic
|
|
with patch("torch.mps.is_available", return_value=True):
|
|
assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer)
|
|
|
|
with patch("torch.mps.is_available", return_value=False):
|
|
assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer)
|
|
|
|
|
|
def test_openai_recognizer():
|
|
with patch("whisper.load_model") as load_mock:
|
|
with patch("whisper.transcribe") as trans_mock:
|
|
rec = OpenAIWhisperSpeechRecognizer()
|
|
rec.load_model()
|
|
load_mock.assert_called()
|
|
|
|
trans_mock.return_value = {"text": "Hi"}
|
|
res = rec.recognize_speech(np.zeros(10))
|
|
assert res == "Hi"
|
|
|
|
|
|
def test_mlx_recognizer():
|
|
# Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir().
|
|
# We must use create=True to inject it into the module namespace during the test.
|
|
module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer"
|
|
|
|
with patch("sys.platform", "darwin"):
|
|
with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock:
|
|
with patch(f"{module_path}.ModelHolder", create=True) as holder_mock:
|
|
# We also need to mock mlx.core if it's used for types/constants
|
|
with patch(f"{module_path}.mx", create=True):
|
|
rec = MLXWhisperSpeechRecognizer()
|
|
rec.load_model()
|
|
holder_mock.get_model.assert_called()
|
|
|
|
mlx_mock.transcribe.return_value = {"text": "Hi"}
|
|
res = rec.recognize_speech(np.zeros(10))
|
|
assert res == "Hi"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcription_loop_continues_after_error(mock_zmq_context):
|
|
mock_sub = MagicMock()
|
|
mock_sub.recv = AsyncMock()
|
|
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
|
|
|
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
|
|
|
mock_sub.recv.side_effect = [
|
|
fake_audio, # first iteration → recognizer fails
|
|
asyncio.CancelledError(), # second iteration → stop loop
|
|
]
|
|
|
|
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
|
mock_recognizer = MagicMock()
|
|
mock_recognizer.recognize_speech.side_effect = RuntimeError("fail")
|
|
mock_best.return_value = mock_recognizer
|
|
|
|
agent = TranscriptionAgent("tcp://in")
|
|
agent._running = True # ← REQUIRED to enter the loop
|
|
agent.send = AsyncMock() # should never be called
|
|
|
|
def close_coro(coro):
|
|
coro.close()
|
|
return MagicMock()
|
|
|
|
agent.add_behavior = MagicMock(side_effect=close_coro) # match other tests
|
|
|
|
await agent.setup()
|
|
|
|
try:
|
|
await agent._transcribing_loop()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# recognizer failed, so we should never send anything
|
|
agent.send.assert_not_called()
|
|
|
|
# recv must have been called twice (audio then CancelledError)
|
|
assert mock_sub.recv.call_count == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcription_continue_branch_when_empty(mock_zmq_context):
|
|
mock_sub = MagicMock()
|
|
mock_sub.recv = AsyncMock()
|
|
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
|
|
|
# First recv → audio chunk
|
|
# Second recv → Cancel loop → stop iteration
|
|
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
|
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
|
|
|
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
|
mock_recognizer = MagicMock()
|
|
mock_recognizer.recognize_speech.return_value = "" # <— triggers the continue branch
|
|
mock_best.return_value = mock_recognizer
|
|
|
|
agent = TranscriptionAgent("tcp://in")
|
|
|
|
# Make loop runnable
|
|
agent._running = True
|
|
agent.send = AsyncMock()
|
|
|
|
def close_coro(coro):
|
|
coro.close()
|
|
return MagicMock()
|
|
|
|
agent.add_behavior = MagicMock(side_effect=close_coro)
|
|
|
|
await agent.setup()
|
|
|
|
# Execute loop manually
|
|
try:
|
|
await agent._transcribing_loop()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# → Because of "continue", NO sending should occur
|
|
agent.send.assert_not_called()
|
|
|
|
# → Continue was hit, so we must have read exactly 2 times:
|
|
# - first audio
|
|
# - second CancelledError
|
|
assert mock_sub.recv.call_count == 2
|
|
|
|
# → recognizer was called once (first iteration)
|
|
assert mock_recognizer.recognize_speech.call_count == 1
|