275 lines
7.0 KiB
Python
275 lines
7.0 KiB
Python
"""Test the base class logic, message passing and background task handling."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from control_backend.core.agent_system import AgentDirectory, BaseAgent, InternalMessage
|
|
|
|
|
|
class ConcreteTestAgent(BaseAgent):
|
|
logger = logging.getLogger("test")
|
|
|
|
def __init__(self, name: str):
|
|
super().__init__(name)
|
|
self.received = []
|
|
|
|
async def setup(self):
|
|
pass
|
|
|
|
async def handle_message(self, msg: InternalMessage):
|
|
self.received.append(msg)
|
|
if msg.body == "stop":
|
|
await self.stop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_lifecycle():
|
|
agent = ConcreteTestAgent("lifecycle_agent")
|
|
await agent.start()
|
|
assert agent._running is True
|
|
|
|
# Test background task
|
|
async def dummy_task():
|
|
pass
|
|
|
|
task = agent.add_behavior(dummy_task())
|
|
assert task in agent._tasks
|
|
|
|
await task
|
|
|
|
# Wait for task to finish
|
|
assert task not in agent._tasks
|
|
assert len(agent._tasks) == 2 # message handling tasks are still running
|
|
|
|
await agent.stop()
|
|
assert agent._running is False
|
|
|
|
await asyncio.sleep(0.01)
|
|
|
|
# Tasks should be cancelled
|
|
assert len(agent._tasks) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_unknown_agent():
|
|
agent = ConcreteTestAgent("sender")
|
|
msg = InternalMessage(to="unknown_receiver", sender="sender", body="boo")
|
|
|
|
agent._internal_pub_socket = AsyncMock()
|
|
|
|
await agent.send(msg)
|
|
|
|
agent._internal_pub_socket.send_multipart.assert_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_agent():
|
|
agent = ConcreteTestAgent("registrant")
|
|
assert AgentDirectory.get("registrant") == agent
|
|
assert AgentDirectory.get("non_existent") is None
|
|
|
|
|
|
class DummyAgent(BaseAgent):
|
|
async def setup(self):
|
|
pass # we will test this separately
|
|
|
|
async def handle_message(self, msg: InternalMessage):
|
|
self.last_handled = msg
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_base_agent_setup_is_noop():
|
|
agent = DummyAgent("dummy")
|
|
|
|
# Should simply return without error
|
|
assert await agent.setup() is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_local_agent(monkeypatch):
|
|
sender = DummyAgent("sender")
|
|
target = DummyAgent("receiver")
|
|
|
|
# Fake logger
|
|
sender.logger = MagicMock()
|
|
|
|
# Patch inbox.put
|
|
target.inbox.put = AsyncMock()
|
|
|
|
message = InternalMessage(to=target.name, sender=sender.name, body="hello")
|
|
|
|
await sender.send(message)
|
|
|
|
target.inbox.put.assert_awaited_once_with(message)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_zmq_agent(monkeypatch):
|
|
sender = DummyAgent("sender")
|
|
target = "remote_receiver"
|
|
|
|
# Fake logger
|
|
sender.logger = MagicMock()
|
|
|
|
# Fake zmq
|
|
sender._internal_pub_socket = AsyncMock()
|
|
|
|
message = InternalMessage(to=target, sender=sender.name, body="hello")
|
|
|
|
await sender.send(message)
|
|
|
|
zmq_calls = sender._internal_pub_socket.send_multipart.call_args[0][0]
|
|
assert zmq_calls[0] == f"internal/{target}".encode()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_multiple_local_agents(monkeypatch):
|
|
sender = DummyAgent("sender")
|
|
target1 = DummyAgent("receiver1")
|
|
target2 = DummyAgent("receiver2")
|
|
|
|
# Fake logger
|
|
sender.logger = MagicMock()
|
|
|
|
# Patch inbox.put
|
|
target1.inbox.put = AsyncMock()
|
|
target2.inbox.put = AsyncMock()
|
|
|
|
message = InternalMessage(to=[target1.name, target2.name], sender=sender.name, body="hello")
|
|
|
|
await sender.send(message)
|
|
|
|
target1.inbox.put.assert_awaited_once_with(message)
|
|
target2.inbox.put.assert_awaited_once_with(message)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_to_multiple_agents(monkeypatch):
|
|
sender = DummyAgent("sender")
|
|
target1 = DummyAgent("receiver1")
|
|
target2 = "remote_receiver"
|
|
|
|
# Fake logger
|
|
sender.logger = MagicMock()
|
|
|
|
# Fake zmq
|
|
sender._internal_pub_socket = AsyncMock()
|
|
|
|
# Patch inbox.put
|
|
target1.inbox.put = AsyncMock()
|
|
|
|
message = InternalMessage(to=[target1.name, target2], sender=sender.name, body="hello")
|
|
|
|
await sender.send(message)
|
|
|
|
target1.inbox.put.assert_awaited_once_with(message)
|
|
zmq_calls = sender._internal_pub_socket.send_multipart.call_args[0][0]
|
|
assert zmq_calls[0] == f"internal/{target2}".encode()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_inbox_calls_handle_message(monkeypatch):
|
|
agent = DummyAgent("dummy")
|
|
agent.logger = MagicMock()
|
|
|
|
# Make agent running so loop triggers
|
|
agent._running = True
|
|
|
|
# Prepare inbox to give one message then stop
|
|
msg = InternalMessage(to="dummy", sender="x", body="test")
|
|
|
|
async def get_once():
|
|
agent._running = False # stop after first iteration
|
|
return msg
|
|
|
|
agent.inbox.get = AsyncMock(side_effect=get_once)
|
|
agent.handle_message = AsyncMock()
|
|
|
|
await agent._process_inbox()
|
|
|
|
agent.handle_message.assert_awaited_once_with(msg)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_internal_zmq_loop_success(monkeypatch):
|
|
agent = DummyAgent("dummy")
|
|
agent.logger = MagicMock()
|
|
agent._running = True
|
|
|
|
mock_socket = MagicMock()
|
|
mock_socket.recv_multipart = AsyncMock(
|
|
side_effect=[
|
|
(
|
|
b"topic",
|
|
InternalMessage(to="dummy", sender="x", body="hi").model_dump_json().encode(),
|
|
),
|
|
asyncio.CancelledError(), # stop loop
|
|
]
|
|
)
|
|
agent._internal_sub_socket = mock_socket
|
|
|
|
agent.inbox.put = AsyncMock()
|
|
|
|
await agent._receive_internal_zmq_loop()
|
|
|
|
agent.inbox.put.assert_awaited() # message forwarded
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_internal_zmq_loop_exception_logs_error():
|
|
agent = DummyAgent("dummy")
|
|
agent.logger = MagicMock()
|
|
agent._running = True
|
|
|
|
mock_socket = MagicMock()
|
|
mock_socket.recv_multipart = AsyncMock(
|
|
side_effect=[Exception("boom"), asyncio.CancelledError()]
|
|
)
|
|
agent._internal_sub_socket = mock_socket
|
|
|
|
agent.inbox.put = AsyncMock()
|
|
|
|
await agent._receive_internal_zmq_loop()
|
|
|
|
agent.logger.exception.assert_called_once()
|
|
assert "Could not process ZMQ message." in agent.logger.exception.call_args[0][0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_base_agent_handle_message_not_implemented():
|
|
class RawAgent(BaseAgent):
|
|
async def setup(self):
|
|
pass
|
|
|
|
agent = RawAgent("raw")
|
|
|
|
msg = InternalMessage(to="raw", sender="x", body="hi")
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
await BaseAgent.handle_message(agent, msg)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_base_agent_setup_abstract_method_body_executes():
|
|
"""
|
|
Covers the 'pass' inside BaseAgent.setup().
|
|
Since BaseAgent is abstract, we do NOT instantiate it.
|
|
We call the coroutine function directly on BaseAgent and pass a dummy self.
|
|
"""
|
|
|
|
class Dummy:
|
|
"""Minimal stub to act as 'self'."""
|
|
|
|
pass
|
|
|
|
stub = Dummy()
|
|
|
|
# Call BaseAgent.setup() as an unbound coroutine, passing stub as 'self'
|
|
result = await BaseAgent.setup(stub)
|
|
|
|
# The method contains only 'pass', so it returns None
|
|
assert result is None
|