apolinario commited on
Commit
afb0b5a
·
1 Parent(s): 6cd8e25

Stream taef1 previews from Z-Image (thread+queue) into a flash gr.Image; swap to gr.Gallery only when all PiD steps done

Browse files
Files changed (1) hide show
  1. app.py +86 -29
app.py CHANGED
@@ -100,6 +100,11 @@ _gm.Gemma2Model.forward = _patched_gemma2_forward
100
  pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
101
  pipeline.to("cuda")
102
 
 
 
 
 
 
103
  print("[pid] loading PiD decoder...", flush=True)
104
  pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
105
  pid_model, _pid_cfg = load_model_from_checkpoint(
@@ -121,6 +126,19 @@ def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
121
  return Image.fromarray(arr)
122
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def _pid_decode(latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str) -> Image.Image:
125
  baseline_neg1_1 = baseline_01 * 2.0 - 1.0
126
  lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
@@ -150,6 +168,10 @@ def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[in
150
  return sorted({int(round(x)) for x in raw})
151
 
152
 
 
 
 
 
153
  @spaces.GPU(duration=240)
154
  def generate(
155
  prompt: str,
@@ -158,41 +180,73 @@ def generate(
158
  guidance_scale: float = 5.0,
159
  seed: int = 0,
160
  resolution: int = 512,
161
- progress=gr.Progress(track_tqdm=True),
162
  ):
163
  if not prompt or not prompt.strip():
164
  raise gr.Error("Please enter a prompt.")
165
 
166
  num_inference_steps = int(num_inference_steps)
167
  num_captures = int(num_captures)
168
- resolution = int(resolution)
169
- H = W = resolution
170
 
171
- capture_ks = set(_evenly_spaced_capture_steps(num_inference_steps, num_captures))
172
- progress(0.05, desc="Running Z-Image latent diffusion…")
173
 
 
174
  xt_cb = XtCaptureCallback(capture_ks) if capture_ks else None
175
- generator = torch.Generator(device="cuda").manual_seed(int(seed))
176
- gen_kwargs = dict(
177
- prompt=prompt,
178
- height=H,
179
- width=W,
180
- num_inference_steps=num_inference_steps,
181
- guidance_scale=float(guidance_scale),
182
- num_images_per_prompt=1,
183
- output_type="latent",
184
- generator=generator,
185
- )
186
- gen_kwargs.update(pipe_cfg.extra_generate_kwargs)
187
- if xt_cb is not None:
188
- gen_kwargs["callback_on_step_end"] = xt_cb
189
- gen_kwargs["callback_on_step_end_tensor_inputs"] = ["latents"]
190
 
191
- with torch.no_grad():
192
- raw_output = pipeline(**gen_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)
194
 
195
- progress(0.5, desc="Decoding each captured step with PiD…")
196
  steps_iter = []
197
  if xt_cb is not None:
198
  for K in sorted(xt_cb.captured.keys()):
@@ -204,14 +258,16 @@ def generate(
204
  steps_iter.append(("final x₀", final_latent, final_sigma))
205
 
206
  outputs: list[tuple[Image.Image, str]] = []
207
- total = len(steps_iter)
208
- for i, (label, latent, sigma) in enumerate(steps_iter):
209
- progress(0.5 + 0.5 * (i / total), desc=f"PiD decoding {label}")
210
  with torch.no_grad():
211
  baseline_01 = decode_with_pipeline_vae(pipeline, latent, pipe_cfg)
212
  pid_img = _pid_decode(latent, baseline_01, sigma, prompt)
213
  outputs.append((pid_img, f"{label} (σ={sigma:.3f})"))
214
- yield outputs
 
 
 
 
215
 
216
 
217
  DESCRIPTION = """
@@ -246,12 +302,13 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=CSS) as demo:
246
  seed = gr.Number(label="Seed", value=0, precision=0)
247
  run = gr.Button("Run", variant="primary")
248
  with gr.Column(scale=2):
249
- gallery = gr.Gallery(label="PiD-decoded denoising trajectory", columns=2, object_fit="contain")
 
250
 
251
  run.click(
252
  fn=generate,
253
  inputs=[prompt, num_inference_steps, num_captures, guidance_scale, seed, resolution],
254
- outputs=[gallery],
255
  )
256
 
257
  if __name__ == "__main__":
 
100
  pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
101
  pipeline.to("cuda")
102
 
103
+ print("[pid] loading TAEF1 (fast preview decoder)...", flush=True)
104
+ from diffusers import AutoencoderTiny
105
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=DTYPE).to("cuda")
106
+ taef1.eval()
107
+
108
  print("[pid] loading PiD decoder...", flush=True)
109
  pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
110
  pid_model, _pid_cfg = load_model_from_checkpoint(
 
126
  return Image.fromarray(arr)
127
 
128
 
129
+ def _taef1_preview(packed_latent: torch.Tensor, H: int, W: int) -> Image.Image:
130
+ """Fast low-res decode of a Z-Image latent using TAEF1 (FLUX-1 compatible)."""
131
+ with torch.no_grad():
132
+ unpacked = extract_latent(pipeline, SimpleNamespace(images=packed_latent), pipe_cfg, H, W)
133
+ scale = pipeline.vae.config.scaling_factor
134
+ shift = getattr(pipeline.vae.config, "shift_factor", None) or 0.0
135
+ denorm = unpacked.to(dtype=DTYPE) / scale + shift
136
+ img = taef1.decode(denorm).sample
137
+ img = (img.float().clamp(-1, 1) + 1) / 2
138
+ arr = (img[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
139
+ return Image.fromarray(arr)
140
+
141
+
142
  def _pid_decode(latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str) -> Image.Image:
143
  baseline_neg1_1 = baseline_01 * 2.0 - 1.0
144
  lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
 
168
  return sorted({int(round(x)) for x in raw})
169
 
170
 
171
+ import threading
172
+ import queue as _queue
173
+
174
+
175
  @spaces.GPU(duration=240)
176
  def generate(
177
  prompt: str,
 
180
  guidance_scale: float = 5.0,
181
  seed: int = 0,
182
  resolution: int = 512,
 
183
  ):
184
  if not prompt or not prompt.strip():
185
  raise gr.Error("Please enter a prompt.")
186
 
187
  num_inference_steps = int(num_inference_steps)
188
  num_captures = int(num_captures)
189
+ H = W = int(resolution)
 
190
 
191
+ # initial: show the live-preview image, hide the final gallery
192
+ yield gr.update(visible=True, value=None), gr.update(visible=False, value=None)
193
 
194
+ capture_ks = set(_evenly_spaced_capture_steps(num_inference_steps, num_captures))
195
  xt_cb = XtCaptureCallback(capture_ks) if capture_ks else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ # ---- Run Z-Image in a thread; stream taef1 previews via a queue ----
198
+ preview_q: "_queue.Queue" = _queue.Queue()
199
+ _DONE = object()
200
+
201
+ def streaming_cb(pipe, step_index, timestep, callback_kwargs):
202
+ if xt_cb is not None:
203
+ xt_cb(pipe, step_index, timestep, callback_kwargs)
204
+ try:
205
+ preview = _taef1_preview(callback_kwargs["latents"], H, W)
206
+ preview_q.put((step_index, preview))
207
+ except Exception as e:
208
+ print(f"[pid] taef1 preview failed at step {step_index}: {e}", flush=True)
209
+ return callback_kwargs
210
+
211
+ def run_pipeline():
212
+ gen_torch = torch.Generator(device="cuda").manual_seed(int(seed))
213
+ gen_kwargs = dict(
214
+ prompt=prompt,
215
+ height=H,
216
+ width=W,
217
+ num_inference_steps=num_inference_steps,
218
+ guidance_scale=float(guidance_scale),
219
+ num_images_per_prompt=1,
220
+ output_type="latent",
221
+ generator=gen_torch,
222
+ callback_on_step_end=streaming_cb,
223
+ callback_on_step_end_tensor_inputs=["latents"],
224
+ )
225
+ gen_kwargs.update(pipe_cfg.extra_generate_kwargs)
226
+ try:
227
+ with torch.no_grad():
228
+ out = pipeline(**gen_kwargs)
229
+ preview_q.put((_DONE, out))
230
+ except Exception as e:
231
+ preview_q.put((_DONE, e))
232
+
233
+ thread = threading.Thread(target=run_pipeline, daemon=True)
234
+ thread.start()
235
+
236
+ raw_output = None
237
+ while True:
238
+ step_index, payload = preview_q.get()
239
+ if step_index is _DONE:
240
+ if isinstance(payload, Exception):
241
+ raise payload
242
+ raw_output = payload
243
+ break
244
+ yield gr.update(visible=True, value=payload), gr.update(visible=False)
245
+
246
+ thread.join()
247
  final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)
248
 
249
+ # ---- PiD per-step decode (sequentially) ----
250
  steps_iter = []
251
  if xt_cb is not None:
252
  for K in sorted(xt_cb.captured.keys()):
 
258
  steps_iter.append(("final x₀", final_latent, final_sigma))
259
 
260
  outputs: list[tuple[Image.Image, str]] = []
261
+ for label, latent, sigma in steps_iter:
 
 
262
  with torch.no_grad():
263
  baseline_01 = decode_with_pipeline_vae(pipeline, latent, pipe_cfg)
264
  pid_img = _pid_decode(latent, baseline_01, sigma, prompt)
265
  outputs.append((pid_img, f"{label} (σ={sigma:.3f})"))
266
+ # Flash the latest PiD output in the live-preview image during PiD decoding too
267
+ yield gr.update(visible=True, value=pid_img), gr.update(visible=False)
268
+
269
+ # ---- Done: hide live preview, show the final gallery ----
270
+ yield gr.update(visible=False, value=None), gr.update(visible=True, value=outputs)
271
 
272
 
273
  DESCRIPTION = """
 
302
  seed = gr.Number(label="Seed", value=0, precision=0)
303
  run = gr.Button("Run", variant="primary")
304
  with gr.Column(scale=2):
305
+ live_preview = gr.Image(label="Live preview", visible=True, show_label=True, type="pil")
306
+ gallery = gr.Gallery(label="PiD-decoded denoising trajectory", visible=False, columns=2, object_fit="contain")
307
 
308
  run.click(
309
  fn=generate,
310
  inputs=[prompt, num_inference_steps, num_captures, guidance_scale, seed, resolution],
311
+ outputs=[live_preview, gallery],
312
  )
313
 
314
  if __name__ == "__main__":