Add program manager #30

Merged
0950726 merged 8 commits from feat/norms-and-goals-program into dev 2025-11-25 11:20:51 +00:00
17 changed files with 313 additions and 131 deletions

View File

@@ -30,7 +30,7 @@ HEADER=$(head -n 1 "$COMMIT_MSG_FILE")
# Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) # Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab)
# Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." # Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..."
MERGE_PATTERN="^Merge (branch|pull request|tag) .*" MERGE_PATTERN="^Merge (remote-tracking )?(branch|pull request|tag) .*"
if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then
echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}"
exit 0 exit 0

View File

@@ -11,9 +11,12 @@ from pydantic import ValidationError
from control_backend.agents.base import BaseAgent from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak
class BDICoreAgent(BaseAgent): class BDICoreAgent(BaseAgent):
bdi_agent: agentspeak.runtime.Agent bdi_agent: agentspeak.runtime.Agent
@@ -77,17 +80,18 @@ class BDICoreAgent(BaseAgent):
""" """
Route incoming messages (Beliefs or LLM responses). Route incoming messages (Beliefs or LLM responses).
""" """
sender = msg.sender self.logger.debug("Processing message from %s.", msg.sender)
match sender: if msg.thread == "beliefs":
case settings.agent_settings.bdi_belief_collector_name: try:
self.logger.debug("Processing message from belief collector.") beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
try: self._apply_beliefs(beliefs)
if msg.thread == "beliefs": except ValidationError:
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs self.logger.exception("Error processing belief.")
self._add_beliefs(beliefs) return
except ValidationError:
self.logger.exception("Error processing belief.") # The message was not a belief, handle special cases based on sender
match msg.sender:
case settings.agent_settings.llm_name: case settings.agent_settings.llm_name:
content = msg.body content = msg.body
self.logger.info("Received LLM response: %s", content) self.logger.info("Received LLM response: %s", content)
@@ -101,15 +105,19 @@ class BDICoreAgent(BaseAgent):
) )
await self.send(out_msg) await self.send(out_msg)
def _add_beliefs(self, beliefs: dict[str, list[str]]): def _apply_beliefs(self, beliefs: list[Belief]):
if not beliefs: if not beliefs:
return return
for name, args in beliefs.items(): for belief in beliefs:
self._add_belief(name, args) if belief.replace:
self._remove_all_with_name(belief.name)
self._add_belief(belief.name, belief.arguments)
def _add_belief(self, name: str, args: Iterable[str] = []): def _add_belief(self, name: str, args: Iterable[str] = []):
new_args = (agentspeak.Literal(arg) for arg in args) # new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple
merged_args = DELIMITER.join(arg for arg in args)
new_args = (agentspeak.Literal(merged_args),)
term = agentspeak.Literal(name, new_args) term = agentspeak.Literal(name, new_args)
self.bdi_agent.call( self.bdi_agent.call(
@@ -143,7 +151,6 @@ class BDICoreAgent(BaseAgent):
else: else:
self.logger.debug("Failed to remove belief (it was not in the belief base).") self.logger.debug("Failed to remove belief (it was not in the belief base).")
# TODO: decide if this is needed
def _remove_all_with_name(self, name: str): def _remove_all_with_name(self, name: str):
""" """
Removes all beliefs that match the given `name`. Removes all beliefs that match the given `name`.
@@ -155,7 +162,8 @@ class BDICoreAgent(BaseAgent):
removed_count = 0 removed_count = 0
for group in relevant_groups: for group in relevant_groups:
for belief in self.bdi_agent.beliefs[group]: beliefs_to_remove = list(self.bdi_agent.beliefs[group])
for belief in beliefs_to_remove:
self.bdi_agent.call( self.bdi_agent.call(
agentspeak.Trigger.removal, agentspeak.Trigger.removal,
agentspeak.GoalType.belief, agentspeak.GoalType.belief,
@@ -175,21 +183,37 @@ class BDICoreAgent(BaseAgent):
the function expects (which will be located in `term.args`). the function expects (which will be located in `term.args`).
""" """
@self.actions.add(".reply", 1) @self.actions.add(".reply", 3)
def _reply(agent, term, intention): def _reply(agent: "BDICoreAgent", term, intention):
""" """
Sends text to the LLM. Sends text to the LLM (AgentSpeak action).
Example: .reply("Hello LLM!", "Some norm", "Some goal")
""" """
message_text = agentspeak.grounded(term.args[0], intention.scope) message_text = agentspeak.grounded(term.args[0], intention.scope)
norms = agentspeak.grounded(term.args[1], intention.scope)
goals = agentspeak.grounded(term.args[2], intention.scope)
asyncio.create_task(self._send_to_llm(str(message_text))) self.logger.debug("Norms: %s", norms)
self.logger.debug("Goals: %s", goals)
self.logger.debug("User text: %s", message_text)
asyncio.create_task(self._send_to_llm(str(message_text), str(norms), str(goals)))
yield yield
async def _send_to_llm(self, text: str): async def _send_to_llm(self, text: str, norms: str = None, goals: str = None):
""" """
Sends a text query to the LLM agent asynchronously. Sends a text query to the LLM agent asynchronously.
""" """
msg = InternalMessage(to=settings.agent_settings.llm_name, sender=self.name, body=text) prompt = LLMPromptMessage(
text=text,
norms=norms.split("\n") if norms else [],
goals=goals.split("\n") if norms else [],
)
msg = InternalMessage(
to=settings.agent_settings.llm_name,
sender=self.name,
body=prompt.model_dump_json(),
)
await self.send(msg) await self.send(msg)
self.logger.info("Message sent to LLM agent: %s", text) self.logger.info("Message sent to LLM agent: %s", text)

View File

@@ -1,3 +1,6 @@
+user_said(Message) <- norms("").
goals("").
+user_said(Message) : norms(Norms) & goals(Goals) <-
-user_said(Message); -user_said(Message);
.reply(Message). .reply(Message, Norms, Goals).

View File

@@ -0,0 +1,67 @@
import zmq
from pydantic import ValidationError
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief, BeliefMessage
from control_backend.schemas.program import Program
class BDIProgramManager(BaseAgent):
"""
Will interpret programs received from the HTTP endpoint. Extracts norms, goals, triggers and
forwards them to the BDI as beliefs.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.sub_socket = None
async def _send_to_bdi(self, program: Program):
first_phase = program.phases[0]
norms_belief = Belief(
name="norms",
arguments=[norm.norm for norm in first_phase.norms],
replace=True,
)
goals_belief = Belief(
name="goals",
arguments=[goal.description for goal in first_phase.goals],
replace=True,
)
program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief])
message = InternalMessage(
to=settings.agent_settings.bdi_core_name,
sender=self.name,
body=program_beliefs.model_dump_json(),
thread="beliefs",
)
await self.send(message)
self.logger.debug("Sent new norms and goals to the BDI agent.")
async def _receive_programs(self):
"""
Continuously receive programs from the HTTP endpoint, sent to us over ZMQ.
"""
while True:
topic, body = await self.sub_socket.recv_multipart()
try:
program = Program.model_validate_json(body)
except ValidationError:
self.logger.exception("Received an invalid program.")
continue
await self._send_to_bdi(program)
async def setup(self):
context = Context.instance()
self.sub_socket = context.socket(zmq.SUB)
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
self.sub_socket.subscribe("program")
self.add_behavior(self._receive_programs())

View File

@@ -1,9 +1,11 @@
import json import json
from pydantic import ValidationError
from control_backend.agents.base import BaseAgent from control_backend.agents.base import BaseAgent
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.belief_message import Belief, BeliefMessage
class BDIBeliefCollectorAgent(BaseAgent): class BDIBeliefCollectorAgent(BaseAgent):
@@ -60,10 +62,30 @@ class BDIBeliefCollectorAgent(BaseAgent):
self.logger.debug("Received empty beliefs set.") self.logger.debug("Received empty beliefs set.")
return return
def try_create_belief(name, arguments) -> Belief | None:
"""
Create a belief object from name and arguments, or return None silently if the input is
not correct.
:param name: The name of the belief.
:param arguments: The arguments of the belief.
:return: A Belief object if the input is valid or None.
"""
try:
return Belief(name=name, arguments=arguments)
except ValidationError:
return None
beliefs = [
belief
for name, arguments in beliefs.items()
if (belief := try_create_belief(name, arguments)) is not None
]
self.logger.debug("Forwarding %d beliefs.", len(beliefs)) self.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief_name, belief_list in beliefs.items(): for belief in beliefs:
for belief in belief_list: for argument in belief.arguments:
self.logger.debug(" - %s %s", belief_name, str(belief)) self.logger.debug(" - %s %s", belief.name, argument)
await self._send_beliefs_to_bdi(beliefs, origin=origin) await self._send_beliefs_to_bdi(beliefs, origin=origin)
@@ -71,7 +93,7 @@ class BDIBeliefCollectorAgent(BaseAgent):
"""TODO: implement (after we have emotional recognition)""" """TODO: implement (after we have emotional recognition)"""
pass pass
async def _send_beliefs_to_bdi(self, beliefs: dict, origin: str | None = None): async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None):
""" """
Sends a unified belief packet to the BDI agent. Sends a unified belief packet to the BDI agent.
""" """

View File

@@ -1,13 +1,16 @@
import json import json
import re import re
import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
import httpx import httpx
from pydantic import ValidationError
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from ...schemas.llm_prompt_message import LLMPromptMessage
from .llm_instructions import LLMInstructions from .llm_instructions import LLMInstructions
@@ -18,19 +21,26 @@ class LLMAgent(BaseAgent):
and responds with processed LLM output. and responds with processed LLM output.
""" """
def __init__(self, name: str):
super().__init__(name)
self.history = []
async def setup(self): async def setup(self):
self.logger.info("Setting up %s.", self.name) self.logger.info("Setting up %s.", self.name)
async def handle_message(self, msg: InternalMessage): async def handle_message(self, msg: InternalMessage):
if msg.sender == settings.agent_settings.bdi_core_name: if msg.sender == settings.agent_settings.bdi_core_name:
self.logger.debug("Processing message from BDI core.") self.logger.debug("Processing message from BDI core.")
await self._process_bdi_message(msg) try:
prompt_message = LLMPromptMessage.model_validate_json(msg.body)
await self._process_bdi_message(prompt_message)
except ValidationError:
self.logger.debug("Prompt message from BDI core is invalid.")
else: else:
self.logger.debug("Message ignored (not from BDI core.") self.logger.debug("Message ignored (not from BDI core.")
async def _process_bdi_message(self, message: InternalMessage): async def _process_bdi_message(self, message: LLMPromptMessage):
user_text = message.body async for chunk in self._query_llm(message.text, message.norms, message.goals):
async for chunk in self._query_llm(user_text):
await self._send_reply(chunk) await self._send_reply(chunk)
self.logger.debug( self.logger.debug(
"Finished processing BDI message. Response sent in chunks to BDI core." "Finished processing BDI message. Response sent in chunks to BDI core."
@@ -47,39 +57,49 @@ class LLMAgent(BaseAgent):
) )
await self.send(reply) await self.send(reply)
async def _query_llm(self, prompt: str) -> AsyncGenerator[str]: async def _query_llm(
self, prompt: str, norms: list[str], goals: list[str]
) -> AsyncGenerator[str]:
""" """
Sends a chat completion request to the local LLM service and streams the response by Sends a chat completion request to the local LLM service and streams the response by
yielding fragments separated by punctuation like. yielding fragments separated by punctuation like.
:param prompt: Input text prompt to pass to the LLM. :param prompt: Input text prompt to pass to the LLM.
:param norms: Norms the LLM should hold itself to.
:param goals: Goals the LLM should achieve.
:yield: Fragments of the LLM-generated content. :yield: Fragments of the LLM-generated content.
""" """
instructions = LLMInstructions( self.history.append(
"- Be friendly and respectful.\n" {
"- Make the conversation feel natural and engaging.\n" "role": "user",
"- Speak like a pirate.\n" "content": prompt,
"- When the user asks what you can do, tell them.", }
"- Try to learn the user's name during conversation.\n"
"- Suggest playing a game of asking yes or no questions where you think of a word "
"and the user must guess it.",
) )
instructions = LLMInstructions(norms if norms else None, goals if goals else None)
messages = [ messages = [
{ {
"role": "developer", "role": "developer",
"content": instructions.build_developer_instruction(), "content": instructions.build_developer_instruction(),
}, },
{ *self.history,
"role": "user",
"content": prompt,
},
] ]
message_id = str(uuid.uuid4())
try: try:
full_message = ""
current_chunk = "" current_chunk = ""
async for token in self._stream_query_llm(messages): async for token in self._stream_query_llm(messages):
full_message += token
current_chunk += token current_chunk += token
self.logger.info(
"Received token: %s",
full_message,
extra={"reference": message_id}, # Used in the UI to update old logs
)
# Stream the message in chunks separated by punctuation. # Stream the message in chunks separated by punctuation.
# We include the delimiter in the emitted chunk for natural flow. # We include the delimiter in the emitted chunk for natural flow.
pattern = re.compile(r".*?(?:,|;|:|—||\.{3}|…|\.|\?|!)\s*", re.DOTALL) pattern = re.compile(r".*?(?:,|;|:|—||\.{3}|…|\.|\?|!)\s*", re.DOTALL)
@@ -92,6 +112,13 @@ class LLMAgent(BaseAgent):
# Yield any remaining tail # Yield any remaining tail
if current_chunk: if current_chunk:
yield current_chunk yield current_chunk
self.history.append(
{
"role": "assistant",
"content": full_message,
}
)
except httpx.HTTPError as err: except httpx.HTTPError as err:
self.logger.error("HTTP error.", exc_info=err) self.logger.error("HTTP error.", exc_info=err)
yield "LLM service unavailable." yield "LLM service unavailable."

View File

@@ -5,21 +5,21 @@ class LLMInstructions:
""" """
@staticmethod @staticmethod
def default_norms() -> str: def default_norms() -> list[str]:
return """ return [
Be friendly and respectful. "Be friendly and respectful.",
Make the conversation feel natural and engaging. "Make the conversation feel natural and engaging.",
""".strip() ]
@staticmethod @staticmethod
def default_goals() -> str: def default_goals() -> list[str]:
return """ return [
Try to learn the user's name during conversation. "Try to learn the user's name during conversation.",
""".strip() ]
def __init__(self, norms: str | None = None, goals: str | None = None): def __init__(self, norms: list[str] | None = None, goals: list[str] | None = None):
self.norms = norms if norms is not None else self.default_norms() self.norms = norms or self.default_norms()
self.goals = goals if goals is not None else self.default_goals() self.goals = goals or self.default_goals()
def build_developer_instruction(self) -> str: def build_developer_instruction(self) -> str:
""" """
@@ -35,12 +35,14 @@ class LLMInstructions:
if self.norms: if self.norms:
sections.append("Norms to follow:") sections.append("Norms to follow:")
sections.append(self.norms) for norm in self.norms:
sections.append("- " + norm)
sections.append("") sections.append("")
if self.goals: if self.goals:
sections.append("Goals to reach:") sections.append("Goals to reach:")
sections.append(self.goals) for goal in self.goals:
sections.append("- " + goal)
sections.append("") sections.append("")
return "\n".join(sections).strip() return "\n".join(sections).strip()

View File

@@ -14,6 +14,7 @@ class AgentSettings(BaseModel):
# agent names # agent names
bdi_core_name: str = "bdi_core_agent" bdi_core_name: str = "bdi_core_agent"
bdi_belief_collector_name: str = "belief_collector_agent" bdi_belief_collector_name: str = "belief_collector_agent"
bdi_program_manager_name: str = "bdi_program_manager_agent"
text_belief_extractor_name: str = "text_belief_extractor_agent" text_belief_extractor_name: str = "text_belief_extractor_agent"
vad_name: str = "vad_agent" vad_name: str = "vad_agent"
llm_name: str = "llm_agent" llm_name: str = "llm_agent"

View File

@@ -13,6 +13,7 @@ from control_backend.agents.bdi import (
BDICoreAgent, BDICoreAgent,
TextBeliefExtractorAgent, TextBeliefExtractorAgent,
) )
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
# Communication agents # Communication agents
from control_backend.agents.communication import RICommunicationAgent from control_backend.agents.communication import RICommunicationAgent
@@ -112,6 +113,12 @@ async def lifespan(app: FastAPI):
VADAgent, VADAgent,
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False}, {"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
), ),
"ProgramManagerAgent": (
BDIProgramManager,
{
"name": settings.agent_settings.bdi_program_manager_name,
},
),
} }
agents = [] agents = []

View File

@@ -1,5 +1,11 @@
from pydantic import BaseModel from pydantic import BaseModel
class Belief(BaseModel):
name: str
arguments: list[str]
replace: bool = False
class BeliefMessage(BaseModel): class BeliefMessage(BaseModel):
beliefs: dict[str, list[str]] beliefs: list[Belief]

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
class LLMPromptMessage(BaseModel):
text: str
norms: list[str]
goals: list[str]

View File

@@ -3,35 +3,35 @@ from pydantic import BaseModel
class Norm(BaseModel): class Norm(BaseModel):
id: str id: str
name: str label: str
value: str norm: str
class Goal(BaseModel): class Goal(BaseModel):
id: str id: str
name: str label: str
description: str description: str
achieved: bool achieved: bool
class Trigger(BaseModel): class TriggerKeyword(BaseModel):
id: str
keyword: str
class KeywordTrigger(BaseModel):
id: str id: str
label: str label: str
type: str type: str
value: list[str] keywords: list[TriggerKeyword]
class PhaseData(BaseModel):
norms: list[Norm]
goals: list[Goal]
triggers: list[Trigger]
class Phase(BaseModel): class Phase(BaseModel):
id: str id: str
name: str label: str
nextPhaseId: str norms: list[Norm]
phaseData: PhaseData goals: list[Goal]
triggers: list[KeywordTrigger]
class Program(BaseModel): class Program(BaseModel):

View File

@@ -7,7 +7,7 @@ import pytest
from control_backend.agents.bdi.bdi_core_agent.bdi_core_agent import BDICoreAgent from control_backend.agents.bdi.bdi_core_agent.bdi_core_agent import BDICoreAgent
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_message import BeliefMessage from control_backend.schemas.belief_message import Belief, BeliefMessage
@pytest.fixture @pytest.fixture
@@ -45,7 +45,7 @@ async def test_setup_no_asl(mock_agentspeak_env, agent):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_belief_collector_message(agent, mock_settings): async def test_handle_belief_collector_message(agent, mock_settings):
"""Test that incoming beliefs are added to the BDI agent""" """Test that incoming beliefs are added to the BDI agent"""
beliefs = {"user_said": ["Hello"]} beliefs = [Belief(name="user_said", arguments=["Hello"])]
msg = InternalMessage( msg = InternalMessage(
to="bdi_agent", to="bdi_agent",
sender=mock_settings.agent_settings.bdi_belief_collector_name, sender=mock_settings.agent_settings.bdi_belief_collector_name,
@@ -116,11 +116,11 @@ async def test_custom_actions(agent):
# Invoke action # Invoke action
mock_term = MagicMock() mock_term = MagicMock()
mock_term.args = ["Hello"] mock_term.args = ["Hello", "Norm", "Goal"]
mock_intention = MagicMock() mock_intention = MagicMock()
# Run generator # Run generator
gen = action_fn(agent, mock_term, mock_intention) gen = action_fn(agent, mock_term, mock_intention)
next(gen) # Execute next(gen) # Execute
agent._send_to_llm.assert_called_with("Hello") agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal")

View File

@@ -8,6 +8,7 @@ from control_backend.agents.bdi import (
) )
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.schemas.belief_message import Belief
@pytest.fixture @pytest.fixture
@@ -57,10 +58,11 @@ async def test_handle_message_bad_json(agent, mocker):
async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker): async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}} payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}}
spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock) spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock)
expected = [Belief(name="user_said", arguments=["hello"])]
await agent._handle_belief_text(payload, "origin") await agent._handle_belief_text(payload, "origin")
spy.assert_awaited_once_with(payload["beliefs"], origin="origin") spy.assert_awaited_once_with(expected, origin="origin")
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -76,7 +78,7 @@ async def test_handle_belief_text_no_send_when_empty(agent, mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_beliefs_to_bdi(agent): async def test_send_beliefs_to_bdi(agent):
agent.send = AsyncMock() agent.send = AsyncMock()
beliefs = {"user_said": ["hello", "world"]} beliefs = [Belief(name="user_said", arguments=["hello", "world"])]
await agent._send_beliefs_to_bdi(beliefs, origin="origin") await agent._send_beliefs_to_bdi(beliefs, origin="origin")
@@ -84,4 +86,4 @@ async def test_send_beliefs_to_bdi(agent):
sent: InternalMessage = agent.send.call_args.args[0] sent: InternalMessage = agent.send.call_args.args[0]
assert sent.to == settings.agent_settings.bdi_core_name assert sent.to == settings.agent_settings.bdi_core_name
assert sent.thread == "beliefs" assert sent.thread == "beliefs"
assert json.loads(sent.body)["beliefs"] == beliefs assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs]

View File

@@ -7,6 +7,7 @@ import pytest
from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions
from control_backend.core.agent_system import InternalMessage from control_backend.core.agent_system import InternalMessage
from control_backend.schemas.llm_prompt_message import LLMPromptMessage
@pytest.fixture @pytest.fixture
@@ -49,8 +50,11 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
agent.send = AsyncMock() # Mock the send method to verify replies agent.send = AsyncMock() # Mock the send method to verify replies
# Simulate receiving a message from BDI # Simulate receiving a message from BDI
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage( msg = InternalMessage(
to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body="Hi" to="llm_agent",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
) )
await agent.handle_message(msg) await agent.handle_message(msg)
@@ -68,7 +72,12 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings):
async def test_llm_processing_errors(mock_httpx_client, mock_settings): async def test_llm_processing_errors(mock_httpx_client, mock_settings):
agent = LLMAgent("llm_agent") agent = LLMAgent("llm_agent")
agent.send = AsyncMock() agent.send = AsyncMock()
msg = InternalMessage(to="llm", sender=mock_settings.agent_settings.bdi_core_name, body="Hi") prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage(
to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
)
# HTTP Error # HTTP Error
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail")) mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
@@ -103,8 +112,11 @@ async def test_llm_json_error(mock_httpx_client, mock_settings):
agent.send = AsyncMock() agent.send = AsyncMock()
with patch.object(agent.logger, "error") as log: with patch.object(agent.logger, "error") as log:
prompt = LLMPromptMessage(text="Hi", norms=[], goals=[])
msg = InternalMessage( msg = InternalMessage(
to="llm", sender=mock_settings.agent_settings.bdi_core_name, body="Hi" to="llm",
sender=mock_settings.agent_settings.bdi_core_name,
body=prompt.model_dump_json(),
) )
await agent.handle_message(msg) await agent.handle_message(msg)
log.assert_called() # Should log JSONDecodeError log.assert_called() # Should log JSONDecodeError
@@ -112,10 +124,10 @@ async def test_llm_json_error(mock_httpx_client, mock_settings):
def test_llm_instructions(): def test_llm_instructions():
# Full custom # Full custom
instr = LLMInstructions(norms="N", goals="G") instr = LLMInstructions(norms=["N1", "N2"], goals=["G1", "G2"])
text = instr.build_developer_instruction() text = instr.build_developer_instruction()
assert "Norms to follow:\nN" in text assert "Norms to follow:\n- N1\n- N2" in text
assert "Goals to reach:\nG" in text assert "Goals to reach:\n- G1\n- G2" in text
# Defaults # Defaults
instr_def = LLMInstructions() instr_def = LLMInstructions()

View File

@@ -29,22 +29,22 @@ def make_valid_program_dict():
"phases": [ "phases": [
{ {
"id": "phase1", "id": "phase1",
"name": "basephase", "label": "basephase",
"nextPhaseId": "phase2", "norms": [{"id": "n1", "label": "norm", "norm": "be nice"}],
"phaseData": { "goals": [
"norms": [{"id": "n1", "name": "norm", "value": "be nice"}], {"id": "g1", "label": "goal", "description": "test goal", "achieved": False}
"goals": [ ],
{"id": "g1", "name": "goal", "description": "test goal", "achieved": False} "triggers": [
], {
"triggers": [ "id": "t1",
{ "label": "trigger",
"id": "t1", "type": "keywords",
"label": "trigger", "keywords": [
"type": "keyword", {"id": "kw1", "keyword": "keyword1"},
"value": ["stop", "exit"], {"id": "kw2", "keyword": "keyword2"},
} ],
], },
}, ],
} }
] ]
} }

View File

@@ -1,49 +1,52 @@
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from control_backend.schemas.program import Goal, Norm, Phase, PhaseData, Program, Trigger from control_backend.schemas.program import (
Goal,
KeywordTrigger,
Norm,
Phase,
Program,
TriggerKeyword,
)
def base_norm() -> Norm: def base_norm() -> Norm:
return Norm( return Norm(
id="norm1", id="norm1",
name="testNorm", label="testNorm",
value="you should act nice", norm="testNormNorm",
) )
def base_goal() -> Goal: def base_goal() -> Goal:
return Goal( return Goal(
id="goal1", id="goal1",
name="testGoal", label="testGoal",
description="you should act nice", description="testGoalDescription",
achieved=False, achieved=False,
) )
def base_trigger() -> Trigger: def base_trigger() -> KeywordTrigger:
return Trigger( return KeywordTrigger(
id="trigger1", id="trigger1",
label="testTrigger", label="testTrigger",
type="keyword", type="keywords",
value=["Stop", "Exit"], keywords=[
) TriggerKeyword(id="keyword1", keyword="testKeyword1"),
TriggerKeyword(id="keyword1", keyword="testKeyword2"),
],
def base_phase_data() -> PhaseData:
return PhaseData(
norms=[base_norm()],
goals=[base_goal()],
triggers=[base_trigger()],
) )
def base_phase() -> Phase: def base_phase() -> Phase:
return Phase( return Phase(
id="phase1", id="phase1",
name="basephase", label="basephase",
nextPhaseId="phase2", norms=[base_norm()],
phaseData=base_phase_data(), goals=[base_goal()],
triggers=[base_trigger()],
) )
@@ -65,7 +68,7 @@ def test_valid_program():
program = base_program() program = base_program()
validated = Program.model_validate(program) validated = Program.model_validate(program)
assert isinstance(validated, Program) assert isinstance(validated, Program)
assert validated.phases[0].phaseData.norms[0].name == "testNorm" assert validated.phases[0].norms[0].norm == "testNormNorm"
def test_valid_deepprogram(): def test_valid_deepprogram():
@@ -73,10 +76,9 @@ def test_valid_deepprogram():
validated = Program.model_validate(program) validated = Program.model_validate(program)
# validate nested components directly # validate nested components directly
phase = validated.phases[0] phase = validated.phases[0]
assert isinstance(phase.phaseData, PhaseData) assert isinstance(phase.goals[0], Goal)
assert isinstance(phase.phaseData.goals[0], Goal) assert isinstance(phase.triggers[0], KeywordTrigger)
assert isinstance(phase.phaseData.triggers[0], Trigger) assert isinstance(phase.norms[0], Norm)
assert isinstance(phase.phaseData.norms[0], Norm)
def test_invalid_program(): def test_invalid_program():