apolinario commited on
Commit
0972cc0
·
1 Parent(s): 4192431

Initial PiD + Z-Image step-by-step denoising demo for ZeroGPU

Browse files
Files changed (3) hide show
  1. README.md +11 -7
  2. app.py +213 -0
  3. requirements.txt +18 -0
README.md CHANGED
@@ -1,13 +1,17 @@
1
  ---
2
- title: Pid
3
- emoji: 🏆
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
+ title: PiD — Z-Image Pixel Diffusion Decoder
3
+ emoji: 🪄
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.49.1
8
+ python_version: '3.10'
9
  app_file: app.py
10
  pinned: false
11
+ short_description: Z-Image denoising loop decoded step-by-step by PiD
12
  ---
13
 
14
+ Demo for [NVIDIA PiD](https://github.com/nv-tlabs/PiD) — Pixel Diffusion
15
+ Decoder — paired with [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image).
16
+ Captures intermediate latents from Z-Image's denoising loop and decodes each one
17
+ with PiD's 4-step distilled pixel-space decoder.
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import tempfile
5
+
6
+ import spaces
7
+
8
+
9
+ PID_REPO_URL = "https://github.com/nv-tlabs/PiD.git"
10
+ PID_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "PiD")
11
+
12
+ if not os.path.exists(PID_REPO_DIR):
13
+ print(f"[pid] cloning {PID_REPO_URL} -> {PID_REPO_DIR}", flush=True)
14
+ subprocess.check_call(["git", "clone", "--depth", "1", PID_REPO_URL, PID_REPO_DIR])
15
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", PID_REPO_DIR])
16
+
17
+ # PiD's loader resolves paths relative to CWD, so chdir into the repo root.
18
+ os.chdir(PID_REPO_DIR)
19
+ sys.path.insert(0, PID_REPO_DIR)
20
+
21
+ import torch
22
+ import numpy as np
23
+ import gradio as gr
24
+ from PIL import Image
25
+ from types import SimpleNamespace
26
+ from huggingface_hub import snapshot_download
27
+
28
+ # Pull just the Flux-1 / Z-Image-compatible checkpoints from nvidia/PiD into the
29
+ # repo's expected checkpoints/ tree.
30
+ snapshot_download(
31
+ repo_id="nvidia/PiD",
32
+ local_dir=PID_REPO_DIR,
33
+ allow_patterns=[
34
+ "checkpoints/PiD_res2k_sr4x_official_flux_distill_4step/*",
35
+ "checkpoints/ae.safetensors",
36
+ ],
37
+ )
38
+
39
+ from pid._src.inference.checkpoint_registry import get_pid_checkpoint
40
+ from pid._src.inference.create_dataset import XtCaptureCallback
41
+ from pid._src.inference.pipeline_registry import (
42
+ decode_with_pipeline_vae,
43
+ extract_latent,
44
+ load_pipeline,
45
+ )
46
+ from pid._src.utils.model_loader import load_model_from_checkpoint
47
+
48
+
49
+ DTYPE = torch.bfloat16
50
+ BACKBONE = "zimage"
51
+ CKPT_TYPE = "2k"
52
+ SR_SCALE = 4
53
+ PID_INFERENCE_STEPS = 4
54
+
55
+ print("[pid] loading Z-Image pipeline...", flush=True)
56
+ pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
57
+ pipeline.to("cuda")
58
+
59
+ print("[pid] loading PiD decoder...", flush=True)
60
+ pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
61
+ pid_model, _pid_cfg = load_model_from_checkpoint(
62
+ experiment_name=pid_meta.experiment,
63
+ checkpoint_path=pid_meta.checkpoint_path,
64
+ config_file="pid/_src/configs/pid/config.py",
65
+ enable_fsdp=False,
66
+ strict=False,
67
+ )
68
+ pid_model.eval()
69
+ print("[pid] ready", flush=True)
70
+
71
+
72
+ def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
73
+ """[C, H, W] in [-1, 1] -> PIL.Image."""
74
+ if tensor.dim() == 4:
75
+ tensor = tensor.squeeze(0)
76
+ arr = ((tensor.float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
77
+ return Image.fromarray(arr)
78
+
79
+
80
+ def _pid_decode(latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str) -> Image.Image:
81
+ baseline_neg1_1 = baseline_01 * 2.0 - 1.0
82
+ lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
83
+ data_batch = {
84
+ pid_model.config.input_caption_key: [caption],
85
+ "LQ_video_or_image": baseline_neg1_1.to(dtype=DTYPE, device="cuda"),
86
+ "LQ_latent": latent.to(dtype=DTYPE, device="cuda"),
87
+ "degrade_sigma": torch.tensor([sigma], device="cuda", dtype=torch.float32),
88
+ }
89
+ samples = pid_model.generate_samples_from_batch(
90
+ data_batch,
91
+ cfg_scale=1.0,
92
+ num_steps=PID_INFERENCE_STEPS,
93
+ seed=0,
94
+ shift=None,
95
+ image_size=(lq_h * SR_SCALE, lq_w * SR_SCALE),
96
+ )
97
+ return _latent_to_pil(samples[0])
98
+
99
+
100
+ def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[int]:
101
+ """Pick N capture indices spread across [1, total_steps-1]. The final x0 is always added separately."""
102
+ if num_captures <= 0:
103
+ return []
104
+ # avoid 0 (no forward pass yet) and total_steps (== final clean, captured separately)
105
+ raw = np.linspace(1, max(2, total_steps - 1), num_captures + 1)[1:]
106
+ return sorted({int(round(x)) for x in raw})
107
+
108
+
109
+ @spaces.GPU(duration=240)
110
+ def generate(
111
+ prompt: str,
112
+ num_inference_steps: int = 28,
113
+ num_captures: int = 4,
114
+ guidance_scale: float = 5.0,
115
+ seed: int = 0,
116
+ resolution: int = 512,
117
+ progress=gr.Progress(track_tqdm=True),
118
+ ):
119
+ if not prompt or not prompt.strip():
120
+ raise gr.Error("Please enter a prompt.")
121
+
122
+ num_inference_steps = int(num_inference_steps)
123
+ num_captures = int(num_captures)
124
+ resolution = int(resolution)
125
+ H = W = resolution
126
+
127
+ capture_ks = set(_evenly_spaced_capture_steps(num_inference_steps, num_captures))
128
+ progress(0.05, desc="Running Z-Image latent diffusion…")
129
+
130
+ xt_cb = XtCaptureCallback(capture_ks) if capture_ks else None
131
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
132
+ gen_kwargs = dict(
133
+ prompt=prompt,
134
+ height=H,
135
+ width=W,
136
+ num_inference_steps=num_inference_steps,
137
+ guidance_scale=float(guidance_scale),
138
+ num_images_per_prompt=1,
139
+ output_type="latent",
140
+ generator=generator,
141
+ )
142
+ gen_kwargs.update(pipe_cfg.extra_generate_kwargs)
143
+ if xt_cb is not None:
144
+ gen_kwargs["callback_on_step_end"] = xt_cb
145
+ gen_kwargs["callback_on_step_end_tensor_inputs"] = ["latents"]
146
+
147
+ with torch.no_grad():
148
+ raw_output = pipeline(**gen_kwargs)
149
+ final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)
150
+
151
+ progress(0.5, desc="Decoding each captured step with PiD…")
152
+ outputs: list[tuple[Image.Image, str]] = []
153
+ steps_iter = []
154
+ if xt_cb is not None:
155
+ for K in sorted(xt_cb.captured.keys()):
156
+ xt_packed_cpu, sigma = xt_cb.captured[K]
157
+ xt_packed = xt_packed_cpu.to(device="cuda", dtype=DTYPE)
158
+ xt_latent = extract_latent(pipeline, SimpleNamespace(images=xt_packed), pipe_cfg, H, W)
159
+ steps_iter.append((f"step {K:02d}/{num_inference_steps}", xt_latent, sigma))
160
+ final_sigma = float(pipeline.scheduler.sigmas[-1].item())
161
+ steps_iter.append((f"final x₀", final_latent, final_sigma))
162
+
163
+ total = len(steps_iter)
164
+ for i, (label, latent, sigma) in enumerate(steps_iter):
165
+ progress(0.5 + 0.5 * (i / total), desc=f"PiD decoding {label}")
166
+ with torch.no_grad():
167
+ baseline_01 = decode_with_pipeline_vae(pipeline, latent, pipe_cfg)
168
+ pid_img = _pid_decode(latent, baseline_01, sigma, prompt)
169
+ outputs.append((pid_img, f"{label} (σ={sigma:.3f})"))
170
+
171
+ return outputs
172
+
173
+
174
+ DESCRIPTION = """
175
+ # 🪄 PiD — Pixel Diffusion Decoder for Z-Image
176
+
177
+ Each tile shows what NVIDIA's [PiD](https://github.com/nv-tlabs/PiD) (a 4-step
178
+ distilled pixel-space diffusion decoder) reconstructs from Z-Image's denoising
179
+ loop at progressive timesteps. The first few tiles come from noisy intermediate
180
+ latents (`xt`); the last tile is decoded from the final clean `x₀`.
181
+
182
+ PiD upsamples 4× during decode, so a 512² Z-Image latent track becomes a
183
+ 2048² super-resolved image.
184
+ """
185
+
186
+ with gr.Blocks() as demo:
187
+ gr.Markdown(DESCRIPTION)
188
+ with gr.Row():
189
+ with gr.Column(scale=1):
190
+ prompt = gr.Textbox(
191
+ label="Prompt",
192
+ value="A photorealistic close-up of a brown tabby cat sitting on a rustic wooden table, morning light, ultra-detailed fur",
193
+ lines=3,
194
+ )
195
+ with gr.Row():
196
+ resolution = gr.Slider(label="Z-Image resolution", minimum=256, maximum=1024, step=128, value=512)
197
+ num_inference_steps = gr.Slider(label="Z-Image steps", minimum=8, maximum=50, step=1, value=28)
198
+ with gr.Row():
199
+ num_captures = gr.Slider(label="Intermediate captures", minimum=1, maximum=8, step=1, value=4)
200
+ guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=10.0, step=0.5, value=5.0)
201
+ seed = gr.Number(label="Seed", value=0, precision=0)
202
+ run = gr.Button("Run", variant="primary")
203
+ with gr.Column(scale=2):
204
+ gallery = gr.Gallery(label="PiD-decoded denoising trajectory", columns=2, object_fit="contain")
205
+
206
+ run.click(
207
+ fn=generate,
208
+ inputs=[prompt, num_inference_steps, num_captures, guidance_scale, seed, resolution],
209
+ outputs=[gallery],
210
+ )
211
+
212
+ if __name__ == "__main__":
213
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers>=0.37.0
2
+ transformers==4.57.1
3
+ sentencepiece
4
+ safetensors
5
+ hydra-core==1.3.2
6
+ omegaconf==2.3.0
7
+ attrs
8
+ einops
9
+ loguru
10
+ termcolor
11
+ fvcore
12
+ iopath
13
+ pynvml
14
+ imageio
15
+ opencv-python-headless
16
+ pandas
17
+ numpy<2
18
+ pillow