| """AstrBot 数据导入器 |
| |
| 负责从 ZIP 备份文件恢复所有数据。 |
| 导入时进行版本校验: |
| - 主版本(前两位)不同时直接拒绝导入 |
| - 小版本(第三位)不同时提示警告,用户可选择强制导入 |
| - 版本匹配时也需要用户确认 |
| """ |
|
|
| import json |
| import os |
| import shutil |
| import zipfile |
| from dataclasses import dataclass, field |
| from datetime import datetime, timezone |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Any |
|
|
| from sqlalchemy import delete |
|
|
| from astrbot.core import logger |
| from astrbot.core.config.default import VERSION |
| from astrbot.core.db import BaseDatabase |
| from astrbot.core.utils.astrbot_path import ( |
| get_astrbot_data_path, |
| get_astrbot_knowledge_base_path, |
| ) |
| from astrbot.core.utils.version_comparator import VersionComparator |
|
|
| |
| from .constants import ( |
| KB_METADATA_MODELS, |
| MAIN_DB_MODELS, |
| get_backup_directories, |
| ) |
|
|
| if TYPE_CHECKING: |
| from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager |
|
|
|
|
| def _get_major_version(version_str: str) -> str: |
| """提取版本的主版本部分(前两位) |
| |
| Args: |
| version_str: 版本字符串,如 "4.9.1", "4.10.0-beta" |
| |
| Returns: |
| 主版本字符串,如 "4.9", "4.10" |
| """ |
| if not version_str: |
| return "0.0" |
| |
| version = version_str.lower().replace("v", "").split("-")[0].split("+")[0] |
| parts = [p for p in version.split(".") if p] |
| if len(parts) >= 2: |
| return f"{parts[0]}.{parts[1]}" |
| elif len(parts) == 1 and parts[0]: |
| return f"{parts[0]}.0" |
| return "0.0" |
|
|
|
|
| CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") |
| KB_PATH = get_astrbot_knowledge_base_path() |
| DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 |
| PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = ( |
| "ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT" |
| ) |
|
|
|
|
| def _load_platform_stats_invalid_count_warn_limit() -> int: |
| raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV) |
| if raw_value is None: |
| return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT |
|
|
| try: |
| value = int(raw_value) |
| if value < 0: |
| raise ValueError("negative") |
| return value |
| except (TypeError, ValueError): |
| logger.warning( |
| "Invalid env %s=%r, fallback to default %d", |
| PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, |
| raw_value, |
| DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, |
| ) |
| return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT |
|
|
|
|
| PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = ( |
| _load_platform_stats_invalid_count_warn_limit() |
| ) |
|
|
|
|
| class _InvalidCountWarnLimiter: |
| """Rate-limit warnings for invalid platform_stats count values.""" |
|
|
| def __init__(self, limit: int) -> None: |
| self.limit = limit |
| self._count = 0 |
| self._suppression_logged = False |
|
|
| def warn_invalid_count(self, value: Any, key_for_log: tuple[Any, ...]) -> None: |
| if self.limit > 0: |
| if self._count < self.limit: |
| logger.warning( |
| "platform_stats count 非法,已按 0 处理: value=%r, key=%s", |
| value, |
| key_for_log, |
| ) |
| self._count += 1 |
| if self._count == self.limit and not self._suppression_logged: |
| logger.warning( |
| "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", |
| self.limit, |
| ) |
| self._suppression_logged = True |
| return |
|
|
| if not self._suppression_logged: |
| |
| logger.warning( |
| "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", |
| self.limit, |
| ) |
| self._suppression_logged = True |
|
|
|
|
| @dataclass |
| class ImportPreCheckResult: |
| """导入预检查结果 |
| |
| 用于在实际导入前检查备份文件的版本兼容性, |
| 并返回确认信息让用户决定是否继续导入。 |
| """ |
|
|
| |
| valid: bool = False |
| |
| can_import: bool = False |
| |
| version_status: str = "" |
| |
| backup_version: str = "" |
| |
| current_version: str = VERSION |
| |
| backup_time: str = "" |
| |
| confirm_message: str = "" |
| |
| warnings: list[str] = field(default_factory=list) |
| |
| error: str = "" |
| |
| backup_summary: dict = field(default_factory=dict) |
|
|
| def to_dict(self) -> dict: |
| return { |
| "valid": self.valid, |
| "can_import": self.can_import, |
| "version_status": self.version_status, |
| "backup_version": self.backup_version, |
| "current_version": self.current_version, |
| "backup_time": self.backup_time, |
| "confirm_message": self.confirm_message, |
| "warnings": self.warnings, |
| "error": self.error, |
| "backup_summary": self.backup_summary, |
| } |
|
|
|
|
| class ImportResult: |
| """导入结果""" |
|
|
| def __init__(self) -> None: |
| self.success = True |
| self.imported_tables: dict[str, int] = {} |
| self.imported_files: dict[str, int] = {} |
| self.imported_directories: dict[str, int] = {} |
| self.warnings: list[str] = [] |
| self.errors: list[str] = [] |
|
|
| def add_warning(self, msg: str) -> None: |
| self.warnings.append(msg) |
| logger.warning(msg) |
|
|
| def add_error(self, msg: str) -> None: |
| self.errors.append(msg) |
| self.success = False |
| logger.error(msg) |
|
|
| def to_dict(self) -> dict: |
| return { |
| "success": self.success, |
| "imported_tables": self.imported_tables, |
| "imported_files": self.imported_files, |
| "imported_directories": self.imported_directories, |
| "warnings": self.warnings, |
| "errors": self.errors, |
| } |
|
|
|
|
| class DatabaseClearError(RuntimeError): |
| """Raised when clearing the main database in replace mode fails.""" |
|
|
|
|
| class AstrBotImporter: |
| """AstrBot 数据导入器 |
| |
| 导入备份文件中的所有数据,包括: |
| - 主数据库所有表 |
| - 知识库元数据和文档 |
| - 配置文件 |
| - 附件文件 |
| - 知识库多媒体文件 |
| - 插件目录(data/plugins) |
| - 插件数据目录(data/plugin_data) |
| - 配置目录(data/config) |
| - T2I 模板目录(data/t2i_templates) |
| - WebChat 数据目录(data/webchat) |
| - 临时文件目录(data/temp) |
| """ |
|
|
| def __init__( |
| self, |
| main_db: BaseDatabase, |
| kb_manager: "KnowledgeBaseManager | None" = None, |
| config_path: str = CMD_CONFIG_FILE_PATH, |
| kb_root_dir: str = KB_PATH, |
| ) -> None: |
| self.main_db = main_db |
| self.kb_manager = kb_manager |
| self.config_path = config_path |
| self.kb_root_dir = kb_root_dir |
|
|
| def pre_check(self, zip_path: str) -> ImportPreCheckResult: |
| """预检查备份文件 |
| |
| 在实际导入前检查备份文件的有效性和版本兼容性。 |
| 返回检查结果供前端显示确认对话框。 |
| |
| Args: |
| zip_path: ZIP 备份文件路径 |
| |
| Returns: |
| ImportPreCheckResult: 预检查结果 |
| """ |
| result = ImportPreCheckResult() |
| result.current_version = VERSION |
|
|
| if not os.path.exists(zip_path): |
| result.error = f"备份文件不存在: {zip_path}" |
| return result |
|
|
| try: |
| with zipfile.ZipFile(zip_path, "r") as zf: |
| |
| try: |
| manifest_data = zf.read("manifest.json") |
| manifest = json.loads(manifest_data) |
| except KeyError: |
| result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份" |
| return result |
| except json.JSONDecodeError as e: |
| result.error = f"manifest.json 格式错误: {e}" |
| return result |
|
|
| |
| result.backup_version = manifest.get("astrbot_version", "未知") |
| result.backup_time = manifest.get("exported_at", "未知") |
| result.valid = True |
|
|
| |
| result.backup_summary = { |
| "tables": list(manifest.get("tables", {}).keys()), |
| "has_knowledge_bases": manifest.get("has_knowledge_bases", False), |
| "has_config": manifest.get("has_config", False), |
| "directories": manifest.get("directories", []), |
| } |
|
|
| |
| version_check = self._check_version_compatibility(result.backup_version) |
| result.version_status = version_check["status"] |
| result.can_import = version_check["can_import"] |
|
|
| |
| |
| |
|
|
| return result |
|
|
| except zipfile.BadZipFile: |
| result.error = "无效的 ZIP 文件" |
| return result |
| except Exception as e: |
| result.error = f"检查备份文件失败: {e}" |
| return result |
|
|
| def _check_version_compatibility(self, backup_version: str) -> dict: |
| """检查版本兼容性 |
| |
| 规则: |
| - 主版本(前两位,如 4.9)必须一致,否则拒绝 |
| - 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入 |
| |
| Returns: |
| dict: {status, can_import, message} |
| """ |
| if not backup_version: |
| return { |
| "status": "major_diff", |
| "can_import": False, |
| "message": "备份文件缺少版本信息", |
| } |
|
|
| |
| backup_major = _get_major_version(backup_version) |
| current_major = _get_major_version(VERSION) |
|
|
| |
| if VersionComparator.compare_version(backup_major, current_major) != 0: |
| return { |
| "status": "major_diff", |
| "can_import": False, |
| "message": ( |
| f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。" |
| f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。" |
| ), |
| } |
|
|
| |
| version_cmp = VersionComparator.compare_version(backup_version, VERSION) |
| if version_cmp != 0: |
| return { |
| "status": "minor_diff", |
| "can_import": True, |
| "message": ( |
| f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。" |
| ), |
| } |
|
|
| return { |
| "status": "match", |
| "can_import": True, |
| "message": "版本匹配", |
| } |
|
|
| async def import_all( |
| self, |
| zip_path: str, |
| mode: str = "replace", |
| progress_callback: Any | None = None, |
| ) -> ImportResult: |
| """从 ZIP 文件导入所有数据 |
| |
| Args: |
| zip_path: ZIP 备份文件路径 |
| mode: 导入模式,目前仅支持 "replace"(清空后导入) |
| progress_callback: 进度回调函数,接收参数 (stage, current, total, message) |
| |
| Returns: |
| ImportResult: 导入结果 |
| """ |
| result = ImportResult() |
|
|
| if not os.path.exists(zip_path): |
| result.add_error(f"备份文件不存在: {zip_path}") |
| return result |
|
|
| logger.info(f"开始从 {zip_path} 导入备份") |
|
|
| try: |
| with zipfile.ZipFile(zip_path, "r") as zf: |
| |
| if progress_callback: |
| await progress_callback("validate", 0, 100, "正在验证备份文件...") |
|
|
| try: |
| manifest_data = zf.read("manifest.json") |
| manifest = json.loads(manifest_data) |
| except KeyError: |
| result.add_error("备份文件缺少 manifest.json") |
| return result |
| except json.JSONDecodeError as e: |
| result.add_error(f"manifest.json 格式错误: {e}") |
| return result |
|
|
| |
| try: |
| self._validate_version(manifest) |
| except ValueError as e: |
| result.add_error(str(e)) |
| return result |
|
|
| if progress_callback: |
| await progress_callback("validate", 100, 100, "验证完成") |
|
|
| |
| if progress_callback: |
| await progress_callback("main_db", 0, 100, "正在导入主数据库...") |
|
|
| try: |
| main_data_content = zf.read("databases/main_db.json") |
| main_data = json.loads(main_data_content) |
|
|
| if mode == "replace": |
| await self._clear_main_db() |
|
|
| imported = await self._import_main_database(main_data) |
| result.imported_tables.update(imported) |
| except DatabaseClearError as e: |
| result.add_error(f"清空主数据库失败: {e}") |
| return result |
| except Exception as e: |
| result.add_error(f"导入主数据库失败: {e}") |
| return result |
|
|
| if progress_callback: |
| await progress_callback("main_db", 100, 100, "主数据库导入完成") |
|
|
| |
| if self.kb_manager and "databases/kb_metadata.json" in zf.namelist(): |
| if progress_callback: |
| await progress_callback("kb", 0, 100, "正在导入知识库...") |
|
|
| try: |
| kb_meta_content = zf.read("databases/kb_metadata.json") |
| kb_meta_data = json.loads(kb_meta_content) |
|
|
| if mode == "replace": |
| await self._clear_kb_data() |
|
|
| await self._import_knowledge_bases(zf, kb_meta_data, result) |
| except Exception as e: |
| result.add_warning(f"导入知识库失败: {e}") |
|
|
| if progress_callback: |
| await progress_callback("kb", 100, 100, "知识库导入完成") |
|
|
| |
| if progress_callback: |
| await progress_callback("config", 0, 100, "正在导入配置文件...") |
|
|
| if "config/cmd_config.json" in zf.namelist(): |
| try: |
| config_content = zf.read("config/cmd_config.json") |
| |
| if os.path.exists(self.config_path): |
| backup_path = f"{self.config_path}.bak" |
| shutil.copy2(self.config_path, backup_path) |
|
|
| with open(self.config_path, "wb") as f: |
| f.write(config_content) |
| result.imported_files["config"] = 1 |
| except Exception as e: |
| result.add_warning(f"导入配置文件失败: {e}") |
|
|
| if progress_callback: |
| await progress_callback("config", 100, 100, "配置文件导入完成") |
|
|
| |
| if progress_callback: |
| await progress_callback("attachments", 0, 100, "正在导入附件...") |
|
|
| attachment_count = await self._import_attachments( |
| zf, main_data.get("attachments", []) |
| ) |
| result.imported_files["attachments"] = attachment_count |
|
|
| if progress_callback: |
| await progress_callback("attachments", 100, 100, "附件导入完成") |
|
|
| |
| if progress_callback: |
| await progress_callback( |
| "directories", 0, 100, "正在导入插件和数据目录..." |
| ) |
|
|
| dir_stats = await self._import_directories(zf, manifest, result) |
| result.imported_directories = dir_stats |
|
|
| if progress_callback: |
| await progress_callback("directories", 100, 100, "目录导入完成") |
|
|
| logger.info(f"备份导入完成: {result.to_dict()}") |
| return result |
|
|
| except zipfile.BadZipFile: |
| result.add_error("无效的 ZIP 文件") |
| return result |
| except Exception as e: |
| result.add_error(f"导入失败: {e}") |
| return result |
|
|
| def _validate_version(self, manifest: dict) -> None: |
| """验证版本兼容性 - 仅允许相同主版本导入 |
| |
| 注意:此方法仅在 import_all 中调用,用于双重校验。 |
| 前端应先调用 pre_check 获取详细的版本信息并让用户确认。 |
| """ |
| backup_version = manifest.get("astrbot_version") |
| if not backup_version: |
| raise ValueError("备份文件缺少版本信息") |
|
|
| |
| version_check = self._check_version_compatibility(backup_version) |
|
|
| if version_check["status"] == "major_diff": |
| raise ValueError(version_check["message"]) |
|
|
| |
| if version_check["status"] == "minor_diff": |
| logger.warning(f"版本差异警告: {version_check['message']}") |
|
|
| async def _clear_main_db(self) -> None: |
| """清空主数据库所有表""" |
| async with self.main_db.get_db() as session: |
| async with session.begin(): |
| for table_name, model_class in MAIN_DB_MODELS.items(): |
| try: |
| await session.execute(delete(model_class)) |
| logger.debug(f"已清空表 {table_name}") |
| except Exception as e: |
| raise DatabaseClearError( |
| f"清空表 {table_name} 失败: {e}" |
| ) from e |
|
|
| async def _clear_kb_data(self) -> None: |
| """清空知识库数据""" |
| if not self.kb_manager: |
| return |
|
|
| |
| async with self.kb_manager.kb_db.get_db() as session: |
| async with session.begin(): |
| for table_name, model_class in KB_METADATA_MODELS.items(): |
| try: |
| await session.execute(delete(model_class)) |
| logger.debug(f"已清空知识库表 {table_name}") |
| except Exception as e: |
| logger.warning(f"清空知识库表 {table_name} 失败: {e}") |
|
|
| |
| for kb_id in list(self.kb_manager.kb_insts.keys()): |
| try: |
| kb_helper = self.kb_manager.kb_insts[kb_id] |
| await kb_helper.terminate() |
| if kb_helper.kb_dir.exists(): |
| shutil.rmtree(kb_helper.kb_dir) |
| except Exception as e: |
| logger.warning(f"清理知识库 {kb_id} 失败: {e}") |
|
|
| self.kb_manager.kb_insts.clear() |
|
|
| async def _import_main_database( |
| self, data: dict[str, list[dict]] |
| ) -> dict[str, int]: |
| """导入主数据库数据""" |
| imported: dict[str, int] = {} |
|
|
| async with self.main_db.get_db() as session: |
| async with session.begin(): |
| for table_name, rows in data.items(): |
| model_class = MAIN_DB_MODELS.get(table_name) |
| if not model_class: |
| logger.warning(f"未知的表: {table_name}") |
| continue |
| normalized_rows = self._preprocess_main_table_rows(table_name, rows) |
|
|
| count = 0 |
| for row in normalized_rows: |
| try: |
| |
| row = self._convert_datetime_fields(row, model_class) |
| obj = model_class(**row) |
| session.add(obj) |
| count += 1 |
| except Exception as e: |
| logger.warning(f"导入记录到 {table_name} 失败: {e}") |
|
|
| imported[table_name] = count |
| logger.debug(f"导入表 {table_name}: {count} 条记录") |
|
|
| return imported |
|
|
| def _preprocess_main_table_rows( |
| self, table_name: str, rows: list[dict[str, Any]] |
| ) -> list[dict[str, Any]]: |
| if table_name == "platform_stats": |
| normalized_rows = self._merge_platform_stats_rows(rows) |
| duplicate_count = len(rows) - len(normalized_rows) |
| if duplicate_count > 0: |
| logger.warning( |
| "检测到 %s 重复键 %d 条,已在导入前聚合", |
| table_name, |
| duplicate_count, |
| ) |
| return normalized_rows |
| return rows |
|
|
| def _merge_platform_stats_rows( |
| self, rows: list[dict[str, Any]] |
| ) -> list[dict[str, Any]]: |
| """Merge duplicate platform_stats rows by normalized timestamp/platform key. |
| |
| Note: |
| - Invalid/empty timestamps are kept as distinct rows to avoid accidental merging. |
| - Non-string platform_id/platform_type are kept as distinct rows. |
| - Invalid count warnings are rate-limited per function invocation. |
| """ |
| merged: dict[tuple[str, str, str], dict[str, Any]] = {} |
| result: list[dict[str, Any]] = [] |
| warn_limiter = _InvalidCountWarnLimiter(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT) |
|
|
| for row in rows: |
| normalized_row, normalized_timestamp, count = ( |
| self._normalize_platform_stats_entry(row, warn_limiter) |
| ) |
| platform_id = normalized_row.get("platform_id") |
| platform_type = normalized_row.get("platform_type") |
|
|
| if ( |
| normalized_timestamp is None |
| or not isinstance(platform_id, str) |
| or not isinstance(platform_type, str) |
| ): |
| result.append(normalized_row) |
| continue |
|
|
| merge_key = (normalized_timestamp, platform_id, platform_type) |
| existing = merged.get(merge_key) |
| if existing is None: |
| merged[merge_key] = normalized_row |
| result.append(normalized_row) |
| else: |
| existing["count"] += count |
|
|
| return result |
|
|
| def _normalize_platform_stats_entry( |
| self, |
| row: dict[str, Any], |
| warn_limiter: _InvalidCountWarnLimiter, |
| ) -> tuple[dict[str, Any], str | None, int]: |
| normalized_row = dict(row) |
| raw_timestamp = normalized_row.get("timestamp") |
| normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp) |
|
|
| if normalized_timestamp is not None: |
| normalized_row["timestamp"] = normalized_timestamp |
| elif isinstance(raw_timestamp, str): |
| normalized_row["timestamp"] = raw_timestamp.strip() |
| elif raw_timestamp is None: |
| normalized_row["timestamp"] = "" |
| else: |
| normalized_row["timestamp"] = str(raw_timestamp) |
|
|
| raw_count = normalized_row.get("count", 0) |
| try: |
| count = int(raw_count) |
| except (TypeError, ValueError): |
| key_for_log = ( |
| normalized_row.get("timestamp"), |
| repr(normalized_row.get("platform_id")), |
| repr(normalized_row.get("platform_type")), |
| ) |
| warn_limiter.warn_invalid_count(raw_count, key_for_log) |
| count = 0 |
|
|
| normalized_row["count"] = count |
| return normalized_row, normalized_timestamp, count |
|
|
| def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: |
| if isinstance(value, datetime): |
| dt = value |
| if dt.tzinfo is None: |
| dt = dt.replace(tzinfo=timezone.utc) |
| else: |
| dt = dt.astimezone(timezone.utc) |
| return dt.isoformat() |
| if isinstance(value, str): |
| timestamp = value.strip() |
| if not timestamp: |
| return None |
| if timestamp.endswith("Z"): |
| timestamp = f"{timestamp[:-1]}+00:00" |
| try: |
| dt = datetime.fromisoformat(timestamp) |
| if dt.tzinfo is None: |
| dt = dt.replace(tzinfo=timezone.utc) |
| else: |
| dt = dt.astimezone(timezone.utc) |
| return dt.isoformat() |
| except ValueError: |
| return None |
| return None |
|
|
| async def _import_knowledge_bases( |
| self, |
| zf: zipfile.ZipFile, |
| kb_meta_data: dict[str, list[dict]], |
| result: ImportResult, |
| ) -> None: |
| """导入知识库数据""" |
| if not self.kb_manager: |
| return |
|
|
| |
| async with self.kb_manager.kb_db.get_db() as session: |
| async with session.begin(): |
| for table_name, rows in kb_meta_data.items(): |
| model_class = KB_METADATA_MODELS.get(table_name) |
| if not model_class: |
| continue |
|
|
| count = 0 |
| for row in rows: |
| try: |
| row = self._convert_datetime_fields(row, model_class) |
| obj = model_class(**row) |
| session.add(obj) |
| count += 1 |
| except Exception as e: |
| logger.warning(f"导入知识库记录到 {table_name} 失败: {e}") |
|
|
| result.imported_tables[f"kb_{table_name}"] = count |
|
|
| |
| for kb_data in kb_meta_data.get("knowledge_bases", []): |
| kb_id = kb_data.get("kb_id") |
| if not kb_id: |
| continue |
|
|
| |
| kb_dir = Path(self.kb_root_dir) / kb_id |
| kb_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| doc_path = f"databases/kb_{kb_id}/documents.json" |
| if doc_path in zf.namelist(): |
| try: |
| doc_content = zf.read(doc_path) |
| doc_data = json.loads(doc_content) |
|
|
| |
| await self._import_kb_documents(kb_id, doc_data) |
| except Exception as e: |
| result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}") |
|
|
| |
| faiss_path = f"databases/kb_{kb_id}/index.faiss" |
| if faiss_path in zf.namelist(): |
| try: |
| target_path = kb_dir / "index.faiss" |
| with zf.open(faiss_path) as src, open(target_path, "wb") as dst: |
| dst.write(src.read()) |
| except Exception as e: |
| result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}") |
|
|
| |
| media_prefix = f"files/kb_media/{kb_id}/" |
| for name in zf.namelist(): |
| if name.startswith(media_prefix): |
| try: |
| rel_path = name[len(media_prefix) :] |
| target_path = kb_dir / rel_path |
| target_path.parent.mkdir(parents=True, exist_ok=True) |
| with zf.open(name) as src, open(target_path, "wb") as dst: |
| dst.write(src.read()) |
| except Exception as e: |
| result.add_warning(f"导入媒体文件 {name} 失败: {e}") |
|
|
| |
| await self.kb_manager.load_kbs() |
|
|
| async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None: |
| """导入知识库文档到向量数据库""" |
| from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage |
|
|
| kb_dir = Path(self.kb_root_dir) / kb_id |
| doc_db_path = kb_dir / "doc.db" |
|
|
| |
| doc_storage = DocumentStorage(str(doc_db_path)) |
| await doc_storage.initialize() |
|
|
| try: |
| documents = doc_data.get("documents", []) |
| for doc in documents: |
| try: |
| await doc_storage.insert_document( |
| doc_id=doc.get("doc_id", ""), |
| text=doc.get("text", ""), |
| metadata=json.loads(doc.get("metadata", "{}")), |
| ) |
| except Exception as e: |
| logger.warning(f"导入文档块失败: {e}") |
| finally: |
| await doc_storage.close() |
|
|
| async def _import_attachments( |
| self, |
| zf: zipfile.ZipFile, |
| attachments: list[dict], |
| ) -> int: |
| """导入附件文件""" |
| count = 0 |
|
|
| attachments_dir = Path(self.config_path).parent / "attachments" |
| attachments_dir.mkdir(parents=True, exist_ok=True) |
|
|
| attachment_prefix = "files/attachments/" |
| for name in zf.namelist(): |
| if name.startswith(attachment_prefix) and name != attachment_prefix: |
| try: |
| |
| attachment_id = os.path.splitext(os.path.basename(name))[0] |
| original_path = None |
| for att in attachments: |
| if att.get("attachment_id") == attachment_id: |
| original_path = att.get("path") |
| break |
|
|
| if original_path: |
| target_path = Path(original_path) |
| else: |
| target_path = attachments_dir / os.path.basename(name) |
|
|
| target_path.parent.mkdir(parents=True, exist_ok=True) |
| with zf.open(name) as src, open(target_path, "wb") as dst: |
| dst.write(src.read()) |
| count += 1 |
| except Exception as e: |
| logger.warning(f"导入附件 {name} 失败: {e}") |
|
|
| return count |
|
|
| async def _import_directories( |
| self, |
| zf: zipfile.ZipFile, |
| manifest: dict, |
| result: ImportResult, |
| ) -> dict[str, int]: |
| """导入插件和其他数据目录 |
| |
| Args: |
| zf: ZIP 文件对象 |
| manifest: 备份清单 |
| result: 导入结果对象 |
| |
| Returns: |
| dict: 每个目录导入的文件数量 |
| """ |
| dir_stats: dict[str, int] = {} |
|
|
| |
| backup_version = manifest.get("version", "1.0") |
| if VersionComparator.compare_version(backup_version, "1.1") < 0: |
| logger.info("备份版本不支持目录备份,跳过目录导入") |
| return dir_stats |
|
|
| backed_up_dirs = manifest.get("directories", []) |
| backup_directories = get_backup_directories() |
|
|
| for dir_name in backed_up_dirs: |
| if dir_name not in backup_directories: |
| result.add_warning(f"未知的目录类型: {dir_name}") |
| continue |
|
|
| target_dir = Path(backup_directories[dir_name]) |
| archive_prefix = f"directories/{dir_name}/" |
|
|
| file_count = 0 |
|
|
| try: |
| |
| dir_files = [ |
| name |
| for name in zf.namelist() |
| if name.startswith(archive_prefix) and name != archive_prefix |
| ] |
|
|
| if not dir_files: |
| continue |
|
|
| |
| if target_dir.exists(): |
| backup_path = Path(f"{target_dir}.bak") |
| if backup_path.exists(): |
| shutil.rmtree(backup_path) |
| shutil.move(str(target_dir), str(backup_path)) |
| logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}") |
|
|
| |
| target_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| for name in dir_files: |
| try: |
| |
| rel_path = name[len(archive_prefix) :] |
| if not rel_path: |
| continue |
|
|
| target_path = target_dir / rel_path |
| target_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| with zf.open(name) as src, open(target_path, "wb") as dst: |
| dst.write(src.read()) |
| file_count += 1 |
| except Exception as e: |
| result.add_warning(f"导入文件 {name} 失败: {e}") |
|
|
| dir_stats[dir_name] = file_count |
| logger.debug(f"导入目录 {dir_name}: {file_count} 个文件") |
|
|
| except Exception as e: |
| result.add_warning(f"导入目录 {dir_name} 失败: {e}") |
| dir_stats[dir_name] = 0 |
|
|
| return dir_stats |
|
|
| def _convert_datetime_fields(self, row: dict, model_class: type) -> dict: |
| """转换 datetime 字符串字段为 datetime 对象""" |
| result = row.copy() |
|
|
| |
| from sqlalchemy import inspect as sa_inspect |
|
|
| try: |
| mapper = sa_inspect(model_class) |
| for column in mapper.columns: |
| if column.name in result and result[column.name] is not None: |
| |
| from sqlalchemy import DateTime |
|
|
| if isinstance(column.type, DateTime): |
| value = result[column.name] |
| if isinstance(value, str): |
| |
| result[column.name] = datetime.fromisoformat(value) |
| except Exception: |
| pass |
|
|
| return result |
|
|