Spaces:
Running on Zero
Running on Zero
Commit ·
622f4d0
1
Parent(s): e2f50b1
Stream PiD's 4 internal student-sampler steps (yield after each); 'Upscaling with PiD — step K/4' label
Browse files
app.py
CHANGED
|
@@ -139,24 +139,74 @@ def _taef1_preview(packed_latent: torch.Tensor, H: int, W: int) -> Image.Image:
|
|
| 139 |
return Image.fromarray(arr)
|
| 140 |
|
| 141 |
|
| 142 |
-
def
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
)
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[int]:
|
|
@@ -248,11 +298,15 @@ def generate(
|
|
| 248 |
(baseline_01[0].clamp(0, 1).permute(1, 2, 0).float().cpu().numpy() * 255).astype(np.uint8)
|
| 249 |
)
|
| 250 |
|
| 251 |
-
# ---- PiD upscaling on the final latent ----
|
| 252 |
-
yield gr.update(visible=True, value=zimage_img, label="Upscaling with PiD (4× super-resolution, 4 steps)…"), gr.update(visible=False)
|
| 253 |
final_sigma = float(pipeline.scheduler.sigmas[-1].item())
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
# ---- Done: hide live preview, show the A/B slider ----
|
| 258 |
yield (
|
|
|
|
| 139 |
return Image.fromarray(arr)
|
| 140 |
|
| 141 |
|
| 142 |
+
def _pid_pixel_to_pil(x: torch.Tensor) -> Image.Image:
|
| 143 |
+
"""PiD pixel-space tensor (B, 3, H, W) in [-1, 1] -> PIL.Image."""
|
| 144 |
+
arr = ((x[0].float().clamp(-1, 1) + 1) * 127.5).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
|
| 145 |
+
return Image.fromarray(arr)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _pid_stream(latent: torch.Tensor, baseline_01: torch.Tensor, sigma: float, caption: str, num_steps: int = PID_INFERENCE_STEPS):
|
| 149 |
+
"""Reimplementation of PiDDistillModel.generate_samples_from_batch that yields
|
| 150 |
+
the current pixel-space tensor after each of the `num_steps` student-sampler
|
| 151 |
+
iterations. Final yield is the clean output."""
|
| 152 |
+
from contextlib import nullcontext
|
| 153 |
+
|
| 154 |
+
B = 1
|
| 155 |
lq_h, lq_w = baseline_01.shape[-2], baseline_01.shape[-1]
|
| 156 |
+
img_h, img_w = lq_h * SR_SCALE, lq_w * SR_SCALE
|
| 157 |
+
|
| 158 |
+
caption_embs, _ = pid_model._encode_text_raw([caption])
|
| 159 |
+
caption_embs = caption_embs.to(**pid_model.tensor_kwargs)
|
| 160 |
+
|
| 161 |
+
lq_video_or_image = (baseline_01 * 2.0 - 1.0).to(dtype=DTYPE, device="cuda")
|
| 162 |
+
lq_latent = latent.to(dtype=DTYPE, device="cuda")
|
| 163 |
+
degrade_sigma_tensor = torch.tensor([sigma], device="cuda", dtype=torch.float32)
|
| 164 |
+
|
| 165 |
+
gen = torch.Generator(device="cuda").manual_seed(0)
|
| 166 |
+
noise = torch.randn(B, 3, img_h, img_w, device="cuda", generator=gen)
|
| 167 |
+
|
| 168 |
+
t_list = pid_model._get_t_list(device=torch.device("cuda"), num_steps=num_steps)
|
| 169 |
+
autocast_ctx = (
|
| 170 |
+
torch.autocast("cuda", dtype=pid_model.autocast_dtype)
|
| 171 |
+
if pid_model.autocast_dtype
|
| 172 |
+
else nullcontext()
|
| 173 |
)
|
| 174 |
+
net = pid_model.net
|
| 175 |
+
net.eval()
|
| 176 |
+
timescale = pid_model.fm_trainer.timescale
|
| 177 |
+
student_sample_type = pid_model.config.student_sample_type
|
| 178 |
+
prediction_type = pid_model.config.prediction_type
|
| 179 |
+
|
| 180 |
+
x = noise
|
| 181 |
+
with torch.no_grad(), autocast_ctx:
|
| 182 |
+
steps_total = len(t_list) - 1
|
| 183 |
+
for step_idx, (t_cur, t_next) in enumerate(zip(t_list[:-1], t_list[1:])):
|
| 184 |
+
t_cur_batch = t_cur.expand(B)
|
| 185 |
+
t_cur_scaled = t_cur_batch * timescale
|
| 186 |
+
v_pred = net(
|
| 187 |
+
x,
|
| 188 |
+
t_cur_scaled,
|
| 189 |
+
caption_embs,
|
| 190 |
+
lq_video_or_image=lq_video_or_image,
|
| 191 |
+
lq_latent=lq_latent,
|
| 192 |
+
degrade_sigma=degrade_sigma_tensor,
|
| 193 |
+
)
|
| 194 |
+
if t_next.item() > 0:
|
| 195 |
+
if student_sample_type == "ode":
|
| 196 |
+
v_for_step = pid_model._net_output_to_velocity(x, v_pred, t_cur_batch, prediction_type)
|
| 197 |
+
dt = t_next - t_cur
|
| 198 |
+
x = x + dt * v_for_step
|
| 199 |
+
else:
|
| 200 |
+
x0_pred = pid_model._velocity_to_x0(x, v_pred, t_cur_batch)
|
| 201 |
+
eps_infer = torch.randn(
|
| 202 |
+
x0_pred.shape, device=x0_pred.device, dtype=x0_pred.dtype, generator=gen
|
| 203 |
+
)
|
| 204 |
+
s = [B] + [1] * (x.ndim - 1)
|
| 205 |
+
t_next_bcast = t_next.reshape(1).expand(s)
|
| 206 |
+
x = (1.0 - t_next_bcast) * x0_pred + t_next_bcast * eps_infer
|
| 207 |
+
else:
|
| 208 |
+
x = pid_model._velocity_to_x0(x, v_pred, t_cur_batch)
|
| 209 |
+
yield step_idx + 1, steps_total, x.clone()
|
| 210 |
|
| 211 |
|
| 212 |
def _evenly_spaced_capture_steps(total_steps: int, num_captures: int) -> list[int]:
|
|
|
|
| 298 |
(baseline_01[0].clamp(0, 1).permute(1, 2, 0).float().cpu().numpy() * 255).astype(np.uint8)
|
| 299 |
)
|
| 300 |
|
| 301 |
+
# ---- PiD upscaling on the final latent, streaming the 4 internal steps ----
|
|
|
|
| 302 |
final_sigma = float(pipeline.scheduler.sigmas[-1].item())
|
| 303 |
+
pid_img = None
|
| 304 |
+
for k, total, x in _pid_stream(final_latent, baseline_01, final_sigma, prompt):
|
| 305 |
+
pid_img = _pid_pixel_to_pil(x)
|
| 306 |
+
yield (
|
| 307 |
+
gr.update(visible=True, value=pid_img, label=f"Upscaling with PiD — step {k}/{total}"),
|
| 308 |
+
gr.update(visible=False),
|
| 309 |
+
)
|
| 310 |
|
| 311 |
# ---- Done: hide live preview, show the A/B slider ----
|
| 312 |
yield (
|