feat: pydantic models and inter-process messaging
Moved `InternalMessage` into schemas and created a `BeliefMessage` model. Also added the ability for agents to communicate via ZMQ to agents on another process. ref: N25B-316
This commit is contained in:
@@ -1,16 +1,17 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import agentspeak
|
import agentspeak
|
||||||
import agentspeak.runtime
|
import agentspeak.runtime
|
||||||
import agentspeak.stdlib
|
import agentspeak.stdlib
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from control_backend.agents.base import BaseAgent
|
from control_backend.agents.base import BaseAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
from control_backend.schemas.ri_message import SpeechCommand
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
|
|
||||||
@@ -58,16 +59,19 @@ class BDICoreAgent(BaseAgent):
|
|||||||
maybe_more_work = True
|
maybe_more_work = True
|
||||||
while maybe_more_work:
|
while maybe_more_work:
|
||||||
maybe_more_work = False
|
maybe_more_work = False
|
||||||
|
self.logger.debug("Stepping BDI.")
|
||||||
if self.bdi_agent.step():
|
if self.bdi_agent.step():
|
||||||
maybe_more_work = True
|
maybe_more_work = True
|
||||||
|
|
||||||
if not maybe_more_work:
|
if not maybe_more_work:
|
||||||
deadline = self.bdi_agent.shortest_deadline()
|
deadline = self.bdi_agent.shortest_deadline()
|
||||||
if deadline:
|
if deadline:
|
||||||
|
self.logger.debug("Sleeping until %s", deadline)
|
||||||
await asyncio.sleep(deadline - time.time())
|
await asyncio.sleep(deadline - time.time())
|
||||||
maybe_more_work = True
|
maybe_more_work = True
|
||||||
else:
|
else:
|
||||||
self._wake_bdi_loop.clear()
|
self._wake_bdi_loop.clear()
|
||||||
|
self.logger.debug("No more deadlines. Halting BDI loop.")
|
||||||
|
|
||||||
async def handle_message(self, msg: InternalMessage):
|
async def handle_message(self, msg: InternalMessage):
|
||||||
"""
|
"""
|
||||||
@@ -80,10 +84,10 @@ class BDICoreAgent(BaseAgent):
|
|||||||
self.logger.debug("Processing message from belief collector.")
|
self.logger.debug("Processing message from belief collector.")
|
||||||
try:
|
try:
|
||||||
if msg.thread == "beliefs":
|
if msg.thread == "beliefs":
|
||||||
beliefs = json.loads(msg.body)
|
beliefs = BeliefMessage.model_validate_json(msg.body).beliefs
|
||||||
self._add_beliefs(beliefs)
|
self._add_beliefs(beliefs)
|
||||||
except Exception as e:
|
except ValidationError:
|
||||||
self.logger.error(f"Error processing belief: {e}")
|
self.logger.exception("Error processing belief.")
|
||||||
case settings.agent_settings.llm_name:
|
case settings.agent_settings.llm_name:
|
||||||
content = msg.body
|
content = msg.body
|
||||||
self.logger.info("Received LLM response: %s", content)
|
self.logger.info("Received LLM response: %s", content)
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
+user_said(NewMessage) <-
|
+user_said(Message) <-
|
||||||
-user_said(NewMessage);
|
-user_said(Message);
|
||||||
.reply(NewMessage).
|
.reply(Message).
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import json
|
|||||||
from control_backend.agents.base import BaseAgent
|
from control_backend.agents.base import BaseAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
|
|
||||||
|
|
||||||
class BDIBeliefCollectorAgent(BaseAgent):
|
class BDIBeliefCollectorAgent(BaseAgent):
|
||||||
@@ -80,7 +81,7 @@ class BDIBeliefCollectorAgent(BaseAgent):
|
|||||||
msg = InternalMessage(
|
msg = InternalMessage(
|
||||||
to=settings.agent_settings.bdi_core_name,
|
to=settings.agent_settings.bdi_core_name,
|
||||||
sender=self.name,
|
sender=self.name,
|
||||||
body=json.dumps(beliefs),
|
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
|
||||||
thread="beliefs",
|
thread="beliefs",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,24 +2,17 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from dataclasses import dataclass
|
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio as azmq
|
||||||
|
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.internal_message import InternalMessage
|
||||||
|
|
||||||
# Central directory to resolve agent names to instances
|
# Central directory to resolve agent names to instances
|
||||||
_agent_directory: dict[str, "BaseAgent"] = {}
|
_agent_directory: dict[str, "BaseAgent"] = {}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InternalMessage:
|
|
||||||
"""
|
|
||||||
Represents a message to an agent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
to: str
|
|
||||||
sender: str
|
|
||||||
body: str
|
|
||||||
thread: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentDirectory:
|
class AgentDirectory:
|
||||||
"""
|
"""
|
||||||
Helper class to keep track of which agents are registered.
|
Helper class to keep track of which agents are registered.
|
||||||
@@ -67,10 +60,23 @@ class BaseAgent(ABC):
|
|||||||
"""Starts the agent and its loops."""
|
"""Starts the agent and its loops."""
|
||||||
self.logger.info(f"Starting agent {self.name}")
|
self.logger.info(f"Starting agent {self.name}")
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
|
context = azmq.Context.instance()
|
||||||
|
|
||||||
|
# Setup the internal publishing socket
|
||||||
|
self._internal_pub_socket = context.socket(zmq.PUB)
|
||||||
|
self._internal_pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||||
|
|
||||||
|
# Setup the internal receiving socket
|
||||||
|
self._internal_sub_socket = context.socket(zmq.SUB)
|
||||||
|
self._internal_sub_socket.connect(settings.zmq_settings.internal_sub_address)
|
||||||
|
self._internal_sub_socket.subscribe(f"internal/{self.name}")
|
||||||
|
|
||||||
await self.setup()
|
await self.setup()
|
||||||
|
|
||||||
# Start processing inbox
|
# Start processing inbox and ZMQ messages
|
||||||
await self.add_behavior(self._process_inbox())
|
await self.add_behavior(self._process_inbox())
|
||||||
|
await self.add_behavior(self._receive_internal_zmq_loop())
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""Stops the agent."""
|
"""Stops the agent."""
|
||||||
@@ -86,15 +92,38 @@ class BaseAgent(ABC):
|
|||||||
target = AgentDirectory.get(message.to)
|
target = AgentDirectory.get(message.to)
|
||||||
if target:
|
if target:
|
||||||
await target.inbox.put(message)
|
await target.inbox.put(message)
|
||||||
|
self.logger.debug(f"Sent message {message.body} to {message.to} via regular inbox.")
|
||||||
else:
|
else:
|
||||||
self.logger.warning(f"Attempted to send message to unknown agent: {message.to}")
|
# Apparently target agent is on a different process, send via ZMQ
|
||||||
|
topic = f"internal/{message.to}".encode()
|
||||||
|
body = message.model_dump_json().encode()
|
||||||
|
await self._internal_pub_socket.send_multipart([topic, body])
|
||||||
|
self.logger.debug(f"Sent message {message.body} to {message.to} via ZMQ.")
|
||||||
|
|
||||||
async def _process_inbox(self):
|
async def _process_inbox(self):
|
||||||
"""Default loop: equivalent to a CyclicBehaviour receiving messages."""
|
"""Default loop: equivalent to a CyclicBehaviour receiving messages."""
|
||||||
while self._running:
|
while self._running:
|
||||||
msg = await self.inbox.get()
|
msg = await self.inbox.get()
|
||||||
|
self.logger.debug(f"Received message from {msg.sender}.")
|
||||||
await self.handle_message(msg)
|
await self.handle_message(msg)
|
||||||
|
|
||||||
|
async def _receive_internal_zmq_loop(self):
|
||||||
|
"""
|
||||||
|
Listens for internal messages sent from agents on another process via ZMQ
|
||||||
|
and puts them into the normal inbox.
|
||||||
|
"""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
_, body = await self._internal_sub_socket.recv_multipart()
|
||||||
|
|
||||||
|
msg = InternalMessage.model_validate_json(body)
|
||||||
|
|
||||||
|
await self.inbox.put(msg)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception("Could not process ZMQ message.")
|
||||||
|
|
||||||
async def handle_message(self, msg: InternalMessage):
|
async def handle_message(self, msg: InternalMessage):
|
||||||
"""Override this to handle incoming messages."""
|
"""Override this to handle incoming messages."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
5
src/control_backend/schemas/belief_message.py
Normal file
5
src/control_backend/schemas/belief_message.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BeliefMessage(BaseModel):
|
||||||
|
beliefs: dict[str, list[str]]
|
||||||
12
src/control_backend/schemas/internal_message.py
Normal file
12
src/control_backend/schemas/internal_message.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class InternalMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a message to an agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
to: str
|
||||||
|
sender: str
|
||||||
|
body: str
|
||||||
|
thread: str | None = None
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
|
import json
|
||||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||||
|
|
||||||
|
import agentspeak
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from control_backend.agents.bdi.bdi_core_agent.bdi_core_agent import BDICoreAgent
|
from control_backend.agents.bdi.bdi_core_agent.bdi_core_agent import BDICoreAgent
|
||||||
from control_backend.core.agent_system import InternalMessage
|
from control_backend.core.agent_system import InternalMessage
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.belief_message import BeliefMessage
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -40,23 +43,43 @@ async def test_setup_no_asl(mock_agentspeak_env, agent):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handle_belief_collector_message(agent):
|
async def test_handle_belief_collector_message(agent, mock_settings):
|
||||||
"""Test that incoming beliefs are added to the BDI agent"""
|
"""Test that incoming beliefs are added to the BDI agent"""
|
||||||
# Simulate message from belief collector
|
|
||||||
import json
|
|
||||||
|
|
||||||
beliefs = {"user_said": ["Hello"]}
|
beliefs = {"user_said": ["Hello"]}
|
||||||
msg = InternalMessage(
|
msg = InternalMessage(
|
||||||
to="bdi_agent",
|
to="bdi_agent",
|
||||||
sender=settings.agent_settings.bdi_belief_collector_name,
|
sender=mock_settings.agent_settings.bdi_belief_collector_name,
|
||||||
body=json.dumps(beliefs),
|
body=BeliefMessage(beliefs=beliefs).model_dump_json(),
|
||||||
thread="beliefs",
|
thread="beliefs",
|
||||||
)
|
)
|
||||||
|
|
||||||
await agent.handle_message(msg)
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
# Expect bdi_agent.call to be triggered to add belief
|
# Expect bdi_agent.call to be triggered to add belief
|
||||||
assert agent.bdi_agent.call.called
|
args = agent.bdi_agent.call.call_args.args
|
||||||
|
assert args[0] == agentspeak.Trigger.addition
|
||||||
|
assert args[1] == agentspeak.GoalType.belief
|
||||||
|
assert args[2] == agentspeak.Literal("user_said", (agentspeak.Literal("Hello"),))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_incorrect_belief_collector_message(agent, mock_settings):
|
||||||
|
"""Test that incorrect message format triggers an exception."""
|
||||||
|
msg = InternalMessage(
|
||||||
|
to="bdi_agent",
|
||||||
|
sender=mock_settings.agent_settings.bdi_belief_collector_name,
|
||||||
|
body=json.dumps({"bad_format": "bad_format"}),
|
||||||
|
thread="beliefs",
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.handle_message(msg)
|
||||||
|
|
||||||
|
agent.bdi_agent.call.assert_not_called() # did not set belief
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -84,4 +84,4 @@ async def test_send_beliefs_to_bdi(agent):
|
|||||||
sent: InternalMessage = agent.send.call_args.args[0]
|
sent: InternalMessage = agent.send.call_args.args[0]
|
||||||
assert sent.to == settings.agent_settings.bdi_core_name
|
assert sent.to == settings.agent_settings.bdi_core_name
|
||||||
assert sent.thread == "beliefs"
|
assert sent.thread == "beliefs"
|
||||||
assert json.loads(sent.body) == beliefs
|
assert json.loads(sent.body)["beliefs"] == beliefs
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ async def test_agent_lifecycle():
|
|||||||
|
|
||||||
# Wait for task to finish
|
# Wait for task to finish
|
||||||
await asyncio.sleep(0.02)
|
await asyncio.sleep(0.02)
|
||||||
assert len(agent._tasks) == 1 # _process_inbox is still running
|
assert len(agent._tasks) == 2 # message handling tasks are running
|
||||||
|
|
||||||
await agent.stop()
|
await agent.stop()
|
||||||
assert agent._running is False
|
assert agent._running is False
|
||||||
@@ -51,14 +52,15 @@ async def test_agent_lifecycle():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_unknown_agent(caplog):
|
async def test_send_unknown_agent():
|
||||||
agent = ConcreteTestAgent("sender")
|
agent = ConcreteTestAgent("sender")
|
||||||
msg = InternalMessage(to="unknown_sender", sender="sender", body="boo")
|
msg = InternalMessage(to="unknown_receiver", sender="sender", body="boo")
|
||||||
|
|
||||||
|
agent._internal_pub_socket = AsyncMock()
|
||||||
|
|
||||||
with caplog.at_level(logging.WARNING):
|
|
||||||
await agent.send(msg)
|
await agent.send(msg)
|
||||||
|
|
||||||
assert "Attempted to send message to unknown agent: unknown_sender" in caplog.text
|
agent._internal_pub_socket.send_multipart.assert_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user