Spaces:
Running on Zero
Running on Zero
File size: 7,473 Bytes
613dab3 9a5065c 613dab3 9a5065c 261639d 613dab3 9a5065c 613dab3 0cf8ffc 613dab3 0cf8ffc 613dab3 0cf8ffc 613dab3 9a5065c 0cf8ffc 613dab3 9a5065c 613dab3 9a5065c 613dab3 9a5065c 613dab3 9a5065c 99302bc 9a5065c 613dab3 9a5065c 613dab3 9a5065c 613dab3 9a5065c 613dab3 261639d 99302bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | """Device autodetect, ZImagePipeline ModelConfig registry, and (Task 4) HF cache mirror."""
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
# Avoid importing torch at module load — keeps `import models` fast in CI.
def on_spaces() -> bool:
"""True iff we are running inside a Hugging Face ZeroGPU Space."""
return bool(os.environ.get("SPACES_ZERO_GPU"))
def auto_device() -> str:
"""Detect the best available compute device."""
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def vram_limit_for(device: str, free_gb: float | None = None) -> float | None:
"""Conservative VRAM limit (GB) passed to DiffSynth's vram_management.
- CUDA: keep a few GB headroom (loaded models + scratch).
- MPS: ``None`` — PyTorch's MPS has no ``mem_get_info`` API, and DiffSynth's
``check_free_vram`` raises AttributeError when called on MPS. Returning
``None`` short-circuits the check (``vram/layers.py:195``) so module
swapping still works without the gate.
- CPU: 0.0 (no offload budget).
"""
if device == "cpu":
return 0.0
if device == "mps":
# PyTorch's MPS backend has no ``torch.mps.mem_get_info``. DiffSynth's
# ``AutoWrappedModule.check_free_vram`` calls it and raises AttributeError.
# Returning None short-circuits the gate at vram/layers.py:195 so we keep
# CPU↔MPS module swapping (offload/onload) without the doomed check.
return None
# cuda
if free_gb is None:
import torch
free_gb = torch.cuda.mem_get_info()[1] / (1024**3)
return max(8.0, free_gb - 4.0)
@dataclass(frozen=True)
class ModelConfig:
"""Lightweight wrapper around DiffSynth's ModelConfig.
Stored as plain data so this module imports cheaply in CI. The real
``diffsynth.core.ModelConfig`` instance is built on demand by
:func:`build_diffsynth_configs`.
"""
model_id: str
origin_file_pattern: str
description: str = ""
MODEL_CONFIGS: tuple[ModelConfig, ...] = (
# Base
ModelConfig("Tongyi-MAI/Z-Image", "transformer/*.safetensors", "Z-Image base transformer (25 steps, cfg=4)"),
ModelConfig(
"Tongyi-MAI/Z-Image", "text_encoder/*.safetensors", "Qwen3-4B text encoder — shared between base + turbo"
),
ModelConfig(
"Tongyi-MAI/Z-Image", "vae/diffusion_pytorch_model.safetensors", "Flux-family VAE — shared between base + turbo"
),
# Turbo (transformer only — encoder + VAE come from the Z-Image entry above)
ModelConfig("Tongyi-MAI/Z-Image-Turbo", "transformer/*.safetensors", "Z-Image-Turbo transformer (8 steps, cfg=1)"),
# ControlNet Union 2.1 (eager preload per spec; can move to lazy if RAM is tight)
ModelConfig(
"alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1",
"Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors",
"ControlNet Union 2.1 — canny/depth/pose",
),
)
TOKENIZER_CONFIG = ModelConfig("Tongyi-MAI/Z-Image", "tokenizer/", "Qwen3-4B tokenizer")
def build_diffsynth_configs(
configs: tuple[ModelConfig, ...] = MODEL_CONFIGS,
vram_cfg: dict[str, Any] | None = None,
) -> list[Any]:
"""Build DiffSynth ``ModelConfig`` instances from our lightweight dataclasses.
Called at app boot; not at module import. ``vram_cfg`` is the disk-offload
block (offload_dtype, offload_device, etc.) that DiffSynth's low-VRAM examples use.
"""
from diffsynth.core import ModelConfig as DSConfig
return [
DSConfig(model_id=c.model_id, origin_file_pattern=c.origin_file_pattern, **(vram_cfg or {})) for c in configs
]
def mirror_preload_hf_cache(src_root: Path | str, dst_root: Path | str) -> None:
"""Mirror a read-only HF cache tree (preload_from_hub) into a writable tree.
- ``blobs/<sha>`` files -> **hardlinked** (zero-copy, shared inode).
- ``snapshots/<commit>/...`` symlinks -> **preserved** with original relative target.
- ``refs/<branch>`` files -> **byte-copied** (HF lib overwrites on etag check).
- Directories -> ``mkdir`` so the runtime user owns them.
Falls back to ``symlink`` when ``os.link()`` raises EXDEV (cross-device).
"""
import errno
import shutil
src_root = Path(src_root)
dst_root = Path(dst_root)
if not (src_root / "hub").exists():
return # nothing preloaded -- no-op
for src_dir, _, files in os.walk(src_root / "hub"):
rel = Path(src_dir).relative_to(src_root)
dst_dir = dst_root / rel
dst_dir.mkdir(parents=True, exist_ok=True)
for name in files:
src_path = Path(src_dir) / name
dst_path = dst_dir / name
if dst_path.exists():
continue
# Refs get byte-copied
if "refs/" in str(rel).replace("\\", "/"):
shutil.copy2(src_path, dst_path)
continue
# Symlinks (snapshot files) preserve their relative target
if src_path.is_symlink():
target = os.readlink(src_path)
dst_path.symlink_to(target)
continue
# Regular files (blobs) hardlink with EXDEV fallback
try:
os.link(src_path, dst_path)
except OSError as e:
if e.errno == errno.EXDEV:
dst_path.symlink_to(src_path)
else:
raise
def symlink_hf_cache_to_diffsynth_layout(cache_hub: Path | str, dest_root: Path | str) -> list[str]:
"""For each ``models--<org>--<repo>`` under ``cache_hub``, symlink the latest snapshot
dir to ``dest_root/<org>/<repo>/`` — the layout DiffSynth's ModelConfig expects.
DiffSynth's ``download()`` joins ``local_model_path`` with ``model_id`` and either
finds matching files (skipping download) or fetches them. Putting symlinks at the
expected location lets DiffSynth reuse our HF-cache snapshots instead of re-downloading.
Returns the list of dest paths created. Idempotent: existing valid symlinks are kept.
"""
cache_hub = Path(cache_hub)
dest_root = Path(dest_root)
if not cache_hub.is_dir():
return []
created: list[str] = []
for entry in sorted(cache_hub.iterdir()):
if not entry.is_dir() or not entry.name.startswith("models--"):
continue
# "models--Tongyi-MAI--Z-Image-Turbo" -> ("Tongyi-MAI", "Z-Image-Turbo")
# Some repos contain "--" in their name; only split off the first segment.
rest = entry.name[len("models--") :]
parts = rest.split("--", 1)
if len(parts) != 2:
continue
org, repo = parts
snapshots = entry / "snapshots"
if not snapshots.is_dir():
continue
sha_dirs = [d for d in snapshots.iterdir() if d.is_dir()]
if not sha_dirs:
continue
# Newest by mtime — usually the only one for our preload + first-fetch flow.
sha_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
snap = sha_dirs[0]
link = dest_root / org / repo
if link.is_symlink() or link.exists():
continue
link.parent.mkdir(parents=True, exist_ok=True)
link.symlink_to(snap)
created.append(str(link))
return created
|