Files
pepperplus-cb/test/unit/agents/llm/test_llm_agent.py
Twirre Meulenbelt 54502e441c test: fix tests after changing schema and
ref: N25B-299
2025-11-24 20:53:53 +01:00

137 lines
4.5 KiB
Python

"""Mocks `httpx` and tests chunking logic."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions
from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
@pytest.fixture
def mock_httpx_client():
with patch("httpx.AsyncClient") as mock_cls:
mock_client = AsyncMock()
mock_cls.return_value.__aenter__.return_value = mock_client
yield mock_client
@pytest.mark.asyncio
async def test_llm_processing_success(mock_httpx_client, mock_settings):
# Setup the mock response for the stream
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
# Simulate stream lines
lines = [
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
b'data: {"choices": [{"delta": {"content": " world"}}]}',
b'data: {"choices": [{"delta": {"content": "."}}]}',
b"data: [DONE]",
]
async def aiter_lines_gen():
for line in lines:
yield line.decode()
mock_response.aiter_lines.side_effect = aiter_lines_gen
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
# Configure the client
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
# Setup Agent
agent = LLMAgent("llm_agent")
agent.send = AsyncMock() # Mock the send method to verify replies
# Simulate receiving a message from BDI
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
await agent.handle_message(msg)
# Verification
# "Hello world." constitutes one sentence/chunk based on punctuation split
# The agent should call send once with the full sentence
assert agent.send.called
args = agent.send.call_args[0][0]
assert args.to == mock_settings.agent_settings.bdi_core_name
assert "Hello world." in args.body
@pytest.mark.asyncio
async def test_llm_processing_errors(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
# HTTP Error
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
await agent.handle_message(msg)
assert "LLM service unavailable." in agent.send.call_args[0][0].body
# General Exception
agent.send.reset_mock()
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
await agent.handle_message(msg)
assert "Error processing the request." in agent.send.call_args[0][0].body
@pytest.mark.asyncio
async def test_llm_json_error(mock_httpx_client, mock_settings):
# Test malformed JSON in stream
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
async def aiter_lines_gen():
yield "data: {bad_json"
yield "data: [DONE]"
mock_response.aiter_lines.side_effect = aiter_lines_gen
mock_stream_context = MagicMock()
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
agent = LLMAgent("llm_agent")
agent.send = AsyncMock()
with patch.object(agent.logger, "error") as log:
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
await agent.handle_message(msg)
log.assert_called() # Should log JSONDecodeError
def test_llm_instructions():
# Full custom
instr = LLMInstructions(norms=["N1", "N2"], goals=["G1", "G2"])
text = instr.build_developer_instruction()
assert "Norms to follow:\n- N1\n- N2" in text
assert "Goals to reach:\n- G1\n- G2" in text
# Defaults
instr_def = LLMInstructions()
text_def = instr_def.build_developer_instruction()
assert "Norms to follow" in text_def
assert "Goals to reach" in text_def