From 1c756474f27248e877ee6900554b047511b02e54 Mon Sep 17 00:00:00 2001 From: "Luijkx,S.O.H. (Storm)" Date: Thu, 6 Nov 2025 12:57:09 +0000 Subject: [PATCH] test: added tests for text_belief_extractor --- .../bdi/behaviours/text_belief_extractor.py | 26 ++- .../behaviours/test_belief_from_text.py | 187 ++++++++++++++++++ 2 files changed, 205 insertions(+), 8 deletions(-) create mode 100644 test/unit/agents/belief_from_text/behaviours/test_belief_from_text.py diff --git a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py index bc98bf1..8a8273e 100644 --- a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py +++ b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py @@ -1,4 +1,5 @@ import json +import logging from spade.behaviour import CyclicBehaviour from spade.message import Message @@ -7,6 +8,8 @@ from control_backend.core.config import settings class BeliefFromText(CyclicBehaviour): + logger = logging.getLogger(__name__) + # TODO: LLM prompt nog hardcoded llm_instruction_prompt = """ You are an information extraction assistent for a BDI agent. Your task is to extract values \ @@ -36,6 +39,9 @@ class BeliefFromText(CyclicBehaviour): async def run(self): msg = await self.receive() + if msg is None: + return + sender = msg.sender.node match sender: case settings.agent_settings.transcription_agent_name: @@ -62,10 +68,14 @@ class BeliefFromText(CyclicBehaviour): # Verify by trying to parse try: json.loads(response) - belief_message = Message( - to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, - body=response, + belief_message = Message() + + belief_message.to = ( + settings.agent_settings.belief_collector_agent_name + + "@" + + settings.agent_settings.host ) + belief_message.body = response belief_message.thread = "beliefs" await self.send(belief_message) @@ -82,12 +92,12 @@ class BeliefFromText(CyclicBehaviour): """ belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"} payload = json.dumps(belief) - belief_msg = Message( - to=settings.agent_settings.belief_collector_agent_name - + "@" - + settings.agent_settings.host, - body=payload, + belief_msg = Message() + + belief_msg.to = ( + settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host ) + belief_msg.body = payload belief_msg.thread = "beliefs" await self.send(belief_msg) diff --git a/test/unit/agents/belief_from_text/behaviours/test_belief_from_text.py b/test/unit/agents/belief_from_text/behaviours/test_belief_from_text.py new file mode 100644 index 0000000..7a3eacd --- /dev/null +++ b/test/unit/agents/belief_from_text/behaviours/test_belief_from_text.py @@ -0,0 +1,187 @@ +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from spade.message import Message + +from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText + + +@pytest.fixture +def mock_settings(): + """ + Mocks the settings object that the behaviour imports. + We patch it at the source where it's imported by the module under test. + """ + # Create a mock object that mimics the nested structure + settings_mock = MagicMock() + settings_mock.agent_settings.transcription_agent_name = "transcriber" + settings_mock.agent_settings.belief_collector_agent_name = "collector" + settings_mock.agent_settings.host = "fake.host" + + # Use patch to replace the settings object during the test + # Adjust 'control_backend.behaviours.belief_from_text.settings' to where + # your behaviour file imports it from. + with patch( + "control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock + ): + yield settings_mock + + +@pytest.fixture +def behavior(mock_settings): + """ + Creates an instance of the BeliefFromText behaviour and mocks its + agent, logger, send, and receive methods. + """ + b = BeliefFromText() + + b.agent = MagicMock() + b.send = AsyncMock() + b.receive = AsyncMock() + + return b + + +def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock: + """Helper function to create a configured mock message.""" + msg = MagicMock() + msg.sender.node = sender_node # MagicMock automatically creates nested mocks + msg.body = body + msg.thread = thread + return msg + + +@pytest.mark.asyncio +async def test_run_no_message(behavior): + """ + Tests the run() method when no message is received. + """ + # Arrange: Configure receive to return None + behavior.receive.return_value = None + + # Act: Run the behavior + await behavior.run() + + # Assert + # 1. Check that receive was called + behavior.receive.assert_called_once() + # 2. Check that no message was sent + behavior.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_run_message_from_other_agent(behavior): + """ + Tests the run() method when a message is received from an + unknown agent (not the transcriber). + """ + # Arrange: Create a mock message from an unknown sender + mock_msg = create_mock_message("unknown", "some data", None) + behavior.receive.return_value = mock_msg + behavior._process_transcription_demo = MagicMock() + + # Act + await behavior.run() + + # Assert + # 1. Check that receive was called + behavior.receive.assert_called_once() + # 2. Check that _process_transcription_demo was not sent + behavior._process_transcription_demo.assert_not_called() + + +@pytest.mark.asyncio +async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkeypatch): + """ + Tests the main success path: receiving a message from the + transcription agent, which triggers _process_transcription_demo. + """ + # Arrange: Create a mock message from the transcriber + transcription_text = "hello world" + mock_msg = create_mock_message( + mock_settings.agent_settings.transcription_agent_name, transcription_text, None + ) + behavior.receive.return_value = mock_msg + + # Act + await behavior.run() + + # Assert + # 1. Check that receive was called + behavior.receive.assert_called_once() + + # 2. Check that send was called *once* + behavior.send.assert_called_once() + + # 3. Deeply inspect the message that was sent + sent_msg: Message = behavior.send.call_args[0][0] + + assert ( + sent_msg.to + == mock_settings.agent_settings.belief_collector_agent_name + + "@" + + mock_settings.agent_settings.host + ) + + # Check thread + assert sent_msg.thread == "beliefs" + + # Parse the received JSON string back into a dict + expected_dict = { + "beliefs": {"user_said": [transcription_text]}, + "type": "belief_extraction_text", + } + sent_dict = json.loads(sent_msg.body) + + # Assert that the dictionaries are equal + assert sent_dict == expected_dict + + +@pytest.mark.asyncio +async def test_process_transcription_success(behavior, mock_settings): + """ + Tests the (currently unused) _process_transcription method's + success path, using its hardcoded mock response. + """ + # Arrange + test_text = "I am feeling happy" + # This is the hardcoded response inside the method + expected_response_body = '{"mood": [["happy"]]}' + + # Act + await behavior._process_transcription(test_text) + + # Assert + # 1. Check that a message was sent + behavior.send.assert_called_once() + + # 2. Inspect the sent message + sent_msg: Message = behavior.send.call_args[0][0] + expected_to = ( + mock_settings.agent_settings.belief_collector_agent_name + + "@" + + mock_settings.agent_settings.host + ) + assert str(sent_msg.to) == expected_to + assert sent_msg.thread == "beliefs" + assert sent_msg.body == expected_response_body + + +@pytest.mark.asyncio +async def test_process_transcription_json_decode_error(behavior, mock_settings): + """ + Tests the _process_transcription method's error handling + when the (mocked) response is invalid JSON. + We do this by patching json.loads to raise an error. + """ + # Arrange + test_text = "I am feeling happy" + # Patch json.loads to raise an error when called + with patch("json.loads", side_effect=json.JSONDecodeError("Mock error", "", 0)): + # Act + await behavior._process_transcription(test_text) + + # Assert + # 1. Check that NO message was sent + behavior.send.assert_not_called()