File size: 3,602 Bytes
8ede856 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import contextlib
import functools
import importlib.metadata as importlib_metadata
import logging
import os
from collections.abc import Iterator
from packaging.requirements import Requirement
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name,
collect_installed_distribution_versions,
get_requirement_check_paths,
)
logger = logging.getLogger("astrbot")
def _resolve_core_dist_name(core_dist_name: str | None) -> str | None:
if core_dist_name:
try:
importlib_metadata.distribution(core_dist_name)
return core_dist_name
except importlib_metadata.PackageNotFoundError:
return None
try:
importlib_metadata.distribution("AstrBot")
return "AstrBot"
except importlib_metadata.PackageNotFoundError:
pass
if not __package__:
return None
top_pkg = __package__.split(".")[0]
for dist in importlib_metadata.distributions():
try:
top_level = dist.read_text("top_level.txt") or ""
except Exception:
continue
if top_pkg in top_level.splitlines():
if "Name" in dist.metadata:
return dist.metadata["Name"]
return None
@functools.cache
def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]:
try:
resolved_core_dist_name = _resolve_core_dist_name(core_dist_name)
except Exception as exc:
logger.warning("解析核心分发名称失败: %s", exc)
return ()
if not resolved_core_dist_name:
return ()
try:
dist = importlib_metadata.distribution(resolved_core_dist_name)
except importlib_metadata.PackageNotFoundError:
return ()
except Exception as exc:
logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc)
return ()
if not dist or not dist.requires:
return ()
installed = collect_installed_distribution_versions(get_requirement_check_paths())
if not installed:
return ()
constraints: list[str] = []
for req_str in dist.requires:
try:
req = Requirement(req_str)
if req.marker and not req.marker.evaluate():
continue
name = canonicalize_distribution_name(req.name)
if name in installed:
constraints.append(f"{name}=={installed[name]}")
except Exception:
continue
return tuple(constraints)
class CoreConstraintsProvider:
def __init__(self, core_dist_name: str | None) -> None:
self._core_dist_name = core_dist_name
@contextlib.contextmanager
def constraints_file(self) -> Iterator[str | None]:
constraints = _get_core_constraints(self._core_dist_name)
if not constraints:
yield None
return
path: str | None = None
try:
import tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8"
) as f:
f.write("\n".join(constraints))
path = f.name
logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints))
except Exception as exc:
logger.warning("创建临时约束文件失败: %s", exc)
yield None
return
try:
yield path
finally:
if path and os.path.exists(path):
with contextlib.suppress(Exception):
os.remove(path)
|