feat: add program manager
ref: N25B-299
This commit is contained in:
@@ -11,7 +11,7 @@ from pydantic import ValidationError
|
|||||||
from control_backend.agents.base import BaseAgent
|
from control_backend.agents.base import BaseAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.belief_message import BeliefMessage
|
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||||
from control_backend.schemas.ri_message import SpeechCommand
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
|
|
||||||
@@ -77,17 +77,18 @@ class BDICoreAgent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
Route incoming messages (Beliefs or LLM responses).
|
Route incoming messages (Beliefs or LLM responses).
|
||||||
"""
|
"""
|
||||||
sender = msg.sender
|
self.logger.debug("Processing message from %s.", msg.sender)
|
||||||
|
|
||||||
match sender:
|
if msg.thread == "beliefs":
|
||||||
case settings.agent_settings.bdi_belief_collector_name:
|
try:
|
||||||
self.logger.debug("Processing message from belief collector.")
|
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
|
||||||
try:
|
self._apply_beliefs(beliefs)
|
||||||
if msg.thread == "beliefs":
|
except ValidationError:
|
||||||
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
|
self.logger.exception("Error processing belief.")
|
||||||
self._add_beliefs(beliefs)
|
return
|
||||||
except ValidationError:
|
|
||||||
self.logger.exception("Error processing belief.")
|
# The message was not a belief, handle special cases based on sender
|
||||||
|
match msg.sender:
|
||||||
case settings.agent_settings.llm_name:
|
case settings.agent_settings.llm_name:
|
||||||
content = msg.body
|
content = msg.body
|
||||||
self.logger.info("Received LLM response: %s", content)
|
self.logger.info("Received LLM response: %s", content)
|
||||||
@@ -101,12 +102,14 @@ class BDICoreAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
await self.send(out_msg)
|
await self.send(out_msg)
|
||||||
|
|
||||||
def _add_beliefs(self, beliefs: dict[str, list[str]]):
|
def _apply_beliefs(self, beliefs: list[Belief]):
|
||||||
if not beliefs:
|
if not beliefs:
|
||||||
return
|
return
|
||||||
|
|
||||||
for name, args in beliefs.items():
|
for belief in beliefs:
|
||||||
self._add_belief(name, args)
|
if belief.replace:
|
||||||
|
self._remove_all_with_name(belief.name)
|
||||||
|
self._add_belief(belief.name, belief.arguments)
|
||||||
|
|
||||||
def _add_belief(self, name: str, args: Iterable[str] = []):
|
def _add_belief(self, name: str, args: Iterable[str] = []):
|
||||||
new_args = (agentspeak.Literal(arg) for arg in args)
|
new_args = (agentspeak.Literal(arg) for arg in args)
|
||||||
@@ -143,7 +146,6 @@ class BDICoreAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
self.logger.debug("Failed to remove belief (it was not in the belief base).")
|
self.logger.debug("Failed to remove belief (it was not in the belief base).")
|
||||||
|
|
||||||
# TODO: decide if this is needed
|
|
||||||
def _remove_all_with_name(self, name: str):
|
def _remove_all_with_name(self, name: str):
|
||||||
"""
|
"""
|
||||||
Removes all beliefs that match the given `name`.
|
Removes all beliefs that match the given `name`.
|
||||||
@@ -155,7 +157,8 @@ class BDICoreAgent(BaseAgent):
|
|||||||
|
|
||||||
removed_count = 0
|
removed_count = 0
|
||||||
for group in relevant_groups:
|
for group in relevant_groups:
|
||||||
for belief in self.bdi_agent.beliefs[group]:
|
beliefs_to_remove = list(self.bdi_agent.beliefs[group])
|
||||||
|
for belief in beliefs_to_remove:
|
||||||
self.bdi_agent.call(
|
self.bdi_agent.call(
|
||||||
agentspeak.Trigger.removal,
|
agentspeak.Trigger.removal,
|
||||||
agentspeak.GoalType.belief,
|
agentspeak.GoalType.belief,
|
||||||
|
|||||||
67
src/control_backend/agents/bdi/bdi_program_manager.py
Normal file
67
src/control_backend/agents/bdi/bdi_program_manager.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import zmq
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
|
from control_backend.agents import BaseAgent
|
||||||
|
from control_backend.core.agent_system import InternalMessage
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||||
|
from control_backend.schemas.program import Program
|
||||||
|
|
||||||
|
|
||||||
|
class BDIProgramManager(BaseAgent):
|
||||||
|
"""
|
||||||
|
Will interpret programs received from the HTTP endpoint. Extracts norms, goals, triggers and
|
||||||
|
forwards them to the BDI as beliefs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.sub_socket = None
|
||||||
|
|
||||||
|
async def _send_to_bdi(self, program: Program):
|
||||||
|
first_phase = program.phases[0]
|
||||||
|
norms_belief = Belief(
|
||||||
|
name="norms",
|
||||||
|
arguments=[norm.norm for norm in first_phase.norms],
|
||||||
|
replace=True,
|
||||||
|
)
|
||||||
|
goals_belief = Belief(
|
||||||
|
name="goals",
|
||||||
|
arguments=[goal.description for goal in first_phase.goals],
|
||||||
|
replace=True,
|
||||||
|
)
|
||||||
|
program_beliefs = BeliefMessage(beliefs=[norms_belief, goals_belief])
|
||||||
|
|
||||||
|
message = InternalMessage(
|
||||||
|
to=settings.agent_settings.bdi_core_name,
|
||||||
|
sender=self.name,
|
||||||
|
body=program_beliefs.model_dump_json(),
|
||||||
|
thread="beliefs",
|
||||||
|
)
|
||||||
|
await self.send(message)
|
||||||
|
self.logger.debug("Sent new norms and goals to the BDI agent.")
|
||||||
|
|
||||||
|
async def _receive_programs(self):
|
||||||
|
"""
|
||||||
|
Continuously receive programs from the HTTP endpoint, sent to us over ZMQ.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
topic, body = await self.sub_socket.recv_multipart()
|
||||||
|
|
||||||
|
try:
|
||||||
|
program = Program.model_validate_json(body)
|
||||||
|
except ValidationError as e:
|
||||||
|
self.logger.error("Received an invalid program.", exc_info=e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self._send_to_bdi(program)
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
context = Context.instance()
|
||||||
|
|
||||||
|
self.sub_socket = context.socket(zmq.SUB)
|
||||||
|
self.sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||||
|
self.sub_socket.subscribe("program")
|
||||||
|
|
||||||
|
await self.add_behavior(self._receive_programs())
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from control_backend.agents.base import BaseAgent
|
from control_backend.agents.base import BaseAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.belief_message import BeliefMessage
|
from control_backend.schemas.belief_message import Belief, BeliefMessage
|
||||||
|
|
||||||
|
|
||||||
class BDIBeliefCollectorAgent(BaseAgent):
|
class BDIBeliefCollectorAgent(BaseAgent):
|
||||||
@@ -60,10 +62,30 @@ class BDIBeliefCollectorAgent(BaseAgent):
|
|||||||
self.logger.debug("Received empty beliefs set.")
|
self.logger.debug("Received empty beliefs set.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def try_create_belief(name, arguments) -> Belief | None:
|
||||||
|
"""
|
||||||
|
Create a belief object from name and arguments, or return None silently if the input is
|
||||||
|
not correct.
|
||||||
|
|
||||||
|
:param name: The name of the belief.
|
||||||
|
:param arguments: The arguments of the belief.
|
||||||
|
:return: A Belief object if the input is valid or None.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return Belief(name=name, arguments=arguments)
|
||||||
|
except ValidationError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
beliefs = [
|
||||||
|
belief
|
||||||
|
for name, arguments in beliefs.items()
|
||||||
|
if (belief := try_create_belief(name, arguments)) is not None
|
||||||
|
]
|
||||||
|
|
||||||
self.logger.debug("Forwarding %d beliefs.", len(beliefs))
|
self.logger.debug("Forwarding %d beliefs.", len(beliefs))
|
||||||
for belief_name, belief_list in beliefs.items():
|
for belief in beliefs:
|
||||||
for belief in belief_list:
|
for argument in belief.arguments:
|
||||||
self.logger.debug(" - %s %s", belief_name, str(belief))
|
self.logger.debug(" - %s %s", belief.name, argument)
|
||||||
|
|
||||||
await self._send_beliefs_to_bdi(beliefs, origin=origin)
|
await self._send_beliefs_to_bdi(beliefs, origin=origin)
|
||||||
|
|
||||||
@@ -71,7 +93,7 @@ class BDIBeliefCollectorAgent(BaseAgent):
|
|||||||
"""TODO: implement (after we have emotional recognition)"""
|
"""TODO: implement (after we have emotional recognition)"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _send_beliefs_to_bdi(self, beliefs: dict, origin: str | None = None):
|
async def _send_beliefs_to_bdi(self, beliefs: list[Belief], origin: str | None = None):
|
||||||
"""
|
"""
|
||||||
Sends a unified belief packet to the BDI agent.
|
Sends a unified belief packet to the BDI agent.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class AgentSettings(BaseModel):
|
|||||||
# agent names
|
# agent names
|
||||||
bdi_core_name: str = "bdi_core_agent"
|
bdi_core_name: str = "bdi_core_agent"
|
||||||
bdi_belief_collector_name: str = "belief_collector_agent"
|
bdi_belief_collector_name: str = "belief_collector_agent"
|
||||||
|
bdi_program_manager_name: str = "bdi_program_manager_agent"
|
||||||
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
text_belief_extractor_name: str = "text_belief_extractor_agent"
|
||||||
vad_name: str = "vad_agent"
|
vad_name: str = "vad_agent"
|
||||||
llm_name: str = "llm_agent"
|
llm_name: str = "llm_agent"
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from control_backend.agents.bdi import (
|
|||||||
BDICoreAgent,
|
BDICoreAgent,
|
||||||
TextBeliefExtractorAgent,
|
TextBeliefExtractorAgent,
|
||||||
)
|
)
|
||||||
|
from control_backend.agents.bdi.bdi_program_manager import BDIProgramManager
|
||||||
|
|
||||||
# Communication agents
|
# Communication agents
|
||||||
from control_backend.agents.communication import RICommunicationAgent
|
from control_backend.agents.communication import RICommunicationAgent
|
||||||
@@ -112,6 +113,12 @@ async def lifespan(app: FastAPI):
|
|||||||
VADAgent,
|
VADAgent,
|
||||||
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
|
{"audio_in_address": settings.zmq_settings.vad_agent_address, "audio_in_bind": False},
|
||||||
),
|
),
|
||||||
|
"ProgramManagerAgent": (
|
||||||
|
BDIProgramManager,
|
||||||
|
{
|
||||||
|
"name": settings.agent_settings.bdi_program_manager_name,
|
||||||
|
},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
agents = []
|
agents = []
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Belief(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: list[str]
|
||||||
|
replace: bool = False
|
||||||
|
|
||||||
|
|
||||||
class BeliefMessage(BaseModel):
|
class BeliefMessage(BaseModel):
|
||||||
beliefs: dict[str, list[str]]
|
beliefs: list[Belief]
|
||||||
|
|||||||
@@ -3,36 +3,36 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class Norm(BaseModel):
|
class Norm(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
label: str
|
||||||
value: str
|
norm: str
|
||||||
|
|
||||||
|
|
||||||
class Goal(BaseModel):
|
class Goal(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
label: str
|
||||||
description: str
|
description: str
|
||||||
achieved: bool
|
achieved: bool
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordTrigger(BaseModel):
|
||||||
|
id: str
|
||||||
|
keyword: str
|
||||||
|
|
||||||
|
|
||||||
class Trigger(BaseModel):
|
class Trigger(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
label: str
|
label: str
|
||||||
type: str
|
type: str
|
||||||
value: list[str]
|
keywords: list[KeywordTrigger]
|
||||||
|
|
||||||
|
|
||||||
class PhaseData(BaseModel):
|
class Phase(BaseModel):
|
||||||
|
id: str
|
||||||
|
label: str
|
||||||
norms: list[Norm]
|
norms: list[Norm]
|
||||||
goals: list[Goal]
|
goals: list[Goal]
|
||||||
triggers: list[Trigger]
|
triggers: list[Trigger]
|
||||||
|
|
||||||
|
|
||||||
class Phase(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
nextPhaseId: str
|
|
||||||
phaseData: PhaseData
|
|
||||||
|
|
||||||
|
|
||||||
class Program(BaseModel):
|
class Program(BaseModel):
|
||||||
phases: list[Phase]
|
phases: list[Phase]
|
||||||
|
|||||||
Reference in New Issue
Block a user