123123aa123 commited on
Commit
2a9fea7
·
verified ·
1 Parent(s): 9c8a5bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -76
app.py CHANGED
@@ -136,7 +136,7 @@ def load_models():
136
  #wan_pipe.to(device)
137
  #wan_pipe.to(dtype=torch.bfloat16)
138
 
139
-
140
  # =========================
141
  # Renderer
142
  # =========================
@@ -266,20 +266,21 @@ def build_estimate_rel(x, y, z, phi, theta):
266
  # =========================
267
 
268
  @spaces.GPU
269
- def infer(image, prompt, seed):
270
-
271
- load_models()
272
 
 
 
 
 
 
 
273
  img = image.convert("RGB")
274
-
275
  TARGET_H, TARGET_W = img.size[1], img.size[0]
276
  TARGET_H = TARGET_H // 32 * 32
277
  TARGET_W = TARGET_W // 32 * 32
278
-
279
  img = img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
280
 
281
  all_steps = generate_all_motions_from_prompt(prompt, num_frames=81)
282
-
283
  cam_idx = list(range(81))
284
  traj = [build_estimate_rel(*all_steps[idx]) for idx in cam_idx]
285
 
@@ -287,16 +288,11 @@ def infer(image, prompt, seed):
287
  first_frame = load_and_preprocess_images(first_frame)
288
  first_frame = first_frame.to(device)
289
 
290
-
291
-
292
  with torch.no_grad():
293
  with torch.cuda.amp.autocast(dtype=dtype):
294
  predictions = vggt_model(first_frame)
295
-
296
  extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], first_frame.shape[-2:])
297
-
298
  first_frame_world_points = predictions["world_points"][0][0]
299
-
300
  focals = intrinsic[0][0][:2, :2].diag().unsqueeze(0).to(device)
301
  principal_points = intrinsic[0][0][:2, 2].unsqueeze(0).to(device)
302
 
@@ -304,24 +300,20 @@ def infer(image, prompt, seed):
304
  raw_image = raw_image.transpose(1, 2, 0)
305
 
306
  render_results_list = []
307
-
308
-
309
  for estimate_rel in traj:
310
  estimate_rel = torch.from_numpy(estimate_rel).float().to(device)
311
  relative_c2ws = estimate_rel.unsqueeze(0)
312
  R, T = relative_c2ws[:, :3, :3], relative_c2ws[:, :3, 3:]
313
  R = torch.stack([-R[:, :, 0], -R[:, :, 1], R[:, :, 2]], 2)
314
  new_c2w = torch.cat([R, T], 2)
315
-
316
  w2c = torch.linalg.inv(torch.cat(
317
  (new_c2w, torch.Tensor([[[0, 0, 0, 1]]]).to(device).repeat(new_c2w.shape[0], 1, 1)),
318
  1
319
  ))
320
  R_new, T_new = w2c[:, :3, :3].permute(0, 2, 1), w2c[:, :3, 3]
321
 
322
-
323
  image_size = (first_frame.shape[-2:],)
324
-
325
  cameras = PerspectiveCameras(
326
  focal_length=focals,
327
  principal_point=principal_points,
@@ -331,7 +323,7 @@ def infer(image, prompt, seed):
331
  T=T_new,
332
  device=device
333
  )
334
-
335
  masks = None
336
  render_results, viewmask = run_render(
337
  [first_frame_world_points],
@@ -342,55 +334,60 @@ def infer(image, prompt, seed):
342
  1,
343
  device=device
344
  )
345
-
346
 
347
  render_result = (render_results[-1].detach().cpu().numpy() * 255).astype(np.uint8)
348
-
349
  if len(render_result.shape) == 2:
350
  render_result = cv2.cvtColor(render_result, cv2.COLOR_GRAY2RGB)
351
  elif render_result.shape[-1] == 4:
352
  render_result = render_result[..., :3]
353
-
354
  render_results_list.append(render_result)
355
 
356
-
357
  raw_image = first_frame[0].cpu().numpy()
358
  raw_image = raw_image.transpose(1, 2, 0)
359
-
360
  raw_image = (raw_image * 255).clip(0, 255).astype(np.uint8)
361
-
362
  render_results_list[0] = raw_image
363
 
364
- frame_indices = np.linspace(
365
- 0,
366
- 80,
367
- 25
368
- ).round().astype(int)
369
-
370
  frames = []
371
  for idx in frame_indices:
372
  frame = render_results_list[idx]
373
  frame = Image.fromarray(frame)
374
  frames.append(frame)
375
-
376
-
377
  last = frames[-1]
378
  for _ in range(4):
379
  frames.append(last)
380
 
381
- # TARGET_H, TARGET_W = 704, 1248
382
-
383
  def resize_pil(img):
384
  return img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
385
 
386
  frames = [resize_pil(f) for f in frames]
387
- image = resize_pil(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  # ===== Wan =====
390
  video = wan_pipe(
391
  prompt="Ensure the consistency of the video",
392
  negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,��容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
393
- src_video=frames,
394
  input_image=image,
395
  height=TARGET_H,
396
  width=TARGET_W,
@@ -403,58 +400,67 @@ def infer(image, prompt, seed):
403
 
404
  video_frames = list(video)
405
  last_frame = np.array(video_frames[-1])
406
-
407
- pcd_last = frames[-1]
408
-
409
- return Image.fromarray(last_frame), pcd_last
410
-
411
 
412
  # =========================
413
  # Gradio UI
414
  # =========================
415
  with gr.Blocks() as demo:
416
-
417
  # ===== 标题 + 说明 =====
418
- gr.Markdown("""
419
- <div style="line-height:1.2; font-size:16px">
420
-
421
- <b>UniGeo: Unifying Geometric Guidance for Camera-Controllable Image Editing via Video Models</b><br>
 
 
 
 
 
 
 
 
 
422
 
423
- <hr style="margin:6px 0;">
424
-
425
- <b>Input Requirement / 输入要求</b><br>
426
- The input image is recommended to have width height due to VGGT and Wan model constraints.<br>
427
- 由于 VGGT 与 Wan 模型限制建议输入图像满足 宽 ≥ 高
428
-
429
- <hr style="margin:6px 0;">
430
-
431
- <b>Usage Guide / 使用说明</b><br>
432
- You can input one or multiple camera commands separated by semicolons, such as “Camera pans left by 15 degrees” or “Camera moves left by 0.27; Camera pans right by 26 degrees”. The motion scale is normalized by VGGT, and the final point cloud is provided to help adjust motion parameters.<br>
433
- 支持输入一条或多条相机控制指令(使用分号分隔),例如“Camera pans left by 15 degrees”或“Camera moves left by 0.27; Camera pans right by 26 degrees”。所有运动数值由 VGGT 统一尺度建模,最终提供的点云结果可用于辅助调整相机运动参数。
434
 
435
- </div>
436
- """)
437
 
438
- # ===== 输入输出图 =====
439
  with gr.Row():
440
- inp = gr.Image(type="pil", label="Input Image")
441
- out = gr.Image(type="numpy", label="Output Image")
442
-
443
- # ===== prompt + seed =====
 
 
 
 
444
  with gr.Row():
445
- txt = gr.Textbox(label="Camera Prompt")
446
- seed_inp = gr.Number(value=0, label="Seed", precision=0)
447
-
448
- run_btn = gr.Button("Run")
449
-
450
- # ===== 点云输出 =====
451
- pcd_out = gr.Image(type="pil", label="Final Frame Point Cloud")
 
 
 
 
 
452
 
453
- # ===== 绑定 =====
454
- run_btn.click(
455
- fn=infer,
456
- inputs=[inp, txt, seed_inp],
457
- outputs=[out, pcd_out]
458
  )
459
 
460
  if __name__ == "__main__":
 
136
  #wan_pipe.to(device)
137
  #wan_pipe.to(dtype=torch.bfloat16)
138
 
139
+ load_models()
140
  # =========================
141
  # Renderer
142
  # =========================
 
266
  # =========================
267
 
268
  @spaces.GPU
269
+ def generate_pcd(image, prompt):
 
 
270
 
271
+ if image is None:
272
+ raise gr.Error("Please upload an input image!")
273
+ if not prompt:
274
+ raise gr.Error("Please enter camera control prompts!")
275
+
276
+
277
  img = image.convert("RGB")
 
278
  TARGET_H, TARGET_W = img.size[1], img.size[0]
279
  TARGET_H = TARGET_H // 32 * 32
280
  TARGET_W = TARGET_W // 32 * 32
 
281
  img = img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
282
 
283
  all_steps = generate_all_motions_from_prompt(prompt, num_frames=81)
 
284
  cam_idx = list(range(81))
285
  traj = [build_estimate_rel(*all_steps[idx]) for idx in cam_idx]
286
 
 
288
  first_frame = load_and_preprocess_images(first_frame)
289
  first_frame = first_frame.to(device)
290
 
 
 
291
  with torch.no_grad():
292
  with torch.cuda.amp.autocast(dtype=dtype):
293
  predictions = vggt_model(first_frame)
 
294
  extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], first_frame.shape[-2:])
 
295
  first_frame_world_points = predictions["world_points"][0][0]
 
296
  focals = intrinsic[0][0][:2, :2].diag().unsqueeze(0).to(device)
297
  principal_points = intrinsic[0][0][:2, 2].unsqueeze(0).to(device)
298
 
 
300
  raw_image = raw_image.transpose(1, 2, 0)
301
 
302
  render_results_list = []
 
 
303
  for estimate_rel in traj:
304
  estimate_rel = torch.from_numpy(estimate_rel).float().to(device)
305
  relative_c2ws = estimate_rel.unsqueeze(0)
306
  R, T = relative_c2ws[:, :3, :3], relative_c2ws[:, :3, 3:]
307
  R = torch.stack([-R[:, :, 0], -R[:, :, 1], R[:, :, 2]], 2)
308
  new_c2w = torch.cat([R, T], 2)
309
+
310
  w2c = torch.linalg.inv(torch.cat(
311
  (new_c2w, torch.Tensor([[[0, 0, 0, 1]]]).to(device).repeat(new_c2w.shape[0], 1, 1)),
312
  1
313
  ))
314
  R_new, T_new = w2c[:, :3, :3].permute(0, 2, 1), w2c[:, :3, 3]
315
 
 
316
  image_size = (first_frame.shape[-2:],)
 
317
  cameras = PerspectiveCameras(
318
  focal_length=focals,
319
  principal_point=principal_points,
 
323
  T=T_new,
324
  device=device
325
  )
326
+
327
  masks = None
328
  render_results, viewmask = run_render(
329
  [first_frame_world_points],
 
334
  1,
335
  device=device
336
  )
 
337
 
338
  render_result = (render_results[-1].detach().cpu().numpy() * 255).astype(np.uint8)
 
339
  if len(render_result.shape) == 2:
340
  render_result = cv2.cvtColor(render_result, cv2.COLOR_GRAY2RGB)
341
  elif render_result.shape[-1] == 4:
342
  render_result = render_result[..., :3]
 
343
  render_results_list.append(render_result)
344
 
 
345
  raw_image = first_frame[0].cpu().numpy()
346
  raw_image = raw_image.transpose(1, 2, 0)
 
347
  raw_image = (raw_image * 255).clip(0, 255).astype(np.uint8)
 
348
  render_results_list[0] = raw_image
349
 
350
+ frame_indices = np.linspace(0, 80, 25).round().astype(int)
 
 
 
 
 
351
  frames = []
352
  for idx in frame_indices:
353
  frame = render_results_list[idx]
354
  frame = Image.fromarray(frame)
355
  frames.append(frame)
356
+
 
357
  last = frames[-1]
358
  for _ in range(4):
359
  frames.append(last)
360
 
 
 
361
  def resize_pil(img):
362
  return img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
363
 
364
  frames = [resize_pil(f) for f in frames]
365
+ pcd_last = frames[-1]
366
+
367
+ # 返回给 UI 界面显示最后一张点云图,同时把所有帧数组传给隐藏的 state 变量
368
+ return pcd_last, frames
369
+
370
+ @spaces.GPU
371
+ def generate_final(image, frames, seed):
372
+ if not frames:
373
+ raise gr.Error("Please generate point cloud first!")
374
+
375
+
376
+ img = image.convert("RGB")
377
+ TARGET_H, TARGET_W = img.size[1], img.size[0]
378
+ TARGET_H = TARGET_H // 32 * 32
379
+ TARGET_W = TARGET_W // 32 * 32
380
+
381
+ def resize_pil(img_to_resize):
382
+ return img_to_resize.resize((TARGET_W, TARGET_H), Image.BICUBIC)
383
+
384
+ image = resize_pil(img)
385
 
386
  # ===== Wan =====
387
  video = wan_pipe(
388
  prompt="Ensure the consistency of the video",
389
  negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,��容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
390
+ src_video=frames, # 直接使用上一步传过来的 frames 状态
391
  input_image=image,
392
  height=TARGET_H,
393
  width=TARGET_W,
 
400
 
401
  video_frames = list(video)
402
  last_frame = np.array(video_frames[-1])
403
+ return Image.fromarray(last_frame)
 
 
 
 
404
 
405
  # =========================
406
  # Gradio UI
407
  # =========================
408
  with gr.Blocks() as demo:
 
409
  # ===== 标题 + 说明 =====
410
+ gr.HTML("""<div style="line-height:1.4; font-size:15px">
411
+ <b style="font-size:18px">UniGeo: Unifying Geometric Guidance for Camera-Controllable Image Editing via Video Models</b><br>
412
+
413
+ <hr style="margin:8px 0;">
414
+ <b>Input Requirement / 输入要求</b><br>
415
+ The input image is recommended to have width ≥ height due to VGGT and Wan model constraints.<br>
416
+ 由于 VGGT 与 Wan 模型限制,建议输入图像满足 宽 ≥ 高。<br>
417
+
418
+ <hr style="margin:8px 0;">
419
+ <b>Usage Guide / 使用说明</b>
420
+ <ul style="margin-top: 4px; padding-left: 20px;">
421
+ <li style="margin-bottom: 4px;"><b>Command Format / 指令格式:</b>You can input one or multiple camera commands separated by semicolons (e.g., “Camera pans left by 15 degrees” or “Camera moves left by 0.27; Camera pans right by 26 degrees”).<br>
422
+ 支持输入一条或多条相机控制指令,使用分号分隔(例如“Camera pans left by 15 degrees”或“Camera moves left by 0.27; Camera pans right by 26 degrees”)。</li>
423
 
424
+ <li style="margin-bottom: 4px;"><b>Scale & Adjustment / 尺度与调整:</b>The motion scale is normalized by VGGT, and the final point cloud is provided to help adjust motion parameters.<br>
425
+ 所有运动数值由 VGGT 统一尺度建模,最终提供的点云结果可用于辅助调整相机运动参数。</li>
426
+
427
+ <li><b>First Run / 首次运行:</b>Please note that the first execution will take slightly longer as the models are being loaded into the GPU. <br>
428
+ 首次运行需要将模型权重加载到显存耗时会稍微久一点,请耐心等待</li>
429
+ </ul>
430
+ </div>""")
 
 
 
 
431
 
432
+ # 隐藏的状态变量,用于在两步之间传递生成的视频帧
433
+ frames_state = gr.State([])
434
 
435
+ gr.Markdown("### Step 1: Point Cloud Preview / 步骤一:点云预览与调节")
436
  with gr.Row():
437
+ with gr.Column():
438
+ inp = gr.Image(type="pil", label="Input Image")
439
+ txt = gr.Textbox(label="Camera Prompt")
440
+ btn_pcd = gr.Button("Generate Point Cloud (生成点云)")
441
+ with gr.Column():
442
+ pcd_out = gr.Image(type="pil", label="Final Frame Point Cloud (预览结果)")
443
+
444
+ gr.Markdown("### Step 2: Final Result Generation / 步骤二:生成最终结果")
445
  with gr.Row():
446
+ with gr.Column():
447
+ seed_inp = gr.Number(value=0, label="Seed", precision=0)
448
+ btn_final = gr.Button("Generate Final Result (生成编辑结果)", variant="primary")
449
+ with gr.Column():
450
+ out = gr.Image(type="numpy", label="Output Image")
451
+
452
+ # ===== 绑定第一步:只生成点云和缓存视频帧 =====
453
+ btn_pcd.click(
454
+ fn=generate_pcd,
455
+ inputs=[inp, txt],
456
+ outputs=[pcd_out, frames_state] # 界面更新点云图,后台偷偷存下 frames 序列
457
+ )
458
 
459
+ # ===== 绑定第二步:读取缓存的帧,生成最终图 =====
460
+ btn_final.click(
461
+ fn=generate_final,
462
+ inputs=[inp, frames_state, seed_inp],
463
+ outputs=[out]
464
  )
465
 
466
  if __name__ == "__main__":