Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -136,7 +136,7 @@ def load_models():
|
|
| 136 |
#wan_pipe.to(device)
|
| 137 |
#wan_pipe.to(dtype=torch.bfloat16)
|
| 138 |
|
| 139 |
-
|
| 140 |
# =========================
|
| 141 |
# Renderer
|
| 142 |
# =========================
|
|
@@ -266,20 +266,21 @@ def build_estimate_rel(x, y, z, phi, theta):
|
|
| 266 |
# =========================
|
| 267 |
|
| 268 |
@spaces.GPU
|
| 269 |
-
def
|
| 270 |
-
|
| 271 |
-
load_models()
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
img = image.convert("RGB")
|
| 274 |
-
|
| 275 |
TARGET_H, TARGET_W = img.size[1], img.size[0]
|
| 276 |
TARGET_H = TARGET_H // 32 * 32
|
| 277 |
TARGET_W = TARGET_W // 32 * 32
|
| 278 |
-
|
| 279 |
img = img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
|
| 280 |
|
| 281 |
all_steps = generate_all_motions_from_prompt(prompt, num_frames=81)
|
| 282 |
-
|
| 283 |
cam_idx = list(range(81))
|
| 284 |
traj = [build_estimate_rel(*all_steps[idx]) for idx in cam_idx]
|
| 285 |
|
|
@@ -287,16 +288,11 @@ def infer(image, prompt, seed):
|
|
| 287 |
first_frame = load_and_preprocess_images(first_frame)
|
| 288 |
first_frame = first_frame.to(device)
|
| 289 |
|
| 290 |
-
|
| 291 |
-
|
| 292 |
with torch.no_grad():
|
| 293 |
with torch.cuda.amp.autocast(dtype=dtype):
|
| 294 |
predictions = vggt_model(first_frame)
|
| 295 |
-
|
| 296 |
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], first_frame.shape[-2:])
|
| 297 |
-
|
| 298 |
first_frame_world_points = predictions["world_points"][0][0]
|
| 299 |
-
|
| 300 |
focals = intrinsic[0][0][:2, :2].diag().unsqueeze(0).to(device)
|
| 301 |
principal_points = intrinsic[0][0][:2, 2].unsqueeze(0).to(device)
|
| 302 |
|
|
@@ -304,24 +300,20 @@ def infer(image, prompt, seed):
|
|
| 304 |
raw_image = raw_image.transpose(1, 2, 0)
|
| 305 |
|
| 306 |
render_results_list = []
|
| 307 |
-
|
| 308 |
-
|
| 309 |
for estimate_rel in traj:
|
| 310 |
estimate_rel = torch.from_numpy(estimate_rel).float().to(device)
|
| 311 |
relative_c2ws = estimate_rel.unsqueeze(0)
|
| 312 |
R, T = relative_c2ws[:, :3, :3], relative_c2ws[:, :3, 3:]
|
| 313 |
R = torch.stack([-R[:, :, 0], -R[:, :, 1], R[:, :, 2]], 2)
|
| 314 |
new_c2w = torch.cat([R, T], 2)
|
| 315 |
-
|
| 316 |
w2c = torch.linalg.inv(torch.cat(
|
| 317 |
(new_c2w, torch.Tensor([[[0, 0, 0, 1]]]).to(device).repeat(new_c2w.shape[0], 1, 1)),
|
| 318 |
1
|
| 319 |
))
|
| 320 |
R_new, T_new = w2c[:, :3, :3].permute(0, 2, 1), w2c[:, :3, 3]
|
| 321 |
|
| 322 |
-
|
| 323 |
image_size = (first_frame.shape[-2:],)
|
| 324 |
-
|
| 325 |
cameras = PerspectiveCameras(
|
| 326 |
focal_length=focals,
|
| 327 |
principal_point=principal_points,
|
|
@@ -331,7 +323,7 @@ def infer(image, prompt, seed):
|
|
| 331 |
T=T_new,
|
| 332 |
device=device
|
| 333 |
)
|
| 334 |
-
|
| 335 |
masks = None
|
| 336 |
render_results, viewmask = run_render(
|
| 337 |
[first_frame_world_points],
|
|
@@ -342,55 +334,60 @@ def infer(image, prompt, seed):
|
|
| 342 |
1,
|
| 343 |
device=device
|
| 344 |
)
|
| 345 |
-
|
| 346 |
|
| 347 |
render_result = (render_results[-1].detach().cpu().numpy() * 255).astype(np.uint8)
|
| 348 |
-
|
| 349 |
if len(render_result.shape) == 2:
|
| 350 |
render_result = cv2.cvtColor(render_result, cv2.COLOR_GRAY2RGB)
|
| 351 |
elif render_result.shape[-1] == 4:
|
| 352 |
render_result = render_result[..., :3]
|
| 353 |
-
|
| 354 |
render_results_list.append(render_result)
|
| 355 |
|
| 356 |
-
|
| 357 |
raw_image = first_frame[0].cpu().numpy()
|
| 358 |
raw_image = raw_image.transpose(1, 2, 0)
|
| 359 |
-
|
| 360 |
raw_image = (raw_image * 255).clip(0, 255).astype(np.uint8)
|
| 361 |
-
|
| 362 |
render_results_list[0] = raw_image
|
| 363 |
|
| 364 |
-
frame_indices = np.linspace(
|
| 365 |
-
0,
|
| 366 |
-
80,
|
| 367 |
-
25
|
| 368 |
-
).round().astype(int)
|
| 369 |
-
|
| 370 |
frames = []
|
| 371 |
for idx in frame_indices:
|
| 372 |
frame = render_results_list[idx]
|
| 373 |
frame = Image.fromarray(frame)
|
| 374 |
frames.append(frame)
|
| 375 |
-
|
| 376 |
-
|
| 377 |
last = frames[-1]
|
| 378 |
for _ in range(4):
|
| 379 |
frames.append(last)
|
| 380 |
|
| 381 |
-
# TARGET_H, TARGET_W = 704, 1248
|
| 382 |
-
|
| 383 |
def resize_pil(img):
|
| 384 |
return img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
|
| 385 |
|
| 386 |
frames = [resize_pil(f) for f in frames]
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
# ===== Wan =====
|
| 390 |
video = wan_pipe(
|
| 391 |
prompt="Ensure the consistency of the video",
|
| 392 |
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,��容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 393 |
-
src_video=frames,
|
| 394 |
input_image=image,
|
| 395 |
height=TARGET_H,
|
| 396 |
width=TARGET_W,
|
|
@@ -403,58 +400,67 @@ def infer(image, prompt, seed):
|
|
| 403 |
|
| 404 |
video_frames = list(video)
|
| 405 |
last_frame = np.array(video_frames[-1])
|
| 406 |
-
|
| 407 |
-
pcd_last = frames[-1]
|
| 408 |
-
|
| 409 |
-
return Image.fromarray(last_frame), pcd_last
|
| 410 |
-
|
| 411 |
|
| 412 |
# =========================
|
| 413 |
# Gradio UI
|
| 414 |
# =========================
|
| 415 |
with gr.Blocks() as demo:
|
| 416 |
-
|
| 417 |
# ===== 标题 + 说明 =====
|
| 418 |
-
gr.
|
| 419 |
-
<
|
| 420 |
-
|
| 421 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
-
<
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
<
|
| 430 |
-
|
| 431 |
-
<b>Usage Guide / 使用说明</b><br>
|
| 432 |
-
You can input one or multiple camera commands separated by semicolons, such as “Camera pans left by 15 degrees” or “Camera moves left by 0.27; Camera pans right by 26 degrees”. The motion scale is normalized by VGGT, and the final point cloud is provided to help adjust motion parameters.<br>
|
| 433 |
-
支持输入一条或多条相机控制指令(使用分号分隔),例如“Camera pans left by 15 degrees”或“Camera moves left by 0.27; Camera pans right by 26 degrees”。所有运动数值由 VGGT 统一尺度建模,最终提供的点云结果可用于辅助调整相机运动参数。
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
|
| 438 |
-
#
|
| 439 |
with gr.Row():
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
with gr.Row():
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
-
# ===== 绑定 =====
|
| 454 |
-
|
| 455 |
-
fn=
|
| 456 |
-
inputs=[inp,
|
| 457 |
-
outputs=[out
|
| 458 |
)
|
| 459 |
|
| 460 |
if __name__ == "__main__":
|
|
|
|
| 136 |
#wan_pipe.to(device)
|
| 137 |
#wan_pipe.to(dtype=torch.bfloat16)
|
| 138 |
|
| 139 |
+
load_models()
|
| 140 |
# =========================
|
| 141 |
# Renderer
|
| 142 |
# =========================
|
|
|
|
| 266 |
# =========================
|
| 267 |
|
| 268 |
@spaces.GPU
|
| 269 |
+
def generate_pcd(image, prompt):
|
|
|
|
|
|
|
| 270 |
|
| 271 |
+
if image is None:
|
| 272 |
+
raise gr.Error("Please upload an input image!")
|
| 273 |
+
if not prompt:
|
| 274 |
+
raise gr.Error("Please enter camera control prompts!")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
img = image.convert("RGB")
|
|
|
|
| 278 |
TARGET_H, TARGET_W = img.size[1], img.size[0]
|
| 279 |
TARGET_H = TARGET_H // 32 * 32
|
| 280 |
TARGET_W = TARGET_W // 32 * 32
|
|
|
|
| 281 |
img = img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
|
| 282 |
|
| 283 |
all_steps = generate_all_motions_from_prompt(prompt, num_frames=81)
|
|
|
|
| 284 |
cam_idx = list(range(81))
|
| 285 |
traj = [build_estimate_rel(*all_steps[idx]) for idx in cam_idx]
|
| 286 |
|
|
|
|
| 288 |
first_frame = load_and_preprocess_images(first_frame)
|
| 289 |
first_frame = first_frame.to(device)
|
| 290 |
|
|
|
|
|
|
|
| 291 |
with torch.no_grad():
|
| 292 |
with torch.cuda.amp.autocast(dtype=dtype):
|
| 293 |
predictions = vggt_model(first_frame)
|
|
|
|
| 294 |
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], first_frame.shape[-2:])
|
|
|
|
| 295 |
first_frame_world_points = predictions["world_points"][0][0]
|
|
|
|
| 296 |
focals = intrinsic[0][0][:2, :2].diag().unsqueeze(0).to(device)
|
| 297 |
principal_points = intrinsic[0][0][:2, 2].unsqueeze(0).to(device)
|
| 298 |
|
|
|
|
| 300 |
raw_image = raw_image.transpose(1, 2, 0)
|
| 301 |
|
| 302 |
render_results_list = []
|
|
|
|
|
|
|
| 303 |
for estimate_rel in traj:
|
| 304 |
estimate_rel = torch.from_numpy(estimate_rel).float().to(device)
|
| 305 |
relative_c2ws = estimate_rel.unsqueeze(0)
|
| 306 |
R, T = relative_c2ws[:, :3, :3], relative_c2ws[:, :3, 3:]
|
| 307 |
R = torch.stack([-R[:, :, 0], -R[:, :, 1], R[:, :, 2]], 2)
|
| 308 |
new_c2w = torch.cat([R, T], 2)
|
| 309 |
+
|
| 310 |
w2c = torch.linalg.inv(torch.cat(
|
| 311 |
(new_c2w, torch.Tensor([[[0, 0, 0, 1]]]).to(device).repeat(new_c2w.shape[0], 1, 1)),
|
| 312 |
1
|
| 313 |
))
|
| 314 |
R_new, T_new = w2c[:, :3, :3].permute(0, 2, 1), w2c[:, :3, 3]
|
| 315 |
|
|
|
|
| 316 |
image_size = (first_frame.shape[-2:],)
|
|
|
|
| 317 |
cameras = PerspectiveCameras(
|
| 318 |
focal_length=focals,
|
| 319 |
principal_point=principal_points,
|
|
|
|
| 323 |
T=T_new,
|
| 324 |
device=device
|
| 325 |
)
|
| 326 |
+
|
| 327 |
masks = None
|
| 328 |
render_results, viewmask = run_render(
|
| 329 |
[first_frame_world_points],
|
|
|
|
| 334 |
1,
|
| 335 |
device=device
|
| 336 |
)
|
|
|
|
| 337 |
|
| 338 |
render_result = (render_results[-1].detach().cpu().numpy() * 255).astype(np.uint8)
|
|
|
|
| 339 |
if len(render_result.shape) == 2:
|
| 340 |
render_result = cv2.cvtColor(render_result, cv2.COLOR_GRAY2RGB)
|
| 341 |
elif render_result.shape[-1] == 4:
|
| 342 |
render_result = render_result[..., :3]
|
|
|
|
| 343 |
render_results_list.append(render_result)
|
| 344 |
|
|
|
|
| 345 |
raw_image = first_frame[0].cpu().numpy()
|
| 346 |
raw_image = raw_image.transpose(1, 2, 0)
|
|
|
|
| 347 |
raw_image = (raw_image * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
| 348 |
render_results_list[0] = raw_image
|
| 349 |
|
| 350 |
+
frame_indices = np.linspace(0, 80, 25).round().astype(int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
frames = []
|
| 352 |
for idx in frame_indices:
|
| 353 |
frame = render_results_list[idx]
|
| 354 |
frame = Image.fromarray(frame)
|
| 355 |
frames.append(frame)
|
| 356 |
+
|
|
|
|
| 357 |
last = frames[-1]
|
| 358 |
for _ in range(4):
|
| 359 |
frames.append(last)
|
| 360 |
|
|
|
|
|
|
|
| 361 |
def resize_pil(img):
|
| 362 |
return img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
|
| 363 |
|
| 364 |
frames = [resize_pil(f) for f in frames]
|
| 365 |
+
pcd_last = frames[-1]
|
| 366 |
+
|
| 367 |
+
# 返回给 UI 界面显示最后一张点云图,同时把所有帧数组传给隐藏的 state 变量
|
| 368 |
+
return pcd_last, frames
|
| 369 |
+
|
| 370 |
+
@spaces.GPU
|
| 371 |
+
def generate_final(image, frames, seed):
|
| 372 |
+
if not frames:
|
| 373 |
+
raise gr.Error("Please generate point cloud first!")
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
img = image.convert("RGB")
|
| 377 |
+
TARGET_H, TARGET_W = img.size[1], img.size[0]
|
| 378 |
+
TARGET_H = TARGET_H // 32 * 32
|
| 379 |
+
TARGET_W = TARGET_W // 32 * 32
|
| 380 |
+
|
| 381 |
+
def resize_pil(img_to_resize):
|
| 382 |
+
return img_to_resize.resize((TARGET_W, TARGET_H), Image.BICUBIC)
|
| 383 |
+
|
| 384 |
+
image = resize_pil(img)
|
| 385 |
|
| 386 |
# ===== Wan =====
|
| 387 |
video = wan_pipe(
|
| 388 |
prompt="Ensure the consistency of the video",
|
| 389 |
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,��容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 390 |
+
src_video=frames, # 直接使用上一步传过来的 frames 状态
|
| 391 |
input_image=image,
|
| 392 |
height=TARGET_H,
|
| 393 |
width=TARGET_W,
|
|
|
|
| 400 |
|
| 401 |
video_frames = list(video)
|
| 402 |
last_frame = np.array(video_frames[-1])
|
| 403 |
+
return Image.fromarray(last_frame)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
# =========================
|
| 406 |
# Gradio UI
|
| 407 |
# =========================
|
| 408 |
with gr.Blocks() as demo:
|
|
|
|
| 409 |
# ===== 标题 + 说明 =====
|
| 410 |
+
gr.HTML("""<div style="line-height:1.4; font-size:15px">
|
| 411 |
+
<b style="font-size:18px">UniGeo: Unifying Geometric Guidance for Camera-Controllable Image Editing via Video Models</b><br>
|
| 412 |
+
|
| 413 |
+
<hr style="margin:8px 0;">
|
| 414 |
+
<b>Input Requirement / 输入要求</b><br>
|
| 415 |
+
The input image is recommended to have width ≥ height due to VGGT and Wan model constraints.<br>
|
| 416 |
+
由于 VGGT 与 Wan 模型限制,建议输入图像满足 宽 ≥ 高。<br>
|
| 417 |
+
|
| 418 |
+
<hr style="margin:8px 0;">
|
| 419 |
+
<b>Usage Guide / 使用说明</b>
|
| 420 |
+
<ul style="margin-top: 4px; padding-left: 20px;">
|
| 421 |
+
<li style="margin-bottom: 4px;"><b>Command Format / 指令格式:</b>You can input one or multiple camera commands separated by semicolons (e.g., “Camera pans left by 15 degrees” or “Camera moves left by 0.27; Camera pans right by 26 degrees”).<br>
|
| 422 |
+
支持输入一条或多条相机控制指令,使用分号分隔(例如“Camera pans left by 15 degrees”或“Camera moves left by 0.27; Camera pans right by 26 degrees”)。</li>
|
| 423 |
|
| 424 |
+
<li style="margin-bottom: 4px;"><b>Scale & Adjustment / 尺度与调整:</b>The motion scale is normalized by VGGT, and the final point cloud is provided to help adjust motion parameters.<br>
|
| 425 |
+
所有运动数值由 VGGT 统一尺度建模,最终提供的点云结果可用于辅助调整相机运动参数。</li>
|
| 426 |
+
|
| 427 |
+
<li><b>First Run / 首次运行:</b>Please note that the first execution will take slightly longer as the models are being loaded into the GPU. <br>
|
| 428 |
+
首次运行需要将模型权重加载到显存,耗时会稍微久一点,请耐心等待。</li>
|
| 429 |
+
</ul>
|
| 430 |
+
</div>""")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
# 隐藏的状态变量,用于在两步之间传递生成的视频帧
|
| 433 |
+
frames_state = gr.State([])
|
| 434 |
|
| 435 |
+
gr.Markdown("### Step 1: Point Cloud Preview / 步骤一:点云预览与调节")
|
| 436 |
with gr.Row():
|
| 437 |
+
with gr.Column():
|
| 438 |
+
inp = gr.Image(type="pil", label="Input Image")
|
| 439 |
+
txt = gr.Textbox(label="Camera Prompt")
|
| 440 |
+
btn_pcd = gr.Button("Generate Point Cloud (生成点云)")
|
| 441 |
+
with gr.Column():
|
| 442 |
+
pcd_out = gr.Image(type="pil", label="Final Frame Point Cloud (预览结果)")
|
| 443 |
+
|
| 444 |
+
gr.Markdown("### Step 2: Final Result Generation / 步骤二:生成最终结果")
|
| 445 |
with gr.Row():
|
| 446 |
+
with gr.Column():
|
| 447 |
+
seed_inp = gr.Number(value=0, label="Seed", precision=0)
|
| 448 |
+
btn_final = gr.Button("Generate Final Result (生成编辑结果)", variant="primary")
|
| 449 |
+
with gr.Column():
|
| 450 |
+
out = gr.Image(type="numpy", label="Output Image")
|
| 451 |
+
|
| 452 |
+
# ===== 绑定第一步:只生成点云和缓存视频帧 =====
|
| 453 |
+
btn_pcd.click(
|
| 454 |
+
fn=generate_pcd,
|
| 455 |
+
inputs=[inp, txt],
|
| 456 |
+
outputs=[pcd_out, frames_state] # 界面更新点云图,后台偷偷存下 frames 序列
|
| 457 |
+
)
|
| 458 |
|
| 459 |
+
# ===== 绑定第二步:读取缓存的帧,生成最终图 =====
|
| 460 |
+
btn_final.click(
|
| 461 |
+
fn=generate_final,
|
| 462 |
+
inputs=[inp, frames_state, seed_inp],
|
| 463 |
+
outputs=[out]
|
| 464 |
)
|
| 465 |
|
| 466 |
if __name__ == "__main__":
|