Spaces:
Running on Zero
Running on Zero
Commit ·
b0fc9e3
1
Parent(s): bcabac3
Replace transformers' vmap-based mask broadcaster with explicit broadcasting (ZeroGPU __torch_function__ can't fake-alloc inside vmap)
Browse files
app.py
CHANGED
|
@@ -53,26 +53,29 @@ SR_SCALE = 4
|
|
| 53 |
PID_INFERENCE_STEPS = 4
|
| 54 |
|
| 55 |
print("[pid] loading Z-Image pipeline...", flush=True)
|
| 56 |
-
# transformers 4.57's SDPA
|
| 57 |
-
#
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
pipeline.to("cuda")
|
| 74 |
-
from pid._src.inference.pipeline_registry import get_config as _get_pipe_cfg
|
| 75 |
-
pipe_cfg = _get_pipe_cfg(BACKBONE)
|
| 76 |
|
| 77 |
print("[pid] loading PiD decoder...", flush=True)
|
| 78 |
pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
|
|
|
|
| 53 |
PID_INFERENCE_STEPS = 4
|
| 54 |
|
| 55 |
print("[pid] loading Z-Image pipeline...", flush=True)
|
| 56 |
+
# transformers 4.57's SDPA / eager mask builders both broadcast the mask
|
| 57 |
+
# function over (b, h, q, k) via torch.vmap, which trips ZeroGPU's
|
| 58 |
+
# __torch_function__ hijack when it tries to fake-allocate the indexed
|
| 59 |
+
# tensors. Replace vmap with explicit broadcasting — same result, same speed,
|
| 60 |
+
# no functorch transform context.
|
| 61 |
+
from transformers import masking_utils as _mu
|
| 62 |
+
|
| 63 |
+
def _broadcasting_vmap_for_bhqkv(mask_function, bh_indices: bool = True):
|
| 64 |
+
def wrapped(b, h, q, k):
|
| 65 |
+
if bh_indices:
|
| 66 |
+
return mask_function(
|
| 67 |
+
b[:, None, None, None],
|
| 68 |
+
h[None, :, None, None],
|
| 69 |
+
q[None, None, :, None],
|
| 70 |
+
k[None, None, None, :],
|
| 71 |
+
)
|
| 72 |
+
return mask_function(b, h, q[:, None], k[None, :])
|
| 73 |
+
return wrapped
|
| 74 |
+
|
| 75 |
+
_mu._vmap_for_bhqkv = _broadcasting_vmap_for_bhqkv
|
| 76 |
+
|
| 77 |
+
pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
|
| 78 |
pipeline.to("cuda")
|
|
|
|
|
|
|
| 79 |
|
| 80 |
print("[pid] loading PiD decoder...", flush=True)
|
| 81 |
pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
|