Files
pepperplus-cb/test/unit/agents/perception/transcription_agent/test_transcription_agent.py
Kasper Marinus b1c18abffd test: bunch of tests
Written with AI, still need to check them

ref: N25B-449
2026-01-16 13:11:41 +01:00

218 lines
7.2 KiB
Python

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
MLXWhisperSpeechRecognizer,
OpenAIWhisperSpeechRecognizer,
SpeechRecognizer,
)
from control_backend.agents.perception.transcription_agent.transcription_agent import (
TranscriptionAgent,
)
@pytest.mark.asyncio
async def test_transcription_agent_flow(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
# Setup context to return this specific mock socket
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# Data: [Audio Bytes, Cancel Loop]
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
# Mock Recognizer
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = "Hello"
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent.send = AsyncMock()
agent._running = True
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# Check transcription happened
assert mock_recognizer.recognize_speech.called
# Check sending
assert agent.send.called
assert agent.send.call_args[0][0].body == "Hello"
await agent.stop()
@pytest.mark.asyncio
async def test_transcription_empty(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# Return valid audio, but recognizer returns empty string
fake_audio = np.zeros(10, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = ""
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent.send = AsyncMock()
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# Should NOT send message
agent.send.assert_not_called()
def test_speech_recognizer_factory():
# Test Factory Logic
with patch("torch.mps.is_available", return_value=True):
assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer)
with patch("torch.mps.is_available", return_value=False):
assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer)
def test_openai_recognizer():
with patch("whisper.load_model") as load_mock:
with patch("whisper.transcribe") as trans_mock:
rec = OpenAIWhisperSpeechRecognizer()
rec.load_model()
load_mock.assert_called()
trans_mock.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"
def test_mlx_recognizer():
# Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir().
# We must use create=True to inject it into the module namespace during the test.
module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer"
with patch("sys.platform", "darwin"):
with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock:
with patch(f"{module_path}.ModelHolder", create=True) as holder_mock:
# We also need to mock mlx.core if it's used for types/constants
with patch(f"{module_path}.mx", create=True):
rec = MLXWhisperSpeechRecognizer()
rec.load_model()
holder_mock.get_model.assert_called()
mlx_mock.transcribe.return_value = {"text": "Hi"}
res = rec.recognize_speech(np.zeros(10))
assert res == "Hi"
@pytest.mark.asyncio
async def test_transcription_loop_continues_after_error(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [
fake_audio, # first iteration → recognizer fails
asyncio.CancelledError(), # second iteration → stop loop
]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.side_effect = RuntimeError("fail")
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
agent._running = True # ← REQUIRED to enter the loop
agent.send = AsyncMock() # should never be called
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro) # match other tests
await agent.setup()
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# recognizer failed, so we should never send anything
agent.send.assert_not_called()
# recv must have been called twice (audio then CancelledError)
assert mock_sub.recv.call_count == 2
@pytest.mark.asyncio
async def test_transcription_continue_branch_when_empty(mock_zmq_context):
mock_sub = MagicMock()
mock_sub.recv = AsyncMock()
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
# First recv → audio chunk
# Second recv → Cancel loop → stop iteration
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
with patch.object(SpeechRecognizer, "best_type") as mock_best:
mock_recognizer = MagicMock()
mock_recognizer.recognize_speech.return_value = "" # <— triggers the continue branch
mock_best.return_value = mock_recognizer
agent = TranscriptionAgent("tcp://in")
# Make loop runnable
agent._running = True
agent.send = AsyncMock()
def close_coro(coro):
coro.close()
return MagicMock()
agent.add_behavior = MagicMock(side_effect=close_coro)
await agent.setup()
# Execute loop manually
try:
await agent._transcribing_loop()
except asyncio.CancelledError:
pass
# → Because of "continue", NO sending should occur
agent.send.assert_not_called()
# → Continue was hit, so we must have read exactly 2 times:
# - first audio
# - second CancelledError
assert mock_sub.recv.call_count == 2
# → recognizer was called once (first iteration)
assert mock_recognizer.recognize_speech.call_count == 1