From 6ca86e4b819c706cd15b7ef7fb06af0b26100019 Mon Sep 17 00:00:00 2001 From: Pim Hutting Date: Fri, 2 Jan 2026 15:13:04 +0000 Subject: [PATCH] feat: made program reset LLM --- .../agents/bdi/bdi_program_manager.py | 21 ++++++++++++++++-- src/control_backend/agents/llm/llm_agent.py | 4 ++++ .../agents/bdi/test_bdi_program_manager.py | 22 +++++++++++++++++++ test/unit/agents/llm/test_llm_agent.py | 20 +++++++++++++++++ test/unit/test_main.py | 2 -- 5 files changed, 65 insertions(+), 4 deletions(-) diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index 83dea93..2f4f850 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -60,24 +60,41 @@ class BDIProgramManager(BaseAgent): await self.send(message) self.logger.debug("Sent new norms and goals to the BDI agent.") + async def _send_clear_llm_history(self): + """ + Clear the LLM Agent's conversation history. + + Sends an empty history to the LLM Agent to reset its state. + """ + message = InternalMessage( + to=settings.agent_settings.llm_name, + sender=self.name, + body="clear_history", + threads="clear history message", + ) + await self.send(message) + self.logger.debug("Sent message to LLM agent to clear history.") + async def _receive_programs(self): """ Continuous loop that receives program updates from the HTTP endpoint. It listens to the ``program`` topic on the internal ZMQ SUB socket. When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`. + Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`. """ while True: topic, body = await self.sub_socket.recv_multipart() try: program = Program.model_validate_json(body) + await self._send_to_bdi(program) + await self._send_clear_llm_history() + except ValidationError: self.logger.exception("Received an invalid program.") continue - await self._send_to_bdi(program) - async def setup(self): """ Initialize the agent. diff --git a/src/control_backend/agents/llm/llm_agent.py b/src/control_backend/agents/llm/llm_agent.py index 0263b30..f1c70c9 100644 --- a/src/control_backend/agents/llm/llm_agent.py +++ b/src/control_backend/agents/llm/llm_agent.py @@ -52,6 +52,10 @@ class LLMAgent(BaseAgent): await self._process_bdi_message(prompt_message) except ValidationError: self.logger.debug("Prompt message from BDI core is invalid.") + elif msg.sender == settings.agent_settings.bdi_program_manager_name: + if msg.body == "clear_history": + self.logger.debug("Clearing conversation history.") + self.history.clear() else: self.logger.debug("Message ignored (not from BDI core.") diff --git a/test/unit/agents/bdi/test_bdi_program_manager.py b/test/unit/agents/bdi/test_bdi_program_manager.py index a54360c..573524e 100644 --- a/test/unit/agents/bdi/test_bdi_program_manager.py +++ b/test/unit/agents/bdi/test_bdi_program_manager.py @@ -63,6 +63,7 @@ async def test_receive_programs_valid_and_invalid(): manager = BDIProgramManager(name="program_manager_test") manager.sub_socket = sub manager._send_to_bdi = AsyncMock() + manager._send_clear_llm_history = AsyncMock() try: # Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out @@ -75,3 +76,24 @@ async def test_receive_programs_valid_and_invalid(): forwarded: Program = manager._send_to_bdi.await_args[0][0] assert forwarded.phases[0].norms[0].norm == "N1" assert forwarded.phases[0].goals[0].description == "G1" + + # Verify history clear was triggered + assert manager._send_clear_llm_history.await_count == 1 + + +@pytest.mark.asyncio +async def test_send_clear_llm_history(mock_settings): + # Ensure the mock returns a string for the agent name (just like in your LLM tests) + mock_settings.agent_settings.llm_agent_name = "llm_agent" + + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + await manager._send_clear_llm_history() + + assert manager.send.await_count == 1 + msg: InternalMessage = manager.send.await_args[0][0] + + # Verify the content and recipient + assert msg.body == "clear_history" + assert msg.to == "llm_agent" diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index e2b6460..3341c7d 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -259,3 +259,23 @@ async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_set # Only the valid 'data:' line should yield content assert tokens == ["Hi"] + + +@pytest.mark.asyncio +async def test_clear_history_command(mock_settings): + """Test that the 'clear_history' message clears the agent's memory.""" + # setup LLM to have some history + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + agent = LLMAgent("llm_agent") + agent.history = [ + {"role": "user", "content": "Old conversation context"}, + {"role": "assistant", "content": "Old response"}, + ] + assert len(agent.history) == 2 + msg = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_program_manager_name, + body="clear_history", + ) + await agent.handle_message(msg) + assert len(agent.history) == 0 diff --git a/test/unit/test_main.py b/test/unit/test_main.py index 2737c76..a423703 100644 --- a/test/unit/test_main.py +++ b/test/unit/test_main.py @@ -53,8 +53,6 @@ async def test_lifespan_agent_start_exception(): Ensures exceptions are logged properly and re-raised. """ with ( - patch("control_backend.main.VADAgent.start", new_callable=AsyncMock), - patch("control_backend.main.VADAgent.reset_stream", new_callable=AsyncMock), patch( "control_backend.main.RICommunicationAgent.start", new_callable=AsyncMock ) as ri_start,