"""备份功能单元测试""" import json import os import re import zipfile from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch import pytest from astrbot.core.backup import ( BACKUP_MANIFEST_VERSION, KB_METADATA_MODELS, MAIN_DB_MODELS, ImportPreCheckResult, ) from astrbot.core.backup.exporter import AstrBotExporter from astrbot.core.backup.importer import ( DatabaseClearError, PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, AstrBotImporter, ImportResult, _get_major_version, ) from astrbot.core.config.default import VERSION from astrbot.core.db.po import ( ConversationV2, ) from astrbot.core.utils.version_comparator import VersionComparator from astrbot.dashboard.routes.backup import ( generate_unique_filename, secure_filename, ) @pytest.fixture def temp_backup_dir(tmp_path): """创建临时备份目录""" backup_dir = tmp_path / "backups" backup_dir.mkdir() return backup_dir @pytest.fixture def temp_data_dir(tmp_path): """创建临时数据目录""" data_dir = tmp_path / "data" data_dir.mkdir() # 创建配置文件 config_path = data_dir / "cmd_config.json" config_path.write_text(json.dumps({"test": "config"})) # 创建附件目录 attachments_dir = data_dir / "attachments" attachments_dir.mkdir() return data_dir @pytest.fixture def mock_main_db(): """创建模拟的主数据库""" db = MagicMock() # 模拟异步上下文管理器 session = AsyncMock() db.get_db = MagicMock( return_value=AsyncMock(__aenter__=AsyncMock(return_value=session)) ) return db @pytest.fixture def mock_kb_manager(): """创建模拟的知识库管理器""" kb_manager = MagicMock() kb_manager.kb_insts = {} # 模拟 kb_db kb_db = MagicMock() session = AsyncMock() kb_db.get_db = MagicMock( return_value=AsyncMock(__aenter__=AsyncMock(return_value=session)) ) kb_manager.kb_db = kb_db return kb_manager class TestImportResult: """ImportResult 类测试""" def test_init(self): """测试初始化""" result = ImportResult() assert result.success is True assert result.imported_tables == {} assert result.imported_files == {} assert result.warnings == [] assert result.errors == [] def test_add_warning(self): """测试添加警告""" result = ImportResult() result.add_warning("test warning") assert "test warning" in result.warnings assert result.success is True # 警告不影响成功状态 def test_add_error(self): """测试添加错误""" result = ImportResult() result.add_error("test error") assert "test error" in result.errors assert result.success is False # 错误会导致失败 def test_to_dict(self): """测试转换为字典""" result = ImportResult() result.imported_tables = {"test_table": 10} result.add_warning("warning") d = result.to_dict() assert d["success"] is True assert d["imported_tables"] == {"test_table": 10} assert "warning" in d["warnings"] class TestAstrBotExporter: """AstrBotExporter 类测试""" def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir): """测试初始化""" exporter = AstrBotExporter( main_db=mock_main_db, kb_manager=mock_kb_manager, config_path=str(temp_data_dir / "cmd_config.json"), ) assert exporter.main_db is mock_main_db assert exporter.kb_manager is mock_kb_manager def test_model_to_dict_with_model_dump(self): """测试 _model_to_dict 使用 model_dump 方法""" exporter = AstrBotExporter(main_db=MagicMock()) # 创建一个有 model_dump 方法的模拟对象 mock_record = MagicMock() mock_record.model_dump.return_value = {"id": 1, "name": "test"} result = exporter._model_to_dict(mock_record) assert result == {"id": 1, "name": "test"} def test_model_to_dict_with_datetime(self): """测试 _model_to_dict 处理 datetime 字段""" exporter = AstrBotExporter(main_db=MagicMock()) now = datetime.now() mock_record = MagicMock() mock_record.model_dump.return_value = {"id": 1, "created_at": now} result = exporter._model_to_dict(mock_record) assert result["created_at"] == now.isoformat() def test_add_checksum(self): """测试添加校验和""" exporter = AstrBotExporter(main_db=MagicMock()) exporter._add_checksum("test.json", '{"test": "data"}') assert "test.json" in exporter._checksums assert exporter._checksums["test.json"].startswith("sha256:") def test_generate_manifest(self, mock_main_db, mock_kb_manager): """测试生成清单""" exporter = AstrBotExporter( main_db=mock_main_db, kb_manager=mock_kb_manager, ) main_data = { "platform_stats": [{"id": 1}], "conversations": [], "attachments": [], } kb_meta_data = { "knowledge_bases": [], "kb_documents": [], } dir_stats = { "plugins": {"files": 10, "size": 1024}, "plugin_data": {"files": 5, "size": 512}, } manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats) assert manifest["version"] == BACKUP_MANIFEST_VERSION assert manifest["astrbot_version"] == VERSION assert manifest["origin"] == "exported" # 验证备份来源标记 assert "exported_at" in manifest assert "tables" in manifest assert "statistics" in manifest assert "directories" in manifest assert manifest["statistics"]["main_db"]["platform_stats"] == 1 assert manifest["statistics"]["directories"] == dir_stats @pytest.mark.asyncio async def test_export_all_creates_zip( self, mock_main_db, temp_backup_dir, temp_data_dir ): """测试导出创建 ZIP 文件""" # 设置模拟数据库返回空数据 session = AsyncMock() result = MagicMock() result.scalars.return_value.all.return_value = [] session.execute = AsyncMock(return_value=result) mock_main_db.get_db.return_value = AsyncMock( __aenter__=AsyncMock(return_value=session), __aexit__=AsyncMock(return_value=None), ) exporter = AstrBotExporter( main_db=mock_main_db, kb_manager=None, config_path=str(temp_data_dir / "cmd_config.json"), ) zip_path = await exporter.export_all(output_dir=str(temp_backup_dir)) assert os.path.exists(zip_path) assert zip_path.endswith(".zip") assert "astrbot_backup_" in zip_path # 验证 ZIP 文件内容 with zipfile.ZipFile(zip_path, "r") as zf: namelist = zf.namelist() assert "manifest.json" in namelist assert "databases/main_db.json" in namelist assert "config/cmd_config.json" in namelist class TestAstrBotImporter: """AstrBotImporter 类测试""" def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir): """测试初始化""" importer = AstrBotImporter( main_db=mock_main_db, kb_manager=mock_kb_manager, config_path=str(temp_data_dir / "cmd_config.json"), ) assert importer.main_db is mock_main_db assert importer.kb_manager is mock_kb_manager def test_validate_version_match(self): """测试版本匹配验证""" importer = AstrBotImporter(main_db=MagicMock()) manifest = {"astrbot_version": VERSION} # 不应该抛出异常 importer._validate_version(manifest) def test_validate_version_major_diff_rejected(self): """测试主版本不同被拒绝""" importer = AstrBotImporter(main_db=MagicMock()) # 使用一个明显不同的主版本 manifest = {"astrbot_version": "0.0.1"} with pytest.raises(ValueError, match="主版本不兼容"): importer._validate_version(manifest) def test_validate_version_minor_diff_allowed(self): """测试小版本不同被允许(仅警告)""" importer = AstrBotImporter(main_db=MagicMock()) # 获取当前主版本 major_version = _get_major_version(VERSION) # 构造一个同主版本但小版本不同的版本 minor_diff_version = f"{major_version}.999" manifest = {"astrbot_version": minor_diff_version} # 不应该抛出异常 importer._validate_version(manifest) def test_validate_version_missing(self): """测试缺少版本信息""" importer = AstrBotImporter(main_db=MagicMock()) manifest = {} with pytest.raises(ValueError, match="缺少版本信息"): importer._validate_version(manifest) def test_convert_datetime_fields(self): """测试 datetime 字段转换""" importer = AstrBotImporter(main_db=MagicMock()) # 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段) row = { "conversation_id": "test-123", "platform_id": "test", "user_id": "user1", "created_at": "2024-01-01T12:00:00", "updated_at": "2024-01-01T12:00:00", } result = importer._convert_datetime_fields(row, ConversationV2) # created_at 应该被转换为 datetime 对象 assert isinstance(result["created_at"], datetime) assert isinstance(result["updated_at"], datetime) def test_merge_platform_stats_rows(self): """测试 platform_stats 重复键会在导入前聚合""" importer = AstrBotImporter(main_db=MagicMock()) rows = [ { "id": 1, "timestamp": "2025-12-13T20:00:00Z", "platform_id": "webchat", "platform_type": "unknown", "count": 14, }, { "id": 80, "timestamp": "2025-12-13T20:00:00+00:00", "platform_id": "webchat", "platform_type": "unknown", "count": 3, }, { "id": 81, "timestamp": "2025-12-13T20:00:00", "platform_id": "webchat", "platform_type": "unknown", "count": 2, }, { "id": 2, "timestamp": "2025-12-13T21:00:00", "platform_id": "aiocqhttp", "platform_type": "unknown", "count": 1, }, ] merged_rows = importer._merge_platform_stats_rows(rows) duplicate_count = len(rows) - len(merged_rows) assert duplicate_count == 2 assert len(merged_rows) == 2 webchat_row = next( ( r for r in merged_rows if r.get("timestamp") == "2025-12-13T20:00:00+00:00" and r.get("platform_id") == "webchat" and r.get("platform_type") == "unknown" ), None, ) assert webchat_row is not None assert webchat_row["timestamp"] == "2025-12-13T20:00:00+00:00" assert webchat_row["platform_id"] == "webchat" assert webchat_row["platform_type"] == "unknown" assert webchat_row["count"] == 19 aiocq_row = next( ( r for r in merged_rows if r.get("platform_id") == "aiocqhttp" and r.get("platform_type") == "unknown" ), None, ) assert aiocq_row is not None assert aiocq_row["timestamp"] == "2025-12-13T21:00:00+00:00" def test_merge_platform_stats_rows_normalizes_naive_timestamp_to_utc(self): """测试 platform_stats 合并前会将 naive timestamp 标准化为 UTC 偏移""" importer = AstrBotImporter(main_db=MagicMock()) rows = [ { "timestamp": "2025-12-13T21:00:00", "platform_id": "webchat", "platform_type": "unknown", "count": 1, }, { "timestamp": datetime(2025, 12, 13, 22, 0, 0), "platform_id": "telegram", "platform_type": "unknown", "count": 1, }, ] merged_rows = importer._merge_platform_stats_rows(rows) assert len(merged_rows) == 2 by_platform = {row["platform_id"]: row for row in merged_rows} assert by_platform["webchat"]["timestamp"] == "2025-12-13T21:00:00+00:00" assert by_platform["telegram"]["timestamp"] == "2025-12-13T22:00:00+00:00" def test_merge_platform_stats_rows_warns_on_invalid_count(self): """测试 platform_stats count 非法时会告警并按 0 处理(含上限)""" importer = AstrBotImporter(main_db=MagicMock()) with patch("astrbot.core.backup.importer.logger.warning") as warning_mock: rows = [ { "timestamp": "2025-12-13T20:00:00+00:00", "platform_id": "webchat", "platform_type": "unknown", "count": 5, }, { "timestamp": "2025-12-13T20:00:00Z", "platform_id": "webchat", "platform_type": "unknown", "count": "bad-count", }, ] merged_rows = importer._merge_platform_stats_rows(rows) duplicate_count = len(rows) - len(merged_rows) assert duplicate_count == 1 assert len(merged_rows) == 1 assert merged_rows[0]["count"] == 5 assert warning_mock.call_count == 1 warning_mock.reset_mock() rows_existing_invalid = [ { "timestamp": "2025-12-13T21:00:00+00:00", "platform_id": "webchat", "platform_type": "unknown", "count": "bad-count", }, { "timestamp": "2025-12-13T21:00:00Z", "platform_id": "webchat", "platform_type": "unknown", "count": 7, }, ] merged_rows = importer._merge_platform_stats_rows(rows_existing_invalid) duplicate_count = len(rows_existing_invalid) - len(merged_rows) assert duplicate_count == 1 assert len(merged_rows) == 1 assert merged_rows[0]["count"] == 7 assert warning_mock.call_count == 1 warning_mock.reset_mock() many_invalid_rows = [ { "timestamp": "2025-12-13T22:00:00+00:00", "platform_id": "webchat", "platform_type": "unknown", "count": 1, }, *[ { "timestamp": "2025-12-13T22:00:00Z", "platform_id": "webchat", "platform_type": "unknown", "count": "bad-count", } for _ in range(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 5) ], ] importer._merge_platform_stats_rows(many_invalid_rows) assert ( warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 1 ) assert any( "告警已达到上限" in str(call.args[0]) for call in warning_mock.call_args_list ) warning_mock.reset_mock() single_invalid_row = [ { "timestamp": "2025-12-13T23:00:00+00:00", "platform_id": "telegram", "platform_type": "unknown", "count": "still-bad", }, ] merged_rows = importer._merge_platform_stats_rows(single_invalid_row) duplicate_count = len(single_invalid_row) - len(merged_rows) assert duplicate_count == 0 assert len(merged_rows) == 1 assert merged_rows[0]["count"] == 0 assert warning_mock.call_count == 1 def test_merge_platform_stats_rows_keeps_invalid_timestamps_distinct(self): """测试空/非法 timestamp 不参与聚合,避免误合并""" importer = AstrBotImporter(main_db=MagicMock()) rows = [ { "timestamp": "", "platform_id": "webchat", "platform_type": "unknown", "count": 2, }, { "timestamp": "not-a-datetime", "platform_id": "webchat", "platform_type": "unknown", "count": 3, }, { "timestamp": "not-a-datetime", "platform_id": "webchat", "platform_type": "unknown", "count": 4, }, ] merged_rows = importer._merge_platform_stats_rows(rows) duplicate_count = len(rows) - len(merged_rows) assert duplicate_count == 0 assert len(merged_rows) == 3 assert [row["count"] for row in merged_rows] == [2, 3, 4] def test_merge_platform_stats_rows_keeps_non_string_platform_keys_distinct(self): """测试非字符串 platform_id/platform_type 不参与聚合""" importer = AstrBotImporter(main_db=MagicMock()) rows = [ { "timestamp": "2025-12-13T20:00:00+00:00", "platform_id": None, "platform_type": "unknown", "count": 2, }, { "timestamp": "2025-12-13T20:00:00Z", "platform_id": None, "platform_type": "unknown", "count": 3, }, { "timestamp": "2025-12-13T20:00:00+00:00", "platform_id": "webchat", "platform_type": 1, "count": 4, }, { "timestamp": "2025-12-13T20:00:00Z", "platform_id": "webchat", "platform_type": 1, "count": 5, }, ] merged_rows = importer._merge_platform_stats_rows(rows) duplicate_count = len(rows) - len(merged_rows) assert duplicate_count == 0 assert len(merged_rows) == 4 def test_merge_platform_stats_rows_preserves_input_order(self): """测试 platform_stats 聚合后仍保持输入顺序(按首次出现位置)""" importer = AstrBotImporter(main_db=MagicMock()) rows = [ { "id": 1, "timestamp": "2025-12-13T20:00:00Z", "platform_id": "webchat", "platform_type": "unknown", "count": 2, }, { "id": 2, "timestamp": "", "platform_id": "webchat", "platform_type": "unknown", "count": 3, }, { "id": 3, "timestamp": "2025-12-13T20:00:00+00:00", "platform_id": "webchat", "platform_type": "unknown", "count": 5, }, { "id": 4, "timestamp": "2025-12-13T21:00:00+00:00", "platform_id": "telegram", "platform_type": "unknown", "count": 7, }, ] merged_rows = importer._merge_platform_stats_rows(rows) assert len(merged_rows) == 3 assert [row["id"] for row in merged_rows] == [1, 2, 4] assert merged_rows[0]["count"] == 7 @pytest.mark.asyncio async def test_import_file_not_exists(self, mock_main_db, tmp_path): """测试导入不存在的文件""" importer = AstrBotImporter(main_db=mock_main_db) result = await importer.import_all(str(tmp_path / "nonexistent.zip")) assert result.success is False assert any("不存在" in err for err in result.errors) @pytest.mark.asyncio async def test_import_invalid_zip(self, mock_main_db, tmp_path): """测试导入无效的 ZIP 文件""" # 创建一个无效的文件 invalid_zip = tmp_path / "invalid.zip" invalid_zip.write_text("not a zip file") importer = AstrBotImporter(main_db=mock_main_db) result = await importer.import_all(str(invalid_zip)) assert result.success is False assert any("无效" in err or "ZIP" in err for err in result.errors) @pytest.mark.asyncio async def test_import_missing_manifest(self, mock_main_db, tmp_path): """测试导入缺少 manifest 的 ZIP 文件""" # 创建一个没有 manifest 的 ZIP 文件 zip_path = tmp_path / "no_manifest.zip" with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr("test.txt", "test content") importer = AstrBotImporter(main_db=mock_main_db) result = await importer.import_all(str(zip_path)) assert result.success is False assert any("manifest" in err.lower() for err in result.errors) @pytest.mark.asyncio async def test_import_major_version_mismatch(self, mock_main_db, tmp_path): """测试导入主版本不匹配的备份""" # 创建一个主版本不匹配的备份 zip_path = tmp_path / "old_version.zip" manifest = { "version": "1.0", "astrbot_version": "0.0.1", # 主版本不同 "tables": {"main_db": []}, } with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr("manifest.json", json.dumps(manifest)) importer = AstrBotImporter(main_db=mock_main_db) result = await importer.import_all(str(zip_path)) assert result.success is False assert any("主版本不兼容" in err for err in result.errors) @pytest.mark.asyncio async def test_import_replace_fails_when_clear_main_db_fails( self, mock_main_db, tmp_path ): """测试 replace 模式下主库清空失败会直接终止导入""" zip_path = tmp_path / "valid_backup.zip" manifest = { "version": "1.1", "astrbot_version": VERSION, "tables": {"platform_stats": 0}, } main_data = {"platform_stats": []} with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr("manifest.json", json.dumps(manifest)) zf.writestr("databases/main_db.json", json.dumps(main_data)) importer = AstrBotImporter(main_db=mock_main_db) importer._clear_main_db = AsyncMock( side_effect=DatabaseClearError("清空表 platform_stats 失败: db locked") ) importer._import_main_database = AsyncMock(return_value={}) result = await importer.import_all(str(zip_path), mode="replace") assert result.success is False assert any("清空主数据库失败" in err for err in result.errors) assert any("清空表 platform_stats 失败" in err for err in result.errors) importer._import_main_database.assert_not_awaited() class TestSecureFilename: """安全文件名函数测试""" def test_secure_filename_normal(self): """测试正常文件名""" assert secure_filename("backup.zip") == "backup.zip" assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip" def test_secure_filename_path_traversal(self): """测试路径遍历攻击""" assert ".." not in secure_filename("../../../etc/passwd") assert "/" not in secure_filename("/etc/passwd") assert "\\" not in secure_filename("..\\..\\windows\\system32") def test_secure_filename_with_path(self): """测试带路径的文件名""" result = secure_filename("/path/to/backup.zip") assert result == "backup.zip" result = secure_filename("C:\\Users\\test\\backup.zip") assert result == "backup.zip" def test_secure_filename_special_chars(self): """测试特殊字符""" result = secure_filename('backup<>:"|?*.zip') # 特殊字符应被替换为下划线 assert "<" not in result assert ">" not in result assert ":" not in result assert '"' not in result assert "|" not in result assert "?" not in result assert "*" not in result def test_secure_filename_hidden_file(self): """测试隐藏文件(前导点)""" result = secure_filename(".hidden_backup.zip") assert not result.startswith(".") def test_secure_filename_empty(self): """测试空文件名""" assert secure_filename("") == "backup" assert secure_filename("...") == "backup" def test_generate_unique_filename(self): """测试生成唯一文件名""" result = generate_unique_filename("backup.zip") # 应包含原文件名和时间戳后缀 assert result.startswith("backup_") assert result.endswith(".zip") # 应包含时间戳格式 YYYYMMDD_HHMMSS assert re.search(r"backup_\d{8}_\d{6}\.zip", result) def test_generate_unique_filename_with_complex_name(self): """测试复杂文件名生成唯一文件名""" result = generate_unique_filename("my_backup_file.zip") # 应在原文件名后添加时间戳 assert result.startswith("my_backup_file_") assert result.endswith(".zip") assert re.search(r"my_backup_file_\d{8}_\d{6}\.zip", result) class TestVersionComparison: """版本比较函数测试 - 使用 VersionComparator""" def test_get_major_version_simple(self): """测试提取简单主版本号""" assert _get_major_version("1.0") == "1.0" assert _get_major_version("2.1") == "2.1" assert _get_major_version("4.9.1") == "4.9" def test_get_major_version_with_prefix(self): """测试带 v 前缀的版本号""" assert _get_major_version("v1.0") == "1.0" assert _get_major_version("V4.9.1") == "4.9" def test_get_major_version_with_prerelease(self): """测试带预发布标签的版本号""" assert _get_major_version("4.9.1-beta") == "4.9" assert _get_major_version("4.9.1-alpha.1") == "4.9" assert _get_major_version("4.9.1+build123") == "4.9" def test_get_major_version_single_part(self): """测试单部分版本号""" assert _get_major_version("1") == "1.0" def test_get_major_version_empty(self): """测试空版本号""" assert _get_major_version("") == "0.0" def test_compare_versions_equal(self): """测试版本相等""" assert VersionComparator.compare_version("1.0", "1.0") == 0 assert VersionComparator.compare_version("1.0.0", "1.0") == 0 assert VersionComparator.compare_version("2.10", "2.10") == 0 def test_compare_versions_less_than(self): """测试版本小于""" assert VersionComparator.compare_version("1.0", "1.1") == -1 assert ( VersionComparator.compare_version("1.9", "1.10") == -1 ) # 关键测试:多位数版本比较 assert VersionComparator.compare_version("1.2", "1.10") == -1 assert VersionComparator.compare_version("1.0", "2.0") == -1 def test_compare_versions_greater_than(self): """测试版本大于""" assert VersionComparator.compare_version("1.1", "1.0") == 1 assert ( VersionComparator.compare_version("1.10", "1.9") == 1 ) # 关键测试:多位数版本比较 assert VersionComparator.compare_version("1.10", "1.2") == 1 assert VersionComparator.compare_version("2.0", "1.0") == 1 def test_compare_versions_different_lengths(self): """测试不同长度版本比较""" assert VersionComparator.compare_version("1.0", "1.0.0") == 0 assert VersionComparator.compare_version("1.0", "1.0.1") == -1 assert VersionComparator.compare_version("1.0.1", "1.0") == 1 def test_compare_versions_prerelease(self): """测试预发布版本比较""" # 预发布版本低于正式版本 assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0") == -1 assert VersionComparator.compare_version("1.0.0", "1.0.0-beta") == 1 # alpha < beta assert VersionComparator.compare_version("1.0.0-alpha", "1.0.0-beta") == -1 class TestImportPreCheckResult: """ImportPreCheckResult 类测试""" def test_init_default_values(self): """测试默认值初始化""" result = ImportPreCheckResult() assert result.valid is False assert result.can_import is False assert result.version_status == "" assert result.backup_version == "" assert result.current_version == VERSION assert result.confirm_message == "" assert result.warnings == [] assert result.error == "" assert result.backup_summary == {} def test_to_dict(self): """测试转换为字典""" result = ImportPreCheckResult( valid=True, can_import=True, version_status="match", backup_version="4.9.0", confirm_message="确认导入?", warnings=["警告1"], backup_summary={"tables": ["table1"]}, ) d = result.to_dict() assert d["valid"] is True assert d["can_import"] is True assert d["version_status"] == "match" assert d["backup_version"] == "4.9.0" assert d["confirm_message"] == "确认导入?" assert "警告1" in d["warnings"] assert d["backup_summary"]["tables"] == ["table1"] class TestPreCheck: """预检查功能测试""" def test_pre_check_file_not_exists(self, mock_main_db): """测试预检查不存在的文件""" importer = AstrBotImporter(main_db=mock_main_db) result = importer.pre_check("/nonexistent/file.zip") assert result.valid is False assert "不存在" in result.error def test_pre_check_invalid_zip(self, mock_main_db, tmp_path): """测试预检查无效的 ZIP 文件""" invalid_zip = tmp_path / "invalid.zip" invalid_zip.write_text("not a zip file") importer = AstrBotImporter(main_db=mock_main_db) result = importer.pre_check(str(invalid_zip)) assert result.valid is False assert "ZIP" in result.error or "无效" in result.error def test_pre_check_missing_manifest(self, mock_main_db, tmp_path): """测试预检查缺少 manifest 的 ZIP 文件""" zip_path = tmp_path / "no_manifest.zip" with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr("test.txt", "test content") importer = AstrBotImporter(main_db=mock_main_db) result = importer.pre_check(str(zip_path)) assert result.valid is False assert "manifest" in result.error.lower() def test_pre_check_version_match(self, mock_main_db, tmp_path): """测试预检查版本匹配""" zip_path = tmp_path / "backup.zip" manifest = { "version": "1.1", "astrbot_version": VERSION, "created_at": "2024-01-01T12:00:00", "tables": {"platform_stats": 1}, "has_knowledge_bases": True, "has_config": True, "directories": ["plugins"], } with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr("manifest.json", json.dumps(manifest)) importer = AstrBotImporter(main_db=mock_main_db) result = importer.pre_check(str(zip_path)) assert result.valid is True assert result.can_import is True assert result.version_status == "match" assert result.backup_version == VERSION # confirm_message 现在由前端生成,后端不再生成 assert result.backup_summary["has_knowledge_bases"] is True def test_pre_check_minor_version_diff(self, mock_main_db, tmp_path): """测试预检查小版本差异""" # 构造一个同主版本但小版本不同的版本 major_version = _get_major_version(VERSION) minor_diff_version = f"{major_version}.999" zip_path = tmp_path / "backup.zip" manifest = { "version": "1.1", "astrbot_version": minor_diff_version, "created_at": "2024-01-01T12:00:00", "tables": {}, } with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr("manifest.json", json.dumps(manifest)) importer = AstrBotImporter(main_db=mock_main_db) result = importer.pre_check(str(zip_path)) assert result.valid is True assert result.can_import is True assert result.version_status == "minor_diff" # 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息 # warnings 列表保留用于其他非版本相关的警告 def test_pre_check_major_version_diff(self, mock_main_db, tmp_path): """测试预检查主版本差异""" zip_path = tmp_path / "backup.zip" manifest = { "version": "1.1", "astrbot_version": "0.0.1", # 主版本不同 "created_at": "2024-01-01T12:00:00", "tables": {}, } with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr("manifest.json", json.dumps(manifest)) importer = AstrBotImporter(main_db=mock_main_db) result = importer.pre_check(str(zip_path)) assert result.valid is True # 文件有效 assert result.can_import is False # 但不能导入 assert result.version_status == "major_diff" # 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息 class TestVersionCompatibility: """版本兼容性检查测试""" def test_check_version_compatibility_match(self, mock_main_db): """测试版本完全匹配""" importer = AstrBotImporter(main_db=mock_main_db) result = importer._check_version_compatibility(VERSION) assert result["status"] == "match" assert result["can_import"] is True def test_check_version_compatibility_minor_diff(self, mock_main_db): """测试小版本差异""" major_version = _get_major_version(VERSION) minor_diff_version = f"{major_version}.999" importer = AstrBotImporter(main_db=mock_main_db) result = importer._check_version_compatibility(minor_diff_version) assert result["status"] == "minor_diff" assert result["can_import"] is True def test_check_version_compatibility_major_diff(self, mock_main_db): """测试主版本差异""" importer = AstrBotImporter(main_db=mock_main_db) result = importer._check_version_compatibility("0.0.1") assert result["status"] == "major_diff" assert result["can_import"] is False def test_check_version_compatibility_empty_version(self, mock_main_db): """测试空版本号""" importer = AstrBotImporter(main_db=mock_main_db) result = importer._check_version_compatibility("") assert result["status"] == "major_diff" assert result["can_import"] is False class TestModelMappings: """测试模型映射配置""" def test_main_db_models_not_empty(self): """测试主数据库模型映射非空""" assert len(MAIN_DB_MODELS) > 0 def test_main_db_models_contain_expected_tables(self): """测试主数据库模型映射包含预期的表""" expected_tables = [ "platform_stats", "conversations", "personas", "preferences", "attachments", ] for table in expected_tables: assert table in MAIN_DB_MODELS, f"Missing table: {table}" def test_kb_metadata_models_not_empty(self): """测试知识库元数据模型映射非空""" assert len(KB_METADATA_MODELS) > 0 def test_kb_metadata_models_contain_expected_tables(self): """测试知识库元数据模型映射包含预期的表""" expected_tables = [ "knowledge_bases", "kb_documents", "kb_media", ] for table in expected_tables: assert table in KB_METADATA_MODELS, f"Missing table: {table}" class TestBackupIntegration: """备份集成测试""" @pytest.mark.asyncio async def test_export_import_roundtrip(self, tmp_path): """测试导出-导入往返""" backup_dir = tmp_path / "backups" backup_dir.mkdir() data_dir = tmp_path / "data" data_dir.mkdir() config_path = data_dir / "cmd_config.json" config_path.write_text(json.dumps({"setting": "value"})) attachments_dir = data_dir / "attachments" attachments_dir.mkdir() # 创建模拟数据库 mock_db = MagicMock() session = AsyncMock() result = MagicMock() result.scalars.return_value.all.return_value = [] session.execute = AsyncMock(return_value=result) mock_db.get_db.return_value = AsyncMock( __aenter__=AsyncMock(return_value=session), __aexit__=AsyncMock(return_value=None), ) # 导出 exporter = AstrBotExporter( main_db=mock_db, kb_manager=None, config_path=str(config_path), ) zip_path = await exporter.export_all(output_dir=str(backup_dir)) assert os.path.exists(zip_path) # 验证 ZIP 内容 with zipfile.ZipFile(zip_path, "r") as zf: # 读取 manifest manifest = json.loads(zf.read("manifest.json")) assert manifest["astrbot_version"] == VERSION assert manifest["origin"] == "exported" # 验证备份来源标记 # 读取配置 config = json.loads(zf.read("config/cmd_config.json")) assert config["setting"] == "value" # 读取主数据库 main_db = json.loads(zf.read("databases/main_db.json")) assert "platform_stats" in main_db