Spaces:
Running on Zero
Running on Zero
feat(models): device autodetect, vram-limit helpers, model config registry
Browse files- models.py +96 -0
- tests/test_models.py +38 -0
models.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Device autodetect, ZImagePipeline ModelConfig registry, and (Task 4) HF cache mirror."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
# Avoid importing torch at module load — keeps `import models` fast in CI.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def on_spaces() -> bool:
|
| 12 |
+
"""True iff we are running inside a Hugging Face ZeroGPU Space."""
|
| 13 |
+
return bool(os.environ.get("SPACES_ZERO_GPU"))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def auto_device() -> str:
|
| 17 |
+
"""Detect the best available compute device."""
|
| 18 |
+
import torch
|
| 19 |
+
if torch.cuda.is_available():
|
| 20 |
+
return "cuda"
|
| 21 |
+
if torch.backends.mps.is_available():
|
| 22 |
+
return "mps"
|
| 23 |
+
return "cpu"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def vram_limit_for(device: str, free_gb: float | None = None) -> float:
|
| 27 |
+
"""Conservative VRAM limit (GB) passed to DiffSynth's vram_management.
|
| 28 |
+
|
| 29 |
+
- CUDA: keep ~5% headroom (loaded models + scratch).
|
| 30 |
+
- MPS: half of unified memory (CPU still needs RAM), capped.
|
| 31 |
+
- CPU: 0.0 (no offload budget).
|
| 32 |
+
"""
|
| 33 |
+
if device == "cpu":
|
| 34 |
+
return 0.0
|
| 35 |
+
if free_gb is None:
|
| 36 |
+
import torch
|
| 37 |
+
if device == "cuda":
|
| 38 |
+
free_gb = torch.cuda.mem_get_info()[1] / (1024 ** 3)
|
| 39 |
+
else: # mps
|
| 40 |
+
# torch.mps has no mem_get_info on most builds; fall back to a safe constant.
|
| 41 |
+
free_gb = 24.0
|
| 42 |
+
if device == "mps":
|
| 43 |
+
# Use half of unified memory; clamp to 8 GB floor for safety.
|
| 44 |
+
return max(8.0, free_gb / 2)
|
| 45 |
+
# cuda
|
| 46 |
+
return max(8.0, free_gb - 4.0)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass(frozen=True)
|
| 50 |
+
class ModelConfig:
|
| 51 |
+
"""Lightweight wrapper around DiffSynth's ModelConfig.
|
| 52 |
+
|
| 53 |
+
Stored as plain data so this module imports cheaply in CI. The real
|
| 54 |
+
``diffsynth.core.ModelConfig`` instance is built on demand by
|
| 55 |
+
:func:`build_diffsynth_configs`.
|
| 56 |
+
"""
|
| 57 |
+
model_id: str
|
| 58 |
+
origin_file_pattern: str
|
| 59 |
+
description: str = ""
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
MODEL_CONFIGS: tuple[ModelConfig, ...] = (
|
| 63 |
+
# Base
|
| 64 |
+
ModelConfig("Tongyi-MAI/Z-Image", "transformer/*.safetensors",
|
| 65 |
+
"Z-Image base transformer (25 steps, cfg=4)"),
|
| 66 |
+
ModelConfig("Tongyi-MAI/Z-Image", "text_encoder/*.safetensors",
|
| 67 |
+
"Qwen3-4B text encoder — shared between base + turbo"),
|
| 68 |
+
ModelConfig("Tongyi-MAI/Z-Image", "vae/diffusion_pytorch_model.safetensors",
|
| 69 |
+
"Flux-family VAE — shared between base + turbo"),
|
| 70 |
+
# Turbo (transformer only — encoder + VAE come from the Z-Image entry above)
|
| 71 |
+
ModelConfig("Tongyi-MAI/Z-Image-Turbo", "transformer/*.safetensors",
|
| 72 |
+
"Z-Image-Turbo transformer (8 steps, cfg=1)"),
|
| 73 |
+
# ControlNet Union 2.1 (eager preload per spec; can move to lazy if RAM is tight)
|
| 74 |
+
ModelConfig("PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1",
|
| 75 |
+
"Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors",
|
| 76 |
+
"ControlNet Union 2.1 — canny/depth/pose"),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
TOKENIZER_CONFIG = ModelConfig("Tongyi-MAI/Z-Image", "tokenizer/",
|
| 80 |
+
"Qwen3-4B tokenizer")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def build_diffsynth_configs(
|
| 84 |
+
configs: tuple[ModelConfig, ...] = MODEL_CONFIGS,
|
| 85 |
+
vram_cfg: dict[str, Any] | None = None,
|
| 86 |
+
) -> list[Any]:
|
| 87 |
+
"""Build DiffSynth ``ModelConfig`` instances from our lightweight dataclasses.
|
| 88 |
+
|
| 89 |
+
Called at app boot; not at module import. ``vram_cfg`` is the disk-offload
|
| 90 |
+
block (offload_dtype, offload_device, etc.) that DiffSynth's low-VRAM examples use.
|
| 91 |
+
"""
|
| 92 |
+
from diffsynth.core import ModelConfig as DSConfig
|
| 93 |
+
return [
|
| 94 |
+
DSConfig(model_id=c.model_id, origin_file_pattern=c.origin_file_pattern, **(vram_cfg or {}))
|
| 95 |
+
for c in configs
|
| 96 |
+
]
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from unittest import mock
|
| 3 |
+
|
| 4 |
+
import models
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_auto_device_returns_cuda_or_mps_or_cpu():
|
| 8 |
+
dev = models.auto_device()
|
| 9 |
+
assert dev in ("cuda", "mps", "cpu")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_on_spaces_reads_env_var():
|
| 13 |
+
with mock.patch.dict(os.environ, {"SPACES_ZERO_GPU": "1"}, clear=False):
|
| 14 |
+
assert models.on_spaces() is True
|
| 15 |
+
with mock.patch.dict(os.environ, {}, clear=True):
|
| 16 |
+
assert models.on_spaces() is False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_model_configs_contains_both_transformers():
|
| 20 |
+
configs = models.MODEL_CONFIGS
|
| 21 |
+
repos = {c.model_id for c in configs}
|
| 22 |
+
assert "Tongyi-MAI/Z-Image" in repos
|
| 23 |
+
assert "Tongyi-MAI/Z-Image-Turbo" in repos
|
| 24 |
+
assert "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1" in repos
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_vram_limit_for_cuda_is_reasonable():
|
| 28 |
+
limit = models.vram_limit_for("cuda", free_gb=80.0)
|
| 29 |
+
assert 60.0 <= limit <= 80.0 # leave headroom
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_vram_limit_for_mps_is_unified_memory_aware():
|
| 33 |
+
limit = models.vram_limit_for("mps", free_gb=24.0)
|
| 34 |
+
assert 12.0 <= limit <= 22.0 # half of unified, headroom
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_vram_limit_for_cpu_is_zero():
|
| 38 |
+
assert models.vram_limit_for("cpu", free_gb=64.0) == 0.0
|