z-image-studio / models.py
techfreakworm's picture
fix: pool-stashed transformer swap + MPS-safe vram + corrected model-zoo anchor
0cf8ffc unverified
"""Device autodetect, ZImagePipeline ModelConfig registry, and (Task 4) HF cache mirror."""
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any
# Avoid importing torch at module load — keeps `import models` fast in CI.
def on_spaces() -> bool:
"""True iff we are running inside a Hugging Face ZeroGPU Space."""
return bool(os.environ.get("SPACES_ZERO_GPU"))
def auto_device() -> str:
"""Detect the best available compute device."""
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def vram_limit_for(device: str, free_gb: float | None = None) -> float | None:
"""Conservative VRAM limit (GB) passed to DiffSynth's vram_management.
- CUDA: keep a few GB headroom (loaded models + scratch).
- MPS: ``None`` — PyTorch's MPS has no ``mem_get_info`` API, and DiffSynth's
``check_free_vram`` raises AttributeError when called on MPS. Returning
``None`` short-circuits the check (``vram/layers.py:195``) so module
swapping still works without the gate.
- CPU: 0.0 (no offload budget).
"""
if device == "cpu":
return 0.0
if device == "mps":
# PyTorch's MPS backend has no ``torch.mps.mem_get_info``. DiffSynth's
# ``AutoWrappedModule.check_free_vram`` calls it and raises AttributeError.
# Returning None short-circuits the gate at vram/layers.py:195 so we keep
# CPU↔MPS module swapping (offload/onload) without the doomed check.
return None
# cuda
if free_gb is None:
import torch
free_gb = torch.cuda.mem_get_info()[1] / (1024**3)
return max(8.0, free_gb - 4.0)
@dataclass(frozen=True)
class ModelConfig:
"""Lightweight wrapper around DiffSynth's ModelConfig.
Stored as plain data so this module imports cheaply in CI. The real
``diffsynth.core.ModelConfig`` instance is built on demand by
:func:`build_diffsynth_configs`.
"""
model_id: str
origin_file_pattern: str
description: str = ""
MODEL_CONFIGS: tuple[ModelConfig, ...] = (
# Base
ModelConfig("Tongyi-MAI/Z-Image", "transformer/*.safetensors", "Z-Image base transformer (25 steps, cfg=4)"),
ModelConfig(
"Tongyi-MAI/Z-Image", "text_encoder/*.safetensors", "Qwen3-4B text encoder — shared between base + turbo"
),
ModelConfig(
"Tongyi-MAI/Z-Image", "vae/diffusion_pytorch_model.safetensors", "Flux-family VAE — shared between base + turbo"
),
# Turbo (transformer only — encoder + VAE come from the Z-Image entry above)
ModelConfig("Tongyi-MAI/Z-Image-Turbo", "transformer/*.safetensors", "Z-Image-Turbo transformer (8 steps, cfg=1)"),
# ControlNet Union 2.1 (eager preload per spec; can move to lazy if RAM is tight)
ModelConfig(
"alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1",
"Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors",
"ControlNet Union 2.1 — canny/depth/pose",
),
)
TOKENIZER_CONFIG = ModelConfig("Tongyi-MAI/Z-Image", "tokenizer/", "Qwen3-4B tokenizer")
def build_diffsynth_configs(
configs: tuple[ModelConfig, ...] = MODEL_CONFIGS,
vram_cfg: dict[str, Any] | None = None,
) -> list[Any]:
"""Build DiffSynth ``ModelConfig`` instances from our lightweight dataclasses.
Called at app boot; not at module import. ``vram_cfg`` is the disk-offload
block (offload_dtype, offload_device, etc.) that DiffSynth's low-VRAM examples use.
"""
from diffsynth.core import ModelConfig as DSConfig
return [
DSConfig(model_id=c.model_id, origin_file_pattern=c.origin_file_pattern, **(vram_cfg or {})) for c in configs
]
def mirror_preload_hf_cache(src_root: Path | str, dst_root: Path | str) -> None:
"""Mirror a read-only HF cache tree (preload_from_hub) into a writable tree.
- ``blobs/<sha>`` files -> **hardlinked** (zero-copy, shared inode).
- ``snapshots/<commit>/...`` symlinks -> **preserved** with original relative target.
- ``refs/<branch>`` files -> **byte-copied** (HF lib overwrites on etag check).
- Directories -> ``mkdir`` so the runtime user owns them.
Falls back to ``symlink`` when ``os.link()`` raises EXDEV (cross-device).
"""
import errno
import shutil
src_root = Path(src_root)
dst_root = Path(dst_root)
if not (src_root / "hub").exists():
return # nothing preloaded -- no-op
for src_dir, _, files in os.walk(src_root / "hub"):
rel = Path(src_dir).relative_to(src_root)
dst_dir = dst_root / rel
dst_dir.mkdir(parents=True, exist_ok=True)
for name in files:
src_path = Path(src_dir) / name
dst_path = dst_dir / name
if dst_path.exists():
continue
# Refs get byte-copied
if "refs/" in str(rel).replace("\\", "/"):
shutil.copy2(src_path, dst_path)
continue
# Symlinks (snapshot files) preserve their relative target
if src_path.is_symlink():
target = os.readlink(src_path)
dst_path.symlink_to(target)
continue
# Regular files (blobs) hardlink with EXDEV fallback
try:
os.link(src_path, dst_path)
except OSError as e:
if e.errno == errno.EXDEV:
dst_path.symlink_to(src_path)
else:
raise
def symlink_hf_cache_to_diffsynth_layout(cache_hub: Path | str, dest_root: Path | str) -> list[str]:
"""For each ``models--<org>--<repo>`` under ``cache_hub``, symlink the latest snapshot
dir to ``dest_root/<org>/<repo>/`` — the layout DiffSynth's ModelConfig expects.
DiffSynth's ``download()`` joins ``local_model_path`` with ``model_id`` and either
finds matching files (skipping download) or fetches them. Putting symlinks at the
expected location lets DiffSynth reuse our HF-cache snapshots instead of re-downloading.
Returns the list of dest paths created. Idempotent: existing valid symlinks are kept.
"""
cache_hub = Path(cache_hub)
dest_root = Path(dest_root)
if not cache_hub.is_dir():
return []
created: list[str] = []
for entry in sorted(cache_hub.iterdir()):
if not entry.is_dir() or not entry.name.startswith("models--"):
continue
# "models--Tongyi-MAI--Z-Image-Turbo" -> ("Tongyi-MAI", "Z-Image-Turbo")
# Some repos contain "--" in their name; only split off the first segment.
rest = entry.name[len("models--") :]
parts = rest.split("--", 1)
if len(parts) != 2:
continue
org, repo = parts
snapshots = entry / "snapshots"
if not snapshots.is_dir():
continue
sha_dirs = [d for d in snapshots.iterdir() if d.is_dir()]
if not sha_dirs:
continue
# Newest by mtime — usually the only one for our preload + first-fetch flow.
sha_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
snap = sha_dirs[0]
link = dest_root / org / repo
if link.is_symlink() or link.exists():
continue
link.parent.mkdir(parents=True, exist_ok=True)
link.symlink_to(snap)
created.append(str(link))
return created