diff --git a/.githooks/check-branch-name.sh b/.githooks/check-branch-name.sh index 0e71c9b..6a6669a 100755 --- a/.githooks/check-branch-name.sh +++ b/.githooks/check-branch-name.sh @@ -10,7 +10,7 @@ # An array of allowed commit types ALLOWED_TYPES=(feat fix refactor perf style test docs build chore revert) # An array of branches to ignore -IGNORED_BRANCHES=(main dev) +IGNORED_BRANCHES=(main dev demo) # --- Colors for Output --- RED='\033[0;31m' diff --git a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py index dc6e862..71e69c6 100644 --- a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py +++ b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py @@ -1,8 +1,10 @@ import logging from spade.behaviour import CyclicBehaviour +from spade.message import Message from control_backend.core.config import settings +from control_backend.schemas.ri_message import SpeechCommand class ReceiveLLMResponseBehaviour(CyclicBehaviour): @@ -10,7 +12,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour): Adds behavior to receive responses from the LLM Agent. """ - logger = logging.getLogger("BDI/LLM Reciever") + logger = logging.getLogger("BDI/LLM Receiver") async def run(self): msg = await self.receive(timeout=2) @@ -22,7 +24,20 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour): case settings.agent_settings.llm_agent_name: content = msg.body self.logger.info("Received LLM response: %s", content) - # Here the BDI can pass the message back as a response + + speech_command = SpeechCommand(data=content) + + message = Message( + to=settings.agent_settings.ri_command_agent_name + + "@" + + settings.agent_settings.host, + sender=self.agent.jid, + body=speech_command.model_dump_json(), + ) + + self.logger.debug("Sending message: %s", message) + + await self.send(message) case _: self.logger.debug("Not from the llm, discarding message") pass diff --git a/src/control_backend/agents/llm/llm.py b/src/control_backend/agents/llm/llm.py index c3c17ab..4487b23 100644 --- a/src/control_backend/agents/llm/llm.py +++ b/src/control_backend/agents/llm/llm.py @@ -3,8 +3,10 @@ LLM Agent module for routing text queries from the BDI Core Agent to a local LLM service and returning its responses back to the BDI Core Agent. """ +import json import logging -from typing import Any +import re +from collections.abc import AsyncGenerator import httpx from spade.agent import Agent @@ -54,11 +56,16 @@ class LLMAgent(Agent): async def _process_bdi_message(self, message: Message): """ - Forwards user text to the LLM and replies with the generated text. + Forwards user text from the BDI to the LLM and replies with the generated text in chunks + separated by punctuation. """ user_text = message.body - llm_response = await self._query_llm(user_text) - await self._reply(llm_response) + # Consume the streaming generator and send a reply for every chunk + async for chunk in self._query_llm(user_text): + await self._reply(chunk) + self.agent.logger.debug( + "Finished processing BDI message. Response sent in chunks to BDI Core Agent." + ) async def _reply(self, msg: str): """ @@ -69,48 +76,89 @@ class LLMAgent(Agent): body=msg, ) await self.send(reply) - self.agent.logger.info("Reply sent to BDI Core Agent") - async def _query_llm(self, prompt: str) -> str: + async def _query_llm(self, prompt: str) -> AsyncGenerator[str]: """ - Sends a chat completion request to the local LLM service. + Sends a chat completion request to the local LLM service and streams the response by + yielding fragments separated by punctuation like. :param prompt: Input text prompt to pass to the LLM. - :return: LLM-generated content or fallback message. + :yield: Fragments of the LLM-generated content. """ - async with httpx.AsyncClient(timeout=120.0) as client: - # Example dynamic content for future (optional) + instructions = LLMInstructions( + "- Be friendly and respectful.\n" + "- Make the conversation feel natural and engaging.\n" + "- Speak like a pirate.\n" + "- When the user asks what you can do, tell them.", + "- Try to learn the user's name during conversation.\n" + "- Suggest playing a game of asking yes or no questions where you think of a word " + "and the user must guess it.", + ) + messages = [ + { + "role": "developer", + "content": instructions.build_developer_instruction(), + }, + { + "role": "user", + "content": prompt, + }, + ] - instructions = LLMInstructions() - developer_instruction = instructions.build_developer_instruction() + try: + current_chunk = "" + async for token in self._stream_query_llm(messages): + current_chunk += token - response = await client.post( + # Stream the message in chunks separated by punctuation. + # We include the delimiter in the emitted chunk for natural flow. + pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL) + for m in pattern.finditer(current_chunk): + chunk = m.group(0) + if chunk: + yield current_chunk + current_chunk = "" + + # Yield any remaining tail + if current_chunk: + yield current_chunk + except httpx.HTTPError as err: + self.agent.logger.error("HTTP error.", exc_info=err) + yield "LLM service unavailable." + except Exception as err: + self.agent.logger.error("Unexpected error.", exc_info=err) + yield "Error processing the request." + + async def _stream_query_llm(self, messages) -> AsyncGenerator[str]: + """Raises httpx.HTTPError when the API gives an error.""" + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream( + "POST", settings.llm_settings.local_llm_url, - headers={"Content-Type": "application/json"}, json={ "model": settings.llm_settings.local_llm_model, - "messages": [ - {"role": "developer", "content": developer_instruction}, - {"role": "user", "content": prompt}, - ], + "messages": messages, "temperature": 0.3, + "stream": True, }, - ) - - try: + ) as response: response.raise_for_status() - data: dict[str, Any] = response.json() - return ( - data.get("choices", [{}])[0] - .get("message", {}) - .get("content", "No response") - ) - except httpx.HTTPError as err: - self.agent.logger.error("HTTP error: %s", err) - return "LLM service unavailable." - except Exception as err: - self.agent.logger.error("Unexpected error: %s", err) - return "Error processing the request." + + async for line in response.aiter_lines(): + if not line or not line.startswith("data: "): + continue + + data = line[len("data: ") :] + if data.strip() == "[DONE]": + break + + try: + event = json.loads(data) + delta = event.get("choices", [{}])[0].get("delta", {}).get("content") + if delta: + yield delta + except json.JSONDecodeError: + self.agent.logger.error("Failed to parse LLM response: %s", data) async def setup(self): """ diff --git a/src/control_backend/agents/llm/llm_instructions.py b/src/control_backend/agents/llm/llm_instructions.py index 9636d88..6922fca 100644 --- a/src/control_backend/agents/llm/llm_instructions.py +++ b/src/control_backend/agents/llm/llm_instructions.py @@ -28,7 +28,9 @@ class LLMInstructions: """ sections = [ "You are a Pepper robot engaging in natural human conversation.", - "Keep responses between 1–5 sentences, unless instructed otherwise.\n", + "Keep responses between 1–3 sentences, unless told otherwise.\n", + "You're given goals to reach. Reach them in order, but make the conversation feel " + "natural. Some turns you should not try to achieve your goals.\n", ] if self.norms: diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 0dcc981..dac41f3 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -1,6 +1,7 @@ import json import logging +import spade.agent import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour @@ -32,6 +33,8 @@ class RICommandAgent(Agent): self.bind = bind class SendCommandsBehaviour(CyclicBehaviour): + """Behaviour for sending commands received from the UI.""" + async def run(self): """ Run the command publishing loop indefinetely. @@ -50,6 +53,18 @@ class RICommandAgent(Agent): except Exception as e: logger.error("Error processing message: %s", e) + class SendPythonCommandsBehaviour(CyclicBehaviour): + """Behaviour for sending commands received from other Python agents.""" + + async def run(self): + message: spade.agent.Message = await self.receive(timeout=0.1) + if message and message.to == self.agent.jid: + try: + speech_command = SpeechCommand.model_validate_json(message.body) + await self.agent.pubsocket.send_json(speech_command.model_dump()) + except Exception as e: + logger.error("Error processing message: %s", e) + async def setup(self): """ Setup the command agent @@ -73,5 +88,6 @@ class RICommandAgent(Agent): # Add behaviour to our agent commands_behaviour = self.SendCommandsBehaviour() self.add_behaviour(commands_behaviour) + self.add_behaviour(self.SendPythonCommandsBehaviour()) logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py index 19d82ff..527d371 100644 --- a/src/control_backend/agents/transcription/speech_recognizer.py +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -36,16 +36,16 @@ class SpeechRecognizer(abc.ABC): def _estimate_max_tokens(audio: np.ndarray) -> int: """ Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking - rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens. + rate of 450 words per minute (3x average), and assumes that 3 words is 4 tokens. :param audio: The audio sample (16 kHz) to use for length estimation. :return: The estimated length of the transcribed audio in tokens. """ length_seconds = len(audio) / 16_000 length_minutes = length_seconds / 60 - word_count = length_minutes * 300 + word_count = length_minutes * 450 token_count = word_count / 3 * 4 - return int(token_count) + return int(token_count) + 10 def _get_decode_options(self, audio: np.ndarray) -> dict: """ @@ -85,9 +85,10 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() return mlx_whisper.transcribe( - audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio) - )["text"] - return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip() + audio, + path_or_hf_repo=self.model_name, + **self._get_decode_options(audio), + )["text"].strip() class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): @@ -103,6 +104,4 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): def recognize_speech(self, audio: np.ndarray) -> str: self.load_model() - return whisper.transcribe( - self.model, audio, decode_options=self._get_decode_options(audio) - )["text"] + return whisper.transcribe(self.model, audio, **self._get_decode_options(audio))["text"] diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py index 530bd68..25103a4 100644 --- a/src/control_backend/agents/transcription/transcription_agent.py +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -58,6 +58,10 @@ class TranscriptionAgent(Agent): audio = await self.audio_in_socket.recv() audio = np.frombuffer(audio, dtype=np.float32) speech = await self._transcribe(audio) + if not speech: + logger.info("Nothing transcribed.") + return + logger.info("Transcribed speech: %s", speech) await self._share_transcription(speech) diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py index f16abf4..9cf2adf 100644 --- a/src/control_backend/agents/vad_agent.py +++ b/src/control_backend/agents/vad_agent.py @@ -54,8 +54,20 @@ class Streaming(CyclicBehaviour): self.audio_buffer = np.array([], dtype=np.float32) self.i_since_speech = 100 # Used to allow small pauses in speech + self._ready = False + + async def reset(self): + """Clears the ZeroMQ queue and tells this behavior to start.""" + discarded = 0 + while await self.audio_in_poller.poll(1) is not None: + discarded += 1 + logging.info(f"Discarded {discarded} audio packets before starting.") + self._ready = True async def run(self) -> None: + if not self._ready: + return + data = await self.audio_in_poller.poll() if data is None: if len(self.audio_buffer) > 0: @@ -107,6 +119,8 @@ class VADAgent(Agent): self.audio_in_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None + self.streaming_behaviour: Streaming | None = None + async def stop(self): """ Stop listening to audio, stop publishing audio, close sockets. @@ -149,8 +163,8 @@ class VADAgent(Agent): return audio_out_address = f"tcp://localhost:{audio_out_port}" - streaming = Streaming(self.audio_in_socket, self.audio_out_socket) - self.add_behaviour(streaming) + self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket) + self.add_behaviour(self.streaming_behaviour) # Start agents dependent on the output audio fragments here transcriber = TranscriptionAgent(audio_out_address) diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 29f1396..043eefd 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -95,6 +95,8 @@ async def lifespan(app: FastAPI): _temp_vad_agent = VADAgent("tcp://localhost:5558", False) await _temp_vad_agent.start() + logger.info("VAD agent started, now making ready...") + await _temp_vad_agent.streaming_behaviour.reset() yield diff --git a/test/integration/agents/vad_agent/test_vad_with_audio.py b/test/integration/agents/vad_agent/test_vad_with_audio.py index 7d10aa3..fd7d4d7 100644 --- a/test/integration/agents/vad_agent/test_vad_with_audio.py +++ b/test/integration/agents/vad_agent/test_vad_with_audio.py @@ -48,6 +48,7 @@ async def test_real_audio(mocker): audio_out_socket = AsyncMock() vad_streamer = Streaming(audio_in_socket, audio_out_socket) + vad_streamer._ready = True for _ in audio_chunks: await vad_streamer.run() diff --git a/test/unit/agents/test_vad_streaming.py b/test/unit/agents/test_vad_streaming.py index 9b38cd0..ab2da0d 100644 --- a/test/unit/agents/test_vad_streaming.py +++ b/test/unit/agents/test_vad_streaming.py @@ -21,7 +21,9 @@ def streaming(audio_in_socket, audio_out_socket): import torch torch.hub.load.return_value = (..., ...) # Mock - return Streaming(audio_in_socket, audio_out_socket) + streaming = Streaming(audio_in_socket, audio_out_socket) + streaming._ready = True + return streaming async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): diff --git a/test/unit/agents/transcription/test_speech_recognizer.py b/test/unit/agents/transcription/test_speech_recognizer.py index 88a5ac2..ab28dcf 100644 --- a/test/unit/agents/transcription/test_speech_recognizer.py +++ b/test/unit/agents/transcription/test_speech_recognizer.py @@ -5,12 +5,13 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper def test_estimate_max_tokens(): - """Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens.""" + """Inputting one minute of audio, assuming 450 words per minute and adding a 10 token padding, + expecting 610 tokens.""" audio = np.empty(shape=(60 * 16_000), dtype=np.float32) actual = SpeechRecognizer._estimate_max_tokens(audio) - assert actual == 400 + assert actual == 610 assert isinstance(actual, int)