Merge remote-tracking branch 'origin/dev' into refactor/zmq-internal-socket-behaviour
# Conflicts: # src/control_backend/agents/ri_command_agent.py # src/control_backend/agents/ri_communication_agent.py # src/control_backend/api/v1/endpoints/command.py # src/control_backend/main.py # test/integration/api/endpoints/test_command_endpoint.py
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import zmq
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||
from control_backend.schemas.ri_message import SpeechCommand
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, ANY
|
||||
|
||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||
|
||||
|
||||
@@ -177,8 +179,8 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
|
||||
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
|
||||
# better response, within a limited time.
|
||||
# We are sending wrong negotiation info to the communication agent,
|
||||
# so we should retry and expect a better response, within a limited time.
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
) as MockCommandAgent:
|
||||
@@ -330,8 +332,8 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
|
||||
|
||||
# Mock RICommandAgent agent startup
|
||||
|
||||
# We are sending wrong negotiation info to the communication agent, so we should retry and expect a
|
||||
# better response, within a limited time.
|
||||
# We are sending wrong negotiation info to the communication agent,
|
||||
# so we should retry and expect a better response, within a limited time.
|
||||
with patch(
|
||||
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
|
||||
) as MockCommandAgent:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
@@ -16,6 +16,7 @@ def app():
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.include_router(command.router)
|
||||
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
|
||||
return app
|
||||
|
||||
|
||||
@@ -25,32 +26,32 @@ def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("control_backend.api.v1.endpoints.command.Context.instance")
|
||||
async def test_receive_command_success(mock_context_instance, client):
|
||||
def test_receive_command_endpoint(client, app):
|
||||
"""
|
||||
Test for successful reception of a command.
|
||||
Ensures the status code is 202 and the response body is correct.
|
||||
It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data.
|
||||
Test that a POST to /command sends the right multipart message
|
||||
and returns a 202 with the expected JSON body.
|
||||
"""
|
||||
# Arrange
|
||||
mock_pub_socket = AsyncMock()
|
||||
client.app.state.endpoints_pub_socket = mock_pub_socket
|
||||
mock_socket = app.state.internal_comm_socket
|
||||
|
||||
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
|
||||
speech_command = SpeechCommand(**command_data)
|
||||
# Prepare test payload that matches SpeechCommand
|
||||
payload = {"endpoint": "actuate/speech", "data": "yooo"}
|
||||
|
||||
# Act
|
||||
response = client.post("/command", json=command_data)
|
||||
# Send POST request
|
||||
response = client.post("/command", json=payload)
|
||||
|
||||
# Assert
|
||||
# Check response
|
||||
assert response.status_code == 202
|
||||
assert response.json() == {"status": "Command received"}
|
||||
|
||||
# Verify that the ZMQ socket was used correctly
|
||||
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
||||
[b"command", speech_command.model_dump_json().encode()]
|
||||
)
|
||||
# Verify that the socket was called with the correct data
|
||||
assert mock_socket.send_multipart.called, "Socket should be used to send data"
|
||||
|
||||
args, kwargs = mock_socket.send_multipart.call_args
|
||||
sent_data = args[0]
|
||||
|
||||
assert sent_data[0] == b"command"
|
||||
# Check JSON encoding roughly matches
|
||||
assert isinstance(SpeechCommand.model_validate_json(sent_data[1].decode()), SpeechCommand)
|
||||
|
||||
|
||||
def test_receive_command_invalid_payload(client):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
from control_backend.schemas.ri_message import RIMessage, RIEndpoint, SpeechCommand
|
||||
from pydantic import ValidationError
|
||||
|
||||
from control_backend.schemas.ri_message import RIEndpoint, RIMessage, SpeechCommand
|
||||
|
||||
|
||||
def valid_command_1():
|
||||
return SpeechCommand(data="Hallo?")
|
||||
@@ -13,24 +14,13 @@ def invalid_command_1():
|
||||
|
||||
def test_valid_speech_command_1():
|
||||
command = valid_command_1()
|
||||
try:
|
||||
RIMessage.model_validate(command)
|
||||
SpeechCommand.model_validate(command)
|
||||
assert True
|
||||
except ValidationError:
|
||||
assert False
|
||||
RIMessage.model_validate(command)
|
||||
SpeechCommand.model_validate(command)
|
||||
|
||||
|
||||
def test_invalid_speech_command_1():
|
||||
command = invalid_command_1()
|
||||
passed_ri_message_validation = False
|
||||
try:
|
||||
# Should succeed, still.
|
||||
RIMessage.model_validate(command)
|
||||
passed_ri_message_validation = True
|
||||
RIMessage.model_validate(command)
|
||||
|
||||
# Should fail.
|
||||
with pytest.raises(ValidationError):
|
||||
SpeechCommand.model_validate(command)
|
||||
assert False
|
||||
except ValidationError:
|
||||
assert passed_ri_message_validation
|
||||
|
||||
@@ -203,6 +203,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
|
||||
assert "Set belief is_hot with arguments ['kitchen']" in caplog.text
|
||||
assert "Set belief door_opened with arguments ['front_door', 'back_door']" in caplog.text
|
||||
|
||||
|
||||
# def test_responded_unset(belief_setter, mock_agent):
|
||||
# # Arrange
|
||||
# new_beliefs = {"user_said": ["message"]}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import MagicMock, AsyncMock, call
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from control_backend.agents.belief_collector.behaviours.continuous_collect import ContinuousBeliefCollector
|
||||
from control_backend.agents.belief_collector.behaviours.continuous_collect import (
|
||||
ContinuousBeliefCollector,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent(mocker):
|
||||
@@ -13,18 +15,20 @@ def mock_agent(mocker):
|
||||
agent.jid = "belief_collector_agent@test"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def continuous_collector(mock_agent, mocker):
|
||||
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
|
||||
# Patch asyncio.sleep to prevent tests from actually waiting
|
||||
mocker.patch("asyncio.sleep", return_value=None)
|
||||
|
||||
|
||||
collector = ContinuousBeliefCollector()
|
||||
collector.agent = mock_agent
|
||||
# Mock the receive method, we will control its return value in each test
|
||||
collector.receive = AsyncMock()
|
||||
return collector
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_no_message_received(continuous_collector, mocker):
|
||||
"""
|
||||
@@ -40,6 +44,7 @@ async def test_run_no_message_received(continuous_collector, mocker):
|
||||
# Assert
|
||||
continuous_collector._process_message.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_message_received(continuous_collector, mocker):
|
||||
"""
|
||||
@@ -55,7 +60,8 @@ async def test_run_message_received(continuous_collector, mocker):
|
||||
|
||||
# Assert
|
||||
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_invalid(continuous_collector, mocker):
|
||||
"""
|
||||
@@ -66,15 +72,18 @@ async def test_process_message_invalid(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
msg.body = invalid_json
|
||||
msg.sender = "belief_text_agent_mock@test"
|
||||
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
|
||||
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
|
||||
# Act
|
||||
await continuous_collector._process_message(msg)
|
||||
|
||||
# Assert
|
||||
logger_mock.warning.assert_called_once()
|
||||
|
||||
|
||||
def test_get_sender_from_message(continuous_collector):
|
||||
"""
|
||||
Test that _sender_node correctly extracts the sender node from the message JID.
|
||||
@@ -89,6 +98,7 @@ def test_get_sender_from_message(continuous_collector):
|
||||
# Assert
|
||||
assert sender_node == "agent_node"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
@@ -98,6 +108,7 @@ async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
@@ -107,6 +118,7 @@ async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mock
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
@@ -116,50 +128,64 @@ async def test_routes_to_handle_emo_text(continuous_collector, mocker):
|
||||
await continuous_collector._process_message(msg)
|
||||
spy.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unrecognized_message_logs_info(continuous_collector, mocker):
|
||||
msg = MagicMock()
|
||||
msg.body = json.dumps({"type": "something_else"})
|
||||
msg.sender = "x@test"
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._process_message(msg)
|
||||
logger_mock.info.assert_any_call(
|
||||
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.", "x", "something_else"
|
||||
"BeliefCollector: unrecognized message (sender=%s, type=%r). Ignoring.",
|
||||
"x",
|
||||
"something_else",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_no_beliefs(continuous_collector, mocker):
|
||||
msg_payload = {"type": "belief_extraction_text"} # no 'beliefs'
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(msg_payload, "origin_node")
|
||||
logger_mock.info.assert_any_call("BeliefCollector: no beliefs to process.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_beliefs_not_dict(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": ["not", "a", "dict"]}
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "origin")
|
||||
logger_mock.warning.assert_any_call("BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"])
|
||||
logger_mock.warning.assert_any_call(
|
||||
"BeliefCollector: 'beliefs' is not a dict: %r", ["not", "a", "dict"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_values_not_lists(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": "not-a-list"}}
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "origin")
|
||||
logger_mock.warning.assert_any_call(
|
||||
"BeliefCollector: 'beliefs' values are not all lists: %r", {"user_said": "not-a-list"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector, mocker):
|
||||
payload = {
|
||||
"type": "belief_extraction_text",
|
||||
"beliefs": {"user_said": ["hello test", "No"]}
|
||||
}
|
||||
# Your code calls self.send(..); patch it (or switch implementation to self.agent.send and patch that)
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
|
||||
continuous_collector.send = AsyncMock()
|
||||
logger_mock = mocker.patch("control_backend.agents.belief_collector.behaviours.continuous_collect.logger")
|
||||
logger_mock = mocker.patch(
|
||||
"control_backend.agents.belief_collector.behaviours.continuous_collect.logger"
|
||||
)
|
||||
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
|
||||
|
||||
logger_mock.info.assert_any_call("BeliefCollector: forwarding %d beliefs.", 1)
|
||||
@@ -169,12 +195,14 @@ async def test_belief_text_happy_path_logs_items_and_sends(continuous_collector,
|
||||
# make sure we attempted a send
|
||||
continuous_collector.send.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_beliefs_noop_on_empty(continuous_collector):
|
||||
continuous_collector.send = AsyncMock()
|
||||
await continuous_collector._send_beliefs_to_bdi([], origin="o")
|
||||
continuous_collector.send.assert_not_awaited()
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_send_beliefs_sends_json_packet(continuous_collector):
|
||||
# # Patch .send and capture the message body
|
||||
@@ -191,19 +219,22 @@ async def test_send_beliefs_noop_on_empty(continuous_collector):
|
||||
# assert "belief_packet" in json.loads(sent["body"])["type"]
|
||||
# assert json.loads(sent["body"])["beliefs"] == beliefs
|
||||
|
||||
|
||||
def test_sender_node_no_sender_returns_literal(continuous_collector):
|
||||
msg = MagicMock()
|
||||
msg.sender = None
|
||||
assert continuous_collector._sender_node(msg) == "no_sender"
|
||||
|
||||
|
||||
def test_sender_node_without_at(continuous_collector):
|
||||
msg = MagicMock()
|
||||
msg.sender = "localpartonly"
|
||||
assert continuous_collector._sender_node(msg) == "localpartonly"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
|
||||
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
|
||||
continuous_collector.send = AsyncMock()
|
||||
await continuous_collector._handle_belief_text(payload, "origin")
|
||||
continuous_collector.send.assert_awaited_once()
|
||||
continuous_collector.send.assert_awaited_once()
|
||||
|
||||
@@ -6,7 +6,7 @@ from control_backend.agents.transcription.speech_recognizer import OpenAIWhisper
|
||||
|
||||
def test_estimate_max_tokens():
|
||||
"""Inputting one minute of audio, assuming 300 words per minute, expecting 400 tokens."""
|
||||
audio = np.empty(shape=(60*16_000), dtype=np.float32)
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
actual = SpeechRecognizer._estimate_max_tokens(audio)
|
||||
|
||||
@@ -16,7 +16,7 @@ def test_estimate_max_tokens():
|
||||
|
||||
def test_get_decode_options():
|
||||
"""Check whether the right decode options are given under different scenarios."""
|
||||
audio = np.empty(shape=(60*16_000), dtype=np.float32)
|
||||
audio = np.empty(shape=(60 * 16_000), dtype=np.float32)
|
||||
|
||||
# With the defaults, it should limit output length based on input size
|
||||
recognizer = OpenAIWhisperSpeechRecognizer()
|
||||
|
||||
Reference in New Issue
Block a user