UniGeo / app.py
123123aa123's picture
Update app.py
3220c4d verified
import torch, os, re
import numpy as np
import gradio as gr
from PIL import Image
from scipy.spatial.transform import Rotation
import cv2, sys
from huggingface_hub import snapshot_download
import spaces
import site
import importlib
#os.system("pip install ./pytorch3d-0.7.8+pt2.7.1cu126-cp310-cp310-linux_x86_64.whl")
#os.system("python -m pip install -e ./pytorch3d-0.7.8 --no-build-isolation")
#site.main()
#importlib.invalidate_caches()
# ===== VGGT =====
sys.path.append(os.path.join(os.getcwd(), "vggt"))
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
# ===== Wan =====
sys.path.append(os.path.join(os.getcwd(), "DiffSynth-Studio"))
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from safetensors.torch import load_file
# ===== PyTorch3D =====
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
PerspectiveCameras,
PointsRasterizationSettings,
PointsRenderer,
PointsRasterizer,
AlphaCompositor,
)
def todevice(batch, device, callback=None, non_blocking=False):
''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
batch: list, tuple, dict of tensors or other things
device: pytorch device or 'numpy'
callback: function that would be called on every sub-elements.
'''
if isinstance(batch, dict):
return {k: todevice(v, device) for k, v in batch.items()}
if isinstance(batch, (tuple, list)):
return type(batch)(todevice(x, device) for x in batch)
x = batch
if device == 'numpy':
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif x is not None:
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if torch.is_tensor(x):
x = x.to(device, non_blocking=non_blocking)
return x
def to_numpy(x): return todevice(x, 'numpy')
# =========================
# Global configs (CHANGE THESE PATHS)
# =========================
hf_token = os.getenv("HF_TOKEN")
VGGT_PATH = snapshot_download(repo_id="facebook/VGGT-1B", token=hf_token)
WAN_MODEL_DIR = snapshot_download(repo_id="Wan-AI/Wan2.2-TI2V-5B", token=hf_token)
LORA_DIR = snapshot_download(repo_id="123123aa123/UniGeo", token=hf_token)
LORA_PATH = os.path.join(LORA_DIR, "UniGeo_lora.safetensors")
#VGGT_PATH = "./VGGT_PATH"
#WAN_MODEL_DIR = "./WAN_MODEL_DIR"
#LORA_PATH = "./LORA_PATH"
WAN_CONFIG_PATH = "./my_config.json"
# =========================
# Global models
# =========================
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
vggt_model = None
wan_pipe = None
# =========================
# Load models once
# =========================
def load_models():
global vggt_model, wan_pipe
if vggt_model is None:
print("Loading VGGT...")
vggt_model = VGGT.from_pretrained(VGGT_PATH).to(device).eval()
if wan_pipe is None:
print("Loading Wan...")
wan_paths = [
os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00001-of-00003.safetensors"),
os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00002-of-00003.safetensors"),
os.path.join(WAN_MODEL_DIR, "diffusion_pytorch_model-00003-of-00003.safetensors"),
]
wan_pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(path=os.path.join(WAN_MODEL_DIR, "models_t5_umt5-xxl-enc-bf16.pth")),
ModelConfig(path=os.path.join(WAN_MODEL_DIR, "Wan2.2_VAE.pth")),
],
tokenizer_config=ModelConfig(path=os.path.join(WAN_MODEL_DIR, "google/umt5-xxl/")),
wan_paths=wan_paths,
wan_config_path=WAN_CONFIG_PATH
)
ckpt = load_file(LORA_PATH)
lora_sd, adapter_sd = {}, {}
for k, v in ckpt.items():
if ".lora_" in k:
lora_sd[k] = v
elif "i2v_adapter" in k:
adapter_sd[k] = v
wan_pipe.load_lora(wan_pipe.dit, state_dict=lora_sd, alpha=1)
wan_pipe.dit.load_state_dict(adapter_sd, strict=False)
wan_pipe.to(device)
wan_pipe.to(dtype=torch.bfloat16)
load_models()
# =========================
# Renderer
# =========================
def setup_renderer(cameras, image_size):
raster_settings = PointsRasterizationSettings(
image_size=image_size,
radius = 0.01,
points_per_pixel = 10,
bin_size = 0
)
renderer = PointsRenderer(
rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings),
compositor=AlphaCompositor()
)
render_setup = {'cameras': cameras, 'raster_settings': raster_settings, 'renderer': renderer}
return render_setup
def render_pcd(pts3d, imgs, masks, views, renderer, device, nbv=False):
imgs = to_numpy(imgs)
pts3d = to_numpy(pts3d)
if masks is None:
pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device)
col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device)
else:
pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device)
col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device)
point_cloud = Pointclouds(points=[pts], features=[col]).extend(views)
images = renderer(point_cloud)
if nbv:
color_mask = torch.ones(col.shape).to(device)
point_cloud_mask = Pointclouds(points=[pts], features=[color_mask]).extend(views)
view_masks = renderer(point_cloud_mask)
else:
view_masks = None
return images, view_masks
def run_render(pcd, imgs, masks, H, W, camera_traj, num_views, device, nbv=True):
render_setup = setup_renderer(camera_traj, image_size=(H,W))
renderer = render_setup['renderer']
render_results, viewmask = render_pcd(pcd, imgs, masks, num_views, renderer, device, nbv=nbv)
return render_results, viewmask
# =========================
# Prompt parsing
# =========================
def generate_all_motions_from_prompt(prompt, num_frames):
x, y, z, phi, theta = parse_prompt_to_motion(prompt)
results = []
for i in range(num_frames):
alpha = i / (num_frames - 1)
results.append((
x * alpha,
y * alpha,
z * alpha,
phi * alpha,
theta * alpha
))
return results
def parse_prompt_to_motion(prompt):
prompt = prompt.lower()
x = y = z = phi = theta = 0.0
clauses = re.split(r'[;,\n]| and ', prompt)
for clause in clauses:
nums = re.findall(r"[-+]?\d*\.?\d+", clause)
if not nums:
continue
val = float(nums[0])
if "pans left" in clause:
phi = -val
elif "pans right" in clause:
phi = val
elif "tilts up" in clause:
theta = val
elif "tilts down" in clause:
theta = -val
elif "moves forward" in clause:
z = val
elif "moves backward" in clause:
z = -val
elif "moves up" in clause:
y = -val
elif "moves down" in clause:
y = val
elif "moves left" in clause:
x = -val
elif "moves right" in clause:
x = val
return x, y, z, phi, theta
def build_estimate_rel(x, y, z, phi, theta):
delta_euler = [theta, phi, 0.0]
rot_mat = Rotation.from_euler('xyz', delta_euler, degrees=True).as_matrix()
mat = np.eye(4)
mat[:3, :3] = rot_mat
mat[:3, 3] = [x, y, z]
return mat
# =========================
# Main inference
# =========================
@spaces.GPU
def generate_pcd(image, prompt):
if image is None:
raise gr.Error("Please upload an input image!")
if not prompt:
raise gr.Error("Please enter camera control prompts!")
img = image.convert("RGB")
TARGET_H, TARGET_W = img.size[1], img.size[0]
TARGET_H = TARGET_H // 32 * 32
TARGET_W = TARGET_W // 32 * 32
img = img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
all_steps = generate_all_motions_from_prompt(prompt, num_frames=81)
cam_idx = list(range(81))
traj = [build_estimate_rel(*all_steps[idx]) for idx in cam_idx]
first_frame = [img, img]
first_frame = load_and_preprocess_images(first_frame)
first_frame = first_frame.to(device)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
predictions = vggt_model(first_frame)
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], first_frame.shape[-2:])
first_frame_world_points = predictions["world_points"][0][0]
focals = intrinsic[0][0][:2, :2].diag().unsqueeze(0).to(device)
principal_points = intrinsic[0][0][:2, 2].unsqueeze(0).to(device)
raw_image = first_frame[0].cpu().numpy()
raw_image = raw_image.transpose(1, 2, 0)
render_results_list = []
for estimate_rel in traj:
estimate_rel = torch.from_numpy(estimate_rel).float().to(device)
relative_c2ws = estimate_rel.unsqueeze(0)
R, T = relative_c2ws[:, :3, :3], relative_c2ws[:, :3, 3:]
R = torch.stack([-R[:, :, 0], -R[:, :, 1], R[:, :, 2]], 2)
new_c2w = torch.cat([R, T], 2)
w2c = torch.linalg.inv(torch.cat(
(new_c2w, torch.Tensor([[[0, 0, 0, 1]]]).to(device).repeat(new_c2w.shape[0], 1, 1)),
1
))
R_new, T_new = w2c[:, :3, :3].permute(0, 2, 1), w2c[:, :3, 3]
image_size = (first_frame.shape[-2:],)
cameras = PerspectiveCameras(
focal_length=focals,
principal_point=principal_points,
in_ndc=False,
image_size=image_size,
R=R_new,
T=T_new,
device=device
)
masks = None
render_results, viewmask = run_render(
[first_frame_world_points],
[raw_image],
masks,
image_size[0][0], image_size[0][1],
cameras,
1,
device=device
)
render_result = (render_results[-1].detach().cpu().numpy() * 255).astype(np.uint8)
if len(render_result.shape) == 2:
render_result = cv2.cvtColor(render_result, cv2.COLOR_GRAY2RGB)
elif render_result.shape[-1] == 4:
render_result = render_result[..., :3]
render_results_list.append(render_result)
raw_image = first_frame[0].cpu().numpy()
raw_image = raw_image.transpose(1, 2, 0)
raw_image = (raw_image * 255).clip(0, 255).astype(np.uint8)
render_results_list[0] = raw_image
frame_indices = np.linspace(0, 80, 25).round().astype(int)
frames = []
for idx in frame_indices:
frame = render_results_list[idx]
frame = Image.fromarray(frame)
frames.append(frame)
last = frames[-1]
for _ in range(4):
frames.append(last)
def resize_pil(img):
return img.resize((TARGET_W, TARGET_H), Image.BICUBIC)
frames = [resize_pil(f) for f in frames]
pcd_last = frames[-1]
# 返回给 UI 界面显示最后一张点云图,同时把所有帧数组传给隐藏的 state 变量
return pcd_last, frames
@spaces.GPU
def generate_final(image, frames, seed):
if not frames:
raise gr.Error("Please generate point cloud first!")
img = image.convert("RGB")
TARGET_H, TARGET_W = img.size[1], img.size[0]
TARGET_H = TARGET_H // 32 * 32
TARGET_W = TARGET_W // 32 * 32
def resize_pil(img_to_resize):
return img_to_resize.resize((TARGET_W, TARGET_H), Image.BICUBIC)
image = resize_pil(img)
# ===== Wan =====
video = wan_pipe(
prompt="Ensure the consistency of the video",
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
src_video=frames, # 直接使用上一步传过来的 frames 状态
input_image=image,
height=TARGET_H,
width=TARGET_W,
cfg_scale=5.0,
num_frames=29,
num_inference_steps=28,
seed=int(seed),
tiled=True
)
video_frames = list(video)
last_frame = np.array(video_frames[-1])
return Image.fromarray(last_frame)
# =========================
# Gradio UI
# =========================
with gr.Blocks() as demo:
# ===== 标题 + 说明 =====
gr.HTML("""<div style="line-height:1.4; font-size:15px">
<b style="font-size:18px">UniGeo: Unifying Geometric Guidance for Camera-Controllable Image Editing via Video Models</b><br>
<hr style="margin:8px 0;">
<b>Input Requirement / 输入要求</b><br>
The input image is recommended to have width ≥ height due to VGGT and Wan model constraints.<br>
由于 VGGT 与 Wan 模型限制,建议输入图像满足 宽 ≥ 高。<br>
<hr style="margin:8px 0;">
<b>Usage Guide / 使用说明</b>
<ul style="margin-top: 4px; padding-left: 20px;">
<li style="margin-bottom: 4px;"><b>Command Format / 指令格式:</b>You can input one or multiple camera commands separated by semicolons (e.g., “Camera pans left by 15 degrees” or “Camera moves left by 0.27; Camera pans right by 26 degrees”).<br>
支持输入一条或多条相机控制指令,使用分号分隔(例如“Camera pans left by 15 degrees”或“Camera moves left by 0.27; Camera pans right by 26 degrees”)。</li>
<li style="margin-bottom: 4px;"><b>Scale & Adjustment / 尺度与调整:</b>The motion scale is normalized by VGGT, and the final point cloud is provided to help adjust motion parameters.<br>
所有运动数值由 VGGT 统一尺度建模,最终提供的点云结果可用于辅助调整相机运动参数。</li>
<li><b>Tips / 提示:</b>Default inference steps: 28 (Speed & Quality balanced). Run locally with higher steps for better results. <br>
为平衡时间与质量,当前推理步数设为 28。若想获得更佳效果,可在本地试着提高推理步数。</li>
</ul>
</div>""")
# 隐藏的状态变量,用于在两步之间传递生成的视频帧
frames_state = gr.State([])
gr.Markdown("### Step 1: Point Cloud Preview / 步骤一:点云预览与调节")
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Input Image")
txt = gr.Textbox(label="Camera Prompt")
btn_pcd = gr.Button("Generate Point Cloud (生成点云)")
with gr.Column():
pcd_out = gr.Image(type="pil", label="Final Frame Point Cloud (预览结果)")
gr.Markdown("### Step 2: Final Result Generation / 步骤二:生成最终结果")
with gr.Row():
with gr.Column():
seed_inp = gr.Number(value=0, label="Seed", precision=0)
btn_final = gr.Button("Generate Final Result (生成编辑结果)", variant="primary")
with gr.Column():
out = gr.Image(type="numpy", label="Output Image")
# ===== 绑定第一步:只生成点云和缓存视频帧 =====
btn_pcd.click(
fn=generate_pcd,
inputs=[inp, txt],
outputs=[pcd_out, frames_state] # 界面更新点云图,后台偷偷存下 frames 序列
)
# ===== 绑定第二步:读取缓存的帧,生成最终图 =====
btn_final.click(
fn=generate_final,
inputs=[inp, frames_state, seed_inp],
outputs=[out]
)
if __name__ == "__main__":
demo.launch(share=True)