diff --git a/src/control_backend/agents/bdi/bdi_core_agent.py b/src/control_backend/agents/bdi/bdi_core_agent.py index 628bb53..54b5149 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent.py +++ b/src/control_backend/agents/bdi/bdi_core_agent.py @@ -338,7 +338,7 @@ class BDICoreAgent(BaseAgent): yield @self.actions.add(".reply_with_goal", 3) - def _reply_with_goal(agent: "BDICoreAgent", term, intention): + def _reply_with_goal(agent, term, intention): """ Let the LLM generate a response to a user's utterance with the current norms and a specific goal. @@ -512,10 +512,6 @@ class BDICoreAgent(BaseAgent): yield - @self.actions.add(".notify_ui", 0) - def _notify_ui(agent, term, intention): - pass - async def _send_to_llm(self, text: str, norms: str, goals: str): """ Sends a text query to the LLM agent asynchronously. diff --git a/src/control_backend/agents/bdi/text_belief_extractor_agent.py b/src/control_backend/agents/bdi/text_belief_extractor_agent.py index 362dfbf..9ea6b9a 100644 --- a/src/control_backend/agents/bdi/text_belief_extractor_agent.py +++ b/src/control_backend/agents/bdi/text_belief_extractor_agent.py @@ -318,6 +318,9 @@ class TextBeliefExtractorAgent(BaseAgent): async with httpx.AsyncClient() as client: response = await client.post( settings.llm_settings.local_llm_url, + headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"} + if settings.llm_settings.api_key + else {}, json={ "model": settings.llm_settings.local_llm_model, "messages": [{"role": "user", "content": prompt}], diff --git a/src/control_backend/agents/llm/llm_agent.py b/src/control_backend/agents/llm/llm_agent.py index 1c72dfc..8d81249 100644 --- a/src/control_backend/agents/llm/llm_agent.py +++ b/src/control_backend/agents/llm/llm_agent.py @@ -1,3 +1,4 @@ +import asyncio import json import re import uuid @@ -32,6 +33,10 @@ class LLMAgent(BaseAgent): def __init__(self, name: str): super().__init__(name) self.history = [] + self._querying = False + self._interrupted = False + self._interrupted_message = "" + self._go_ahead = asyncio.Event() async def setup(self): self.logger.info("Setting up %s.", self.name) @@ -50,13 +55,13 @@ class LLMAgent(BaseAgent): case "prompt_message": try: prompt_message = LLMPromptMessage.model_validate_json(msg.body) - await self._process_bdi_message(prompt_message) + self.add_behavior(self._process_bdi_message(prompt_message)) # no block except ValidationError: self.logger.debug("Prompt message from BDI core is invalid.") case "assistant_message": - self.history.append({"role": "assistant", "content": msg.body}) + self._apply_conversation_message({"role": "assistant", "content": msg.body}) case "user_message": - self.history.append({"role": "user", "content": msg.body}) + self._apply_conversation_message({"role": "user", "content": msg.body}) elif msg.sender == settings.agent_settings.bdi_program_manager_name: if msg.body == "clear_history": self.logger.debug("Clearing conversation history.") @@ -73,12 +78,45 @@ class LLMAgent(BaseAgent): :param message: The parsed prompt message containing text, norms, and goals. """ + if self._querying: + self.logger.debug("Received another BDI prompt while processing previous message.") + self._interrupted = True # interrupt the previous processing + await self._go_ahead.wait() # wait until we get the go-ahead + + message.text = f"{self._interrupted_message} {message.text}" + + self._go_ahead.clear() + self._querying = True full_message = "" async for chunk in self._query_llm(message.text, message.norms, message.goals): + if self._interrupted: + self._interrupted_message = message.text + self.logger.debug("Interrupted processing of previous message.") + break await self._send_reply(chunk) full_message += chunk - self.logger.debug("Finished processing BDI message. Response sent in chunks to BDI core.") - await self._send_full_reply(full_message) + else: + self._querying = False + + self._apply_conversation_message( + { + "role": "assistant", + "content": full_message, + } + ) + self.logger.debug( + "Finished processing BDI message. Response sent in chunks to BDI core." + ) + await self._send_full_reply(full_message) + + self._go_ahead.set() + self._interrupted = False + + def _apply_conversation_message(self, message: dict[str, str]): + if len(self.history) > 0 and message["role"] == self.history[-1]["role"]: + self.history[-1]["content"] += " " + message["content"] + return + self.history.append(message) async def _send_reply(self, msg: str): """ @@ -159,13 +197,6 @@ class LLMAgent(BaseAgent): # Yield any remaining tail if current_chunk: yield current_chunk - - self.history.append( - { - "role": "assistant", - "content": full_message, - } - ) except httpx.HTTPError as err: self.logger.error("HTTP error.", exc_info=err) yield "LLM service unavailable." @@ -185,6 +216,9 @@ class LLMAgent(BaseAgent): async with client.stream( "POST", settings.llm_settings.local_llm_url, + headers={"Authorization": f"Bearer {settings.llm_settings.api_key}"} + if settings.llm_settings.api_key + else {}, json={ "model": settings.llm_settings.local_llm_model, "messages": messages, diff --git a/src/control_backend/agents/perception/transcription_agent/speech_recognizer.py b/src/control_backend/agents/perception/transcription_agent/speech_recognizer.py index 9fae676..1fe7e3f 100644 --- a/src/control_backend/agents/perception/transcription_agent/speech_recognizer.py +++ b/src/control_backend/agents/perception/transcription_agent/speech_recognizer.py @@ -145,4 +145,6 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"] + return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))[ + "text" + ].strip() diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 329a246..82b9ede 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -117,6 +117,7 @@ class LLMSettings(BaseModel): local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "gpt-oss" + api_key: str = "" chat_temperature: float = 1.0 code_temperature: float = 0.3 n_parallel: int = 4 diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index a1cc297..bd407cc 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -61,8 +61,52 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings): thread="prompt_message", # REQUIRED: thread must match handle_message logic ) + agent._process_bdi_message = AsyncMock() + await agent.handle_message(msg) + agent._process_bdi_message.assert_called() + + +@pytest.mark.asyncio +async def test_process_bdi_message_success(mock_httpx_client, mock_settings): + # Setup the mock response for the stream + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + + # Simulate stream lines + lines = [ + b'data: {"choices": [{"delta": {"content": "Hello"}}]}', + b'data: {"choices": [{"delta": {"content": " world"}}]}', + b'data: {"choices": [{"delta": {"content": "."}}]}', + b"data: [DONE]", + ] + + async def aiter_lines_gen(): + for line in lines: + yield line.decode() + + mock_response.aiter_lines.side_effect = aiter_lines_gen + + mock_stream_context = MagicMock() + mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_stream_context.__aexit__ = AsyncMock(return_value=None) + + # Configure the client + mock_httpx_client.stream = MagicMock(return_value=mock_stream_context) + + # Setup Agent + agent = LLMAgent("llm_agent") + agent.send = AsyncMock() # Mock the send method to verify replies + + mock_logger = MagicMock() + agent.logger = mock_logger + + # Simulate receiving a message from BDI + prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) + + await agent._process_bdi_message(prompt) + # Verification # "Hello world." constitutes one sentence/chunk based on punctuation split # The agent should call send once with the full sentence, PLUS once more for full reply @@ -79,28 +123,16 @@ async def test_llm_processing_errors(mock_httpx_client, mock_settings): agent = LLMAgent("llm_agent") agent.send = AsyncMock() prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) - msg = InternalMessage( - to="llm", - sender=mock_settings.agent_settings.bdi_core_name, - body=prompt.model_dump_json(), - thread="prompt_message", - ) # HTTP Error: stream method RAISES exception immediately mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail")) - await agent.handle_message(msg) + await agent._process_bdi_message(prompt) # Check that error message was sent assert agent.send.called assert "LLM service unavailable." in agent.send.call_args_list[0][0][0].body - # General Exception - agent.send.reset_mock() - mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom")) - await agent.handle_message(msg) - assert "Error processing the request." in agent.send.call_args_list[0][0][0].body - @pytest.mark.asyncio async def test_llm_json_error(mock_httpx_client, mock_settings): @@ -125,13 +157,7 @@ async def test_llm_json_error(mock_httpx_client, mock_settings): agent.logger = MagicMock() prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) - msg = InternalMessage( - to="llm", - sender=mock_settings.agent_settings.bdi_core_name, - body=prompt.model_dump_json(), - thread="prompt_message", - ) - await agent.handle_message(msg) + await agent._process_bdi_message(prompt) agent.logger.error.assert_called() # Should log JSONDecodeError