apolinario commited on
Commit
6cd8e25
·
1 Parent(s): 9014add

Fix _latent_to_pil squeeze dim (T=1, dim=1 not dim=0); stream gallery as each PiD step completes

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -114,9 +114,9 @@ print("[pid] ready", flush=True)
114
 
115
 
116
  def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
117
- """[C, H, W] in [-1, 1] -> PIL.Image."""
118
  if tensor.dim() == 4:
119
- tensor = tensor.squeeze(0)
120
  arr = ((tensor.float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
121
  return Image.fromarray(arr)
122
 
@@ -193,7 +193,6 @@ def generate(
193
  final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)
194
 
195
  progress(0.5, desc="Decoding each captured step with PiD…")
196
- outputs: list[tuple[Image.Image, str]] = []
197
  steps_iter = []
198
  if xt_cb is not None:
199
  for K in sorted(xt_cb.captured.keys()):
@@ -202,8 +201,9 @@ def generate(
202
  xt_latent = extract_latent(pipeline, SimpleNamespace(images=xt_packed), pipe_cfg, H, W)
203
  steps_iter.append((f"step {K:02d}/{num_inference_steps}", xt_latent, sigma))
204
  final_sigma = float(pipeline.scheduler.sigmas[-1].item())
205
- steps_iter.append((f"final x₀", final_latent, final_sigma))
206
 
 
207
  total = len(steps_iter)
208
  for i, (label, latent, sigma) in enumerate(steps_iter):
209
  progress(0.5 + 0.5 * (i / total), desc=f"PiD decoding {label}")
@@ -211,8 +211,7 @@ def generate(
211
  baseline_01 = decode_with_pipeline_vae(pipeline, latent, pipe_cfg)
212
  pid_img = _pid_decode(latent, baseline_01, sigma, prompt)
213
  outputs.append((pid_img, f"{label} (σ={sigma:.3f})"))
214
-
215
- return outputs
216
 
217
 
218
  DESCRIPTION = """
 
114
 
115
 
116
  def _latent_to_pil(tensor: torch.Tensor) -> Image.Image:
117
+ """PiD output is (C, T, H, W) with T=1 for image -> PIL.Image."""
118
  if tensor.dim() == 4:
119
+ tensor = tensor.squeeze(1)
120
  arr = ((tensor.float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
121
  return Image.fromarray(arr)
122
 
 
193
  final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)
194
 
195
  progress(0.5, desc="Decoding each captured step with PiD…")
 
196
  steps_iter = []
197
  if xt_cb is not None:
198
  for K in sorted(xt_cb.captured.keys()):
 
201
  xt_latent = extract_latent(pipeline, SimpleNamespace(images=xt_packed), pipe_cfg, H, W)
202
  steps_iter.append((f"step {K:02d}/{num_inference_steps}", xt_latent, sigma))
203
  final_sigma = float(pipeline.scheduler.sigmas[-1].item())
204
+ steps_iter.append(("final x₀", final_latent, final_sigma))
205
 
206
+ outputs: list[tuple[Image.Image, str]] = []
207
  total = len(steps_iter)
208
  for i, (label, latent, sigma) in enumerate(steps_iter):
209
  progress(0.5 + 0.5 * (i / total), desc=f"PiD decoding {label}")
 
211
  baseline_01 = decode_with_pipeline_vae(pipeline, latent, pipe_cfg)
212
  pid_img = _pid_decode(latent, baseline_01, sigma, prompt)
213
  outputs.append((pid_img, f"{label} (σ={sigma:.3f})"))
214
+ yield outputs
 
215
 
216
 
217
  DESCRIPTION = """