Compare commits
15 Commits
refactor/l
...
build/dock
| Author | SHA1 | Date | |
|---|---|---|---|
| 173326d4ad | |||
| 9c538d927f | |||
|
|
1518b14867 | ||
|
|
858a554c78 | ||
|
|
5376b3bb4c | ||
| 8cd8988fe0 | |||
|
|
919604493e | ||
| 273f621b1b | |||
|
|
43f3cba1a8 | ||
|
|
e39139cac9 | ||
|
|
b785493b97 | ||
|
|
781a05328f | ||
|
|
1c756474f2 | ||
| df7dc8fdf3 | |||
|
|
6235fcdaf4 |
14
.dockerignore
Normal file
14
.dockerignore
Normal file
@@ -0,0 +1,14 @@
|
||||
.git
|
||||
.venv
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.dockerignore
|
||||
Dockerfile
|
||||
README.md
|
||||
.gitlab-ci.yml
|
||||
.gitignore
|
||||
.pre-commit-config.yaml
|
||||
.githooks/
|
||||
test/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
@@ -30,7 +30,7 @@ HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
|
||||
|
||||
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
|
||||
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
|
||||
MERGE_PATTERN="^Merge (branch|pull request|tag) .*"
|
||||
MERGE_PATTERN="^Merge (remote-tracking )?(branch|pull request|tag) .*"
|
||||
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
|
||||
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
|
||||
exit 0
|
||||
|
||||
21
Dockerfile
Normal file
21
Dockerfile
Normal file
@@ -0,0 +1,21 @@
|
||||
# Debian based image
|
||||
FROM ghcr.io/astral-sh/uv:0.9.8-trixie-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV VIRTUAL_ENV=/app/.venv
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
RUN apt-get update && apt-get install -y gcc=4:14.2.0-1 portaudio19-dev && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY pyproject.toml uv.lock .python-version ./
|
||||
|
||||
RUN uv sync
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENV PYTHONPATH=src
|
||||
|
||||
CMD [ "fastapi", "run", "src/control_backend/main.py" ]
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import agentspeak
|
||||
@@ -37,28 +38,66 @@ class BDICoreAgent(BDIAgent):
|
||||
Registers custom AgentSpeak actions callable from plans.
|
||||
"""
|
||||
|
||||
@actions.add(".reply", 1)
|
||||
@actions.add(".reply", 3)
|
||||
def _reply(agent: "BDICoreAgent", term, intention):
|
||||
"""
|
||||
Sends text to the LLM (AgentSpeak action).
|
||||
Example: .reply("Hello LLM!")
|
||||
"""
|
||||
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
||||
self.logger.debug("Reply action sending: %s", message_text)
|
||||
norms = agentspeak.grounded(term.args[1], intention.scope)
|
||||
goals = agentspeak.grounded(term.args[2], intention.scope)
|
||||
|
||||
self._send_to_llm(str(message_text))
|
||||
self.logger.debug("Norms: %s", norms)
|
||||
self.logger.debug("Goals: %s", goals)
|
||||
self.logger.debug("User text: %s", message_text)
|
||||
|
||||
self._send_to_llm(str(message_text), str(norms), str(goals))
|
||||
yield
|
||||
|
||||
def _send_to_llm(self, text: str):
|
||||
@actions.add(".reply_no_norms", 2)
|
||||
def _reply_no_norms(agent: "BDICoreAgent", term, intention):
|
||||
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
||||
goals = agentspeak.grounded(term.args[1], intention.scope)
|
||||
|
||||
self.logger.debug("Goals: %s", goals)
|
||||
self.logger.debug("User text: %s", message_text)
|
||||
|
||||
self._send_to_llm(str(message_text), goals=str(goals))
|
||||
|
||||
@actions.add(".reply_no_goals", 2)
|
||||
def _reply_no_goals(agent: "BDICoreAgent", term, intention):
|
||||
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
||||
norms = agentspeak.grounded(term.args[1], intention.scope)
|
||||
|
||||
self.logger.debug("Norms: %s", norms)
|
||||
self.logger.debug("User text: %s", message_text)
|
||||
|
||||
self._send_to_llm(str(message_text), norms=str(norms))
|
||||
|
||||
@actions.add(".reply_no_goals_no_norms", 1)
|
||||
def _reply_no_goals_no_norms(agent: "BDICoreAgent", term, intention):
|
||||
message_text = agentspeak.grounded(term.args[0], intention.scope)
|
||||
|
||||
self.logger.debug("User text: %s", message_text)
|
||||
|
||||
self._send_to_llm(message_text)
|
||||
|
||||
def _send_to_llm(self, text: str, norms: str = None, goals: str = None):
|
||||
"""
|
||||
Sends a text query to the LLM Agent asynchronously.
|
||||
"""
|
||||
|
||||
class SendBehaviour(OneShotBehaviour):
|
||||
async def run(self) -> None:
|
||||
message_dict = {
|
||||
"text": text,
|
||||
"norms": norms if norms else "",
|
||||
"goals": goals if goals else "",
|
||||
}
|
||||
msg = Message(
|
||||
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
|
||||
body=text,
|
||||
body=json.dumps(message_dict),
|
||||
)
|
||||
|
||||
await self.send(msg)
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import zmq
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.agents import BaseAgent
|
||||
from control_backend.core.config import settings
|
||||
|
||||
from .receive_programs_behavior import ReceiveProgramsBehavior
|
||||
|
||||
|
||||
class BDIProgramManager(BaseAgent):
|
||||
"""
|
||||
Will interpret programs received from the HTTP endpoint. Extracts norms, goals, triggers and
|
||||
forwards them to the BDI as beliefs.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.sub_socket = None
|
||||
|
||||
async def setup(self):
|
||||
context = Context.instance()
|
||||
|
||||
self.sub_socket = context.socket(zmq.SUB)
|
||||
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||
self.sub_socket.subscribe("program")
|
||||
|
||||
self.add_behaviour(ReceiveProgramsBehavior())
|
||||
@@ -0,0 +1,59 @@
|
||||
import json
|
||||
|
||||
from pydantic import ValidationError
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
from spade.message import Message
|
||||
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.schemas.program import Program
|
||||
|
||||
|
||||
class ReceiveProgramsBehavior(CyclicBehaviour):
|
||||
async def _receive(self) -> Program | None:
|
||||
topic, body = await self.agent.sub_socket.recv_multipart()
|
||||
|
||||
try:
|
||||
return Program.model_validate_json(body)
|
||||
except ValidationError as e:
|
||||
self.agent.logger.error("Received an invalid program.", exc_info=e)
|
||||
return None
|
||||
|
||||
def _extract_norms(self, program: Program) -> str:
|
||||
"""First phase only for now, as a single newline delimited string."""
|
||||
if not program.phases:
|
||||
return ""
|
||||
if not program.phases[0].phaseData.norms:
|
||||
return ""
|
||||
norm_values = [norm.value for norm in program.phases[0].phaseData.norms]
|
||||
return "\n".join(norm_values)
|
||||
|
||||
def _extract_goals(self, program: Program) -> str:
|
||||
"""First phase only for now, as a single newline delimited string."""
|
||||
if not program.phases:
|
||||
return ""
|
||||
if not program.phases[0].phaseData.goals:
|
||||
return ""
|
||||
goal_descriptions = [goal.description for goal in program.phases[0].phaseData.goals]
|
||||
return "\n".join(goal_descriptions)
|
||||
|
||||
async def _send_to_bdi(self, program: Program):
|
||||
temp_allowed_parts = {
|
||||
"norms": [self._extract_norms(program)],
|
||||
"goals": [self._extract_goals(program)],
|
||||
}
|
||||
|
||||
message = Message(
|
||||
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
|
||||
sender=self.agent.jid,
|
||||
body=json.dumps(temp_allowed_parts),
|
||||
thread="beliefs",
|
||||
)
|
||||
await self.send(message)
|
||||
self.agent.logger.debug("Sent new norms and goals to the BDI agent.")
|
||||
|
||||
async def run(self):
|
||||
program = await self._receive()
|
||||
if not program:
|
||||
return
|
||||
|
||||
await self._send_to_bdi(program)
|
||||
@@ -17,7 +17,9 @@ class BeliefSetterBehaviour(CyclicBehaviour):
|
||||
|
||||
async def run(self):
|
||||
"""Polls for messages and processes them."""
|
||||
msg = await self.receive()
|
||||
msg = await self.receive(timeout=1)
|
||||
if not msg:
|
||||
return
|
||||
self.agent.logger.debug(
|
||||
"Received message from %s with thread '%s' and body: %s",
|
||||
msg.sender,
|
||||
@@ -37,8 +39,13 @@ class BeliefSetterBehaviour(CyclicBehaviour):
|
||||
"Message is from the belief collector agent. Processing as belief message."
|
||||
)
|
||||
self._process_belief_message(message)
|
||||
case settings.agent_settings.program_manager_agent_name:
|
||||
self.agent.logger.debug(
|
||||
"Processing message from the program manager. Processing as belief message."
|
||||
)
|
||||
self._process_belief_message(message)
|
||||
case _:
|
||||
self.agent.logger.debug("Not the belief agent, discarding message")
|
||||
self.agent.logger.debug("Not from expected agents, discarding message")
|
||||
pass
|
||||
|
||||
def _process_belief_message(self, message: Message):
|
||||
|
||||
@@ -11,7 +11,9 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
|
||||
"""
|
||||
|
||||
async def run(self):
|
||||
msg = await self.receive()
|
||||
msg = await self.receive(timeout=1)
|
||||
if not msg:
|
||||
return
|
||||
|
||||
sender = msg.sender.node
|
||||
match sender:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from spade.behaviour import CyclicBehaviour
|
||||
from spade.message import Message
|
||||
@@ -7,6 +8,8 @@ from control_backend.core.config import settings
|
||||
|
||||
|
||||
class BeliefFromText(CyclicBehaviour):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: LLM prompt nog hardcoded
|
||||
llm_instruction_prompt = """
|
||||
You are an information extraction assistent for a BDI agent. Your task is to extract values \
|
||||
@@ -35,7 +38,10 @@ class BeliefFromText(CyclicBehaviour):
|
||||
beliefs = {"mood": ["X"], "car": ["Y"]}
|
||||
|
||||
async def run(self):
|
||||
msg = await self.receive()
|
||||
msg = await self.receive(timeout=1)
|
||||
if not msg:
|
||||
return
|
||||
|
||||
sender = msg.sender.node
|
||||
match sender:
|
||||
case settings.agent_settings.transcription_agent_name:
|
||||
@@ -62,10 +68,14 @@ class BeliefFromText(CyclicBehaviour):
|
||||
# Verify by trying to parse
|
||||
try:
|
||||
json.loads(response)
|
||||
belief_message = Message(
|
||||
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
|
||||
body=response,
|
||||
belief_message = Message()
|
||||
|
||||
belief_message.to = (
|
||||
settings.agent_settings.belief_collector_agent_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host
|
||||
)
|
||||
belief_message.body = response
|
||||
belief_message.thread = "beliefs"
|
||||
|
||||
await self.send(belief_message)
|
||||
@@ -82,12 +92,12 @@ class BeliefFromText(CyclicBehaviour):
|
||||
"""
|
||||
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
|
||||
payload = json.dumps(belief)
|
||||
belief_msg = Message(
|
||||
to=settings.agent_settings.belief_collector_agent_name
|
||||
+ "@"
|
||||
+ settings.agent_settings.host,
|
||||
body=payload,
|
||||
belief_msg = Message()
|
||||
|
||||
belief_msg.to = (
|
||||
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host
|
||||
)
|
||||
belief_msg.body = payload
|
||||
belief_msg.thread = "beliefs"
|
||||
|
||||
await self.send(belief_msg)
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
+new_message : user_said(Message) <-
|
||||
norms("").
|
||||
goals("").
|
||||
|
||||
+new_message : user_said(Message) & norms(Norms) & goals(Goals) <-
|
||||
-new_message;
|
||||
.reply(Message).
|
||||
.reply(Message, Norms, Goals).
|
||||
|
||||
// +new_message : user_said(Message) & norms(Norms) <-
|
||||
// -new_message;
|
||||
// .reply_no_goals(Message, Norms).
|
||||
//
|
||||
// +new_message : user_said(Message) & goals(Goals) <-
|
||||
// -new_message;
|
||||
// .reply_no_norms(Message, Goals).
|
||||
//
|
||||
// +new_message : user_said(Message) <-
|
||||
// -new_message;
|
||||
// .reply_no_goals_no_norms(Message).
|
||||
@@ -14,7 +14,9 @@ class ContinuousBeliefCollector(CyclicBehaviour):
|
||||
"""
|
||||
|
||||
async def run(self):
|
||||
msg = await self.receive()
|
||||
msg = await self.receive(timeout=1)
|
||||
if not msg:
|
||||
return
|
||||
await self._process_message(msg)
|
||||
|
||||
async def _process_message(self, msg: Message):
|
||||
|
||||
@@ -30,7 +30,9 @@ class LLMAgent(BaseAgent):
|
||||
Receives SPADE messages and processes only those originating from the
|
||||
configured BDI agent.
|
||||
"""
|
||||
msg = await self.receive()
|
||||
msg = await self.receive(timeout=1)
|
||||
if not msg:
|
||||
return
|
||||
|
||||
sender = msg.sender.node
|
||||
self.agent.logger.debug(
|
||||
@@ -50,9 +52,13 @@ class LLMAgent(BaseAgent):
|
||||
Forwards user text from the BDI to the LLM and replies with the generated text in chunks
|
||||
separated by punctuation.
|
||||
"""
|
||||
user_text = message.body
|
||||
try:
|
||||
message = json.loads(message.body)
|
||||
except json.JSONDecodeError:
|
||||
self.agent.logger.error("Could not process BDI message.", exc_info=True)
|
||||
|
||||
# Consume the streaming generator and send a reply for every chunk
|
||||
async for chunk in self._query_llm(user_text):
|
||||
async for chunk in self._query_llm(message["text"], message["norms"], message["goals"]):
|
||||
await self._reply(chunk)
|
||||
self.agent.logger.debug(
|
||||
"Finished processing BDI message. Response sent in chunks to BDI Core Agent."
|
||||
@@ -68,7 +74,7 @@ class LLMAgent(BaseAgent):
|
||||
)
|
||||
await self.send(reply)
|
||||
|
||||
async def _query_llm(self, prompt: str) -> AsyncGenerator[str]:
|
||||
async def _query_llm(self, prompt: str, norms: str, goals: str) -> AsyncGenerator[str]:
|
||||
"""
|
||||
Sends a chat completion request to the local LLM service and streams the response by
|
||||
yielding fragments separated by punctuation like.
|
||||
@@ -76,15 +82,7 @@ class LLMAgent(BaseAgent):
|
||||
:param prompt: Input text prompt to pass to the LLM.
|
||||
:yield: Fragments of the LLM-generated content.
|
||||
"""
|
||||
instructions = LLMInstructions(
|
||||
"- Be friendly and respectful.\n"
|
||||
"- Make the conversation feel natural and engaging.\n"
|
||||
"- Speak like a pirate.\n"
|
||||
"- When the user asks what you can do, tell them.",
|
||||
"- Try to learn the user's name during conversation.\n"
|
||||
"- Suggest playing a game of asking yes or no questions where you think of a word "
|
||||
"and the user must guess it.",
|
||||
)
|
||||
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
|
||||
messages = [
|
||||
{
|
||||
"role": "developer",
|
||||
|
||||
@@ -6,10 +6,7 @@ class LLMInstructions:
|
||||
|
||||
@staticmethod
|
||||
def default_norms() -> str:
|
||||
return """
|
||||
Be friendly and respectful.
|
||||
Make the conversation feel natural and engaging.
|
||||
""".strip()
|
||||
return "Be friendly and respectful.\nMake the conversation feel natural and engaging."
|
||||
|
||||
@staticmethod
|
||||
def default_goals() -> str:
|
||||
|
||||
@@ -54,7 +54,9 @@ class RICommandAgent(BaseAgent):
|
||||
"""Behaviour for sending commands received from other Python agents."""
|
||||
|
||||
async def run(self):
|
||||
message: spade.agent.Message = await self.receive(timeout=0.1)
|
||||
message: spade.agent.Message = await self.receive(timeout=1)
|
||||
if not message:
|
||||
return
|
||||
if message and message.to == self.agent.jid:
|
||||
try:
|
||||
speech_command = SpeechCommand.model_validate_json(message.body)
|
||||
|
||||
@@ -21,10 +21,13 @@ class RICommunicationAgent(BaseAgent):
|
||||
password: str,
|
||||
port: int = 5222,
|
||||
verify_security: bool = False,
|
||||
address="tcp://localhost:0000",
|
||||
bind=False,
|
||||
address=None,
|
||||
bind=True,
|
||||
):
|
||||
super().__init__(jid, password, port, verify_security)
|
||||
if not address:
|
||||
self.logger.critical("No address set for negotiations.")
|
||||
raise Exception # TODO: improve
|
||||
self._address = address
|
||||
self._bind = bind
|
||||
|
||||
@@ -119,10 +122,7 @@ class RICommunicationAgent(BaseAgent):
|
||||
port = port_data["port"]
|
||||
bind = port_data["bind"]
|
||||
|
||||
if not bind:
|
||||
addr = f"tcp://localhost:{port}"
|
||||
else:
|
||||
addr = f"tcp://*:{port}"
|
||||
addr = f"tcp://{settings.zmq_settings.external_host}:{port}"
|
||||
|
||||
match id:
|
||||
case "main":
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
import zmq
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pyjabber.server_parameters import json
|
||||
from zmq.asyncio import Context
|
||||
|
||||
from control_backend.core.config import settings
|
||||
@@ -13,6 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# DO NOT LOG INSIDE THIS FUNCTION
|
||||
@router.get("/logs/stream")
|
||||
async def log_stream():
|
||||
context = Context.instance()
|
||||
@@ -27,7 +27,6 @@ async def log_stream():
|
||||
while True:
|
||||
_, message = await socket.recv_multipart()
|
||||
message = message.decode().strip()
|
||||
json_data = json.dumps(message)
|
||||
yield f"data: {json_data}\n\n"
|
||||
yield f"data: {message}\n\n"
|
||||
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
|
||||
35
src/control_backend/api/v1/endpoints/program.py
Normal file
35
src/control_backend/api/v1/endpoints/program.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.message import Message
|
||||
from control_backend.schemas.program import Program
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/program", status_code=202)
|
||||
async def receive_message(program: Message, request: Request):
|
||||
"""
|
||||
Receives a BehaviorProgram as a stringified JSON list inside `message`.
|
||||
Converts it into real Phase objects.
|
||||
"""
|
||||
logger.debug("Received raw program: %s", program)
|
||||
raw_str = program.message # This is the JSON string
|
||||
|
||||
# Validate program
|
||||
try:
|
||||
program = Program.model_validate_json(raw_str)
|
||||
except ValidationError as e:
|
||||
logger.error("Failed to validate program JSON: %s", e)
|
||||
raise HTTPException(status_code=400, detail="Not a valid program") from None
|
||||
|
||||
# send away
|
||||
topic = b"program"
|
||||
body = program.model_dump_json().encode()
|
||||
pub_socket = request.app.state.endpoints_pub_socket
|
||||
await pub_socket.send_multipart([topic, body])
|
||||
|
||||
return {"status": "Program parsed"}
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from control_backend.api.v1.endpoints import command, logs, message, sse
|
||||
from control_backend.api.v1.endpoints import command, logs, message, program, sse
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -11,3 +11,5 @@ api_router.include_router(sse.router, tags=["SSE"])
|
||||
api_router.include_router(command.router, tags=["Commands"])
|
||||
|
||||
api_router.include_router(logs.router, tags=["Logs"])
|
||||
|
||||
api_router.include_router(program.router, tags=["Program"])
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
@@ -6,9 +8,11 @@ class ZMQSettings(BaseModel):
|
||||
internal_pub_address: str = "tcp://localhost:5560"
|
||||
internal_sub_address: str = "tcp://localhost:5561"
|
||||
|
||||
external_host: str = "0.0.0.0"
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
host: str = "localhost"
|
||||
host: str = os.environ.get("XMPP_HOST", "localhost")
|
||||
bdi_core_agent_name: str = "bdi_core"
|
||||
belief_collector_agent_name: str = "belief_collector"
|
||||
text_belief_extractor_agent_name: str = "text_belief_extractor"
|
||||
@@ -16,14 +20,15 @@ class AgentSettings(BaseModel):
|
||||
llm_agent_name: str = "llm_agent"
|
||||
test_agent_name: str = "test_agent"
|
||||
transcription_agent_name: str = "transcription_agent"
|
||||
program_manager_agent_name: str = "program_manager"
|
||||
|
||||
ri_communication_agent_name: str = "ri_communication_agent"
|
||||
ri_command_agent_name: str = "ri_command_agent"
|
||||
|
||||
|
||||
class LLMSettings(BaseModel):
|
||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||
local_llm_model: str = "openai/gpt-oss-20b"
|
||||
local_llm_url: str = os.environ.get("LLM_URL", "http://localhost:1234/v1/") + "chat/completions"
|
||||
local_llm_model: str = os.environ.get("LLM_MODEL", "openai/gpt-oss-20b")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
import zmq
|
||||
@@ -14,6 +15,7 @@ from control_backend.agents import (
|
||||
VADAgent,
|
||||
)
|
||||
from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent
|
||||
from control_backend.agents.bdi.bdi_program_manager.bdi_program_manager import BDIProgramManager
|
||||
from control_backend.api.v1.router import api_router
|
||||
from control_backend.core.config import settings
|
||||
from control_backend.logging import setup_logging
|
||||
@@ -48,7 +50,9 @@ async def lifespan(app: FastAPI):
|
||||
# --- APPLICATION STARTUP ---
|
||||
setup_logging()
|
||||
logger.info("%s is starting up.", app.title)
|
||||
logger.warning("testing extra", extra={"extra1": "one", "extra2": "two"})
|
||||
logger.info(
|
||||
"LLM_URL: %s, LLM_MODEL: %s", os.environ.get("LLM_URL"), os.environ.get("LLM_MODEL")
|
||||
)
|
||||
|
||||
# Initiate sockets
|
||||
proxy_thread = threading.Thread(target=setup_sockets)
|
||||
@@ -71,7 +75,7 @@ async def lifespan(app: FastAPI):
|
||||
"jid": f"{settings.agent_settings.ri_communication_agent_name}"
|
||||
f"@{settings.agent_settings.host}",
|
||||
"password": settings.agent_settings.ri_communication_agent_name,
|
||||
"address": "tcp://*:5555",
|
||||
"address": f"tcp://{settings.zmq_settings.external_host}:5555",
|
||||
"bind": True,
|
||||
},
|
||||
),
|
||||
@@ -113,21 +117,39 @@ async def lifespan(app: FastAPI):
|
||||
),
|
||||
"VADAgent": (
|
||||
VADAgent,
|
||||
{"audio_in_address": "tcp://localhost:5558", "audio_in_bind": False},
|
||||
{
|
||||
"audio_in_address": f"tcp://{settings.zmq_settings.external_host}:5558",
|
||||
"audio_in_bind": True,
|
||||
},
|
||||
),
|
||||
"ProgramManager": (
|
||||
BDIProgramManager,
|
||||
{
|
||||
"name": settings.agent_settings.program_manager_agent_name,
|
||||
"jid": f"{settings.agent_settings.program_manager_agent_name}@"
|
||||
f"{settings.agent_settings.host}",
|
||||
"password": settings.agent_settings.program_manager_agent_name,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
vad_agent_instance = None
|
||||
|
||||
for name, (agent_class, kwargs) in agents_to_start.items():
|
||||
try:
|
||||
logger.debug("Starting agent: %s", name)
|
||||
agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"})
|
||||
await agent_instance.start()
|
||||
if isinstance(agent_instance, VADAgent):
|
||||
vad_agent_instance = agent_instance
|
||||
logger.info("Agent '%s' started successfully.", name)
|
||||
except Exception as e:
|
||||
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
|
||||
# Consider if the application should continue if an agent fails to start.
|
||||
raise
|
||||
|
||||
await vad_agent_instance.streaming_behaviour.reset()
|
||||
|
||||
logger.info("Application startup complete.")
|
||||
|
||||
yield
|
||||
|
||||
38
src/control_backend/schemas/program.py
Normal file
38
src/control_backend/schemas/program.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Norm(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
class Goal(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
achieved: bool
|
||||
|
||||
|
||||
class Trigger(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
type: str
|
||||
value: list[str]
|
||||
|
||||
|
||||
class PhaseData(BaseModel):
|
||||
norms: list[Norm]
|
||||
goals: list[Goal]
|
||||
triggers: list[Trigger]
|
||||
|
||||
|
||||
class Phase(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
nextPhaseId: str
|
||||
phaseData: PhaseData
|
||||
|
||||
|
||||
class Program(BaseModel):
|
||||
phases: list[Phase]
|
||||
@@ -0,0 +1,187 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from spade.message import Message
|
||||
|
||||
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
"""
|
||||
Mocks the settings object that the behaviour imports.
|
||||
We patch it at the source where it's imported by the module under test.
|
||||
"""
|
||||
# Create a mock object that mimics the nested structure
|
||||
settings_mock = MagicMock()
|
||||
settings_mock.agent_settings.transcription_agent_name = "transcriber"
|
||||
settings_mock.agent_settings.belief_collector_agent_name = "collector"
|
||||
settings_mock.agent_settings.host = "fake.host"
|
||||
|
||||
# Use patch to replace the settings object during the test
|
||||
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where
|
||||
# your behaviour file imports it from.
|
||||
with patch(
|
||||
"control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock
|
||||
):
|
||||
yield settings_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def behavior(mock_settings):
|
||||
"""
|
||||
Creates an instance of the BeliefFromText behaviour and mocks its
|
||||
agent, logger, send, and receive methods.
|
||||
"""
|
||||
b = BeliefFromText()
|
||||
|
||||
b.agent = MagicMock()
|
||||
b.send = AsyncMock()
|
||||
b.receive = AsyncMock()
|
||||
|
||||
return b
|
||||
|
||||
|
||||
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
|
||||
"""Helper function to create a configured mock message."""
|
||||
msg = MagicMock()
|
||||
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
|
||||
msg.body = body
|
||||
msg.thread = thread
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_no_message(behavior):
|
||||
"""
|
||||
Tests the run() method when no message is received.
|
||||
"""
|
||||
# Arrange: Configure receive to return None
|
||||
behavior.receive.return_value = None
|
||||
|
||||
# Act: Run the behavior
|
||||
await behavior.run()
|
||||
|
||||
# Assert
|
||||
# 1. Check that receive was called
|
||||
behavior.receive.assert_called_once()
|
||||
# 2. Check that no message was sent
|
||||
behavior.send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_message_from_other_agent(behavior):
|
||||
"""
|
||||
Tests the run() method when a message is received from an
|
||||
unknown agent (not the transcriber).
|
||||
"""
|
||||
# Arrange: Create a mock message from an unknown sender
|
||||
mock_msg = create_mock_message("unknown", "some data", None)
|
||||
behavior.receive.return_value = mock_msg
|
||||
behavior._process_transcription_demo = MagicMock()
|
||||
|
||||
# Act
|
||||
await behavior.run()
|
||||
|
||||
# Assert
|
||||
# 1. Check that receive was called
|
||||
behavior.receive.assert_called_once()
|
||||
# 2. Check that _process_transcription_demo was not sent
|
||||
behavior._process_transcription_demo.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkeypatch):
|
||||
"""
|
||||
Tests the main success path: receiving a message from the
|
||||
transcription agent, which triggers _process_transcription_demo.
|
||||
"""
|
||||
# Arrange: Create a mock message from the transcriber
|
||||
transcription_text = "hello world"
|
||||
mock_msg = create_mock_message(
|
||||
mock_settings.agent_settings.transcription_agent_name, transcription_text, None
|
||||
)
|
||||
behavior.receive.return_value = mock_msg
|
||||
|
||||
# Act
|
||||
await behavior.run()
|
||||
|
||||
# Assert
|
||||
# 1. Check that receive was called
|
||||
behavior.receive.assert_called_once()
|
||||
|
||||
# 2. Check that send was called *once*
|
||||
behavior.send.assert_called_once()
|
||||
|
||||
# 3. Deeply inspect the message that was sent
|
||||
sent_msg: Message = behavior.send.call_args[0][0]
|
||||
|
||||
assert (
|
||||
sent_msg.to
|
||||
== mock_settings.agent_settings.belief_collector_agent_name
|
||||
+ "@"
|
||||
+ mock_settings.agent_settings.host
|
||||
)
|
||||
|
||||
# Check thread
|
||||
assert sent_msg.thread == "beliefs"
|
||||
|
||||
# Parse the received JSON string back into a dict
|
||||
expected_dict = {
|
||||
"beliefs": {"user_said": [transcription_text]},
|
||||
"type": "belief_extraction_text",
|
||||
}
|
||||
sent_dict = json.loads(sent_msg.body)
|
||||
|
||||
# Assert that the dictionaries are equal
|
||||
assert sent_dict == expected_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_transcription_success(behavior, mock_settings):
|
||||
"""
|
||||
Tests the (currently unused) _process_transcription method's
|
||||
success path, using its hardcoded mock response.
|
||||
"""
|
||||
# Arrange
|
||||
test_text = "I am feeling happy"
|
||||
# This is the hardcoded response inside the method
|
||||
expected_response_body = '{"mood": [["happy"]]}'
|
||||
|
||||
# Act
|
||||
await behavior._process_transcription(test_text)
|
||||
|
||||
# Assert
|
||||
# 1. Check that a message was sent
|
||||
behavior.send.assert_called_once()
|
||||
|
||||
# 2. Inspect the sent message
|
||||
sent_msg: Message = behavior.send.call_args[0][0]
|
||||
expected_to = (
|
||||
mock_settings.agent_settings.belief_collector_agent_name
|
||||
+ "@"
|
||||
+ mock_settings.agent_settings.host
|
||||
)
|
||||
assert str(sent_msg.to) == expected_to
|
||||
assert sent_msg.thread == "beliefs"
|
||||
assert sent_msg.body == expected_response_body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_transcription_json_decode_error(behavior, mock_settings):
|
||||
"""
|
||||
Tests the _process_transcription method's error handling
|
||||
when the (mocked) response is invalid JSON.
|
||||
We do this by patching json.loads to raise an error.
|
||||
"""
|
||||
# Arrange
|
||||
test_text = "I am feeling happy"
|
||||
# Patch json.loads to raise an error when called
|
||||
with patch("json.loads", side_effect=json.JSONDecodeError("Mock error", "", 0)):
|
||||
# Act
|
||||
await behavior._process_transcription(test_text)
|
||||
|
||||
# Assert
|
||||
# 1. Check that NO message was sent
|
||||
behavior.send.assert_not_called()
|
||||
Reference in New Issue
Block a user