techfreakworm commited on
Commit
0cf8ffc
·
unverified ·
1 Parent(s): 5ab6428

fix: pool-stashed transformer swap + MPS-safe vram + corrected model-zoo anchor

Browse files

Three coupled fixes uncovered by the first real Generate click locally:

1. AttributeError: 'ZImagePipeline' object has no attribute 'model_pool'
DiffSynth's from_pretrained builds a fresh ModelPool, attaches pipe.dit/
pipe.text_encoder/etc. from it, then discards the pool. My _swap_transformer
in modes.py assumed pipe.model_pool persisted. Replace from_pretrained call
with a manual replication that stashes the pool on pipe._zis_pool, then
index into pool.model for the two z_image_dit entries (Base loaded first,
Turbo second per MODEL_CONFIGS order). fetch_model can't distinguish them
since both register under the same name.

2. AttributeError: module 'torch.mps' has no attribute 'mem_get_info'
DiffSynth's AutoWrappedModule.forward gates module load on check_free_vram
(vram/layers.py:195), which calls torch.<device>.mem_get_info — CUDA-only.
The escape hatch is vram_limit=None, which short-circuits the gate. Update
models.vram_limit_for('mps') -> None (was a positive float).

3. Preemptively set PYTORCH_ENABLE_MPS_FALLBACK=1 in app.py so any other
MPS-unsupported op (SDPA variants, certain index ops) falls back to CPU
instead of crashing the request.

4. Corrected the Z-Image Model Zoo anchor — README heading has a leading
emoji that GitHub renders as a '-' prefix, so the URL is #-model-zoo.

Tests: 68 passing, ruff clean. Validated locally; HF Space will rebuild on push.

app.py CHANGED
@@ -14,6 +14,11 @@ from pathlib import Path
14
  # Must be set before any diffsynth import path is taken (backend imports it lazily).
15
  os.environ.setdefault("DIFFSYNTH_DOWNLOAD_SOURCE", "huggingface")
16
 
 
 
 
 
 
17
  import gradio as gr
18
 
19
  import backend
 
14
  # Must be set before any diffsynth import path is taken (backend imports it lazily).
15
  os.environ.setdefault("DIFFSYNTH_DOWNLOAD_SOURCE", "huggingface")
16
 
17
+ # Apple Silicon: let PyTorch fall back to CPU for the small set of ops MPS doesn't
18
+ # implement (some scaled-dot-product flavors, certain index ops). Without this,
19
+ # DiffSynth crashes mid-pipeline on the first unsupported op rather than degrading.
20
+ os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
21
+
22
  import gradio as gr
23
 
24
  import backend
backend.py CHANGED
@@ -61,9 +61,18 @@ _GPU = (
61
 
62
 
63
  def _build_pipeline() -> Any:
64
- """Construct the DiffSynth ZImagePipeline. Imported lazily to keep tests fast."""
 
 
 
 
 
 
 
 
65
  import torch
66
  from diffsynth.pipelines.z_image import ZImagePipeline
 
67
 
68
  import models
69
 
@@ -81,16 +90,35 @@ def _build_pipeline() -> Any:
81
  computation_device=device,
82
  )
83
 
84
- pipe = ZImagePipeline.from_pretrained(
85
- torch_dtype=torch.bfloat16,
86
- device=device,
87
- model_configs=models.build_diffsynth_configs(vram_cfg=vram_cfg),
88
- tokenizer_config=models.build_diffsynth_configs(
89
- (models.TOKENIZER_CONFIG,),
90
- vram_cfg=None,
91
- )[0],
92
  vram_limit=models.vram_limit_for(device),
93
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  return pipe
95
 
96
 
 
61
 
62
 
63
  def _build_pipeline() -> Any:
64
+ """Construct a ZImagePipeline carrying BOTH Base and Turbo transformers.
65
+
66
+ DiffSynth's ``ZImagePipeline.from_pretrained`` builds a fresh ``ModelPool``
67
+ locally and throws it away after attaching ``pipe.dit`` etc. — so a later
68
+ transformer swap has nothing to switch between. We replicate the same
69
+ initialization manually and keep the pool on ``pipe._zis_pool`` so
70
+ :func:`modes._swap_transformer` can flip ``pipe.dit`` between the two
71
+ ``z_image_dit`` entries (Base loaded first, Turbo second per MODEL_CONFIGS).
72
+ """
73
  import torch
74
  from diffsynth.pipelines.z_image import ZImagePipeline
75
+ from transformers import AutoTokenizer
76
 
77
  import models
78
 
 
90
  computation_device=device,
91
  )
92
 
93
+ pipe = ZImagePipeline(device=device, torch_dtype=torch.bfloat16)
94
+
95
+ # Load every safetensors listed in MODEL_CONFIGS — both transformers + shared
96
+ # text encoder + VAE + controlnet — into one pool.
97
+ pool = pipe.download_and_load_models(
98
+ models.build_diffsynth_configs(vram_cfg=vram_cfg),
 
 
99
  vram_limit=models.vram_limit_for(device),
100
  )
101
+ pipe._zis_pool = pool
102
+
103
+ pipe.text_encoder = pool.fetch_model("z_image_text_encoder")
104
+ pipe.dit = pool.fetch_model("z_image_dit") # first match = Base per load order
105
+ pipe.vae_encoder = pool.fetch_model("flux_vae_encoder")
106
+ pipe.vae_decoder = pool.fetch_model("flux_vae_decoder")
107
+ pipe.controlnet = pool.fetch_model("z_image_controlnet")
108
+ # Optional image encoders that DiffSynth's ZImagePipeline references but
109
+ # aren't in our preload (Omni / image2lora). fetch_model returns None when
110
+ # absent — that's the documented "not an error" path.
111
+ pipe.image_encoder = pool.fetch_model("siglip_vision_model_428m")
112
+ pipe.siglip2_image_encoder = pool.fetch_model("siglip2_image_encoder")
113
+ pipe.dinov3_image_encoder = pool.fetch_model("dinov3_image_encoder")
114
+ pipe.image2lora_style = pool.fetch_model("z_image_image2lora_style")
115
+
116
+ # Tokenizer (Qwen3-4B tokenizer dir under Z-Image)
117
+ tok_cfg = models.build_diffsynth_configs((models.TOKENIZER_CONFIG,), vram_cfg=None)[0]
118
+ tok_cfg.download_if_necessary()
119
+ pipe.tokenizer = AutoTokenizer.from_pretrained(tok_cfg.path)
120
+
121
+ pipe.vram_management_enabled = pipe.check_vram_management_state()
122
  return pipe
123
 
124
 
docs/superpowers/plans/2026-05-13-z-image-studio.md CHANGED
@@ -2892,7 +2892,7 @@ def test_model_selector_html_marks_current_as_on():
2892
 
2893
  def test_model_selector_html_includes_both_soon_cards_with_github_link():
2894
  out = ui.model_selector_html(current="Turbo")
2895
- assert out.count("github.com/Tongyi-MAI/Z-Image#model-zoo") == 2
2896
  assert "Edit" in out
2897
  assert "Omni Base" in out
2898
  assert "soon-tag" in out
@@ -2920,7 +2920,7 @@ from __future__ import annotations
2920
 
2921
  from html import escape
2922
 
2923
- GITHUB_MODEL_ZOO_URL = "https://github.com/Tongyi-MAI/Z-Image#model-zoo"
2924
 
2925
 
2926
  def labeled_label(text: str, info_text: str) -> str:
 
2892
 
2893
  def test_model_selector_html_includes_both_soon_cards_with_github_link():
2894
  out = ui.model_selector_html(current="Turbo")
2895
+ assert out.count("github.com/Tongyi-MAI/Z-Image#-model-zoo") == 2
2896
  assert "Edit" in out
2897
  assert "Omni Base" in out
2898
  assert "soon-tag" in out
 
2920
 
2921
  from html import escape
2922
 
2923
+ GITHUB_MODEL_ZOO_URL = "https://github.com/Tongyi-MAI/Z-Image#-model-zoo"
2924
 
2925
 
2926
  def labeled_label(text: str, info_text: str) -> str:
docs/superpowers/specs/2026-05-13-z-image-studio-design.md CHANGED
@@ -195,14 +195,14 @@ The T2I tab's Model selector replaces `gr.Radio` with a custom HTML grid because
195
  <span class="dot"></span><span class="name">Turbo</span>
196
  </button>
197
  <a class="zis-model soon"
198
- href="https://github.com/Tongyi-MAI/Z-Image#model-zoo"
199
  target="_blank" rel="noopener noreferrer">
200
  <span class="dot"></span>
201
  <span class="name">Edit<span class="ext">↗</span></span>
202
  <span class="soon-tag">soon</span>
203
  </a>
204
  <a class="zis-model soon"
205
- href="https://github.com/Tongyi-MAI/Z-Image#model-zoo"
206
  target="_blank" rel="noopener noreferrer">
207
  <span class="dot"></span>
208
  <span class="name">Omni Base<span class="ext">↗</span></span>
 
195
  <span class="dot"></span><span class="name">Turbo</span>
196
  </button>
197
  <a class="zis-model soon"
198
+ href="https://github.com/Tongyi-MAI/Z-Image#-model-zoo"
199
  target="_blank" rel="noopener noreferrer">
200
  <span class="dot"></span>
201
  <span class="name">Edit<span class="ext">↗</span></span>
202
  <span class="soon-tag">soon</span>
203
  </a>
204
  <a class="zis-model soon"
205
+ href="https://github.com/Tongyi-MAI/Z-Image#-model-zoo"
206
  target="_blank" rel="noopener noreferrer">
207
  <span class="dot"></span>
208
  <span class="name">Omni Base<span class="ext">↗</span></span>
models.py CHANGED
@@ -26,27 +26,29 @@ def auto_device() -> str:
26
  return "cpu"
27
 
28
 
29
- def vram_limit_for(device: str, free_gb: float | None = None) -> float:
30
  """Conservative VRAM limit (GB) passed to DiffSynth's vram_management.
31
 
32
- - CUDA: keep ~5% headroom (loaded models + scratch).
33
- - MPS: half of unified memory (CPU still needs RAM), capped.
 
 
 
34
  - CPU: 0.0 (no offload budget).
35
  """
36
  if device == "cpu":
37
  return 0.0
 
 
 
 
 
 
 
38
  if free_gb is None:
39
  import torch
40
 
41
- if device == "cuda":
42
- free_gb = torch.cuda.mem_get_info()[1] / (1024**3)
43
- else: # mps
44
- # torch.mps has no mem_get_info on most builds; fall back to a safe constant.
45
- free_gb = 24.0
46
- if device == "mps":
47
- # Use half of unified memory; clamp to 8 GB floor for safety.
48
- return max(8.0, free_gb / 2)
49
- # cuda
50
  return max(8.0, free_gb - 4.0)
51
 
52
 
 
26
  return "cpu"
27
 
28
 
29
+ def vram_limit_for(device: str, free_gb: float | None = None) -> float | None:
30
  """Conservative VRAM limit (GB) passed to DiffSynth's vram_management.
31
 
32
+ - CUDA: keep a few GB headroom (loaded models + scratch).
33
+ - MPS: ``None`` PyTorch's MPS has no ``mem_get_info`` API, and DiffSynth's
34
+ ``check_free_vram`` raises AttributeError when called on MPS. Returning
35
+ ``None`` short-circuits the check (``vram/layers.py:195``) so module
36
+ swapping still works without the gate.
37
  - CPU: 0.0 (no offload budget).
38
  """
39
  if device == "cpu":
40
  return 0.0
41
+ if device == "mps":
42
+ # PyTorch's MPS backend has no ``torch.mps.mem_get_info``. DiffSynth's
43
+ # ``AutoWrappedModule.check_free_vram`` calls it and raises AttributeError.
44
+ # Returning None short-circuits the gate at vram/layers.py:195 so we keep
45
+ # CPU↔MPS module swapping (offload/onload) without the doomed check.
46
+ return None
47
+ # cuda
48
  if free_gb is None:
49
  import torch
50
 
51
+ free_gb = torch.cuda.mem_get_info()[1] / (1024**3)
 
 
 
 
 
 
 
 
52
  return max(8.0, free_gb - 4.0)
53
 
54
 
modes.py CHANGED
@@ -36,9 +36,23 @@ class T2IParams(TypedDict, total=False):
36
 
37
 
38
  def _swap_transformer(pipe: Any, model_name: str) -> None:
39
- """Swap the active transformer in the pipeline's model pool."""
 
 
 
 
 
 
 
 
 
 
40
  variant = "z_image" if model_name == "Base" else "z_image_turbo"
41
- pipe.dit = pipe.model_pool.fetch_model("z_image_dit", variant=variant)
 
 
 
 
42
  try:
43
  pipe.dit._zis_variant = variant
44
  except (AttributeError, RuntimeError):
 
36
 
37
 
38
  def _swap_transformer(pipe: Any, model_name: str) -> None:
39
+ """Swap the active transformer between Base (index 0) and Turbo (index 1).
40
+
41
+ ``backend._build_pipeline`` loads both transformers into ``pipe._zis_pool``
42
+ and stores them under the same name ``z_image_dit``. DiffSynth's
43
+ ``ModelPool.fetch_model`` doesn't expose a variant kwarg — both entries
44
+ share the same name — so we index into ``pool.model`` directly. MODEL_CONFIGS
45
+ loads Base first, then Turbo (so index 0 = Base, index 1 = Turbo).
46
+
47
+ No-op if the pool is unavailable (e.g. mocked tests) or only one transformer
48
+ was loaded.
49
+ """
50
  variant = "z_image" if model_name == "Base" else "z_image_turbo"
51
+ pool = getattr(pipe, "_zis_pool", None)
52
+ if pool is not None:
53
+ dits = [m for m, n in zip(pool.model, pool.model_name, strict=False) if n == "z_image_dit"]
54
+ if len(dits) >= 2:
55
+ pipe.dit = dits[0 if model_name == "Base" else 1]
56
  try:
57
  pipe.dit._zis_variant = variant
58
  except (AttributeError, RuntimeError):
tests/test_models.py CHANGED
@@ -33,9 +33,11 @@ def test_vram_limit_for_cuda_is_reasonable():
33
  assert 60.0 <= limit <= 80.0 # leave headroom
34
 
35
 
36
- def test_vram_limit_for_mps_is_unified_memory_aware():
37
- limit = models.vram_limit_for("mps", free_gb=24.0)
38
- assert 12.0 <= limit <= 22.0 # half of unified, headroom
 
 
39
 
40
 
41
  def test_vram_limit_for_cpu_is_zero():
 
33
  assert 60.0 <= limit <= 80.0 # leave headroom
34
 
35
 
36
+ def test_vram_limit_for_mps_returns_none():
37
+ # MPS has no torch.mps.mem_get_info; DiffSynth's check_free_vram crashes
38
+ # on a numeric limit. None short-circuits the check (vram/layers.py:195).
39
+ assert models.vram_limit_for("mps", free_gb=24.0) is None
40
+ assert models.vram_limit_for("mps") is None
41
 
42
 
43
  def test_vram_limit_for_cpu_is_zero():
tests/test_modes.py CHANGED
@@ -68,25 +68,33 @@ def test_t2i_base_passes_negative_prompt_and_cfg4(fake_pipe):
68
  assert kwargs["num_inference_steps"] == 25
69
 
70
 
71
- def test_t2i_swaps_transformer_via_model_pool(fake_pipe):
 
 
 
 
 
 
 
72
  modes.call_t2i(
73
  fake_pipe,
74
  params=dict(
75
- prompt="x",
76
- negative_prompt="",
77
- model="Base",
78
- steps=25,
79
- cfg=4.0,
80
- width=1024,
81
- height=1024,
82
- seed=0,
83
- lora_path=None,
84
- lora_strength=0.0,
 
 
 
85
  ),
86
  )
87
- fake_pipe.model_pool.fetch_model.assert_called()
88
- call = fake_pipe.model_pool.fetch_model.call_args
89
- assert call.args[0] == "z_image_dit"
90
 
91
 
92
  def test_controlnet_calls_preprocessor_then_pipeline(fake_pipe, monkeypatch):
 
68
  assert kwargs["num_inference_steps"] == 25
69
 
70
 
71
+ def test_t2i_swaps_transformer_via_pool_index(fake_pipe):
72
+ """Base picks pool.model[0]; Turbo picks pool.model[1] (load-order indexed)."""
73
+ base_dit = object()
74
+ turbo_dit = object()
75
+ # Two z_image_dit entries in load order: Base first, Turbo second.
76
+ fake_pipe._zis_pool.model = [base_dit, turbo_dit, "vae_decoder_obj"]
77
+ fake_pipe._zis_pool.model_name = ["z_image_dit", "z_image_dit", "flux_vae_decoder"]
78
+
79
  modes.call_t2i(
80
  fake_pipe,
81
  params=dict(
82
+ prompt="x", negative_prompt="", model="Base",
83
+ steps=25, cfg=4.0, width=1024, height=1024, seed=0,
84
+ lora_path=None, lora_strength=0.0,
85
+ ),
86
+ )
87
+ assert fake_pipe.dit is base_dit
88
+
89
+ modes.call_t2i(
90
+ fake_pipe,
91
+ params=dict(
92
+ prompt="x", negative_prompt="", model="Turbo",
93
+ steps=8, cfg=1.0, width=1024, height=1024, seed=0,
94
+ lora_path=None, lora_strength=0.0,
95
  ),
96
  )
97
+ assert fake_pipe.dit is turbo_dit
 
 
98
 
99
 
100
  def test_controlnet_calls_preprocessor_then_pipeline(fake_pipe, monkeypatch):
tests/test_ui.py CHANGED
@@ -28,7 +28,7 @@ def test_model_selector_html_marks_current_as_on():
28
 
29
  def test_model_selector_html_includes_both_soon_cards_with_github_link():
30
  out = ui.model_selector_html(current="Turbo")
31
- assert out.count("github.com/Tongyi-MAI/Z-Image#model-zoo") == 2
32
  assert "Edit" in out
33
  assert "Omni Base" in out
34
  assert "soon-tag" in out
 
28
 
29
  def test_model_selector_html_includes_both_soon_cards_with_github_link():
30
  out = ui.model_selector_html(current="Turbo")
31
+ assert out.count("github.com/Tongyi-MAI/Z-Image#-model-zoo") == 2
32
  assert "Edit" in out
33
  assert "Omni Base" in out
34
  assert "soon-tag" in out
ui.py CHANGED
@@ -9,7 +9,7 @@ import gradio as gr
9
  import preprocessors
10
  from tooltips import TOOLTIPS
11
 
12
- GITHUB_MODEL_ZOO_URL = "https://github.com/Tongyi-MAI/Z-Image#model-zoo"
13
 
14
 
15
  def labeled_label(text: str, info_text: str) -> str:
 
9
  import preprocessors
10
  from tooltips import TOOLTIPS
11
 
12
+ GITHUB_MODEL_ZOO_URL = "https://github.com/Tongyi-MAI/Z-Image#-model-zoo"
13
 
14
 
15
  def labeled_label(text: str, info_text: str) -> str: