refactor: testing
Redid testing structure, added tests and changed some tests. ref: N25B-301
This commit is contained in:
158
test/unit/agents/actuation/test_robot_speech_agent.py
Normal file
158
test/unit/agents/actuation/test_robot_speech_agent.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch(
|
||||
"control_backend.agents.actuation.robot_speech_agent.azmq.Context.instance"
|
||||
)
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_bind(zmq_context, mocker):
|
||||
"""Setup binds and subscribes to internal commands."""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=True)
|
||||
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
# Swallow background task coroutines to avoid un-awaited warnings
|
||||
class Swallow:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def __call__(self, coro):
|
||||
self.calls += 1
|
||||
coro.close()
|
||||
|
||||
swallow = Swallow()
|
||||
agent.add_background_task = swallow
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
|
||||
assert swallow.calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_connect(zmq_context, mocker):
|
||||
"""Setup connects when bind=False."""
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RobotSpeechAgent("robot_speech", address="tcp://localhost:5555", bind=False)
|
||||
settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
|
||||
settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
|
||||
|
||||
class Swallow:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def __call__(self, coro):
|
||||
self.calls += 1
|
||||
coro.close()
|
||||
|
||||
swallow = Swallow()
|
||||
agent.add_background_task = swallow
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.connect.assert_any_call("tcp://internal:1234")
|
||||
assert swallow.calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_sends_command():
|
||||
"""Internal message is forwarded to robot pub socket as JSON."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
payload = {"endpoint": "actuate/speech", "data": "hello"}
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_awaited_once_with(payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_valid_payload(zmq_context):
|
||||
"""UI command is read from SUB and published."""
|
||||
command = {"endpoint": "actuate/speech", "data": "hello"}
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
# stop after first iteration
|
||||
agent._running = False
|
||||
return (b"command", json.dumps(command).encode("utf-8"))
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_awaited_once_with(command)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zmq_command_loop_invalid_json():
|
||||
"""Invalid JSON is ignored without sending."""
|
||||
fake_socket = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return (b"command", b"{not_json}")
|
||||
|
||||
fake_socket.recv_multipart = recv_once
|
||||
fake_socket.send_json = AsyncMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.subsocket = fake_socket
|
||||
agent.pubsocket = fake_socket
|
||||
agent._running = True
|
||||
|
||||
await agent._zmq_command_loop()
|
||||
|
||||
fake_socket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_invalid_payload():
|
||||
"""Invalid payload is caught and does not send."""
|
||||
pubsocket = AsyncMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.pubsocket = pubsocket
|
||||
|
||||
msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"}))
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
pubsocket.send_json.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_closes_sockets():
|
||||
pubsocket = MagicMock()
|
||||
subsocket = MagicMock()
|
||||
agent = RobotSpeechAgent("robot_speech")
|
||||
agent.pubsocket = pubsocket
|
||||
agent.subsocket = subsocket
|
||||
|
||||
await agent.stop()
|
||||
|
||||
pubsocket.close.assert_called_once()
|
||||
subsocket.close.assert_called_once()
|
||||
@@ -1,211 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, MagicMock, call
|
||||
|
||||
import pytest
|
||||
from control_backend.agents.bdi.bdi_core_agent.behaviours.belief_setter_behaviour import (
|
||||
BeliefSetterBehaviour,
|
||||
)
|
||||
|
||||
# Define a constant for the collector agent name to use in tests
|
||||
COLLECTOR_AGENT_NAME = "belief_collector_agent"
|
||||
COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent(mocker):
|
||||
"""Fixture to create a mock BDIAgent."""
|
||||
agent = MagicMock()
|
||||
agent.bdi = MagicMock()
|
||||
agent.jid = "bdi_agent@test"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def belief_setter_behaviour(mock_agent, mocker):
|
||||
"""Fixture to create an instance of BeliefSetterBehaviour with a mocked agent."""
|
||||
# Patch the settings to use a predictable agent name
|
||||
mocker.patch(
|
||||
"control_backend.agents.bdi.bdi_core_agent."
|
||||
"behaviours.belief_setter_behaviour.settings.agent_settings.bdi_belief_collector_name",
|
||||
COLLECTOR_AGENT_NAME,
|
||||
)
|
||||
|
||||
setter = BeliefSetterBehaviour()
|
||||
setter.agent = mock_agent
|
||||
# Mock the receive method, we will control its return value in each test
|
||||
setter.receive = AsyncMock()
|
||||
return setter
|
||||
|
||||
|
||||
def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
|
||||
"""Helper function to create a configured mock message."""
|
||||
msg = MagicMock()
|
||||
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
|
||||
msg.body = body
|
||||
msg.thread = thread
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_message_received(belief_setter_behaviour, mocker):
|
||||
"""
|
||||
Test that when a message is received, _process_message is called.
|
||||
"""
|
||||
# Arrange
|
||||
msg = MagicMock()
|
||||
belief_setter_behaviour.receive.return_value = msg
|
||||
mocker.patch.object(belief_setter_behaviour, "_process_message")
|
||||
|
||||
# Act
|
||||
await belief_setter_behaviour.run()
|
||||
|
||||
# Assert
|
||||
belief_setter_behaviour._process_message.assert_called_once_with(msg)
|
||||
|
||||
|
||||
def test_process_message_from_bdi_belief_collector_agent(belief_setter_behaviour, mocker):
|
||||
"""
|
||||
Test processing a message from the correct belief collector agent.
|
||||
"""
|
||||
# Arrange
|
||||
msg = create_mock_message(sender_node=COLLECTOR_AGENT_NAME, body="", thread="")
|
||||
mock_process_belief = mocker.patch.object(belief_setter_behaviour, "_process_belief_message")
|
||||
|
||||
# Act
|
||||
belief_setter_behaviour._process_message(msg)
|
||||
|
||||
# Assert
|
||||
mock_process_belief.assert_called_once_with(msg)
|
||||
|
||||
|
||||
def test_process_message_from_other_agent(belief_setter_behaviour, mocker):
|
||||
"""
|
||||
Test that messages from other agents are ignored.
|
||||
"""
|
||||
# Arrange
|
||||
msg = create_mock_message(sender_node="other_agent", body="", thread="")
|
||||
mock_process_belief = mocker.patch.object(belief_setter_behaviour, "_process_belief_message")
|
||||
|
||||
# Act
|
||||
belief_setter_behaviour._process_message(msg)
|
||||
|
||||
# Assert
|
||||
mock_process_belief.assert_not_called()
|
||||
|
||||
|
||||
def test_process_belief_message_valid_json(belief_setter_behaviour, mocker):
|
||||
"""
|
||||
Test processing a valid belief message with correct thread and JSON body.
|
||||
"""
|
||||
# Arrange
|
||||
beliefs_payload = {"is_hot": ["kitchen"], "is_clean": ["kitchen", "bathroom"]}
|
||||
msg = create_mock_message(
|
||||
sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs"
|
||||
)
|
||||
mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
|
||||
|
||||
# Act
|
||||
belief_setter_behaviour._process_belief_message(msg)
|
||||
|
||||
# Assert
|
||||
mock_set_beliefs.assert_called_once_with(beliefs_payload)
|
||||
|
||||
|
||||
def test_process_belief_message_invalid_json(belief_setter_behaviour, mocker, caplog):
|
||||
"""
|
||||
Test that a message with invalid JSON is handled gracefully and an error is logged.
|
||||
"""
|
||||
# Arrange
|
||||
msg = create_mock_message(
|
||||
sender_node=COLLECTOR_AGENT_JID, body="this is not a json string", thread="beliefs"
|
||||
)
|
||||
mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
|
||||
|
||||
# Act
|
||||
belief_setter_behaviour._process_belief_message(msg)
|
||||
|
||||
# Assert
|
||||
mock_set_beliefs.assert_not_called()
|
||||
|
||||
|
||||
def test_process_belief_message_wrong_thread(belief_setter_behaviour, mocker):
|
||||
"""
|
||||
Test that a message with an incorrect thread is ignored.
|
||||
"""
|
||||
# Arrange
|
||||
msg = create_mock_message(
|
||||
sender_node=COLLECTOR_AGENT_JID, body='{"some": "data"}', thread="not_beliefs"
|
||||
)
|
||||
mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
|
||||
|
||||
# Act
|
||||
belief_setter_behaviour._process_belief_message(msg)
|
||||
|
||||
# Assert
|
||||
mock_set_beliefs.assert_not_called()
|
||||
|
||||
|
||||
def test_process_belief_message_empty_body(belief_setter_behaviour, mocker):
|
||||
"""
|
||||
Test that a message with an empty body is ignored.
|
||||
"""
|
||||
# Arrange
|
||||
msg = create_mock_message(sender_node=COLLECTOR_AGENT_JID, body="", thread="beliefs")
|
||||
mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
|
||||
|
||||
# Act
|
||||
belief_setter_behaviour._process_belief_message(msg)
|
||||
|
||||
# Assert
|
||||
mock_set_beliefs.assert_not_called()
|
||||
|
||||
|
||||
def test_set_beliefs_success(belief_setter_behaviour, mock_agent, caplog):
|
||||
"""
|
||||
Test that beliefs are correctly set on the agent's BDI.
|
||||
"""
|
||||
# Arrange
|
||||
beliefs_to_set = {
|
||||
"is_hot": ["kitchen"],
|
||||
"door_opened": ["front_door", "back_door"],
|
||||
}
|
||||
|
||||
# Act
|
||||
with caplog.at_level(logging.INFO):
|
||||
belief_setter_behaviour._set_beliefs(beliefs_to_set)
|
||||
|
||||
# Assert
|
||||
expected_calls = [
|
||||
call("is_hot", "kitchen"),
|
||||
call("door_opened", "front_door", "back_door"),
|
||||
]
|
||||
mock_agent.bdi.set_belief.assert_has_calls(expected_calls, any_order=True)
|
||||
assert mock_agent.bdi.set_belief.call_count == 2
|
||||
|
||||
|
||||
# def test_responded_unset(belief_setter_behaviour, mock_agent):
|
||||
# # Arrange
|
||||
# new_beliefs = {"user_said": ["message"]}
|
||||
#
|
||||
# # Act
|
||||
# belief_setter_behaviour._set_beliefs(new_beliefs)
|
||||
#
|
||||
# # Assert
|
||||
# mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")])
|
||||
# mock_agent.bdi.remove_belief.assert_has_calls([call("responded")])
|
||||
|
||||
# def test_set_beliefs_bdi_not_initialized(belief_setter_behaviour, mock_agent, caplog):
|
||||
# """
|
||||
# Test that a warning is logged if the agent's BDI is not initialized.
|
||||
# """
|
||||
# # Arrange
|
||||
# mock_agent.bdi = None # Simulate BDI not being ready
|
||||
# beliefs_to_set = {"is_hot": ["kitchen"]}
|
||||
#
|
||||
# # Act
|
||||
# with caplog.at_level(logging.WARNING):
|
||||
# belief_setter_behaviour._set_beliefs(beliefs_to_set)
|
||||
#
|
||||
# # Assert
|
||||
# assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text
|
||||
103
test/unit/agents/bdi/test_bdi_core_agent.py
Normal file
103
test/unit/agents/bdi/test_bdi_core_agent.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.bdi.bdi_core_agent.bdi_core_agent import BDICoreAgent
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agentspeak_env():
|
||||
with patch("agentspeak.runtime.Environment") as mock_env:
|
||||
yield mock_env
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
agent = BDICoreAgent("bdi_agent", "dummy.asl")
|
||||
agent.send = AsyncMock()
|
||||
agent.bdi_agent = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_loads_asl(mock_agentspeak_env, agent):
|
||||
# Mock file opening
|
||||
with patch("builtins.open", mock_open(read_data="+initial_goal.")):
|
||||
await agent.setup()
|
||||
|
||||
# Check if environment tried to build agent
|
||||
mock_agentspeak_env.return_value.build_agent.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_no_asl(mock_agentspeak_env, agent):
|
||||
with patch("builtins.open", side_effect=FileNotFoundError):
|
||||
await agent.setup()
|
||||
|
||||
mock_agentspeak_env.return_value.build_agent.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_belief_collector_message(agent):
|
||||
"""Test that incoming beliefs are added to the BDI agent"""
|
||||
# Simulate message from belief collector
|
||||
import json
|
||||
|
||||
beliefs = {"user_said": ["Hello"]}
|
||||
msg = InternalMessage(
|
||||
to="bdi_agent",
|
||||
sender=settings.agent_settings.bdi_belief_collector_name,
|
||||
body=json.dumps(beliefs),
|
||||
thread="beliefs",
|
||||
)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Expect bdi_agent.call to be triggered to add belief
|
||||
assert agent.bdi_agent.call.called
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_llm_response(agent):
|
||||
"""Test that LLM responses are forwarded to the Robot Speech Agent"""
|
||||
msg = InternalMessage(
|
||||
to="bdi_agent", sender=settings.agent_settings.llm_name, body="This is the LLM reply"
|
||||
)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Verify forward
|
||||
assert agent.send.called
|
||||
sent_msg = agent.send.call_args[0][0]
|
||||
assert sent_msg.to == settings.agent_settings.robot_speech_name
|
||||
assert "This is the LLM reply" in sent_msg.body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_actions(agent):
|
||||
agent._send_to_llm = MagicMock(side_effect=agent.send) # Mock specific method
|
||||
|
||||
# Initialize actions manually since we didn't call setup with real file
|
||||
agent._add_custom_actions()
|
||||
|
||||
# Find the action
|
||||
action_fn = None
|
||||
for (functor, _), fn in agent.actions.actions.items():
|
||||
if functor == ".reply":
|
||||
action_fn = fn
|
||||
break
|
||||
|
||||
assert action_fn is not None
|
||||
|
||||
# Invoke action
|
||||
mock_term = MagicMock()
|
||||
mock_term.args = ["Hello"]
|
||||
mock_intention = MagicMock()
|
||||
|
||||
# Run generator
|
||||
gen = action_fn(agent, mock_term, mock_intention)
|
||||
next(gen) # Execute
|
||||
|
||||
agent._send_to_llm.assert_called_with("Hello")
|
||||
@@ -7,16 +7,6 @@ from control_backend.agents.bdi.text_belief_extractor_agent.text_belief_extracto
|
||||
TextBeliefExtractorAgent,
|
||||
)
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
from control_backend.core.config import settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_settings(monkeypatch):
|
||||
monkeypatch.setattr(settings.agent_settings, "transcription_name", "transcriber", raising=False)
|
||||
monkeypatch.setattr(
|
||||
settings.agent_settings, "bdi_belief_collector_name", "collector", raising=False
|
||||
)
|
||||
monkeypatch.setattr(settings.agent_settings, "host", "fake.host", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -40,29 +30,29 @@ async def test_handle_message_ignores_other_agents(agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_from_transcriber(agent):
|
||||
async def test_handle_message_from_transcriber(agent, mock_settings):
|
||||
transcription = "hello world"
|
||||
msg = make_msg(settings.agent_settings.transcription_name, transcription, None)
|
||||
msg = make_msg(mock_settings.agent_settings.transcription_name, transcription, None)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
||||
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
||||
assert sent.to == settings.agent_settings.bdi_belief_collector_name
|
||||
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
||||
assert sent.thread == "beliefs"
|
||||
parsed = json.loads(sent.body)
|
||||
assert parsed == {"beliefs": {"user_said": [transcription]}, "type": "belief_extraction_text"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_transcription_demo(agent):
|
||||
async def test_process_transcription_demo(agent, mock_settings):
|
||||
transcription = "this is a test"
|
||||
|
||||
await agent._process_transcription_demo(transcription)
|
||||
|
||||
agent.send.assert_awaited_once() # noqa # `agent.send` has no such property, but we mock it.
|
||||
sent: InternalMessage = agent.send.call_args.args[0] # noqa
|
||||
assert sent.to == settings.agent_settings.bdi_belief_collector_name
|
||||
assert sent.to == mock_settings.agent_settings.bdi_belief_collector_name
|
||||
assert sent.thread == "beliefs"
|
||||
parsed = json.loads(sent.body)
|
||||
assert parsed["beliefs"]["user_said"] == [transcription]
|
||||
354
test/unit/agents/communication/test_ri_communication_agent.py
Normal file
354
test/unit/agents/communication/test_ri_communication_agent.py
Normal file
@@ -0,0 +1,354 @@
|
||||
import asyncio
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
|
||||
|
||||
|
||||
def speech_agent_path():
|
||||
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zmq_context(mocker):
|
||||
mock_context = mocker.patch(
|
||||
"control_backend.agents.communication.ri_communication_agent.Context.instance"
|
||||
)
|
||||
mock_context.return_value = MagicMock()
|
||||
return mock_context
|
||||
|
||||
|
||||
def negotiation_message(
|
||||
actuation_port: int = 5556,
|
||||
bind_main: bool = False,
|
||||
bind_actuation: bool = True,
|
||||
main_port: int = 5555,
|
||||
):
|
||||
return {
|
||||
"endpoint": "negotiate/ports",
|
||||
"data": [
|
||||
{"id": "main", "port": main_port, "bind": bind_main},
|
||||
{"id": "actuation", "port": actuation_port, "bind": bind_actuation},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_success_connects_and_starts_robot(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
||||
robot_instance = MockRobot.return_value
|
||||
robot_instance.start = AsyncMock()
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
|
||||
class Swallow:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def __call__(self, coro):
|
||||
self.calls += 1
|
||||
coro.close()
|
||||
|
||||
swallow = Swallow()
|
||||
agent.add_background_task = swallow
|
||||
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:5555")
|
||||
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
|
||||
robot_instance.start.assert_awaited_once()
|
||||
MockRobot.assert_called_once_with(ANY, address="tcp://*:5556", bind=True)
|
||||
assert swallow.calls == 1
|
||||
assert agent.connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_binds_when_requested(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message(bind_main=True))
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=True)
|
||||
|
||||
class Swallow:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def __call__(self, coro):
|
||||
self.calls += 1
|
||||
coro.close()
|
||||
|
||||
swallow = Swallow()
|
||||
agent.add_background_task = swallow
|
||||
|
||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
||||
MockRobot.return_value.start = AsyncMock()
|
||||
await agent.setup()
|
||||
|
||||
fake_socket.bind.assert_any_call("tcp://localhost:5555")
|
||||
assert swallow.calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_invalid_endpoint_retries(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
success = await agent._negotiate_connection(max_retries=1)
|
||||
|
||||
assert success is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_timeout(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
fake_socket.send_multipart = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
success = await agent._negotiate_connection(max_retries=1)
|
||||
|
||||
assert success is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_negotiation_response_updates_req_socket(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
agent = RICommunicationAgent("ri_comm", address="tcp://localhost:5555", bind=False)
|
||||
agent._req_socket = fake_socket
|
||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
||||
MockRobot.return_value.start = AsyncMock()
|
||||
await agent._handle_negotiation_response(
|
||||
negotiation_message(
|
||||
main_port=6000,
|
||||
actuation_port=6001,
|
||||
bind_main=False,
|
||||
bind_actuation=False,
|
||||
)
|
||||
)
|
||||
|
||||
fake_socket.connect.assert_any_call("tcp://localhost:6000")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_disconnection_publishes_and_reconnects():
|
||||
pub_socket = AsyncMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent.pub_socket = pub_socket
|
||||
agent.connected = True
|
||||
agent._negotiate_connection = AsyncMock(return_value=True)
|
||||
|
||||
await agent._handle_disconnection()
|
||||
|
||||
pub_socket.send_multipart.assert_awaited()
|
||||
assert agent.connected is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_handles_non_ping(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"endpoint": "negotiate/ports", "data": {}}
|
||||
|
||||
fake_socket.recv_json = recv_once
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
fake_socket.send_json.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_unexpected_error(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom"))
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
|
||||
assert await agent._negotiate_connection(max_retries=1) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negotiate_handle_response_error(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock(return_value=negotiation_message())
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent._handle_negotiation_response = AsyncMock(side_effect=Exception("bad response"))
|
||||
|
||||
assert await agent._negotiate_connection(max_retries=1) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
fake_socket.recv_json = AsyncMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
|
||||
async def swallow(coro):
|
||||
coro.close()
|
||||
|
||||
agent.add_background_task = swallow
|
||||
agent._negotiate_connection = AsyncMock(return_value=False)
|
||||
|
||||
await agent.setup()
|
||||
|
||||
assert agent.connected is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_negotiation_response_unhandled_id():
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
|
||||
await agent._handle_negotiation_response(
|
||||
{"data": [{"id": "other", "port": 5000, "bind": False}]}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_closes_sockets():
|
||||
req = MagicMock()
|
||||
pub = MagicMock()
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = pub
|
||||
|
||||
await agent.stop()
|
||||
|
||||
req.close.assert_called_once()
|
||||
pub.close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_not_connected(monkeypatch):
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._running = True
|
||||
agent.connected = False
|
||||
agent._req_socket = AsyncMock()
|
||||
|
||||
async def fake_sleep(duration):
|
||||
agent._running = False
|
||||
|
||||
monkeypatch.setattr("asyncio.sleep", fake_sleep)
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_send_and_recv_timeout():
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock(side_effect=TimeoutError)
|
||||
req.recv_json = AsyncMock(side_effect=TimeoutError)
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
async def stop_run():
|
||||
agent._running = False
|
||||
|
||||
agent._handle_disconnection = AsyncMock(side_effect=stop_run)
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
agent._handle_disconnection.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_missing_endpoint(monkeypatch):
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock()
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"data": {}}
|
||||
|
||||
req.recv_json = recv_once
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_generic_exception():
|
||||
req = AsyncMock()
|
||||
req.send_json = AsyncMock()
|
||||
req.recv_json = AsyncMock(side_effect=ValueError("boom"))
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = req
|
||||
agent.pub_socket = AsyncMock()
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await agent._listen_loop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_disconnection_timeout(monkeypatch):
|
||||
pub = AsyncMock()
|
||||
pub.send_multipart = AsyncMock(side_effect=TimeoutError)
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent.pub_socket = pub
|
||||
agent._negotiate_connection = AsyncMock(return_value=False)
|
||||
|
||||
await agent._handle_disconnection()
|
||||
|
||||
pub.send_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listen_loop_ping_sends_internal(zmq_context):
|
||||
fake_socket = zmq_context.return_value.socket.return_value
|
||||
fake_socket.send_json = AsyncMock()
|
||||
pub_socket = AsyncMock()
|
||||
|
||||
agent = RICommunicationAgent("ri_comm")
|
||||
agent._req_socket = fake_socket
|
||||
agent.pub_socket = pub_socket
|
||||
agent.connected = True
|
||||
agent._running = True
|
||||
|
||||
async def recv_once():
|
||||
agent._running = False
|
||||
return {"endpoint": "ping", "data": {}}
|
||||
|
||||
fake_socket.recv_json = recv_once
|
||||
|
||||
await agent._listen_loop()
|
||||
|
||||
pub_socket.send_multipart.assert_awaited()
|
||||
124
test/unit/agents/llm/test_llm_agent.py
Normal file
124
test/unit/agents/llm/test_llm_agent.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Mocks `httpx` and tests chunking logic."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.llm.llm_agent import LLMAgent, LLMInstructions
|
||||
from control_backend.core.agent_system import InternalMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
with patch("httpx.AsyncClient") as mock_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_client
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_processing_success(mock_httpx_client, mock_settings):
|
||||
# Setup the mock response for the stream
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
# Simulate stream lines
|
||||
lines = [
|
||||
b'data: {"choices": [{"delta": {"content": "Hello"}}]}',
|
||||
b'data: {"choices": [{"delta": {"content": " world"}}]}',
|
||||
b'data: {"choices": [{"delta": {"content": "."}}]}',
|
||||
b"data: [DONE]",
|
||||
]
|
||||
|
||||
async def aiter_lines_gen():
|
||||
for line in lines:
|
||||
yield line.decode()
|
||||
|
||||
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
||||
|
||||
mock_stream_context = MagicMock()
|
||||
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Configure the client
|
||||
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
||||
|
||||
# Setup Agent
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock() # Mock the send method to verify replies
|
||||
|
||||
# Simulate receiving a message from BDI
|
||||
msg = InternalMessage(
|
||||
to="llm_agent", sender=mock_settings.agent_settings.bdi_core_name, body="Hi"
|
||||
)
|
||||
|
||||
await agent.handle_message(msg)
|
||||
|
||||
# Verification
|
||||
# "Hello world." constitutes one sentence/chunk based on punctuation split
|
||||
# The agent should call send once with the full sentence
|
||||
assert agent.send.called
|
||||
args = agent.send.call_args[0][0]
|
||||
assert args.to == mock_settings.agent_settings.bdi_core_name
|
||||
assert "Hello world." in args.body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_processing_errors(mock_httpx_client, mock_settings):
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock()
|
||||
msg = InternalMessage(to="llm", sender=mock_settings.agent_settings.bdi_core_name, body="Hi")
|
||||
|
||||
# HTTP Error
|
||||
mock_httpx_client.stream = MagicMock(side_effect=httpx.HTTPError("Fail"))
|
||||
await agent.handle_message(msg)
|
||||
assert "LLM service unavailable." in agent.send.call_args[0][0].body
|
||||
|
||||
# General Exception
|
||||
agent.send.reset_mock()
|
||||
mock_httpx_client.stream = MagicMock(side_effect=Exception("Boom"))
|
||||
await agent.handle_message(msg)
|
||||
assert "Error processing the request." in agent.send.call_args[0][0].body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_json_error(mock_httpx_client, mock_settings):
|
||||
# Test malformed JSON in stream
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
async def aiter_lines_gen():
|
||||
yield "data: {bad_json"
|
||||
yield "data: [DONE]"
|
||||
|
||||
mock_response.aiter_lines.side_effect = aiter_lines_gen
|
||||
|
||||
mock_stream_context = MagicMock()
|
||||
mock_stream_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_stream_context.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_httpx_client.stream = MagicMock(return_value=mock_stream_context)
|
||||
|
||||
agent = LLMAgent("llm_agent")
|
||||
agent.send = AsyncMock()
|
||||
|
||||
with patch.object(agent.logger, "error") as log:
|
||||
msg = InternalMessage(
|
||||
to="llm", sender=mock_settings.agent_settings.bdi_core_name, body="Hi"
|
||||
)
|
||||
await agent.handle_message(msg)
|
||||
log.assert_called() # Should log JSONDecodeError
|
||||
|
||||
|
||||
def test_llm_instructions():
|
||||
# Full custom
|
||||
instr = LLMInstructions(norms="N", goals="G")
|
||||
text = instr.build_developer_instruction()
|
||||
assert "Norms to follow:\nN" in text
|
||||
assert "Goals to reach:\nG" in text
|
||||
|
||||
# Defaults
|
||||
instr_def = LLMInstructions()
|
||||
text_def = instr_def.build_developer_instruction()
|
||||
assert "Norms to follow" in text_def
|
||||
assert "Goals to reach" in text_def
|
||||
@@ -0,0 +1,122 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.perception.transcription_agent.speech_recognizer import (
|
||||
MLXWhisperSpeechRecognizer,
|
||||
OpenAIWhisperSpeechRecognizer,
|
||||
SpeechRecognizer,
|
||||
)
|
||||
from control_backend.agents.perception.transcription_agent.transcription_agent import (
|
||||
TranscriptionAgent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_agent_flow(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.recv = AsyncMock()
|
||||
|
||||
# Setup context to return this specific mock socket
|
||||
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
||||
|
||||
# Data: [Audio Bytes, Cancel Loop]
|
||||
fake_audio = np.zeros(16000, dtype=np.float32).tobytes()
|
||||
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
||||
|
||||
# Mock Recognizer
|
||||
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
||||
mock_recognizer = MagicMock()
|
||||
mock_recognizer.recognize_speech.return_value = "Hello"
|
||||
mock_best.return_value = mock_recognizer
|
||||
|
||||
agent = TranscriptionAgent("tcp://in")
|
||||
agent.send = AsyncMock()
|
||||
|
||||
agent._running = True
|
||||
agent.add_background_task = AsyncMock()
|
||||
|
||||
await agent.setup()
|
||||
|
||||
try:
|
||||
await agent._transcribing_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Check transcription happened
|
||||
assert mock_recognizer.recognize_speech.called
|
||||
# Check sending
|
||||
assert agent.send.called
|
||||
assert agent.send.call_args[0][0].body == "Hello"
|
||||
|
||||
await agent.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_empty(mock_zmq_context):
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.recv = AsyncMock()
|
||||
mock_zmq_context.instance.return_value.socket.return_value = mock_sub
|
||||
|
||||
# Return valid audio, but recognizer returns empty string
|
||||
fake_audio = np.zeros(10, dtype=np.float32).tobytes()
|
||||
mock_sub.recv.side_effect = [fake_audio, asyncio.CancelledError()]
|
||||
|
||||
with patch.object(SpeechRecognizer, "best_type") as mock_best:
|
||||
mock_recognizer = MagicMock()
|
||||
mock_recognizer.recognize_speech.return_value = ""
|
||||
mock_best.return_value = mock_recognizer
|
||||
|
||||
agent = TranscriptionAgent("tcp://in")
|
||||
agent.send = AsyncMock()
|
||||
await agent.setup()
|
||||
|
||||
try:
|
||||
await agent._transcribing_loop()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Should NOT send message
|
||||
agent.send.assert_not_called()
|
||||
|
||||
|
||||
def test_speech_recognizer_factory():
|
||||
# Test Factory Logic
|
||||
with patch("torch.mps.is_available", return_value=True):
|
||||
assert isinstance(SpeechRecognizer.best_type(), MLXWhisperSpeechRecognizer)
|
||||
|
||||
with patch("torch.mps.is_available", return_value=False):
|
||||
assert isinstance(SpeechRecognizer.best_type(), OpenAIWhisperSpeechRecognizer)
|
||||
|
||||
|
||||
def test_openai_recognizer():
|
||||
with patch("whisper.load_model") as load_mock:
|
||||
with patch("whisper.transcribe") as trans_mock:
|
||||
rec = OpenAIWhisperSpeechRecognizer()
|
||||
rec.load_model()
|
||||
load_mock.assert_called()
|
||||
|
||||
trans_mock.return_value = {"text": "Hi"}
|
||||
res = rec.recognize_speech(np.zeros(10))
|
||||
assert res == "Hi"
|
||||
|
||||
|
||||
def test_mlx_recognizer():
|
||||
# Fix: On Linux, 'mlx_whisper' isn't imported by the module, so it's missing from dir().
|
||||
# We must use create=True to inject it into the module namespace during the test.
|
||||
module_path = "control_backend.agents.perception.transcription_agent.speech_recognizer"
|
||||
|
||||
with patch("sys.platform", "darwin"):
|
||||
with patch(f"{module_path}.mlx_whisper", create=True) as mlx_mock:
|
||||
with patch(f"{module_path}.ModelHolder", create=True) as holder_mock:
|
||||
# We also need to mock mlx.core if it's used for types/constants
|
||||
with patch(f"{module_path}.mx", create=True):
|
||||
rec = MLXWhisperSpeechRecognizer()
|
||||
rec.load_model()
|
||||
holder_mock.get_model.assert_called()
|
||||
|
||||
mlx_mock.transcribe.return_value = {"text": "Hi"}
|
||||
res = rec.recognize_speech(np.zeros(10))
|
||||
assert res == "Hi"
|
||||
156
test/unit/api/v1/endpoints/test_robot_endpoint.py
Normal file
156
test/unit/api/v1/endpoints/test_robot_endpoint.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from control_backend.api.v1.endpoints import robot
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""
|
||||
Creates a FastAPI test app and attaches the router under test.
|
||||
Also sets up a mock internal_comm_socket.
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.include_router(robot.router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client for the app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_receive_command_success(client):
|
||||
"""
|
||||
Test for successful reception of a command. Ensures the status code is 202 and the response body
|
||||
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
|
||||
expected data.
|
||||
"""
|
||||
# Arrange
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
|
||||
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
|
||||
speech_command = SpeechCommand(**command_data)
|
||||
|
||||
# Act
|
||||
response = client.post("/command", json=command_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Command received"}
|
||||
|
||||
# Verify that the ZMQ socket was used correctly
|
||||
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
||||
[b"command", speech_command.model_dump_json().encode()]
|
||||
)
|
||||
|
||||
|
||||
def test_receive_command_invalid_payload(client):
|
||||
"""
|
||||
Test invalid data handling (schema validation).
|
||||
"""
|
||||
# Missing required field(s)
|
||||
bad_payload = {"invalid": "data"}
|
||||
response = client.post("/command", json=bad_payload)
|
||||
assert response.status_code == 422 # validation error
|
||||
|
||||
|
||||
def test_ping_check_returns_none(client):
|
||||
"""Ensure /ping_check returns 200 and None (currently unimplemented)."""
|
||||
response = client.get("/ping_check")
|
||||
assert response.status_code == 200
|
||||
assert response.json() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_yields_ping_event(monkeypatch):
|
||||
"""Test that ping_stream yields a proper SSE message when a ping is received."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"ping", b"true"])
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
event = await anext(generator)
|
||||
event_text = event.decode() if isinstance(event, bytes) else str(event)
|
||||
assert event_text.strip() == "data: true"
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await anext(generator)
|
||||
|
||||
mock_sub_socket.connect.assert_called_once()
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_handles_timeout(monkeypatch):
|
||||
"""Test that ping_stream continues looping on TimeoutError."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart.side_effect = TimeoutError()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(return_value=True)
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await anext(generator)
|
||||
|
||||
mock_sub_socket.connect.assert_called_once()
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_stream_yields_json_values(monkeypatch):
|
||||
"""Ensure ping_stream correctly parses and yields JSON body values."""
|
||||
mock_sub_socket = AsyncMock()
|
||||
mock_sub_socket.connect = MagicMock()
|
||||
mock_sub_socket.setsockopt = MagicMock()
|
||||
mock_sub_socket.recv_multipart = AsyncMock(
|
||||
return_value=[b"ping", json.dumps({"connected": True}).encode()]
|
||||
)
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.socket.return_value = mock_sub_socket
|
||||
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
|
||||
|
||||
response = await robot.ping_stream(mock_request)
|
||||
generator = aiter(response.body_iterator)
|
||||
|
||||
event = await anext(generator)
|
||||
event_text = event.decode() if isinstance(event, bytes) else str(event)
|
||||
|
||||
assert "connected" in event_text
|
||||
assert "true" in event_text
|
||||
|
||||
mock_sub_socket.connect.assert_called_once()
|
||||
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
|
||||
mock_sub_socket.recv_multipart.assert_awaited()
|
||||
@@ -1,71 +1,43 @@
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.core.agent_system import _agent_directory
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_agent_directory():
|
||||
"""
|
||||
This hook runs at the start of the pytest session, before any tests are
|
||||
collected. It mocks heavy or unavailable modules to prevent ImportErrors.
|
||||
Automatically clears the global agent directory before and after each test
|
||||
to prevent state leakage between tests.
|
||||
"""
|
||||
# --- Mock spade and spade-bdi ---
|
||||
mock_agentspeak = MagicMock()
|
||||
mock_httpx = MagicMock()
|
||||
mock_pydantic = MagicMock()
|
||||
mock_spade = MagicMock()
|
||||
mock_spade.agent = MagicMock()
|
||||
mock_spade.behaviour = MagicMock()
|
||||
mock_spade.message = MagicMock()
|
||||
mock_spade_bdi = MagicMock()
|
||||
mock_spade_bdi.bdi = MagicMock()
|
||||
_agent_directory.clear()
|
||||
yield
|
||||
_agent_directory.clear()
|
||||
|
||||
mock_spade.agent.Message = MagicMock()
|
||||
mock_spade.behaviour.CyclicBehaviour = type("CyclicBehaviour", (object,), {})
|
||||
mock_spade_bdi.bdi.BDIAgent = type("BDIAgent", (object,), {})
|
||||
|
||||
# Ensure submodule imports like `agentspeak.runtime` succeed
|
||||
mock_agentspeak.runtime = MagicMock()
|
||||
mock_agentspeak.stdlib = MagicMock()
|
||||
sys.modules["agentspeak"] = mock_agentspeak
|
||||
sys.modules["agentspeak.runtime"] = mock_agentspeak.runtime
|
||||
sys.modules["agentspeak.stdlib"] = mock_agentspeak.stdlib
|
||||
sys.modules["httpx"] = mock_httpx
|
||||
sys.modules["pydantic"] = mock_pydantic
|
||||
sys.modules["spade"] = mock_spade
|
||||
sys.modules["spade.agent"] = mock_spade.agent
|
||||
sys.modules["spade.behaviour"] = mock_spade.behaviour
|
||||
sys.modules["spade.message"] = mock_spade.message
|
||||
sys.modules["spade_bdi"] = mock_spade_bdi
|
||||
sys.modules["spade_bdi.bdi"] = mock_spade_bdi.bdi
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch("control_backend.core.config.settings") as mock:
|
||||
# Set default values that match the pydantic model defaults
|
||||
# to avoid AttributeErrors during tests
|
||||
mock.zmq_settings.internal_pub_address = "tcp://localhost:5560"
|
||||
mock.zmq_settings.internal_sub_address = "tcp://localhost:5561"
|
||||
mock.zmq_settings.ri_command_address = "tcp://localhost:0000"
|
||||
mock.agent_settings.bdi_core_name = "bdi_core_agent"
|
||||
mock.agent_settings.bdi_belief_collector_name = "belief_collector_agent"
|
||||
mock.agent_settings.llm_name = "llm_agent"
|
||||
mock.agent_settings.robot_speech_name = "robot_speech_agent"
|
||||
mock.agent_settings.transcription_name = "transcription_agent"
|
||||
mock.agent_settings.text_belief_extractor_name = "text_belief_extractor_agent"
|
||||
mock.agent_settings.vad_name = "vad_agent"
|
||||
mock.behaviour_settings.sleep_s = 0.01 # Speed up tests
|
||||
mock.behaviour_settings.comm_setup_max_retries = 1
|
||||
yield mock
|
||||
|
||||
# --- Mock the config module to prevent Pydantic ImportError ---
|
||||
mock_config_module = MagicMock()
|
||||
|
||||
# The code under test does `from ... import settings`, so our mock module
|
||||
# must have a `settings` attribute. We'll make it a MagicMock so we can
|
||||
# configure it later in our tests using mocker.patch.
|
||||
mock_config_module.settings = MagicMock()
|
||||
|
||||
sys.modules["control_backend.core.config"] = mock_config_module
|
||||
|
||||
# --- Mock torch and zmq for VAD ---
|
||||
mock_torch = MagicMock()
|
||||
mock_zmq = MagicMock()
|
||||
mock_zmq.asyncio = mock_zmq
|
||||
|
||||
# In individual tests, these can be imported and the return values changed
|
||||
sys.modules["torch"] = mock_torch
|
||||
sys.modules["zmq"] = mock_zmq
|
||||
sys.modules["zmq.asyncio"] = mock_zmq.asyncio
|
||||
|
||||
# --- Mock whisper ---
|
||||
mock_whisper = MagicMock()
|
||||
mock_mlx = MagicMock()
|
||||
mock_mlx.core = MagicMock()
|
||||
mock_mlx_whisper = MagicMock()
|
||||
mock_mlx_whisper.transcribe = MagicMock()
|
||||
|
||||
sys.modules["whisper"] = mock_whisper
|
||||
sys.modules["mlx"] = mock_mlx
|
||||
sys.modules["mlx.core"] = mock_mlx
|
||||
sys.modules["mlx_whisper"] = mock_mlx_whisper
|
||||
sys.modules["mlx_whisper.transcribe"] = mock_mlx_whisper.transcribe
|
||||
@pytest.fixture
|
||||
def mock_zmq_context():
|
||||
with patch("zmq.asyncio.Context") as mock:
|
||||
mock.instance.return_value = MagicMock()
|
||||
yield mock
|
||||
|
||||
68
test/unit/core/test_agent_system.py
Normal file
68
test/unit/core/test_agent_system.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Test the base class logic, message passing and background task handling."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.core.agent_system import AgentDirectory, BaseAgent, InternalMessage
|
||||
|
||||
|
||||
class ConcreteTestAgent(BaseAgent):
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
self.received = []
|
||||
|
||||
async def setup(self):
|
||||
pass
|
||||
|
||||
async def handle_message(self, msg: InternalMessage):
|
||||
self.received.append(msg)
|
||||
if msg.body == "stop":
|
||||
await self.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_lifecycle():
|
||||
agent = ConcreteTestAgent("lifecycle_agent")
|
||||
await agent.start()
|
||||
assert agent._running is True
|
||||
|
||||
# Test background task
|
||||
async def dummy_task():
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await agent.add_background_task(dummy_task())
|
||||
assert len(agent._tasks) > 0
|
||||
|
||||
# Wait for task to finish
|
||||
await asyncio.sleep(0.02)
|
||||
assert len(agent._tasks) == 1 # _process_inbox is still running
|
||||
|
||||
await agent.stop()
|
||||
assert agent._running is False
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Tasks should be cancelled
|
||||
assert len(agent._tasks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_unknown_agent(caplog):
|
||||
agent = ConcreteTestAgent("sender")
|
||||
msg = InternalMessage(to="unknown_sender", sender="sender", body="boo")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
await agent.send(msg)
|
||||
|
||||
assert "Attempted to send message to unknown agent: unknown_sender" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent():
|
||||
agent = ConcreteTestAgent("registrant")
|
||||
assert AgentDirectory.get("registrant") == agent
|
||||
assert AgentDirectory.get("non_existent") is None
|
||||
14
test/unit/core/test_config.py
Normal file
14
test/unit/core/test_config.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Test if settings load correctly and environment variables override defaults."""
|
||||
|
||||
from control_backend.core.config import Settings
|
||||
|
||||
|
||||
def test_default_settings():
|
||||
settings = Settings()
|
||||
assert settings.app_title == "PepperPlus"
|
||||
|
||||
|
||||
def test_env_override(monkeypatch):
|
||||
monkeypatch.setenv("APP_TITLE", "TestPepper")
|
||||
settings = Settings()
|
||||
assert settings.app_title == "TestPepper"
|
||||
88
test/unit/core/test_logging.py
Normal file
88
test/unit/core/test_logging.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.logging.setup_logging import add_logging_level, setup_logging
|
||||
|
||||
|
||||
def test_add_logging_level():
|
||||
# Add a unique level to avoid conflicts with other tests/libraries
|
||||
level_name = "TESTLEVEL"
|
||||
level_num = 35
|
||||
|
||||
add_logging_level(level_name, level_num)
|
||||
|
||||
assert logging.getLevelName(level_num) == level_name
|
||||
assert hasattr(logging, level_name)
|
||||
assert hasattr(logging.getLoggerClass(), level_name.lower())
|
||||
|
||||
# Test functionality
|
||||
logger = logging.getLogger("test_custom_level")
|
||||
with patch.object(logger, "_log") as mock_log:
|
||||
getattr(logger, level_name.lower())("message")
|
||||
mock_log.assert_called_with(level_num, "message", ())
|
||||
|
||||
# Test duplicates
|
||||
with pytest.raises(AttributeError):
|
||||
add_logging_level(level_name, level_num)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
add_logging_level("INFO", 20) # Existing level
|
||||
|
||||
|
||||
def test_setup_logging_no_file(caplog):
|
||||
with patch("os.path.exists", return_value=False):
|
||||
setup_logging("dummy.yaml")
|
||||
assert "Logging config file not found" in caplog.text
|
||||
|
||||
|
||||
def test_setup_logging_yaml_error(caplog):
|
||||
with patch("os.path.exists", return_value=True):
|
||||
with patch("builtins.open", mock_open(read_data="invalid: [yaml")):
|
||||
with patch("logging.config.dictConfig") as mock_dict_config:
|
||||
setup_logging("config.yaml")
|
||||
|
||||
# Verify we logged the warning
|
||||
assert "Could not load logging configuration" in caplog.text
|
||||
# Verify dictConfig was called with empty dict (which would crash real dictConfig)
|
||||
mock_dict_config.assert_called_with({})
|
||||
assert "Could not load logging configuration" in caplog.text
|
||||
|
||||
|
||||
def test_setup_logging_success():
|
||||
config_data = """
|
||||
version: 1
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
root:
|
||||
handlers: [console]
|
||||
level: INFO
|
||||
custom_levels:
|
||||
MYLEVEL: 15
|
||||
"""
|
||||
with patch("os.path.exists", return_value=True):
|
||||
with patch("builtins.open", mock_open(read_data=config_data)):
|
||||
with patch("logging.config.dictConfig") as mock_dict_config:
|
||||
setup_logging("config.yaml")
|
||||
mock_dict_config.assert_called()
|
||||
assert hasattr(logging, "MYLEVEL")
|
||||
|
||||
|
||||
def test_setup_logging_zmq_handler(mock_zmq_context):
|
||||
config_data = """
|
||||
version: 1
|
||||
handlers:
|
||||
ui:
|
||||
class: logging.NullHandler
|
||||
# In real config this would be a zmq handler, but for unit test logic
|
||||
# we just want to see if the socket injection happens
|
||||
"""
|
||||
with patch("os.path.exists", return_value=True):
|
||||
with patch("builtins.open", mock_open(read_data=config_data)):
|
||||
with patch("logging.config.dictConfig") as mock_dict_config:
|
||||
setup_logging("config.yaml")
|
||||
|
||||
args = mock_dict_config.call_args[0][0]
|
||||
assert "interface_or_socket" in args["handlers"]["ui"]
|
||||
26
test/unit/schemas/test_ri_message.py
Normal file
26
test/unit/schemas/test_ri_message.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
|
||||
|
||||
|
||||
def valid_command_1():
|
||||
return SpeechCommand(data="Hallo?")
|
||||
|
||||
|
||||
def invalid_command_1():
|
||||
return RIMessage(endpoint=RIEndpoint.PING, data="Hello again.")
|
||||
|
||||
|
||||
def test_valid_speech_command_1():
|
||||
command = valid_command_1()
|
||||
RIMessage.model_validate(command)
|
||||
SpeechCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_invalid_speech_command_1():
|
||||
command = invalid_command_1()
|
||||
RIMessage.model_validate(command)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
SpeechCommand.model_validate(command)
|
||||
85
test/unit/schemas/test_ui_program_message.py
Normal file
85
test/unit/schemas/test_ui_program_message.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.program import Goal, Norm, Phase, PhaseData, Program, Trigger
|
||||
|
||||
|
||||
def base_norm() -> Norm:
|
||||
return Norm(
|
||||
id="norm1",
|
||||
name="testNorm",
|
||||
value="you should act nice",
|
||||
)
|
||||
|
||||
|
||||
def base_goal() -> Goal:
|
||||
return Goal(
|
||||
id="goal1",
|
||||
name="testGoal",
|
||||
description="you should act nice",
|
||||
achieved=False,
|
||||
)
|
||||
|
||||
|
||||
def base_trigger() -> Trigger:
|
||||
return Trigger(
|
||||
id="trigger1",
|
||||
label="testTrigger",
|
||||
type="keyword",
|
||||
value=["Stop", "Exit"],
|
||||
)
|
||||
|
||||
|
||||
def base_phase_data() -> PhaseData:
|
||||
return PhaseData(
|
||||
norms=[base_norm()],
|
||||
goals=[base_goal()],
|
||||
triggers=[base_trigger()],
|
||||
)
|
||||
|
||||
|
||||
def base_phase() -> Phase:
|
||||
return Phase(
|
||||
id="phase1",
|
||||
name="basephase",
|
||||
nextPhaseId="phase2",
|
||||
phaseData=base_phase_data(),
|
||||
)
|
||||
|
||||
|
||||
def base_program() -> Program:
|
||||
return Program(phases=[base_phase()])
|
||||
|
||||
|
||||
def invalid_program() -> dict:
|
||||
# wrong types inside phases list (not Phase objects)
|
||||
return {
|
||||
"phases": [
|
||||
{"id": "phase1"}, # incomplete
|
||||
{"not_a_phase": True},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_valid_program():
|
||||
program = base_program()
|
||||
validated = Program.model_validate(program)
|
||||
assert isinstance(validated, Program)
|
||||
assert validated.phases[0].phaseData.norms[0].name == "testNorm"
|
||||
|
||||
|
||||
def test_valid_deepprogram():
|
||||
program = base_program()
|
||||
validated = Program.model_validate(program)
|
||||
# validate nested components directly
|
||||
phase = validated.phases[0]
|
||||
assert isinstance(phase.phaseData, PhaseData)
|
||||
assert isinstance(phase.phaseData.goals[0], Goal)
|
||||
assert isinstance(phase.phaseData.triggers[0], Trigger)
|
||||
assert isinstance(phase.phaseData.norms[0], Norm)
|
||||
|
||||
|
||||
def test_invalid_program():
|
||||
bad = invalid_program()
|
||||
with pytest.raises(ValidationError):
|
||||
Program.model_validate(bad)
|
||||
Reference in New Issue
Block a user