style: apply ruff check and format
Made sure all ruff checks pass and formatted all files. ref: N25B-224
This commit is contained in:
@@ -58,8 +58,8 @@ class BDICoreAgent(BDIAgent):
|
||||
class SendBehaviour(OneShotBehaviour):
|
||||
async def run(self) -> None:
|
||||
msg = Message(
|
||||
to= settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host,
|
||||
body= text
|
||||
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
|
||||
body=text,
|
||||
)
|
||||
|
||||
await self.send(msg)
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
|
||||
from spade.agent import Message
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
from spade_bdi.bdi import BDIAgent, BeliefNotInitiated
|
||||
from spade_bdi.bdi import BDIAgent
|
||||
|
||||
from control_backend.core.config import settings
|
||||
|
||||
@@ -23,7 +23,6 @@ class BeliefSetterBehaviour(CyclicBehaviour):
|
||||
self.logger.info(f"Received message {msg.body}")
|
||||
self._process_message(msg)
|
||||
|
||||
|
||||
def _process_message(self, message: Message):
|
||||
sender = message.sender.node # removes host from jid and converts to str
|
||||
self.logger.debug("Sender: %s", sender)
|
||||
@@ -61,6 +60,7 @@ class BeliefSetterBehaviour(CyclicBehaviour):
|
||||
self.agent.bdi.set_belief(belief, *arguments)
|
||||
|
||||
# Special case: if there's a new user message, flag that we haven't responded yet
|
||||
if belief == "user_said": self.agent.bdi.set_belief("new_message")
|
||||
if belief == "user_said":
|
||||
self.agent.bdi.set_belief("new_message")
|
||||
|
||||
self.logger.info("Set belief %s with arguments %s", belief, arguments)
|
||||
|
||||
@@ -9,7 +9,9 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
||||
"""
|
||||
Adds behavior to receive responses from the LLM Agent.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger("BDI/LLM Reciever")
|
||||
|
||||
async def run(self):
|
||||
msg = await self.receive(timeout=2)
|
||||
if not msg:
|
||||
@@ -20,7 +22,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
||||
case settings.agent_settings.llm_agent_name:
|
||||
content = msg.body
|
||||
self.logger.info("Received LLM response: %s", content)
|
||||
#Here the BDI can pass the message back as a response
|
||||
# Here the BDI can pass the message back as a response
|
||||
case _:
|
||||
self.logger.debug("Not from the llm, discarding message")
|
||||
pass
|
||||
@@ -13,28 +13,30 @@ class BeliefFromText(CyclicBehaviour):
|
||||
|
||||
# TODO: LLM prompt nog hardcoded
|
||||
llm_instruction_prompt = """
|
||||
You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
|
||||
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) and "text" (user's transcript).
|
||||
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
||||
from a user's text to bind a list of ungrounded beliefs. Rules:
|
||||
You will receive a JSON object with "beliefs" (a list of ungrounded AgentSpeak beliefs) \
|
||||
and "text" (user's transcript).
|
||||
Analyze the text to find values that sematically match the variables (X,Y,Z) in the beliefs.
|
||||
A single piece of text might contain multiple instances that match a belief.
|
||||
Respond ONLY with a single JSON object.
|
||||
The JSON object's keys should be the belief functors (e.g., "weather").
|
||||
The value for each key must be a list of lists.
|
||||
Each inner list must contain the extracted arguments (as strings) for one instance of that belief.
|
||||
CRITICAL: If no information in the text matches a belief, DO NOT include that key in your response.
|
||||
Each inner list must contain the extracted arguments (as strings) for one instance \
|
||||
of that belief.
|
||||
CRITICAL: If no information in the text matches a belief, DO NOT include that key \
|
||||
in your response.
|
||||
"""
|
||||
|
||||
# on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt
|
||||
#async def on_start(self):
|
||||
# on_start agent receives message containing the beliefs to look out for and
|
||||
# sets up the LLM with instruction prompt
|
||||
# async def on_start(self):
|
||||
# msg = await self.receive(timeout=0.1)
|
||||
# self.beliefs = dict uit message
|
||||
# send instruction prompt to LLM
|
||||
|
||||
beliefs: dict[str, list[str]]
|
||||
beliefs = {
|
||||
"mood": ["X"],
|
||||
"car": ["Y"]
|
||||
}
|
||||
beliefs = {"mood": ["X"], "car": ["Y"]}
|
||||
|
||||
async def run(self):
|
||||
msg = await self.receive(timeout=0.1)
|
||||
@@ -58,8 +60,8 @@ class BeliefFromText(CyclicBehaviour):
|
||||
|
||||
prompt = text_prompt + beliefs_prompt
|
||||
self.logger.info(prompt)
|
||||
#prompt_msg = Message(to="LLMAgent@whatever")
|
||||
#response = self.send(prompt_msg)
|
||||
# prompt_msg = Message(to="LLMAgent@whatever")
|
||||
# response = self.send(prompt_msg)
|
||||
|
||||
# Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]]
|
||||
response = '{"mood": [["happy"]]}'
|
||||
@@ -67,8 +69,9 @@ class BeliefFromText(CyclicBehaviour):
|
||||
try:
|
||||
json.loads(response)
|
||||
belief_message = Message(
|
||||
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
|
||||
body=response)
|
||||
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
|
||||
body=response,
|
||||
)
|
||||
belief_message.thread = "beliefs"
|
||||
|
||||
await self.send(belief_message)
|
||||
@@ -85,9 +88,12 @@ class BeliefFromText(CyclicBehaviour):
|
||||
"""
|
||||
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
|
||||
payload = json.dumps(belief)
|
||||
belief_msg = Message(to=settings.agent_settings.belief_collector_agent_name
|
||||
+ '@' + settings.agent_settings.host,
|
||||
body=payload)
|
||||
belief_msg = Message(
|
||||
to=settings.agent_settings.belief_collector_agent_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host,
|
||||
body=payload,
|
||||
)
|
||||
belief_msg.thread = "beliefs"
|
||||
|
||||
await self.send(belief_msg)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
|
||||
from spade.agent import Message
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
|
||||
from control_backend.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
"""
|
||||
Continuously collects beliefs/emotions from extractor agents:
|
||||
@@ -17,7 +20,6 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
if msg:
|
||||
await self._process_message(msg)
|
||||
|
||||
|
||||
async def _process_message(self, msg: Message):
|
||||
sender_node = self._sender_node(msg)
|
||||
|
||||
@@ -27,7 +29,9 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"BeliefCollector: failed to parse JSON from %s. Body=%r Error=%s",
|
||||
sender_node, msg.body, e
|
||||
sender_node,
|
||||
msg.body,
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -35,16 +39,21 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
|
||||
# Prefer explicit 'type' field
|
||||
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock":
|
||||
logger.info("BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node)
|
||||
logger.info(
|
||||
"BeliefCollector: message routed to _handle_belief_text (sender=%s)", sender_node
|
||||
)
|
||||
await self._handle_belief_text(payload, sender_node)
|
||||
#This is not implemented yet, but we keep the structure for future use
|
||||
# This is not implemented yet, but we keep the structure for future use
|
||||
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
|
||||
logger.info("BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node)
|
||||
logger.info(
|
||||
"BeliefCollector: message routed to _handle_emo_text (sender=%s)", sender_node
|
||||
)
|
||||
await self._handle_emo_text(payload, sender_node)
|
||||
else:
|
||||
logger.info(
|
||||
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
|
||||
sender_node, msg_type
|
||||
sender_node,
|
||||
msg_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -56,13 +65,12 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
s = str(msg.sender) if msg.sender is not None else "no_sender"
|
||||
return s.split("@", 1)[0] if "@" in s else s
|
||||
|
||||
|
||||
async def _handle_belief_text(self, payload: dict, origin: str):
|
||||
"""
|
||||
Expected payload:
|
||||
{
|
||||
"type": "belief_extraction_text",
|
||||
"beliefs": {"user_said": ["hello"","Can you help me?","stop talking to me","No","Pepper do a dance"]}
|
||||
"beliefs": {"user_said": ["Can you help me?"]}
|
||||
|
||||
}
|
||||
|
||||
@@ -84,17 +92,14 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
logger.info("BeliefCollector: forwarding %d beliefs.", len(beliefs))
|
||||
for belief_name, belief_list in beliefs.items():
|
||||
for belief in belief_list:
|
||||
logger.info(" - %s %s", belief_name,str(belief))
|
||||
logger.info(" - %s %s", belief_name, str(belief))
|
||||
|
||||
await self._send_beliefs_to_bdi(beliefs, origin=origin)
|
||||
|
||||
|
||||
|
||||
async def _handle_emo_text(self, payload: dict, origin: str):
|
||||
"""TODO: implement (after we have emotional recogntion)"""
|
||||
pass
|
||||
|
||||
|
||||
async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None):
|
||||
"""
|
||||
Sends a unified belief packet to the BDI agent.
|
||||
@@ -107,6 +112,5 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
|
||||
msg.body = json.dumps(beliefs)
|
||||
|
||||
|
||||
await self.send(msg)
|
||||
logger.info("BeliefCollector: sent %d belief(s) to BDI at %s", len(beliefs), to_jid)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
|
||||
from spade.agent import Agent
|
||||
|
||||
from .behaviours.continuous_collect import ContinuousBeliefCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BeliefCollectorAgent(Agent):
|
||||
async def setup(self):
|
||||
logger.info("BeliefCollectorAgent starting (%s)", self.jid)
|
||||
|
||||
@@ -65,8 +65,8 @@ class LLMAgent(Agent):
|
||||
Sends a response message back to the BDI Core Agent.
|
||||
"""
|
||||
reply = Message(
|
||||
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
|
||||
body=msg
|
||||
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
|
||||
body=msg,
|
||||
)
|
||||
await self.send(reply)
|
||||
self.agent.logger.info("Reply sent to BDI Core Agent")
|
||||
@@ -90,25 +90,21 @@ class LLMAgent(Agent):
|
||||
json={
|
||||
"model": settings.llm_settings.local_llm_model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": developer_instruction
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
{"role": "developer", "content": developer_instruction},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"temperature": 0.3
|
||||
"temperature": 0.3,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
data: dict[str, Any] = response.json()
|
||||
return data.get("choices", [{}])[0].get(
|
||||
"message", {}
|
||||
).get("content", "No response")
|
||||
return (
|
||||
data.get("choices", [{}])[0]
|
||||
.get("message", {})
|
||||
.get("content", "No response")
|
||||
)
|
||||
except httpx.HTTPError as err:
|
||||
self.agent.logger.error("HTTP error: %s", err)
|
||||
return "LLM service unavailable."
|
||||
|
||||
@@ -1,18 +1,33 @@
|
||||
import json
|
||||
|
||||
from spade.agent import Agent
|
||||
from spade.behaviour import OneShotBehaviour
|
||||
from spade.message import Message
|
||||
|
||||
from control_backend.core.config import settings
|
||||
|
||||
|
||||
class BeliefTextAgent(Agent):
|
||||
class SendOnceBehaviourBlfText(OneShotBehaviour):
|
||||
async def run(self):
|
||||
to_jid = f"{settings.agent_settings.belief_collector_agent_name}@{settings.agent_settings.host}"
|
||||
to_jid = (
|
||||
settings.agent_settings.belief_collector_agent_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host
|
||||
)
|
||||
|
||||
# Send multiple beliefs in one JSON payload
|
||||
payload = {
|
||||
"type": "belief_extraction_text",
|
||||
"beliefs": {"user_said": ["hello test","Can you help me?","stop talking to me","No","Pepper do a dance"]}
|
||||
"beliefs": {
|
||||
"user_said": [
|
||||
"hello test",
|
||||
"Can you help me?",
|
||||
"stop talking to me",
|
||||
"No",
|
||||
"Pepper do a dance",
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
msg = Message(to=to_jid)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import zmq
|
||||
from spade.agent import Agent
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
import zmq
|
||||
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import zmq
|
||||
from spade.agent import Agent
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context
|
||||
from control_backend.schemas.message import Message
|
||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,7 +46,7 @@ class RICommunicationAgent(Agent):
|
||||
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
|
||||
|
||||
# We didnt get a reply :(
|
||||
except asyncio.TimeoutError as e:
|
||||
except TimeoutError:
|
||||
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
||||
self.kill()
|
||||
|
||||
@@ -88,7 +87,7 @@ class RICommunicationAgent(Agent):
|
||||
try:
|
||||
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"No connection established in 20 seconds (attempt %d/%d)",
|
||||
retries + 1,
|
||||
|
||||
@@ -75,7 +75,8 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
self.model_name = "mlx-community/whisper-small.en-mlx"
|
||||
|
||||
def load_model(self):
|
||||
if self.was_loaded: return
|
||||
if self.was_loaded:
|
||||
return
|
||||
# There appears to be no dedicated mechanism to preload a model, but this `get_model` does
|
||||
# store it in memory for later usage
|
||||
ModelHolder.get_model(self.model_name, mx.float16)
|
||||
@@ -83,9 +84,9 @@ class MLXWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return mlx_whisper.transcribe(audio,
|
||||
path_or_hf_repo=self.model_name,
|
||||
decode_options=self._get_decode_options(audio))["text"]
|
||||
return mlx_whisper.transcribe(
|
||||
audio, path_or_hf_repo=self.model_name, decode_options=self._get_decode_options(audio)
|
||||
)["text"]
|
||||
return mlx_whisper.transcribe(audio, path_or_hf_repo=self.model_name)["text"].strip()
|
||||
|
||||
|
||||
@@ -95,12 +96,13 @@ class OpenAIWhisperSpeechRecognizer(SpeechRecognizer):
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
if self.model is not None: return
|
||||
if self.model is not None:
|
||||
return
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
self.model = whisper.load_model("small.en", device=device)
|
||||
|
||||
def recognize_speech(self, audio: np.ndarray) -> str:
|
||||
self.load_model()
|
||||
return whisper.transcribe(self.model,
|
||||
audio,
|
||||
decode_options=self._get_decode_options(audio))["text"]
|
||||
return whisper.transcribe(
|
||||
self.model, audio, decode_options=self._get_decode_options(audio)
|
||||
)["text"]
|
||||
|
||||
@@ -47,7 +47,8 @@ class TranscriptionAgent(Agent):
|
||||
"""Share a transcription to the other agents that depend on it."""
|
||||
receiver_jids = [
|
||||
settings.agent_settings.text_belief_extractor_agent_name
|
||||
+ '@' + settings.agent_settings.host,
|
||||
+ "@"
|
||||
+ settings.agent_settings.host,
|
||||
] # Set message receivers here
|
||||
|
||||
for receiver_jid in receiver_jids:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from fastapi import APIRouter, Request
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from zmq import Socket
|
||||
|
||||
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,5 +18,4 @@ async def receive_command(command: SpeechCommand, request: Request):
|
||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
||||
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||
|
||||
|
||||
return {"status": "Command received"}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from control_backend.api.v1.endpoints import message, sse, command
|
||||
from control_backend.api.v1.endpoints import command, message, sse
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class LLMSettings(BaseModel):
|
||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||
local_llm_model: str = "openai/gpt-oss-20b"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_title: str = "PepperPlus"
|
||||
|
||||
@@ -37,4 +38,5 @@ class Settings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -8,13 +8,14 @@ import zmq
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# Internal imports
|
||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||
from control_backend.agents.bdi.bdi_core import BDICoreAgent
|
||||
from control_backend.agents.vad_agent import VADAgent
|
||||
from control_backend.agents.llm.llm import LLMAgent
|
||||
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
|
||||
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
|
||||
from control_backend.agents.llm.llm import LLMAgent
|
||||
|
||||
# Internal imports
|
||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||
from control_backend.agents.vad_agent import VADAgent
|
||||
from control_backend.api.v1.router import api_router
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.core.zmq_context import context
|
||||
@@ -34,7 +35,6 @@ async def lifespan(app: FastAPI):
|
||||
app.state.internal_comm_socket = internal_comm_socket
|
||||
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
|
||||
|
||||
|
||||
# Initiate agents
|
||||
ri_communication_agent = RICommunicationAgent(
|
||||
settings.agent_settings.ri_communication_agent_name + "@" + settings.agent_settings.host,
|
||||
@@ -45,26 +45,28 @@ async def lifespan(app: FastAPI):
|
||||
await ri_communication_agent.start()
|
||||
|
||||
llm_agent = LLMAgent(
|
||||
settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host,
|
||||
settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
|
||||
settings.agent_settings.llm_agent_name,
|
||||
)
|
||||
await llm_agent.start()
|
||||
|
||||
bdi_core = BDICoreAgent(
|
||||
settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
|
||||
settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
|
||||
settings.agent_settings.bdi_core_agent_name,
|
||||
"src/control_backend/agents/bdi/rules.asl",
|
||||
)
|
||||
await bdi_core.start()
|
||||
|
||||
belief_collector = BeliefCollectorAgent(
|
||||
settings.agent_settings.belief_collector_agent_name + '@' + settings.agent_settings.host,
|
||||
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host,
|
||||
settings.agent_settings.belief_collector_agent_name,
|
||||
)
|
||||
await belief_collector.start()
|
||||
|
||||
text_belief_extractor = TBeliefExtractor(
|
||||
settings.agent_settings.text_belief_extractor_agent_name + '@' + settings.agent_settings.host,
|
||||
settings.agent_settings.text_belief_extractor_agent_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host,
|
||||
settings.agent_settings.text_belief_extractor_agent_name,
|
||||
)
|
||||
await text_belief_extractor.start()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RIEndpoint(str, Enum):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import zmq
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, ANY
|
||||
|
||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||
|
||||
|
||||
@@ -185,8 +187,8 @@ async def test_setup_creates_socket_and_negotiate_3(monkeypatch, caplog):
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
|
||||
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
|
||||
# better response, within a limited time.
|
||||
# We are sending wrong negotiation info to the communication agent,
|
||||
# so we should retry and expect a better response, within a limited time.
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
) as MockCommandAgent:
|
||||
@@ -358,8 +360,8 @@ async def test_setup_creates_socket_and_negotiate_7(monkeypatch, caplog):
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
|
||||
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
|
||||
# better response, within a limited time.
|
||||
# We are sending wrong negotiation info to the communication agent,
|
||||
# so we should retry and expect a better response, within a limited time.
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
) as MockCommandAgent:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from control_backend.api.v1.endpoints import command
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
|
||||
|
||||
|
||||
def valid_command_1():
|
||||
return SpeechCommand(data="Hallo?")
|
||||
@@ -13,24 +14,13 @@ def invalid_command_1():
|
||||
|
||||
def test_valid_speech_command_1():
|
||||
command = valid_command_1()
|
||||
try:
|
||||
RIMessage.model_validate(command)
|
||||
SpeechCommand.model_validate(command)
|
||||
assert True
|
||||
except ValidationError:
|
||||
assert False
|
||||
|
||||
|
||||
def test_invalid_speech_command_1():
|
||||
command = invalid_command_1()
|
||||
passed_ri_message_validation = False
|
||||
try:
|
||||
# Should succeed, still.
|
||||
RIMessage.model_validate(command)
|
||||
passed_ri_message_validation = True
|
||||
|
||||
# Should fail.
|
||||
with pytest.raises(ValidationError):
|
||||
SpeechCommand.model_validate(command)
|
||||
assert False
|
||||
except ValidationError:
|
||||
assert passed_ri_message_validation
|
||||
|
||||
@@ -203,6 +203,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
|
||||
assert "Set belief is_hot with arguments ['kitchen']" in caplog.text
|
||||
assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text
|
||||
|
||||
|
||||
# def test_responded_unset(belief_setter, mock_agent):
|
||||
# # Arrange
|
||||
# new_beliefs = {"user_said": ["message"]}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import MagicMock, AsyncMock, call
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.belief_collector.behaviours.continuous_collect import ContinuousBeliefCollector
|
||||
from control_backend.agents.belief_collector.behaviours.continuous_collect import (
|
||||
ContinuousBeliefCollector,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent(mocker):
|
||||
@@ -13,6 +15,7 @@ def mock_agent(mocker):
|
||||
agent.jid = "belief_collector_agent@test"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def continuous_collector(mock_agent, mocker):
|
||||
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
|
||||
@@ -25,6 +28,7 @@ def continuous_collector(mock_agent, mocker):
|
||||
collector.receive = AsyncMock()
|
||||
return collector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_no_message_received(continuous_collector, mocker):
|
||||
"""
|
||||
@@ -40,6 +44,7 @@ async def test_run_no_message_received(continuous_collector, mocker):
|
||||
# Assert
|
||||
continuous_collector._process_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_message_received(continuous_collector, mocker):
|
||||
"""
|
||||
@@ -56,6 +61,7 @@ async def test_run_message_received(continuous_collector, mocker):
|
||||
# Assert
|
||||
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_invalid(continuous_collector, mocker):
|
||||
"""
|
||||
@@ -67,7 +73,9 @@ async def test_process_message_invalid(continuous_collector, mocker):
|
||||
msg.body = invalid_json
|
||||
msg.sender = "belief_text_agent_mock@test"
|
||||
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
|
||||
# Act
|
||||
await continuous_collector._process_message(msg)
|
||||
@@ -75,6 +83,7 @@ async def test_process_message_invalid(continuous_collector, mocker):
|
||||
# Assert
|
||||
logger_mock.warning.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sender_from_message(continuous_collector):
|
||||
"""
|
||||
Test that _sender_node correctly extracts the sender node from the message JID.
|
||||
@@ -89,6 +98,7 @@ def test_get_sender_from_message(continuous_collector):
|
||||
# Assert
|
||||
assert sender_node == "agent_node"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
@@ -98,6 +108,7 @@ async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
@@ -107,6 +118,7 @@ async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mock
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
@@ -116,50 +128,64 @@ async def test_routes_to_handle_emo_text(continuous_collector, mocker):
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unrecognized_message_logs_info(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
msg.body = json.dumps({"type": "something_else"})
|
||||
msg.sender = "x@test"
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._process_message(msg)
|
||||
logger_mock.info.assert_any_call(
|
||||
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", "x", "something_else"
|
||||
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
|
||||
"x",
|
||||
"something_else",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_no_beliefs(continuous_collector, mocker):
|
||||
msg_payload = {"type": "belief_extraction_text"} # no 'beliefs'
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(msg_payload, "origin_node")
|
||||
logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_beliefs_not_dict(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]}
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "origin")
|
||||
logger_mock.warning.assert_any_call("BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"])
|
||||
logger_mock.warning.assert_any_call(
|
||||
"BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_values_not_lists(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}}
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "origin")
|
||||
logger_mock.warning.assert_any_call(
|
||||
"BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
|
||||
payload = {
|
||||
"type": "belief_extraction_text",
|
||||
"beliefs": {"user_said": ["hello test", "No"]}
|
||||
}
|
||||
# Your code calls self.send(..); patch it (or switch implementation to self.agent.send and patch that)
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
|
||||
continuous_collector.send = AsyncMock()
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
|
||||
|
||||
logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1)
|
||||
@@ -169,12 +195,14 @@ async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector,
|
||||
# make sure we attempted a send
|
||||
continuous_collector.send.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_beliefs_noop_on_empty(continuous_collector):
|
||||
continuous_collector.send = AsyncMock()
|
||||
await continuous_collector._send_beliefs_to_bdi([], origin="o")
|
||||
continuous_collector.send.assert_not_awaited()
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_send_beliefs_sends_json_packet(continuous_collector):
|
||||
# # Patch .send and capture the message body
|
||||
@@ -191,16 +219,19 @@ async def test_send_beliefs_noop_on_empty(continuous_collector):
|
||||
# assert "belief_packet" in json.loads(sent["body"])["type"]
|
||||
# assert json.loads(sent["body"])["beliefs"] == beliefs
|
||||
|
||||
|
||||
def test_sender_node_no_sender_returns_literal(continuous_collector):
|
||||
msg = MagicMock()
|
||||
msg.sender = None
|
||||
assert continuous_collector._sender_node(msg) == "no_sender"
|
||||
|
||||
|
||||
def test_sender_node_without_at(continuous_collector):
|
||||
msg = MagicMock()
|
||||
msg.sender = "localpartonly"
|
||||
assert continuous_collector._sender_node(msg) == "localpartonly"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
|
||||
|
||||
@@ -6,7 +6,7 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper
|
||||
|
||||
def test_estimate_max_tokens():
|
||||
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
|
||||
audio = np.empty(shape=(60*16_000), dtype=np.float32)
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
||||
|
||||
@@ -16,7 +16,7 @@ def test_estimate_max_tokens():
|
||||
|
||||
def test_get_decode_options():
|
||||
"""Check whether the right decode options are given under different scenarios."""
|
||||
audio = np.empty(shape=(60*16_000), dtype=np.float32)
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
# With the defaults, it should limit output length based on input size
|
||||
recognizer = OpenAIWhisperSpeechRecognizer()
|
||||
|
||||
Reference in New Issue
Block a user