diff --git a/src/control_backend/agents/bdi/bdi_core_agent.py b/src/control_backend/agents/bdi/bdi_core_agent.py index 427e024..23c2808 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent.py +++ b/src/control_backend/agents/bdi/bdi_core_agent.py @@ -11,7 +11,7 @@ from pydantic import ValidationError from control_backend.agents.base import BaseAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings -from control_backend.schemas.belief_message import Belief, BeliefMessage +from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.llm_prompt_message import LLMPromptMessage from control_backend.schemas.ri_message import SpeechCommand @@ -124,8 +124,8 @@ class BDICoreAgent(BaseAgent): if msg.thread == "beliefs": try: - beliefs = BeliefMessage.model_validate_json(msg.body).beliefs - self._apply_beliefs(beliefs) + belief_changes = BeliefMessage.model_validate_json(msg.body) + self._apply_belief_changes(belief_changes) except ValidationError: self.logger.exception("Error processing belief.") return @@ -145,21 +145,28 @@ class BDICoreAgent(BaseAgent): ) await self.send(out_msg) - def _apply_beliefs(self, beliefs: list[Belief]): + def _apply_belief_changes(self, belief_changes: BeliefMessage): """ Update the belief base with a list of new beliefs. - If ``replace=True`` is set on a belief, it removes all existing beliefs with that name - before adding the new one. + For beliefs in ``belief_changes.replace``, it removes all existing beliefs with that name + before adding one new one. + + :param belief_changes: The changes in beliefs to apply. """ - if not beliefs: + if not belief_changes.create and not belief_changes.replace and not belief_changes.delete: return - for belief in beliefs: - if belief.replace: - self._remove_all_with_name(belief.name) + for belief in belief_changes.create: self._add_belief(belief.name, belief.arguments) + for belief in belief_changes.replace: + self._remove_all_with_name(belief.name) + self._add_belief(belief.name, belief.arguments) + + for belief in belief_changes.delete: + self._remove_belief(belief.name, belief.arguments) + def _add_belief(self, name: str, args: list[str] = None): """ Add a single belief to the BDI agent. diff --git a/src/control_backend/agents/bdi/belief_collector_agent.py b/src/control_backend/agents/bdi/belief_collector_agent.py index 788cff1..ac0e2e5 100644 --- a/src/control_backend/agents/bdi/belief_collector_agent.py +++ b/src/control_backend/agents/bdi/belief_collector_agent.py @@ -144,7 +144,7 @@ class BDIBeliefCollectorAgent(BaseAgent): msg = InternalMessage( to=settings.agent_settings.bdi_core_name, sender=self.name, - body=BeliefMessage(beliefs=beliefs).model_dump_json(), + body=BeliefMessage(create=beliefs).model_dump_json(), thread="beliefs", ) diff --git a/src/control_backend/agents/bdi/text_belief_extractor_agent.py b/src/control_backend/agents/bdi/text_belief_extractor_agent.py index 0324573..5cc75d8 100644 --- a/src/control_backend/agents/bdi/text_belief_extractor_agent.py +++ b/src/control_backend/agents/bdi/text_belief_extractor_agent.py @@ -34,8 +34,8 @@ class TextBeliefExtractorAgent(BaseAgent): def __init__(self, name: str): super().__init__(name) - self.beliefs = {} - self.available_beliefs = [] + self.beliefs: dict[str, bool] = {} + self.available_beliefs: list[SemanticBelief] = [] self.conversation = ChatHistory(messages=[]) async def setup(self): @@ -151,23 +151,30 @@ class TextBeliefExtractorAgent(BaseAgent): return candidate_beliefs = await self._infer_turn() - new_beliefs: list[InternalBelief] = [] + belief_changes = BeliefMessage() for belief_key, belief_value in candidate_beliefs.items(): if belief_value is None: continue old_belief_value = self.beliefs.get(belief_key) - # TODO: Do we need this check? Can we send the same beliefs multiple times? if belief_value == old_belief_value: continue + self.beliefs[belief_key] = belief_value - new_beliefs.append( - InternalBelief(name=belief_key, arguments=[belief_value], replace=True), - ) + + belief = InternalBelief(name=belief_key, arguments=None) + if belief_value: + belief_changes.create.append(belief) + else: + belief_changes.delete.append(belief) + + # Return if there were no changes in beliefs + if not belief_changes.has_values(): + return beliefs_message = InternalMessage( to=settings.agent_settings.bdi_core_name, sender=self.name, - body=BeliefMessage(beliefs=new_beliefs).model_dump_json(), + body=belief_changes.model_dump_json(), thread="beliefs", ) await self.send(beliefs_message) @@ -184,7 +191,7 @@ class TextBeliefExtractorAgent(BaseAgent): :return: A dict mapping belief names to a value ``True``, ``False`` or ``None``. """ - n_parallel = min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs)) + n_parallel = max(1, min(settings.llm_settings.n_parallel - 1, len(self.available_beliefs))) all_beliefs = await asyncio.gather( *[ self._infer_beliefs(self.conversation, beliefs) @@ -286,7 +293,7 @@ Respond with a JSON similar to the following, but with the property names as giv try: return await self._query_llm(prompt, schema) - except (httpx.HTTPStatusError, json.JSONDecodeError, KeyError) as e: + except (httpx.HTTPError, json.JSONDecodeError, KeyError) as e: if try_count < tries: continue self.logger.exception( diff --git a/src/control_backend/schemas/belief_message.py b/src/control_backend/schemas/belief_message.py index deb1152..56a8a4a 100644 --- a/src/control_backend/schemas/belief_message.py +++ b/src/control_backend/schemas/belief_message.py @@ -6,18 +6,27 @@ class Belief(BaseModel): Represents a single belief in the BDI system. :ivar name: The functor or name of the belief (e.g., 'user_said'). - :ivar arguments: A list of string arguments for the belief. - :ivar replace: If True, existing beliefs with this name should be replaced by this one. + :ivar arguments: A list of string arguments for the belief, or None if the belief has no + arguments. """ name: str - arguments: list[str] - replace: bool = False + arguments: list[str] | None class BeliefMessage(BaseModel): """ - A container for transporting a list of beliefs between agents. + A container for communicating beliefs between agents. + + :ivar create: Beliefs to create. + :ivar delete: Beliefs to delete. + :ivar replace: Beliefs to replace. Deletes all beliefs with the same name, replacing them with + one new belief. """ - beliefs: list[Belief] + create: list[Belief] = [] + delete: list[Belief] = [] + replace: list[Belief] = [] + + def has_values(self) -> bool: + return len(self.create) > 0 or len(self.delete) > 0 or len(self.replace) > 0 diff --git a/src/control_backend/schemas/program.py b/src/control_backend/schemas/program.py index 5a8caa9..be538b0 100644 --- a/src/control_backend/schemas/program.py +++ b/src/control_backend/schemas/program.py @@ -194,7 +194,7 @@ class Phase(ProgramElement): """ name: str = "" - norms: list[Norm] + norms: list[BasicNorm | ConditionalNorm] goals: list[Goal] triggers: list[Trigger] diff --git a/test/unit/agents/bdi/test_bdi_core_agent.py b/test/unit/agents/bdi/test_bdi_core_agent.py index 8d004fc..2325a57 100644 --- a/test/unit/agents/bdi/test_bdi_core_agent.py +++ b/test/unit/agents/bdi/test_bdi_core_agent.py @@ -51,7 +51,7 @@ async def test_handle_belief_collector_message(agent, mock_settings): msg = InternalMessage( to="bdi_agent", sender=mock_settings.agent_settings.bdi_belief_collector_name, - body=BeliefMessage(beliefs=beliefs).model_dump_json(), + body=BeliefMessage(create=beliefs).model_dump_json(), thread="beliefs", ) @@ -64,6 +64,26 @@ async def test_handle_belief_collector_message(agent, mock_settings): assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),)) +@pytest.mark.asyncio +async def test_handle_delete_belief_message(agent, mock_settings): + """Test that incoming beliefs to be deleted are removed from the BDI agent""" + beliefs = [Belief(name="user_said", arguments=["Hello"])] + + msg = InternalMessage( + to="bdi_agent", + sender=mock_settings.agent_settings.bdi_belief_collector_name, + body=BeliefMessage(delete=beliefs).model_dump_json(), + thread="beliefs", + ) + await agent.handle_message(msg) + + # Expect bdi_agent.call to be triggered to remove belief + args = agent.bdi_agent.call.call_args.args + assert args[0] == agentspeak.Trigger.removal + assert args[1] == agentspeak.GoalType.belief + assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),)) + + @pytest.mark.asyncio async def test_incorrect_belief_collector_message(agent, mock_settings): """Test that incorrect message format triggers an exception.""" @@ -128,7 +148,8 @@ def test_add_belief_sets_event(agent): agent._wake_bdi_loop = MagicMock() belief = Belief(name="test_belief", arguments=["a", "b"]) - agent._apply_beliefs([belief]) + belief_changes = BeliefMessage(replace=[belief]) + agent._apply_belief_changes(belief_changes) assert agent.bdi_agent.call.called agent._wake_bdi_loop.set.assert_called() @@ -137,7 +158,7 @@ def test_add_belief_sets_event(agent): def test_apply_beliefs_empty_returns(agent): """Line: if not beliefs: return""" agent._wake_bdi_loop = MagicMock() - agent._apply_beliefs([]) + agent._apply_belief_changes(BeliefMessage()) agent.bdi_agent.call.assert_not_called() agent._wake_bdi_loop.set.assert_not_called() @@ -220,8 +241,9 @@ def test_replace_belief_calls_remove_all(agent): agent._remove_all_with_name = MagicMock() agent._wake_bdi_loop = MagicMock() - belief = Belief(name="user_said", arguments=["Hello"], replace=True) - agent._apply_beliefs([belief]) + belief = Belief(name="user_said", arguments=["Hello"]) + belief_changes = BeliefMessage(replace=[belief]) + agent._apply_belief_changes(belief_changes) agent._remove_all_with_name.assert_called_with("user_said") diff --git a/test/unit/agents/bdi/test_belief_collector.py b/test/unit/agents/bdi/test_belief_collector.py index 67b2ed5..69db269 100644 --- a/test/unit/agents/bdi/test_belief_collector.py +++ b/test/unit/agents/bdi/test_belief_collector.py @@ -86,7 +86,7 @@ async def test_send_beliefs_to_bdi(agent): sent: InternalMessage = agent.send.call_args.args[0] assert sent.to == settings.agent_settings.bdi_core_name assert sent.thread == "beliefs" - assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs] + assert json.loads(sent.body)["create"] == [belief.model_dump() for belief in beliefs] @pytest.mark.asyncio diff --git a/test/unit/agents/bdi/test_text_belief_extractor.py b/test/unit/agents/bdi/test_text_belief_extractor.py new file mode 100644 index 0000000..827adbc --- /dev/null +++ b/test/unit/agents/bdi/test_text_belief_extractor.py @@ -0,0 +1,346 @@ +import json +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from control_backend.agents.bdi import TextBeliefExtractorAgent +from control_backend.core.agent_system import InternalMessage +from control_backend.core.config import settings +from control_backend.schemas.belief_message import BeliefMessage +from control_backend.schemas.program import ( + ConditionalNorm, + LLMAction, + Phase, + Plan, + Program, + SemanticBelief, + Trigger, +) + + +@pytest.fixture +def agent(): + agent = TextBeliefExtractorAgent("text_belief_agent") + agent.send = AsyncMock() + agent._query_llm = AsyncMock() + return agent + + +@pytest.fixture +def sample_program(): + return Program( + phases=[ + Phase( + name="Some phase", + id=uuid.uuid4(), + norms=[ + ConditionalNorm( + name="Some norm", + id=uuid.uuid4(), + norm="Use nautical terms.", + critical=False, + condition=SemanticBelief( + name="is_pirate", + id=uuid.uuid4(), + description="The user is a pirate. Perhaps because they say " + "they are, or because they speak like a pirate " + 'with terms like "arr".', + ), + ), + ], + goals=[], + triggers=[ + Trigger( + name="Some trigger", + id=uuid.uuid4(), + condition=SemanticBelief( + name="no_more_booze", + id=uuid.uuid4(), + description="There is no more alcohol.", + ), + plan=Plan( + name="Some plan", + id=uuid.uuid4(), + steps=[ + LLMAction( + name="Some action", + id=uuid.uuid4(), + goal="Suggest eating chocolate instead.", + ), + ], + ), + ), + ], + ), + ], + ) + + +def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage: + return InternalMessage(to="unused", sender=sender, body=body, thread=thread) + + +@pytest.mark.asyncio +async def test_handle_message_ignores_other_agents(agent): + msg = make_msg("unknown", "some data", None) + + await agent.handle_message(msg) + + agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it. + + +@pytest.mark.asyncio +async def test_handle_message_from_transcriber(agent, mock_settings): + transcription = "hello world" + msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None) + + await agent.handle_message(msg) + + agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it. + sent: InternalMessage = agent.send.call_args.args[0] # noqa + assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name + assert sent.thread == "beliefs" + parsed = json.loads(sent.body) + assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"} + + +@pytest.mark.asyncio +async def test_process_user_said(agent, mock_settings): + transcription = "this is a test" + + await agent._user_said(transcription) + + agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it. + sent: InternalMessage = agent.send.call_args.args[0] # noqa + assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name + assert sent.thread == "beliefs" + parsed = json.loads(sent.body) + assert parsed["beliefs"]["user_said"] == [transcription] + + +@pytest.mark.asyncio +async def test_query_llm(): + mock_response = MagicMock() + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": "null", + } + } + ] + } + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_async_client = MagicMock() + mock_async_client.__aenter__.return_value = mock_client + mock_async_client.__aexit__.return_value = None + + with patch( + "control_backend.agents.bdi.text_belief_extractor_agent.httpx.AsyncClient", + return_value=mock_async_client, + ): + agent = TextBeliefExtractorAgent("text_belief_agent") + + res = await agent._query_llm("hello world", {"type": "null"}) + # Response content was set as "null", so should be deserialized as None + assert res is None + + +@pytest.mark.asyncio +async def test_retry_query_llm_success(agent): + agent._query_llm.return_value = None + res = await agent._retry_query_llm("hello world", {"type": "null"}) + + agent._query_llm.assert_called_once() + assert res is None + + +@pytest.mark.asyncio +async def test_retry_query_llm_success_after_failure(agent): + agent._query_llm.side_effect = [KeyError(), "real value"] + res = await agent._retry_query_llm("hello world", {"type": "string"}) + + assert agent._query_llm.call_count == 2 + assert res == "real value" + + +@pytest.mark.asyncio +async def test_retry_query_llm_failures(agent): + agent._query_llm.side_effect = [KeyError(), KeyError(), KeyError(), "real value"] + res = await agent._retry_query_llm("hello world", {"type": "string"}) + + assert agent._query_llm.call_count == 3 + assert res is None + + +@pytest.mark.asyncio +async def test_retry_query_llm_fail_immediately(agent): + agent._query_llm.side_effect = [KeyError(), "real value"] + res = await agent._retry_query_llm("hello world", {"type": "string"}, tries=1) + + assert agent._query_llm.call_count == 1 + assert res is None + + +@pytest.mark.asyncio +async def test_extracting_beliefs_from_program(agent, sample_program): + assert len(agent.available_beliefs) == 0 + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.bdi_program_manager_name, + body=sample_program.model_dump_json(), + ), + ) + assert len(agent.available_beliefs) == 2 + + +@pytest.mark.asyncio +async def test_handle_invalid_program(agent, sample_program): + agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + assert len(agent.available_beliefs) == 2 + + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.bdi_program_manager_name, + body=json.dumps({"phases": "Invalid"}), + ), + ) + + assert len(agent.available_beliefs) == 2 + + +@pytest.mark.asyncio +async def test_handle_robot_response(agent): + initial_length = len(agent.conversation.messages) + response = "Hi, I'm Pepper. What's your name?" + + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.llm_name, + body=response, + ), + ) + + assert len(agent.conversation.messages) == initial_length + 1 + assert agent.conversation.messages[-1].role == "assistant" + assert agent.conversation.messages[-1].content == response + + +@pytest.mark.asyncio +async def test_simulated_real_turn_with_beliefs(agent, sample_program): + """Test sending user message to extract beliefs from.""" + agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + + # Send a user message with the belief that there's no more booze + agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": True} + assert len(agent.conversation.messages) == 0 + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.transcription_name, + body="We're all out of schnaps.", + ), + ) + assert len(agent.conversation.messages) == 1 + + # There should be a belief set and sent to the BDI core, as well as the user_said belief + assert agent.send.call_count == 2 + + # First should be the beliefs message + message: InternalMessage = agent.send.call_args_list[0].args[0] + beliefs = BeliefMessage.model_validate_json(message.body) + assert len(beliefs.create) == 1 + assert beliefs.create[0].name == "no_more_booze" + + +@pytest.mark.asyncio +async def test_simulated_real_turn_no_beliefs(agent, sample_program): + """Test a user message to extract beliefs from, but no beliefs are formed.""" + agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + + # Send a user message with no new beliefs + agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": None} + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.transcription_name, + body="Hello there!", + ), + ) + + # Only the user_said belief should've been sent + agent.send.assert_called_once() + + +@pytest.mark.asyncio +async def test_simulated_real_turn_no_new_beliefs(agent, sample_program): + """ + Test a user message to extract beliefs from, but no new beliefs are formed because they already + existed. + """ + agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + agent.beliefs["is_pirate"] = True + + # Send a user message with the belief the user is a pirate, still + agent._query_llm.return_value = {"is_pirate": True, "no_more_booze": None} + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.transcription_name, + body="Arr, nice to meet you, matey.", + ), + ) + + # Only the user_said belief should've been sent, as no beliefs have changed + agent.send.assert_called_once() + + +@pytest.mark.asyncio +async def test_simulated_real_turn_remove_belief(agent, sample_program): + """ + Test a user message to extract beliefs from, but an existing belief is determined no longer to + hold. + """ + agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + agent.beliefs["no_more_booze"] = True + + # Send a user message with the belief the user is a pirate, still + agent._query_llm.return_value = {"is_pirate": None, "no_more_booze": False} + await agent.handle_message( + InternalMessage( + to=settings.agent_settings.text_belief_extractor_name, + sender=settings.agent_settings.transcription_name, + body="I found an untouched barrel of wine!", + ), + ) + + # Both user_said and belief change should've been sent + assert agent.send.call_count == 2 + + # Agent's current beliefs should've changed + assert not agent.beliefs["no_more_booze"] + + +@pytest.mark.asyncio +async def test_llm_failure_handling(agent, sample_program): + """ + Check that the agent handles failures gracefully without crashing. + """ + agent._query_llm.side_effect = httpx.HTTPError("") + agent.available_beliefs.append(sample_program.phases[0].norms[0].condition) + agent.available_beliefs.append(sample_program.phases[0].triggers[0].condition) + + belief_changes = await agent._infer_turn() + + assert len(belief_changes) == 0 diff --git a/test/unit/agents/bdi/test_text_extractor.py b/test/unit/agents/bdi/test_text_extractor.py deleted file mode 100644 index c51571a..0000000 --- a/test/unit/agents/bdi/test_text_extractor.py +++ /dev/null @@ -1,58 +0,0 @@ -import json -from unittest.mock import AsyncMock - -import pytest - -from control_backend.agents.bdi import ( - TextBeliefExtractorAgent, -) -from control_backend.core.agent_system import InternalMessage - - -@pytest.fixture -def agent(): - agent = TextBeliefExtractorAgent("text_belief_agent") - agent.send = AsyncMock() - return agent - - -def make_msg(sender: str, body: str, thread: str | None = None) -> InternalMessage: - return InternalMessage(to="unused", sender=sender, body=body, thread=thread) - - -@pytest.mark.asyncio -async def test_handle_message_ignores_other_agents(agent): - msg = make_msg("unknown", "some data", None) - - await agent.handle_message(msg) - - agent.send.assert_not_called() # noqa # `agent.send` has no such property, but we mock it. - - -@pytest.mark.asyncio -async def test_handle_message_from_transcriber(agent, mock_settings): - transcription = "hello world" - msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None) - - await agent.handle_message(msg) - - agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it. - sent: InternalMessage = agent.send.call_args.args[0] # noqa - assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name - assert sent.thread == "beliefs" - parsed = json.loads(sent.body) - assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"} - - -@pytest.mark.asyncio -async def test_process_user_said(agent, mock_settings): - transcription = "this is a test" - - await agent._user_said(transcription) - - agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it. - sent: InternalMessage = agent.send.call_args.args[0] # noqa - assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name - assert sent.thread == "beliefs" - parsed = json.loads(sent.body) - assert parsed["beliefs"]["user_said"] == [transcription] diff --git a/test/unit/schemas/test_ui_program_message.py b/test/unit/schemas/test_ui_program_message.py index a9f96dd..6014db7 100644 --- a/test/unit/schemas/test_ui_program_message.py +++ b/test/unit/schemas/test_ui_program_message.py @@ -5,11 +5,15 @@ from pydantic import ValidationError from control_backend.schemas.program import ( BasicNorm, + ConditionalNorm, Goal, + InferredBelief, KeywordBelief, + LogicalOperator, Phase, Plan, Program, + SemanticBelief, Trigger, ) @@ -97,3 +101,104 @@ def test_invalid_program(): bad = invalid_program() with pytest.raises(ValidationError): Program.model_validate(bad) + + +def test_conditional_norm_parsing(): + """ + Check that pydantic is able to preserve the type of the norm, that it doesn't lose its + "condition" field when serializing and deserializing. + """ + norm = ConditionalNorm( + name="testNormName", + id=uuid.uuid4(), + norm="testNormNorm", + critical=False, + condition=KeywordBelief( + name="testKeywordBelief", + id=uuid.uuid4(), + keyword="testKeywordBelief", + ), + ) + program = Program( + phases=[ + Phase( + name="Some phase", + id=uuid.uuid4(), + norms=[norm], + goals=[], + triggers=[], + ), + ], + ) + + parsed_program = Program.model_validate_json(program.model_dump_json()) + parsed_norm = parsed_program.phases[0].norms[0] + + assert hasattr(parsed_norm, "condition") + assert isinstance(parsed_norm, ConditionalNorm) + + +def test_belief_type_parsing(): + """ + Check that pydantic is able to discern between the different types of beliefs when serializing + and deserializing. + """ + keyword_belief = KeywordBelief( + name="testKeywordBelief", + id=uuid.uuid4(), + keyword="something", + ) + semantic_belief = SemanticBelief( + name="testSemanticBelief", + id=uuid.uuid4(), + description="something", + ) + inferred_belief = InferredBelief( + name="testInferredBelief", + id=uuid.uuid4(), + operator=LogicalOperator.OR, + left=keyword_belief, + right=semantic_belief, + ) + + program = Program( + phases=[ + Phase( + name="Some phase", + id=uuid.uuid4(), + norms=[], + goals=[], + triggers=[ + Trigger( + name="testTriggerKeywordTrigger", + id=uuid.uuid4(), + condition=keyword_belief, + plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]), + ), + Trigger( + name="testTriggerSemanticTrigger", + id=uuid.uuid4(), + condition=semantic_belief, + plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]), + ), + Trigger( + name="testTriggerInferredTrigger", + id=uuid.uuid4(), + condition=inferred_belief, + plan=Plan(name="testTriggerPlanName", id=uuid.uuid4(), steps=[]), + ), + ], + ), + ], + ) + + parsed_program = Program.model_validate_json(program.model_dump_json()) + + parsed_keyword_belief = parsed_program.phases[0].triggers[0].condition + assert isinstance(parsed_keyword_belief, KeywordBelief) + + parsed_semantic_belief = parsed_program.phases[0].triggers[1].condition + assert isinstance(parsed_semantic_belief, SemanticBelief) + + parsed_inferred_belief = parsed_program.phases[0].triggers[2].condition + assert isinstance(parsed_inferred_belief, InferredBelief)