diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index 83dea93..010b9a0 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -60,24 +60,41 @@ class BDIProgramManager(BaseAgent): await self.send(message) self.logger.debug("Sent new norms and goals to the BDI agent.") + async def _send_clear_llm_history(self): + """ + Clear the LLM Agent's conversation history. + + Sends an empty history to the LLM Agent to reset its state. + """ + message = InternalMessage( + to=settings.agent_settings.llm_name, + sender=self.name, + body="clear_history", + threads="clear history message", + ) + await self.send(message) + self.logger.debug("Sent message to LLM agent to clear history.") + async def _receive_programs(self): """ Continuous loop that receives program updates from the HTTP endpoint. It listens to the ``program`` topic on the internal ZMQ SUB socket. When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`. + Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`. """ while True: topic, body = await self.sub_socket.recv_multipart() try: program = Program.model_validate_json(body) + await self._send_to_bdi(program) + await self._send_clear_llm_history() + except ValidationError: self.logger.exception("Received an invalid program.") continue - await self._send_to_bdi(program) - async def setup(self): """ Initialize the agent. @@ -92,3 +109,4 @@ class BDIProgramManager(BaseAgent): self.sub_socket.subscribe("program") self.add_behavior(self._receive_programs()) + # self.add_behavior(self._reset_llm_on_new_program()) diff --git a/src/control_backend/agents/llm/llm_agent.py b/src/control_backend/agents/llm/llm_agent.py index 0263b30..f1c70c9 100644 --- a/src/control_backend/agents/llm/llm_agent.py +++ b/src/control_backend/agents/llm/llm_agent.py @@ -52,6 +52,10 @@ class LLMAgent(BaseAgent): await self._process_bdi_message(prompt_message) except ValidationError: self.logger.debug("Prompt message from BDI core is invalid.") + elif msg.sender == settings.agent_settings.bdi_program_manager_name: + if msg.body == "clear_history": + self.logger.debug("Clearing conversation history.") + self.history.clear() else: self.logger.debug("Message ignored (not from BDI core.") diff --git a/test/unit/agents/bdi/test_bdi_program_manager.py b/test/unit/agents/bdi/test_bdi_program_manager.py new file mode 100644 index 0000000..573524e --- /dev/null +++ b/test/unit/agents/bdi/test_bdi_program_manager.py @@ -0,0 +1,99 @@ +import asyncio +import json +import sys +from unittest.mock import AsyncMock + +import pytest + +from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager +from control_backend.core.agent_system import InternalMessage +from control_backend.schemas.belief_message import BeliefMessage +from control_backend.schemas.program import Program + +# Fix Windows Proactor loop for zmq +if sys.platform.startswith("win"): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + +def make_valid_program_json(norm="N1", goal="G1"): + return json.dumps( + { + "phases": [ + { + "id": "phase1", + "label": "Phase 1", + "triggers": [], + "norms": [{"id": "n1", "label": "Norm 1", "norm": norm}], + "goals": [ + {"id": "g1", "label": "Goal 1", "description": goal, "achieved": False} + ], + } + ] + } + ) + + +@pytest.mark.asyncio +async def test_send_to_bdi(): + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + program = Program.model_validate_json(make_valid_program_json()) + await manager._send_to_bdi(program) + + assert manager.send.await_count == 1 + msg: InternalMessage = manager.send.await_args[0][0] + assert msg.thread == "beliefs" + + beliefs = BeliefMessage.model_validate_json(msg.body) + names = {b.name: b.arguments for b in beliefs.beliefs} + + assert "norms" in names and names["norms"] == ["N1"] + assert "goals" in names and names["goals"] == ["G1"] + + +@pytest.mark.asyncio +async def test_receive_programs_valid_and_invalid(): + sub = AsyncMock() + sub.recv_multipart.side_effect = [ + (b"program", b"{bad json"), + (b"program", make_valid_program_json().encode()), + ] + + manager = BDIProgramManager(name="program_manager_test") + manager.sub_socket = sub + manager._send_to_bdi = AsyncMock() + manager._send_clear_llm_history = AsyncMock() + + try: + # Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out + await manager._receive_programs() + except StopAsyncIteration: + pass + + # Only valid Program should have triggered _send_to_bdi + assert manager._send_to_bdi.await_count == 1 + forwarded: Program = manager._send_to_bdi.await_args[0][0] + assert forwarded.phases[0].norms[0].norm == "N1" + assert forwarded.phases[0].goals[0].description == "G1" + + # Verify history clear was triggered + assert manager._send_clear_llm_history.await_count == 1 + + +@pytest.mark.asyncio +async def test_send_clear_llm_history(mock_settings): + # Ensure the mock returns a string for the agent name (just like in your LLM tests) + mock_settings.agent_settings.llm_agent_name = "llm_agent" + + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + await manager._send_clear_llm_history() + + assert manager.send.await_count == 1 + msg: InternalMessage = manager.send.await_args[0][0] + + # Verify the content and recipient + assert msg.body == "clear_history" + assert msg.to == "llm_agent" diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index 2f1b72e..5d07bb7 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -134,3 +134,23 @@ def test_llm_instructions(): text_def = instr_def.build_developer_instruction() assert "Norms to follow" in text_def assert "Goals to reach" in text_def + + +@pytest.mark.asyncio +async def test_clear_history_command(mock_settings): + """Test that the 'clear_history' message clears the agent's memory.""" + # setup LLM to have some history + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + agent = LLMAgent("llm_agent") + agent.history = [ + {"role": "user", "content": "Old conversation context"}, + {"role": "assistant", "content": "Old response"}, + ] + assert len(agent.history) == 2 + msg = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_program_manager_name, + body="clear_history", + ) + await agent.handle_message(msg) + assert len(agent.history) == 0