test: added tests for text_belief_extractor
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from spade.behaviour import CyclicBehaviour
|
from spade.behaviour import CyclicBehaviour
|
||||||
from spade.message import Message
|
from spade.message import Message
|
||||||
@@ -7,6 +8,8 @@ from control_backend.core.config import settings
|
|||||||
|
|
||||||
|
|
||||||
class BeliefFromText(CyclicBehaviour):
|
class BeliefFromText(CyclicBehaviour):
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO: LLM prompt nog hardcoded
|
# TODO: LLM prompt nog hardcoded
|
||||||
llm_instruction_prompt = """
|
llm_instruction_prompt = """
|
||||||
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
||||||
@@ -36,6 +39,9 @@ class BeliefFromText(CyclicBehaviour):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
msg = await self.receive()
|
msg = await self.receive()
|
||||||
|
if msg is None:
|
||||||
|
return
|
||||||
|
|
||||||
sender = msg.sender.node
|
sender = msg.sender.node
|
||||||
match sender:
|
match sender:
|
||||||
case settings.agent_settings.transcription_agent_name:
|
case settings.agent_settings.transcription_agent_name:
|
||||||
@@ -62,10 +68,14 @@ class BeliefFromText(CyclicBehaviour):
|
|||||||
# Verify by trying to parse
|
# Verify by trying to parse
|
||||||
try:
|
try:
|
||||||
json.loads(response)
|
json.loads(response)
|
||||||
belief_message = Message(
|
belief_message = Message()
|
||||||
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
|
|
||||||
body=response,
|
belief_message.to = (
|
||||||
|
settings.agent_settings.belief_collector_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ settings.agent_settings.host
|
||||||
)
|
)
|
||||||
|
belief_message.body = response
|
||||||
belief_message.thread = "beliefs"
|
belief_message.thread = "beliefs"
|
||||||
|
|
||||||
await self.send(belief_message)
|
await self.send(belief_message)
|
||||||
@@ -82,12 +92,12 @@ class BeliefFromText(CyclicBehaviour):
|
|||||||
"""
|
"""
|
||||||
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
|
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
|
||||||
payload = json.dumps(belief)
|
payload = json.dumps(belief)
|
||||||
belief_msg = Message(
|
belief_msg = Message()
|
||||||
to=settings.agent_settings.belief_collector_agent_name
|
|
||||||
+ "@"
|
belief_msg.to = (
|
||||||
+ settings.agent_settings.host,
|
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host
|
||||||
body=payload,
|
|
||||||
)
|
)
|
||||||
|
belief_msg.body = payload
|
||||||
belief_msg.thread = "beliefs"
|
belief_msg.thread = "beliefs"
|
||||||
|
|
||||||
await self.send(belief_msg)
|
await self.send(belief_msg)
|
||||||
|
|||||||
@@ -0,0 +1,187 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from spade.message import Message
|
||||||
|
|
||||||
|
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
"""
|
||||||
|
Mocks the settings object that the behaviour imports.
|
||||||
|
We patch it at the source where it's imported by the module under test.
|
||||||
|
"""
|
||||||
|
# Create a mock object that mimics the nested structure
|
||||||
|
settings_mock = MagicMock()
|
||||||
|
settings_mock.agent_settings.transcription_agent_name = "transcriber"
|
||||||
|
settings_mock.agent_settings.belief_collector_agent_name = "collector"
|
||||||
|
settings_mock.agent_settings.host = "fake.host"
|
||||||
|
|
||||||
|
# Use patch to replace the settings object during the test
|
||||||
|
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where
|
||||||
|
# your behaviour file imports it from.
|
||||||
|
with patch(
|
||||||
|
"control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock
|
||||||
|
):
|
||||||
|
yield settings_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def behavior(mock_settings):
|
||||||
|
"""
|
||||||
|
Creates an instance of the BeliefFromText behaviour and mocks its
|
||||||
|
agent, logger, send, and receive methods.
|
||||||
|
"""
|
||||||
|
b = BeliefFromText()
|
||||||
|
|
||||||
|
b.agent = MagicMock()
|
||||||
|
b.send = AsyncMock()
|
||||||
|
b.receive = AsyncMock()
|
||||||
|
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
|
||||||
|
"""Helper function to create a configured mock message."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
|
||||||
|
msg.body = body
|
||||||
|
msg.thread = thread
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_no_message(behavior):
|
||||||
|
"""
|
||||||
|
Tests the run() method when no message is received.
|
||||||
|
"""
|
||||||
|
# Arrange: Configure receive to return None
|
||||||
|
behavior.receive.return_value = None
|
||||||
|
|
||||||
|
# Act: Run the behavior
|
||||||
|
await behavior.run()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that receive was called
|
||||||
|
behavior.receive.assert_called_once()
|
||||||
|
# 2. Check that no message was sent
|
||||||
|
behavior.send.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_message_from_other_agent(behavior):
|
||||||
|
"""
|
||||||
|
Tests the run() method when a message is received from an
|
||||||
|
unknown agent (not the transcriber).
|
||||||
|
"""
|
||||||
|
# Arrange: Create a mock message from an unknown sender
|
||||||
|
mock_msg = create_mock_message("unknown", "some data", None)
|
||||||
|
behavior.receive.return_value = mock_msg
|
||||||
|
behavior._process_transcription_demo = MagicMock()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await behavior.run()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that receive was called
|
||||||
|
behavior.receive.assert_called_once()
|
||||||
|
# 2. Check that _process_transcription_demo was not sent
|
||||||
|
behavior._process_transcription_demo.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkeypatch):
|
||||||
|
"""
|
||||||
|
Tests the main success path: receiving a message from the
|
||||||
|
transcription agent, which triggers _process_transcription_demo.
|
||||||
|
"""
|
||||||
|
# Arrange: Create a mock message from the transcriber
|
||||||
|
transcription_text = "hello world"
|
||||||
|
mock_msg = create_mock_message(
|
||||||
|
mock_settings.agent_settings.transcription_agent_name, transcription_text, None
|
||||||
|
)
|
||||||
|
behavior.receive.return_value = mock_msg
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await behavior.run()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that receive was called
|
||||||
|
behavior.receive.assert_called_once()
|
||||||
|
|
||||||
|
# 2. Check that send was called *once*
|
||||||
|
behavior.send.assert_called_once()
|
||||||
|
|
||||||
|
# 3. Deeply inspect the message that was sent
|
||||||
|
sent_msg: Message = behavior.send.call_args[0][0]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sent_msg.to
|
||||||
|
== mock_settings.agent_settings.belief_collector_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ mock_settings.agent_settings.host
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check thread
|
||||||
|
assert sent_msg.thread == "beliefs"
|
||||||
|
|
||||||
|
# Parse the received JSON string back into a dict
|
||||||
|
expected_dict = {
|
||||||
|
"beliefs": {"user_said": [transcription_text]},
|
||||||
|
"type": "belief_extraction_text",
|
||||||
|
}
|
||||||
|
sent_dict = json.loads(sent_msg.body)
|
||||||
|
|
||||||
|
# Assert that the dictionaries are equal
|
||||||
|
assert sent_dict == expected_dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_transcription_success(behavior, mock_settings):
|
||||||
|
"""
|
||||||
|
Tests the (currently unused) _process_transcription method's
|
||||||
|
success path, using its hardcoded mock response.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
test_text = "I am feeling happy"
|
||||||
|
# This is the hardcoded response inside the method
|
||||||
|
expected_response_body = '{"mood": [["happy"]]}'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
await behavior._process_transcription(test_text)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that a message was sent
|
||||||
|
behavior.send.assert_called_once()
|
||||||
|
|
||||||
|
# 2. Inspect the sent message
|
||||||
|
sent_msg: Message = behavior.send.call_args[0][0]
|
||||||
|
expected_to = (
|
||||||
|
mock_settings.agent_settings.belief_collector_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ mock_settings.agent_settings.host
|
||||||
|
)
|
||||||
|
assert str(sent_msg.to) == expected_to
|
||||||
|
assert sent_msg.thread == "beliefs"
|
||||||
|
assert sent_msg.body == expected_response_body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_transcription_json_decode_error(behavior, mock_settings):
|
||||||
|
"""
|
||||||
|
Tests the _process_transcription method's error handling
|
||||||
|
when the (mocked) response is invalid JSON.
|
||||||
|
We do this by patching json.loads to raise an error.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
test_text = "I am feeling happy"
|
||||||
|
# Patch json.loads to raise an error when called
|
||||||
|
with patch("json.loads", side_effect=json.JSONDecodeError("Mock error", "", 0)):
|
||||||
|
# Act
|
||||||
|
await behavior._process_transcription(test_text)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check that NO message was sent
|
||||||
|
behavior.send.assert_not_called()
|
||||||
Reference in New Issue
Block a user