style: apply ruff check and format

Made sure all ruff checks pass and formatted all files.

ref: N25B-224
This commit is contained in:
2025-11-02 19:45:01 +01:00
parent 657c300bc7
commit 48c9746417
25 changed files with 199 additions and 143 deletions

View File

@@ -58,8 +58,8 @@ class BDICoreAgent(BDIAgent):
class SendBehaviour(OneShotBehaviour): class SendBehaviour(OneShotBehaviour):
async def run(self) -> None: async def run(self) -> None:
msg = Message( msg = Message(
to= settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host, to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
body= text body=text,
) )
await self.send(msg) await self.send(msg)

View File

@@ -3,7 +3,7 @@ import logging
from spade.agent import Message from spade.agent import Message
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from spade_bdi.bdi import BDIAgent, BeliefNotInitiated from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings from control_backend.core.config import settings
@@ -23,7 +23,6 @@ class BeliefSetterBehaviour(CyclicBehaviour):
self.logger.info(f"Received message {msg.body}") self.logger.info(f"Received message {msg.body}")
self._process_message(msg) self._process_message(msg)
def _process_message(self, message: Message): def _process_message(self, message: Message):
sender = message.sender.node # removes host from jid and converts to str sender = message.sender.node # removes host from jid and converts to str
self.logger.debug("Sender: %s", sender) self.logger.debug("Sender: %s", sender)
@@ -61,6 +60,7 @@ class BeliefSetterBehaviour(CyclicBehaviour):
self.agent.bdi.set_belief(belief, *arguments) self.agent.bdi.set_belief(belief, *arguments)
# Special case: if there's a new user message, flag that we haven't responded yet # Special case: if there's a new user message, flag that we haven't responded yet
if belief == "user_said": self.agent.bdi.set_belief("new_message") if belief == "user_said":
self.agent.bdi.set_belief("new_message")
self.logger.info("Set belief %s with arguments %s", belief, arguments) self.logger.info("Set belief %s with arguments %s", belief, arguments)

View File

@@ -9,7 +9,9 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
""" """
Adds behavior to receive responses from the LLM Agent. Adds behavior to receive responses from the LLM Agent.
""" """
logger = logging.getLogger("BDI/LLM Reciever") logger = logging.getLogger("BDI/LLM Reciever")
async def run(self): async def run(self):
msg = await self.receive(timeout=2) msg = await self.receive(timeout=2)
if not msg: if not msg:

View File

@@ -13,28 +13,30 @@ class BeliefFromText(CyclicBehaviour):
# TODO: LLM prompt nog hardcoded # TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """ llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules: You are an information extraction assistent for a BDI agent. Your task is to extract values \
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript). from a user's text to bind a list of ungrounded beliefs. Rules:
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
and "text" (user's transcript).
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs. Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
A single piece of text might contain multiple instances that match a belief. A single piece of text might contain multiple instances that match a belief.
Respond ONLY with a single JSON object. Respond ONLY with a single JSON object.
The JSON object's keys should be the belief functors (e.g., "weather"). The JSON object's keys should be the belief functors (e.g., "weather").
The value for each key must be a list of lists. The value for each key must be a list of lists.
Each inner list must contain the extracted arguments (as strings) for one instance of that belief. Each inner list must contain the extracted arguments (as strings) for one instance \
CRITICAL: If no information in the text matches a belief, DO NOT include that key in your response. of that belief.
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
in your response.
""" """
# on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt # on_start agent receives message containing the beliefs to look out for and
# sets up the LLM with instruction prompt
# async def on_start(self): # async def on_start(self):
# msg = await self.receive(timeout=0.1) # msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message # self.beliefs = dict uit message
# send instruction prompt to LLM # send instruction prompt to LLM
beliefs: dict[str, list[str]] beliefs: dict[str, list[str]]
beliefs = { beliefs = {"mood": ["X"], "car": ["Y"]}
"mood": ["X"],
"car": ["Y"]
}
async def run(self): async def run(self):
msg = await self.receive(timeout=0.1) msg = await self.receive(timeout=0.1)
@@ -67,8 +69,9 @@ class BeliefFromText(CyclicBehaviour):
try: try:
json.loads(response) json.loads(response)
belief_message = Message( belief_message = Message(
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=response) body=response,
)
belief_message.thread = "beliefs" belief_message.thread = "beliefs"
await self.send(belief_message) await self.send(belief_message)
@@ -85,9 +88,12 @@ class BeliefFromText(CyclicBehaviour):
""" """
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"} belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief) payload = json.dumps(belief)
belief_msg = Message(to=settings.agent_settings.belief_collector_agent_name belief_msg = Message(
+ '@' + settings.agent_settings.host, to=settings.agent_settings.belief_collector_agent_name
body=payload) + "@"
+ settings.agent_settings.host,
body=payload,
)
belief_msg.thread = "beliefs" belief_msg.thread = "beliefs"
await self.send(belief_msg) await self.send(belief_msg)

View File

@@ -1,11 +1,14 @@
import json import json
import logging import logging
from spade.behaviour import CyclicBehaviour
from spade.agent import Message from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings from control_backend.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ContinuousBeliefCollector(CyclicBehaviour): class ContinuousBeliefCollector(CyclicBehaviour):
""" """
Continuously collects beliefs/emotions from extractor agents: Continuously collects beliefs/emotions from extractor agents:
@@ -17,7 +20,6 @@ class ContinuousBeliefCollector(CyclicBehaviour):
if msg: if msg:
await self._process_message(msg) await self._process_message(msg)
async def _process_message(self, msg: Message): async def _process_message(self, msg: Message):
sender_node = self._sender_node(msg) sender_node = self._sender_node(msg)
@@ -27,7 +29,9 @@ class ContinuousBeliefCollector(CyclicBehaviour):
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s", "BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
sender_node, msg.body, e sender_node,
msg.body,
e,
) )
return return
@@ -35,16 +39,21 @@ class ContinuousBeliefCollector(CyclicBehaviour):
# Prefer explicit 'type' field # Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock": if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock":
logger.info("BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node) logger.info(
"BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node
)
await self._handle_belief_text(payload, sender_node) await self._handle_belief_text(payload, sender_node)
# This is not implemented yet, but we keep the structure for future use # This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock": elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
logger.info("BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node) logger.info(
"BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node
)
await self._handle_emo_text(payload, sender_node) await self._handle_emo_text(payload, sender_node)
else: else:
logger.info( logger.info(
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
sender_node, msg_type sender_node,
msg_type,
) )
@staticmethod @staticmethod
@@ -56,13 +65,12 @@ class ContinuousBeliefCollector(CyclicBehaviour):
s = str(msg.sender) if msg.sender is not None else "no_sender" s = str(msg.sender) if msg.sender is not None else "no_sender"
return s.split("@", 1)[0] if "@" in s else s return s.split("@", 1)[0] if "@" in s else s
async def _handle_belief_text(self, payload: dict, origin: str): async def _handle_belief_text(self, payload: dict, origin: str):
""" """
Expected payload: Expected payload:
{ {
"type": "belief_extraction_text", "type": "belief_extraction_text",
"beliefs": {"user_said": ["hello"","Can you help me?","stop talking to me","No","Pepper do a dance"]} "beliefs": {"user_said": ["Can you help me?"]}
} }
@@ -88,13 +96,10 @@ class ContinuousBeliefCollector(CyclicBehaviour):
await self._send_beliefs_to_bdi(beliefs, origin=origin) await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str): async def _handle_emo_text(self, payload: dict, origin: str):
"""TODO: implement (after we have emotional recogntion)""" """TODO: implement (after we have emotional recogntion)"""
pass pass
async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None): async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None):
""" """
Sends a unified belief packet to the BDI agent. Sends a unified belief packet to the BDI agent.
@@ -107,6 +112,5 @@ class ContinuousBeliefCollector(CyclicBehaviour):
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs") msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs) msg.body = json.dumps(beliefs)
await self.send(msg) await self.send(msg)
logger.info("BeliefCollector: sent %d belief(s) to BDI at %s", len(beliefs), to_jid) logger.info("BeliefCollector: sent %d belief(s) to BDI at %s", len(beliefs), to_jid)

View File

@@ -1,10 +1,12 @@
import logging import logging
from spade.agent import Agent from spade.agent import Agent
from .behaviours.continuous_collect import ContinuousBeliefCollector from .behaviours.continuous_collect import ContinuousBeliefCollector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BeliefCollectorAgent(Agent): class BeliefCollectorAgent(Agent):
async def setup(self): async def setup(self):
logger.info("BeliefCollectorAgent starting (%s)", self.jid) logger.info("BeliefCollectorAgent starting (%s)", self.jid)

View File

@@ -65,8 +65,8 @@ class LLMAgent(Agent):
Sends a response message back to the BDI Core Agent. Sends a response message back to the BDI Core Agent.
""" """
reply = Message( reply = Message(
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=msg body=msg,
) )
await self.send(reply) await self.send(reply)
self.agent.logger.info("Reply sent to BDI Core Agent") self.agent.logger.info("Reply sent to BDI Core Agent")
@@ -90,25 +90,21 @@ class LLMAgent(Agent):
json={ json={
"model": settings.llm_settings.local_llm_model, "model": settings.llm_settings.local_llm_model,
"messages": [ "messages": [
{ {"role": "developer", "content": developer_instruction},
"role": "developer", {"role": "user", "content": prompt},
"content": developer_instruction
},
{
"role": "user",
"content": prompt
}
], ],
"temperature": 0.3 "temperature": 0.3,
}, },
) )
try: try:
response.raise_for_status() response.raise_for_status()
data: dict[str, Any] = response.json() data: dict[str, Any] = response.json()
return data.get("choices", [{}])[0].get( return (
"message", {} data.get("choices", [{}])[0]
).get("content", "No response") .get("message", {})
.get("content", "No response")
)
except httpx.HTTPError as err: except httpx.HTTPError as err:
self.agent.logger.error("HTTP error: %s", err) self.agent.logger.error("HTTP error: %s", err)
return "LLM service unavailable." return "LLM service unavailable."

View File

@@ -1,18 +1,33 @@
import json import json
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import OneShotBehaviour from spade.behaviour import OneShotBehaviour
from spade.message import Message from spade.message import Message
from control_backend.core.config import settings from control_backend.core.config import settings
class BeliefTextAgent(Agent): class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour): class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self): async def run(self):
to_jid = f"{settings.agent_settings.belief_collector_agent_name}@{settings.agent_settings.host}" to_jid = (
settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host
)
# Send multiple beliefs in one JSON payload # Send multiple beliefs in one JSON payload
payload = { payload = {
"type": "belief_extraction_text", "type": "belief_extraction_text",
"beliefs": {"user_said": ["hello test","Can you help me?","stop talking to me","No","Pepper do a dance"]} "beliefs": {
"user_said": [
"hello test",
"Can you help me?",
"stop talking to me",
"No",
"Pepper do a dance",
]
},
} }
msg = Message(to=to_jid) msg = Message(to=to_jid)

View File

@@ -1,8 +1,9 @@
import json import json
import logging import logging
import zmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
import zmq
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context from control_backend.core.zmq_context import context

View File

@@ -1,14 +1,13 @@
import asyncio import asyncio
import json
import logging import logging
import zmq
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
import zmq
from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context from control_backend.core.zmq_context import context
from control_backend.schemas.message import Message
from control_backend.agents.ri_command_agent import RICommandAgent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,7 +46,7 @@ class RICommunicationAgent(Agent):
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0) message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :( # We didnt get a reply :(
except asyncio.TimeoutError as e: except TimeoutError:
logger.info("No ping retrieved in 3 seconds, killing myself.") logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill() self.kill()
@@ -88,7 +87,7 @@ class RICommunicationAgent(Agent):
try: try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError: except TimeoutError:
logger.warning( logger.warning(
"No connection established in 20 seconds (attempt %d/%d)", "No connection established in 20 seconds (attempt %d/%d)",
retries + 1, retries + 1,

View File

@@ -75,7 +75,8 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
self.model_name = "mlx-community/whisper-small.en-mlx" self.model_name = "mlx-community/whisper-small.en-mlx"
def load_model(self): def load_model(self):
if self.was_loaded: return if self.was_loaded:
return
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does # There appears to be no dedicated mechanism to preload a model, but this `get_model` does
# store it in memory for later usage # store it in memory for later usage
ModelHolder.get_model(self.model_name, mx.float16) ModelHolder.get_model(self.model_name, mx.float16)
@@ -83,9 +84,9 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
def recognize_speech(self, audio: np.ndarray) -> str: def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model() self.load_model()
return mlx_whisper.transcribe(audio, return mlx_whisper.transcribe(
path_or_hf_repo=self.model_name, audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio)
decode_options=self._get_decode_options(audio))["text"] )["text"]
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip() return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
@@ -95,12 +96,13 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
self.model = None self.model = None
def load_model(self): def load_model(self):
if self.model is not None: return if self.model is not None:
return
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model = whisper.load_model("small.en", device=device) self.model = whisper.load_model("small.en", device=device)
def recognize_speech(self, audio: np.ndarray) -> str: def recognize_speech(self, audio: np.ndarray) -> str:
self.load_model() self.load_model()
return whisper.transcribe(self.model, return whisper.transcribe(
audio, self.model, audio, decode_options=self._get_decode_options(audio)
decode_options=self._get_decode_options(audio))["text"] )["text"]

View File

@@ -47,7 +47,8 @@ class TranscriptionAgent(Agent):
"""Share a transcription to the other agents that depend on it.""" """Share a transcription to the other agents that depend on it."""
receiver_jids = [ receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name settings.agent_settings.text_belief_extractor_agent_name
+ '@' + settings.agent_settings.host, + "@"
+ settings.agent_settings.host,
] # Set message receivers here ] # Set message receivers here
for receiver_jid in receiver_jids: for receiver_jid in receiver_jids:

View File

@@ -1,9 +1,9 @@
from fastapi import APIRouter, Request
import logging import logging
from fastapi import APIRouter, Request
from zmq import Socket from zmq import Socket
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,5 +18,4 @@ async def receive_command(command: SpeechCommand, request: Request):
pub_socket: Socket = request.app.state.internal_comm_socket pub_socket: Socket = request.app.state.internal_comm_socket
pub_socket.send_multipart([topic, command.model_dump_json().encode()]) pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"} return {"status": "Command received"}

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import message, sse, command from control_backend.api.v1.endpoints import command, message, sse
api_router = APIRouter() api_router = APIRouter()

View File

@@ -24,6 +24,7 @@ class LLMSettings(BaseModel):
local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = "openai/gpt-oss-20b" local_llm_model: str = "openai/gpt-oss-20b"
class Settings(BaseSettings): class Settings(BaseSettings):
app_title: str = "PepperPlus" app_title: str = "PepperPlus"
@@ -37,4 +38,5 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env") model_config = SettingsConfigDict(env_file=".env")
settings = Settings() settings = Settings()

View File

@@ -8,13 +8,14 @@ import zmq
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.bdi.bdi_core import BDICoreAgent from control_backend.agents.bdi.bdi_core import BDICoreAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.agents.llm.llm import LLMAgent
from control_backend.agents.bdi.text_extractor import TBeliefExtractor from control_backend.agents.bdi.text_extractor import TBeliefExtractor
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
from control_backend.agents.llm.llm import LLMAgent
# Internal imports
from control_backend.agents.ri_communication_agent import RICommunicationAgent
from control_backend.agents.vad_agent import VADAgent
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context from control_backend.core.zmq_context import context
@@ -34,7 +35,6 @@ async def lifespan(app: FastAPI):
app.state.internal_comm_socket = internal_comm_socket app.state.internal_comm_socket = internal_comm_socket
logger.info("Internal publishing socket bound to %s", internal_comm_socket) logger.info("Internal publishing socket bound to %s", internal_comm_socket)
# Initiate agents # Initiate agents
ri_communication_agent = RICommunicationAgent( ri_communication_agent = RICommunicationAgent(
settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host, settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host,
@@ -45,26 +45,28 @@ async def lifespan(app: FastAPI):
await ri_communication_agent.start() await ri_communication_agent.start()
llm_agent = LLMAgent( llm_agent = LLMAgent(
settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.llm_agent_name, settings.agent_settings.llm_agent_name,
) )
await llm_agent.start() await llm_agent.start()
bdi_core = BDICoreAgent( bdi_core = BDICoreAgent(
settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.bdi_core_agent_name, settings.agent_settings.bdi_core_agent_name,
"src/control_backend/agents/bdi/rules.asl", "src/control_backend/agents/bdi/rules.asl",
) )
await bdi_core.start() await bdi_core.start()
belief_collector = BeliefCollectorAgent( belief_collector = BeliefCollectorAgent(
settings.agent_settings.belief_collector_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host,
settings.agent_settings.belief_collector_agent_name, settings.agent_settings.belief_collector_agent_name,
) )
await belief_collector.start() await belief_collector.start()
text_belief_extractor = TBeliefExtractor( text_belief_extractor = TBeliefExtractor(
settings.agent_settings.text_belief_extractor_agent_name + '@' + settings.agent_settings.host, settings.agent_settings.text_belief_extractor_agent_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.text_belief_extractor_agent_name, settings.agent_settings.text_belief_extractor_agent_name,
) )
await text_belief_extractor.start() await text_belief_extractor.start()

View File

@@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import Any, Literal from typing import Any
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel
class RIEndpoint(str, Enum): class RIEndpoint(str, Enum):

View File

@@ -1,10 +1,10 @@
import asyncio
import zmq
import json import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import zmq
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.ri_command_agent import RICommandAgent
from control_backend.schemas.ri_message import SpeechCommand
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -1,6 +1,8 @@
import asyncio import asyncio
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.ri_communication_agent import RICommunicationAgent
@@ -185,8 +187,8 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a # We are sending wrong negotiation info to the communication agent,
# better response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent: ) as MockCommandAgent:
@@ -358,8 +360,8 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
# Mock RICommandAgent agent startup # Mock RICommandAgent agent startup
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a # We are sending wrong negotiation info to the communication agent,
# better response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True "control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent: ) as MockCommandAgent:

View File

@@ -1,7 +1,8 @@
from unittest.mock import MagicMock
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from control_backend.api.v1.endpoints import command from control_backend.api.v1.endpoints import command
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand

View File

@@ -1,7 +1,8 @@
import pytest import pytest
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
from pydantic import ValidationError from pydantic import ValidationError
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
def valid_command_1(): def valid_command_1():
return SpeechCommand(data="Hallo?") return SpeechCommand(data="Hallo?")
@@ -13,24 +14,13 @@ def invalid_command_1():
def test_valid_speech_command_1(): def test_valid_speech_command_1():
command = valid_command_1() command = valid_command_1()
try:
RIMessage.model_validate(command) RIMessage.model_validate(command)
SpeechCommand.model_validate(command) SpeechCommand.model_validate(command)
assert True
except ValidationError:
assert False
def test_invalid_speech_command_1(): def test_invalid_speech_command_1():
command = invalid_command_1() command = invalid_command_1()
passed_ri_message_validation = False
try:
# Should succeed, still.
RIMessage.model_validate(command) RIMessage.model_validate(command)
passed_ri_message_validation = True
# Should fail. with pytest.raises(ValidationError):
SpeechCommand.model_validate(command) SpeechCommand.model_validate(command)
assert False
except ValidationError:
assert passed_ri_message_validation

View File

@@ -203,6 +203,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
assert "Set belief is_hot with arguments ['kitchen']" in caplog.text assert "Set belief is_hot with arguments ['kitchen']" in caplog.text
assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text
# def test_responded_unset(belief_setter, mock_agent): # def test_responded_unset(belief_setter, mock_agent):
# # Arrange # # Arrange
# new_beliefs = {"user_said": ["message"]} # new_beliefs = {"user_said": ["message"]}

View File

@@ -1,10 +1,12 @@
import json import json
import logging from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock, AsyncMock, call
import pytest import pytest
from control_backend.agents.belief_collector.behaviours.continuous_collect import ContinuousBeliefCollector from control_backend.agents.belief_collector.behaviours.continuous_collect import (
ContinuousBeliefCollector,
)
@pytest.fixture @pytest.fixture
def mock_agent(mocker): def mock_agent(mocker):
@@ -13,6 +15,7 @@ def mock_agent(mocker):
agent.jid = "belief_collector_agent@test" agent.jid = "belief_collector_agent@test"
return agent return agent
@pytest.fixture @pytest.fixture
def continuous_collector(mock_agent, mocker): def continuous_collector(mock_agent, mocker):
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent.""" """Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
@@ -25,6 +28,7 @@ def continuous_collector(mock_agent, mocker):
collector.receive = AsyncMock() collector.receive = AsyncMock()
return collector return collector
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_no_message_received(continuous_collector, mocker): async def test_run_no_message_received(continuous_collector, mocker):
""" """
@@ -40,6 +44,7 @@ async def test_run_no_message_received(continuous_collector, mocker):
# Assert # Assert
continuous_collector._process_message.assert_not_called() continuous_collector._process_message.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_message_received(continuous_collector, mocker): async def test_run_message_received(continuous_collector, mocker):
""" """
@@ -56,6 +61,7 @@ async def test_run_message_received(continuous_collector, mocker):
# Assert # Assert
continuous_collector._process_message.assert_awaited_once_with(mock_msg) continuous_collector._process_message.assert_awaited_once_with(mock_msg)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_message_invalid(continuous_collector, mocker): async def test_process_message_invalid(continuous_collector, mocker):
""" """
@@ -67,7 +73,9 @@ async def test_process_message_invalid(continuous_collector, mocker):
msg.body = invalid_json msg.body = invalid_json
msg.sender = "belief_text_agent_mock@test" msg.sender = "belief_text_agent_mock@test"
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
# Act # Act
await continuous_collector._process_message(msg) await continuous_collector._process_message(msg)
@@ -75,6 +83,7 @@ async def test_process_message_invalid(continuous_collector, mocker):
# Assert # Assert
logger_mock.warning.assert_called_once() logger_mock.warning.assert_called_once()
def test_get_sender_from_message(continuous_collector): def test_get_sender_from_message(continuous_collector):
""" """
Test that _sender_node correctly extracts the sender node from the message JID. Test that _sender_node correctly extracts the sender node from the message JID.
@@ -89,6 +98,7 @@ def test_get_sender_from_message(continuous_collector):
# Assert # Assert
assert sender_node == "agent_node" assert sender_node == "agent_node"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker): async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
msg = MagicMock() msg = MagicMock()
@@ -98,6 +108,7 @@ async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker
await continuous_collector._process_message(msg) await continuous_collector._process_message(msg)
spy.assert_awaited_once() spy.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker): async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
msg = MagicMock() msg = MagicMock()
@@ -107,6 +118,7 @@ async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mock
await continuous_collector._process_message(msg) await continuous_collector._process_message(msg)
spy.assert_awaited_once() spy.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_routes_to_handle_emo_text(continuous_collector, mocker): async def test_routes_to_handle_emo_text(continuous_collector, mocker):
msg = MagicMock() msg = MagicMock()
@@ -116,50 +128,64 @@ async def test_routes_to_handle_emo_text(continuous_collector, mocker):
await continuous_collector._process_message(msg) await continuous_collector._process_message(msg)
spy.assert_awaited_once() spy.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unrecognized_message_logs_info(continuous_collector, mocker): async def test_unrecognized_message_logs_info(continuous_collector, mocker):
msg = MagicMock() msg = MagicMock()
msg.body = json.dumps({"type": "something_else"}) msg.body = json.dumps({"type": "something_else"})
msg.sender = "x@test" msg.sender = "x@test"
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._process_message(msg) await continuous_collector._process_message(msg)
logger_mock.info.assert_any_call( logger_mock.info.assert_any_call(
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", "x", "something_else" "BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
"x",
"something_else",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_belief_text_no_beliefs(continuous_collector, mocker): async def test_belief_text_no_beliefs(continuous_collector, mocker):
msg_payload = {"type": "belief_extraction_text"} # no 'beliefs' msg_payload = {"type": "belief_extraction_text"} # no 'beliefs'
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._handle_belief_text(msg_payload, "origin_node") await continuous_collector._handle_belief_text(msg_payload, "origin_node")
logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.") logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_belief_text_beliefs_not_dict(continuous_collector, mocker): async def test_belief_text_beliefs_not_dict(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]} payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]}
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._handle_belief_text(payload, "origin") await continuous_collector._handle_belief_text(payload, "origin")
logger_mock.warning.assert_any_call("BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]) logger_mock.warning.assert_any_call(
"BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_belief_text_values_not_lists(continuous_collector, mocker): async def test_belief_text_values_not_lists(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}} payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}}
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._handle_belief_text(payload, "origin") await continuous_collector._handle_belief_text(payload, "origin")
logger_mock.warning.assert_any_call( logger_mock.warning.assert_any_call(
"BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"} "BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"}
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker): async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
payload = { payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
"type": "belief_extraction_text",
"beliefs": {"user_said": ["hello test", "No"]}
}
# Your code calls self.send(..); patch it (or switch implementation to self.agent.send and patch that)
continuous_collector.send = AsyncMock() continuous_collector.send = AsyncMock()
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger") logger_mock = mocker.patch(
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
)
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock") await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1) logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1)
@@ -169,12 +195,14 @@ async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector,
# make sure we attempted a send # make sure we attempted a send
continuous_collector.send.assert_awaited_once() continuous_collector.send.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_beliefs_noop_on_empty(continuous_collector): async def test_send_beliefs_noop_on_empty(continuous_collector):
continuous_collector.send = AsyncMock() continuous_collector.send = AsyncMock()
await continuous_collector._send_beliefs_to_bdi([], origin="o") await continuous_collector._send_beliefs_to_bdi([], origin="o")
continuous_collector.send.assert_not_awaited() continuous_collector.send.assert_not_awaited()
# @pytest.mark.asyncio # @pytest.mark.asyncio
# async def test_send_beliefs_sends_json_packet(continuous_collector): # async def test_send_beliefs_sends_json_packet(continuous_collector):
# # Patch .send and capture the message body # # Patch .send and capture the message body
@@ -191,16 +219,19 @@ async def test_send_beliefs_noop_on_empty(continuous_collector):
# assert "belief_packet" in json.loads(sent["body"])["type"] # assert "belief_packet" in json.loads(sent["body"])["type"]
# assert json.loads(sent["body"])["beliefs"] == beliefs # assert json.loads(sent["body"])["beliefs"] == beliefs
def test_sender_node_no_sender_returns_literal(continuous_collector): def test_sender_node_no_sender_returns_literal(continuous_collector):
msg = MagicMock() msg = MagicMock()
msg.sender = None msg.sender = None
assert continuous_collector._sender_node(msg) == "no_sender" assert continuous_collector._sender_node(msg) == "no_sender"
def test_sender_node_without_at(continuous_collector): def test_sender_node_without_at(continuous_collector):
msg = MagicMock() msg = MagicMock()
msg.sender = "localpartonly" msg.sender = "localpartonly"
assert continuous_collector._sender_node(msg) == "localpartonly" assert continuous_collector._sender_node(msg) == "localpartonly"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_belief_text_coerces_non_strings(continuous_collector, mocker): async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}} payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}