diff --git a/.githooks/check-commit-msg.sh b/.githooks/check-commit-msg.sh index eacf2a8..497a32f 100755 --- a/.githooks/check-commit-msg.sh +++ b/.githooks/check-commit-msg.sh @@ -30,7 +30,7 @@ HEADER=$(head -n 1 "$COMMIT_MSG_FILE") # Check for Merge commits (covers 'git merge' and PR merges from GitHub/GitLab) # Examples: "Merge branch 'main' into ...", "Merge pull request #123 from ..." -MERGE_PATTERN="^Merge (branch|pull request|tag) .*" +MERGE_PATTERN="^Merge (remote-tracking )?(branch|pull request|tag) .*" if [[ "$HEADER" =~ $MERGE_PATTERN ]]; then echo -e "${GREEN}Merge commit detected by message content. Skipping validation.${NC}" exit 0 diff --git a/src/control_backend/agents/bdi/bdi_core_agent/bdi_core_agent.py b/src/control_backend/agents/bdi/bdi_core_agent/bdi_core_agent.py index b798982..7a8cd13 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent/bdi_core_agent.py +++ b/src/control_backend/agents/bdi/bdi_core_agent/bdi_core_agent.py @@ -11,9 +11,12 @@ from pydantic import ValidationError from control_backend.agents.base import BaseAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings -from control_backend.schemas.belief_message import BeliefMessage +from control_backend.schemas.belief_message import Belief, BeliefMessage +from control_backend.schemas.llm_prompt_message import LLMPromptMessage from control_backend.schemas.ri_message import SpeechCommand +DELIMITER = ";\n" # TODO: temporary until we support lists in AgentSpeak + class BDICoreAgent(BaseAgent): bdi_agent: agentspeak.runtime.Agent @@ -77,17 +80,18 @@ class BDICoreAgent(BaseAgent): """ Route incoming messages (Beliefs or LLM responses). """ - sender = msg.sender + self.logger.debug("Processing message from %s.", msg.sender) - match sender: - case settings.agent_settings.bdi_belief_collector_name: - self.logger.debug("Processing message from belief collector.") - try: - if msg.thread == "beliefs": - beliefs = BeliefMessage.model_validate_json(msg.body).beliefs - self._add_beliefs(beliefs) - except ValidationError: - self.logger.exception("Error processing belief.") + if msg.thread == "beliefs": + try: + beliefs = BeliefMessage.model_validate_json(msg.body).beliefs + self._apply_beliefs(beliefs) + except ValidationError: + self.logger.exception("Error processing belief.") + return + + # The message was not a belief, handle special cases based on sender + match msg.sender: case settings.agent_settings.llm_name: content = msg.body self.logger.info("Received LLM response: %s", content) @@ -101,15 +105,19 @@ class BDICoreAgent(BaseAgent): ) await self.send(out_msg) - def _add_beliefs(self, beliefs: dict[str, list[str]]): + def _apply_beliefs(self, beliefs: list[Belief]): if not beliefs: return - for name, args in beliefs.items(): - self._add_belief(name, args) + for belief in beliefs: + if belief.replace: + self._remove_all_with_name(belief.name) + self._add_belief(belief.name, belief.arguments) def _add_belief(self, name: str, args: Iterable[str] = []): - new_args = (agentspeak.Literal(arg) for arg in args) + # new_args = (agentspeak.Literal(arg) for arg in args) # TODO: Eventually support multiple + merged_args = DELIMITER.join(arg for arg in args) + new_args = (agentspeak.Literal(merged_args),) term = agentspeak.Literal(name, new_args) self.bdi_agent.call( @@ -143,7 +151,6 @@ class BDICoreAgent(BaseAgent): else: self.logger.debug("Failed to remove belief (it was not in the belief base).") - # TODO: decide if this is needed def _remove_all_with_name(self, name: str): """ Removes all beliefs that match the given `name`. @@ -155,7 +162,8 @@ class BDICoreAgent(BaseAgent): removed_count = 0 for group in relevant_groups: - for belief in self.bdi_agent.beliefs[group]: + beliefs_to_remove = list(self.bdi_agent.beliefs[group]) + for belief in beliefs_to_remove: self.bdi_agent.call( agentspeak.Trigger.removal, agentspeak.GoalType.belief, @@ -175,21 +183,37 @@ class BDICoreAgent(BaseAgent): the function expects (which will be located in `term.args`). """ - @self.actions.add(".reply", 1) - def _reply(agent, term, intention): + @self.actions.add(".reply", 3) + def _reply(agent: "BDICoreAgent", term, intention): """ - Sends text to the LLM. + Sends text to the LLM (AgentSpeak action). + Example: .reply("Hello LLM!", "Some norm", "Some goal") """ message_text = agentspeak.grounded(term.args[0], intention.scope) + norms = agentspeak.grounded(term.args[1], intention.scope) + goals = agentspeak.grounded(term.args[2], intention.scope) - asyncio.create_task(self._send_to_llm(str(message_text))) + self.logger.debug("Norms: %s", norms) + self.logger.debug("Goals: %s", goals) + self.logger.debug("User text: %s", message_text) + + asyncio.create_task(self._send_to_llm(str(message_text), str(norms), str(goals))) yield - async def _send_to_llm(self, text: str): + async def _send_to_llm(self, text: str, norms: str = None, goals: str = None): """ Sends a text query to the LLM agent asynchronously. """ - msg = InternalMessage(to=settings.agent_settings.llm_name, sender=self.name, body=text) + prompt = LLMPromptMessage( + text=text, + norms=norms.split("\n") if norms else [], + goals=goals.split("\n") if norms else [], + ) + msg = InternalMessage( + to=settings.agent_settings.llm_name, + sender=self.name, + body=prompt.model_dump_json(), + ) await self.send(msg) self.logger.info("Message sent to LLM agent: %s", text) diff --git a/src/control_backend/agents/bdi/bdi_core_agent/rules.asl b/src/control_backend/agents/bdi/bdi_core_agent/rules.asl index a685f93..cc9b4ef 100644 --- a/src/control_backend/agents/bdi/bdi_core_agent/rules.asl +++ b/src/control_backend/agents/bdi/bdi_core_agent/rules.asl @@ -1,3 +1,6 @@ -+user_said(Message) <- +norms(""). +goals(""). + ++user_said(Message) : norms(Norms) & goals(Goals) <- -user_said(Message); - .reply(Message). + .reply(Message, Norms, Goals). diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py new file mode 100644 index 0000000..14d95f8 --- /dev/null +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -0,0 +1,67 @@ +import zmq +from pydantic import ValidationError +from zmq.asyncio import Context + +from control_backend.agents import BaseAgent +from control_backend.core.agent_system import InternalMessage +from control_backend.core.config import settings +from control_backend.schemas.belief_message import Belief, BeliefMessage +from control_backend.schemas.program import Program + + +class BDIProgramManager(BaseAgent): + """ + Will interpret programs received from the HTTP endpoint. Extracts norms, goals, triggers and + forwards them to the BDI as beliefs. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.sub_socket = None + + async def _send_to_bdi(self, program: Program): + first_phase = program.phases[0] + norms_belief = Belief( + name="norms", + arguments=[norm.norm for norm in first_phase.norms], + replace=True, + ) + goals_belief = Belief( + name="goals", + arguments=[goal.description for goal in first_phase.goals], + replace=True, + ) + program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief]) + + message = InternalMessage( + to=settings.agent_settings.bdi_core_name, + sender=self.name, + body=program_beliefs.model_dump_json(), + thread="beliefs", + ) + await self.send(message) + self.logger.debug("Sent new norms and goals to the BDI agent.") + + async def _receive_programs(self): + """ + Continuously receive programs from the HTTP endpoint, sent to us over ZMQ. + """ + while True: + topic, body = await self.sub_socket.recv_multipart() + + try: + program = Program.model_validate_json(body) + except ValidationError: + self.logger.exception("Received an invalid program.") + continue + + await self._send_to_bdi(program) + + async def setup(self): + context = Context.instance() + + self.sub_socket = context.socket(zmq.SUB) + self.sub_socket.connect(settings.zmq_settings.internal_sub_address) + self.sub_socket.subscribe("program") + + self.add_behavior(self._receive_programs()) diff --git a/src/control_backend/agents/bdi/belief_collector_agent.py b/src/control_backend/agents/bdi/belief_collector_agent.py index 5d25204..9f68461 100644 --- a/src/control_backend/agents/bdi/belief_collector_agent.py +++ b/src/control_backend/agents/bdi/belief_collector_agent.py @@ -1,9 +1,11 @@ import json +from pydantic import ValidationError + from control_backend.agents.base import BaseAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings -from control_backend.schemas.belief_message import BeliefMessage +from control_backend.schemas.belief_message import Belief, BeliefMessage class BDIBeliefCollectorAgent(BaseAgent): @@ -60,10 +62,30 @@ class BDIBeliefCollectorAgent(BaseAgent): self.logger.debug("Received empty beliefs set.") return + def try_create_belief(name, arguments) -> Belief | None: + """ + Create a belief object from name and arguments, or return None silently if the input is + not correct. + + :param name: The name of the belief. + :param arguments: The arguments of the belief. + :return: A Belief object if the input is valid or None. + """ + try: + return Belief(name=name, arguments=arguments) + except ValidationError: + return None + + beliefs = [ + belief + for name, arguments in beliefs.items() + if (belief := try_create_belief(name, arguments)) is not None + ] + self.logger.debug("Forwarding %d beliefs.", len(beliefs)) - for belief_name, belief_list in beliefs.items(): - for belief in belief_list: - self.logger.debug(" - %s %s", belief_name, str(belief)) + for belief in beliefs: + for argument in belief.arguments: + self.logger.debug(" - %s %s", belief.name, argument) await self._send_beliefs_to_bdi(beliefs, origin=origin) @@ -71,7 +93,7 @@ class BDIBeliefCollectorAgent(BaseAgent): """TODO: implement (after we have emotional recognition)""" pass - async def _send_beliefs_to_bdi(self, beliefs: dict, origin: str | None = None): + async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None): """ Sends a unified belief packet to the BDI agent. """ diff --git a/src/control_backend/agents/llm/llm_agent.py b/src/control_backend/agents/llm/llm_agent.py index a6950f2..cc6a982 100644 --- a/src/control_backend/agents/llm/llm_agent.py +++ b/src/control_backend/agents/llm/llm_agent.py @@ -1,13 +1,16 @@ import json import re +import uuid from collections.abc import AsyncGenerator import httpx +from pydantic import ValidationError from control_backend.agents import BaseAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from ...schemas.llm_prompt_message import LLMPromptMessage from .llm_instructions import LLMInstructions @@ -18,19 +21,26 @@ class LLMAgent(BaseAgent): and responds with processed LLM output. """ + def __init__(self, name: str): + super().__init__(name) + self.history = [] + async def setup(self): self.logger.info("Setting up %s.", self.name) async def handle_message(self, msg: InternalMessage): if msg.sender == settings.agent_settings.bdi_core_name: self.logger.debug("Processing message from BDI core.") - await self._process_bdi_message(msg) + try: + prompt_message = LLMPromptMessage.model_validate_json(msg.body) + await self._process_bdi_message(prompt_message) + except ValidationError: + self.logger.debug("Prompt message from BDI core is invalid.") else: self.logger.debug("Message ignored (not from BDI core.") - async def _process_bdi_message(self, message: InternalMessage): - user_text = message.body - async for chunk in self._query_llm(user_text): + async def _process_bdi_message(self, message: LLMPromptMessage): + async for chunk in self._query_llm(message.text, message.norms, message.goals): await self._send_reply(chunk) self.logger.debug( "Finished processing BDI message. Response sent in chunks to BDI core." @@ -47,39 +57,49 @@ class LLMAgent(BaseAgent): ) await self.send(reply) - async def _query_llm(self, prompt: str) -> AsyncGenerator[str]: + async def _query_llm( + self, prompt: str, norms: list[str], goals: list[str] + ) -> AsyncGenerator[str]: """ Sends a chat completion request to the local LLM service and streams the response by yielding fragments separated by punctuation like. :param prompt: Input text prompt to pass to the LLM. + :param norms: Norms the LLM should hold itself to. + :param goals: Goals the LLM should achieve. :yield: Fragments of the LLM-generated content. """ - instructions = LLMInstructions( - "- Be friendly and respectful.\n" - "- Make the conversation feel natural and engaging.\n" - "- Speak like a pirate.\n" - "- When the user asks what you can do, tell them.", - "- Try to learn the user's name during conversation.\n" - "- Suggest playing a game of asking yes or no questions where you think of a word " - "and the user must guess it.", + self.history.append( + { + "role": "user", + "content": prompt, + } ) + + instructions = LLMInstructions(norms if norms else None, goals if goals else None) messages = [ { "role": "developer", "content": instructions.build_developer_instruction(), }, - { - "role": "user", - "content": prompt, - }, + *self.history, ] + message_id = str(uuid.uuid4()) + try: + full_message = "" current_chunk = "" async for token in self._stream_query_llm(messages): + full_message += token current_chunk += token + self.logger.info( + "Received token: %s", + full_message, + extra={"reference": message_id}, # Used in the UI to update old logs + ) + # Stream the message in chunks separated by punctuation. # We include the delimiter in the emitted chunk for natural flow. pattern = re.compile(r".*?(?:,|;|:|—|–|\.{3}|…|\.|\?|!)\s*", re.DOTALL) @@ -92,6 +112,13 @@ class LLMAgent(BaseAgent): # Yield any remaining tail if current_chunk: yield current_chunk + + self.history.append( + { + "role": "assistant", + "content": full_message, + } + ) except httpx.HTTPError as err: self.logger.error("HTTP error.", exc_info=err) yield "LLM service unavailable." diff --git a/src/control_backend/agents/llm/llm_instructions.py b/src/control_backend/agents/llm/llm_instructions.py index 6922fca..5e3e7ba 100644 --- a/src/control_backend/agents/llm/llm_instructions.py +++ b/src/control_backend/agents/llm/llm_instructions.py @@ -5,21 +5,21 @@ class LLMInstructions: """ @staticmethod - def default_norms() -> str: - return """ - Be friendly and respectful. - Make the conversation feel natural and engaging. - """.strip() + def default_norms() -> list[str]: + return [ + "Be friendly and respectful.", + "Make the conversation feel natural and engaging.", + ] @staticmethod - def default_goals() -> str: - return """ - Try to learn the user's name during conversation. - """.strip() + def default_goals() -> list[str]: + return [ + "Try to learn the user's name during conversation.", + ] - def __init__(self, norms: str | None = None, goals: str | None = None): - self.norms = norms if norms is not None else self.default_norms() - self.goals = goals if goals is not None else self.default_goals() + def __init__(self, norms: list[str] | None = None, goals: list[str] | None = None): + self.norms = norms or self.default_norms() + self.goals = goals or self.default_goals() def build_developer_instruction(self) -> str: """ @@ -35,12 +35,14 @@ class LLMInstructions: if self.norms: sections.append("Norms to follow:") - sections.append(self.norms) + for norm in self.norms: + sections.append("- " + norm) sections.append("") if self.goals: sections.append("Goals to reach:") - sections.append(self.goals) + for goal in self.goals: + sections.append("- " + goal) sections.append("") return "\n".join(sections).strip() diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index bf131af..a959ae6 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -14,6 +14,7 @@ class AgentSettings(BaseModel): # agent names bdi_core_name: str = "bdi_core_agent" bdi_belief_collector_name: str = "belief_collector_agent" + bdi_program_manager_name: str = "bdi_program_manager_agent" text_belief_extractor_name: str = "text_belief_extractor_agent" vad_name: str = "vad_agent" llm_name: str = "llm_agent" diff --git a/src/control_backend/main.py b/src/control_backend/main.py index b16e01d..afa923e 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -13,6 +13,7 @@ from control_backend.agents.bdi import ( BDICoreAgent, TextBeliefExtractorAgent, ) +from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager # Communication agents from control_backend.agents.communication import RICommunicationAgent @@ -112,6 +113,12 @@ async def lifespan(app: FastAPI): VADAgent, {"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False}, ), + "ProgramManagerAgent": ( + BDIProgramManager, + { + "name": settings.agent_settings.bdi_program_manager_name, + }, + ), } agents = [] diff --git a/src/control_backend/schemas/belief_message.py b/src/control_backend/schemas/belief_message.py index a5f7507..1a0ef89 100644 --- a/src/control_backend/schemas/belief_message.py +++ b/src/control_backend/schemas/belief_message.py @@ -1,5 +1,11 @@ from pydantic import BaseModel +class Belief(BaseModel): + name: str + arguments: list[str] + replace: bool = False + + class BeliefMessage(BaseModel): - beliefs: dict[str, list[str]] + beliefs: list[Belief] diff --git a/src/control_backend/schemas/llm_prompt_message.py b/src/control_backend/schemas/llm_prompt_message.py new file mode 100644 index 0000000..12f8887 --- /dev/null +++ b/src/control_backend/schemas/llm_prompt_message.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class LLMPromptMessage(BaseModel): + text: str + norms: list[str] + goals: list[str] diff --git a/src/control_backend/schemas/program.py b/src/control_backend/schemas/program.py index c207757..db94347 100644 --- a/src/control_backend/schemas/program.py +++ b/src/control_backend/schemas/program.py @@ -3,35 +3,35 @@ from pydantic import BaseModel class Norm(BaseModel): id: str - name: str - value: str + label: str + norm: str class Goal(BaseModel): id: str - name: str + label: str description: str achieved: bool -class Trigger(BaseModel): +class TriggerKeyword(BaseModel): + id: str + keyword: str + + +class KeywordTrigger(BaseModel): id: str label: str type: str - value: list[str] - - -class PhaseData(BaseModel): - norms: list[Norm] - goals: list[Goal] - triggers: list[Trigger] + keywords: list[TriggerKeyword] class Phase(BaseModel): id: str - name: str - nextPhaseId: str - phaseData: PhaseData + label: str + norms: list[Norm] + goals: list[Goal] + triggers: list[KeywordTrigger] class Program(BaseModel): diff --git a/test/unit/agents/bdi/test_bdi_core_agent.py b/test/unit/agents/bdi/test_bdi_core_agent.py index 43ee033..5c73b76 100644 --- a/test/unit/agents/bdi/test_bdi_core_agent.py +++ b/test/unit/agents/bdi/test_bdi_core_agent.py @@ -7,7 +7,7 @@ import pytest from control_backend.agents.bdi.bdi_core_agent.bdi_core_agent import BDICoreAgent from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings -from control_backend.schemas.belief_message import BeliefMessage +from control_backend.schemas.belief_message import Belief, BeliefMessage @pytest.fixture @@ -45,7 +45,7 @@ async def test_setup_no_asl(mock_agentspeak_env, agent): @pytest.mark.asyncio async def test_handle_belief_collector_message(agent, mock_settings): """Test that incoming beliefs are added to the BDI agent""" - beliefs = {"user_said": ["Hello"]} + beliefs = [Belief(name="user_said", arguments=["Hello"])] msg = InternalMessage( to="bdi_agent", sender=mock_settings.agent_settings.bdi_belief_collector_name, @@ -116,11 +116,11 @@ async def test_custom_actions(agent): # Invoke action mock_term = MagicMock() - mock_term.args = ["Hello"] + mock_term.args = ["Hello", "Norm", "Goal"] mock_intention = MagicMock() # Run generator gen = action_fn(agent, mock_term, mock_intention) next(gen) # Execute - agent._send_to_llm.assert_called_with("Hello") + agent._send_to_llm.assert_called_with("Hello", "Norm", "Goal") diff --git a/test/unit/agents/bdi/test_belief_collector.py b/test/unit/agents/bdi/test_belief_collector.py index ca89a9d..df28ac4 100644 --- a/test/unit/agents/bdi/test_belief_collector.py +++ b/test/unit/agents/bdi/test_belief_collector.py @@ -8,6 +8,7 @@ from control_backend.agents.bdi import ( ) from control_backend.core.agent_system import InternalMessage from control_backend.core.config import settings +from control_backend.schemas.belief_message import Belief @pytest.fixture @@ -57,10 +58,11 @@ async def test_handle_message_bad_json(agent, mocker): async def test_handle_belief_text_sends_when_beliefs_exist(agent, mocker): payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello"]}} spy = mocker.patch.object(agent, "_send_beliefs_to_bdi", new_callable=AsyncMock) + expected = [Belief(name="user_said", arguments=["hello"])] await agent._handle_belief_text(payload, "origin") - spy.assert_awaited_once_with(payload["beliefs"], origin="origin") + spy.assert_awaited_once_with(expected, origin="origin") @pytest.mark.asyncio @@ -76,7 +78,7 @@ async def test_handle_belief_text_no_send_when_empty(agent, mocker): @pytest.mark.asyncio async def test_send_beliefs_to_bdi(agent): agent.send = AsyncMock() - beliefs = {"user_said": ["hello", "world"]} + beliefs = [Belief(name="user_said", arguments=["hello", "world"])] await agent._send_beliefs_to_bdi(beliefs, origin="origin") @@ -84,4 +86,4 @@ async def test_send_beliefs_to_bdi(agent): sent: InternalMessage = agent.send.call_args.args[0] assert sent.to == settings.agent_settings.bdi_core_name assert sent.thread == "beliefs" - assert json.loads(sent.body)["beliefs"] == beliefs + assert json.loads(sent.body)["beliefs"] == [belief.model_dump() for belief in beliefs] diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index 4a8b7df..2f1b72e 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -7,6 +7,7 @@ import pytest from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions from control_backend.core.agent_system import InternalMessage +from control_backend.schemas.llm_prompt_message import LLMPromptMessage @pytest.fixture @@ -49,8 +50,11 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings): agent.send = AsyncMock() # Mock the send method to verify replies # Simulate receiving a message from BDI + prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) msg = InternalMessage( - to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body="Hi" + to="llm_agent", + sender=mock_settings.agent_settings.bdi_core_name, + body=prompt.model_dump_json(), ) await agent.handle_message(msg) @@ -68,7 +72,12 @@ async def test_llm_processing_success(mock_httpx_client, mock_settings): async def test_llm_processing_errors(mock_httpx_client, mock_settings): agent = LLMAgent("llm_agent") agent.send = AsyncMock() - msg = InternalMessage(to="llm", sender=mock_settings.agent_settings.bdi_core_name, body="Hi") + prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) + msg = InternalMessage( + to="llm", + sender=mock_settings.agent_settings.bdi_core_name, + body=prompt.model_dump_json(), + ) # HTTP Error mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail")) @@ -103,8 +112,11 @@ async def test_llm_json_error(mock_httpx_client, mock_settings): agent.send = AsyncMock() with patch.object(agent.logger, "error") as log: + prompt = LLMPromptMessage(text="Hi", norms=[], goals=[]) msg = InternalMessage( - to="llm", sender=mock_settings.agent_settings.bdi_core_name, body="Hi" + to="llm", + sender=mock_settings.agent_settings.bdi_core_name, + body=prompt.model_dump_json(), ) await agent.handle_message(msg) log.assert_called() # Should log JSONDecodeError @@ -112,10 +124,10 @@ async def test_llm_json_error(mock_httpx_client, mock_settings): def test_llm_instructions(): # Full custom - instr = LLMInstructions(norms="N", goals="G") + instr = LLMInstructions(norms=["N1", "N2"], goals=["G1", "G2"]) text = instr.build_developer_instruction() - assert "Norms to follow:\nN" in text - assert "Goals to reach:\nG" in text + assert "Norms to follow:\n- N1\n- N2" in text + assert "Goals to reach:\n- G1\n- G2" in text # Defaults instr_def = LLMInstructions() diff --git a/test/unit/api/v1/endpoints/test_program_endpoint.py b/test/unit/api/v1/endpoints/test_program_endpoint.py index f6bb261..178159c 100644 --- a/test/unit/api/v1/endpoints/test_program_endpoint.py +++ b/test/unit/api/v1/endpoints/test_program_endpoint.py @@ -29,22 +29,22 @@ def make_valid_program_dict(): "phases": [ { "id": "phase1", - "name": "basephase", - "nextPhaseId": "phase2", - "phaseData": { - "norms": [{"id": "n1", "name": "norm", "value": "be nice"}], - "goals": [ - {"id": "g1", "name": "goal", "description": "test goal", "achieved": False} - ], - "triggers": [ - { - "id": "t1", - "label": "trigger", - "type": "keyword", - "value": ["stop", "exit"], - } - ], - }, + "label": "basephase", + "norms": [{"id": "n1", "label": "norm", "norm": "be nice"}], + "goals": [ + {"id": "g1", "label": "goal", "description": "test goal", "achieved": False} + ], + "triggers": [ + { + "id": "t1", + "label": "trigger", + "type": "keywords", + "keywords": [ + {"id": "kw1", "keyword": "keyword1"}, + {"id": "kw2", "keyword": "keyword2"}, + ], + }, + ], } ] } diff --git a/test/unit/schemas/test_ui_program_message.py b/test/unit/schemas/test_ui_program_message.py index 36352d6..7ed544e 100644 --- a/test/unit/schemas/test_ui_program_message.py +++ b/test/unit/schemas/test_ui_program_message.py @@ -1,49 +1,52 @@ import pytest from pydantic import ValidationError -from control_backend.schemas.program import Goal, Norm, Phase, PhaseData, Program, Trigger +from control_backend.schemas.program import ( + Goal, + KeywordTrigger, + Norm, + Phase, + Program, + TriggerKeyword, +) def base_norm() -> Norm: return Norm( id="norm1", - name="testNorm", - value="you should act nice", + label="testNorm", + norm="testNormNorm", ) def base_goal() -> Goal: return Goal( id="goal1", - name="testGoal", - description="you should act nice", + label="testGoal", + description="testGoalDescription", achieved=False, ) -def base_trigger() -> Trigger: - return Trigger( +def base_trigger() -> KeywordTrigger: + return KeywordTrigger( id="trigger1", label="testTrigger", - type="keyword", - value=["Stop", "Exit"], - ) - - -def base_phase_data() -> PhaseData: - return PhaseData( - norms=[base_norm()], - goals=[base_goal()], - triggers=[base_trigger()], + type="keywords", + keywords=[ + TriggerKeyword(id="keyword1", keyword="testKeyword1"), + TriggerKeyword(id="keyword1", keyword="testKeyword2"), + ], ) def base_phase() -> Phase: return Phase( id="phase1", - name="basephase", - nextPhaseId="phase2", - phaseData=base_phase_data(), + label="basephase", + norms=[base_norm()], + goals=[base_goal()], + triggers=[base_trigger()], ) @@ -65,7 +68,7 @@ def test_valid_program(): program = base_program() validated = Program.model_validate(program) assert isinstance(validated, Program) - assert validated.phases[0].phaseData.norms[0].name == "testNorm" + assert validated.phases[0].norms[0].norm == "testNormNorm" def test_valid_deepprogram(): @@ -73,10 +76,9 @@ def test_valid_deepprogram(): validated = Program.model_validate(program) # validate nested components directly phase = validated.phases[0] - assert isinstance(phase.phaseData, PhaseData) - assert isinstance(phase.phaseData.goals[0], Goal) - assert isinstance(phase.phaseData.triggers[0], Trigger) - assert isinstance(phase.phaseData.norms[0], Norm) + assert isinstance(phase.goals[0], Goal) + assert isinstance(phase.triggers[0], KeywordTrigger) + assert isinstance(phase.norms[0], Norm) def test_invalid_program():