import torch, os, re import numpy as np import gradio as gr from PIL import Image from scipy.spatial.transform import Rotation import cv2, sys from huggingface_hub import snapshot_download import spaces import site import importlib #os.system("pip install ./pytorch3d-0.7.8+pt2.7.1cu126-cp310-cp310-linux_x86_64.whl") #os.system("python -m pip install -e ./pytorch3d-0.7.8 --no-build-isolation") #site.main() #importlib.invalidate_caches() # ===== VGGT ===== sys.path.append(os.path.join(os.getcwd(), "vggt")) from vggt.models.vggt import VGGT from vggt.utils.load_fn import load_and_preprocess_images from vggt.utils.pose_enc import pose_encoding_to_extri_intri # ===== Wan ===== sys.path.append(os.path.join(os.getcwd(), "DiffSynth-Studio")) from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig from safetensors.torch import load_file # ===== PyTorch3D ===== from pytorch3d.structures import Pointclouds from pytorch3d.renderer import ( PerspectiveCameras, PointsRasterizationSettings, PointsRenderer, PointsRasterizer, AlphaCompositor, ) def todevice(batch, device, callback=None, non_blocking=False): ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). batch: list, tuple, dict of tensors or other things device: pytorch device or 'numpy' callback: function that would be called on every sub-elements. ''' if isinstance(batch, dict): return {k: todevice(v, device) for k, v in batch.items()} if isinstance(batch, (tuple, list)): return type(batch)(todevice(x, device) for x in batch) x = batch if device == 'numpy': if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() elif x is not None: if isinstance(x, np.ndarray): x = torch.from_numpy(x) if torch.is_tensor(x): x = x.to(device, non_blocking=non_blocking) return x def to_numpy(x): return todevice(x, 'numpy') # ========================= # Global configs (CHANGE THESE PATHS) # ========================= hf_token = os.getenv("HF_TOKEN") VGGT_PATH = snapshot_download(repo_id="facebook/VGGT-1B", token=hf_token) WAN_MODEL_DIR = snapshot_download(repo_id="Wan-AI/Wan2.2-TI2V-5B", token=hf_token) LORA_DIR = snapshot_download(repo_id="123123aa123/UniGeo", token=hf_token) LORA_PATH = os.path.join(LORA_DIR, "UniGeo_lora.safetensors") #VGGT_PATH = "./VGGT_PATH" #WAN_MODEL_DIR = "./WAN_MODEL_DIR" #LORA_PATH = "./LORA_PATH" WAN_CONFIG_PATH = "./my_config.json" # ========================= # Global models # ========================= device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 vggt_model = None wan_pipe = None # ========================= # Load models once # ========================= def load_models(): global vggt_model, wan_pipe if vggt_model is None: print("Loading VGGT...") vggt_model = VGGT.from_pretrained(VGGT_PATH).to(device).eval() if wan_pipe is None: print("Loading Wan...") wan_paths = [ os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00001-of-00003.safetensors"), os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00002-of-00003.safetensors"), os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00003-of-00003.safetensors"), ] wan_pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device=device, model_configs=[ ModelConfig(path=os.path.join(WAN_MODEL_DIR, "models_t5_umt5-xxl-enc-bf16.pth")), ModelConfig(path=os.path.join(WAN_MODEL_DIR, "Wan2.2_VAE.pth")), ], tokenizer_config=ModelConfig(path=os.path.join(WAN_MODEL_DIR, "google/umt5-xxl/")), wan_paths=wan_paths, wan_config_path=WAN_CONFIG_PATH ) ckpt = load_file(LORA_PATH) lora_sd, adapter_sd = {}, {} for k, v in ckpt.items(): if ".lora_" in k: lora_sd[k] = v elif "i2v_adapter" in k: adapter_sd[k] = v wan_pipe.load_lora(wan_pipe.dit, state_dict=lora_sd, alpha=1) wan_pipe.dit.load_state_dict(adapter_sd, strict=False) wan_pipe.to(device) wan_pipe.to(dtype=torch.bfloat16) load_models() # ========================= # Renderer # ========================= def setup_renderer(cameras, image_size): raster_settings = PointsRasterizationSettings( image_size=image_size, radius = 0.01, points_per_pixel = 10, bin_size = 0 ) renderer = PointsRenderer( rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings), compositor=AlphaCompositor() ) render_setup = {'cameras': cameras, 'raster_settings': raster_settings, 'renderer': renderer} return render_setup def render_pcd(pts3d, imgs, masks, views, renderer, device, nbv=False): imgs = to_numpy(imgs) pts3d = to_numpy(pts3d) if masks is None: pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) else: pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) point_cloud = Pointclouds(points=[pts], features=[col]).extend(views) images = renderer(point_cloud) if nbv: color_mask = torch.ones(col.shape).to(device) point_cloud_mask = Pointclouds(points=[pts], features=[color_mask]).extend(views) view_masks = renderer(point_cloud_mask) else: view_masks = None return images, view_masks def run_render(pcd, imgs, masks, H, W, camera_traj, num_views, device, nbv=True): render_setup = setup_renderer(camera_traj, image_size=(H,W)) renderer = render_setup['renderer'] render_results, viewmask = render_pcd(pcd, imgs, masks, num_views, renderer, device, nbv=nbv) return render_results, viewmask # ========================= # Prompt parsing # ========================= def generate_all_motions_from_prompt(prompt, num_frames): x, y, z, phi, theta = parse_prompt_to_motion(prompt) results = [] for i in range(num_frames): alpha = i / (num_frames - 1) results.append(( x * alpha, y * alpha, z * alpha, phi * alpha, theta * alpha )) return results def parse_prompt_to_motion(prompt): prompt = prompt.lower() x = y = z = phi = theta = 0.0 clauses = re.split(r'[;,\n]| and ', prompt) for clause in clauses: nums = re.findall(r"[-+]?\d*\.?\d+", clause) if not nums: continue val = float(nums[0]) if "pans left" in clause: phi = -val elif "pans right" in clause: phi = val elif "tilts up" in clause: theta = val elif "tilts down" in clause: theta = -val elif "moves forward" in clause: z = val elif "moves backward" in clause: z = -val elif "moves up" in clause: y = -val elif "moves down" in clause: y = val elif "moves left" in clause: x = -val elif "moves right" in clause: x = val return x, y, z, phi, theta def build_estimate_rel(x, y, z, phi, theta): delta_euler = [theta, phi, 0.0] rot_mat = Rotation.from_euler('xyz', delta_euler, degrees=True).as_matrix() mat = np.eye(4) mat[:3, :3] = rot_mat mat[:3, 3] = [x, y, z] return mat # ========================= # Main inference # ========================= @spaces.GPU def generate_pcd(image, prompt): if image is None: raise gr.Error("Please upload an input image!") if not prompt: raise gr.Error("Please enter camera control prompts!") img = image.convert("RGB") TARGET_H, TARGET_W = img.size[1], img.size[0] TARGET_H = TARGET_H // 32 * 32 TARGET_W = TARGET_W // 32 * 32 img = img.resize((TARGET_W, TARGET_H), Image.BICUBIC) all_steps = generate_all_motions_from_prompt(prompt, num_frames=81) cam_idx = list(range(81)) traj = [build_estimate_rel(*all_steps[idx]) for idx in cam_idx] first_frame = [img, img] first_frame = load_and_preprocess_images(first_frame) first_frame = first_frame.to(device) with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): predictions = vggt_model(first_frame) extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], first_frame.shape[-2:]) first_frame_world_points = predictions["world_points"][0][0] focals = intrinsic[0][0][:2, :2].diag().unsqueeze(0).to(device) principal_points = intrinsic[0][0][:2, 2].unsqueeze(0).to(device) raw_image = first_frame[0].cpu().numpy() raw_image = raw_image.transpose(1, 2, 0) render_results_list = [] for estimate_rel in traj: estimate_rel = torch.from_numpy(estimate_rel).float().to(device) relative_c2ws = estimate_rel.unsqueeze(0) R, T = relative_c2ws[:, :3, :3], relative_c2ws[:, :3, 3:] R = torch.stack([-R[:, :, 0], -R[:, :, 1], R[:, :, 2]], 2) new_c2w = torch.cat([R, T], 2) w2c = torch.linalg.inv(torch.cat( (new_c2w, torch.Tensor([[[0, 0, 0, 1]]]).to(device).repeat(new_c2w.shape[0], 1, 1)), 1 )) R_new, T_new = w2c[:, :3, :3].permute(0, 2, 1), w2c[:, :3, 3] image_size = (first_frame.shape[-2:],) cameras = PerspectiveCameras( focal_length=focals, principal_point=principal_points, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device ) masks = None render_results, viewmask = run_render( [first_frame_world_points], [raw_image], masks, image_size[0][0], image_size[0][1], cameras, 1, device=device ) render_result = (render_results[-1].detach().cpu().numpy() * 255).astype(np.uint8) if len(render_result.shape) == 2: render_result = cv2.cvtColor(render_result, cv2.COLOR_GRAY2RGB) elif render_result.shape[-1] == 4: render_result = render_result[..., :3] render_results_list.append(render_result) raw_image = first_frame[0].cpu().numpy() raw_image = raw_image.transpose(1, 2, 0) raw_image = (raw_image * 255).clip(0, 255).astype(np.uint8) render_results_list[0] = raw_image frame_indices = np.linspace(0, 80, 25).round().astype(int) frames = [] for idx in frame_indices: frame = render_results_list[idx] frame = Image.fromarray(frame) frames.append(frame) last = frames[-1] for _ in range(4): frames.append(last) def resize_pil(img): return img.resize((TARGET_W, TARGET_H), Image.BICUBIC) frames = [resize_pil(f) for f in frames] pcd_last = frames[-1] # 返回给 UI 界面显示最后一张点云图,同时把所有帧数组传给隐藏的 state 变量 return pcd_last, frames @spaces.GPU def generate_final(image, frames, seed): if not frames: raise gr.Error("Please generate point cloud first!") img = image.convert("RGB") TARGET_H, TARGET_W = img.size[1], img.size[0] TARGET_H = TARGET_H // 32 * 32 TARGET_W = TARGET_W // 32 * 32 def resize_pil(img_to_resize): return img_to_resize.resize((TARGET_W, TARGET_H), Image.BICUBIC) image = resize_pil(img) # ===== Wan ===== video = wan_pipe( prompt="Ensure the consistency of the video", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", src_video=frames, # 直接使用上一步传过来的 frames 状态 input_image=image, height=TARGET_H, width=TARGET_W, cfg_scale=5.0, num_frames=29, num_inference_steps=28, seed=int(seed), tiled=True ) video_frames = list(video) last_frame = np.array(video_frames[-1]) return Image.fromarray(last_frame) # ========================= # Gradio UI # ========================= with gr.Blocks() as demo: # ===== 标题 + 说明 ===== gr.HTML("""