feat: made program reset LLM
also added test_bdi_program_manager back cause it was somehow missing in my files ref: N25B-355
This commit is contained in:
@@ -60,24 +60,41 @@ class BDIProgramManager(BaseAgent):
|
|||||||
await self.send(message)
|
await self.send(message)
|
||||||
self.logger.debug("Sent new norms and goals to the BDI agent.")
|
self.logger.debug("Sent new norms and goals to the BDI agent.")
|
||||||
|
|
||||||
|
async def _send_clear_llm_history(self):
|
||||||
|
"""
|
||||||
|
Clear the LLM Agent's conversation history.
|
||||||
|
|
||||||
|
Sends an empty history to the LLM Agent to reset its state.
|
||||||
|
"""
|
||||||
|
message = InternalMessage(
|
||||||
|
to=settings.agent_settings.llm_name,
|
||||||
|
sender=self.name,
|
||||||
|
body="clear_history",
|
||||||
|
threads="clear history message",
|
||||||
|
)
|
||||||
|
await self.send(message)
|
||||||
|
self.logger.debug("Sent message to LLM agent to clear history.")
|
||||||
|
|
||||||
async def _receive_programs(self):
|
async def _receive_programs(self):
|
||||||
"""
|
"""
|
||||||
Continuous loop that receives program updates from the HTTP endpoint.
|
Continuous loop that receives program updates from the HTTP endpoint.
|
||||||
|
|
||||||
It listens to the ``program`` topic on the internal ZMQ SUB socket.
|
It listens to the ``program`` topic on the internal ZMQ SUB socket.
|
||||||
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
|
When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`.
|
||||||
|
Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`.
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
topic, body = await self.sub_socket.recv_multipart()
|
topic, body = await self.sub_socket.recv_multipart()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
program = Program.model_validate_json(body)
|
program = Program.model_validate_json(body)
|
||||||
|
await self._send_to_bdi(program)
|
||||||
|
await self._send_clear_llm_history()
|
||||||
|
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
self.logger.exception("Received an invalid program.")
|
self.logger.exception("Received an invalid program.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self._send_to_bdi(program)
|
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""
|
"""
|
||||||
Initialize the agent.
|
Initialize the agent.
|
||||||
@@ -92,3 +109,4 @@ class BDIProgramManager(BaseAgent):
|
|||||||
self.sub_socket.subscribe("program")
|
self.sub_socket.subscribe("program")
|
||||||
|
|
||||||
self.add_behavior(self._receive_programs())
|
self.add_behavior(self._receive_programs())
|
||||||
|
# self.add_behavior(self._reset_llm_on_new_program())
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ class LLMAgent(BaseAgent):
|
|||||||
await self._process_bdi_message(prompt_message)
|
await self._process_bdi_message(prompt_message)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
self.logger.debug("Prompt message from BDI core is invalid.")
|
self.logger.debug("Prompt message from BDI core is invalid.")
|
||||||
|
elif msg.sender == settings.agent_settings.bdi_program_manager_name:
|
||||||
|
if msg.body == "clear_history":
|
||||||
|
self.logger.debug("Clearing conversation history.")
|
||||||
|
self.history.clear()
|
||||||
else:
|
else:
|
||||||
self.logger.debug("Message ignored (not from BDI core.")
|
self.logger.debug("Message ignored (not from BDI core.")
|
||||||
|
|
||||||
|
|||||||
99
test/unit/agents/bdi/test_bdi_program_manager.py
Normal file
99
test/unit/agents/bdi/test_bdi_program_manager.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
|
from control_backend.schemas.program import Program
|
||||||
|
|
||||||
|
# Fix Windows Proactor loop for zmq
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
def make_valid_program_json(norm="N1", goal="G1"):
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"phases": [
|
||||||
|
{
|
||||||
|
"id": "phase1",
|
||||||
|
"label": "Phase 1",
|
||||||
|
"triggers": [],
|
||||||
|
"norms": [{"id": "n1", "label": "Norm 1", "norm": norm}],
|
||||||
|
"goals": [
|
||||||
|
{"id": "g1", "label": "Goal 1", "description": goal, "achieved": False}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_to_bdi():
|
||||||
|
manager = BDIProgramManager(name="program_manager_test")
|
||||||
|
manager.send = AsyncMock()
|
||||||
|
|
||||||
|
program = Program.model_validate_json(make_valid_program_json())
|
||||||
|
await manager._send_to_bdi(program)
|
||||||
|
|
||||||
|
assert manager.send.await_count == 1
|
||||||
|
msg: InternalMessage = manager.send.await_args[0][0]
|
||||||
|
assert msg.thread == "beliefs"
|
||||||
|
|
||||||
|
beliefs = BeliefMessage.model_validate_json(msg.body)
|
||||||
|
names = {b.name: b.arguments for b in beliefs.beliefs}
|
||||||
|
|
||||||
|
assert "norms" in names and names["norms"] == ["N1"]
|
||||||
|
assert "goals" in names and names["goals"] == ["G1"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_receive_programs_valid_and_invalid():
|
||||||
|
sub = AsyncMock()
|
||||||
|
sub.recv_multipart.side_effect = [
|
||||||
|
(b"program", b"{bad json"),
|
||||||
|
(b"program", make_valid_program_json().encode()),
|
||||||
|
]
|
||||||
|
|
||||||
|
manager = BDIProgramManager(name="program_manager_test")
|
||||||
|
manager.sub_socket = sub
|
||||||
|
manager._send_to_bdi = AsyncMock()
|
||||||
|
manager._send_clear_llm_history = AsyncMock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out
|
||||||
|
await manager._receive_programs()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Only valid Program should have triggered _send_to_bdi
|
||||||
|
assert manager._send_to_bdi.await_count == 1
|
||||||
|
forwarded: Program = manager._send_to_bdi.await_args[0][0]
|
||||||
|
assert forwarded.phases[0].norms[0].norm == "N1"
|
||||||
|
assert forwarded.phases[0].goals[0].description == "G1"
|
||||||
|
|
||||||
|
# Verify history clear was triggered
|
||||||
|
assert manager._send_clear_llm_history.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_clear_llm_history(mock_settings):
|
||||||
|
# Ensure the mock returns a string for the agent name (just like in your LLM tests)
|
||||||
|
mock_settings.agent_settings.llm_agent_name = "llm_agent"
|
||||||
|
|
||||||
|
manager = BDIProgramManager(name="program_manager_test")
|
||||||
|
manager.send = AsyncMock()
|
||||||
|
|
||||||
|
await manager._send_clear_llm_history()
|
||||||
|
|
||||||
|
assert manager.send.await_count == 1
|
||||||
|
msg: InternalMessage = manager.send.await_args[0][0]
|
||||||
|
|
||||||
|
# Verify the content and recipient
|
||||||
|
assert msg.body == "clear_history"
|
||||||
|
assert msg.to == "llm_agent"
|
||||||
@@ -134,3 +134,23 @@ def test_llm_instructions():
|
|||||||
text_def = instr_def.build_developer_instruction()
|
text_def = instr_def.build_developer_instruction()
|
||||||
assert "Norms to follow" in text_def
|
assert "Norms to follow" in text_def
|
||||||
assert "Goals to reach" in text_def
|
assert "Goals to reach" in text_def
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_history_command(mock_settings):
|
||||||
|
"""Test that the 'clear_history' message clears the agent's memory."""
|
||||||
|
# setup LLM to have some history
|
||||||
|
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
|
||||||
|
agent = LLMAgent("llm_agent")
|
||||||
|
agent.history = [
|
||||||
|
{"role": "user", "content": "Old conversation context"},
|
||||||
|
{"role": "assistant", "content": "Old response"},
|
||||||
|
]
|
||||||
|
assert len(agent.history) == 2
|
||||||
|
msg = InternalMessage(
|
||||||
|
to="llm_agent",
|
||||||
|
sender=mock_settings.agent_settings.bdi_program_manager_name,
|
||||||
|
body="clear_history",
|
||||||
|
)
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
assert len(agent.history) == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user