Remove SPADE dependency #29
@@ -1,9 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import zmq
|
import zmq
|
||||||
from spade.agent import Agent
|
|
||||||
|
|
||||||
from control_backend.agents.perception.vad_agent import VADAgent
|
from control_backend.agents.perception.vad_agent import VADAgent
|
||||||
|
|
||||||
@@ -15,11 +14,6 @@ def zmq_context(mocker):
|
|||||||
return mock_context
|
return mock_context
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def streaming(mocker):
|
|
||||||
return mocker.patch("control_backend.agents.perception.vad_agent.StreamingBehaviour")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def per_transcription_agent(mocker):
|
def per_transcription_agent(mocker):
|
||||||
return mocker.patch(
|
return mocker.patch(
|
||||||
@@ -27,21 +21,36 @@ def per_transcription_agent(mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def torch_load(mocker):
|
||||||
|
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
|
||||||
|
model = MagicMock()
|
||||||
|
mock_torch.hub.load.return_value = (model, None)
|
||||||
|
mock_torch.from_numpy.side_effect = lambda arr: arr
|
||||||
|
return mock_torch
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_normal_setup(streaming, per_transcription_agent):
|
async def test_normal_setup(per_transcription_agent):
|
||||||
"""
|
"""
|
||||||
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
|
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
|
||||||
sockets, and starts the TranscriptionAgent without loading real models.
|
sockets, and starts the TranscriptionAgent without loading real models.
|
||||||
"""
|
"""
|
||||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
per_vad_agent.add_behaviour = MagicMock()
|
per_vad_agent._streaming_loop = AsyncMock()
|
||||||
|
|
||||||
|
async def swallow_background_task(coro):
|
||||||
|
coro.close()
|
||||||
|
|
||||||
|
per_vad_agent.add_background_task = swallow_background_task
|
||||||
|
per_vad_agent.reset_stream = AsyncMock()
|
||||||
|
|
||||||
await per_vad_agent.setup()
|
await per_vad_agent.setup()
|
||||||
|
|
||||||
streaming.assert_called_once()
|
|
||||||
per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
|
|
||||||
per_transcription_agent.assert_called_once()
|
per_transcription_agent.assert_called_once()
|
||||||
per_transcription_agent.return_value.start.assert_called_once()
|
per_transcription_agent.return_value.start.assert_called_once()
|
||||||
|
per_vad_agent._streaming_loop.assert_called_once()
|
||||||
|
per_vad_agent.reset_stream.assert_called_once()
|
||||||
assert per_vad_agent.audio_in_socket is not None
|
assert per_vad_agent.audio_in_socket is not None
|
||||||
assert per_vad_agent.audio_out_socket is not None
|
assert per_vad_agent.audio_out_socket is not None
|
||||||
|
|
||||||
@@ -91,16 +100,22 @@ async def test_out_socket_creation_failure(zmq_context):
|
|||||||
"""
|
"""
|
||||||
Test setup failure when the audio output socket cannot be created.
|
Test setup failure when the audio output socket cannot be created.
|
||||||
"""
|
"""
|
||||||
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||||
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
zmq.ZMQBindError
|
per_vad_agent.stop = AsyncMock()
|
||||||
)
|
per_vad_agent.reset_stream = AsyncMock()
|
||||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
per_vad_agent._streaming_loop = AsyncMock()
|
||||||
|
per_vad_agent._connect_audio_out_socket = MagicMock(return_value=None)
|
||||||
|
|
||||||
await per_vad_agent.setup()
|
async def swallow_background_task(coro):
|
||||||
|
coro.close()
|
||||||
|
|
||||||
assert per_vad_agent.audio_out_socket is None
|
per_vad_agent.add_background_task = swallow_background_task
|
||||||
mock_super_stop.assert_called_once()
|
|
||||||
|
await per_vad_agent.setup()
|
||||||
|
|
||||||
|
assert per_vad_agent.audio_out_socket is None
|
||||||
|
per_vad_agent.stop.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -109,6 +124,13 @@ async def test_stop(zmq_context, per_transcription_agent):
|
|||||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||||
"""
|
"""
|
||||||
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
per_vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
|
per_vad_agent.reset_stream = AsyncMock()
|
||||||
|
per_vad_agent._streaming_loop = AsyncMock()
|
||||||
|
|
||||||
|
async def swallow_background_task(coro):
|
||||||
|
coro.close()
|
||||||
|
|
||||||
|
per_vad_agent.add_background_task = swallow_background_task
|
||||||
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
||||||
1000,
|
1000,
|
||||||
10000,
|
10000,
|
||||||
|
|||||||
@@ -5,7 +5,24 @@ import pytest
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from control_backend.agents.perception.vad_agent import StreamingBehaviour
|
from control_backend.agents.perception.vad_agent import VADAgent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_settings():
|
||||||
|
from control_backend.agents.perception import vad_agent
|
||||||
|
|
||||||
|
vad_agent.settings.behaviour_settings.vad_prob_threshold = 0.5
|
||||||
|
vad_agent.settings.behaviour_settings.vad_non_speech_patience_chunks = 3
|
||||||
|
vad_agent.settings.behaviour_settings.vad_initial_since_speech = 0
|
||||||
|
vad_agent.settings.vad_settings.sample_rate_hz = 16_000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_torch(mocker):
|
||||||
|
mock_torch = mocker.patch("control_backend.agents.perception.vad_agent.torch")
|
||||||
|
mock_torch.from_numpy.side_effect = lambda arr: arr
|
||||||
|
return mock_torch
|
||||||
|
|
||||||
|
|
||||||
def get_audio_chunks() -> list[bytes]:
|
def get_audio_chunks() -> list[bytes]:
|
||||||
@@ -42,16 +59,39 @@ async def test_real_audio(mocker):
|
|||||||
audio_in_socket = AsyncMock()
|
audio_in_socket = AsyncMock()
|
||||||
audio_in_socket.recv.side_effect = audio_chunks
|
audio_in_socket.recv.side_effect = audio_chunks
|
||||||
|
|
||||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
|
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
|
||||||
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
|
mock_poller.return_value.poll = AsyncMock(return_value=[(audio_in_socket, zmq.POLLIN)])
|
||||||
|
|
||||||
audio_out_socket = AsyncMock()
|
audio_out_socket = AsyncMock()
|
||||||
|
|
||||||
vad_streamer = StreamingBehaviour(audio_in_socket, audio_out_socket)
|
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||||
vad_streamer._ready = True
|
vad_agent.audio_out_socket = audio_out_socket
|
||||||
vad_streamer.agent = MagicMock()
|
|
||||||
for _ in audio_chunks:
|
# Use a fake model that marks most chunks as speech and ends with a few silences
|
||||||
await vad_streamer.run()
|
silence_padding = 5
|
||||||
|
probabilities = [1.0] * len(audio_chunks) + [0.0] * silence_padding
|
||||||
|
chunk_bytes = audio_chunks + [b"\x00" * len(audio_chunks[0])] * silence_padding
|
||||||
|
model_item = MagicMock()
|
||||||
|
model_item.item.side_effect = probabilities
|
||||||
|
vad_agent.model = MagicMock(return_value=model_item)
|
||||||
|
|
||||||
|
class DummyPoller:
|
||||||
|
def __init__(self, data, agent):
|
||||||
|
self.data = data
|
||||||
|
self.agent = agent
|
||||||
|
|
||||||
|
async def poll(self, timeout_ms=None):
|
||||||
|
if self.data:
|
||||||
|
return self.data.pop(0)
|
||||||
|
self.agent._running = False
|
||||||
|
return None
|
||||||
|
|
||||||
|
vad_agent.audio_in_poller = DummyPoller(chunk_bytes, vad_agent)
|
||||||
|
vad_agent._ready = True
|
||||||
|
vad_agent._running = True
|
||||||
|
vad_agent.i_since_speech = 0
|
||||||
|
|
||||||
|
await vad_agent._streaming_loop()
|
||||||
|
|
||||||
audio_out_socket.send.assert_called()
|
audio_out_socket.send.assert_called()
|
||||||
for args in audio_out_socket.send.call_args_list:
|
for args in audio_out_socket.send.call_args_list:
|
||||||
|
|||||||
@@ -2,11 +2,10 @@ import json
|
|||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from spade.message import Message
|
|
||||||
|
|
||||||
from control_backend.agents.bdi.text_belief_extractor_agent.behaviours.text_belief_extractor_behaviour import ( # noqa: E501, We can't shorten this import.
|
from control_backend.agents.bdi.text_belief_extractor_agent.behaviours.text_belief_extractor_behaviour import ( # noqa: E501, We can't shorten this import.
|
||||||
TextBeliefExtractorBehaviour,
|
TextBeliefExtractorBehaviour,
|
||||||
)
|
)
|
||||||
|
from spade.message import Message
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ async def test_socket_poller_with_data(socket, mocker):
|
|||||||
socket_data = b"test"
|
socket_data = b"test"
|
||||||
socket.recv.return_value = socket_data
|
socket.recv.return_value = socket_data
|
||||||
|
|
||||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
|
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
|
||||||
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
|
mock_poller.return_value.poll = AsyncMock(return_value=[(socket, zmq.POLLIN)])
|
||||||
|
|
||||||
poller = SocketPoller(socket)
|
poller = SocketPoller(socket)
|
||||||
# Calling `poll` twice to be able to check that the poller is reused
|
# Calling `poll` twice to be able to check that the poller is reused
|
||||||
@@ -35,8 +35,8 @@ async def test_socket_poller_with_data(socket, mocker):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_socket_poller_no_data(socket, mocker):
|
async def test_socket_poller_no_data(socket, mocker):
|
||||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
|
mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Poller")
|
||||||
mock_poller.return_value.poll.return_value = []
|
mock_poller.return_value.poll = AsyncMock(return_value=[])
|
||||||
|
|
||||||
poller = SocketPoller(socket)
|
poller = SocketPoller(socket)
|
||||||
data = await poller.poll()
|
data = await poller.poll()
|
||||||
|
|||||||
@@ -3,12 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from control_backend.agents.perception.vad_agent import StreamingBehaviour
|
from control_backend.agents.perception.vad_agent import VADAgent
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def audio_in_socket():
|
|
||||||
return AsyncMock()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -17,22 +12,8 @@ def audio_out_socket():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_agent(mocker):
|
def vad_agent(audio_out_socket):
|
||||||
"""Fixture to create a mock BDIAgent."""
|
return VADAgent("tcp://localhost:5555", False)
|
||||||
agent = MagicMock()
|
|
||||||
agent.jid = "vad_agent@test"
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def streaming(audio_in_socket, audio_out_socket, mock_agent):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
torch.hub.load.return_value = (..., ...) # Mock
|
|
||||||
streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
|
|
||||||
streaming._ready = True
|
|
||||||
streaming.agent = mock_agent
|
|
||||||
return streaming
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -61,25 +42,40 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f
|
|||||||
"""
|
"""
|
||||||
model_item = MagicMock()
|
model_item = MagicMock()
|
||||||
model_item.item.side_effect = probabilities
|
model_item.item.side_effect = probabilities
|
||||||
streaming.model = MagicMock()
|
streaming.model = MagicMock(return_value=model_item)
|
||||||
streaming.model.return_value = model_item
|
|
||||||
|
|
||||||
audio_in_poller = AsyncMock()
|
# Prepare deterministic audio chunks and a poller that stops the loop when exhausted
|
||||||
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
|
chunk_bytes = np.empty(shape=512, dtype=np.float32).tobytes()
|
||||||
streaming.audio_in_poller = audio_in_poller
|
chunks = [chunk_bytes for _ in probabilities]
|
||||||
|
|
||||||
for _ in probabilities:
|
class DummyPoller:
|
||||||
await streaming.run()
|
def __init__(self, data, agent):
|
||||||
|
self.data = data
|
||||||
|
self.agent = agent
|
||||||
|
|
||||||
|
async def poll(self, timeout_ms=None):
|
||||||
|
if self.data:
|
||||||
|
return self.data.pop(0)
|
||||||
|
# Stop the loop cleanly once we've consumed all chunks
|
||||||
|
self.agent._running = False
|
||||||
|
return None
|
||||||
|
|
||||||
|
streaming.audio_in_poller = DummyPoller(chunks, streaming)
|
||||||
|
streaming._ready = True
|
||||||
|
streaming._running = True
|
||||||
|
|
||||||
|
await streaming._streaming_loop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
async def test_voice_activity_detected(audio_out_socket, vad_agent):
|
||||||
"""
|
"""
|
||||||
Test a scenario where there is voice activity detected between silences.
|
Test a scenario where there is voice activity detected between silences.
|
||||||
"""
|
"""
|
||||||
speech_chunk_count = 5
|
speech_chunk_count = 5
|
||||||
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
vad_agent.audio_out_socket = audio_out_socket
|
||||||
|
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
||||||
|
|
||||||
audio_out_socket.send.assert_called_once()
|
audio_out_socket.send.assert_called_once()
|
||||||
data = audio_out_socket.send.call_args[0][0]
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
@@ -88,7 +84,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
|
async def test_voice_activity_short_pause(audio_out_socket, vad_agent):
|
||||||
"""
|
"""
|
||||||
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
||||||
short pause.
|
short pause.
|
||||||
@@ -97,7 +93,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
|
|||||||
probabilities = (
|
probabilities = (
|
||||||
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
||||||
)
|
)
|
||||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
vad_agent.audio_out_socket = audio_out_socket
|
||||||
|
await simulate_streaming_with_probabilities(vad_agent, probabilities)
|
||||||
|
|
||||||
audio_out_socket.send.assert_called_once()
|
audio_out_socket.send.assert_called_once()
|
||||||
data = audio_out_socket.send.call_args[0][0]
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
@@ -107,15 +104,22 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
|
async def test_no_data(audio_out_socket, vad_agent):
|
||||||
"""
|
"""
|
||||||
Test a scenario where there is no data received. This should not cause errors.
|
Test a scenario where there is no data received. This should not cause errors.
|
||||||
"""
|
"""
|
||||||
audio_in_poller = AsyncMock()
|
|
||||||
audio_in_poller.poll.return_value = None
|
|
||||||
streaming.audio_in_poller = audio_in_poller
|
|
||||||
|
|
||||||
await streaming.run()
|
class DummyPoller:
|
||||||
|
async def poll(self, timeout_ms=None):
|
||||||
|
vad_agent._running = False
|
||||||
|
return None
|
||||||
|
|
||||||
|
vad_agent.audio_out_socket = audio_out_socket
|
||||||
|
vad_agent.audio_in_poller = DummyPoller()
|
||||||
|
vad_agent._ready = True
|
||||||
|
vad_agent._running = True
|
||||||
|
|
||||||
|
await vad_agent._streaming_loop()
|
||||||
|
|
||||||
audio_out_socket.send.assert_not_called()
|
audio_out_socket.send.assert_not_called()
|
||||||
assert len(streaming.audio_buffer) == 0
|
assert len(vad_agent.audio_buffer) == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user