| import importlib.metadata as importlib_metadata |
| import logging |
| import os |
| import re |
| import shlex |
| import sys |
| from collections.abc import Iterable, Iterator, Sequence |
| from dataclasses import dataclass |
|
|
| from packaging.requirements import InvalidRequirement, Requirement |
| from packaging.specifiers import SpecifierSet |
| from packaging.version import InvalidVersion, Version |
|
|
| from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path |
| from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime |
|
|
| logger = logging.getLogger("astrbot") |
|
|
|
|
| class RequirementsPrecheckFailed(Exception): |
| """Raised when the pre-check of requirements fails.""" |
|
|
| pass |
|
|
|
|
| @dataclass(frozen=True) |
| class ParsedPackageInput: |
| specs: tuple[str, ...] |
| requirement_names: frozenset[str] |
|
|
|
|
| @dataclass(frozen=True) |
| class MissingRequirementsPlan: |
| missing_names: frozenset[str] |
| install_lines: tuple[str, ...] |
| fallback_reason: str | None = None |
|
|
|
|
| def canonicalize_distribution_name(name: str) -> str: |
| return re.sub(r"[-_.]+", "-", name).strip("-").lower() |
|
|
|
|
| def strip_inline_requirement_comment(raw_input: str) -> str: |
| if raw_input.lstrip().startswith("#"): |
| return "" |
| return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip() |
|
|
|
|
| def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool: |
| try: |
| parsed_version = Version(version) |
| except InvalidVersion: |
| return False |
| return specifier.contains(parsed_version, prereleases=True) |
|
|
|
|
| def _looks_like_local_path_reference(token: str) -> bool: |
| candidate = token.strip() |
| if not candidate: |
| return False |
| return candidate in {".", ".."} or candidate.startswith( |
| ("./", "../", "/", "~/", ".\\", "..\\", "\\") |
| ) |
|
|
|
|
| def looks_like_direct_reference(token: str) -> bool: |
| candidate = token.strip() |
| if not candidate: |
| return False |
| return ( |
| _looks_like_local_path_reference(candidate) |
| or candidate.startswith("git+") |
| or "://" in candidate |
| ) |
|
|
|
|
| def extract_requirement_name(raw_requirement: str) -> str | None: |
| line = raw_requirement.split("#", 1)[0].strip() |
| if not line: |
| return None |
| if line.startswith(("-r", "--requirement", "-c", "--constraint")): |
| return None |
|
|
| egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement) |
| if egg_match: |
| return canonicalize_distribution_name(egg_match.group(1)) |
|
|
| if line.startswith("-"): |
| return None |
|
|
| candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip() |
| if not candidate: |
| return None |
| return canonicalize_distribution_name(candidate) |
|
|
|
|
| def _parse_editable_or_direct_name(target: str) -> str | None: |
| name = extract_requirement_name(target) |
| if not name: |
| return None |
| if "#egg=" in target or not looks_like_direct_reference(target): |
| return name |
| return None |
|
|
|
|
| def _parse_requirement_name_and_spec( |
| line: str, |
| ) -> tuple[str | None, SpecifierSet | None]: |
| if line.startswith(("-c", "--constraint")): |
| return None, None |
|
|
| try: |
| req = Requirement(line) |
| except InvalidRequirement: |
| tokens = shlex.split(line) |
| if not tokens: |
| return None, None |
|
|
| editable_target: str | None = None |
| if tokens[0] in {"-e", "--editable"} and len(tokens) > 1: |
| editable_target = tokens[1] |
| elif tokens[0].startswith("--editable="): |
| editable_target = tokens[0].split("=", 1)[1] |
|
|
| if editable_target: |
| name = _parse_editable_or_direct_name(editable_target) |
| return (name, None) if name else (None, None) |
|
|
| name = _parse_editable_or_direct_name(line) |
| return (name, None) if name else (None, None) |
|
|
| if req.marker and not req.marker.evaluate(): |
| return None, None |
|
|
| return canonicalize_distribution_name(req.name), (req.specifier or None) |
|
|
|
|
| def _parse_requirement_line( |
| line: str, |
| ) -> tuple[str, SpecifierSet | None] | None: |
| name, specifier = _parse_requirement_name_and_spec(line) |
| return (name, specifier) if name else None |
|
|
|
|
| def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]: |
| requirement_names: set[str] = set() |
| skip_next_for: str | None = None |
|
|
| for token in tokens: |
| if skip_next_for: |
| if skip_next_for == "editable": |
| name = _parse_editable_or_direct_name(token) |
| if name: |
| requirement_names.add(name) |
| skip_next_for = None |
| continue |
|
|
| if token in {"-e", "--editable"}: |
| skip_next_for = "editable" |
| continue |
|
|
| if token in { |
| "-i", |
| "--index-url", |
| "--extra-index-url", |
| "-f", |
| "--find-links", |
| "--trusted-host", |
| "-r", |
| "--requirement", |
| "-c", |
| "--constraint", |
| }: |
| skip_next_for = "option-value" |
| continue |
|
|
| if token.startswith(("--editable=",)): |
| editable_target = token.split("=", 1)[1] |
| name = _parse_editable_or_direct_name(editable_target) |
| if name: |
| requirement_names.add(name) |
| continue |
|
|
| if token.startswith( |
| ( |
| "--index-url=", |
| "--extra-index-url=", |
| "--find-links=", |
| "--trusted-host=", |
| "--requirement=", |
| "--constraint=", |
| ) |
| ): |
| continue |
|
|
| if ( |
| (token.startswith("-i") and token != "-i") |
| or (token.startswith("-f") and token != "-f") |
| or token == "--no-index" |
| ): |
| continue |
|
|
| if token.startswith("-"): |
| continue |
|
|
| name, _ = _parse_requirement_name_and_spec(token) |
| if name: |
| requirement_names.add(name) |
|
|
| return frozenset(requirement_names) |
|
|
|
|
| def parse_package_install_input(raw_input: str) -> ParsedPackageInput: |
| specs: list[str] = [] |
| requirement_names: set[str] = set() |
| normalized = raw_input.strip() |
| if not normalized: |
| return ParsedPackageInput(specs=(), requirement_names=frozenset()) |
|
|
| for raw_line in normalized.splitlines(): |
| line = strip_inline_requirement_comment(raw_line) |
| if not line: |
| continue |
|
|
| try: |
| Requirement(line) |
| except InvalidRequirement: |
| tokens = shlex.split(line) |
| if not tokens: |
| continue |
| specs.extend(tokens) |
| requirement_names.update( |
| _extract_requirement_names_from_package_tokens(tokens) |
| ) |
| continue |
|
|
| specs.append(line) |
| name, _ = _parse_requirement_name_and_spec(line) |
| if name: |
| requirement_names.add(name) |
|
|
| return ParsedPackageInput( |
| specs=tuple(specs), |
| requirement_names=frozenset(requirement_names), |
| ) |
|
|
|
|
| def _iter_requirement_lines( |
| requirements_path: str, |
| _visited: set[str] | None = None, |
| ) -> Iterator[str]: |
| visited = _visited or set() |
| resolved_path = os.path.realpath(requirements_path) |
| if resolved_path in visited: |
| logger.warning( |
| "检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path |
| ) |
| return |
| visited.add(resolved_path) |
|
|
| with open(resolved_path, encoding="utf-8") as f: |
| for raw_line in f: |
| line = strip_inline_requirement_comment(raw_line) |
| if not line: |
| continue |
|
|
| tokens = shlex.split(line) |
| if not tokens: |
| continue |
|
|
| nested: str | None = None |
| if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1: |
| nested = tokens[1] |
| elif tokens[0].startswith("--requirement="): |
| nested = tokens[0].split("=", 1)[1] |
|
|
| if nested: |
| if not os.path.isabs(nested): |
| nested = os.path.join(os.path.dirname(resolved_path), nested) |
| yield from _iter_requirement_lines(nested, _visited=visited) |
| continue |
|
|
| yield line |
|
|
|
|
| def iter_requirements( |
| requirements_path: str | None = None, |
| lines: Iterable[str] | None = None, |
| ) -> Iterator[tuple[str, SpecifierSet | None]]: |
| if lines is None: |
| if requirements_path is None: |
| raise ValueError("Either requirements_path or lines must be provided") |
| lines = _iter_requirement_lines(requirements_path) |
|
|
| for line in lines: |
| parsed = _parse_requirement_line(line) |
| if parsed is not None: |
| yield parsed |
|
|
|
|
| def extract_requirement_names(requirements_path: str) -> set[str]: |
| try: |
| return { |
| name for name, _ in iter_requirements(requirements_path=requirements_path) |
| } |
| except Exception as exc: |
| logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc) |
| return set() |
|
|
|
|
| def get_requirement_check_paths() -> list[str]: |
| paths = list(sys.path) |
| if is_packaged_desktop_runtime(): |
| target_site_packages = get_astrbot_site_packages_path() |
| if os.path.isdir(target_site_packages): |
| paths.insert(0, target_site_packages) |
| return paths |
|
|
|
|
| def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]: |
| distribution_name = ( |
| distribution.metadata["Name"] if "Name" in distribution.metadata else None |
| ) |
| if not distribution_name: |
| return None, None |
| return canonicalize_distribution_name(distribution_name), distribution.version |
|
|
|
|
| def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None: |
| installed: dict[str, str] = {} |
| try: |
| for distribution in importlib_metadata.distributions(path=paths): |
| distribution_name, version = _canonical_distribution_identity(distribution) |
| if not distribution_name or not version: |
| continue |
| installed.setdefault(distribution_name, version) |
| except Exception as exc: |
| logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc) |
| return None |
| return installed |
|
|
|
|
| def _load_requirement_lines_for_precheck( |
| requirements_path: str, |
| ) -> tuple[bool, list[str] | None]: |
| try: |
| requirement_lines = list(_iter_requirement_lines(requirements_path)) |
| except Exception as exc: |
| logger.warning( |
| "预检查缺失依赖失败,将回退到完整安装: %s (%s)", |
| requirements_path, |
| exc, |
| ) |
| return False, None |
|
|
| fallback_line = next( |
| ( |
| line |
| for line in requirement_lines |
| if ( |
| ( |
| line.startswith(("-e ", "--editable ", "--editable=")) |
| and "#egg=" not in line |
| ) |
| or ( |
| _parse_requirement_line(line) is None |
| and looks_like_direct_reference(line) |
| ) |
| ) |
| ), |
| None, |
| ) |
| if fallback_line is not None: |
| logger.info( |
| "缺失依赖预检查发现无法安全裁剪的 option/direct-reference 行,将回退到完整安装: %s (%s)", |
| requirements_path, |
| fallback_line, |
| ) |
| return False, None |
|
|
| return True, requirement_lines |
|
|
|
|
| def find_missing_requirements(requirements_path: str) -> set[str] | None: |
| can_precheck, requirement_lines = _load_requirement_lines_for_precheck( |
| requirements_path |
| ) |
| if not can_precheck or requirement_lines is None: |
| return None |
|
|
| return find_missing_requirements_from_lines(requirement_lines) |
|
|
|
|
| def find_missing_requirements_from_lines( |
| requirement_lines: Sequence[str], |
| ) -> set[str] | None: |
|
|
| required = list(iter_requirements(lines=requirement_lines)) |
| if not required: |
| return set() |
|
|
| installed = collect_installed_distribution_versions(get_requirement_check_paths()) |
| if installed is None: |
| return None |
|
|
| missing: set[str] = set() |
| for name, specifier in required: |
| installed_version = installed.get(name) |
| if not installed_version: |
| missing.add(name) |
| continue |
| if specifier and not _specifier_contains_version(specifier, installed_version): |
| missing.add(name) |
|
|
| return missing |
|
|
|
|
| def build_missing_requirements_install_lines( |
| requirements_path: str, |
| requirement_lines: Sequence[str], |
| missing_names: set[str] | frozenset[str], |
| ) -> tuple[str, ...] | None: |
| wanted_names = set(missing_names) |
| install_lines: list[str] = [] |
| for line in requirement_lines: |
| parsed = _parse_requirement_line(line) |
| if parsed is None: |
| if looks_like_direct_reference(line) or line.startswith(("-", "--")): |
| logger.debug( |
| "缺失依赖行筛选回退到完整安装:requirements 中包含无法安全裁剪的 option/direct-reference 行: %s (%s)", |
| requirements_path, |
| line, |
| ) |
| return None |
| continue |
|
|
| name, _specifier = parsed |
| if name in wanted_names: |
| install_lines.append(line) |
|
|
| return tuple(install_lines) |
|
|
|
|
| def plan_missing_requirements_install( |
| requirements_path: str, |
| ) -> MissingRequirementsPlan | None: |
| can_precheck, requirement_lines = _load_requirement_lines_for_precheck( |
| requirements_path |
| ) |
| if not can_precheck or requirement_lines is None: |
| return None |
|
|
| missing = find_missing_requirements_from_lines(requirement_lines) |
| if missing is None: |
| return None |
|
|
| install_lines = build_missing_requirements_install_lines( |
| requirements_path, |
| requirement_lines, |
| missing, |
| ) |
| if install_lines is None: |
| return None |
| if missing and not install_lines: |
| logger.warning( |
| "预检查缺失依赖成功,但无法映射到可安装 requirement 行,将回退到完整安装: %s -> %s", |
| requirements_path, |
| sorted(missing), |
| ) |
| return MissingRequirementsPlan( |
| missing_names=frozenset(missing), |
| install_lines=(), |
| fallback_reason="unmapped missing requirement names", |
| ) |
|
|
| return MissingRequirementsPlan( |
| missing_names=frozenset(missing), |
| install_lines=install_lines, |
| ) |
|
|
|
|
| def find_missing_requirements_or_raise(requirements_path: str) -> set[str]: |
| missing = find_missing_requirements(requirements_path) |
| if missing is None: |
| raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}") |
| return missing |
|
|