apolinario commited on
Commit
622f4d0
·
1 Parent(s): e2f50b1

Stream PiD's 4 internal student-sampler steps (yield after each); 'Upscaling with PiD — step K/4' label

Browse files
Files changed (1) hide show
  1. app.py +74 -20
app.py CHANGED
@@ -139,24 +139,74 @@ def _taef1_preview(packed_latent: torch.Tensor, H: int, W: int) -> Image.Image:
139
  return Image.fromarray(arr)
140
 
141
 
142
- def _pid_decode(latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str) -> Image.Image:
143
- baseline_neg1_1 = baseline_01 * 2.0 - 1.0
 
 
 
 
 
 
 
 
 
 
 
144
  lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
145
- data_batch = {
146
- pid_model.config.input_caption_key: [caption],
147
- "LQ_video_or_image": baseline_neg1_1.to(dtype=DTYPE, device="cuda"),
148
- "LQ_latent": latent.to(dtype=DTYPE, device="cuda"),
149
- "degrade_sigma": torch.tensor([sigma], device="cuda", dtype=torch.float32),
150
- }
151
- samples = pid_model.generate_samples_from_batch(
152
- data_batch,
153
- cfg_scale=1.0,
154
- num_steps=PID_INFERENCE_STEPS,
155
- seed=0,
156
- shift=None,
157
- image_size=(lq_h * SR_SCALE, lq_w * SR_SCALE),
 
 
 
 
158
  )
159
- return _latent_to_pil(samples[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
 
162
  def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[int]:
@@ -248,11 +298,15 @@ def generate(
248
  (baseline_01[0].clamp(0, 1).permute(1, 2, 0).float().cpu().numpy() * 255).astype(np.uint8)
249
  )
250
 
251
- # ---- PiD upscaling on the final latent ----
252
- yield gr.update(visible=True, value=zimage_img, label="Upscaling with PiD (4× super-resolution, 4 steps)…"), gr.update(visible=False)
253
  final_sigma = float(pipeline.scheduler.sigmas[-1].item())
254
- with torch.no_grad():
255
- pid_img = _pid_decode(final_latent, baseline_01, final_sigma, prompt)
 
 
 
 
 
256
 
257
  # ---- Done: hide live preview, show the A/B slider ----
258
  yield (
 
139
  return Image.fromarray(arr)
140
 
141
 
142
+ def _pid_pixel_to_pil(x: torch.Tensor) -> Image.Image:
143
+ """PiD pixel-space tensor (B, 3, H, W) in [-1, 1] -> PIL.Image."""
144
+ arr = ((x[0].float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
145
+ return Image.fromarray(arr)
146
+
147
+
148
+ def _pid_stream(latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str, num_steps: int = PID_INFERENCE_STEPS):
149
+ """Reimplementation of PiDDistillModel.generate_samples_from_batch that yields
150
+ the current pixel-space tensor after each of the `num_steps` student-sampler
151
+ iterations. Final yield is the clean output."""
152
+ from contextlib import nullcontext
153
+
154
+ B = 1
155
  lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
156
+ img_h, img_w = lq_h * SR_SCALE, lq_w * SR_SCALE
157
+
158
+ caption_embs, _ = pid_model._encode_text_raw([caption])
159
+ caption_embs = caption_embs.to(**pid_model.tensor_kwargs)
160
+
161
+ lq_video_or_image = (baseline_01 * 2.0 - 1.0).to(dtype=DTYPE, device="cuda")
162
+ lq_latent = latent.to(dtype=DTYPE, device="cuda")
163
+ degrade_sigma_tensor = torch.tensor([sigma], device="cuda", dtype=torch.float32)
164
+
165
+ gen = torch.Generator(device="cuda").manual_seed(0)
166
+ noise = torch.randn(B, 3, img_h, img_w, device="cuda", generator=gen)
167
+
168
+ t_list = pid_model._get_t_list(device=torch.device("cuda"), num_steps=num_steps)
169
+ autocast_ctx = (
170
+ torch.autocast("cuda", dtype=pid_model.autocast_dtype)
171
+ if pid_model.autocast_dtype
172
+ else nullcontext()
173
  )
174
+ net = pid_model.net
175
+ net.eval()
176
+ timescale = pid_model.fm_trainer.timescale
177
+ student_sample_type = pid_model.config.student_sample_type
178
+ prediction_type = pid_model.config.prediction_type
179
+
180
+ x = noise
181
+ with torch.no_grad(), autocast_ctx:
182
+ steps_total = len(t_list) - 1
183
+ for step_idx, (t_cur, t_next) in enumerate(zip(t_list[:-1], t_list[1:])):
184
+ t_cur_batch = t_cur.expand(B)
185
+ t_cur_scaled = t_cur_batch * timescale
186
+ v_pred = net(
187
+ x,
188
+ t_cur_scaled,
189
+ caption_embs,
190
+ lq_video_or_image=lq_video_or_image,
191
+ lq_latent=lq_latent,
192
+ degrade_sigma=degrade_sigma_tensor,
193
+ )
194
+ if t_next.item() > 0:
195
+ if student_sample_type == "ode":
196
+ v_for_step = pid_model._net_output_to_velocity(x, v_pred, t_cur_batch, prediction_type)
197
+ dt = t_next - t_cur
198
+ x = x + dt * v_for_step
199
+ else:
200
+ x0_pred = pid_model._velocity_to_x0(x, v_pred, t_cur_batch)
201
+ eps_infer = torch.randn(
202
+ x0_pred.shape, device=x0_pred.device, dtype=x0_pred.dtype, generator=gen
203
+ )
204
+ s = [B] + [1] * (x.ndim - 1)
205
+ t_next_bcast = t_next.reshape(1).expand(s)
206
+ x = (1.0 - t_next_bcast) * x0_pred + t_next_bcast * eps_infer
207
+ else:
208
+ x = pid_model._velocity_to_x0(x, v_pred, t_cur_batch)
209
+ yield step_idx + 1, steps_total, x.clone()
210
 
211
 
212
  def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[int]:
 
298
  (baseline_01[0].clamp(0, 1).permute(1, 2, 0).float().cpu().numpy() * 255).astype(np.uint8)
299
  )
300
 
301
+ # ---- PiD upscaling on the final latent, streaming the 4 internal steps ----
 
302
  final_sigma = float(pipeline.scheduler.sigmas[-1].item())
303
+ pid_img = None
304
+ for k, total, x in _pid_stream(final_latent, baseline_01, final_sigma, prompt):
305
+ pid_img = _pid_pixel_to_pil(x)
306
+ yield (
307
+ gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"),
308
+ gr.update(visible=False),
309
+ )
310
 
311
  # ---- Done: hide live preview, show the A/B slider ----
312
  yield (