feat: apply new agent naming standards
Expanding abbreviations to remove ambiguity, simplifying agent names to reduce repetition. ref: N25B-257
This commit is contained in:
Binary file not shown.
122
test/integration/agents/perception/vad_agent/test_vad_agent.py
Normal file
122
test/integration/agents/perception/vad_agent/test_vad_agent.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import random
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
from spade.agent import Agent
|
||||
|
||||
from control_backend.agents.perception.vad_agent import VADAgent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Context.instance")
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streaming(mocker):
|
||||
return mocker.patch("control_backend.agents.perception.vad_agent.StreamingBehaviour")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def per_transcription_agent(mocker):
|
||||
return mocker.patch(
|
||||
"control_backend.agents.perception.vad_agent.TranscriptionAgent", autospec=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_setup(streaming, per_transcription_agent):
|
||||
"""
|
||||
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
|
||||
sockets, and starts the TranscriptionAgent without loading real models.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
per_vad_agent.add_behaviour = MagicMock()
|
||||
|
||||
await per_vad_agent.setup()
|
||||
|
||||
streaming.assert_called_once()
|
||||
per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
|
||||
per_transcription_agent.assert_called_once()
|
||||
per_transcription_agent.return_value.start.assert_called_once()
|
||||
assert per_vad_agent.audio_in_socket is not None
|
||||
assert per_vad_agent.audio_out_socket is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_bind", [True, False])
|
||||
def test_in_socket_creation(zmq_context, do_bind: bool):
|
||||
"""
|
||||
Test that the VAD agent creates an audio input socket, differentiating between binding and
|
||||
connecting.
|
||||
"""
|
||||
per_vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
|
||||
|
||||
per_vad_agent._connect_audio_in_socket()
|
||||
|
||||
assert per_vad_agent.audio_in_socket is not None
|
||||
|
||||
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
|
||||
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
|
||||
zmq.SUBSCRIBE,
|
||||
"",
|
||||
)
|
||||
|
||||
if do_bind:
|
||||
zmq_context.return_value.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
||||
else:
|
||||
zmq_context.return_value.socket.return_value.connect.assert_called_once_with(
|
||||
"tcp://localhost:12345"
|
||||
)
|
||||
|
||||
|
||||
def test_out_socket_creation(zmq_context):
|
||||
"""
|
||||
Test that the VAD agent creates an audio output socket correctly.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
per_vad_agent._connect_audio_out_socket()
|
||||
|
||||
assert per_vad_agent.audio_out_socket is not None
|
||||
|
||||
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_out_socket_creation_failure(zmq_context):
|
||||
"""
|
||||
Test setup failure when the audio output socket cannot be created.
|
||||
"""
|
||||
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
|
||||
zmq.ZMQBindError
|
||||
)
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
await per_vad_agent.setup()
|
||||
|
||||
assert per_vad_agent.audio_out_socket is None
|
||||
mock_super_stop.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop(zmq_context, per_transcription_agent):
|
||||
"""
|
||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||
"""
|
||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
||||
1000,
|
||||
10000,
|
||||
)
|
||||
|
||||
await per_vad_agent.setup()
|
||||
await per_vad_agent.stop()
|
||||
|
||||
assert zmq_context.return_value.socket.return_value.close.call_count == 2
|
||||
assert per_vad_agent.audio_in_socket is None
|
||||
assert per_vad_agent.audio_out_socket is None
|
||||
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.perception.vad_agent import StreamingBehaviour
|
||||
|
||||
|
||||
def get_audio_chunks() -> list[bytes]:
|
||||
curr_file = os.path.realpath(__file__)
|
||||
curr_dir = os.path.dirname(curr_file)
|
||||
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
|
||||
|
||||
chunk_size = 512
|
||||
|
||||
chunks = []
|
||||
|
||||
with sf.SoundFile(file, "r") as f:
|
||||
assert f.samplerate == 16000
|
||||
assert f.channels == 1
|
||||
assert f.subtype == "FLOAT"
|
||||
|
||||
while True:
|
||||
data = f.read(chunk_size, dtype="float32")
|
||||
if len(data) != chunk_size:
|
||||
break
|
||||
|
||||
chunks.append(data.tobytes())
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_audio(mocker):
|
||||
"""
|
||||
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
|
||||
input. Ensure that it outputs some fragments with audio.
|
||||
"""
|
||||
audio_chunks = get_audio_chunks()
|
||||
audio_in_socket = AsyncMock()
|
||||
audio_in_socket.recv.side_effect = audio_chunks
|
||||
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
|
||||
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
|
||||
|
||||
audio_out_socket = AsyncMock()
|
||||
|
||||
vad_streamer = StreamingBehaviour(audio_in_socket, audio_out_socket)
|
||||
vad_streamer._ready = True
|
||||
vad_streamer.agent = MagicMock()
|
||||
for _ in audio_chunks:
|
||||
await vad_streamer.run()
|
||||
|
||||
audio_out_socket.send.assert_called()
|
||||
for args in audio_out_socket.send.call_args_list:
|
||||
assert isinstance(args[0][0], bytes)
|
||||
assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio
|
||||
Reference in New Issue
Block a user