Spaces:
Running on Zero
Running on Zero
| import torch, os, re | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from scipy.spatial.transform import Rotation | |
| import cv2, sys | |
| from huggingface_hub import snapshot_download | |
| import spaces | |
| import site | |
| import importlib | |
| #os.system("pip install ./pytorch3d-0.7.8+pt2.7.1cu126-cp310-cp310-linux_x86_64.whl") | |
| #os.system("python -m pip install -e ./pytorch3d-0.7.8 --no-build-isolation") | |
| #site.main() | |
| #importlib.invalidate_caches() | |
| # ===== VGGT ===== | |
| sys.path.append(os.path.join(os.getcwd(), "vggt")) | |
| from vggt.models.vggt import VGGT | |
| from vggt.utils.load_fn import load_and_preprocess_images | |
| from vggt.utils.pose_enc import pose_encoding_to_extri_intri | |
| # ===== Wan ===== | |
| sys.path.append(os.path.join(os.getcwd(), "DiffSynth-Studio")) | |
| from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig | |
| from safetensors.torch import load_file | |
| # ===== PyTorch3D ===== | |
| from pytorch3d.structures import Pointclouds | |
| from pytorch3d.renderer import ( | |
| PerspectiveCameras, | |
| PointsRasterizationSettings, | |
| PointsRenderer, | |
| PointsRasterizer, | |
| AlphaCompositor, | |
| ) | |
| def todevice(batch, device, callback=None, non_blocking=False): | |
| ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). | |
| batch: list, tuple, dict of tensors or other things | |
| device: pytorch device or 'numpy' | |
| callback: function that would be called on every sub-elements. | |
| ''' | |
| if isinstance(batch, dict): | |
| return {k: todevice(v, device) for k, v in batch.items()} | |
| if isinstance(batch, (tuple, list)): | |
| return type(batch)(todevice(x, device) for x in batch) | |
| x = batch | |
| if device == 'numpy': | |
| if isinstance(x, torch.Tensor): | |
| x = x.detach().cpu().numpy() | |
| elif x is not None: | |
| if isinstance(x, np.ndarray): | |
| x = torch.from_numpy(x) | |
| if torch.is_tensor(x): | |
| x = x.to(device, non_blocking=non_blocking) | |
| return x | |
| def to_numpy(x): return todevice(x, 'numpy') | |
| # ========================= | |
| # Global configs (CHANGE THESE PATHS) | |
| # ========================= | |
| hf_token = os.getenv("HF_TOKEN") | |
| VGGT_PATH = snapshot_download(repo_id="facebook/VGGT-1B", token=hf_token) | |
| WAN_MODEL_DIR = snapshot_download(repo_id="Wan-AI/Wan2.2-TI2V-5B", token=hf_token) | |
| LORA_DIR = snapshot_download(repo_id="123123aa123/UniGeo", token=hf_token) | |
| LORA_PATH = os.path.join(LORA_DIR, "UniGeo_lora.safetensors") | |
| #VGGT_PATH = "./VGGT_PATH" | |
| #WAN_MODEL_DIR = "./WAN_MODEL_DIR" | |
| #LORA_PATH = "./LORA_PATH" | |
| WAN_CONFIG_PATH = "./my_config.json" | |
| # ========================= | |
| # Global models | |
| # ========================= | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 | |
| vggt_model = None | |
| wan_pipe = None | |
| # ========================= | |
| # Load models once | |
| # ========================= | |
| def load_models(): | |
| global vggt_model, wan_pipe | |
| if vggt_model is None: | |
| print("Loading VGGT...") | |
| vggt_model = VGGT.from_pretrained(VGGT_PATH).to(device).eval() | |
| if wan_pipe is None: | |
| print("Loading Wan...") | |
| wan_paths = [ | |
| os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00001-of-00003.safetensors"), | |
| os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00002-of-00003.safetensors"), | |
| os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00003-of-00003.safetensors"), | |
| ] | |
| wan_pipe = WanVideoPipeline.from_pretrained( | |
| torch_dtype=torch.bfloat16, | |
| device=device, | |
| model_configs=[ | |
| ModelConfig(path=os.path.join(WAN_MODEL_DIR, "models_t5_umt5-xxl-enc-bf16.pth")), | |
| ModelConfig(path=os.path.join(WAN_MODEL_DIR, "Wan2.2_VAE.pth")), | |
| ], | |
| tokenizer_config=ModelConfig(path=os.path.join(WAN_MODEL_DIR, "google/umt5-xxl/")), | |
| wan_paths=wan_paths, | |
| wan_config_path=WAN_CONFIG_PATH | |
| ) | |
| ckpt = load_file(LORA_PATH) | |
| lora_sd, adapter_sd = {}, {} | |
| for k, v in ckpt.items(): | |
| if ".lora_" in k: | |
| lora_sd[k] = v | |
| elif "i2v_adapter" in k: | |
| adapter_sd[k] = v | |
| wan_pipe.load_lora(wan_pipe.dit, state_dict=lora_sd, alpha=1) | |
| wan_pipe.dit.load_state_dict(adapter_sd, strict=False) | |
| wan_pipe.to(device) | |
| wan_pipe.to(dtype=torch.bfloat16) | |
| load_models() | |
| # ========================= | |
| # Renderer | |
| # ========================= | |
| def setup_renderer(cameras, image_size): | |
| raster_settings = PointsRasterizationSettings( | |
| image_size=image_size, | |
| radius = 0.01, | |
| points_per_pixel = 10, | |
| bin_size = 0 | |
| ) | |
| renderer = PointsRenderer( | |
| rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings), | |
| compositor=AlphaCompositor() | |
| ) | |
| render_setup = {'cameras': cameras, 'raster_settings': raster_settings, 'renderer': renderer} | |
| return render_setup | |
| def render_pcd(pts3d, imgs, masks, views, renderer, device, nbv=False): | |
| imgs = to_numpy(imgs) | |
| pts3d = to_numpy(pts3d) | |
| if masks is None: | |
| pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) | |
| col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) | |
| else: | |
| pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) | |
| col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) | |
| point_cloud = Pointclouds(points=[pts], features=[col]).extend(views) | |
| images = renderer(point_cloud) | |
| if nbv: | |
| color_mask = torch.ones(col.shape).to(device) | |
| point_cloud_mask = Pointclouds(points=[pts], features=[color_mask]).extend(views) | |
| view_masks = renderer(point_cloud_mask) | |
| else: | |
| view_masks = None | |
| return images, view_masks | |
| def run_render(pcd, imgs, masks, H, W, camera_traj, num_views, device, nbv=True): | |
| render_setup = setup_renderer(camera_traj, image_size=(H,W)) | |
| renderer = render_setup['renderer'] | |
| render_results, viewmask = render_pcd(pcd, imgs, masks, num_views, renderer, device, nbv=nbv) | |
| return render_results, viewmask | |
| # ========================= | |
| # Prompt parsing | |
| # ========================= | |
| def generate_all_motions_from_prompt(prompt, num_frames): | |
| x, y, z, phi, theta = parse_prompt_to_motion(prompt) | |
| results = [] | |
| for i in range(num_frames): | |
| alpha = i / (num_frames - 1) | |
| results.append(( | |
| x * alpha, | |
| y * alpha, | |
| z * alpha, | |
| phi * alpha, | |
| theta * alpha | |
| )) | |
| return results | |
| def parse_prompt_to_motion(prompt): | |
| prompt = prompt.lower() | |
| x = y = z = phi = theta = 0.0 | |
| clauses = re.split(r'[;,\n]| and ', prompt) | |
| for clause in clauses: | |
| nums = re.findall(r"[-+]?\d*\.?\d+", clause) | |
| if not nums: | |
| continue | |
| val = float(nums[0]) | |
| if "pans left" in clause: | |
| phi = -val | |
| elif "pans right" in clause: | |
| phi = val | |
| elif "tilts up" in clause: | |
| theta = val | |
| elif "tilts down" in clause: | |
| theta = -val | |
| elif "moves forward" in clause: | |
| z = val | |
| elif "moves backward" in clause: | |
| z = -val | |
| elif "moves up" in clause: | |
| y = -val | |
| elif "moves down" in clause: | |
| y = val | |
| elif "moves left" in clause: | |
| x = -val | |
| elif "moves right" in clause: | |
| x = val | |
| return x, y, z, phi, theta | |
| def build_estimate_rel(x, y, z, phi, theta): | |
| delta_euler = [theta, phi, 0.0] | |
| rot_mat = Rotation.from_euler('xyz', delta_euler, degrees=True).as_matrix() | |
| mat = np.eye(4) | |
| mat[:3, :3] = rot_mat | |
| mat[:3, 3] = [x, y, z] | |
| return mat | |
| # ========================= | |
| # Main inference | |
| # ========================= | |
| def generate_pcd(image, prompt): | |
| if image is None: | |
| raise gr.Error("Please upload an input image!") | |
| if not prompt: | |
| raise gr.Error("Please enter camera control prompts!") | |
| img = image.convert("RGB") | |
| TARGET_H, TARGET_W = img.size[1], img.size[0] | |
| TARGET_H = TARGET_H // 32 * 32 | |
| TARGET_W = TARGET_W // 32 * 32 | |
| img = img.resize((TARGET_W, TARGET_H), Image.BICUBIC) | |
| all_steps = generate_all_motions_from_prompt(prompt, num_frames=81) | |
| cam_idx = list(range(81)) | |
| traj = [build_estimate_rel(*all_steps[idx]) for idx in cam_idx] | |
| first_frame = [img, img] | |
| first_frame = load_and_preprocess_images(first_frame) | |
| first_frame = first_frame.to(device) | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(dtype=dtype): | |
| predictions = vggt_model(first_frame) | |
| extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], first_frame.shape[-2:]) | |
| first_frame_world_points = predictions["world_points"][0][0] | |
| focals = intrinsic[0][0][:2, :2].diag().unsqueeze(0).to(device) | |
| principal_points = intrinsic[0][0][:2, 2].unsqueeze(0).to(device) | |
| raw_image = first_frame[0].cpu().numpy() | |
| raw_image = raw_image.transpose(1, 2, 0) | |
| render_results_list = [] | |
| for estimate_rel in traj: | |
| estimate_rel = torch.from_numpy(estimate_rel).float().to(device) | |
| relative_c2ws = estimate_rel.unsqueeze(0) | |
| R, T = relative_c2ws[:, :3, :3], relative_c2ws[:, :3, 3:] | |
| R = torch.stack([-R[:, :, 0], -R[:, :, 1], R[:, :, 2]], 2) | |
| new_c2w = torch.cat([R, T], 2) | |
| w2c = torch.linalg.inv(torch.cat( | |
| (new_c2w, torch.Tensor([[[0, 0, 0, 1]]]).to(device).repeat(new_c2w.shape[0], 1, 1)), | |
| 1 | |
| )) | |
| R_new, T_new = w2c[:, :3, :3].permute(0, 2, 1), w2c[:, :3, 3] | |
| image_size = (first_frame.shape[-2:],) | |
| cameras = PerspectiveCameras( | |
| focal_length=focals, | |
| principal_point=principal_points, | |
| in_ndc=False, | |
| image_size=image_size, | |
| R=R_new, | |
| T=T_new, | |
| device=device | |
| ) | |
| masks = None | |
| render_results, viewmask = run_render( | |
| [first_frame_world_points], | |
| [raw_image], | |
| masks, | |
| image_size[0][0], image_size[0][1], | |
| cameras, | |
| 1, | |
| device=device | |
| ) | |
| render_result = (render_results[-1].detach().cpu().numpy() * 255).astype(np.uint8) | |
| if len(render_result.shape) == 2: | |
| render_result = cv2.cvtColor(render_result, cv2.COLOR_GRAY2RGB) | |
| elif render_result.shape[-1] == 4: | |
| render_result = render_result[..., :3] | |
| render_results_list.append(render_result) | |
| raw_image = first_frame[0].cpu().numpy() | |
| raw_image = raw_image.transpose(1, 2, 0) | |
| raw_image = (raw_image * 255).clip(0, 255).astype(np.uint8) | |
| render_results_list[0] = raw_image | |
| frame_indices = np.linspace(0, 80, 25).round().astype(int) | |
| frames = [] | |
| for idx in frame_indices: | |
| frame = render_results_list[idx] | |
| frame = Image.fromarray(frame) | |
| frames.append(frame) | |
| last = frames[-1] | |
| for _ in range(4): | |
| frames.append(last) | |
| def resize_pil(img): | |
| return img.resize((TARGET_W, TARGET_H), Image.BICUBIC) | |
| frames = [resize_pil(f) for f in frames] | |
| pcd_last = frames[-1] | |
| # 返回给 UI 界面显示最后一张点云图,同时把所有帧数组传给隐藏的 state 变量 | |
| return pcd_last, frames | |
| def generate_final(image, frames, seed): | |
| if not frames: | |
| raise gr.Error("Please generate point cloud first!") | |
| img = image.convert("RGB") | |
| TARGET_H, TARGET_W = img.size[1], img.size[0] | |
| TARGET_H = TARGET_H // 32 * 32 | |
| TARGET_W = TARGET_W // 32 * 32 | |
| def resize_pil(img_to_resize): | |
| return img_to_resize.resize((TARGET_W, TARGET_H), Image.BICUBIC) | |
| image = resize_pil(img) | |
| # ===== Wan ===== | |
| video = wan_pipe( | |
| prompt="Ensure the consistency of the video", | |
| negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", | |
| src_video=frames, # 直接使用上一步传过来的 frames 状态 | |
| input_image=image, | |
| height=TARGET_H, | |
| width=TARGET_W, | |
| cfg_scale=5.0, | |
| num_frames=29, | |
| num_inference_steps=28, | |
| seed=int(seed), | |
| tiled=True | |
| ) | |
| video_frames = list(video) | |
| last_frame = np.array(video_frames[-1]) | |
| return Image.fromarray(last_frame) | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| with gr.Blocks() as demo: | |
| # ===== 标题 + 说明 ===== | |
| gr.HTML("""<div style="line-height:1.4; font-size:15px"> | |
| <b style="font-size:18px">UniGeo: Unifying Geometric Guidance for Camera-Controllable Image Editing via Video Models</b><br> | |
| <hr style="margin:8px 0;"> | |
| <b>Input Requirement / 输入要求</b><br> | |
| The input image is recommended to have width ≥ height due to VGGT and Wan model constraints.<br> | |
| 由于 VGGT 与 Wan 模型限制,建议输入图像满足 宽 ≥ 高。<br> | |
| <hr style="margin:8px 0;"> | |
| <b>Usage Guide / 使用说明</b> | |
| <ul style="margin-top: 4px; padding-left: 20px;"> | |
| <li style="margin-bottom: 4px;"><b>Command Format / 指令格式:</b>You can input one or multiple camera commands separated by semicolons (e.g., “Camera pans left by 15 degrees” or “Camera moves left by 0.27; Camera pans right by 26 degrees”).<br> | |
| 支持输入一条或多条相机控制指令,使用分号分隔(例如“Camera pans left by 15 degrees”或“Camera moves left by 0.27; Camera pans right by 26 degrees”)。</li> | |
| <li style="margin-bottom: 4px;"><b>Scale & Adjustment / 尺度与调整:</b>The motion scale is normalized by VGGT, and the final point cloud is provided to help adjust motion parameters.<br> | |
| 所有运动数值由 VGGT 统一尺度建模,最终提供的点云结果可用于辅助调整相机运动参数。</li> | |
| <li><b>Tips / 提示:</b>Default inference steps: 28 (Speed & Quality balanced). Run locally with higher steps for better results. <br> | |
| 为平衡时间与质量,当前推理步数设为 28。若想获得更佳效果,可在本地试着提高推理步数。</li> | |
| </ul> | |
| </div>""") | |
| # 隐藏的状态变量,用于在两步之间传递生成的视频帧 | |
| frames_state = gr.State([]) | |
| gr.Markdown("### Step 1: Point Cloud Preview / 步骤一:点云预览与调节") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(type="pil", label="Input Image") | |
| txt = gr.Textbox(label="Camera Prompt") | |
| btn_pcd = gr.Button("Generate Point Cloud (生成点云)") | |
| with gr.Column(): | |
| pcd_out = gr.Image(type="pil", label="Final Frame Point Cloud (预览结果)") | |
| gr.Markdown("### Step 2: Final Result Generation / 步骤二:生成最终结果") | |
| with gr.Row(): | |
| with gr.Column(): | |
| seed_inp = gr.Number(value=0, label="Seed", precision=0) | |
| btn_final = gr.Button("Generate Final Result (生成编辑结果)", variant="primary") | |
| with gr.Column(): | |
| out = gr.Image(type="numpy", label="Output Image") | |
| # ===== 绑定第一步:只生成点云和缓存视频帧 ===== | |
| btn_pcd.click( | |
| fn=generate_pcd, | |
| inputs=[inp, txt], | |
| outputs=[pcd_out, frames_state] # 界面更新点云图,后台偷偷存下 frames 序列 | |
| ) | |
| # ===== 绑定第二步:读取缓存的帧,生成最终图 ===== | |
| btn_final.click( | |
| fn=generate_final, | |
| inputs=[inp, frames_state, seed_inp], | |
| outputs=[out] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |