refactor: testing
Redid testing structure, added tests and changed some tests. ref: N25B-301
This commit is contained in:
@@ -0,0 +1,122 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
|
||||
MLXWhisperSpeechRecognizer,
|
||||
OpenAIWhisperSpeechRecognizer,
|
||||
SpeechRecognizer,
|
||||
)
|
||||
from control_backend.agents.perception.transcription_agent.transcription_agent import (
|
||||
TranscriptionAgent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_agent_flow(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.recv = AsyncMock()
|
||||
|
||||
# Setup context to return this specific mock socket
|
||||
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
||||
|
||||
# Data: [Audio Bytes, Cancel Loop]
|
||||
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
||||
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
||||
|
||||
# Mock Recognizer
|
||||
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
||||
mock_recognizer = MagicMock()
|
||||
mock_recognizer.recognize_speech.return_value = "Hello"
|
||||
mock_best.return_value = mock_recognizer
|
||||
|
||||
agent = TranscriptionAgent("tcp://in")
|
||||
agent.send = AsyncMock()
|
||||
|
||||
agent._running = True
|
||||
agent.add_background_task = AsyncMock()
|
||||
|
||||
await agent.setup()
|
||||
|
||||
try:
|
||||
await agent._transcribing_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Check transcription happened
|
||||
assert mock_recognizer.recognize_speech.called
|
||||
# Check sending
|
||||
assert agent.send.called
|
||||
assert agent.send.call_args[0][0].body == "Hello"
|
||||
|
||||
await agent.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_empty(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.recv = AsyncMock()
|
||||
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
||||
|
||||
# Return valid audio, but recognizer returns empty string
|
||||
fake_audio = np.zeros(10, dtype=np.float32).tobytes()
|
||||
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
||||
|
||||
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
||||
mock_recognizer = MagicMock()
|
||||
mock_recognizer.recognize_speech.return_value = ""
|
||||
mock_best.return_value = mock_recognizer
|
||||
|
||||
agent = TranscriptionAgent("tcp://in")
|
||||
agent.send = AsyncMock()
|
||||
await agent.setup()
|
||||
|
||||
try:
|
||||
await agent._transcribing_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Should NOT send message
|
||||
agent.send.assert_not_called()
|
||||
|
||||
|
||||
def test_speech_recognizer_factory():
|
||||
# Test Factory Logic
|
||||
with patch("torch.mps.is_available", return_value=True):
|
||||
assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer)
|
||||
|
||||
with patch("torch.mps.is_available", return_value=False):
|
||||
assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer)
|
||||
|
||||
|
||||
def test_openai_recognizer():
|
||||
with patch("whisper.load_model") as load_mock:
|
||||
with patch("whisper.transcribe") as trans_mock:
|
||||
rec = OpenAIWhisperSpeechRecognizer()
|
||||
rec.load_model()
|
||||
load_mock.assert_called()
|
||||
|
||||
trans_mock.return_value = {"text": "Hi"}
|
||||
res = rec.recognize_speech(np.zeros(10))
|
||||
assert res == "Hi"
|
||||
|
||||
|
||||
def test_mlx_recognizer():
|
||||
# Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir().
|
||||
# We must use create=True to inject it into the module namespace during the test.
|
||||
module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer"
|
||||
|
||||
with patch("sys.platform", "darwin"):
|
||||
with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock:
|
||||
with patch(f"{module_path}.ModelHolder", create=True) as holder_mock:
|
||||
# We also need to mock mlx.core if it's used for types/constants
|
||||
with patch(f"{module_path}.mx", create=True):
|
||||
rec = MLXWhisperSpeechRecognizer()
|
||||
rec.load_model()
|
||||
holder_mock.get_model.assert_called()
|
||||
|
||||
mlx_mock.transcribe.return_value = {"text": "Hi"}
|
||||
res = rec.recognize_speech(np.zeros(10))
|
||||
assert res == "Hi"
|
||||
Reference in New Issue
Block a user