diff --git a/README.md b/README.md index 62ff566..45f8f98 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,17 @@ Using UV, installing the packages and virtual environment is as simple as typing uv sync ``` +## Local LLM + +To run a LLM locally download https://lmstudio.ai +When installing select developer mode, download a model (it will already suggest one) and run it (see developer window, status: running) + +copy the url at the top right and replace local_llm_url with it + v1/chat/completions. +This + part might differ based on what model you choose. + +copy the model name in the module loaded and replace local_llm_modelL. In settings. + + ## Running To run the project (development server), execute the following command (while inside the root repository): @@ -24,10 +35,16 @@ uv run fastapi dev src/control_backend/main.py ``` ## Testing -Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following: +Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests: ```bash -uv run --only-group test pytest +uv run --only-group test pytest test/unit +``` + +Or for integration tests: + +```bash +uv run --group integration-test pytest test/integration ``` ## GitHooks diff --git a/pyproject.toml b/pyproject.toml index 8299d0f..ee3ca08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.13" dependencies = [ "fastapi[all]>=0.115.6", "mlx-whisper>=0.4.3 ; sys_platform == 'darwin'", + "numpy>=2.3.3", "openai-whisper>=20250625", "pyaudio>=0.2.14", "pydantic>=2.12.0", @@ -33,6 +34,7 @@ integration-test = [ "soundfile>=0.13.1", ] test = [ + "numpy>=2.3.3", "pytest>=8.4.2", "pytest-asyncio>=1.2.0", "pytest-cov>=7.0.0", diff --git a/src/control_backend/agents/bdi/bdi_core.py b/src/control_backend/agents/bdi/bdi_core.py index 1696303..6e5cdc0 100644 --- a/src/control_backend/agents/bdi/bdi_core.py +++ b/src/control_backend/agents/bdi/bdi_core.py @@ -1,9 +1,15 @@ import logging import agentspeak +from spade.behaviour import OneShotBehaviour +from spade.message import Message from spade_bdi.bdi import BDIAgent -from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetter +from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetterBehaviour +from control_backend.agents.bdi.behaviours.receive_llm_resp_behaviour import ( + ReceiveLLMResponseBehaviour, +) +from control_backend.core.config import settings class BDICoreAgent(BDIAgent): @@ -11,25 +17,52 @@ class BDICoreAgent(BDIAgent): This is the Brain agent that does the belief inference with AgentSpeak. This is a continous process that happens automatically in the background. This class contains all the actions that can be called from AgentSpeak plans. - It has the BeliefSetter behaviour. + It has the BeliefSetter behaviour and can aks and recieve requests from the LLM agent. """ - logger = logging.getLogger("BDI Core") + logger = logging.getLogger("bdi_core_agent") - async def setup(self): - belief_setter = BeliefSetter() - self.add_behaviour(belief_setter) + async def setup(self) -> None: + """ + Initializes belief behaviors and message routing. + """ + self.logger.info("BDICoreAgent setup started") + + self.add_behaviour(BeliefSetterBehaviour()) + self.add_behaviour(ReceiveLLMResponseBehaviour()) + + self.logger.info("BDICoreAgent setup complete") + + def add_custom_actions(self, actions) -> None: + """ + Registers custom AgentSpeak actions callable from plans. + """ - def add_custom_actions(self, actions): @actions.add(".reply", 1) - def _reply(agent, term, intention): - message = agentspeak.grounded(term.args[0], intention.scope) - self.logger.info(f"Replying to message: {message}") - reply = self._send_to_llm(message) - self.logger.info(f"Received reply: {reply}") + def _reply(agent: "BDICoreAgent", term, intention): + """ + Sends text to the LLM (AgentSpeak action). + Example: .reply("Hello LLM!") + """ + message_text = agentspeak.grounded(term.args[0], intention.scope) + self.logger.info("Reply action sending: %s", message_text) + self._send_to_llm(str(message_text)) yield - def _send_to_llm(self, message) -> str: - """TODO: implement""" - return f"This is a reply to {message}" + def _send_to_llm(self, text: str): + """ + Sends a text query to the LLM Agent asynchronously. + """ + + class SendBehaviour(OneShotBehaviour): + async def run(self) -> None: + msg = Message( + to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host, + body=text, + ) + + await self.send(msg) + self.agent.logger.info("Message sent to LLM: %s", text) + + self.add_behaviour(SendBehaviour()) diff --git a/src/control_backend/agents/bdi/behaviours/belief_setter.py b/src/control_backend/agents/bdi/behaviours/belief_setter.py index d36fe5e..2f64036 100644 --- a/src/control_backend/agents/bdi/behaviours/belief_setter.py +++ b/src/control_backend/agents/bdi/behaviours/belief_setter.py @@ -1,4 +1,3 @@ -import asyncio import json import logging @@ -9,11 +8,10 @@ from spade_bdi.bdi import BDIAgent from control_backend.core.config import settings -class BeliefSetter(CyclicBehaviour): +class BeliefSetterBehaviour(CyclicBehaviour): """ This is the behaviour that the BDI agent runs. This behaviour waits for incoming - message and processes it based on sender. Currently, it only waits for messages - containing beliefs from BeliefCollector and adds these to its KB. + message and processes it based on sender. """ agent: BDIAgent @@ -24,7 +22,6 @@ class BeliefSetter(CyclicBehaviour): if msg: self.logger.info(f"Received message {msg.body}") self._process_message(msg) - await asyncio.sleep(1) def _process_message(self, message: Message): sender = message.sender.node # removes host from jid and converts to str @@ -35,6 +32,7 @@ class BeliefSetter(CyclicBehaviour): self.logger.debug("Processing message from belief collector.") self._process_belief_message(message) case _: + self.logger.debug("Not the belief agent, discarding message") pass def _process_belief_message(self, message: Message): @@ -44,19 +42,25 @@ class BeliefSetter(CyclicBehaviour): match message.thread: case "beliefs": try: - beliefs: dict[str, list[list[str]]] = json.loads(message.body) + beliefs: dict[str, list[str]] = json.loads(message.body) self._set_beliefs(beliefs) except json.JSONDecodeError as e: self.logger.error("Could not decode beliefs into JSON format: %s", e) case _: pass - def _set_beliefs(self, beliefs: dict[str, list[list[str]]]): + def _set_beliefs(self, beliefs: dict[str, list[str]]): + """Remove previous values for beliefs and update them with the provided values.""" if self.agent.bdi is None: self.logger.warning("Cannot set beliefs, since agent's BDI is not yet initialized.") return - for belief, arguments_list in beliefs.items(): - for arguments in arguments_list: - self.agent.bdi.set_belief(belief, *arguments) - self.logger.info("Set belief %s with arguments %s", belief, arguments) + # Set new beliefs (outdated beliefs are automatically removed) + for belief, arguments in beliefs.items(): + self.agent.bdi.set_belief(belief, *arguments) + + # Special case: if there's a new user message, flag that we haven't responded yet + if belief == "user_said": + self.agent.bdi.set_belief("new_message") + + self.logger.info("Set belief %s with arguments %s", belief, arguments) diff --git a/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py new file mode 100644 index 0000000..dc6e862 --- /dev/null +++ b/src/control_backend/agents/bdi/behaviours/receive_llm_resp_behaviour.py @@ -0,0 +1,28 @@ +import logging + +from spade.behaviour import CyclicBehaviour + +from control_backend.core.config import settings + + +class ReceiveLLMResponseBehaviour(CyclicBehaviour): + """ + Adds behavior to receive responses from the LLM Agent. + """ + + logger = logging.getLogger("BDI/LLM Reciever") + + async def run(self): + msg = await self.receive(timeout=2) + if not msg: + return + + sender = msg.sender.node + match sender: + case settings.agent_settings.llm_agent_name: + content = msg.body + self.logger.info("Received LLM response: %s", content) + # Here the BDI can pass the message back as a response + case _: + self.logger.debug("Not from the llm, discarding message") + pass diff --git a/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py new file mode 100644 index 0000000..913bc44 --- /dev/null +++ b/src/control_backend/agents/bdi/behaviours/text_belief_extractor.py @@ -0,0 +1,100 @@ +import asyncio +import json +import logging + +from spade.behaviour import CyclicBehaviour +from spade.message import Message + +from control_backend.core.config import settings + + +class BeliefFromText(CyclicBehaviour): + logger = logging.getLogger("Belief From Text") + + # TODO: LLM prompt nog hardcoded + llm_instruction_prompt = """ + You are an information extraction assistent for a BDI agent. + Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules: + You will receive a JSON object with "beliefs" + (a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript). + Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs. + A single piece of text might contain multiple instances that match a belief. + Respond ONLY with a single JSON object. + The JSON object's keys should be the belief functors (e.g., "weather"). + The value for each key must be a list of lists. + Each inner list must contain the extracted arguments + (as strings) for one instance of that belief. + CRITICAL: If no information in the text matches a belief, + DO NOT include that key in your response. + """ + + # on_start agent receives message containing the beliefs to look out + # for and sets up the LLM with instruction prompt + # async def on_start(self): + # msg = await self.receive(timeout=0.1) + # self.beliefs = dict uit message + # send instruction prompt to LLM + + beliefs: dict[str, list[str]] + beliefs = {"mood": ["X"], "car": ["Y"]} + + async def run(self): + msg = await self.receive(timeout=0.1) + if msg: + sender = msg.sender.node + match sender: + case settings.agent_settings.transcription_agent_name: + self.logger.info("Received text from transcriber.") + await self._process_transcription_demo(msg.body) + case _: + self.logger.info("Received message from other agent.") + pass + await asyncio.sleep(1) + + async def _process_transcription(self, text: str): + text_prompt = f"Text: {text}" + + beliefs_prompt = "These are the beliefs to be bound:\n" + for belief, values in self.beliefs.items(): + beliefs_prompt += f"{belief}({', '.join(values)})\n" + + prompt = text_prompt + beliefs_prompt + self.logger.info(prompt) + # prompt_msg = Message(to="LLMAgent@whatever") + # response = self.send(prompt_msg) + + # Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]] + response = '{"mood": [["happy"]]}' + # Verify by trying to parse + try: + json.loads(response) + belief_message = Message( + to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, + body=response, + ) + belief_message.thread = "beliefs" + + await self.send(belief_message) + self.logger.info("Sent beliefs to BDI.") + except json.JSONDecodeError: + # Parsing failed, so the response is in the wrong format, log warning + self.logger.warning("Received LLM response in incorrect format.") + + async def _process_transcription_demo(self, txt: str): + """ + Demo version to process the transcription input to beliefs. For the demo only the belief + 'user_said' is relevant, so this function simply makes a dict with key: "user_said", + value: txt and passes this to the Belief Collector agent. + """ + belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"} + payload = json.dumps(belief) + belief_msg = Message( + to=settings.agent_settings.belief_collector_agent_name + + "@" + + settings.agent_settings.host, + body=payload, + ) + belief_msg.thread = "beliefs" + + await self.send(belief_msg) + self.logger.info("Sent beliefs to Belief Collector.") diff --git a/src/control_backend/agents/bdi/rules.asl b/src/control_backend/agents/bdi/rules.asl index 41660a4..0001d3c 100644 --- a/src/control_backend/agents/bdi/rules.asl +++ b/src/control_backend/agents/bdi/rules.asl @@ -1,3 +1,3 @@ -+user_said(Message) : not responded <- - +responded; ++new_message : user_said(Message) <- + -new_message; .reply(Message). diff --git a/src/control_backend/agents/bdi/text_extractor.py b/src/control_backend/agents/bdi/text_extractor.py new file mode 100644 index 0000000..ff9ad58 --- /dev/null +++ b/src/control_backend/agents/bdi/text_extractor.py @@ -0,0 +1,9 @@ +from spade.agent import Agent + +from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText + + +class TBeliefExtractor(Agent): + async def setup(self): + self.b = BeliefFromText() + self.add_behaviour(self.b) diff --git a/test/__init__.py b/src/control_backend/agents/belief_collector/__init__.py similarity index 100% rename from test/__init__.py rename to src/control_backend/agents/belief_collector/__init__.py diff --git a/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py new file mode 100644 index 0000000..ada1c7a --- /dev/null +++ b/src/control_backend/agents/belief_collector/behaviours/continuous_collect.py @@ -0,0 +1,117 @@ +import json +import logging + +from spade.agent import Message +from spade.behaviour import CyclicBehaviour + +from control_backend.core.config import settings + +logger = logging.getLogger(__name__) + + +class ContinuousBeliefCollector(CyclicBehaviour): + """ + Continuously collects beliefs/emotions from extractor agents: + Then we send a unified belief packet to the BDI agent. + """ + + async def run(self): + msg = await self.receive(timeout=0.1) # Wait for 0.1s + if msg: + await self._process_message(msg) + + async def _process_message(self, msg: Message): + sender_node = self._sender_node(msg) + + # Parse JSON payload + try: + payload = json.loads(msg.body) + except Exception as e: + logger.warning( + "BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s", + sender_node, + msg.body, + e, + ) + return + + msg_type = payload.get("type") + + # Prefer explicit 'type' field + if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock": + logger.info( + "BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node + ) + await self._handle_belief_text(payload, sender_node) + # This is not implemented yet, but we keep the structure for future use + elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock": + logger.info( + "BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node + ) + await self._handle_emo_text(payload, sender_node) + else: + logger.info( + "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", + sender_node, + msg_type, + ) + + @staticmethod + def _sender_node(msg: Message) -> str: + """ + Extracts the 'node' (localpart) of the sender JID. + E.g., 'agent@host/resource' -> 'agent' + """ + s = str(msg.sender) if msg.sender is not None else "no_sender" + return s.split("@", 1)[0] if "@" in s else s + + async def _handle_belief_text(self, payload: dict, origin: str): + """ + Expected payload: + { + "type": "belief_extraction_text", + "beliefs": {"user_said": ["hello"","Can you help me?", + "stop talking to me","No","Pepper do a dance"]} + + } + + """ + beliefs = payload.get("beliefs", {}) + + if not beliefs: + logger.info("BeliefCollector: no beliefs to process.") + return + + if not isinstance(beliefs, dict): + logger.warning("BeliefCollector: 'beliefs' is not a dict: %r", beliefs) + return + + if not all(isinstance(v, list) for v in beliefs.values()): + logger.warning("BeliefCollector: 'beliefs' values are not all lists: %r", beliefs) + return + + logger.info("BeliefCollector: forwarding %d beliefs.", len(beliefs)) + for belief_name, belief_list in beliefs.items(): + for belief in belief_list: + logger.info(" - %s %s", belief_name, str(belief)) + + await self._send_beliefs_to_bdi(beliefs, origin=origin) + + async def _handle_emo_text(self, payload: dict, origin: str): + """TODO: implement (after we have emotional recogntion)""" + pass + + async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None): + """ + Sends a unified belief packet to the BDI agent. + """ + if not beliefs: + return + + to_jid = f"{settings.agent_settings.bdi_core_agent_name}@{settings.agent_settings.host}" + + msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs") + msg.body = json.dumps(beliefs) + + await self.send(msg) + logger.info("BeliefCollector: sent %d belief(s) to BDI at %s", len(beliefs), to_jid) diff --git a/src/control_backend/agents/belief_collector/belief_collector.py b/src/control_backend/agents/belief_collector/belief_collector.py new file mode 100644 index 0000000..8558242 --- /dev/null +++ b/src/control_backend/agents/belief_collector/belief_collector.py @@ -0,0 +1,15 @@ +import logging + +from spade.agent import Agent + +from .behaviours.continuous_collect import ContinuousBeliefCollector + +logger = logging.getLogger(__name__) + + +class BeliefCollectorAgent(Agent): + async def setup(self): + logger.info("BeliefCollectorAgent starting (%s)", self.jid) + # Attach the continuous collector behaviour (listens and forwards to BDI) + self.add_behaviour(ContinuousBeliefCollector()) + logger.info("BeliefCollectorAgent ready.") diff --git a/src/control_backend/agents/llm/llm.py b/src/control_backend/agents/llm/llm.py new file mode 100644 index 0000000..c3c17ab --- /dev/null +++ b/src/control_backend/agents/llm/llm.py @@ -0,0 +1,123 @@ +""" +LLM Agent module for routing text queries from the BDI Core Agent to a local LLM +service and returning its responses back to the BDI Core Agent. +""" + +import logging +from typing import Any + +import httpx +from spade.agent import Agent +from spade.behaviour import CyclicBehaviour +from spade.message import Message + +from control_backend.agents.llm.llm_instructions import LLMInstructions +from control_backend.core.config import settings + + +class LLMAgent(Agent): + """ + Agent responsible for processing user text input and querying a locally + hosted LLM for text generation. Receives messages from the BDI Core Agent + and responds with processed LLM output. + """ + + logger = logging.getLogger("llm_agent") + + class ReceiveMessageBehaviour(CyclicBehaviour): + """ + Cyclic behaviour to continuously listen for incoming messages from + the BDI Core Agent and handle them. + """ + + async def run(self): + """ + Receives SPADE messages and processes only those originating from the + configured BDI agent. + """ + msg = await self.receive(timeout=1) + if not msg: + return + + sender = msg.sender.node + self.agent.logger.info( + "Received message: %s from %s", + msg.body, + sender, + ) + + if sender == settings.agent_settings.bdi_core_agent_name: + self.agent.logger.debug("Processing message from BDI Core Agent") + await self._process_bdi_message(msg) + else: + self.agent.logger.debug("Message ignored (not from BDI Core Agent)") + + async def _process_bdi_message(self, message: Message): + """ + Forwards user text to the LLM and replies with the generated text. + """ + user_text = message.body + llm_response = await self._query_llm(user_text) + await self._reply(llm_response) + + async def _reply(self, msg: str): + """ + Sends a response message back to the BDI Core Agent. + """ + reply = Message( + to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, + body=msg, + ) + await self.send(reply) + self.agent.logger.info("Reply sent to BDI Core Agent") + + async def _query_llm(self, prompt: str) -> str: + """ + Sends a chat completion request to the local LLM service. + + :param prompt: Input text prompt to pass to the LLM. + :return: LLM-generated content or fallback message. + """ + async with httpx.AsyncClient(timeout=120.0) as client: + # Example dynamic content for future (optional) + + instructions = LLMInstructions() + developer_instruction = instructions.build_developer_instruction() + + response = await client.post( + settings.llm_settings.local_llm_url, + headers={"Content-Type": "application/json"}, + json={ + "model": settings.llm_settings.local_llm_model, + "messages": [ + {"role": "developer", "content": developer_instruction}, + {"role": "user", "content": prompt}, + ], + "temperature": 0.3, + }, + ) + + try: + response.raise_for_status() + data: dict[str, Any] = response.json() + return ( + data.get("choices", [{}])[0] + .get("message", {}) + .get("content", "No response") + ) + except httpx.HTTPError as err: + self.agent.logger.error("HTTP error: %s", err) + return "LLM service unavailable." + except Exception as err: + self.agent.logger.error("Unexpected error: %s", err) + return "Error processing the request." + + async def setup(self): + """ + Sets up the SPADE behaviour to filter and process messages from the + BDI Core Agent. + """ + self.logger.info("LLMAgent setup complete") + + behaviour = self.ReceiveMessageBehaviour() + self.add_behaviour(behaviour) diff --git a/src/control_backend/agents/llm/llm_instructions.py b/src/control_backend/agents/llm/llm_instructions.py new file mode 100644 index 0000000..9636d88 --- /dev/null +++ b/src/control_backend/agents/llm/llm_instructions.py @@ -0,0 +1,44 @@ +class LLMInstructions: + """ + Defines structured instructions that are sent along with each request + to the LLM to guide its behavior (norms, goals, etc.). + """ + + @staticmethod + def default_norms() -> str: + return """ + Be friendly and respectful. + Make the conversation feel natural and engaging. + """.strip() + + @staticmethod + def default_goals() -> str: + return """ + Try to learn the user's name during conversation. + """.strip() + + def __init__(self, norms: str | None = None, goals: str | None = None): + self.norms = norms if norms is not None else self.default_norms() + self.goals = goals if goals is not None else self.default_goals() + + def build_developer_instruction(self) -> str: + """ + Builds a multi-line formatted instruction string for the LLM. + Includes only non-empty structured fields. + """ + sections = [ + "You are a Pepper robot engaging in natural human conversation.", + "Keep responses between 1–5 sentences, unless instructed otherwise.\n", + ] + + if self.norms: + sections.append("Norms to follow:") + sections.append(self.norms) + sections.append("") + + if self.goals: + sections.append("Goals to reach:") + sections.append(self.goals) + sections.append("") + + return "\n".join(sections).strip() diff --git a/src/control_backend/agents/mock_agents/__init__.py b/src/control_backend/agents/mock_agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/control_backend/agents/mock_agents/belief_text_mock.py b/src/control_backend/agents/mock_agents/belief_text_mock.py new file mode 100644 index 0000000..769b263 --- /dev/null +++ b/src/control_backend/agents/mock_agents/belief_text_mock.py @@ -0,0 +1,43 @@ +import json + +from spade.agent import Agent +from spade.behaviour import OneShotBehaviour +from spade.message import Message + +from control_backend.core.config import settings + + +class BeliefTextAgent(Agent): + class SendOnceBehaviourBlfText(OneShotBehaviour): + async def run(self): + to_jid = ( + f"{settings.agent_settings.belief_collector_agent_name}" + f"@{settings.agent_settings.host}" + ) + + # Send multiple beliefs in one JSON payload + payload = { + "type": "belief_extraction_text", + "beliefs": { + "user_said": [ + "hello test", + "Can you help me?", + "stop talking to me", + "No", + "Pepper do a dance", + ] + }, + } + + msg = Message(to=to_jid) + msg.body = json.dumps(payload) + await self.send(msg) + print(f"Beliefs sent to {to_jid}!") + + self.exit_code = "Job Finished!" + await self.agent.stop() + + async def setup(self): + print("BeliefTextAgent started") + self.b = self.SendOnceBehaviourBlfText() + self.add_behaviour(self.b) diff --git a/src/control_backend/agents/ri_command_agent.py b/src/control_backend/agents/ri_command_agent.py index 51b8064..51e148f 100644 --- a/src/control_backend/agents/ri_command_agent.py +++ b/src/control_backend/agents/ri_command_agent.py @@ -4,9 +4,9 @@ import logging import zmq from spade.agent import Agent from spade.behaviour import CyclicBehaviour +from zmq.asyncio import Context from control_backend.core.config import settings -from control_backend.core.zmq_context import context from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) @@ -56,16 +56,18 @@ class RICommandAgent(Agent): """ logger.info("Setting up %s", self.jid) + context = Context.instance() + # To the robot self.pubsocket = context.socket(zmq.PUB) - if self.bind: + if self.bind: # TODO: Should this ever be the case? self.pubsocket.bind(self.address) else: self.pubsocket.connect(self.address) # Receive internal topics regarding commands self.subsocket = context.socket(zmq.SUB) - self.subsocket.connect(settings.zmq_settings.internal_comm_address) + self.subsocket.connect(settings.zmq_settings.internal_sub_address) self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") # Add behaviour to our agent diff --git a/src/control_backend/agents/ri_communication_agent.py b/src/control_backend/agents/ri_communication_agent.py index 0bb369d..9d9170f 100644 --- a/src/control_backend/agents/ri_communication_agent.py +++ b/src/control_backend/agents/ri_communication_agent.py @@ -2,21 +2,18 @@ import asyncio import json import logging -import zmq import zmq.asyncio from spade.agent import Agent from spade.behaviour import CyclicBehaviour +from zmq.asyncio import Context from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.core.config import settings -from control_backend.core.zmq_context import context logger = logging.getLogger(__name__) class RICommunicationAgent(Agent): - _pub_socket: zmq.asyncio.Socket - req_socket: zmq.asyncio.Socket | None _address = "" _bind = True connected = False @@ -25,7 +22,6 @@ class RICommunicationAgent(Agent): self, jid: str, password: str, - pub_socket: zmq.asyncio.Socket, port: int = 5222, verify_security: bool = False, address="tcp://localhost:0000", @@ -34,8 +30,8 @@ class RICommunicationAgent(Agent): super().__init__(jid, password, port, verify_security) self._address = address self._bind = bind - self.req_socket = None - self._pub_socket = pub_socket + self._req_socket: zmq.asyncio.Socket | None = None + self.pub_socket: zmq.asyncio.Socket | None = None class ListenBehaviour(CyclicBehaviour): async def run(self): @@ -49,7 +45,7 @@ class RICommunicationAgent(Agent): seconds_to_wait_total = 1.0 try: await asyncio.wait_for( - self.agent.req_socket.send_json(message), timeout=seconds_to_wait_total / 2 + self.agent._req_socket.send_json(message), timeout=seconds_to_wait_total / 2 ) except TimeoutError: logger.debug( @@ -61,23 +57,13 @@ class RICommunicationAgent(Agent): try: logger.debug(f"waiting for message for {seconds_to_wait_total / 2} seconds.") message = await asyncio.wait_for( - self.agent.req_socket.recv_json(), timeout=seconds_to_wait_total / 2 + self.agent._req_socket.recv_json(), timeout=seconds_to_wait_total / 2 ) # We didnt get a reply :( except TimeoutError: - logger.info( - f"No ping back retrieved in {seconds_to_wait_total / 2} seconds totalling" - f"{seconds_to_wait_total} of time, killing myself (or maybe just laying low)." - ) - # TODO: Send event to UI letting know that we've lost connection - topic = b"ping" - data = json.dumps(False).encode() - self.agent._pub_socket.send_multipart([topic, data]) - await self.agent.setup() - - except Exception as e: - logger.debug(f"Differennt exception: {e}") + logger.info("No ping retrieved in 3 seconds, killing myself.") + self.kill() logger.debug('Received message "%s"', message) if "endpoint" not in message: @@ -89,46 +75,53 @@ class RICommunicationAgent(Agent): case "ping": topic = b"ping" data = json.dumps(True).encode() - await self.agent._pub_socket.send_multipart([topic, data]) + if self.agent.pub_socket is not None: + await self.agent.pub_socket.send_multipart([topic, data]) await asyncio.sleep(1) case _: logger.info( "Received message with topic different than ping, while ping expected." ) - async def setup_req_socket(self, force=False): + async def setup_sockets(self, force=False): """ Sets up request socket for communication agent. """ - if self.req_socket is None or force: - self.req_socket = context.socket(zmq.REQ) - if self._bind: - self.req_socket.bind(self._address) + # Bind request socket + if self._req_socket is None or force: + self._req_socket = Context.instance().socket(zmq.REQ) + if self._bind: # TODO: Should this ever be the case with new architecture? + self._req_socket.bind(self._address) else: - self.req_socket.connect(self._address) + self._req_socket.connect(self._address) - async def setup(self, max_retries: int = 5): + # TODO: Check with Kasper + if self.pub_socket is None or force: + self.pub_socket = Context.instance().socket(zmq.PUB) + self.pub_socket.connect(settings.zmq_settings.internal_pub_address) + + async def setup(self, max_retries: int = 100): """ Try to setup the communication agent, we have 5 retries in case we dont have a response yet. """ logger.info("Setting up %s", self.jid) # Bind request socket - await self.setup_req_socket() + await self.setup_sockets() retries = 0 # Let's try a certain amount of times before failing connection while retries < max_retries: # Make sure the socket is properly setup. - if self.req_socket is None: + if self._req_socket is None: continue # Send our message and receive one back:) message = {"endpoint": "negotiate/ports", "data": {}} - await self.req_socket.send_json(message) + await self._req_socket.send_json(message) try: - received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) + received_message = await asyncio.wait_for(self._req_socket.recv_json(), timeout=1.0) except TimeoutError: logger.warning( @@ -173,9 +166,9 @@ class RICommunicationAgent(Agent): case "main": if addr != self._address: if not bind: - self.req_socket.connect(addr) - else: - self.req_socket.bind(addr) + self._req_socket.connect(addr) + else: # TODO: Should this ever be the case? + self._req_socket.bind(addr) case "actuation": ri_commands_agent = RICommandAgent( settings.agent_settings.ri_command_agent_name @@ -205,9 +198,17 @@ class RICommunicationAgent(Agent): listen_behaviour = self.ListenBehaviour() self.add_behaviour(listen_behaviour) - # TODO: Let UI know that we're connected >:) + # Let UI know that we're connected >:) topic = b"ping" data = json.dumps(True).encode() - await self._pub_socket.send_multipart([topic, data]) + if self.pub_socket is None: + logger.error("communication agent pub socket not correctly initialized.") + else: + try: + await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5) + except TimeoutError: + logger.error( + "Initial connection ping for router timed out in ri_communication_agent." + ) self.connected = True logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/agents/transcription/__init__.py b/src/control_backend/agents/transcription/__init__.py new file mode 100644 index 0000000..fd3c8c5 --- /dev/null +++ b/src/control_backend/agents/transcription/__init__.py @@ -0,0 +1,2 @@ +from .speech_recognizer import SpeechRecognizer as SpeechRecognizer +from .transcription_agent import TranscriptionAgent as TranscriptionAgent diff --git a/src/control_backend/agents/transcription/speech_recognizer.py b/src/control_backend/agents/transcription/speech_recognizer.py new file mode 100644 index 0000000..19d82ff --- /dev/null +++ b/src/control_backend/agents/transcription/speech_recognizer.py @@ -0,0 +1,108 @@ +import abc +import sys + +if sys.platform == "darwin": + import mlx.core as mx + import mlx_whisper + from mlx_whisper.transcribe import ModelHolder + +import numpy as np +import torch +import whisper + + +class SpeechRecognizer(abc.ABC): + def __init__(self, limit_output_length=True): + """ + :param limit_output_length: When `True`, the length of the generated speech will be limited + by the length of the input audio and some heuristics. + """ + self.limit_output_length = limit_output_length + + @abc.abstractmethod + def load_model(self): ... + + @abc.abstractmethod + def recognize_speech(self, audio: np.ndarray) -> str: + """ + Recognize speech from the given audio sample. + + :param audio: A full utterance sample. Audio must be 16 kHz, mono, np.float32, values in the + range [-1.0, 1.0]. + :return: Recognized speech. + """ + + @staticmethod + def _estimate_max_tokens(audio: np.ndarray) -> int: + """ + Estimate the maximum length of a given audio sample in tokens. Assumes a maximum speaking + rate of 300 words per minute (2x average), and assumes that 3 words is 4 tokens. + + :param audio: The audio sample (16 kHz) to use for length estimation. + :return: The estimated length of the transcribed audio in tokens. + """ + length_seconds = len(audio) / 16_000 + length_minutes = length_seconds / 60 + word_count = length_minutes * 300 + token_count = word_count / 3 * 4 + return int(token_count) + + def _get_decode_options(self, audio: np.ndarray) -> dict: + """ + :param audio: The audio sample (16 kHz) to use to determine options like max decode length. + :return: A dict that can be used to construct `whisper.DecodingOptions`. + """ + options = {} + if self.limit_output_length: + options["sample_len"] = self._estimate_max_tokens(audio) + return options + + @staticmethod + def best_type(): + """Get the best type of SpeechRecognizer based on system capabilities.""" + if torch.mps.is_available(): + print("Choosing MLX Whisper model.") + return MLXWhisperSpeechRecognizer() + else: + print("Choosing reference Whisper model.") + return OpenAIWhisperSpeechRecognizer() + + +class MLXWhisperSpeechRecognizer(SpeechRecognizer): + def __init__(self, limit_output_length=True): + super().__init__(limit_output_length) + self.was_loaded = False + self.model_name = "mlx-community/whisper-small.en-mlx" + + def load_model(self): + if self.was_loaded: + return + # There appears to be no dedicated mechanism to preload a model, but this `get_model` does + # store it in memory for later usage + ModelHolder.get_model(self.model_name, mx.float16) + self.was_loaded = True + + def recognize_speech(self, audio: np.ndarray) -> str: + self.load_model() + return mlx_whisper.transcribe( + audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio) + )["text"] + return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip() + + +class OpenAIWhisperSpeechRecognizer(SpeechRecognizer): + def __init__(self, limit_output_length=True): + super().__init__(limit_output_length) + self.model = None + + def load_model(self): + if self.model is not None: + return + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model = whisper.load_model("small.en", device=device) + + def recognize_speech(self, audio: np.ndarray) -> str: + self.load_model() + return whisper.transcribe( + self.model, audio, decode_options=self._get_decode_options(audio) + )["text"] diff --git a/src/control_backend/agents/transcription/transcription_agent.py b/src/control_backend/agents/transcription/transcription_agent.py new file mode 100644 index 0000000..530bd68 --- /dev/null +++ b/src/control_backend/agents/transcription/transcription_agent.py @@ -0,0 +1,84 @@ +import asyncio +import logging + +import numpy as np +import zmq +import zmq.asyncio as azmq +from spade.agent import Agent +from spade.behaviour import CyclicBehaviour +from spade.message import Message + +from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer +from control_backend.core.config import settings + +logger = logging.getLogger(__name__) + + +class TranscriptionAgent(Agent): + """ + An agent which listens to audio fragments with voice, transcribes them, and sends the + transcription to other agents. + """ + + def __init__(self, audio_in_address: str): + jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host + super().__init__(jid, settings.agent_settings.transcription_agent_name) + + self.audio_in_address = audio_in_address + self.audio_in_socket: azmq.Socket | None = None + + class Transcribing(CyclicBehaviour): + def __init__(self, audio_in_socket: azmq.Socket): + super().__init__() + self.audio_in_socket = audio_in_socket + self.speech_recognizer = SpeechRecognizer.best_type() + self._concurrency = asyncio.Semaphore(3) + + def warmup(self): + """Load the transcription model into memory to speed up the first transcription.""" + self.speech_recognizer.load_model() + + async def _transcribe(self, audio: np.ndarray) -> str: + async with self._concurrency: + return await asyncio.to_thread(self.speech_recognizer.recognize_speech, audio) + + async def _share_transcription(self, transcription: str): + """Share a transcription to the other agents that depend on it.""" + receiver_jids = [ + settings.agent_settings.text_belief_extractor_agent_name + + "@" + + settings.agent_settings.host, + ] # Set message receivers here + + for receiver_jid in receiver_jids: + message = Message(to=receiver_jid, body=transcription) + await self.send(message) + + async def run(self) -> None: + audio = await self.audio_in_socket.recv() + audio = np.frombuffer(audio, dtype=np.float32) + speech = await self._transcribe(audio) + logger.info("Transcribed speech: %s", speech) + + await self._share_transcription(speech) + + async def stop(self): + self.audio_in_socket.close() + self.audio_in_socket = None + return await super().stop() + + def _connect_audio_in_socket(self): + self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) + self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") + self.audio_in_socket.connect(self.audio_in_address) + + async def setup(self): + logger.info("Setting up %s", self.jid) + + self._connect_audio_in_socket() + + transcribing = self.Transcribing(self.audio_in_socket) + transcribing.warmup() + self.add_behaviour(transcribing) + + logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/agents/vad_agent.py b/src/control_backend/agents/vad_agent.py new file mode 100644 index 0000000..f16abf4 --- /dev/null +++ b/src/control_backend/agents/vad_agent.py @@ -0,0 +1,159 @@ +import logging + +import numpy as np +import torch +import zmq +import zmq.asyncio as azmq +from spade.agent import Agent +from spade.behaviour import CyclicBehaviour + +from control_backend.agents.transcription import TranscriptionAgent +from control_backend.core.config import settings + +logger = logging.getLogger(__name__) + + +class SocketPoller[T]: + """ + Convenience class for polling a socket for data with a timeout, persisting a zmq.Poller for + multiple usages. + """ + + def __init__(self, socket: azmq.Socket, timeout_ms: int = 100): + """ + :param socket: The socket to poll and get data from. + :param timeout_ms: A timeout in milliseconds to wait for data. + """ + self.socket = socket + self.poller = zmq.Poller() + self.poller.register(self.socket, zmq.POLLIN) + self.timeout_ms = timeout_ms + + async def poll(self, timeout_ms: int | None = None) -> T | None: + """ + Get data from the socket, or None if the timeout is reached. + + :param timeout_ms: If given, the timeout. Otherwise, `self.timeout_ms` is used. + :return: Data from the socket or None. + """ + timeout_ms = timeout_ms or self.timeout_ms + socks = dict(self.poller.poll(timeout_ms)) + if socks.get(self.socket) == zmq.POLLIN: + return await self.socket.recv() + return None + + +class Streaming(CyclicBehaviour): + def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket): + super().__init__() + self.audio_in_poller = SocketPoller[bytes](audio_in_socket) + self.model, _ = torch.hub.load( + repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False + ) + self.audio_out_socket = audio_out_socket + + self.audio_buffer = np.array([], dtype=np.float32) + self.i_since_speech = 100 # Used to allow small pauses in speech + + async def run(self) -> None: + data = await self.audio_in_poller.poll() + if data is None: + if len(self.audio_buffer) > 0: + logger.debug("No audio data received. Discarding buffer until new data arrives.") + self.audio_buffer = np.array([], dtype=np.float32) + self.i_since_speech = 100 + return + + # copy otherwise Torch will be sad that it's immutable + chunk = np.frombuffer(data, dtype=np.float32).copy() + prob = self.model(torch.from_numpy(chunk), 16000).item() + + if prob > 0.5: + if self.i_since_speech > 3: + logger.debug("Speech started.") + self.audio_buffer = np.append(self.audio_buffer, chunk) + self.i_since_speech = 0 + return + self.i_since_speech += 1 + + # prob < 0.5, so speech maybe ended. Wait a bit more before to be more certain + if self.i_since_speech <= 3: + self.audio_buffer = np.append(self.audio_buffer, chunk) + return + + # Speech probably ended. Make sure we have a usable amount of data. + if len(self.audio_buffer) >= 3 * len(chunk): + logger.debug("Speech ended.") + await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes()) + + # At this point, we know that the speech has ended. + # Prepend the last chunk that had no speech, for a more fluent boundary + self.audio_buffer = chunk + + +class VADAgent(Agent): + """ + An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends + fragments with detected speech to other agents over ZeroMQ. + """ + + def __init__(self, audio_in_address: str, audio_in_bind: bool): + jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host + super().__init__(jid, settings.agent_settings.vad_agent_name) + + self.audio_in_address = audio_in_address + self.audio_in_bind = audio_in_bind + + self.audio_in_socket: azmq.Socket | None = None + self.audio_out_socket: azmq.Socket | None = None + + async def stop(self): + """ + Stop listening to audio, stop publishing audio, close sockets. + """ + if self.audio_in_socket is not None: + self.audio_in_socket.close() + self.audio_in_socket = None + if self.audio_out_socket is not None: + self.audio_out_socket.close() + self.audio_out_socket = None + return await super().stop() + + def _connect_audio_in_socket(self): + self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB) + self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "") + if self.audio_in_bind: + self.audio_in_socket.bind(self.audio_in_address) + else: + self.audio_in_socket.connect(self.audio_in_address) + self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket) + + def _connect_audio_out_socket(self) -> int | None: + """Returns the port bound, or None if binding failed.""" + try: + self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) + return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) + except zmq.ZMQBindError: + logger.error("Failed to bind an audio output socket after 100 tries.") + self.audio_out_socket = None + return None + + async def setup(self): + logger.info("Setting up %s", self.jid) + + self._connect_audio_in_socket() + + audio_out_port = self._connect_audio_out_socket() + if audio_out_port is None: + await self.stop() + return + audio_out_address = f"tcp://localhost:{audio_out_port}" + + streaming = Streaming(self.audio_in_socket, self.audio_out_socket) + self.add_behaviour(streaming) + + # Start agents dependent on the output audio fragments here + transcriber = TranscriptionAgent(audio_out_address) + await transcriber.start() + + logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/api/v1/endpoints/message.py b/src/control_backend/api/v1/endpoints/message.py index 1053c3c..bd88a0b 100644 --- a/src/control_backend/api/v1/endpoints/message.py +++ b/src/control_backend/api/v1/endpoints/message.py @@ -1,7 +1,6 @@ import logging from fastapi import APIRouter, Request -from zmq import Socket from control_backend.schemas.message import Message @@ -17,8 +16,7 @@ async def receive_message(message: Message, request: Request): topic = b"message" body = message.model_dump_json().encode("utf-8") - pub_socket: Socket = request.app.state.internal_comm_socket - - pub_socket.send_multipart([topic, body]) + pub_socket = request.app.state.endpoints_pub_socket + await pub_socket.send_multipart([topic, body]) return {"status": "Message received"} diff --git a/src/control_backend/api/v1/endpoints/robot.py b/src/control_backend/api/v1/endpoints/robot.py index e114757..7b1c2f8 100644 --- a/src/control_backend/api/v1/endpoints/robot.py +++ b/src/control_backend/api/v1/endpoints/robot.py @@ -5,10 +5,9 @@ import logging import zmq.asyncio from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse -from zmq.asyncio import Socket +from zmq.asyncio import Context, Socket from control_backend.core.config import settings -from control_backend.core.zmq_context import context from control_backend.schemas.ri_message import SpeechCommand logger = logging.getLogger(__name__) @@ -21,6 +20,8 @@ async def receive_command(command: SpeechCommand, request: Request): # Validate and retrieve data. SpeechCommand.model_validate(command) topic = b"command" + + # TODO: Check with Kasper pub_socket: Socket = request.app.state.internal_comm_socket pub_socket.send_multipart([topic, command.model_dump_json().encode()]) @@ -40,8 +41,9 @@ async def ping_stream(request: Request): # Set up internal socket to receive ping updates logger.debug("Ping stream router event stream entered.") - sub_socket = context.socket(zmq.SUB) - sub_socket.connect(settings.zmq_settings.internal_comm_address) + # TODO: Check with Kasper + sub_socket = Context.instance().socket(zmq.SUB) + sub_socket.connect(settings.zmq_settings.internal_sub_address) sub_socket.setsockopt(zmq.SUBSCRIBE, b"ping") connected = False diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index f48d54f..8de2403 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -3,19 +3,29 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class ZMQSettings(BaseModel): - internal_comm_address: str = "tcp://localhost:5560" + internal_pub_address: str = "tcp://localhost:5560" + internal_sub_address: str = "tcp://localhost:5561" class AgentSettings(BaseModel): host: str = "localhost" bdi_core_agent_name: str = "bdi_core" belief_collector_agent_name: str = "belief_collector" + text_belief_extractor_agent_name: str = "text_belief_extractor" + vad_agent_name: str = "vad_agent" + llm_agent_name: str = "llm_agent" test_agent_name: str = "test_agent" + transcription_agent_name: str = "transcription_agent" ri_communication_agent_name: str = "ri_communication_agent" ri_command_agent_name: str = "ri_command_agent" +class LLMSettings(BaseModel): + local_llm_url: str = "http://localhost:1234/v1/chat/completions" + local_llm_model: str = "openai/gpt-oss-20b" + + class Settings(BaseSettings): app_title: str = "PepperPlus" @@ -25,6 +35,8 @@ class Settings(BaseSettings): agent_settings: AgentSettings = AgentSettings() + llm_settings: LLMSettings = LLMSettings() + model_config = SettingsConfigDict(env_file=".env") diff --git a/src/control_backend/core/zmq_context.py b/src/control_backend/core/zmq_context.py deleted file mode 100644 index a74544f..0000000 --- a/src/control_backend/core/zmq_context.py +++ /dev/null @@ -1,3 +0,0 @@ -from zmq.asyncio import Context - -context = Context() diff --git a/src/control_backend/main.py b/src/control_backend/main.py index a824ab1..b8e3ef3 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -3,33 +3,61 @@ # External imports import contextlib import logging +import threading import zmq from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from zmq.asyncio import Context from control_backend.agents.bdi.bdi_core import BDICoreAgent +from control_backend.agents.bdi.text_extractor import TBeliefExtractor +from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent +from control_backend.agents.llm.llm import LLMAgent # Internal imports from control_backend.agents.ri_communication_agent import RICommunicationAgent +from control_backend.agents.vad_agent import VADAgent from control_backend.api.v1.router import api_router from control_backend.core.config import settings -from control_backend.core.zmq_context import context logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) +def setup_sockets(): + context = Context.instance() + + internal_pub_socket = context.socket(zmq.XPUB) + internal_pub_socket.bind(settings.zmq_settings.internal_sub_address) + logger.debug("Internal publishing socket bound to %s", internal_pub_socket) + + internal_sub_socket = context.socket(zmq.XSUB) + internal_sub_socket.bind(settings.zmq_settings.internal_pub_address) + logger.debug("Internal subscribing socket bound to %s", internal_sub_socket) + try: + zmq.proxy(internal_sub_socket, internal_pub_socket) + except zmq.ZMQError: + logger.warning("Error while handling PUB/SUB proxy. Closing sockets.") + finally: + internal_pub_socket.close() + internal_sub_socket.close() + + @contextlib.asynccontextmanager async def lifespan(app: FastAPI): logger.info("%s starting up.", app.title) # Initiate sockets - internal_comm_socket = context.socket(zmq.PUB) - internal_comm_address = settings.zmq_settings.internal_comm_address - internal_comm_socket.bind(internal_comm_address) - app.state.internal_comm_socket = internal_comm_socket - logger.info("Internal publishing socket bound to %s", internal_comm_socket) + proxy_thread = threading.Thread(target=setup_sockets) + proxy_thread.daemon = True + proxy_thread.start() + + context = Context.instance() + + endpoints_pub_socket = context.socket(zmq.PUB) + endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address) + app.state.endpoints_pub_socket = endpoints_pub_socket # Initiate agents ri_communication_agent = RICommunicationAgent( @@ -37,12 +65,17 @@ async def lifespan(app: FastAPI): + "@" + settings.agent_settings.host, password=settings.agent_settings.ri_communication_agent_name, - pub_socket=internal_comm_socket, address="tcp://*:5555", bind=True, ) await ri_communication_agent.start() + llm_agent = LLMAgent( + settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host, + settings.agent_settings.llm_agent_name, + ) + await llm_agent.start() + bdi_core = BDICoreAgent( settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name, @@ -50,6 +83,23 @@ async def lifespan(app: FastAPI): ) await bdi_core.start() + belief_collector = BeliefCollectorAgent( + settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host, + settings.agent_settings.belief_collector_agent_name, + ) + await belief_collector.start() + + text_belief_extractor = TBeliefExtractor( + settings.agent_settings.text_belief_extractor_agent_name + + "@" + + settings.agent_settings.host, + settings.agent_settings.text_belief_extractor_agent_name, + ) + await text_belief_extractor.start() + + _temp_vad_agent = VADAgent("tcp://localhost:5558", False) + await _temp_vad_agent.start() + yield logger.info("%s shutting down.", app.title) diff --git a/test/integration/agents/test_ri_commands_agent.py b/test/integration/agents/test_ri_commands_agent.py index 4249401..a4902b5 100644 --- a/test/integration/agents/test_ri_commands_agent.py +++ b/test/integration/agents/test_ri_commands_agent.py @@ -11,25 +11,27 @@ from control_backend.agents.ri_command_agent import RICommandAgent async def test_setup_bind(monkeypatch): """Test setup with bind=True""" fake_socket = MagicMock() + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket + + # Patch Context.instance() to return fake_context monkeypatch.setattr( - "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_command_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) + monkeypatch.setattr( "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), + MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")), ) await agent.setup() - # Ensure PUB socket bound fake_socket.bind.assert_any_call("tcp://localhost:5555") - # Ensure SUB socket connected to internal address and subscribed fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") - - # Ensure behaviour attached assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours) @@ -37,19 +39,23 @@ async def test_setup_bind(monkeypatch): async def test_setup_connect(monkeypatch): """Test setup with bind=False""" fake_socket = MagicMock() + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket + + # Patch Context.instance() to return fake_context monkeypatch.setattr( - "control_backend.agents.ri_command_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_command_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) monkeypatch.setattr( "control_backend.agents.ri_command_agent.settings", - MagicMock(zmq_settings=MagicMock(internal_comm_address="tcp://internal:1234")), + MagicMock(zmq_settings=MagicMock(internal_sub_address="tcp://internal:1234")), ) await agent.setup() - # Ensure PUB socket connected fake_socket.connect.assert_any_call("tcp://localhost:5555") diff --git a/test/integration/agents/test_ri_communication_agent.py b/test/integration/agents/test_ri_communication_agent.py index baeb717..9febf20 100644 --- a/test/integration/agents/test_ri_communication_agent.py +++ b/test/integration/agents/test_ri_communication_agent.py @@ -93,12 +93,14 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_1() - - fake_pub_socket = AsyncMock() + fake_socket.send_multipart = AsyncMock() # Mock context.socket to return our fake socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) # Mock RICommandAgent agent startup @@ -107,13 +109,11 @@ async def test_setup_creates_socket_and_negotiate_1(monkeypatch): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() # --- Act --- agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) @@ -143,10 +143,14 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_2() + fake_socket.send_multipart = AsyncMock() # Mock context.socket to return our fake socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) # Mock RICommandAgent agent startup @@ -155,13 +159,11 @@ async def test_setup_creates_socket_and_negotiate_2(monkeypatch): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() # --- Act --- agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) @@ -191,10 +193,14 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_wrong_negototiate_1() + fake_socket.send_multipart = AsyncMock() # Mock context.socket to return our fake socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) # Mock RICommandAgent agent startup @@ -206,13 +212,11 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() # --- Act --- with caplog.at_level("ERROR"): agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) @@ -239,10 +243,14 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_3() + fake_socket.send_multipart = AsyncMock() # Mock context.socket to return our fake socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) # Mock RICommandAgent agent startup @@ -251,12 +259,10 @@ async def test_setup_creates_socket_and_negotiate_4(monkeypatch): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() # --- Act --- agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=True, ) @@ -286,10 +292,14 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_4() + fake_socket.send_multipart = AsyncMock() # Mock context.socket to return our fake socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) # Mock RICommandAgent agent startup @@ -298,12 +308,10 @@ async def test_setup_creates_socket_and_negotiate_5(monkeypatch): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() # --- Act --- agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) @@ -333,10 +341,14 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_correct_negototiate_5() + fake_socket.send_multipart = AsyncMock() # Mock context.socket to return our fake socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) # Mock RICommandAgent agent startup @@ -345,12 +357,10 @@ async def test_setup_creates_socket_and_negotiate_6(monkeypatch): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() # --- Act --- agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) @@ -380,10 +390,14 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = fake_json_invalid_id_negototiate() + fake_socket.send_multipart = AsyncMock() # Mock context.socket to return our fake socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) # Mock RICommandAgent agent startup @@ -395,14 +409,12 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() # --- Act --- with caplog.at_level("WARNING"): agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) @@ -426,10 +438,14 @@ async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) + fake_socket.send_multipart = AsyncMock() # Mock context.socket to return our fake socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) with patch( @@ -437,14 +453,12 @@ async def test_setup_creates_socket_and_negotiate_timeout(monkeypatch, caplog): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() # --- Act --- with caplog.at_level("WARNING"): agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) @@ -466,11 +480,11 @@ async def test_listen_behaviour_ping_correct(caplog): fake_socket = AsyncMock() fake_socket.send_json = AsyncMock() fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}}) - fake_pub_socket = AsyncMock() + fake_socket.send_multipart = AsyncMock() # TODO: Integration test between actual server and password needed for spade agents - agent = RICommunicationAgent("test@server", "password", fake_pub_socket) - agent.req_socket = fake_socket + agent = RICommunicationAgent("test@server", "password") + agent._req_socket = fake_socket behaviour = agent.ListenBehaviour() agent.add_behaviour(behaviour) @@ -505,7 +519,7 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog): fake_pub_socket = AsyncMock() agent = RICommunicationAgent("test@server", "password", fake_pub_socket) - agent.req_socket = fake_socket + agent._req_socket = fake_socket behaviour = agent.ListenBehaviour() agent.add_behaviour(behaviour) @@ -525,10 +539,10 @@ async def test_listen_behaviour_timeout(caplog): fake_socket.send_json = AsyncMock() # recv_json will never resolve, simulate timeout fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) - fake_pub_socket = AsyncMock() + fake_socket.send_multipart = AsyncMock() - agent = RICommunicationAgent("test@server", "password", fake_pub_socket) - agent.req_socket = fake_socket + agent = RICommunicationAgent("test@server", "password") + agent._req_socket = fake_socket behaviour = agent.ListenBehaviour() agent.add_behaviour(behaviour) @@ -546,6 +560,7 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): """ fake_socket = AsyncMock() fake_socket.send_json = AsyncMock() + fake_socket.send_multipart = AsyncMock() # This is a message without endpoint >:( fake_socket.recv_json = AsyncMock( @@ -553,10 +568,9 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): "data": "I dont have an endpoint >:)", } ) - fake_pub_socket = AsyncMock() - agent = RICommunicationAgent("test@server", "password", fake_pub_socket) - agent.req_socket = fake_socket + agent = RICommunicationAgent("test@server", "password") + agent._req_socket = fake_socket behaviour = agent.ListenBehaviour() agent.add_behaviour(behaviour) @@ -574,18 +588,20 @@ async def test_listen_behaviour_ping_no_endpoint(caplog): async def test_setup_unexpected_exception(monkeypatch, caplog): fake_socket = MagicMock() fake_socket.send_json = AsyncMock() - fake_pub_socket = AsyncMock() # Simulate unexpected exception during recv_json() fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) + fake_socket.send_multipart = AsyncMock() + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) @@ -602,6 +618,7 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): # --- Arrange --- fake_socket = MagicMock() fake_socket.send_json = AsyncMock() + fake_socket.send_multipart = AsyncMock() # Make recv_json return malformed negotiation data to trigger unpacking exception malformed_data = { @@ -611,8 +628,11 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): fake_socket.recv_json = AsyncMock(return_value=malformed_data) # Patch context.socket + fake_context = MagicMock() + fake_context.socket.return_value = fake_socket monkeypatch.setattr( - "control_backend.agents.ri_communication_agent.context.socket", lambda _: fake_socket + "control_backend.agents.ri_communication_agent.Context", + MagicMock(instance=MagicMock(return_value=fake_context)), ) # Patch RICommandAgent so it won't actually start @@ -621,12 +641,10 @@ async def test_setup_unpacking_exception(monkeypatch, caplog): ) as MockCommandAgent: fake_agent_instance = MockCommandAgent.return_value fake_agent_instance.start = AsyncMock() - fake_pub_socket = AsyncMock() agent = RICommunicationAgent( "test@server", "password", - pub_socket=fake_pub_socket, address="tcp://localhost:5555", bind=False, ) diff --git a/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav b/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav new file mode 100644 index 0000000..530bc0a Binary files /dev/null and b/test/integration/agents/vad_agent/speech_with_pauses_16k_1c_float32.wav differ diff --git a/test/integration/agents/vad_agent/test_vad_agent.py b/test/integration/agents/vad_agent/test_vad_agent.py new file mode 100644 index 0000000..54c9d82 --- /dev/null +++ b/test/integration/agents/vad_agent/test_vad_agent.py @@ -0,0 +1,108 @@ +import random +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import zmq +from spade.agent import Agent + +from control_backend.agents.vad_agent import VADAgent + + +@pytest.fixture +def zmq_context(mocker): + return mocker.patch("control_backend.agents.vad_agent.zmq_context") + + +@pytest.fixture +def streaming(mocker): + return mocker.patch("control_backend.agents.vad_agent.Streaming") + + +@pytest.fixture +def transcription_agent(mocker): + return mocker.patch("control_backend.agents.vad_agent.TranscriptionAgent", autospec=True) + + +@pytest.mark.asyncio +async def test_normal_setup(streaming, transcription_agent): + """ + Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio + sockets, and starts the TranscriptionAgent without loading real models. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + vad_agent.add_behaviour = MagicMock() + + await vad_agent.setup() + + streaming.assert_called_once() + vad_agent.add_behaviour.assert_called_once_with(streaming.return_value) + transcription_agent.assert_called_once() + transcription_agent.return_value.start.assert_called_once() + assert vad_agent.audio_in_socket is not None + assert vad_agent.audio_out_socket is not None + + +@pytest.mark.parametrize("do_bind", [True, False]) +def test_in_socket_creation(zmq_context, do_bind: bool): + """ + Test that the VAD agent creates an audio input socket, differentiating between binding and + connecting. + """ + vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind) + + vad_agent._connect_audio_in_socket() + + assert vad_agent.audio_in_socket is not None + + zmq_context.socket.assert_called_once_with(zmq.SUB) + zmq_context.socket.return_value.setsockopt_string.assert_called_once_with(zmq.SUBSCRIBE, "") + + if do_bind: + zmq_context.socket.return_value.bind.assert_called_once_with("tcp://*:12345") + else: + zmq_context.socket.return_value.connect.assert_called_once_with("tcp://localhost:12345") + + +def test_out_socket_creation(zmq_context): + """ + Test that the VAD agent creates an audio output socket correctly. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + + vad_agent._connect_audio_out_socket() + + assert vad_agent.audio_out_socket is not None + + zmq_context.socket.assert_called_once_with(zmq.PUB) + zmq_context.socket.return_value.bind_to_random_port.assert_called_once() + + +@pytest.mark.asyncio +async def test_out_socket_creation_failure(zmq_context): + """ + Test setup failure when the audio output socket cannot be created. + """ + with patch.object(Agent, "stop", new_callable=AsyncMock) as mock_super_stop: + zmq_context.socket.return_value.bind_to_random_port.side_effect = zmq.ZMQBindError + vad_agent = VADAgent("tcp://localhost:12345", False) + + await vad_agent.setup() + + assert vad_agent.audio_out_socket is None + mock_super_stop.assert_called_once() + + +@pytest.mark.asyncio +async def test_stop(zmq_context, transcription_agent): + """ + Test that when the VAD agent is stopped, the sockets are closed correctly. + """ + vad_agent = VADAgent("tcp://localhost:12345", False) + zmq_context.socket.return_value.bind_to_random_port.return_value = random.randint(1000, 10000) + + await vad_agent.setup() + await vad_agent.stop() + + assert zmq_context.socket.return_value.close.call_count == 2 + assert vad_agent.audio_in_socket is None + assert vad_agent.audio_out_socket is None diff --git a/test/integration/agents/vad_agent/test_vad_with_audio.py b/test/integration/agents/vad_agent/test_vad_with_audio.py new file mode 100644 index 0000000..7d10aa3 --- /dev/null +++ b/test/integration/agents/vad_agent/test_vad_with_audio.py @@ -0,0 +1,57 @@ +import os +from unittest.mock import AsyncMock, MagicMock + +import pytest +import soundfile as sf +import zmq + +from control_backend.agents.vad_agent import Streaming + + +def get_audio_chunks() -> list[bytes]: + curr_file = os.path.realpath(__file__) + curr_dir = os.path.dirname(curr_file) + file = f"{curr_dir}/speech_with_pauses_16k_1c_float32.wav" + + chunk_size = 512 + + chunks = [] + + with sf.SoundFile(file, "r") as f: + assert f.samplerate == 16000 + assert f.channels == 1 + assert f.subtype == "FLOAT" + + while True: + data = f.read(chunk_size, dtype="float32") + if len(data) != chunk_size: + break + + chunks.append(data.tobytes()) + + return chunks + + +@pytest.mark.asyncio +async def test_real_audio(mocker): + """ + Test the VAD agent with only input and output mocked. Using the real model, using real audio as + input. Ensure that it outputs some fragments with audio. + """ + audio_chunks = get_audio_chunks() + audio_in_socket = AsyncMock() + audio_in_socket.recv.side_effect = audio_chunks + + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)] + + audio_out_socket = AsyncMock() + + vad_streamer = Streaming(audio_in_socket, audio_out_socket) + for _ in audio_chunks: + await vad_streamer.run() + + audio_out_socket.send.assert_called() + for args in audio_out_socket.send.call_args_list: + assert isinstance(args[0][0], bytes) + assert len(args[0][0]) >= 512 * 4 * 3 # Should be at least 3 chunks of audio diff --git a/test/unit/agents/bdi/behaviours/test_belief_setter.py b/test/unit/agents/bdi/behaviours/test_belief_setter.py index b8f5570..c7bb0e9 100644 --- a/test/unit/agents/bdi/behaviours/test_belief_setter.py +++ b/test/unit/agents/bdi/behaviours/test_belief_setter.py @@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, call import pytest -from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetter +from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetterBehaviour # Define a constant for the collector agent name to use in tests COLLECTOR_AGENT_NAME = "belief_collector" @@ -22,16 +22,14 @@ def mock_agent(mocker): @pytest.fixture def belief_setter(mock_agent, mocker): - """Fixture to create an instance of BeliefSetter with a mocked agent.""" + """Fixture to create an instance of BeliefSetterBehaviour with a mocked agent.""" # Patch the settings to use a predictable agent name mocker.patch( "control_backend.agents.bdi.behaviours.belief_setter.settings.agent_settings.belief_collector_agent_name", COLLECTOR_AGENT_NAME, ) - # Patch asyncio.sleep to prevent tests from actually waiting - mocker.patch("asyncio.sleep", return_value=None) - setter = BeliefSetter() + setter = BeliefSetterBehaviour() setter.agent = mock_agent # Mock the receive method, we will control its return value in each test setter.receive = AsyncMock() @@ -115,7 +113,7 @@ def test_process_belief_message_valid_json(belief_setter, mocker): Test processing a valid belief message with correct thread and JSON body. """ # Arrange - beliefs_payload = {"is_hot": [["kitchen"]], "is_clean": [["kitchen"], ["bathroom"]]} + beliefs_payload = {"is_hot": ["kitchen"], "is_clean": ["kitchen", "bathroom"]} msg = create_mock_message( sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs" ) @@ -185,8 +183,8 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog): """ # Arrange beliefs_to_set = { - "is_hot": [["kitchen"], ["living_room"]], - "door_is": [["front_door", "closed"]], + "is_hot": ["kitchen"], + "door_opened": ["front_door", "back_door"], } # Act @@ -196,29 +194,38 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog): # Assert expected_calls = [ call("is_hot", "kitchen"), - call("is_hot", "living_room"), - call("door_is", "front_door", "closed"), + call("door_opened", "front_door", "back_door"), ] mock_agent.bdi.set_belief.assert_has_calls(expected_calls, any_order=True) - assert mock_agent.bdi.set_belief.call_count == 3 + assert mock_agent.bdi.set_belief.call_count == 2 # Check logs assert "Set belief is_hot with arguments ['kitchen']" in caplog.text - assert "Set belief is_hot with arguments ['living_room']" in caplog.text - assert "Set belief door_is with arguments ['front_door', 'closed']" in caplog.text + assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text -def test_set_beliefs_bdi_not_initialized(belief_setter, mock_agent, caplog): - """ - Test that a warning is logged if the agent's BDI is not initialized. - """ - # Arrange - mock_agent.bdi = None # Simulate BDI not being ready - beliefs_to_set = {"is_hot": [["kitchen"]]} +# def test_responded_unset(belief_setter, mock_agent): +# # Arrange +# new_beliefs = {"user_said": ["message"]} +# +# # Act +# belief_setter._set_beliefs(new_beliefs) +# +# # Assert +# mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")]) +# mock_agent.bdi.remove_belief.assert_has_calls([call("responded")]) - # Act - with caplog.at_level(logging.WARNING): - belief_setter._set_beliefs(beliefs_to_set) - - # Assert - assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text +# def test_set_beliefs_bdi_not_initialized(belief_setter, mock_agent, caplog): +# """ +# Test that a warning is logged if the agent's BDI is not initialized. +# """ +# # Arrange +# mock_agent.bdi = None # Simulate BDI not being ready +# beliefs_to_set = {"is_hot": ["kitchen"]} +# +# # Act +# with caplog.at_level(logging.WARNING): +# belief_setter._set_beliefs(beliefs_to_set) +# +# # Assert +# assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text diff --git a/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py b/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py new file mode 100644 index 0000000..79957f0 --- /dev/null +++ b/test/unit/agents/belief_collector/behaviours/test_continuous_collect.py @@ -0,0 +1,242 @@ +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from control_backend.agents.belief_collector.behaviours.continuous_collect import ( + ContinuousBeliefCollector, +) + + +@pytest.fixture +def mock_agent(mocker): + """Fixture to create a mock Agent.""" + agent = MagicMock() + agent.jid = "belief_collector_agent@test" + return agent + + +@pytest.fixture +def continuous_collector(mock_agent, mocker): + """Fixture to create an instance of ContinuousBeliefCollector with a mocked agent.""" + # Patch asyncio.sleep to prevent tests from actually waiting + mocker.patch("asyncio.sleep", return_value=None) + + collector = ContinuousBeliefCollector() + collector.agent = mock_agent + # Mock the receive method, we will control its return value in each test + collector.receive = AsyncMock() + return collector + + +@pytest.mark.asyncio +async def test_run_no_message_received(continuous_collector, mocker): + """ + Test that when no message is received, _process_message is not called. + """ + # Arrange + continuous_collector.receive.return_value = None + mocker.patch.object(continuous_collector, "_process_message") + + # Act + await continuous_collector.run() + + # Assert + continuous_collector._process_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_run_message_received(continuous_collector, mocker): + """ + Test that when a message is received, _process_message is called with that message. + """ + # Arrange + mock_msg = MagicMock() + continuous_collector.receive.return_value = mock_msg + mocker.patch.object(continuous_collector, "_process_message") + + # Act + await continuous_collector.run() + + # Assert + continuous_collector._process_message.assert_awaited_once_with(mock_msg) + + +@pytest.mark.asyncio +async def test_process_message_invalid(continuous_collector, mocker): + """ + Test that when an invalid JSON message is received, a warning is logged and processing stops. + """ + # Arrange + invalid_json = "this is not json" + msg = MagicMock() + msg.body = invalid_json + msg.sender = "belief_text_agent_mock@test" + + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) + + # Act + await continuous_collector._process_message(msg) + + # Assert + logger_mock.warning.assert_called_once() + + +def test_get_sender_from_message(continuous_collector): + """ + Test that _sender_node correctly extracts the sender node from the message JID. + """ + # Arrange + msg = MagicMock() + msg.sender = "agent_node@host/resource" + + # Act + sender_node = continuous_collector._sender_node(msg) + + # Assert + assert sender_node == "agent_node" + + +@pytest.mark.asyncio +async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker): + msg = MagicMock() + msg.body = json.dumps({"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}) + msg.sender = "anyone@test" + spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock()) + await continuous_collector._process_message(msg) + spy.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker): + msg = MagicMock() + msg.body = json.dumps({"beliefs": {"user_said": [["hi"]]}}) # no type + msg.sender = "belief_text_agent_mock@test" + spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock()) + await continuous_collector._process_message(msg) + spy.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_routes_to_handle_emo_text(continuous_collector, mocker): + msg = MagicMock() + msg.body = json.dumps({"type": "emotion_extraction_text"}) + msg.sender = "anyone@test" + spy = mocker.patch.object(continuous_collector, "_handle_emo_text", new=AsyncMock()) + await continuous_collector._process_message(msg) + spy.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_unrecognized_message_logs_info(continuous_collector, mocker): + msg = MagicMock() + msg.body = json.dumps({"type": "something_else"}) + msg.sender = "x@test" + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) + await continuous_collector._process_message(msg) + logger_mock.info.assert_any_call( + "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", + "x", + "something_else", + ) + + +@pytest.mark.asyncio +async def test_belief_text_no_beliefs(continuous_collector, mocker): + msg_payload = {"type": "belief_extraction_text"} # no 'beliefs' + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) + await continuous_collector._handle_belief_text(msg_payload, "origin_node") + logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.") + + +@pytest.mark.asyncio +async def test_belief_text_beliefs_not_dict(continuous_collector, mocker): + payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]} + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) + await continuous_collector._handle_belief_text(payload, "origin") + logger_mock.warning.assert_any_call( + "BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"] + ) + + +@pytest.mark.asyncio +async def test_belief_text_values_not_lists(continuous_collector, mocker): + payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}} + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) + await continuous_collector._handle_belief_text(payload, "origin") + logger_mock.warning.assert_any_call( + "BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"} + ) + + +@pytest.mark.asyncio +async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker): + payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}} + # Your code calls self.send(..); patch it + # (or switch implementation to self.agent.send and patch that) + continuous_collector.send = AsyncMock() + logger_mock = mocker.patch( + "control_backend.agents.belief_collector.behaviours.continuous_collect.logger" + ) + await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock") + + logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1) + # and the item logs: + logger_mock.info.assert_any_call(" - %s %s", "user_said", "hello test") + logger_mock.info.assert_any_call(" - %s %s", "user_said", "No") + # make sure we attempted a send + continuous_collector.send.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_beliefs_noop_on_empty(continuous_collector): + continuous_collector.send = AsyncMock() + await continuous_collector._send_beliefs_to_bdi([], origin="o") + continuous_collector.send.assert_not_awaited() + + +# @pytest.mark.asyncio +# async def test_send_beliefs_sends_json_packet(continuous_collector): +# # Patch .send and capture the message body +# sent = {} +# +# async def _fake_send(msg): +# sent["body"] = msg.body +# sent["to"] = str(msg.to) +# +# continuous_collector.send = AsyncMock(side_effect=_fake_send) +# beliefs = ["user_said hello", "user_said No"] +# await continuous_collector._send_beliefs_to_bdi(beliefs, origin="origin_node") +# +# assert "belief_packet" in json.loads(sent["body"])["type"] +# assert json.loads(sent["body"])["beliefs"] == beliefs + + +def test_sender_node_no_sender_returns_literal(continuous_collector): + msg = MagicMock() + msg.sender = None + assert continuous_collector._sender_node(msg) == "no_sender" + + +def test_sender_node_without_at(continuous_collector): + msg = MagicMock() + msg.sender = "localpartonly" + assert continuous_collector._sender_node(msg) == "localpartonly" + + +@pytest.mark.asyncio +async def test_belief_text_coerces_non_strings(continuous_collector, mocker): + payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}} + continuous_collector.send = AsyncMock() + await continuous_collector._handle_belief_text(payload, "origin") + continuous_collector.send.assert_awaited_once() diff --git a/test/unit/agents/test_vad_socket_poller.py b/test/unit/agents/test_vad_socket_poller.py new file mode 100644 index 0000000..aaf8d0f --- /dev/null +++ b/test/unit/agents/test_vad_socket_poller.py @@ -0,0 +1,46 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +import zmq + +from control_backend.agents.vad_agent import SocketPoller + + +@pytest.fixture +def socket(): + return AsyncMock() + + +@pytest.mark.asyncio +async def test_socket_poller_with_data(socket, mocker): + socket_data = b"test" + socket.recv.return_value = socket_data + + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)] + + poller = SocketPoller(socket) + # Calling `poll` twice to be able to check that the poller is reused + await poller.poll() + data = await poller.poll() + + assert data == socket_data + + # Ensure that the poller was reused + mock_poller.assert_called_once_with() + mock_poller.return_value.register.assert_called_once_with(socket, zmq.POLLIN) + + assert socket.recv.call_count == 2 + + +@pytest.mark.asyncio +async def test_socket_poller_no_data(socket, mocker): + mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") + mock_poller.return_value.poll.return_value = [] + + poller = SocketPoller(socket) + data = await poller.poll() + + assert data is None + + socket.recv.assert_not_called() diff --git a/test/unit/agents/test_vad_streaming.py b/test/unit/agents/test_vad_streaming.py new file mode 100644 index 0000000..9b38cd0 --- /dev/null +++ b/test/unit/agents/test_vad_streaming.py @@ -0,0 +1,95 @@ +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest + +from control_backend.agents.vad_agent import Streaming + + +@pytest.fixture +def audio_in_socket(): + return AsyncMock() + + +@pytest.fixture +def audio_out_socket(): + return AsyncMock() + + +@pytest.fixture +def streaming(audio_in_socket, audio_out_socket): + import torch + + torch.hub.load.return_value = (..., ...) # Mock + return Streaming(audio_in_socket, audio_out_socket) + + +async def simulate_streaming_with_probabilities(streaming, probabilities: list[float]): + """ + Simulates a streaming scenario with given VAD model probabilities for testing purposes. + + :param streaming: The streaming component to be tested. + :param probabilities: A list of probabilities representing the outputs of the VAD model. + """ + model_item = MagicMock() + model_item.item.side_effect = probabilities + streaming.model = MagicMock() + streaming.model.return_value = model_item + + audio_in_poller = AsyncMock() + audio_in_poller.poll.return_value = np.empty(shape=512, dtype=np.float32) + streaming.audio_in_poller = audio_in_poller + + for _ in probabilities: + await streaming.run() + + +@pytest.mark.asyncio +async def test_voice_activity_detected(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is voice activity detected between silences. + :return: + """ + speech_chunk_count = 5 + probabilities = [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] * 5 + await simulate_streaming_with_probabilities(streaming, probabilities) + + audio_out_socket.send.assert_called_once() + data = audio_out_socket.send.call_args[0][0] + assert isinstance(data, bytes) + # each sample has 512 frames of 4 bytes, expecting 7 chunks (5 with speech, 2 as padding) + assert len(data) == 512 * 4 * (speech_chunk_count + 2) + + +@pytest.mark.asyncio +async def test_voice_activity_short_pause(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is a short pause between speech, checking whether it ignores the + short pause. + """ + speech_chunk_count = 5 + probabilities = ( + [0.0] * 5 + [1.0] * speech_chunk_count + [0.0] + [1.0] * speech_chunk_count + [0.0] * 5 + ) + await simulate_streaming_with_probabilities(streaming, probabilities) + + audio_out_socket.send.assert_called_once() + data = audio_out_socket.send.call_args[0][0] + assert isinstance(data, bytes) + # Expecting 13 chunks (2*5 with speech, 1 pause between, 2 as padding) + assert len(data) == 512 * 4 * (speech_chunk_count * 2 + 1 + 2) + + +@pytest.mark.asyncio +async def test_no_data(audio_in_socket, audio_out_socket, streaming): + """ + Test a scenario where there is no data received. This should not cause errors. + """ + audio_in_poller = AsyncMock() + audio_in_poller.poll.return_value = None + streaming.audio_in_poller = audio_in_poller + + await streaming.run() + + audio_out_socket.send.assert_not_called() + assert len(streaming.audio_buffer) == 0 diff --git a/test/unit/agents/transcription/test_speech_recognizer.py b/test/unit/agents/transcription/test_speech_recognizer.py new file mode 100644 index 0000000..88a5ac2 --- /dev/null +++ b/test/unit/agents/transcription/test_speech_recognizer.py @@ -0,0 +1,36 @@ +import numpy as np + +from control_backend.agents.transcription import SpeechRecognizer +from control_backend.agents.transcription.speech_recognizer import OpenAIWhisperSpeechRecognizer + + +def test_estimate_max_tokens(): + """Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens.""" + audio = np.empty(shape=(60 * 16_000), dtype=np.float32) + + actual = SpeechRecognizer._estimate_max_tokens(audio) + + assert actual == 400 + assert isinstance(actual, int) + + +def test_get_decode_options(): + """Check whether the right decode options are given under different scenarios.""" + audio = np.empty(shape=(60 * 16_000), dtype=np.float32) + + # With the defaults, it should limit output length based on input size + recognizer = OpenAIWhisperSpeechRecognizer() + options = recognizer._get_decode_options(audio) + + assert "sample_len" in options + assert isinstance(options["sample_len"], int) + + # When explicitly enabled, it should limit output length based on input size + recognizer = OpenAIWhisperSpeechRecognizer(limit_output_length=True) + options = recognizer._get_decode_options(audio) + + assert "sample_len" in options + assert isinstance(options["sample_len"], int) + + # When disabled, it should not limit output length based on input size + assert "sample_rate" not in options diff --git a/test/unit/conftest.py b/test/unit/conftest.py index d7c10f2..ecf00c1 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -11,6 +11,7 @@ def pytest_configure(config): mock_spade = MagicMock() mock_spade.agent = MagicMock() mock_spade.behaviour = MagicMock() + mock_spade.message = MagicMock() mock_spade_bdi = MagicMock() mock_spade_bdi.bdi = MagicMock() @@ -21,6 +22,7 @@ def pytest_configure(config): sys.modules["spade"] = mock_spade sys.modules["spade.agent"] = mock_spade.agent sys.modules["spade.behaviour"] = mock_spade.behaviour + sys.modules["spade.message"] = mock_spade.message sys.modules["spade_bdi"] = mock_spade_bdi sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi @@ -33,3 +35,26 @@ def pytest_configure(config): mock_config_module.settings = MagicMock() sys.modules["control_backend.core.config"] = mock_config_module + + # --- Mock torch and zmq for VAD --- + mock_torch = MagicMock() + mock_zmq = MagicMock() + mock_zmq.asyncio = mock_zmq + + # In individual tests, these can be imported and the return values changed + sys.modules["torch"] = mock_torch + sys.modules["zmq"] = mock_zmq + sys.modules["zmq.asyncio"] = mock_zmq.asyncio + + # --- Mock whisper --- + mock_whisper = MagicMock() + mock_mlx = MagicMock() + mock_mlx.core = MagicMock() + mock_mlx_whisper = MagicMock() + mock_mlx_whisper.transcribe = MagicMock() + + sys.modules["whisper"] = mock_whisper + sys.modules["mlx"] = mock_mlx + sys.modules["mlx.core"] = mock_mlx + sys.modules["mlx_whisper"] = mock_mlx_whisper + sys.modules["mlx_whisper.transcribe"] = mock_mlx_whisper.transcribe diff --git a/uv.lock b/uv.lock index 07ec3c1..c2bb61a 100644 --- a/uv.lock +++ b/uv.lock @@ -1332,6 +1332,7 @@ source = { virtual = "." } dependencies = [ { name = "fastapi", extra = ["all"] }, { name = "mlx-whisper", marker = "sys_platform == 'darwin'" }, + { name = "numpy" }, { name = "openai-whisper" }, { name = "pyaudio" }, { name = "pydantic" }, @@ -1358,6 +1359,7 @@ integration-test = [ { name = "soundfile" }, ] test = [ + { name = "numpy" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -1368,6 +1370,7 @@ test = [ requires-dist = [ { name = "fastapi", extras = ["all"], specifier = ">=0.115.6" }, { name = "mlx-whisper", marker = "sys_platform == 'darwin'", specifier = ">=0.4.3" }, + { name = "numpy", specifier = ">=2.3.3" }, { name = "openai-whisper", specifier = ">=20250625" }, { name = "pyaudio", specifier = ">=0.2.14" }, { name = "pydantic", specifier = ">=2.12.0" }, @@ -1392,6 +1395,7 @@ dev = [ ] integration-test = [{ name = "soundfile", specifier = ">=0.13.1" }] test = [ + { name = "numpy", specifier = ">=2.3.3" }, { name = "pytest", specifier = ">=8.4.2" }, { name = "pytest-asyncio", specifier = ">=1.2.0" }, { name = "pytest-cov", specifier = ">=7.0.0" },