apolinario commited on
Commit
b0fc9e3
·
1 Parent(s): bcabac3

Replace transformers' vmap-based mask broadcaster with explicit broadcasting (ZeroGPU __torch_function__ can't fake-alloc inside vmap)

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -53,26 +53,29 @@ SR_SCALE = 4
53
  PID_INFERENCE_STEPS = 4
54
 
55
  print("[pid] loading Z-Image pipeline...", flush=True)
56
- # transformers 4.57's SDPA causal-mask uses torch.vmap, which clashes with
57
- # ZeroGPU's __torch_function__ hijack during fake tensor allocation. Force
58
- # eager attention on the text encoder to skip the vmap codepath.
59
- from diffusers import ZImagePipeline
60
- from transformers import Qwen3Model
61
-
62
- _text_encoder = Qwen3Model.from_pretrained(
63
- "Tongyi-MAI/Z-Image",
64
- subfolder="text_encoder",
65
- torch_dtype=DTYPE,
66
- attn_implementation="eager",
67
- )
68
- pipeline = ZImagePipeline.from_pretrained(
69
- "Tongyi-MAI/Z-Image",
70
- torch_dtype=DTYPE,
71
- text_encoder=_text_encoder,
72
- )
 
 
 
 
 
73
  pipeline.to("cuda")
74
- from pid._src.inference.pipeline_registry import get_config as _get_pipe_cfg
75
- pipe_cfg = _get_pipe_cfg(BACKBONE)
76
 
77
  print("[pid] loading PiD decoder...", flush=True)
78
  pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
 
53
  PID_INFERENCE_STEPS = 4
54
 
55
  print("[pid] loading Z-Image pipeline...", flush=True)
56
+ # transformers 4.57's SDPA / eager mask builders both broadcast the mask
57
+ # function over (b, h, q, k) via torch.vmap, which trips ZeroGPU's
58
+ # __torch_function__ hijack when it tries to fake-allocate the indexed
59
+ # tensors. Replace vmap with explicit broadcasting — same result, same speed,
60
+ # no functorch transform context.
61
+ from transformers import masking_utils as _mu
62
+
63
+ def _broadcasting_vmap_for_bhqkv(mask_function, bh_indices: bool = True):
64
+ def wrapped(b, h, q, k):
65
+ if bh_indices:
66
+ return mask_function(
67
+ b[:, None, None, None],
68
+ h[None, :, None, None],
69
+ q[None, None, :, None],
70
+ k[None, None, None, :],
71
+ )
72
+ return mask_function(b, h, q[:, None], k[None, :])
73
+ return wrapped
74
+
75
+ _mu._vmap_for_bhqkv = _broadcasting_vmap_for_bhqkv
76
+
77
+ pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
78
  pipeline.to("cuda")
 
 
79
 
80
  print("[pid] loading PiD decoder...", flush=True)
81
  pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)