Compare commits

..

1 Commits

Author SHA1 Message Date
Twirre Meulenbelt
f12a57248a fix: do not double JSON encode
ref: N25B-242
2025-11-05 16:35:12 +01:00
22 changed files with 54 additions and 535 deletions

View File

@@ -1,14 +0,0 @@
.git
.venv
__pycache__/
*.pyc
.dockerignore
Dockerfile
README.md
.gitlab-ci.yml
.gitignore
.pre-commit-config.yaml
.githooks/
test/
.pytest_cache/
.ruff_cache/

View File

@@ -30,7 +30,7 @@ HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) # Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." # Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
MERGE_PATTERN="^Merge (remote-tracking )?(branch|pull request|tag) .*" MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
exit 0 exit 0

View File

@@ -1,21 +0,0 @@
# Debian based image
FROM ghcr.io/astral-sh/uv:0.9.8-trixie-slim
WORKDIR /app
ENV VIRTUAL_ENV=/app/.venv
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
RUN apt-get update && apt-get install -y gcc=4:14.2.0-1 portaudio19-dev && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
COPY pyproject.toml uv.lock .python-version ./
RUN uv sync
COPY . .
EXPOSE 8000
ENV PYTHONPATH=src
CMD [ "fastapi", "run", "src/control_backend/main.py" ]

View File

@@ -1,4 +1,3 @@
import json
import logging import logging
import agentspeak import agentspeak
@@ -38,66 +37,28 @@ class BDICoreAgent(BDIAgent):
Registers custom AgentSpeak actions callable from plans. Registers custom AgentSpeak actions callable from plans.
""" """
@actions.add(".reply", 3) @actions.add(".reply", 1)
def _reply(agent: "BDICoreAgent", term, intention): def _reply(agent: "BDICoreAgent", term, intention):
""" """
Sends text to the LLM (AgentSpeak action). Sends text to the LLM (AgentSpeak action).
Example: .reply("Hello LLM!") Example: .reply("Hello LLM!")
""" """
message_text = agentspeak.grounded(term.args[0], intention.scope) message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope) self.logger.debug("Reply action sending: %s", message_text)
goals = agentspeak.grounded(term.args[2], intention.scope)
self.logger.debug("Norms: %s", norms) self._send_to_llm(str(message_text))
self.logger.debug("Goals: %s", goals)
self.logger.debug("User text: %s", message_text)
self._send_to_llm(str(message_text), str(norms), str(goals))
yield yield
@actions.add(".reply_no_norms", 2) def _send_to_llm(self, text: str):
def _reply_no_norms(agent: "BDICoreAgent", term, intention):
message_text = agentspeak.grounded(term.args[0], intention.scope)
goals = agentspeak.grounded(term.args[1], intention.scope)
self.logger.debug("Goals: %s", goals)
self.logger.debug("User text: %s", message_text)
self._send_to_llm(str(message_text), goals=str(goals))
@actions.add(".reply_no_goals", 2)
def _reply_no_goals(agent: "BDICoreAgent", term, intention):
message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
self.logger.debug("Norms: %s", norms)
self.logger.debug("User text: %s", message_text)
self._send_to_llm(str(message_text), norms=str(norms))
@actions.add(".reply_no_goals_no_norms", 1)
def _reply_no_goals_no_norms(agent: "BDICoreAgent", term, intention):
message_text = agentspeak.grounded(term.args[0], intention.scope)
self.logger.debug("User text: %s", message_text)
self._send_to_llm(message_text)
def _send_to_llm(self, text: str, norms: str = None, goals: str = None):
""" """
Sends a text query to the LLM Agent asynchronously. Sends a text query to the LLM Agent asynchronously.
""" """
class SendBehaviour(OneShotBehaviour): class SendBehaviour(OneShotBehaviour):
async def run(self) -> None: async def run(self) -> None:
message_dict = {
"text": text,
"norms": norms if norms else "",
"goals": goals if goals else "",
}
msg = Message( msg = Message(
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host, to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
body=json.dumps(message_dict), body=text,
) )
await self.send(msg) await self.send(msg)

View File

@@ -1,27 +0,0 @@
import zmq
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .receive_programs_behavior import ReceiveProgramsBehavior
class BDIProgramManager(BaseAgent):
"""
Will interpret programs received from the HTTP endpoint. Extracts norms, goals, triggers and
forwards them to the BDI as beliefs.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
async def setup(self):
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("program")
self.add_behaviour(ReceiveProgramsBehavior())

View File

@@ -1,59 +0,0 @@
import json
from pydantic import ValidationError
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from control_backend.core.config import settings
from control_backend.schemas.program import Program
class ReceiveProgramsBehavior(CyclicBehaviour):
async def _receive(self) -> Program | None:
topic, body = await self.agent.sub_socket.recv_multipart()
try:
return Program.model_validate_json(body)
except ValidationError as e:
self.agent.logger.error("Received an invalid program.", exc_info=e)
return None
def _extract_norms(self, program: Program) -> str:
"""First phase only for now, as a single newline delimited string."""
if not program.phases:
return ""
if not program.phases[0].phaseData.norms:
return ""
norm_values = [norm.value for norm in program.phases[0].phaseData.norms]
return "\n".join(norm_values)
def _extract_goals(self, program: Program) -> str:
"""First phase only for now, as a single newline delimited string."""
if not program.phases:
return ""
if not program.phases[0].phaseData.goals:
return ""
goal_descriptions = [goal.description for goal in program.phases[0].phaseData.goals]
return "\n".join(goal_descriptions)
async def _send_to_bdi(self, program: Program):
temp_allowed_parts = {
"norms": [self._extract_norms(program)],
"goals": [self._extract_goals(program)],
}
message = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
sender=self.agent.jid,
body=json.dumps(temp_allowed_parts),
thread="beliefs",
)
await self.send(message)
self.agent.logger.debug("Sent new norms and goals to the BDI agent.")
async def run(self):
program = await self._receive()
if not program:
return
await self._send_to_bdi(program)

View File

@@ -17,9 +17,7 @@ class BeliefSetterBehaviour(CyclicBehaviour):
async def run(self): async def run(self):
"""Polls for messages and processes them.""" """Polls for messages and processes them."""
msg = await self.receive(timeout=1) msg = await self.receive()
if not msg:
return
self.agent.logger.debug( self.agent.logger.debug(
"Received message from %s with thread '%s' and body: %s", "Received message from %s with thread '%s' and body: %s",
msg.sender, msg.sender,
@@ -39,13 +37,8 @@ class BeliefSetterBehaviour(CyclicBehaviour):
"Message is from the belief collector agent. Processing as belief message." "Message is from the belief collector agent. Processing as belief message."
) )
self._process_belief_message(message) self._process_belief_message(message)
case settings.agent_settings.program_manager_agent_name:
self.agent.logger.debug(
"Processing message from the program manager. Processing as belief message."
)
self._process_belief_message(message)
case _: case _:
self.agent.logger.debug("Not from expected agents, discarding message") self.agent.logger.debug("Not the belief agent, discarding message")
pass pass
def _process_belief_message(self, message: Message): def _process_belief_message(self, message: Message):

View File

@@ -11,9 +11,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
""" """
async def run(self): async def run(self):
msg = await self.receive(timeout=1) msg = await self.receive()
if not msg:
return
sender = msg.sender.node sender = msg.sender.node
match sender: match sender:

View File

@@ -1,5 +1,4 @@
import json import json
import logging
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from spade.message import Message from spade.message import Message
@@ -8,8 +7,6 @@ from control_backend.core.config import settings
class BeliefFromText(CyclicBehaviour): class BeliefFromText(CyclicBehaviour):
logger = logging.getLogger(__name__)
# TODO: LLM prompt nog hardcoded # TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """ llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent. Your task is to extract values \ You are an information extraction assistent for a BDI agent. Your task is to extract values \
@@ -38,10 +35,7 @@ class BeliefFromText(CyclicBehaviour):
beliefs = {"mood": ["X"], "car": ["Y"]} beliefs = {"mood": ["X"], "car": ["Y"]}
async def run(self): async def run(self):
msg = await self.receive(timeout=1) msg = await self.receive()
if not msg:
return
sender = msg.sender.node sender = msg.sender.node
match sender: match sender:
case settings.agent_settings.transcription_agent_name: case settings.agent_settings.transcription_agent_name:
@@ -68,14 +62,10 @@ class BeliefFromText(CyclicBehaviour):
# Verify by trying to parse # Verify by trying to parse
try: try:
json.loads(response) json.loads(response)
belief_message = Message() belief_message = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
belief_message.to = ( body=response,
settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host
) )
belief_message.body = response
belief_message.thread = "beliefs" belief_message.thread = "beliefs"
await self.send(belief_message) await self.send(belief_message)
@@ -92,12 +82,12 @@ class BeliefFromText(CyclicBehaviour):
""" """
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"} belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief) payload = json.dumps(belief)
belief_msg = Message() belief_msg = Message(
to=settings.agent_settings.belief_collector_agent_name
belief_msg.to = ( + "@"
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host + settings.agent_settings.host,
body=payload,
) )
belief_msg.body = payload
belief_msg.thread = "beliefs" belief_msg.thread = "beliefs"
await self.send(belief_msg) await self.send(belief_msg)

View File

@@ -1,18 +1,3 @@
norms(""). +new_message : user_said(Message) <-
goals("").
+new_message : user_said(Message) & norms(Norms) & goals(Goals) <-
-new_message; -new_message;
.reply(Message, Norms, Goals). .reply(Message).
// +new_message : user_said(Message) & norms(Norms) <-
// -new_message;
// .reply_no_goals(Message, Norms).
//
// +new_message : user_said(Message) & goals(Goals) <-
// -new_message;
// .reply_no_norms(Message, Goals).
//
// +new_message : user_said(Message) <-
// -new_message;
// .reply_no_goals_no_norms(Message).

View File

@@ -14,9 +14,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
""" """
async def run(self): async def run(self):
msg = await self.receive(timeout=1) msg = await self.receive()
if not msg:
return
await self._process_message(msg) await self._process_message(msg)
async def _process_message(self, msg: Message): async def _process_message(self, msg: Message):

View File

@@ -30,9 +30,7 @@ class LLMAgent(BaseAgent):
Receives SPADE messages and processes only those originating from the Receives SPADE messages and processes only those originating from the
configured BDI agent. configured BDI agent.
""" """
msg = await self.receive(timeout=1) msg = await self.receive()
if not msg:
return
sender = msg.sender.node sender = msg.sender.node
self.agent.logger.debug( self.agent.logger.debug(
@@ -52,13 +50,9 @@ class LLMAgent(BaseAgent):
Forwards user text from the BDI to the LLM and replies with the generated text in chunks Forwards user text from the BDI to the LLM and replies with the generated text in chunks
separated by punctuation. separated by punctuation.
""" """
try: user_text = message.body
message = json.loads(message.body)
except json.JSONDecodeError:
self.agent.logger.error("Could not process BDI message.", exc_info=True)
# Consume the streaming generator and send a reply for every chunk # Consume the streaming generator and send a reply for every chunk
async for chunk in self._query_llm(message["text"], message["norms"], message["goals"]): async for chunk in self._query_llm(user_text):
await self._reply(chunk) await self._reply(chunk)
self.agent.logger.debug( self.agent.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI Core Agent." "Finished processing BDI message. Response sent in chunks to BDI Core Agent."
@@ -74,7 +68,7 @@ class LLMAgent(BaseAgent):
) )
await self.send(reply) await self.send(reply)
async def _query_llm(self, prompt: str, norms: str, goals: str) -> AsyncGenerator[str]: async def _query_llm(self, prompt: str) -> AsyncGenerator[str]:
""" """
Sends a chat completion request to the local LLM service and streams the response by Sends a chat completion request to the local LLM service and streams the response by
yielding fragments separated by punctuation like. yielding fragments separated by punctuation like.
@@ -82,7 +76,15 @@ class LLMAgent(BaseAgent):
:param prompt: Input text prompt to pass to the LLM. :param prompt: Input text prompt to pass to the LLM.
:yield: Fragments of the LLM-generated content. :yield: Fragments of the LLM-generated content.
""" """
instructions = LLMInstructions(norms if norms else None, goals if goals else None) instructions = LLMInstructions(
"- Be friendly and respectful.\n"
"- Make the conversation feel natural and engaging.\n"
"- Speak like a pirate.\n"
"- When the user asks what you can do, tell them.",
"- Try to learn the user's name during conversation.\n"
"- Suggest playing a game of asking yes or no questions where you think of a word "
"and the user must guess it.",
)
messages = [ messages = [
{ {
"role": "developer", "role": "developer",

View File

@@ -6,7 +6,10 @@ class LLMInstructions:
@staticmethod @staticmethod
def default_norms() -> str: def default_norms() -> str:
return "Be friendly and respectful.\nMake the conversation feel natural and engaging." return """
Be friendly and respectful.
Make the conversation feel natural and engaging.
""".strip()
@staticmethod @staticmethod
def default_goals() -> str: def default_goals() -> str:

View File

@@ -54,9 +54,7 @@ class RICommandAgent(BaseAgent):
"""Behaviour for sending commands received from other Python agents.""" """Behaviour for sending commands received from other Python agents."""
async def run(self): async def run(self):
message: spade.agent.Message = await self.receive(timeout=1) message: spade.agent.Message = await self.receive(timeout=0.1)
if not message:
return
if message and message.to == self.agent.jid: if message and message.to == self.agent.jid:
try: try:
speech_command = SpeechCommand.model_validate_json(message.body) speech_command = SpeechCommand.model_validate_json(message.body)

View File

@@ -21,13 +21,10 @@ class RICommunicationAgent(BaseAgent):
password: str, password: str,
port: int = 5222, port: int = 5222,
verify_security: bool = False, verify_security: bool = False,
address=None, address="tcp://localhost:0000",
bind=True, bind=False,
): ):
super().__init__(jid, password, port, verify_security) super().__init__(jid, password, port, verify_security)
if not address:
self.logger.critical("No address set for negotiations.")
raise Exception # TODO: improve
self._address = address self._address = address
self._bind = bind self._bind = bind
@@ -122,7 +119,10 @@ class RICommunicationAgent(BaseAgent):
port = port_data["port"] port = port_data["port"]
bind = port_data["bind"] bind = port_data["bind"]
addr = f"tcp://{settings.zmq_settings.external_host}:{port}" if not bind:
addr = f"tcp://localhost:{port}"
else:
addr = f"tcp://*:{port}"
match id: match id:
case "main": case "main":

View File

@@ -12,9 +12,10 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
# DO NOT LOG INSIDE THIS FUNCTION
@router.get("/logs/stream") @router.get("/logs/stream")
async def log_stream(): async def log_stream():
# DO NOT LOG in this function, or you'll get recursive logs. If you need it for debugging, use
# the built-in `print()`
context = Context.instance() context = Context.instance()
socket = context.socket(zmq.SUB) socket = context.socket(zmq.SUB)

View File

@@ -1,35 +0,0 @@
import logging
from fastapi import APIRouter, HTTPException, Request
from pydantic import ValidationError
from control_backend.schemas.message import Message
from control_backend.schemas.program import Program
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/program", status_code=202)
async def receive_message(program: Message, request: Request):
"""
Receives a BehaviorProgram as a stringified JSON list inside `message`.
Converts it into real Phase objects.
"""
logger.debug("Received raw program: %s", program)
raw_str = program.message # This is the JSON string
# Validate program
try:
program = Program.model_validate_json(raw_str)
except ValidationError as e:
logger.error("Failed to validate program JSON: %s", e)
raise HTTPException(status_code=400, detail="Not a valid program") from None
# send away
topic = b"program"
body = program.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Program parsed"}

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import command, logs, message, program, sse from control_backend.api.v1.endpoints import command, logs, message, sse
api_router = APIRouter() api_router = APIRouter()
@@ -11,5 +11,3 @@ api_router.include_router(sse.router, tags=["SSE"])
api_router.include_router(command.router, tags=["Commands"]) api_router.include_router(command.router, tags=["Commands"])
api_router.include_router(logs.router, tags=["Logs"]) api_router.include_router(logs.router, tags=["Logs"])
api_router.include_router(program.router, tags=["Program"])

View File

@@ -1,5 +1,3 @@
import os
from pydantic import BaseModel from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -8,11 +6,9 @@ class ZMQSettings(BaseModel):
internal_pub_address: str = "tcp://localhost:5560" internal_pub_address: str = "tcp://localhost:5560"
internal_sub_address: str = "tcp://localhost:5561" internal_sub_address: str = "tcp://localhost:5561"
external_host: str = "0.0.0.0"
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
host: str = os.environ.get("XMPP_HOST", "localhost") host: str = "localhost"
bdi_core_agent_name: str = "bdi_core" bdi_core_agent_name: str = "bdi_core"
belief_collector_agent_name: str = "belief_collector" belief_collector_agent_name: str = "belief_collector"
text_belief_extractor_agent_name: str = "text_belief_extractor" text_belief_extractor_agent_name: str = "text_belief_extractor"
@@ -20,15 +16,14 @@ class AgentSettings(BaseModel):
llm_agent_name: str = "llm_agent" llm_agent_name: str = "llm_agent"
test_agent_name: str = "test_agent" test_agent_name: str = "test_agent"
transcription_agent_name: str = "transcription_agent" transcription_agent_name: str = "transcription_agent"
program_manager_agent_name: str = "program_manager"
ri_communication_agent_name: str = "ri_communication_agent" ri_communication_agent_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent" ri_command_agent_name: str = "ri_command_agent"
class LLMSettings(BaseModel): class LLMSettings(BaseModel):
local_llm_url: str = os.environ.get("LLM_URL", "http://localhost:1234/v1/") + "chat/completions" local_llm_url: str = "http://localhost:1234/v1/chat/completions"
local_llm_model: str = os.environ.get("LLM_MODEL", "openai/gpt-oss-20b") local_llm_model: str = "openai/gpt-oss-20b"
class Settings(BaseSettings): class Settings(BaseSettings):

View File

@@ -1,6 +1,5 @@
import contextlib import contextlib
import logging import logging
import os
import threading import threading
import zmq import zmq
@@ -15,7 +14,6 @@ from control_backend.agents import (
VADAgent, VADAgent,
) )
from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent
from control_backend.agents.bdi.bdi_program_manager.bdi_program_manager import BDIProgramManager
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.logging import setup_logging from control_backend.logging import setup_logging
@@ -50,9 +48,7 @@ async def lifespan(app: FastAPI):
# --- APPLICATION STARTUP --- # --- APPLICATION STARTUP ---
setup_logging() setup_logging()
logger.info("%s is starting up.", app.title) logger.info("%s is starting up.", app.title)
logger.info( logger.warning("testing extra", extra={"extra1": "one", "extra2": "two"})
"LLM_URL: %s, LLM_MODEL: %s", os.environ.get("LLM_URL"), os.environ.get("LLM_MODEL")
)
# Initiate sockets # Initiate sockets
proxy_thread = threading.Thread(target=setup_sockets) proxy_thread = threading.Thread(target=setup_sockets)
@@ -75,7 +71,7 @@ async def lifespan(app: FastAPI):
"jid": f"{settings.agent_settings.ri_communication_agent_name}" "jid": f"{settings.agent_settings.ri_communication_agent_name}"
f"@{settings.agent_settings.host}", f"@{settings.agent_settings.host}",
"password": settings.agent_settings.ri_communication_agent_name, "password": settings.agent_settings.ri_communication_agent_name,
"address": f"tcp://{settings.zmq_settings.external_host}:5555", "address": "tcp://*:5555",
"bind": True, "bind": True,
}, },
), ),
@@ -117,39 +113,21 @@ async def lifespan(app: FastAPI):
), ),
"VADAgent": ( "VADAgent": (
VADAgent, VADAgent,
{ {"audio_in_address": "tcp://localhost:5558", "audio_in_bind": False},
"audio_in_address": f"tcp://{settings.zmq_settings.external_host}:5558",
"audio_in_bind": True,
},
),
"ProgramManager": (
BDIProgramManager,
{
"name": settings.agent_settings.program_manager_agent_name,
"jid": f"{settings.agent_settings.program_manager_agent_name}@"
f"{settings.agent_settings.host}",
"password": settings.agent_settings.program_manager_agent_name,
},
), ),
} }
vad_agent_instance = None
for name, (agent_class, kwargs) in agents_to_start.items(): for name, (agent_class, kwargs) in agents_to_start.items():
try: try:
logger.debug("Starting agent: %s", name) logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"}) agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"})
await agent_instance.start() await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent_instance = agent_instance
logger.info("Agent '%s' started successfully.", name) logger.info("Agent '%s' started successfully.", name)
except Exception as e: except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True) logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
# Consider if the application should continue if an agent fails to start. # Consider if the application should continue if an agent fails to start.
raise raise
await vad_agent_instance.streaming_behaviour.reset()
logger.info("Application startup complete.") logger.info("Application startup complete.")
yield yield

View File

@@ -1,38 +0,0 @@
from pydantic import BaseModel
class Norm(BaseModel):
id: str
name: str
value: str
class Goal(BaseModel):
id: str
name: str
description: str
achieved: bool
class Trigger(BaseModel):
id: str
label: str
type: str
value: list[str]
class PhaseData(BaseModel):
norms: list[Norm]
goals: list[Goal]
triggers: list[Trigger]
class Phase(BaseModel):
id: str
name: str
nextPhaseId: str
phaseData: PhaseData
class Program(BaseModel):
phases: list[Phase]

View File

@@ -1,187 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from spade.message import Message
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText
@pytest.fixture
def mock_settings():
"""
Mocks the settings object that the behaviour imports.
We patch it at the source where it's imported by the module under test.
"""
# Create a mock object that mimics the nested structure
settings_mock = MagicMock()
settings_mock.agent_settings.transcription_agent_name = "transcriber"
settings_mock.agent_settings.belief_collector_agent_name = "collector"
settings_mock.agent_settings.host = "fake.host"
# Use patch to replace the settings object during the test
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where
# your behaviour file imports it from.
with patch(
"control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock
):
yield settings_mock
@pytest.fixture
def behavior(mock_settings):
"""
Creates an instance of the BeliefFromText behaviour and mocks its
agent, logger, send, and receive methods.
"""
b = BeliefFromText()
b.agent = MagicMock()
b.send = AsyncMock()
b.receive = AsyncMock()
return b
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
msg.thread = thread
return msg
@pytest.mark.asyncio
async def test_run_no_message(behavior):
"""
Tests the run() method when no message is received.
"""
# Arrange: Configure receive to return None
behavior.receive.return_value = None
# Act: Run the behavior
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that no message was sent
behavior.send.assert_not_called()
@pytest.mark.asyncio
async def test_run_message_from_other_agent(behavior):
"""
Tests the run() method when a message is received from an
unknown agent (not the transcriber).
"""
# Arrange: Create a mock message from an unknown sender
mock_msg = create_mock_message("unknown", "some data", None)
behavior.receive.return_value = mock_msg
behavior._process_transcription_demo = MagicMock()
# Act
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that _process_transcription_demo was not sent
behavior._process_transcription_demo.assert_not_called()
@pytest.mark.asyncio
async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkeypatch):
"""
Tests the main success path: receiving a message from the
transcription agent, which triggers _process_transcription_demo.
"""
# Arrange: Create a mock message from the transcriber
transcription_text = "hello world"
mock_msg = create_mock_message(
mock_settings.agent_settings.transcription_agent_name, transcription_text, None
)
behavior.receive.return_value = mock_msg
# Act
await behavior.run()
# Assert
# 1. Check that receive was called
behavior.receive.assert_called_once()
# 2. Check that send was called *once*
behavior.send.assert_called_once()
# 3. Deeply inspect the message that was sent
sent_msg: Message = behavior.send.call_args[0][0]
assert (
sent_msg.to
== mock_settings.agent_settings.belief_collector_agent_name
+ "@"
+ mock_settings.agent_settings.host
)
# Check thread
assert sent_msg.thread == "beliefs"
# Parse the received JSON string back into a dict
expected_dict = {
"beliefs": {"user_said": [transcription_text]},
"type": "belief_extraction_text",
}
sent_dict = json.loads(sent_msg.body)
# Assert that the dictionaries are equal
assert sent_dict == expected_dict
@pytest.mark.asyncio
async def test_process_transcription_success(behavior, mock_settings):
"""
Tests the (currently unused) _process_transcription method's
success path, using its hardcoded mock response.
"""
# Arrange
test_text = "I am feeling happy"
# This is the hardcoded response inside the method
expected_response_body = '{"mood": [["happy"]]}'
# Act
await behavior._process_transcription(test_text)
# Assert
# 1. Check that a message was sent
behavior.send.assert_called_once()
# 2. Inspect the sent message
sent_msg: Message = behavior.send.call_args[0][0]
expected_to = (
mock_settings.agent_settings.belief_collector_agent_name
+ "@"
+ mock_settings.agent_settings.host
)
assert str(sent_msg.to) == expected_to
assert sent_msg.thread == "beliefs"
assert sent_msg.body == expected_response_body
@pytest.mark.asyncio
async def test_process_transcription_json_decode_error(behavior, mock_settings):
"""
Tests the _process_transcription method's error handling
when the (mocked) response is invalid JSON.
We do this by patching json.loads to raise an error.
"""
# Arrange
test_text = "I am feeling happy"
# Patch json.loads to raise an error when called
with patch("json.loads", side_effect=json.JSONDecodeError("Mock error", "", 0)):
# Act
await behavior._process_transcription(test_text)
# Assert
# 1. Check that NO message was sent
behavior.send.assert_not_called()