feat: LLM agent #12

Merged
2584433 merged 13 commits from feat/llm-agent into dev 2025-10-29 12:58:41 +00:00
13 changed files with 483 additions and 3 deletions
Showing only changes of commit f163e0ee6c - Show all commits

View File

@@ -35,10 +35,16 @@ uv run fastapi dev src/control_backend/main.py
``` ```
## Testing ## Testing
Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following: Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests:
```bash ```bash
uv run --only-group test pytest uv run --only-group test pytest test/unit
```
Or for integration tests:
```bash
uv run --group integration-test pytest test/integration
``` ```
## GitHooks ## GitHooks

View File

@@ -7,6 +7,7 @@ requires-python = ">=3.13"
dependencies = [ dependencies = [
"fastapi[all]>=0.115.6", "fastapi[all]>=0.115.6",
"mlx-whisper>=0.4.3 ; sys_platform == 'darwin'", "mlx-whisper>=0.4.3 ; sys_platform == 'darwin'",
"numpy>=2.3.3",
"openai-whisper>=20250625", "openai-whisper>=20250625",
"pyaudio>=0.2.14", "pyaudio>=0.2.14",
"pydantic>=2.12.0", "pydantic>=2.12.0",
@@ -33,6 +34,7 @@ integration-test = [
"soundfile>=0.13.1", "soundfile>=0.13.1",
] ]
test = [ test = [
"numpy>=2.3.3",
"pytest>=8.4.2", "pytest>=8.4.2",
"pytest-asyncio>=1.2.0", "pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0", "pytest-cov>=7.0.0",

View File

@@ -0,0 +1,156 @@
import logging
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
class SocketPoller[T]:
"""
Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for
multiple usages.
"""
def __init__(self, socket: azmq.Socket, timeout_ms: int = 100):
"""
:param socket: The socket to poll and get data from.
:param timeout_ms: A timeout in milliseconds to wait for data.
"""
self.socket = socket
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
self.timeout_ms = timeout_ms
async def poll(self, timeout_ms: int | None = None) -> T | None:
"""
Get data from the socket, or None if the timeout is reached.
:param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used.
:return: Data from the socket or None.
"""
timeout_ms = timeout_ms or self.timeout_ms
socks = dict(self.poller.poll(timeout_ms))
if socks.get(self.socket) == zmq.POLLIN:
return await self.socket.recv()
return None
class Streaming(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
self.model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
)
self.audio_out_socket = audio_out_socket
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100 # Used to allow small pauses in speech
async def run(self) -> None:
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
logger.debug("No audio data received. Discarding buffer until new data arrives.")
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100
return
# copy otherwise Torch will be sad that it's immutable
chunk = np.frombuffer(data, dtype=np.float32).copy()
prob = self.model(torch.from_numpy(chunk), 16000).item()
if prob > 0.5:
if self.i_since_speech > 3:
logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
self.i_since_speech += 1
# prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain
if self.i_since_speech <= 3:
self.audio_buffer = np.append(self.audio_buffer, chunk)
return
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
# Prepend the last chunk that had no speech, for a more fluent boundary
self.audio_buffer = chunk
class VADAgent(Agent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
"""
def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name)
self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind
self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None
async def stop(self):
"""
Stop listening to audio, stop publishing audio, close sockets.
"""
if self.audio_in_socket is not None:
self.audio_in_socket.close()
self.audio_in_socket = None
if self.audio_out_socket is not None:
self.audio_out_socket.close()
self.audio_out_socket = None
return await super().stop()
def _connect_audio_in_socket(self):
self.audio_in_socket = zmq_context.socket(zmq.SUB)
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
if self.audio_in_bind:
self.audio_in_socket.bind(self.audio_in_address)
else:
self.audio_in_socket.connect(self.audio_in_address)
self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket)
def _connect_audio_out_socket(self) -> int | None:
"""Returns the port bound, or None if binding failed."""
try:
self.audio_out_socket = zmq_context.socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def setup(self):
logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
audio_out_port = self._connect_audio_out_socket()
if audio_out_port is None:
await self.stop()
return
streaming = Streaming(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(streaming)
# ... start agents dependent on the output audio fragments here
logger.info("Finished setting up %s", self.jid)

View File

@@ -10,6 +10,7 @@ class AgentSettings(BaseModel):
host: str = "xmpp.twirre.dev" host: str = "xmpp.twirre.dev"
bdi_core_agent_name: str = "bdi_core" bdi_core_agent_name: str = "bdi_core"
belief_collector_agent_name: str = "belief_collector" belief_collector_agent_name: str = "belief_collector"
vad_agent_name: str = "vad_agent"
llm_agent_name: str = "llm_agent" llm_agent_name: str = "llm_agent"
test_agent_name: str = "test_agent" test_agent_name: str = "test_agent"

View File

@@ -11,6 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
# Internal imports # Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.bdi.bdi_core import BDICoreAgent from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.agents.llm.llm import LLMAgent from control_backend.agents.llm.llm import LLMAgent
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
@@ -48,7 +49,10 @@ async def lifespan(app: FastAPI):
bdi_core = BDICoreAgent(settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, bdi_core = BDICoreAgent(settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
"secret, ask twirre", "src/control_backend/agents/bdi/rules.asl") "secret, ask twirre", "src/control_backend/agents/bdi/rules.asl")
await bdi_core.start() await bdi_core.start()
_temp_vad_agent = VADAgent("tcp://localhost:5558", False)
await _temp_vad_agent.start()
yield yield
logger.info("%s shutting down.", app.title) logger.info("%s shutting down.", app.title)

View File

View File

@@ -0,0 +1,99 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq
from spade.agent import Agent
from control_backend.agents.vad_agent import VADAgent
@pytest.fixture
def zmq_context(mocker):
return mocker.patch("control_backend.agents.vad_agent.zmq_context")
@pytest.fixture
def streaming(mocker):
return mocker.patch("control_backend.agents.vad_agent.Streaming")
@pytest.mark.asyncio
async def test_normal_setup(streaming):
"""
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent.add_behaviour = MagicMock()
await vad_agent.setup()
streaming.assert_called_once()
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
assert vad_agent.audio_in_socket is not None
assert vad_agent.audio_out_socket is not None
@pytest.mark.parametrize("do_bind", [True, False])
def test_in_socket_creation(zmq_context, do_bind: bool):
"""
Test that the VAD agent creates an audio input socket, differentiating between binding and
connecting.
"""
vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
vad_agent._connect_audio_in_socket()
assert vad_agent.audio_in_socket is not None
zmq_context.socket.assert_called_once_with(zmq.SUB)
zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "")
if do_bind:
zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345")
else:
zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345")
def test_out_socket_creation(zmq_context):
"""
Test that the VAD agent creates an audio output socket correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._connect_audio_out_socket()
assert vad_agent.audio_out_socket is not None
zmq_context.socket.assert_called_once_with(zmq.PUB)
zmq_context.socket.return_value.bind_to_random_port.assert_called_once()
@pytest.mark.asyncio
async def test_out_socket_creation_failure(zmq_context):
"""
Test setup failure when the audio output socket cannot be created.
"""
with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop:
zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
assert vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once()
@pytest.mark.asyncio
async def test_stop(zmq_context):
"""
Test that when the VAD agent is stopped, the sockets are closed correctly.
"""
vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup()
await vad_agent.stop()
assert zmq_context.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None

View File

@@ -0,0 +1,57 @@
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
import soundfile as sf
import zmq
from control_backend.agents.vad_agent import Streaming
def get_audio_chunks() -> list[bytes]:
curr_file = os.path.realpath(__file__)
curr_dir = os.path.dirname(curr_file)
file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav"
chunk_size = 512
chunks = []
with sf.SoundFile(file, "r") as f:
assert f.samplerate == 16000
assert f.channels == 1
assert f.subtype == "FLOAT"
while True:
data = f.read(chunk_size, dtype="float32")
if len(data) != chunk_size:
break
chunks.append(data.tobytes())
return chunks
@pytest.mark.asyncio
async def test_real_audio(mocker):
"""
Test the VAD agent with only input and output mocked. Using the real model, using real audio as
input. Ensure that it outputs some fragments with audio.
"""
audio_chunks = get_audio_chunks()
audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
audio_out_socket = AsyncMock()
vad_streamer = Streaming(audio_in_socket, audio_out_socket)
for _ in audio_chunks:
await vad_streamer.run()
audio_out_socket.send.assert_called()
for args in audio_out_socket.send.call_args_list:
assert isinstance(args[0][0], bytes)
assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio

View File

@@ -0,0 +1,46 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
import zmq
from control_backend.agents.vad_agent import SocketPoller
@pytest.fixture
def socket():
return AsyncMock()
@pytest.mark.asyncio
async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test"
socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket)
# Calling `poll` twice to be able to check that the poller is reused
await poller.poll()
data = await poller.poll()
assert data == socket_data
# Ensure that the poller was reused
mock_poller.assert_called_once_with()
mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN)
assert socket.recv.call_count == 2
@pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket)
data = await poller.poll()
assert data is None
socket.recv.assert_not_called()

View File

@@ -0,0 +1,95 @@
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from control_backend.agents.vad_agent import Streaming
@pytest.fixture
def audio_in_socket():
return AsyncMock()
@pytest.fixture
def audio_out_socket():
return AsyncMock()
@pytest.fixture
def streaming(audio_in_socket, audio_out_socket):
import torch
torch.hub.load.return_value = (..., ...) # Mock
return Streaming(audio_in_socket, audio_out_socket)
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
"""
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
:param streaming: The streaming component to be tested.
:param probabilities: A list of probabilities representing the outputs of the VAD model.
"""
model_item = MagicMock()
model_item.item.side_effect = probabilities
streaming.model = MagicMock()
streaming.model.return_value = model_item
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32)
streaming.audio_in_poller = audio_in_poller
for _ in probabilities:
await streaming.run()
@pytest.mark.asyncio
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is voice activity detected between silences.
:return:
"""
speech_chunk_count = 5
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count + 2)
@pytest.mark.asyncio
async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is a short pause between speech, checking whether it ignores the
short pause.
"""
speech_chunk_count = 5
probabilities = (
[0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5
)
await simulate_streaming_with_probabilities(streaming, probabilities)
audio_out_socket.send.assert_called_once()
data = audio_out_socket.send.call_args[0][0]
assert isinstance(data, bytes)
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2)
@pytest.mark.asyncio
async def test_no_data(audio_in_socket, audio_out_socket, streaming):
"""
Test a scenario where there is no data received. This should not cause errors.
"""
audio_in_poller = AsyncMock()
audio_in_poller.poll.return_value = None
streaming.audio_in_poller = audio_in_poller
await streaming.run()
audio_out_socket.send.assert_not_called()
assert len(streaming.audio_buffer) == 0

View File

@@ -33,3 +33,13 @@ def pytest_configure(config):
mock_config_module.settings = MagicMock() mock_config_module.settings = MagicMock()
sys.modules["control_backend.core.config"] = mock_config_module sys.modules["control_backend.core.config"] = mock_config_module
# --- Mock torch and zmq for VAD ---
mock_torch = MagicMock()
mock_zmq = MagicMock()
mock_zmq.asyncio = mock_zmq
# In individual tests, these can be imported and the return values changed
sys.modules["torch"] = mock_torch
sys.modules["zmq"] = mock_zmq
sys.modules["zmq.asyncio"] = mock_zmq.asyncio

4
uv.lock generated
View File

@@ -1332,6 +1332,7 @@ source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "fastapi", extra = ["all"] }, { name = "fastapi", extra = ["all"] },
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'" }, { name = "mlx-whisper", marker = "sys_platform == 'darwin'" },
{ name = "numpy" },
{ name = "openai-whisper" }, { name = "openai-whisper" },
{ name = "pyaudio" }, { name = "pyaudio" },
{ name = "pydantic" }, { name = "pydantic" },
@@ -1358,6 +1359,7 @@ integration-test = [
{ name = "soundfile" }, { name = "soundfile" },
] ]
test = [ test = [
{ name = "numpy" },
{ name = "pytest" }, { name = "pytest" },
{ name = "pytest-asyncio" }, { name = "pytest-asyncio" },
{ name = "pytest-cov" }, { name = "pytest-cov" },
@@ -1368,6 +1370,7 @@ test = [
requires-dist = [ requires-dist = [
{ name = "fastapi", extras = ["all"], specifier = ">=0.115.6" }, { name = "fastapi", extras = ["all"], specifier = ">=0.115.6" },
{ name = "mlx-whisper", marker = "sys_platform == 'darwin'", specifier = ">=0.4.3" }, { name = "mlx-whisper", marker = "sys_platform == 'darwin'", specifier = ">=0.4.3" },
{ name = "numpy", specifier = ">=2.3.3" },
{ name = "openai-whisper", specifier = ">=20250625" }, { name = "openai-whisper", specifier = ">=20250625" },
{ name = "pyaudio", specifier = ">=0.2.14" }, { name = "pyaudio", specifier = ">=0.2.14" },
{ name = "pydantic", specifier = ">=2.12.0" }, { name = "pydantic", specifier = ">=2.12.0" },
@@ -1392,6 +1395,7 @@ dev = [
] ]
integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }] integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }]
test = [ test = [
{ name = "numpy", specifier = ">=2.3.3" },
{ name = "pytest", specifier = ">=8.4.2" }, { name = "pytest", specifier = ">=8.4.2" },
{ name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-asyncio", specifier = ">=1.2.0" },
{ name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-cov", specifier = ">=7.0.0" },