refactor: remove constants and put in config file #24
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from spade.behaviour import CyclicBehaviour
|
from spade.behaviour import CyclicBehaviour
|
||||||
from spade.message import Message
|
from spade.message import Message
|
||||||
@@ -7,6 +8,8 @@ from control_backend.core.config import settings
|
|||||||
|
|
||||||
|
|
||||||
class BeliefFromText(CyclicBehaviour):
|
class BeliefFromText(CyclicBehaviour):
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO: LLM prompt nog hardcoded
|
# TODO: LLM prompt nog hardcoded
|
||||||
llm_instruction_prompt = """
|
llm_instruction_prompt = """
|
||||||
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
||||||
@@ -36,6 +39,9 @@ class BeliefFromText(CyclicBehaviour):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
msg = await self.receive()
|
msg = await self.receive()
|
||||||
|
if msg is None:
|
||||||
|
return
|
||||||
|
|
||||||
sender = msg.sender.node
|
sender = msg.sender.node
|
||||||
match sender:
|
match sender:
|
||||||
case settings.agent_settings.transcription_agent_name:
|
case settings.agent_settings.transcription_agent_name:
|
||||||
@@ -62,10 +68,14 @@ class BeliefFromText(CyclicBehaviour):
|
|||||||
# Verify by trying to parse
|
# Verify by trying to parse
|
||||||
try:
|
try:
|
||||||
json.loads(response)
|
json.loads(response)
|
||||||
belief_message = Message(
|
belief_message = Message()
|
||||||
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
|
|
||||||
body=response,
|
belief_message.to = (
|
||||||
|
settings.agent_settings.belief_collector_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ settings.agent_settings.host
|
||||||
)
|
)
|
||||||
|
belief_message.body = response
|
||||||
belief_message.thread = "beliefs"
|
belief_message.thread = "beliefs"
|
||||||
|
|
||||||
await self.send(belief_message)
|
await self.send(belief_message)
|
||||||
@@ -82,12 +92,12 @@ class BeliefFromText(CyclicBehaviour):
|
|||||||
"""
|
"""
|
||||||
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
|
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
|
||||||
payload = json.dumps(belief)
|
payload = json.dumps(belief)
|
||||||
belief_msg = Message(
|
belief_msg = Message()
|
||||||
to=settings.agent_settings.belief_collector_agent_name
|
|
||||||
+ "@"
|
belief_msg.to = (
|
||||||
+ settings.agent_settings.host,
|
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host
|
||||||
body=payload,
|
|
||||||
)
|
)
|
||||||
|
belief_msg.body = payload
|
||||||
belief_msg.thread = "beliefs"
|
belief_msg.thread = "beliefs"
|
||||||
|
|
||||||
await self.send(belief_msg)
|
await self.send(belief_msg)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import logging
|
|||||||
import zmq
|
import zmq
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pyjabber.server_parameters import json
|
|
||||||
from zmq.asyncio import Context
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
@@ -13,6 +12,7 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# DO NOT LOG INSIDE THIS FUNCTION
|
||||||
@router.get("/logs/stream")
|
@router.get("/logs/stream")
|
||||||
async def log_stream():
|
async def log_stream():
|
||||||
context = Context.instance()
|
context = Context.instance()
|
||||||
@@ -27,7 +27,6 @@ async def log_stream():
|
|||||||
while True:
|
while True:
|
||||||
_, message = await socket.recv_multipart()
|
_, message = await socket.recv_multipart()
|
||||||
message = message.decode().strip()
|
message = message.decode().strip()
|
||||||
json_data = json.dumps(message)
|
yield f"data: {message}\n\n"
|
||||||
yield f"data: {json_data}\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||||
|
|||||||
@@ -0,0 +1,187 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from spade.message import Message
|
||||||
|
|
||||||
|
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
"""
|
||||||
|
Mocks the settings object that the behaviour imports.
|
||||||
|
We patch it at the source where it's imported by the module under test.
|
||||||
|
"""
|
||||||
|
# Create a mock object that mimics the nested structure
|
||||||
|
settings_mock = MagicMock()
|
||||||
|
settings_mock.agent_settings.transcription_agent_name = "transcriber"
|
||||||
|
settings_mock.agent_settings.belief_collector_agent_name = "collector"
|
||||||
|
settings_mock.agent_settings.host = "fake.host"
|
||||||
|
|
||||||
|
# Use patch to replace the settings object during the test
|
||||||
|
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where
|
||||||
|
# your behaviour file imports it from.
|
||||||
|
with patch(
|
||||||
|
"control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock
|
||||||
|
):
|
||||||
|
yield settings_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def behavior(mock_settings):
|
||||||
|
"""
|
||||||
|
Creates an instance of the BeliefFromText behaviour and mocks its
|
||||||
|
agent, logger, send, and receive methods.
|
||||||
|
"""
|
||||||
|
b = BeliefFromText()
|
||||||
|
|
||||||
|
b.agent = MagicMock()
|
||||||
|
b.send = AsyncMock()
|
||||||
|
b.receive = AsyncMock()
|
||||||
|
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
|
||||||
|
"""Helper function to create a configured mock message."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
|
||||||
|
msg.body = body
|
||||||
|
msg.thread = thread
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_no_message(behavior):
|
||||||
|
"""
|
||||||
|
Tests the run() method when no message is received.
|
||||||
|
"""
|
||||||
|
# Arrange: Configure receive to return None
|
||||||
|
behavior.receive.return_value = None
|
||||||
|
|
||||||
|
# Act: Run the behavior
|
||||||
|
await behavior.run()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that receive was called
|
||||||
|
behavior.receive.assert_called_once()
|
||||||
|
# 2. Check that no message was sent
|
||||||
|
behavior.send.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_message_from_other_agent(behavior):
|
||||||
|
"""
|
||||||
|
Tests the run() method when a message is received from an
|
||||||
|
unknown agent (not the transcriber).
|
||||||
|
"""
|
||||||
|
# Arrange: Create a mock message from an unknown sender
|
||||||
|
mock_msg = create_mock_message("unknown", "some data", None)
|
||||||
|
behavior.receive.return_value = mock_msg
|
||||||
|
behavior._process_transcription_demo = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await behavior.run()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that receive was called
|
||||||
|
behavior.receive.assert_called_once()
|
||||||
|
# 2. Check that _process_transcription_demo was not sent
|
||||||
|
behavior._process_transcription_demo.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkeypatch):
|
||||||
|
"""
|
||||||
|
Tests the main success path: receiving a message from the
|
||||||
|
transcription agent, which triggers _process_transcription_demo.
|
||||||
|
"""
|
||||||
|
# Arrange: Create a mock message from the transcriber
|
||||||
|
transcription_text = "hello world"
|
||||||
|
mock_msg = create_mock_message(
|
||||||
|
mock_settings.agent_settings.transcription_agent_name, transcription_text, None
|
||||||
|
)
|
||||||
|
behavior.receive.return_value = mock_msg
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await behavior.run()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that receive was called
|
||||||
|
behavior.receive.assert_called_once()
|
||||||
|
|
||||||
|
# 2. Check that send was called *once*
|
||||||
|
behavior.send.assert_called_once()
|
||||||
|
|
||||||
|
# 3. Deeply inspect the message that was sent
|
||||||
|
sent_msg: Message = behavior.send.call_args[0][0]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sent_msg.to
|
||||||
|
== mock_settings.agent_settings.belief_collector_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ mock_settings.agent_settings.host
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check thread
|
||||||
|
assert sent_msg.thread == "beliefs"
|
||||||
|
|
||||||
|
# Parse the received JSON string back into a dict
|
||||||
|
expected_dict = {
|
||||||
|
"beliefs": {"user_said": [transcription_text]},
|
||||||
|
"type": "belief_extraction_text",
|
||||||
|
}
|
||||||
|
sent_dict = json.loads(sent_msg.body)
|
||||||
|
|
||||||
|
# Assert that the dictionaries are equal
|
||||||
|
assert sent_dict == expected_dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_transcription_success(behavior, mock_settings):
|
||||||
|
"""
|
||||||
|
Tests the (currently unused) _process_transcription method's
|
||||||
|
success path, using its hardcoded mock response.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
test_text = "I am feeling happy"
|
||||||
|
# This is the hardcoded response inside the method
|
||||||
|
expected_response_body = '{"mood": [["happy"]]}'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await behavior._process_transcription(test_text)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that a message was sent
|
||||||
|
behavior.send.assert_called_once()
|
||||||
|
|
||||||
|
# 2. Inspect the sent message
|
||||||
|
sent_msg: Message = behavior.send.call_args[0][0]
|
||||||
|
expected_to = (
|
||||||
|
mock_settings.agent_settings.belief_collector_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ mock_settings.agent_settings.host
|
||||||
|
)
|
||||||
|
assert str(sent_msg.to) == expected_to
|
||||||
|
assert sent_msg.thread == "beliefs"
|
||||||
|
assert sent_msg.body == expected_response_body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_transcription_json_decode_error(behavior, mock_settings):
|
||||||
|
"""
|
||||||
|
Tests the _process_transcription method's error handling
|
||||||
|
when the (mocked) response is invalid JSON.
|
||||||
|
We do this by patching json.loads to raise an error.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
test_text = "I am feeling happy"
|
||||||
|
# Patch json.loads to raise an error when called
|
||||||
|
with patch("json.loads", side_effect=json.JSONDecodeError("Mock error", "", 0)):
|
||||||
|
# Act
|
||||||
|
await behavior._process_transcription(test_text)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that NO message was sent
|
||||||
|
behavior.send.assert_not_called()
|
||||||
@@ -35,13 +35,24 @@ def streaming(audio_in_socket, audio_out_socket, mock_agent):
|
|||||||
return streaming
|
return streaming
|
||||||
|
|
||||||
|
|
||||||
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
@pytest.fixture(autouse=True)
|
||||||
"""
|
def patch_settings(monkeypatch):
|
||||||
Simulates a streaming scenario with given VAD model probabilities for testing purposes.
|
# Patch the settings that vad_agent.run() reads
|
||||||
|
from control_backend.agents import vad_agent
|
||||||
|
|
||||||
:param streaming: The streaming component to be tested.
|
monkeypatch.setattr(
|
||||||
:param probabilities: A list of probabilities representing the outputs of the VAD model.
|
vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False
|
||||||
"""
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
vad_agent.settings.behaviour_settings, "vad_non_speech_patience_chunks", 2, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
vad_agent.settings.behaviour_settings, "vad_initial_since_speech", 0, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(vad_agent.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]):
|
||||||
model_item = MagicMock()
|
model_item = MagicMock()
|
||||||
model_item.item.side_effect = probabilities
|
model_item.item.side_effect = probabilities
|
||||||
streaming.model = MagicMock()
|
streaming.model = MagicMock()
|
||||||
@@ -57,10 +68,6 @@ async def simulate_streaming_with_probabilities(streaming, probabilities: list[f
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming):
|
||||||
"""
|
|
||||||
Test a scenario where there is voice activity detected between silences.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speech_chunk_count = 5
|
speech_chunk_count = 5
|
||||||
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5
|
||||||
await simulate_streaming_with_probabilities(streaming, probabilities)
|
await simulate_streaming_with_probabilities(streaming, probabilities)
|
||||||
@@ -68,8 +75,7 @@ async def test_voice_activity_detected(audio_in_socket, audio_out_socket, stream
|
|||||||
audio_out_socket.send.assert_called_once()
|
audio_out_socket.send.assert_called_once()
|
||||||
data = audio_out_socket.send.call_args[0][0]
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
assert isinstance(data, bytes)
|
assert isinstance(data, bytes)
|
||||||
# each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding)
|
assert len(data) == 512 * 4 * (speech_chunk_count + 1)
|
||||||
assert len(data) == 512 * 4 * (speech_chunk_count + 2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -87,8 +93,8 @@ async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, str
|
|||||||
audio_out_socket.send.assert_called_once()
|
audio_out_socket.send.assert_called_once()
|
||||||
data = audio_out_socket.send.call_args[0][0]
|
data = audio_out_socket.send.call_args[0][0]
|
||||||
assert isinstance(data, bytes)
|
assert isinstance(data, bytes)
|
||||||
# Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding)
|
# Expecting 13 chunks (2*5 with speech, 1 pause between, 1 as padding)
|
||||||
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2)
|
assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from control_backend.agents.transcription.speech_recognizer import (
|
from control_backend.agents.transcription.speech_recognizer import (
|
||||||
OpenAIWhisperSpeechRecognizer,
|
OpenAIWhisperSpeechRecognizer,
|
||||||
@@ -6,6 +7,24 @@ from control_backend.agents.transcription.speech_recognizer import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_sr_settings(monkeypatch):
|
||||||
|
# Patch the *module-local* settings that SpeechRecognizer imported
|
||||||
|
from control_backend.agents.transcription import speech_recognizer as sr
|
||||||
|
|
||||||
|
# Provide real numbers for everything _estimate_max_tokens() reads
|
||||||
|
monkeypatch.setattr(sr.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
sr.settings.behaviour_settings, "transcription_words_per_minute", 450, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
sr.settings.behaviour_settings, "transcription_words_per_token", 0.75, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
sr.settings.behaviour_settings, "transcription_token_buffer", 10, raising=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_estimate_max_tokens():
|
def test_estimate_max_tokens():
|
||||||
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
|
"""Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding,
|
||||||
expecting 610 tokens."""
|
expecting 610 tokens."""
|
||||||
|
|||||||
Reference in New Issue
Block a user