fix: merge refactor/zmq-internal-socket-behaviour into feat/cb2ui-robot-connections. (And fixed all ruff/ test issues to commit)
ref: None
This commit is contained in:
@@ -11,25 +11,27 @@ from control_backend.agents.ri_command_agent import RICommandAgent
|
||||
async def test_setup_bind(monkeypatch):
|
||||
"""Test setup with bind=True"""
|
||||
fake_socket = MagicMock()
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
|
||||
# Patch Context.instance() to return fake_context
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_command_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_command_agent.settings",
|
||||
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
|
||||
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
|
||||
)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
# Ensure PUB socket bound
|
||||
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
||||
# Ensure SUB socket connected to internal address and subscribed
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
|
||||
|
||||
# Ensure behaviour attached
|
||||
assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours)
|
||||
|
||||
|
||||
@@ -37,19 +39,23 @@ async def test_setup_bind(monkeypatch):
|
||||
async def test_setup_connect(monkeypatch):
|
||||
"""Test setup with bind=False"""
|
||||
fake_socket = MagicMock()
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
|
||||
# Patch Context.instance() to return fake_context
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_command_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_command_agent.settings",
|
||||
MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")),
|
||||
MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")),
|
||||
)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
# Ensure PUB socket connected
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||
|
||||
|
||||
|
||||
@@ -93,12 +93,14 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_1()
|
||||
|
||||
fake_pub_socket = AsyncMock()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
@@ -107,13 +109,11 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
|
||||
# --- Act ---
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
@@ -143,10 +143,14 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_2()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
@@ -155,13 +159,11 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
|
||||
# --- Act ---
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
@@ -191,10 +193,14 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_wrong_negototiate_1()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
@@ -206,13 +212,11 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
# --- Act ---
|
||||
with caplog.at_level("ERROR"):
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
@@ -239,10 +243,14 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_3()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
@@ -251,12 +259,10 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
# --- Act ---
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=True,
|
||||
)
|
||||
@@ -286,10 +292,14 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_4()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
@@ -298,12 +308,10 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
# --- Act ---
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
@@ -333,10 +341,14 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_correct_negototiate_5()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
@@ -345,12 +357,10 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
# --- Act ---
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
@@ -380,10 +390,14 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = fake_json_invalid_id_negototiate()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
@@ -395,14 +409,12 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
|
||||
# --- Act ---
|
||||
with caplog.at_level("WARNING"):
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
@@ -426,10 +438,14 @@ async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Mock context.socket to return our fake socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -437,14 +453,12 @@ async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
|
||||
# --- Act ---
|
||||
with caplog.at_level("WARNING"):
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
@@ -466,11 +480,11 @@ async def test_listen_behaviour_ping_correct(caplog):
|
||||
fake_socket = AsyncMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
|
||||
fake_pub_socket = AsyncMock()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# TODO: Integration test between actual server and password needed for spade agents
|
||||
agent = RICommunicationAgent("test@server", "password", fake_pub_socket)
|
||||
agent.req_socket = fake_socket
|
||||
agent = RICommunicationAgent("test@server", "password")
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
behaviour = agent.ListenBehaviour()
|
||||
agent.add_behaviour(behaviour)
|
||||
@@ -505,7 +519,7 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
|
||||
fake_pub_socket = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("test@server", "password", fake_pub_socket)
|
||||
agent.req_socket = fake_socket
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
behaviour = agent.ListenBehaviour()
|
||||
agent.add_behaviour(behaviour)
|
||||
@@ -525,10 +539,10 @@ async def test_listen_behaviour_timeout(caplog):
|
||||
fake_socket.send_json = AsyncMock()
|
||||
# recv_json will never resolve, simulate timeout
|
||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
fake_pub_socket = AsyncMock()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("test@server", "password", fake_pub_socket)
|
||||
agent.req_socket = fake_socket
|
||||
agent = RICommunicationAgent("test@server", "password")
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
behaviour = agent.ListenBehaviour()
|
||||
agent.add_behaviour(behaviour)
|
||||
@@ -546,6 +560,7 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
|
||||
"""
|
||||
fake_socket = AsyncMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# This is a message without endpoint >:(
|
||||
fake_socket.recv_json = AsyncMock(
|
||||
@@ -553,10 +568,9 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
|
||||
"data": "I dont have an endpoint >:)",
|
||||
}
|
||||
)
|
||||
fake_pub_socket = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("test@server", "password", fake_pub_socket)
|
||||
agent.req_socket = fake_socket
|
||||
agent = RICommunicationAgent("test@server", "password")
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
behaviour = agent.ListenBehaviour()
|
||||
agent.add_behaviour(behaviour)
|
||||
@@ -574,18 +588,20 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
|
||||
async def test_setup_unexpected_exception(monkeypatch, caplog):
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
# Simulate unexpected exception during recv_json()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
@@ -602,6 +618,7 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
|
||||
# --- Arrange ---
|
||||
fake_socket = MagicMock()
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
# Make recv_json return malformed negotiation data to trigger unpacking exception
|
||||
malformed_data = {
|
||||
@@ -611,8 +628,11 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
|
||||
fake_socket.recv_json = AsyncMock(return_value=malformed_data)
|
||||
|
||||
# Patch context.socket
|
||||
fake_context = MagicMock()
|
||||
fake_context.socket.return_value = fake_socket
|
||||
monkeypatch.setattr(
|
||||
"control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket
|
||||
"control_backend.agents.ri_communication_agent.Context",
|
||||
MagicMock(instance=MagicMock(return_value=fake_context)),
|
||||
)
|
||||
|
||||
# Patch RICommandAgent so it won't actually start
|
||||
@@ -621,12 +641,10 @@ async def test_setup_unpacking_exception(monkeypatch, caplog):
|
||||
) as MockCommandAgent:
|
||||
fake_agent_instance = MockCommandAgent.return_value
|
||||
fake_agent_instance.start = AsyncMock()
|
||||
fake_pub_socket = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent(
|
||||
"test@server",
|
||||
"password",
|
||||
pub_socket=fake_pub_socket,
|
||||
address="tcp://localhost:5555",
|
||||
bind=False,
|
||||
)
|
||||
|
||||
Binary file not shown.
108
test/integration/agents/vad_agent/test_vad_agent.py
Normal file
108
test/integration/agents/vad_agent/test_vad_agent.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import random
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
from spade.agent import Agent
|
||||
|
||||
from control_backend.agents.vad_agent import VADAgent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streaming(mocker):
|
||||
return mocker.patch("control_backend.agents.vad_agent.Streaming")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transcription_agent(mocker):
|
||||
return mocker.patch("control_backend.agents.vad_agent.TranscriptionAgent", autospec=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_setup(streaming, transcription_agent):
|
||||
"""
|
||||
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
|
||||
sockets, and starts the TranscriptionAgent without loading real models.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
vad_agent.add_behaviour = MagicMock()
|
||||
|
||||
await vad_agent.setup()
|
||||
|
||||
streaming.assert_called_once()
|
||||
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
|
||||
transcription_agent.assert_called_once()
|
||||
transcription_agent.return_value.start.assert_called_once()
|
||||
assert vad_agent.audio_in_socket is not None
|
||||
assert vad_agent.audio_out_socket is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("do_bind", [True, False])
|
||||
def test_in_socket_creation(zmq_context, do_bind: bool):
|
||||
"""
|
||||
Test that the VAD agent creates an audio input socket, differentiating between binding and
|
||||
connecting.
|
||||
"""
|
||||
vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
|
||||
|
||||
vad_agent._connect_audio_in_socket()
|
||||
|
||||
assert vad_agent.audio_in_socket is not None
|
||||
|
||||
zmq_context.socket.assert_called_once_with(zmq.SUB)
|
||||
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
|
||||
|
||||
if do_bind:
|
||||
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
|
||||
else:
|
||||
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
|
||||
|
||||
|
||||
def test_out_socket_creation(zmq_context):
|
||||
"""
|
||||
Test that the VAD agent creates an audio output socket correctly.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
vad_agent._connect_audio_out_socket()
|
||||
|
||||
assert vad_agent.audio_out_socket is not None
|
||||
|
||||
zmq_context.socket.assert_called_once_with(zmq.PUB)
|
||||
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_out_socket_creation_failure(zmq_context):
|
||||
"""
|
||||
Test setup failure when the audio output socket cannot be created.
|
||||
"""
|
||||
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
|
||||
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
|
||||
await vad_agent.setup()
|
||||
|
||||
assert vad_agent.audio_out_socket is None
|
||||
mock_super_stop.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop(zmq_context, transcription_agent):
|
||||
"""
|
||||
Test that when the VAD agent is stopped, the sockets are closed correctly.
|
||||
"""
|
||||
vad_agent = VADAgent("tcp://localhost:12345", False)
|
||||
zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000)
|
||||
|
||||
await vad_agent.setup()
|
||||
await vad_agent.stop()
|
||||
|
||||
assert zmq_context.socket.return_value.close.call_count == 2
|
||||
assert vad_agent.audio_in_socket is None
|
||||
assert vad_agent.audio_out_socket is None
|
||||
57
test/integration/agents/vad_agent/test_vad_with_audio.py
Normal file
57
test/integration/agents/vad_agent/test_vad_with_audio.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.vad_agent import Streaming
|
||||
|
||||
|
||||
def get_audio_chunks() -> list[bytes]:
|
||||
curr_file = os.path.realpath(__file__)
|
||||
curr_dir = os.path.dirname(curr_file)
|
||||
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
|
||||
|
||||
chunk_size = 512
|
||||
|
||||
chunks = []
|
||||
|
||||
with sf.SoundFile(file, "r") as f:
|
||||
assert f.samplerate == 16000
|
||||
assert f.channels == 1
|
||||
assert f.subtype == "FLOAT"
|
||||
|
||||
while True:
|
||||
data = f.read(chunk_size, dtype="float32")
|
||||
if len(data) != chunk_size:
|
||||
break
|
||||
|
||||
chunks.append(data.tobytes())
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_audio(mocker):
|
||||
"""
|
||||
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
|
||||
input. Ensure that it outputs some fragments with audio.
|
||||
"""
|
||||
audio_chunks = get_audio_chunks()
|
||||
audio_in_socket = AsyncMock()
|
||||
audio_in_socket.recv.side_effect = audio_chunks
|
||||
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
|
||||
|
||||
audio_out_socket = AsyncMock()
|
||||
|
||||
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
|
||||
for _ in audio_chunks:
|
||||
await vad_streamer.run()
|
||||
|
||||
audio_out_socket.send.assert_called()
|
||||
for args in audio_out_socket.send.call_args_list:
|
||||
assert isinstance(args[0][0], bytes)
|
||||
assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetter
|
||||
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetterBehaviour
|
||||
|
||||
# Define a constant for the collector agent name to use in tests
|
||||
COLLECTOR_AGENT_NAME = "belief_collector"
|
||||
@@ -22,16 +22,14 @@ def mock_agent(mocker):
|
||||
|
||||
@pytest.fixture
|
||||
def belief_setter(mock_agent, mocker):
|
||||
"""Fixture to create an instance of BeliefSetter with a mocked agent."""
|
||||
"""Fixture to create an instance of BeliefSetterBehaviour with a mocked agent."""
|
||||
# Patch the settings to use a predictable agent name
|
||||
mocker.patch(
|
||||
"control_backend.agents.bdi.behaviours.belief_setter.settings.agent_settings.belief_collector_agent_name",
|
||||
COLLECTOR_AGENT_NAME,
|
||||
)
|
||||
# Patch asyncio.sleep to prevent tests from actually waiting
|
||||
mocker.patch("asyncio.sleep", return_value=None)
|
||||
|
||||
setter = BeliefSetter()
|
||||
setter = BeliefSetterBehaviour()
|
||||
setter.agent = mock_agent
|
||||
# Mock the receive method, we will control its return value in each test
|
||||
setter.receive = AsyncMock()
|
||||
@@ -115,7 +113,7 @@ def test_process_belief_message_valid_json(belief_setter, mocker):
|
||||
Test processing a valid belief message with correct thread and JSON body.
|
||||
"""
|
||||
# Arrange
|
||||
beliefs_payload = {"is_hot": [["kitchen"]], "is_clean": [["kitchen"], ["bathroom"]]}
|
||||
beliefs_payload = {"is_hot": ["kitchen"], "is_clean": ["kitchen", "bathroom"]}
|
||||
msg = create_mock_message(
|
||||
sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs"
|
||||
)
|
||||
@@ -185,8 +183,8 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
|
||||
"""
|
||||
# Arrange
|
||||
beliefs_to_set = {
|
||||
"is_hot": [["kitchen"], ["living_room"]],
|
||||
"door_is": [["front_door", "closed"]],
|
||||
"is_hot": ["kitchen"],
|
||||
"door_opened": ["front_door", "back_door"],
|
||||
}
|
||||
|
||||
# Act
|
||||
@@ -196,29 +194,38 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
|
||||
# Assert
|
||||
expected_calls = [
|
||||
call("is_hot", "kitchen"),
|
||||
call("is_hot", "living_room"),
|
||||
call("door_is", "front_door", "closed"),
|
||||
call("door_opened", "front_door", "back_door"),
|
||||
]
|
||||
mock_agent.bdi.set_belief.assert_has_calls(expected_calls, any_order=True)
|
||||
assert mock_agent.bdi.set_belief.call_count == 3
|
||||
assert mock_agent.bdi.set_belief.call_count == 2
|
||||
|
||||
# Check logs
|
||||
assert "Set belief is_hot with arguments ['kitchen']" in caplog.text
|
||||
assert "Set belief is_hot with arguments ['living_room']" in caplog.text
|
||||
assert "Set belief door_is with arguments ['front_door', 'closed']" in caplog.text
|
||||
assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text
|
||||
|
||||
|
||||
def test_set_beliefs_bdi_not_initialized(belief_setter, mock_agent, caplog):
|
||||
"""
|
||||
Test that a warning is logged if the agent's BDI is not initialized.
|
||||
"""
|
||||
# Arrange
|
||||
mock_agent.bdi = None # Simulate BDI not being ready
|
||||
beliefs_to_set = {"is_hot": [["kitchen"]]}
|
||||
# def test_responded_unset(belief_setter, mock_agent):
|
||||
# # Arrange
|
||||
# new_beliefs = {"user_said": ["message"]}
|
||||
#
|
||||
# # Act
|
||||
# belief_setter._set_beliefs(new_beliefs)
|
||||
#
|
||||
# # Assert
|
||||
# mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")])
|
||||
# mock_agent.bdi.remove_belief.assert_has_calls([call("responded")])
|
||||
|
||||
# Act
|
||||
with caplog.at_level(logging.WARNING):
|
||||
belief_setter._set_beliefs(beliefs_to_set)
|
||||
|
||||
# Assert
|
||||
assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text
|
||||
# def test_set_beliefs_bdi_not_initialized(belief_setter, mock_agent, caplog):
|
||||
# """
|
||||
# Test that a warning is logged if the agent's BDI is not initialized.
|
||||
# """
|
||||
# # Arrange
|
||||
# mock_agent.bdi = None # Simulate BDI not being ready
|
||||
# beliefs_to_set = {"is_hot": ["kitchen"]}
|
||||
#
|
||||
# # Act
|
||||
# with caplog.at_level(logging.WARNING):
|
||||
# belief_setter._set_beliefs(beliefs_to_set)
|
||||
#
|
||||
# # Assert
|
||||
# assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.belief_collector.behaviours.continuous_collect import (
|
||||
ContinuousBeliefCollector,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent(mocker):
|
||||
"""Fixture to create a mock Agent."""
|
||||
agent = MagicMock()
|
||||
agent.jid = "belief_collector_agent@test"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def continuous_collector(mock_agent, mocker):
|
||||
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
|
||||
# Patch asyncio.sleep to prevent tests from actually waiting
|
||||
mocker.patch("asyncio.sleep", return_value=None)
|
||||
|
||||
collector = ContinuousBeliefCollector()
|
||||
collector.agent = mock_agent
|
||||
# Mock the receive method, we will control its return value in each test
|
||||
collector.receive = AsyncMock()
|
||||
return collector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_no_message_received(continuous_collector, mocker):
|
||||
"""
|
||||
Test that when no message is received, _process_message is not called.
|
||||
"""
|
||||
# Arrange
|
||||
continuous_collector.receive.return_value = None
|
||||
mocker.patch.object(continuous_collector, "_process_message")
|
||||
|
||||
# Act
|
||||
await continuous_collector.run()
|
||||
|
||||
# Assert
|
||||
continuous_collector._process_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_message_received(continuous_collector, mocker):
|
||||
"""
|
||||
Test that when a message is received, _process_message is called with that message.
|
||||
"""
|
||||
# Arrange
|
||||
mock_msg = MagicMock()
|
||||
continuous_collector.receive.return_value = mock_msg
|
||||
mocker.patch.object(continuous_collector, "_process_message")
|
||||
|
||||
# Act
|
||||
await continuous_collector.run()
|
||||
|
||||
# Assert
|
||||
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_invalid(continuous_collector, mocker):
|
||||
"""
|
||||
Test that when an invalid JSON message is received, a warning is logged and processing stops.
|
||||
"""
|
||||
# Arrange
|
||||
invalid_json = "this is not json"
|
||||
msg = MagicMock()
|
||||
msg.body = invalid_json
|
||||
msg.sender = "belief_text_agent_mock@test"
|
||||
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
|
||||
# Act
|
||||
await continuous_collector._process_message(msg)
|
||||
|
||||
# Assert
|
||||
logger_mock.warning.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sender_from_message(continuous_collector):
|
||||
"""
|
||||
Test that _sender_node correctly extracts the sender node from the message JID.
|
||||
"""
|
||||
# Arrange
|
||||
msg = MagicMock()
|
||||
msg.sender = "agent_node@host/resource"
|
||||
|
||||
# Act
|
||||
sender_node = continuous_collector._sender_node(msg)
|
||||
|
||||
# Assert
|
||||
assert sender_node == "agent_node"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
msg.body = json.dumps({"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}})
|
||||
msg.sender = "anyone@test"
|
||||
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
msg.body = json.dumps({"beliefs": {"user_said": [["hi"]]}}) # no type
|
||||
msg.sender = "belief_text_agent_mock@test"
|
||||
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
msg.body = json.dumps({"type": "emotion_extraction_text"})
|
||||
msg.sender = "anyone@test"
|
||||
spy = mocker.patch.object(continuous_collector, "_handle_emo_text", new=AsyncMock())
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unrecognized_message_logs_info(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
msg.body = json.dumps({"type": "something_else"})
|
||||
msg.sender = "x@test"
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._process_message(msg)
|
||||
logger_mock.info.assert_any_call(
|
||||
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
|
||||
"x",
|
||||
"something_else",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_no_beliefs(continuous_collector, mocker):
|
||||
msg_payload = {"type": "belief_extraction_text"} # no 'beliefs'
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(msg_payload, "origin_node")
|
||||
logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_beliefs_not_dict(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]}
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "origin")
|
||||
logger_mock.warning.assert_any_call(
|
||||
"BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_values_not_lists(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}}
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "origin")
|
||||
logger_mock.warning.assert_any_call(
|
||||
"BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
|
||||
# Your code calls self.send(..); patch it
|
||||
# (or switch implementation to self.agent.send and patch that)
|
||||
continuous_collector.send = AsyncMock()
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
|
||||
|
||||
logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1)
|
||||
# and the item logs:
|
||||
logger_mock.info.assert_any_call(" - %s %s", "user_said", "hello test")
|
||||
logger_mock.info.assert_any_call(" - %s %s", "user_said", "No")
|
||||
# make sure we attempted a send
|
||||
continuous_collector.send.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_beliefs_noop_on_empty(continuous_collector):
|
||||
continuous_collector.send = AsyncMock()
|
||||
await continuous_collector._send_beliefs_to_bdi([], origin="o")
|
||||
continuous_collector.send.assert_not_awaited()
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_send_beliefs_sends_json_packet(continuous_collector):
|
||||
# # Patch .send and capture the message body
|
||||
# sent = {}
|
||||
#
|
||||
# async def _fake_send(msg):
|
||||
# sent["body"] = msg.body
|
||||
# sent["to"] = str(msg.to)
|
||||
#
|
||||
# continuous_collector.send = AsyncMock(side_effect=_fake_send)
|
||||
# beliefs = ["user_said hello", "user_said No"]
|
||||
# await continuous_collector._send_beliefs_to_bdi(beliefs, origin="origin_node")
|
||||
#
|
||||
# assert "belief_packet" in json.loads(sent["body"])["type"]
|
||||
# assert json.loads(sent["body"])["beliefs"] == beliefs
|
||||
|
||||
|
||||
def test_sender_node_no_sender_returns_literal(continuous_collector):
|
||||
msg = MagicMock()
|
||||
msg.sender = None
|
||||
assert continuous_collector._sender_node(msg) == "no_sender"
|
||||
|
||||
|
||||
def test_sender_node_without_at(continuous_collector):
|
||||
msg = MagicMock()
|
||||
msg.sender = "localpartonly"
|
||||
assert continuous_collector._sender_node(msg) == "localpartonly"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
|
||||
continuous_collector.send = AsyncMock()
|
||||
await continuous_collector._handle_belief_text(payload, "origin")
|
||||
continuous_collector.send.assert_awaited_once()
|
||||
46
test/unit/agents/test_vad_socket_poller.py
Normal file
46
test/unit/agents/test_vad_socket_poller.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.vad_agent import SocketPoller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_socket_poller_with_data(socket, mocker):
|
||||
socket_data = b"test"
|
||||
socket.recv.return_value = socket_data
|
||||
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
|
||||
|
||||
poller = SocketPoller(socket)
|
||||
# Calling `poll` twice to be able to check that the poller is reused
|
||||
await poller.poll()
|
||||
data = await poller.poll()
|
||||
|
||||
assert data == socket_data
|
||||
|
||||
# Ensure that the poller was reused
|
||||
mock_poller.assert_called_once_with()
|
||||
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
|
||||
|
||||
assert socket.recv.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_socket_poller_no_data(socket, mocker):
|
||||
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
|
||||
mock_poller.return_value.poll.return_value = []
|
||||
|
||||
poller = SocketPoller(socket)
|
||||
data = await poller.poll()
|
||||
|
||||
assert data is None
|
||||
|
||||
socket.recv.assert_not_called()
|
||||
95
test/unit/agents/test_vad_streaming.py
Normal file
95
test/unit/agents/test_vad_streaming.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.vad_agent import Streaming
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_in_socket():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_out_socket():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def streaming(audio_in_socket, audio_out_socket):
|
||||
import torch
|
||||
|
||||
torch.hub.load.return_value = (..., ...) # Mock
|
||||
return Streaming(audio_in_socket, audio_out_socket)
|
||||
|
||||
|
||||
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||
"""
|
||||
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
||||
|
||||
:param streaming: The streaming component to be tested.
|
||||
:param probabilities: A list of probabilities representing the outputs of the VAD model.
|
||||
"""
|
||||
model_item = MagicMock()
|
||||
model_item.item.side_effect = probabilities
|
||||
streaming.model = MagicMock()
|
||||
streaming.model.return_value = model_item
|
||||
|
||||
audio_in_poller = AsyncMock()
|
||||
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
|
||||
streaming.audio_in_poller = audio_in_poller
|
||||
|
||||
for _ in probabilities:
|
||||
await streaming.run()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
||||
"""
|
||||
Test a scenario where there is voice activity detected between silences.
|
||||
:return:
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
assert isinstance(data, bytes)
|
||||
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
|
||||
assert len(data) == 512 * 4 * (speech_chunk_count + 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
|
||||
"""
|
||||
Test a scenario where there is a short pause between speech, checking whether it ignores the
|
||||
short pause.
|
||||
"""
|
||||
speech_chunk_count = 5
|
||||
probabilities = (
|
||||
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
|
||||
)
|
||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||
|
||||
audio_out_socket.send.assert_called_once()
|
||||
data = audio_out_socket.send.call_args[0][0]
|
||||
assert isinstance(data, bytes)
|
||||
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
|
||||
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
|
||||
"""
|
||||
Test a scenario where there is no data received. This should not cause errors.
|
||||
"""
|
||||
audio_in_poller = AsyncMock()
|
||||
audio_in_poller.poll.return_value = None
|
||||
streaming.audio_in_poller = audio_in_poller
|
||||
|
||||
await streaming.run()
|
||||
|
||||
audio_out_socket.send.assert_not_called()
|
||||
assert len(streaming.audio_buffer) == 0
|
||||
36
test/unit/agents/transcription/test_speech_recognizer.py
Normal file
36
test/unit/agents/transcription/test_speech_recognizer.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
|
||||
from control_backend.agents.transcription import SpeechRecognizer
|
||||
from control_backend.agents.transcription.speech_recognizer import OpenAIWhisperSpeechRecognizer
|
||||
|
||||
|
||||
def test_estimate_max_tokens():
|
||||
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
||||
|
||||
assert actual == 400
|
||||
assert isinstance(actual, int)
|
||||
|
||||
|
||||
def test_get_decode_options():
|
||||
"""Check whether the right decode options are given under different scenarios."""
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
# With the defaults, it should limit output length based on input size
|
||||
recognizer = OpenAIWhisperSpeechRecognizer()
|
||||
options = recognizer._get_decode_options(audio)
|
||||
|
||||
assert "sample_len" in options
|
||||
assert isinstance(options["sample_len"], int)
|
||||
|
||||
# When explicitly enabled, it should limit output length based on input size
|
||||
recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=True)
|
||||
options = recognizer._get_decode_options(audio)
|
||||
|
||||
assert "sample_len" in options
|
||||
assert isinstance(options["sample_len"], int)
|
||||
|
||||
# When disabled, it should not limit output length based on input size
|
||||
assert "sample_rate" not in options
|
||||
@@ -11,6 +11,7 @@ def pytest_configure(config):
|
||||
mock_spade = MagicMock()
|
||||
mock_spade.agent = MagicMock()
|
||||
mock_spade.behaviour = MagicMock()
|
||||
mock_spade.message = MagicMock()
|
||||
mock_spade_bdi = MagicMock()
|
||||
mock_spade_bdi.bdi = MagicMock()
|
||||
|
||||
@@ -21,6 +22,7 @@ def pytest_configure(config):
|
||||
sys.modules["spade"] = mock_spade
|
||||
sys.modules["spade.agent"] = mock_spade.agent
|
||||
sys.modules["spade.behaviour"] = mock_spade.behaviour
|
||||
sys.modules["spade.message"] = mock_spade.message
|
||||
sys.modules["spade_bdi"] = mock_spade_bdi
|
||||
sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi
|
||||
|
||||
@@ -33,3 +35,26 @@ def pytest_configure(config):
|
||||
mock_config_module.settings = MagicMock()
|
||||
|
||||
sys.modules["control_backend.core.config"] = mock_config_module
|
||||
|
||||
# --- Mock torch and zmq for VAD ---
|
||||
mock_torch = MagicMock()
|
||||
mock_zmq = MagicMock()
|
||||
mock_zmq.asyncio = mock_zmq
|
||||
|
||||
# In individual tests, these can be imported and the return values changed
|
||||
sys.modules["torch"] = mock_torch
|
||||
sys.modules["zmq"] = mock_zmq
|
||||
sys.modules["zmq.asyncio"] = mock_zmq.asyncio
|
||||
|
||||
# --- Mock whisper ---
|
||||
mock_whisper = MagicMock()
|
||||
mock_mlx = MagicMock()
|
||||
mock_mlx.core = MagicMock()
|
||||
mock_mlx_whisper = MagicMock()
|
||||
mock_mlx_whisper.transcribe = MagicMock()
|
||||
|
||||
sys.modules["whisper"] = mock_whisper
|
||||
sys.modules["mlx"] = mock_mlx
|
||||
sys.modules["mlx.core"] = mock_mlx
|
||||
sys.modules["mlx_whisper"] = mock_mlx_whisper
|
||||
sys.modules["mlx_whisper.transcribe"] = mock_mlx_whisper.transcribe
|
||||
|
||||
Reference in New Issue
Block a user