astrbbbb / tests /test_backup.py
qa1145's picture
Upload 1245 files
8ede856 verified
"""备份功能单元测试"""
import json
import os
import re
import zipfile
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from astrbot.core.backup import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
ImportPreCheckResult,
)
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.backup.importer import (
DatabaseClearError,
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
AstrBotImporter,
ImportResult,
_get_major_version,
)
from astrbot.core.config.default import VERSION
from astrbot.core.db.po import (
ConversationV2,
)
from astrbot.core.utils.version_comparator import VersionComparator
from astrbot.dashboard.routes.backup import (
generate_unique_filename,
secure_filename,
)
@pytest.fixture
def temp_backup_dir(tmp_path):
"""创建临时备份目录"""
backup_dir = tmp_path / "backups"
backup_dir.mkdir()
return backup_dir
@pytest.fixture
def temp_data_dir(tmp_path):
"""创建临时数据目录"""
data_dir = tmp_path / "data"
data_dir.mkdir()
# 创建配置文件
config_path = data_dir / "cmd_config.json"
config_path.write_text(json.dumps({"test": "config"}))
# 创建附件目录
attachments_dir = data_dir / "attachments"
attachments_dir.mkdir()
return data_dir
@pytest.fixture
def mock_main_db():
"""创建模拟的主数据库"""
db = MagicMock()
# 模拟异步上下文管理器
session = AsyncMock()
db.get_db = MagicMock(
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
)
return db
@pytest.fixture
def mock_kb_manager():
"""创建模拟的知识库管理器"""
kb_manager = MagicMock()
kb_manager.kb_insts = {}
# 模拟 kb_db
kb_db = MagicMock()
session = AsyncMock()
kb_db.get_db = MagicMock(
return_value=AsyncMock(__aenter__=AsyncMock(return_value=session))
)
kb_manager.kb_db = kb_db
return kb_manager
class TestImportResult:
"""ImportResult 类测试"""
def test_init(self):
"""测试初始化"""
result = ImportResult()
assert result.success is True
assert result.imported_tables == {}
assert result.imported_files == {}
assert result.warnings == []
assert result.errors == []
def test_add_warning(self):
"""测试添加警告"""
result = ImportResult()
result.add_warning("test warning")
assert "test warning" in result.warnings
assert result.success is True # 警告不影响成功状态
def test_add_error(self):
"""测试添加错误"""
result = ImportResult()
result.add_error("test error")
assert "test error" in result.errors
assert result.success is False # 错误会导致失败
def test_to_dict(self):
"""测试转换为字典"""
result = ImportResult()
result.imported_tables = {"test_table": 10}
result.add_warning("warning")
d = result.to_dict()
assert d["success"] is True
assert d["imported_tables"] == {"test_table": 10}
assert "warning" in d["warnings"]
class TestAstrBotExporter:
"""AstrBotExporter 类测试"""
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
"""测试初始化"""
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
config_path=str(temp_data_dir / "cmd_config.json"),
)
assert exporter.main_db is mock_main_db
assert exporter.kb_manager is mock_kb_manager
def test_model_to_dict_with_model_dump(self):
"""测试 _model_to_dict 使用 model_dump 方法"""
exporter = AstrBotExporter(main_db=MagicMock())
# 创建一个有 model_dump 方法的模拟对象
mock_record = MagicMock()
mock_record.model_dump.return_value = {"id": 1, "name": "test"}
result = exporter._model_to_dict(mock_record)
assert result == {"id": 1, "name": "test"}
def test_model_to_dict_with_datetime(self):
"""测试 _model_to_dict 处理 datetime 字段"""
exporter = AstrBotExporter(main_db=MagicMock())
now = datetime.now()
mock_record = MagicMock()
mock_record.model_dump.return_value = {"id": 1, "created_at": now}
result = exporter._model_to_dict(mock_record)
assert result["created_at"] == now.isoformat()
def test_add_checksum(self):
"""测试添加校验和"""
exporter = AstrBotExporter(main_db=MagicMock())
exporter._add_checksum("test.json", '{"test": "data"}')
assert "test.json" in exporter._checksums
assert exporter._checksums["test.json"].startswith("sha256:")
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
"""测试生成清单"""
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
)
main_data = {
"platform_stats": [{"id": 1}],
"conversations": [],
"attachments": [],
}
kb_meta_data = {
"knowledge_bases": [],
"kb_documents": [],
}
dir_stats = {
"plugins": {"files": 10, "size": 1024},
"plugin_data": {"files": 5, "size": 512},
}
manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats)
assert manifest["version"] == BACKUP_MANIFEST_VERSION
assert manifest["astrbot_version"] == VERSION
assert manifest["origin"] == "exported" # 验证备份来源标记
assert "exported_at" in manifest
assert "tables" in manifest
assert "statistics" in manifest
assert "directories" in manifest
assert manifest["statistics"]["main_db"]["platform_stats"] == 1
assert manifest["statistics"]["directories"] == dir_stats
@pytest.mark.asyncio
async def test_export_all_creates_zip(
self, mock_main_db, temp_backup_dir, temp_data_dir
):
"""测试导出创建 ZIP 文件"""
# 设置模拟数据库返回空数据
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=result)
mock_main_db.get_db.return_value = AsyncMock(
__aenter__=AsyncMock(return_value=session),
__aexit__=AsyncMock(return_value=None),
)
exporter = AstrBotExporter(
main_db=mock_main_db,
kb_manager=None,
config_path=str(temp_data_dir / "cmd_config.json"),
)
zip_path = await exporter.export_all(output_dir=str(temp_backup_dir))
assert os.path.exists(zip_path)
assert zip_path.endswith(".zip")
assert "astrbot_backup_" in zip_path
# 验证 ZIP 文件内容
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "manifest.json" in namelist
assert "databases/main_db.json" in namelist
assert "config/cmd_config.json" in namelist
class TestAstrBotImporter:
"""AstrBotImporter 类测试"""
def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir):
"""测试初始化"""
importer = AstrBotImporter(
main_db=mock_main_db,
kb_manager=mock_kb_manager,
config_path=str(temp_data_dir / "cmd_config.json"),
)
assert importer.main_db is mock_main_db
assert importer.kb_manager is mock_kb_manager
def test_validate_version_match(self):
"""测试版本匹配验证"""
importer = AstrBotImporter(main_db=MagicMock())
manifest = {"astrbot_version": VERSION}
# 不应该抛出异常
importer._validate_version(manifest)
def test_validate_version_major_diff_rejected(self):
"""测试主版本不同被拒绝"""
importer = AstrBotImporter(main_db=MagicMock())
# 使用一个明显不同的主版本
manifest = {"astrbot_version": "0.0.1"}
with pytest.raises(ValueError, match="主版本不兼容"):
importer._validate_version(manifest)
def test_validate_version_minor_diff_allowed(self):
"""测试小版本不同被允许(仅警告)"""
importer = AstrBotImporter(main_db=MagicMock())
# 获取当前主版本
major_version = _get_major_version(VERSION)
# 构造一个同主版本但小版本不同的版本
minor_diff_version = f"{major_version}.999"
manifest = {"astrbot_version": minor_diff_version}
# 不应该抛出异常
importer._validate_version(manifest)
def test_validate_version_missing(self):
"""测试缺少版本信息"""
importer = AstrBotImporter(main_db=MagicMock())
manifest = {}
with pytest.raises(ValueError, match="缺少版本信息"):
importer._validate_version(manifest)
def test_convert_datetime_fields(self):
"""测试 datetime 字段转换"""
importer = AstrBotImporter(main_db=MagicMock())
# 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段)
row = {
"conversation_id": "test-123",
"platform_id": "test",
"user_id": "user1",
"created_at": "2024-01-01T12:00:00",
"updated_at": "2024-01-01T12:00:00",
}
result = importer._convert_datetime_fields(row, ConversationV2)
# created_at 应该被转换为 datetime 对象
assert isinstance(result["created_at"], datetime)
assert isinstance(result["updated_at"], datetime)
def test_merge_platform_stats_rows(self):
"""测试 platform_stats 重复键会在导入前聚合"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"id": 1,
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 14,
},
{
"id": 80,
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 3,
},
{
"id": 81,
"timestamp": "2025-12-13T20:00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 2,
},
{
"id": 2,
"timestamp": "2025-12-13T21:00:00",
"platform_id": "aiocqhttp",
"platform_type": "unknown",
"count": 1,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(merged_rows)
assert duplicate_count == 2
assert len(merged_rows) == 2
webchat_row = next(
(
r
for r in merged_rows
if r.get("timestamp") == "2025-12-13T20:00:00+00:00"
and r.get("platform_id") == "webchat"
and r.get("platform_type") == "unknown"
),
None,
)
assert webchat_row is not None
assert webchat_row["timestamp"] == "2025-12-13T20:00:00+00:00"
assert webchat_row["platform_id"] == "webchat"
assert webchat_row["platform_type"] == "unknown"
assert webchat_row["count"] == 19
aiocq_row = next(
(
r
for r in merged_rows
if r.get("platform_id") == "aiocqhttp"
and r.get("platform_type") == "unknown"
),
None,
)
assert aiocq_row is not None
assert aiocq_row["timestamp"] == "2025-12-13T21:00:00+00:00"
def test_merge_platform_stats_rows_normalizes_naive_timestamp_to_utc(self):
"""测试 platform_stats 合并前会将 naive timestamp 标准化为 UTC 偏移"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"timestamp": "2025-12-13T21:00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 1,
},
{
"timestamp": datetime(2025, 12, 13, 22, 0, 0),
"platform_id": "telegram",
"platform_type": "unknown",
"count": 1,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
assert len(merged_rows) == 2
by_platform = {row["platform_id"]: row for row in merged_rows}
assert by_platform["webchat"]["timestamp"] == "2025-12-13T21:00:00+00:00"
assert by_platform["telegram"]["timestamp"] == "2025-12-13T22:00:00+00:00"
def test_merge_platform_stats_rows_warns_on_invalid_count(self):
"""测试 platform_stats count 非法时会告警并按 0 处理(含上限)"""
importer = AstrBotImporter(main_db=MagicMock())
with patch("astrbot.core.backup.importer.logger.warning") as warning_mock:
rows = [
{
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 5,
},
{
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": "bad-count",
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(merged_rows)
assert duplicate_count == 1
assert len(merged_rows) == 1
assert merged_rows[0]["count"] == 5
assert warning_mock.call_count == 1
warning_mock.reset_mock()
rows_existing_invalid = [
{
"timestamp": "2025-12-13T21:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": "bad-count",
},
{
"timestamp": "2025-12-13T21:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 7,
},
]
merged_rows = importer._merge_platform_stats_rows(rows_existing_invalid)
duplicate_count = len(rows_existing_invalid) - len(merged_rows)
assert duplicate_count == 1
assert len(merged_rows) == 1
assert merged_rows[0]["count"] == 7
assert warning_mock.call_count == 1
warning_mock.reset_mock()
many_invalid_rows = [
{
"timestamp": "2025-12-13T22:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 1,
},
*[
{
"timestamp": "2025-12-13T22:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": "bad-count",
}
for _ in range(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 5)
],
]
importer._merge_platform_stats_rows(many_invalid_rows)
assert (
warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 1
)
assert any(
"告警已达到上限" in str(call.args[0])
for call in warning_mock.call_args_list
)
warning_mock.reset_mock()
single_invalid_row = [
{
"timestamp": "2025-12-13T23:00:00+00:00",
"platform_id": "telegram",
"platform_type": "unknown",
"count": "still-bad",
},
]
merged_rows = importer._merge_platform_stats_rows(single_invalid_row)
duplicate_count = len(single_invalid_row) - len(merged_rows)
assert duplicate_count == 0
assert len(merged_rows) == 1
assert merged_rows[0]["count"] == 0
assert warning_mock.call_count == 1
def test_merge_platform_stats_rows_keeps_invalid_timestamps_distinct(self):
"""测试空/非法 timestamp 不参与聚合,避免误合并"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"timestamp": "",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 2,
},
{
"timestamp": "not-a-datetime",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 3,
},
{
"timestamp": "not-a-datetime",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 4,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(merged_rows)
assert duplicate_count == 0
assert len(merged_rows) == 3
assert [row["count"] for row in merged_rows] == [2, 3, 4]
def test_merge_platform_stats_rows_keeps_non_string_platform_keys_distinct(self):
"""测试非字符串 platform_id/platform_type 不参与聚合"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": None,
"platform_type": "unknown",
"count": 2,
},
{
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": None,
"platform_type": "unknown",
"count": 3,
},
{
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": "webchat",
"platform_type": 1,
"count": 4,
},
{
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": "webchat",
"platform_type": 1,
"count": 5,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(merged_rows)
assert duplicate_count == 0
assert len(merged_rows) == 4
def test_merge_platform_stats_rows_preserves_input_order(self):
"""测试 platform_stats 聚合后仍保持输入顺序(按首次出现位置)"""
importer = AstrBotImporter(main_db=MagicMock())
rows = [
{
"id": 1,
"timestamp": "2025-12-13T20:00:00Z",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 2,
},
{
"id": 2,
"timestamp": "",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 3,
},
{
"id": 3,
"timestamp": "2025-12-13T20:00:00+00:00",
"platform_id": "webchat",
"platform_type": "unknown",
"count": 5,
},
{
"id": 4,
"timestamp": "2025-12-13T21:00:00+00:00",
"platform_id": "telegram",
"platform_type": "unknown",
"count": 7,
},
]
merged_rows = importer._merge_platform_stats_rows(rows)
assert len(merged_rows) == 3
assert [row["id"] for row in merged_rows] == [1, 2, 4]
assert merged_rows[0]["count"] == 7
@pytest.mark.asyncio
async def test_import_file_not_exists(self, mock_main_db, tmp_path):
"""测试导入不存在的文件"""
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(tmp_path / "nonexistent.zip"))
assert result.success is False
assert any("不存在" in err for err in result.errors)
@pytest.mark.asyncio
async def test_import_invalid_zip(self, mock_main_db, tmp_path):
"""测试导入无效的 ZIP 文件"""
# 创建一个无效的文件
invalid_zip = tmp_path / "invalid.zip"
invalid_zip.write_text("not a zip file")
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(invalid_zip))
assert result.success is False
assert any("无效" in err or "ZIP" in err for err in result.errors)
@pytest.mark.asyncio
async def test_import_missing_manifest(self, mock_main_db, tmp_path):
"""测试导入缺少 manifest 的 ZIP 文件"""
# 创建一个没有 manifest 的 ZIP 文件
zip_path = tmp_path / "no_manifest.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "test content")
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(zip_path))
assert result.success is False
assert any("manifest" in err.lower() for err in result.errors)
@pytest.mark.asyncio
async def test_import_major_version_mismatch(self, mock_main_db, tmp_path):
"""测试导入主版本不匹配的备份"""
# 创建一个主版本不匹配的备份
zip_path = tmp_path / "old_version.zip"
manifest = {
"version": "1.0",
"astrbot_version": "0.0.1", # 主版本不同
"tables": {"main_db": []},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = await importer.import_all(str(zip_path))
assert result.success is False
assert any("主版本不兼容" in err for err in result.errors)
@pytest.mark.asyncio
async def test_import_replace_fails_when_clear_main_db_fails(
self, mock_main_db, tmp_path
):
"""测试 replace 模式下主库清空失败会直接终止导入"""
zip_path = tmp_path / "valid_backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": VERSION,
"tables": {"platform_stats": 0},
}
main_data = {"platform_stats": []}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
zf.writestr("databases/main_db.json", json.dumps(main_data))
importer = AstrBotImporter(main_db=mock_main_db)
importer._clear_main_db = AsyncMock(
side_effect=DatabaseClearError("清空表 platform_stats 失败: db locked")
)
importer._import_main_database = AsyncMock(return_value={})
result = await importer.import_all(str(zip_path), mode="replace")
assert result.success is False
assert any("清空主数据库失败" in err for err in result.errors)
assert any("清空表 platform_stats 失败" in err for err in result.errors)
importer._import_main_database.assert_not_awaited()
class TestSecureFilename:
"""安全文件名函数测试"""
def test_secure_filename_normal(self):
"""测试正常文件名"""
assert secure_filename("backup.zip") == "backup.zip"
assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip"
def test_secure_filename_path_traversal(self):
"""测试路径遍历攻击"""
assert ".." not in secure_filename("../../../etc/passwd")
assert "/" not in secure_filename("/etc/passwd")
assert "\\" not in secure_filename("..\\..\\windows\\system32")
def test_secure_filename_with_path(self):
"""测试带路径的文件名"""
result = secure_filename("/path/to/backup.zip")
assert result == "backup.zip"
result = secure_filename("C:\\Users\\test\\backup.zip")
assert result == "backup.zip"
def test_secure_filename_special_chars(self):
"""测试特殊字符"""
result = secure_filename('backup<>:"|?*.zip')
# 特殊字符应被替换为下划线
assert "<" not in result
assert ">" not in result
assert ":" not in result
assert '"' not in result
assert "|" not in result
assert "?" not in result
assert "*" not in result
def test_secure_filename_hidden_file(self):
"""测试隐藏文件(前导点)"""
result = secure_filename(".hidden_backup.zip")
assert not result.startswith(".")
def test_secure_filename_empty(self):
"""测试空文件名"""
assert secure_filename("") == "backup"
assert secure_filename("...") == "backup"
def test_generate_unique_filename(self):
"""测试生成唯一文件名"""
result = generate_unique_filename("backup.zip")
# 应包含原文件名和时间戳后缀
assert result.startswith("backup_")
assert result.endswith(".zip")
# 应包含时间戳格式 YYYYMMDD_HHMMSS
assert re.search(r"backup_\d{8}_\d{6}\.zip", result)
def test_generate_unique_filename_with_complex_name(self):
"""测试复杂文件名生成唯一文件名"""
result = generate_unique_filename("my_backup_file.zip")
# 应在原文件名后添加时间戳
assert result.startswith("my_backup_file_")
assert result.endswith(".zip")
assert re.search(r"my_backup_file_\d{8}_\d{6}\.zip", result)
class TestVersionComparison:
"""版本比较函数测试 - 使用 VersionComparator"""
def test_get_major_version_simple(self):
"""测试提取简单主版本号"""
assert _get_major_version("1.0") == "1.0"
assert _get_major_version("2.1") == "2.1"
assert _get_major_version("4.9.1") == "4.9"
def test_get_major_version_with_prefix(self):
"""测试带 v 前缀的版本号"""
assert _get_major_version("v1.0") == "1.0"
assert _get_major_version("V4.9.1") == "4.9"
def test_get_major_version_with_prerelease(self):
"""测试带预发布标签的版本号"""
assert _get_major_version("4.9.1-beta") == "4.9"
assert _get_major_version("4.9.1-alpha.1") == "4.9"
assert _get_major_version("4.9.1+build123") == "4.9"
def test_get_major_version_single_part(self):
"""测试单部分版本号"""
assert _get_major_version("1") == "1.0"
def test_get_major_version_empty(self):
"""测试空版本号"""
assert _get_major_version("") == "0.0"
def test_compare_versions_equal(self):
"""测试版本相等"""
assert VersionComparator.compare_version("1.0", "1.0") == 0
assert VersionComparator.compare_version("1.0.0", "1.0") == 0
assert VersionComparator.compare_version("2.10", "2.10") == 0
def test_compare_versions_less_than(self):
"""测试版本小于"""
assert VersionComparator.compare_version("1.0", "1.1") == -1
assert (
VersionComparator.compare_version("1.9", "1.10") == -1
) # 关键测试:多位数版本比较
assert VersionComparator.compare_version("1.2", "1.10") == -1
assert VersionComparator.compare_version("1.0", "2.0") == -1
def test_compare_versions_greater_than(self):
"""测试版本大于"""
assert VersionComparator.compare_version("1.1", "1.0") == 1
assert (
VersionComparator.compare_version("1.10", "1.9") == 1
) # 关键测试:多位数版本比较
assert VersionComparator.compare_version("1.10", "1.2") == 1
assert VersionComparator.compare_version("2.0", "1.0") == 1
def test_compare_versions_different_lengths(self):
"""测试不同长度版本比较"""
assert VersionComparator.compare_version("1.0", "1.0.0") == 0
assert VersionComparator.compare_version("1.0", "1.0.1") == -1
assert VersionComparator.compare_version("1.0.1", "1.0") == 1
def test_compare_versions_prerelease(self):
"""测试预发布版本比较"""
# 预发布版本低于正式版本
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0") == -1
assert VersionComparator.compare_version("1.0.0", "1.0.0-beta") == 1
# alpha < beta
assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0-beta") == -1
class TestImportPreCheckResult:
"""ImportPreCheckResult 类测试"""
def test_init_default_values(self):
"""测试默认值初始化"""
result = ImportPreCheckResult()
assert result.valid is False
assert result.can_import is False
assert result.version_status == ""
assert result.backup_version == ""
assert result.current_version == VERSION
assert result.confirm_message == ""
assert result.warnings == []
assert result.error == ""
assert result.backup_summary == {}
def test_to_dict(self):
"""测试转换为字典"""
result = ImportPreCheckResult(
valid=True,
can_import=True,
version_status="match",
backup_version="4.9.0",
confirm_message="确认导入?",
warnings=["警告1"],
backup_summary={"tables": ["table1"]},
)
d = result.to_dict()
assert d["valid"] is True
assert d["can_import"] is True
assert d["version_status"] == "match"
assert d["backup_version"] == "4.9.0"
assert d["confirm_message"] == "确认导入?"
assert "警告1" in d["warnings"]
assert d["backup_summary"]["tables"] == ["table1"]
class TestPreCheck:
"""预检查功能测试"""
def test_pre_check_file_not_exists(self, mock_main_db):
"""测试预检查不存在的文件"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check("/nonexistent/file.zip")
assert result.valid is False
assert "不存在" in result.error
def test_pre_check_invalid_zip(self, mock_main_db, tmp_path):
"""测试预检查无效的 ZIP 文件"""
invalid_zip = tmp_path / "invalid.zip"
invalid_zip.write_text("not a zip file")
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(invalid_zip))
assert result.valid is False
assert "ZIP" in result.error or "无效" in result.error
def test_pre_check_missing_manifest(self, mock_main_db, tmp_path):
"""测试预检查缺少 manifest 的 ZIP 文件"""
zip_path = tmp_path / "no_manifest.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("test.txt", "test content")
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is False
assert "manifest" in result.error.lower()
def test_pre_check_version_match(self, mock_main_db, tmp_path):
"""测试预检查版本匹配"""
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": VERSION,
"created_at": "2024-01-01T12:00:00",
"tables": {"platform_stats": 1},
"has_knowledge_bases": True,
"has_config": True,
"directories": ["plugins"],
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True
assert result.can_import is True
assert result.version_status == "match"
assert result.backup_version == VERSION
# confirm_message 现在由前端生成,后端不再生成
assert result.backup_summary["has_knowledge_bases"] is True
def test_pre_check_minor_version_diff(self, mock_main_db, tmp_path):
"""测试预检查小版本差异"""
# 构造一个同主版本但小版本不同的版本
major_version = _get_major_version(VERSION)
minor_diff_version = f"{major_version}.999"
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": minor_diff_version,
"created_at": "2024-01-01T12:00:00",
"tables": {},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True
assert result.can_import is True
assert result.version_status == "minor_diff"
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
# warnings 列表保留用于其他非版本相关的警告
def test_pre_check_major_version_diff(self, mock_main_db, tmp_path):
"""测试预检查主版本差异"""
zip_path = tmp_path / "backup.zip"
manifest = {
"version": "1.1",
"astrbot_version": "0.0.1", # 主版本不同
"created_at": "2024-01-01T12:00:00",
"tables": {},
}
with zipfile.ZipFile(zip_path, "w") as zf:
zf.writestr("manifest.json", json.dumps(manifest))
importer = AstrBotImporter(main_db=mock_main_db)
result = importer.pre_check(str(zip_path))
assert result.valid is True # 文件有效
assert result.can_import is False # 但不能导入
assert result.version_status == "major_diff"
# 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息
class TestVersionCompatibility:
"""版本兼容性检查测试"""
def test_check_version_compatibility_match(self, mock_main_db):
"""测试版本完全匹配"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility(VERSION)
assert result["status"] == "match"
assert result["can_import"] is True
def test_check_version_compatibility_minor_diff(self, mock_main_db):
"""测试小版本差异"""
major_version = _get_major_version(VERSION)
minor_diff_version = f"{major_version}.999"
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility(minor_diff_version)
assert result["status"] == "minor_diff"
assert result["can_import"] is True
def test_check_version_compatibility_major_diff(self, mock_main_db):
"""测试主版本差异"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility("0.0.1")
assert result["status"] == "major_diff"
assert result["can_import"] is False
def test_check_version_compatibility_empty_version(self, mock_main_db):
"""测试空版本号"""
importer = AstrBotImporter(main_db=mock_main_db)
result = importer._check_version_compatibility("")
assert result["status"] == "major_diff"
assert result["can_import"] is False
class TestModelMappings:
"""测试模型映射配置"""
def test_main_db_models_not_empty(self):
"""测试主数据库模型映射非空"""
assert len(MAIN_DB_MODELS) > 0
def test_main_db_models_contain_expected_tables(self):
"""测试主数据库模型映射包含预期的表"""
expected_tables = [
"platform_stats",
"conversations",
"personas",
"preferences",
"attachments",
]
for table in expected_tables:
assert table in MAIN_DB_MODELS, f"Missing table: {table}"
def test_kb_metadata_models_not_empty(self):
"""测试知识库元数据模型映射非空"""
assert len(KB_METADATA_MODELS) > 0
def test_kb_metadata_models_contain_expected_tables(self):
"""测试知识库元数据模型映射包含预期的表"""
expected_tables = [
"knowledge_bases",
"kb_documents",
"kb_media",
]
for table in expected_tables:
assert table in KB_METADATA_MODELS, f"Missing table: {table}"
class TestBackupIntegration:
"""备份集成测试"""
@pytest.mark.asyncio
async def test_export_import_roundtrip(self, tmp_path):
"""测试导出-导入往返"""
backup_dir = tmp_path / "backups"
backup_dir.mkdir()
data_dir = tmp_path / "data"
data_dir.mkdir()
config_path = data_dir / "cmd_config.json"
config_path.write_text(json.dumps({"setting": "value"}))
attachments_dir = data_dir / "attachments"
attachments_dir.mkdir()
# 创建模拟数据库
mock_db = MagicMock()
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=result)
mock_db.get_db.return_value = AsyncMock(
__aenter__=AsyncMock(return_value=session),
__aexit__=AsyncMock(return_value=None),
)
# 导出
exporter = AstrBotExporter(
main_db=mock_db,
kb_manager=None,
config_path=str(config_path),
)
zip_path = await exporter.export_all(output_dir=str(backup_dir))
assert os.path.exists(zip_path)
# 验证 ZIP 内容
with zipfile.ZipFile(zip_path, "r") as zf:
# 读取 manifest
manifest = json.loads(zf.read("manifest.json"))
assert manifest["astrbot_version"] == VERSION
assert manifest["origin"] == "exported" # 验证备份来源标记
# 读取配置
config = json.loads(zf.read("config/cmd_config.json"))
assert config["setting"] == "value"
# 读取主数据库
main_db = json.loads(zf.read("databases/main_db.json"))
assert "platform_stats" in main_db