357 lines
11 KiB
Python
357 lines
11 KiB
Python
"""Mocks `httpx` and tests chunking logic."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_httpx_client():
|
|
with patch("httpx.AsyncClient") as mock_cls:
|
|
mock_client = AsyncMock()
|
|
mock_cls.return_value.__aenter__.return_value = mock_client
|
|
yield mock_client
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_experiment_logger():
|
|
with patch("control_backend.agents.llm.llm_agent.experiment_logger") as logger:
|
|
yield logger
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_processing_success(mock_httpx_client, mock_settings):
|
|
# Setup the mock response for the stream
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
# Simulate stream lines
|
|
lines = [
|
|
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
|
|
b'data: {"choices": [{"delta": {"content": " world"}}]}',
|
|
b'data: {"choices": [{"delta": {"content": "."}}]}',
|
|
b"data: [DONE]",
|
|
]
|
|
|
|
async def aiter_lines_gen():
|
|
for line in lines:
|
|
yield line.decode()
|
|
|
|
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
|
|
|
mock_stream_context = MagicMock()
|
|
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Configure the client
|
|
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
|
|
|
# Setup Agent
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock() # Mock the send method to verify replies
|
|
|
|
mock_logger = MagicMock()
|
|
agent.logger = mock_logger
|
|
|
|
# Simulate receiving a message from BDI
|
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
|
msg = InternalMessage(
|
|
to="llm_agent",
|
|
sender=mock_settings.agent_settings.bdi_core_name,
|
|
body=prompt.model_dump_json(),
|
|
thread="prompt_message", # REQUIRED: thread must match handle_message logic
|
|
)
|
|
|
|
agent._process_bdi_message = AsyncMock()
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
agent._process_bdi_message.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_bdi_message_success(mock_httpx_client, mock_settings):
|
|
# Setup the mock response for the stream
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
# Simulate stream lines
|
|
lines = [
|
|
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
|
|
b'data: {"choices": [{"delta": {"content": " world"}}]}',
|
|
b'data: {"choices": [{"delta": {"content": "."}}]}',
|
|
b"data: [DONE]",
|
|
]
|
|
|
|
async def aiter_lines_gen():
|
|
for line in lines:
|
|
yield line.decode()
|
|
|
|
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
|
|
|
mock_stream_context = MagicMock()
|
|
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Configure the client
|
|
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
|
|
|
# Setup Agent
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock() # Mock the send method to verify replies
|
|
|
|
mock_logger = MagicMock()
|
|
agent.logger = mock_logger
|
|
|
|
# Simulate receiving a message from BDI
|
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
|
|
|
await agent._process_bdi_message(prompt)
|
|
|
|
# Verification
|
|
# "Hello world." constitutes one sentence/chunk based on punctuation split
|
|
# The agent should call send once with the full sentence, PLUS once more for full reply
|
|
assert agent.send.called
|
|
|
|
# Check args. We expect at least one call sending "Hello world."
|
|
calls = agent.send.call_args_list
|
|
bodies = [c[0][0].body for c in calls]
|
|
assert any("Hello world." in b for b in bodies)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_processing_errors(mock_httpx_client, mock_settings):
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock()
|
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
|
|
|
# HTTP Error: stream method RAISES exception immediately
|
|
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
|
|
|
|
await agent._process_bdi_message(prompt)
|
|
|
|
# Check that error message was sent
|
|
assert agent.send.called
|
|
assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_json_error(mock_httpx_client, mock_settings):
|
|
# Test malformed JSON in stream
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
async def aiter_lines_gen():
|
|
yield "data: {bad_json"
|
|
yield "data: [DONE]"
|
|
|
|
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
|
|
|
mock_stream_context = MagicMock()
|
|
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
|
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
|
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock()
|
|
# Ensure logger is mocked
|
|
agent.logger = MagicMock()
|
|
|
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
|
await agent._process_bdi_message(prompt)
|
|
|
|
agent.logger.error.assert_called() # Should log JSONDecodeError
|
|
|
|
|
|
def test_llm_instructions():
|
|
# Full custom
|
|
instr = LLMInstructions(norms=["N1", "N2"], goals=["G1", "G2"])
|
|
text = instr.build_developer_instruction()
|
|
assert "Norms to follow:\n- N1\n- N2" in text
|
|
assert "Goals to reach:\n- G1\n- G2" in text
|
|
|
|
# Defaults
|
|
instr_def = LLMInstructions()
|
|
text_def = instr_def.build_developer_instruction()
|
|
assert "Norms to follow" in text_def
|
|
assert "Goals to reach" in text_def
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, mock_settings):
|
|
"""
|
|
Covers the ValidationError branch:
|
|
except ValidationError:
|
|
self.logger.debug("Prompt message from BDI core is invalid.")
|
|
Assert: no message is sent.
|
|
"""
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock()
|
|
|
|
# Invalid JSON that triggers ValidationError in LLMPromptMessage
|
|
invalid_json = '{"text": "Hi", "wrong_field": 123}' # field not in schema
|
|
|
|
msg = InternalMessage(
|
|
to="llm_agent",
|
|
sender=mock_settings.agent_settings.bdi_core_name,
|
|
body=invalid_json,
|
|
thread="prompt_message",
|
|
)
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
# Should not send any reply
|
|
agent.send.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_message_ignored_sender_branch_no_send(mock_httpx_client, mock_settings):
|
|
"""
|
|
Covers the else branch for messages not from BDI core:
|
|
else:
|
|
self.logger.debug("Message ignored (not from BDI core.")
|
|
Assert: no message is sent.
|
|
"""
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock()
|
|
|
|
msg = InternalMessage(
|
|
to="llm_agent",
|
|
sender="some_other_agent", # Not BDI core
|
|
body='{"text": "Hi"}',
|
|
)
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
# Should not send any reply
|
|
agent.send.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_llm_yields_final_tail_chunk(mock_settings):
|
|
"""
|
|
Covers the branch: if current_chunk: yield current_chunk
|
|
Ensure that the last partial chunk is emitted.
|
|
"""
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock()
|
|
|
|
agent.logger = MagicMock()
|
|
agent.logger.llm = MagicMock()
|
|
|
|
# Patch _stream_query_llm to yield tokens that do NOT end with punctuation
|
|
async def fake_stream(messages):
|
|
yield "Hello"
|
|
yield " world" # No punctuation to trigger the normal chunking
|
|
|
|
agent._stream_query_llm = fake_stream
|
|
|
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
|
|
|
# Collect chunks yielded
|
|
chunks = []
|
|
async for chunk in agent._query_llm(prompt.text, prompt.norms, prompt.goals):
|
|
chunks.append(chunk)
|
|
|
|
# The final chunk should be yielded
|
|
assert chunks[-1] == "Hello world"
|
|
assert any("Hello" in c for c in chunks)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_settings):
|
|
"""
|
|
Covers: if not line or not line.startswith("data: "): continue
|
|
Feed lines that are empty or do not start with 'data:' and check they are skipped.
|
|
"""
|
|
# Mock response
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
lines = [
|
|
"", # empty line
|
|
"not data", # invalid prefix
|
|
'data: {"choices": [{"delta": {"content": "Hi"}}]}',
|
|
"data: [DONE]",
|
|
]
|
|
|
|
async def aiter_lines_gen():
|
|
for line in lines:
|
|
yield line
|
|
|
|
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
|
|
|
# Proper async context manager for stream
|
|
mock_stream_context = MagicMock()
|
|
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Make stream return the async context manager
|
|
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
|
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock()
|
|
|
|
# Patch settings for local LLM URL
|
|
with patch("control_backend.agents.llm.llm_agent.settings") as mock_sett:
|
|
mock_sett.llm_settings.local_llm_url = "http://localhost"
|
|
mock_sett.llm_settings.local_llm_model = "test-model"
|
|
|
|
# Collect tokens
|
|
tokens = []
|
|
async for token in agent._stream_query_llm([]):
|
|
tokens.append(token)
|
|
|
|
# Only the valid 'data:' line should yield content
|
|
assert tokens == ["Hi"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_history_command(mock_settings):
|
|
"""Test that the 'clear_history' message clears the agent's memory."""
|
|
# setup LLM to have some history
|
|
mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent"
|
|
agent = LLMAgent("llm_agent")
|
|
agent.history = [
|
|
{"role": "user", "content": "Old conversation context"},
|
|
{"role": "assistant", "content": "Old response"},
|
|
]
|
|
assert len(agent.history) == 2
|
|
msg = InternalMessage(
|
|
to="llm_agent",
|
|
sender=mock_settings.agent_settings.bdi_program_manager_name,
|
|
body="clear_history",
|
|
)
|
|
await agent.handle_message(msg)
|
|
assert len(agent.history) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_assistant_and_user_messages(mock_settings):
|
|
agent = LLMAgent("llm_agent")
|
|
|
|
# Assistant message
|
|
msg_ast = InternalMessage(
|
|
to="llm_agent",
|
|
sender=mock_settings.agent_settings.bdi_core_name,
|
|
thread="assistant_message",
|
|
body="I said this",
|
|
)
|
|
await agent.handle_message(msg_ast)
|
|
assert agent.history[-1] == {"role": "assistant", "content": "I said this"}
|
|
|
|
# User message
|
|
msg_usr = InternalMessage(
|
|
to="llm_agent",
|
|
sender=mock_settings.agent_settings.bdi_core_name,
|
|
thread="user_message",
|
|
body="User said this",
|
|
)
|
|
await agent.handle_message(msg_usr)
|
|
assert agent.history[-1] == {"role": "user", "content": "User said this"}
|