| import contextlib |
| import functools |
| import importlib.metadata as importlib_metadata |
| import logging |
| import os |
| from collections.abc import Iterator |
|
|
| from packaging.requirements import Requirement |
|
|
| from astrbot.core.utils.requirements_utils import ( |
| canonicalize_distribution_name, |
| collect_installed_distribution_versions, |
| get_requirement_check_paths, |
| ) |
|
|
| logger = logging.getLogger("astrbot") |
|
|
|
|
| def _resolve_core_dist_name(core_dist_name: str | None) -> str | None: |
| if core_dist_name: |
| try: |
| importlib_metadata.distribution(core_dist_name) |
| return core_dist_name |
| except importlib_metadata.PackageNotFoundError: |
| return None |
|
|
| try: |
| importlib_metadata.distribution("AstrBot") |
| return "AstrBot" |
| except importlib_metadata.PackageNotFoundError: |
| pass |
|
|
| if not __package__: |
| return None |
|
|
| top_pkg = __package__.split(".")[0] |
| for dist in importlib_metadata.distributions(): |
| try: |
| top_level = dist.read_text("top_level.txt") or "" |
| except Exception: |
| continue |
| if top_pkg in top_level.splitlines(): |
| if "Name" in dist.metadata: |
| return dist.metadata["Name"] |
|
|
| return None |
|
|
|
|
| @functools.cache |
| def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]: |
| try: |
| resolved_core_dist_name = _resolve_core_dist_name(core_dist_name) |
| except Exception as exc: |
| logger.warning("解析核心分发名称失败: %s", exc) |
| return () |
|
|
| if not resolved_core_dist_name: |
| return () |
|
|
| try: |
| dist = importlib_metadata.distribution(resolved_core_dist_name) |
| except importlib_metadata.PackageNotFoundError: |
| return () |
| except Exception as exc: |
| logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc) |
| return () |
|
|
| if not dist or not dist.requires: |
| return () |
|
|
| installed = collect_installed_distribution_versions(get_requirement_check_paths()) |
| if not installed: |
| return () |
|
|
| constraints: list[str] = [] |
| for req_str in dist.requires: |
| try: |
| req = Requirement(req_str) |
| if req.marker and not req.marker.evaluate(): |
| continue |
| name = canonicalize_distribution_name(req.name) |
| if name in installed: |
| constraints.append(f"{name}=={installed[name]}") |
| except Exception: |
| continue |
|
|
| return tuple(constraints) |
|
|
|
|
| class CoreConstraintsProvider: |
| def __init__(self, core_dist_name: str | None) -> None: |
| self._core_dist_name = core_dist_name |
|
|
| @contextlib.contextmanager |
| def constraints_file(self) -> Iterator[str | None]: |
| constraints = _get_core_constraints(self._core_dist_name) |
| if not constraints: |
| yield None |
| return |
|
|
| path: str | None = None |
| try: |
| import tempfile |
|
|
| with tempfile.NamedTemporaryFile( |
| mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8" |
| ) as f: |
| f.write("\n".join(constraints)) |
| path = f.name |
| logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints)) |
| except Exception as exc: |
| logger.warning("创建临时约束文件失败: %s", exc) |
| yield None |
| return |
|
|
| try: |
| yield path |
| finally: |
| if path and os.path.exists(path): |
| with contextlib.suppress(Exception): |
| os.remove(path) |
|
|