techfreakworm commited on
Commit
613dab3
·
unverified ·
1 Parent(s): fd0ad15

feat(models): device autodetect, vram-limit helpers, model config registry

Browse files
Files changed (2) hide show
  1. models.py +96 -0
  2. tests/test_models.py +38 -0
models.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Device autodetect, ZImagePipeline ModelConfig registry, and (Task 4) HF cache mirror."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+ # Avoid importing torch at module load — keeps `import models` fast in CI.
9
+
10
+
11
+ def on_spaces() -> bool:
12
+ """True iff we are running inside a Hugging Face ZeroGPU Space."""
13
+ return bool(os.environ.get("SPACES_ZERO_GPU"))
14
+
15
+
16
+ def auto_device() -> str:
17
+ """Detect the best available compute device."""
18
+ import torch
19
+ if torch.cuda.is_available():
20
+ return "cuda"
21
+ if torch.backends.mps.is_available():
22
+ return "mps"
23
+ return "cpu"
24
+
25
+
26
+ def vram_limit_for(device: str, free_gb: float | None = None) -> float:
27
+ """Conservative VRAM limit (GB) passed to DiffSynth's vram_management.
28
+
29
+ - CUDA: keep ~5% headroom (loaded models + scratch).
30
+ - MPS: half of unified memory (CPU still needs RAM), capped.
31
+ - CPU: 0.0 (no offload budget).
32
+ """
33
+ if device == "cpu":
34
+ return 0.0
35
+ if free_gb is None:
36
+ import torch
37
+ if device == "cuda":
38
+ free_gb = torch.cuda.mem_get_info()[1] / (1024 ** 3)
39
+ else: # mps
40
+ # torch.mps has no mem_get_info on most builds; fall back to a safe constant.
41
+ free_gb = 24.0
42
+ if device == "mps":
43
+ # Use half of unified memory; clamp to 8 GB floor for safety.
44
+ return max(8.0, free_gb / 2)
45
+ # cuda
46
+ return max(8.0, free_gb - 4.0)
47
+
48
+
49
+ @dataclass(frozen=True)
50
+ class ModelConfig:
51
+ """Lightweight wrapper around DiffSynth's ModelConfig.
52
+
53
+ Stored as plain data so this module imports cheaply in CI. The real
54
+ ``diffsynth.core.ModelConfig`` instance is built on demand by
55
+ :func:`build_diffsynth_configs`.
56
+ """
57
+ model_id: str
58
+ origin_file_pattern: str
59
+ description: str = ""
60
+
61
+
62
+ MODEL_CONFIGS: tuple[ModelConfig, ...] = (
63
+ # Base
64
+ ModelConfig("Tongyi-MAI/Z-Image", "transformer/*.safetensors",
65
+ "Z-Image base transformer (25 steps, cfg=4)"),
66
+ ModelConfig("Tongyi-MAI/Z-Image", "text_encoder/*.safetensors",
67
+ "Qwen3-4B text encoder — shared between base + turbo"),
68
+ ModelConfig("Tongyi-MAI/Z-Image", "vae/diffusion_pytorch_model.safetensors",
69
+ "Flux-family VAE — shared between base + turbo"),
70
+ # Turbo (transformer only — encoder + VAE come from the Z-Image entry above)
71
+ ModelConfig("Tongyi-MAI/Z-Image-Turbo", "transformer/*.safetensors",
72
+ "Z-Image-Turbo transformer (8 steps, cfg=1)"),
73
+ # ControlNet Union 2.1 (eager preload per spec; can move to lazy if RAM is tight)
74
+ ModelConfig("PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1",
75
+ "Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors",
76
+ "ControlNet Union 2.1 — canny/depth/pose"),
77
+ )
78
+
79
+ TOKENIZER_CONFIG = ModelConfig("Tongyi-MAI/Z-Image", "tokenizer/",
80
+ "Qwen3-4B tokenizer")
81
+
82
+
83
+ def build_diffsynth_configs(
84
+ configs: tuple[ModelConfig, ...] = MODEL_CONFIGS,
85
+ vram_cfg: dict[str, Any] | None = None,
86
+ ) -> list[Any]:
87
+ """Build DiffSynth ``ModelConfig`` instances from our lightweight dataclasses.
88
+
89
+ Called at app boot; not at module import. ``vram_cfg`` is the disk-offload
90
+ block (offload_dtype, offload_device, etc.) that DiffSynth's low-VRAM examples use.
91
+ """
92
+ from diffsynth.core import ModelConfig as DSConfig
93
+ return [
94
+ DSConfig(model_id=c.model_id, origin_file_pattern=c.origin_file_pattern, **(vram_cfg or {}))
95
+ for c in configs
96
+ ]
tests/test_models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from unittest import mock
3
+
4
+ import models
5
+
6
+
7
+ def test_auto_device_returns_cuda_or_mps_or_cpu():
8
+ dev = models.auto_device()
9
+ assert dev in ("cuda", "mps", "cpu")
10
+
11
+
12
+ def test_on_spaces_reads_env_var():
13
+ with mock.patch.dict(os.environ, {"SPACES_ZERO_GPU": "1"}, clear=False):
14
+ assert models.on_spaces() is True
15
+ with mock.patch.dict(os.environ, {}, clear=True):
16
+ assert models.on_spaces() is False
17
+
18
+
19
+ def test_model_configs_contains_both_transformers():
20
+ configs = models.MODEL_CONFIGS
21
+ repos = {c.model_id for c in configs}
22
+ assert "Tongyi-MAI/Z-Image" in repos
23
+ assert "Tongyi-MAI/Z-Image-Turbo" in repos
24
+ assert "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1" in repos
25
+
26
+
27
+ def test_vram_limit_for_cuda_is_reasonable():
28
+ limit = models.vram_limit_for("cuda", free_gb=80.0)
29
+ assert 60.0 <= limit <= 80.0 # leave headroom
30
+
31
+
32
+ def test_vram_limit_for_mps_is_unified_memory_aware():
33
+ limit = models.vram_limit_for("mps", free_gb=24.0)
34
+ assert 12.0 <= limit <= 22.0 # half of unified, headroom
35
+
36
+
37
+ def test_vram_limit_for_cpu_is_zero():
38
+ assert models.vram_limit_for("cpu", free_gb=64.0) == 0.0