"""Mocks `httpx` and tests chunking logic.""" from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions from control_backend.core.agent_system import InternalMessage from control_backend.schemas.llm_prompt_message import LLMPromptMessage @pytest.fixture def mock_httpx_client(): with patch("httpx.AsyncClient") as mock_cls: mock_client = AsyncMock() mock_cls.return_value.__aenter__.return_value = mock_client yield mock_client @pytest.fixture(autouse=True) def mock_experiment_logger(): with patch("control_backend.agents.llm.llm_agent.experiment_logger") as logger: yield logger @pytest.mark.asyncio async def test_llm_processing_success(mock_httpx_client, mock_settings): # Setup the mock response for the stream mock_response = MagicMock() mock_response.raise_for_status = MagicMock() # Simulate stream lines lines = [ b'data: {"choices": [{"delta": {"content": "Hello"}}]}', b'data: {"choices": [{"delta": {"content": " world"}}]}', b'data: {"choices": [{"delta": {"content": "."}}]}', b"data: [DONE]", ] async def aiter_lines_gen(): for line in lines: yield line.decode() mock_response.aiter_lines.side_effect = aiter_lines_gen mock_stream_context = MagicMock() mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response) mock_stream_context.__aexit__ = AsyncMock(return_value=None) # Configure the client mock_httpx_client.stream = MagicMock(return_value=mock_stream_context) # Setup Agent agent = LLMAgent("llm_agent") agent.send = AsyncMock() # Mock the send method to verify replies mock_logger = MagicMock() agent.logger = mock_logger # Simulate receiving a message from BDI prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) msg = InternalMessage( to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body=prompt.model_dump_json(), thread="prompt_message", # REQUIRED: thread must match handle_message logic ) agent._process_bdi_message = AsyncMock() await agent.handle_message(msg) agent._process_bdi_message.assert_called() @pytest.mark.asyncio async def test_process_bdi_message_success(mock_httpx_client, mock_settings): # Setup the mock response for the stream mock_response = MagicMock() mock_response.raise_for_status = MagicMock() # Simulate stream lines lines = [ b'data: {"choices": [{"delta": {"content": "Hello"}}]}', b'data: {"choices": [{"delta": {"content": " world"}}]}', b'data: {"choices": [{"delta": {"content": "."}}]}', b"data: [DONE]", ] async def aiter_lines_gen(): for line in lines: yield line.decode() mock_response.aiter_lines.side_effect = aiter_lines_gen mock_stream_context = MagicMock() mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response) mock_stream_context.__aexit__ = AsyncMock(return_value=None) # Configure the client mock_httpx_client.stream = MagicMock(return_value=mock_stream_context) # Setup Agent agent = LLMAgent("llm_agent") agent.send = AsyncMock() # Mock the send method to verify replies mock_logger = MagicMock() agent.logger = mock_logger # Simulate receiving a message from BDI prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) await agent._process_bdi_message(prompt) # Verification # "Hello world." constitutes one sentence/chunk based on punctuation split # The agent should call send once with the full sentence, PLUS once more for full reply assert agent.send.called # Check args. We expect at least one call sending "Hello world." calls = agent.send.call_args_list bodies = [c[0][0].body for c in calls] assert any("Hello world." in b for b in bodies) @pytest.mark.asyncio async def test_llm_processing_errors(mock_httpx_client, mock_settings): agent = LLMAgent("llm_agent") agent.send = AsyncMock() prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) # HTTP Error: stream method RAISES exception immediately mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail")) await agent._process_bdi_message(prompt) # Check that error message was sent assert agent.send.called assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body @pytest.mark.asyncio async def test_llm_json_error(mock_httpx_client, mock_settings): # Test malformed JSON in stream mock_response = MagicMock() mock_response.raise_for_status = MagicMock() async def aiter_lines_gen(): yield "data: {bad_json" yield "data: [DONE]" mock_response.aiter_lines.side_effect = aiter_lines_gen mock_stream_context = MagicMock() mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response) mock_stream_context.__aexit__ = AsyncMock(return_value=None) mock_httpx_client.stream = MagicMock(return_value=mock_stream_context) agent = LLMAgent("llm_agent") agent.send = AsyncMock() # Ensure logger is mocked agent.logger = MagicMock() prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) await agent._process_bdi_message(prompt) agent.logger.error.assert_called() # Should log JSONDecodeError def test_llm_instructions(): # Full custom instr = LLMInstructions(norms=["N1", "N2"], goals=["G1", "G2"]) text = instr.build_developer_instruction() assert "Norms to follow:\n- N1\n- N2" in text assert "Goals to reach:\n- G1\n- G2" in text # Defaults instr_def = LLMInstructions() text_def = instr_def.build_developer_instruction() assert "Norms to follow" in text_def assert "Goals to reach" in text_def @pytest.mark.asyncio async def test_handle_message_validation_error_branch_no_send(mock_httpx_client, mock_settings): """ Covers the ValidationError branch: except ValidationError: self.logger.debug("Prompt message from BDI core is invalid.") Assert: no message is sent. """ agent = LLMAgent("llm_agent") agent.send = AsyncMock() # Invalid JSON that triggers ValidationError in LLMPromptMessage invalid_json = '{"text": "Hi", "wrong_field": 123}' # field not in schema msg = InternalMessage( to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body=invalid_json, thread="prompt_message", ) await agent.handle_message(msg) # Should not send any reply agent.send.assert_not_called() @pytest.mark.asyncio async def test_handle_message_ignored_sender_branch_no_send(mock_httpx_client, mock_settings): """ Covers the else branch for messages not from BDI core: else: self.logger.debug("Message ignored (not from BDI core.") Assert: no message is sent. """ agent = LLMAgent("llm_agent") agent.send = AsyncMock() msg = InternalMessage( to="llm_agent", sender="some_other_agent", # Not BDI core body='{"text": "Hi"}', ) await agent.handle_message(msg) # Should not send any reply agent.send.assert_not_called() @pytest.mark.asyncio async def test_query_llm_yields_final_tail_chunk(mock_settings): """ Covers the branch: if current_chunk: yield current_chunk Ensure that the last partial chunk is emitted. """ agent = LLMAgent("llm_agent") agent.send = AsyncMock() agent.logger = MagicMock() agent.logger.llm = MagicMock() # Patch _stream_query_llm to yield tokens that do NOT end with punctuation async def fake_stream(messages): yield "Hello" yield " world" # No punctuation to trigger the normal chunking agent._stream_query_llm = fake_stream prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) # Collect chunks yielded chunks = [] async for chunk in agent._query_llm(prompt.text, prompt.norms, prompt.goals): chunks.append(chunk) # The final chunk should be yielded assert chunks[-1] == "Hello world" assert any("Hello" in c for c in chunks) @pytest.mark.asyncio async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_settings): """ Covers: if not line or not line.startswith("data: "): continue Feed lines that are empty or do not start with 'data:' and check they are skipped. """ # Mock response mock_response = MagicMock() mock_response.raise_for_status = MagicMock() lines = [ "", # empty line "not data", # invalid prefix 'data: {"choices": [{"delta": {"content": "Hi"}}]}', "data: [DONE]", ] async def aiter_lines_gen(): for line in lines: yield line mock_response.aiter_lines.side_effect = aiter_lines_gen # Proper async context manager for stream mock_stream_context = MagicMock() mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response) mock_stream_context.__aexit__ = AsyncMock(return_value=None) # Make stream return the async context manager mock_httpx_client.stream = MagicMock(return_value=mock_stream_context) agent = LLMAgent("llm_agent") agent.send = AsyncMock() # Patch settings for local LLM URL with patch("control_backend.agents.llm.llm_agent.settings") as mock_sett: mock_sett.llm_settings.local_llm_url = "http://localhost" mock_sett.llm_settings.local_llm_model = "test-model" # Collect tokens tokens = [] async for token in agent._stream_query_llm([]): tokens.append(token) # Only the valid 'data:' line should yield content assert tokens == ["Hi"] @pytest.mark.asyncio async def test_clear_history_command(mock_settings): """Test that the 'clear_history' message clears the agent's memory.""" # setup LLM to have some history mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" agent = LLMAgent("llm_agent") agent.history = [ {"role": "user", "content": "Old conversation context"}, {"role": "assistant", "content": "Old response"}, ] assert len(agent.history) == 2 msg = InternalMessage( to="llm_agent", sender=mock_settings.agent_settings.bdi_program_manager_name, body="clear_history", ) await agent.handle_message(msg) assert len(agent.history) == 0 @pytest.mark.asyncio async def test_handle_assistant_and_user_messages(mock_settings): agent = LLMAgent("llm_agent") # Assistant message msg_ast = InternalMessage( to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, thread="assistant_message", body="I said this", ) await agent.handle_message(msg_ast) assert agent.history[-1] == {"role": "assistant", "content": "I said this"} # User message msg_usr = InternalMessage( to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, thread="user_message", body="User said this", ) await agent.handle_message(msg_usr) assert agent.history[-1] == {"role": "user", "content": "User said this"}