137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
"""Mocks `httpx` and tests chunking logic."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions
|
|
from control_backend.core.agent_system import InternalMessage
|
|
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_httpx_client():
|
|
with patch("httpx.AsyncClient") as mock_cls:
|
|
mock_client = AsyncMock()
|
|
mock_cls.return_value.__aenter__.return_value = mock_client
|
|
yield mock_client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_processing_success(mock_httpx_client, mock_settings):
|
|
# Setup the mock response for the stream
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
# Simulate stream lines
|
|
lines = [
|
|
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
|
|
b'data: {"choices": [{"delta": {"content": " world"}}]}',
|
|
b'data: {"choices": [{"delta": {"content": "."}}]}',
|
|
b"data: [DONE]",
|
|
]
|
|
|
|
async def aiter_lines_gen():
|
|
for line in lines:
|
|
yield line.decode()
|
|
|
|
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
|
|
|
mock_stream_context = MagicMock()
|
|
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Configure the client
|
|
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
|
|
|
# Setup Agent
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock() # Mock the send method to verify replies
|
|
|
|
# Simulate receiving a message from BDI
|
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
|
msg = InternalMessage(
|
|
to="llm_agent",
|
|
sender=mock_settings.agent_settings.bdi_core_name,
|
|
body=prompt.model_dump_json(),
|
|
)
|
|
|
|
await agent.handle_message(msg)
|
|
|
|
# Verification
|
|
# "Hello world." constitutes one sentence/chunk based on punctuation split
|
|
# The agent should call send once with the full sentence
|
|
assert agent.send.called
|
|
args = agent.send.call_args[0][0]
|
|
assert args.to == mock_settings.agent_settings.bdi_core_name
|
|
assert "Hello world." in args.body
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_processing_errors(mock_httpx_client, mock_settings):
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock()
|
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
|
msg = InternalMessage(
|
|
to="llm",
|
|
sender=mock_settings.agent_settings.bdi_core_name,
|
|
body=prompt.model_dump_json(),
|
|
)
|
|
|
|
# HTTP Error
|
|
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
|
|
await agent.handle_message(msg)
|
|
assert "LLM service unavailable." in agent.send.call_args[0][0].body
|
|
|
|
# General Exception
|
|
agent.send.reset_mock()
|
|
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
|
|
await agent.handle_message(msg)
|
|
assert "Error processing the request." in agent.send.call_args[0][0].body
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_json_error(mock_httpx_client, mock_settings):
|
|
# Test malformed JSON in stream
|
|
mock_response = MagicMock()
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
async def aiter_lines_gen():
|
|
yield "data: {bad_json"
|
|
yield "data: [DONE]"
|
|
|
|
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
|
|
|
mock_stream_context = MagicMock()
|
|
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
|
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
|
|
|
agent = LLMAgent("llm_agent")
|
|
agent.send = AsyncMock()
|
|
|
|
with patch.object(agent.logger, "error") as log:
|
|
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
|
|
msg = InternalMessage(
|
|
to="llm",
|
|
sender=mock_settings.agent_settings.bdi_core_name,
|
|
body=prompt.model_dump_json(),
|
|
)
|
|
await agent.handle_message(msg)
|
|
log.assert_called() # Should log JSONDecodeError
|
|
|
|
|
|
def test_llm_instructions():
|
|
# Full custom
|
|
instr = LLMInstructions(norms=["N1", "N2"], goals=["G1", "G2"])
|
|
text = instr.build_developer_instruction()
|
|
assert "Norms to follow:\n- N1\n- N2" in text
|
|
assert "Goals to reach:\n- G1\n- G2" in text
|
|
|
|
# Defaults
|
|
instr_def = LLMInstructions()
|
|
text_def = instr_def.build_developer_instruction()
|
|
assert "Norms to follow" in text_def
|
|
assert "Goals to reach" in text_def
|