| from __future__ import annotations |
| import argparse |
| import logging |
| import os |
| import re |
| import tempfile |
| import zipfile |
| from dataclasses import dataclass |
| from functools import cached_property |
| from pathlib import Path |
| from typing import TypedDict, Optional |
|
|
| import requests |
| from typing_extensions import NotRequired |
| from comfy.cli_args import DEFAULT_VERSION_STRING |
|
|
|
|
| REQUEST_TIMEOUT = 10 |
|
|
|
|
| class Asset(TypedDict): |
| url: str |
|
|
|
|
| class Release(TypedDict): |
| id: int |
| tag_name: str |
| name: str |
| prerelease: bool |
| created_at: str |
| published_at: str |
| body: str |
| assets: NotRequired[list[Asset]] |
|
|
|
|
| @dataclass |
| class FrontEndProvider: |
| owner: str |
| repo: str |
|
|
| @property |
| def folder_name(self) -> str: |
| return f"{self.owner}_{self.repo}" |
|
|
| @property |
| def release_url(self) -> str: |
| return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases" |
|
|
| @cached_property |
| def all_releases(self) -> list[Release]: |
| releases = [] |
| api_url = self.release_url |
| while api_url: |
| response = requests.get(api_url, timeout=REQUEST_TIMEOUT) |
| response.raise_for_status() |
| releases.extend(response.json()) |
| |
| if "next" in response.links: |
| api_url = response.links["next"]["url"] |
| else: |
| api_url = None |
| return releases |
|
|
| @cached_property |
| def latest_release(self) -> Release: |
| latest_release_url = f"{self.release_url}/latest" |
| response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT) |
| response.raise_for_status() |
| return response.json() |
|
|
| def get_release(self, version: str) -> Release: |
| if version == "latest": |
| return self.latest_release |
| else: |
| for release in self.all_releases: |
| if release["tag_name"] in [version, f"v{version}"]: |
| return release |
| raise ValueError(f"Version {version} not found in releases") |
|
|
|
|
| def download_release_asset_zip(release: Release, destination_path: str) -> None: |
| """Download dist.zip from github release.""" |
| asset_url = None |
| for asset in release.get("assets", []): |
| if asset["name"] == "dist.zip": |
| asset_url = asset["url"] |
| break |
|
|
| if not asset_url: |
| raise ValueError("dist.zip not found in the release assets") |
|
|
| |
| with tempfile.TemporaryFile() as tmp_file: |
| headers = {"Accept": "application/octet-stream"} |
| response = requests.get( |
| asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT |
| ) |
| response.raise_for_status() |
|
|
| |
| tmp_file.write(response.content) |
|
|
| |
| tmp_file.seek(0) |
|
|
| |
| with zipfile.ZipFile(tmp_file, "r") as zip_ref: |
| zip_ref.extractall(destination_path) |
|
|
|
|
| class FrontendManager: |
| DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web") |
| CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") |
|
|
| @classmethod |
| def parse_version_string(cls, value: str) -> tuple[str, str, str]: |
| """ |
| Args: |
| value (str): The version string to parse. |
| |
| Returns: |
| tuple[str, str]: A tuple containing provider name and version. |
| |
| Raises: |
| argparse.ArgumentTypeError: If the version string is invalid. |
| """ |
| VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$" |
| match_result = re.match(VERSION_PATTERN, value) |
| if match_result is None: |
| raise argparse.ArgumentTypeError(f"Invalid version string: {value}") |
|
|
| return match_result.group(1), match_result.group(2), match_result.group(3) |
|
|
| @classmethod |
| def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: |
| """ |
| Initializes the frontend for the specified version. |
| |
| Args: |
| version_string (str): The version string. |
| provider (FrontEndProvider, optional): The provider to use. Defaults to None. |
| |
| Returns: |
| str: The path to the initialized frontend. |
| |
| Raises: |
| Exception: If there is an error during the initialization process. |
| main error source might be request timeout or invalid URL. |
| """ |
| if version_string == DEFAULT_VERSION_STRING: |
| return cls.DEFAULT_FRONTEND_PATH |
|
|
| repo_owner, repo_name, version = cls.parse_version_string(version_string) |
|
|
| if version.startswith("v"): |
| expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v")) |
| if os.path.exists(expected_path): |
| logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}") |
| return expected_path |
|
|
| logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...") |
|
|
| provider = provider or FrontEndProvider(repo_owner, repo_name) |
| release = provider.get_release(version) |
|
|
| semantic_version = release["tag_name"].lstrip("v") |
| web_root = str( |
| Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version |
| ) |
| if not os.path.exists(web_root): |
| try: |
| os.makedirs(web_root, exist_ok=True) |
| logging.info( |
| "Downloading frontend(%s) version(%s) to (%s)", |
| provider.folder_name, |
| semantic_version, |
| web_root, |
| ) |
| logging.debug(release) |
| download_release_asset_zip(release, destination_path=web_root) |
| finally: |
| |
| if not os.listdir(web_root): |
| os.rmdir(web_root) |
|
|
| return web_root |
|
|
| @classmethod |
| def init_frontend(cls, version_string: str) -> str: |
| """ |
| Initializes the frontend with the specified version string. |
| |
| Args: |
| version_string (str): The version string to initialize the frontend with. |
| |
| Returns: |
| str: The path of the initialized frontend. |
| """ |
| try: |
| return cls.init_frontend_unsafe(version_string) |
| except Exception as e: |
| logging.error("Failed to initialize frontend: %s", e) |
| logging.info("Falling back to the default frontend.") |
| return cls.DEFAULT_FRONTEND_PATH |
|
|