128 lines
4.2 KiB
Python
128 lines
4.2 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.testclient import TestClient
|
|
from starlette.responses import StreamingResponse
|
|
|
|
from control_backend.api.v1.endpoints import logs
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""TestClient with logs router included."""
|
|
app = FastAPI()
|
|
app.include_router(logs.router)
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_stream_endpoint_lines(client):
|
|
"""Call /logs/stream with a mocked ZMQ socket to cover all lines."""
|
|
|
|
# Dummy socket to mock ZMQ behavior
|
|
class DummySocket:
|
|
def __init__(self):
|
|
self.subscribed = []
|
|
self.connected = False
|
|
self.recv_count = 0
|
|
|
|
def subscribe(self, topic):
|
|
self.subscribed.append(topic)
|
|
|
|
def connect(self, addr):
|
|
self.connected = True
|
|
|
|
async def recv_multipart(self):
|
|
# Return one message, then stop generator
|
|
if self.recv_count == 0:
|
|
self.recv_count += 1
|
|
return (b"INFO", b"test message")
|
|
else:
|
|
raise StopAsyncIteration
|
|
|
|
dummy_socket = DummySocket()
|
|
|
|
# Patch Context.instance().socket() to return dummy socket
|
|
with patch("control_backend.api.v1.endpoints.logs.Context.instance") as mock_context:
|
|
mock_context.return_value.socket.return_value = dummy_socket
|
|
|
|
# Call the endpoint directly
|
|
response = await logs.log_stream()
|
|
assert isinstance(response, StreamingResponse)
|
|
|
|
# Fetch one chunk from the generator
|
|
gen = response.body_iterator
|
|
chunk = await gen.__anext__()
|
|
if isinstance(chunk, bytes):
|
|
chunk = chunk.decode("utf-8")
|
|
assert "data:" in chunk
|
|
|
|
# Optional: assert subscribe/connect were called
|
|
assert dummy_socket.subscribed # at least some log levels subscribed
|
|
assert dummy_socket.connected # connect was called
|
|
|
|
|
|
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
|
def test_files_endpoint(LOGGING_DIR, client):
|
|
file_1, file_2 = MagicMock(), MagicMock()
|
|
file_1.name = "file_1"
|
|
file_2.name = "file_2"
|
|
LOGGING_DIR.glob.return_value = [file_1, file_2]
|
|
result = client.get("/api/logs/files")
|
|
|
|
assert result.status_code == 200
|
|
assert result.json() == ["file_1", "file_2"]
|
|
|
|
|
|
@patch("control_backend.api.v1.endpoints.logs.FileResponse")
|
|
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
|
def test_log_file_endpoint_success(LOGGING_DIR, MockFileResponse, client):
|
|
mock_file_path = MagicMock()
|
|
mock_file_path.is_relative_to.return_value = True
|
|
mock_file_path.is_file.return_value = True
|
|
mock_file_path.name = "test.log"
|
|
|
|
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
|
mock_file_path.resolve.return_value = mock_file_path
|
|
|
|
MockFileResponse.return_value = MagicMock()
|
|
|
|
result = client.get("/api/logs/files/test.log")
|
|
|
|
assert result.status_code == 200
|
|
MockFileResponse.assert_called_once_with(mock_file_path, filename="test.log")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
|
async def test_log_file_endpoint_path_traversal(LOGGING_DIR):
|
|
from control_backend.api.v1.endpoints.logs import log_file
|
|
|
|
mock_file_path = MagicMock()
|
|
mock_file_path.is_relative_to.return_value = False
|
|
|
|
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
|
mock_file_path.resolve.return_value = mock_file_path
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await log_file("../secret.txt")
|
|
|
|
assert exc_info.value.status_code == 400
|
|
assert exc_info.value.detail == "Invalid filename."
|
|
|
|
|
|
@patch("control_backend.api.v1.endpoints.logs.LOGGING_DIR")
|
|
def test_log_file_endpoint_file_not_found(LOGGING_DIR, client):
|
|
mock_file_path = MagicMock()
|
|
mock_file_path.is_relative_to.return_value = True
|
|
mock_file_path.is_file.return_value = False
|
|
|
|
LOGGING_DIR.__truediv__ = MagicMock(return_value=mock_file_path)
|
|
mock_file_path.resolve.return_value = mock_file_path
|
|
|
|
result = client.get("/api/logs/files/nonexistent.log")
|
|
|
|
assert result.status_code == 404
|
|
assert result.json()["detail"] == "File not found."
|