vggt-omega / app.py
JianyuanWang's picture
add example
2597ec6
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import gc
import glob
import os
import shutil
from datetime import datetime
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from huggingface_hub import hf_hub_download
from visual_util import predictions_to_glb
from vggt_omega.models import VGGTOmega
from vggt_omega.utils.load_fn import load_and_preprocess_images
from vggt_omega.utils.pose_enc import encoding_to_camera
CHECKPOINT_REPO_ID = "facebook/VGGT-Omega"
CHECKPOINT_FILENAME = "vggt_omega_1b_512.pt"
IMAGE_RESOLUTION = 512
def _build_model() -> VGGTOmega:
checkpoint_path = hf_hub_download(repo_id=CHECKPOINT_REPO_ID, filename=CHECKPOINT_FILENAME)
print(f"Loaded checkpoint to {checkpoint_path}")
model = VGGTOmega().eval()
state_dict = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(state_dict)
return model
MODEL = _build_model()
@spaces.GPU()
def run_model(target_dir: str) -> dict:
print(f"Processing images from {target_dir}")
image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
if len(image_names) == 0:
raise gr.Error("No images found. Please upload images or a video first.")
model = MODEL.to("cuda")
images = load_and_preprocess_images(image_names, image_resolution=IMAGE_RESOLUTION).to("cuda")
print(f"Preprocessed images shape: {tuple(images.shape)}")
with torch.inference_mode():
predictions = model(images)
extrinsic, intrinsic = encoding_to_camera(
predictions["pose_enc"],
predictions["images"].shape[-2:],
)
predictions["extrinsic"] = extrinsic
predictions["intrinsic"] = intrinsic
predictions_np = {}
for key, value in predictions.items():
if isinstance(value, torch.Tensor):
value = value.detach().float().cpu().numpy()
if value.shape[0] == 1:
value = value[0]
predictions_np[key] = value
predictions_np["world_points_from_depth"] = unproject_depth_map_to_point_map(
predictions_np["depth"],
predictions_np["extrinsic"],
predictions_np["intrinsic"],
)
torch.cuda.empty_cache()
return predictions_np
def unproject_depth_map_to_point_map(depth_map: np.ndarray, extrinsic: np.ndarray, intrinsic: np.ndarray) -> np.ndarray:
depth = depth_map[..., 0]
num_frames, height, width = depth.shape
y, x = np.meshgrid(np.arange(height), np.arange(width), indexing="ij")
x = np.broadcast_to(x[None], (num_frames, height, width))
y = np.broadcast_to(y[None], (num_frames, height, width))
fx = intrinsic[:, 0, 0][:, None, None]
fy = intrinsic[:, 1, 1][:, None, None]
cx = intrinsic[:, 0, 2][:, None, None]
cy = intrinsic[:, 1, 2][:, None, None]
camera_points = np.stack(
[
(x - cx) / fx * depth,
(y - cy) / fy * depth,
depth,
],
axis=-1,
)
rotation = extrinsic[:, :3, :3]
translation = extrinsic[:, :3, 3]
return np.einsum(
"sij,shwj->shwi",
np.transpose(rotation, (0, 2, 1)),
camera_points - translation[:, None, None, :],
)
def file_path(file_data) -> str:
if isinstance(file_data, dict):
if "name" in file_data:
return file_data["name"]
if "path" in file_data:
return file_data["path"]
if file_data.get("video") is not None:
return file_path(file_data["video"])
if hasattr(file_data, "name"):
return file_data.name
return str(file_data)
def handle_uploads(input_video, input_images, video_sample_fps=1.0):
gc.collect()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = os.path.join("demo_outputs", f"input_images_{timestamp}")
target_dir_images = os.path.join(target_dir, "images")
os.makedirs(target_dir_images, exist_ok=True)
image_paths = []
if input_images is not None:
for item in input_images:
src_path = file_path(item)
dst_path = os.path.join(target_dir_images, os.path.basename(src_path))
shutil.copy(src_path, dst_path)
image_paths.append(dst_path)
if input_video is not None:
video_path = file_path(input_video)
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
video_sample_fps = max(float(video_sample_fps), 0.1)
frame_interval = max(int(round((fps if fps and fps > 0 else 1) / video_sample_fps)), 1)
frame_idx = 0
saved_idx = 0
while True:
ok, frame = video.read()
if not ok:
break
if frame_idx % frame_interval == 0:
image_path = os.path.join(target_dir_images, f"{saved_idx:06}.png")
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
saved_idx += 1
frame_idx += 1
video.release()
image_paths = sorted(image_paths)
return target_dir, image_paths
def update_gallery_on_upload(input_video, input_images, video_sample_fps):
if not input_video and not input_images:
return None, "None", None, "Upload images or a video."
target_dir, image_paths = handle_uploads(input_video, input_images, video_sample_fps)
return None, target_dir, image_paths, "Upload complete. Click Reconstruct."
def gradio_demo(
target_dir,
conf_thres=20.0,
mask_black_bg=False,
mask_white_bg=False,
show_cam=True,
mask_sky=False,
max_points_k=1000,
):
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
raise gr.Error("Please upload images or a video first.")
conf_thres = max(3.0, float(conf_thres))
gc.collect()
target_dir_images = os.path.join(target_dir, "images")
all_files = sorted(os.listdir(target_dir_images))
predictions = run_model(target_dir)
prediction_save_path = os.path.join(target_dir, "predictions.npz")
np.savez(prediction_save_path, **predictions)
glbfile = glb_path(
target_dir,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
)
scene = predictions_to_glb(
predictions,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
max_points=int(max_points_k * 1000),
)
scene.export(file_obj=glbfile)
del predictions
gc.collect()
return (
glbfile,
f"Reconstruction complete: {len(all_files)} frames.",
)
def glb_path(
target_dir,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
):
return os.path.join(
target_dir,
f"scene_conf{conf_thres}_black{mask_black_bg}_white{mask_white_bg}_"
f"cam{show_cam}_sky{mask_sky}_max{int(max_points_k)}k.glb",
)
def update_visualization(
target_dir,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
):
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return None, "No reconstruction available. Click Reconstruct first."
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return None, "No reconstruction available. Click Reconstruct first."
conf_thres = max(3.0, float(conf_thres))
glbfile = glb_path(
target_dir,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
)
if not os.path.exists(glbfile):
with np.load(predictions_path) as loaded:
predictions = {key: np.array(loaded[key]) for key in loaded.files}
scene = predictions_to_glb(
predictions,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
max_points=int(max_points_k * 1000),
)
scene.export(file_obj=glbfile)
return glbfile, "Visualization updated."
def clear_model3d():
return None
def update_log():
return "Loading and Reconstructing..."
def update_visual_log():
return "Updating visualization..."
# -------------------------------------------------------------------------
# Example videos
# -------------------------------------------------------------------------
conf_20_video = "examples/conf_20.mp4"
conf_20_robot_video = "examples/conf_20_robot.mp4"
conf_30_video = "examples/conf_30.mp4"
conf50_video = "examples/conf50.mp4"
conf50_filter_sky_video = "examples/conf50_filter_sky.mp4"
def build_ui():
theme = gr.themes.Ocean()
theme.set(
checkbox_label_background_fill_selected="*button_primary_background_fill",
checkbox_label_text_color_selected="*button_primary_text_color",
)
with gr.Blocks(
theme=theme,
css="""
.custom-log * {
font-style: italic;
font-size: 22px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
}
""",
) as demo:
gr.HTML(
"""
<h1>🌀 VGGT-Ω</h1>
<p>
<a href="https://github.com/facebookresearch/vggt-omega">🐙 GitHub Repository</a> |
<a href="https://vggt-omega.github.io/">Project Page</a>
</p>
<div style="font-size: 16px; line-height: 1.5;">
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT-Ω takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
<h3>Getting Started:</h3>
<ol>
<li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames using the selected sampling rate.</li>
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
<li><strong>Reconstruct:</strong> Click the "Reconstruct" button to run camera and depth inference and build the first GLB scene.</li>
<li><strong>Visualize:</strong> The point cloud and camera poses will appear in the viewer on the right. You can rotate, pan, zoom, and download the GLB file.</li>
<li>
<strong>Adjust Visualization (Optional):</strong>
After reconstruction, adjust the visualization options and click "Update Visual" to refresh the GLB without rerunning inference.
</li>
</ol>
<p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">The demo limits Max Points by default to keep the UI responsive; increase Max Points if you need a denser point cloud. Visualizing very dense point clouds may take longer due to third-party rendering, which is independent of VGGT-Ω's processing time.</span></p>
</div>
"""
)
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
with gr.Row():
with gr.Column(scale=2):
input_video = gr.Video(label="Upload Video", interactive=True)
video_sample_fps = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Video Sampling FPS",
interactive=True,
)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column(scale=4):
with gr.Column():
gr.Markdown("**Reconstruction (Point Cloud and Camera Poses)**")
log_output = gr.Markdown(
"Please upload a video or images, then click Reconstruct.",
elem_classes=["custom-log"],
)
reconstruction_output = gr.Model3D(height=780, zoom_speed=0.2, pan_speed=0.2)
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
update_visual_btn = gr.Button("Update Visual", scale=1)
clear_btn = gr.ClearButton(
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
scale=1,
)
with gr.Row():
conf_thres = gr.Slider(
minimum=2,
maximum=100,
value=50,
step=0.1,
label="Confidence Threshold (%)",
)
max_points_k = gr.Slider(
minimum=500,
maximum=10000,
value=1000,
step=500,
label="Max Points (K points)",
)
with gr.Column():
show_cam = gr.Checkbox(label="Show Camera", value=True)
mask_sky = gr.Checkbox(label="Filter Sky", value=False)
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
# ---------------------- Examples section ----------------------
examples = [
[conf_20_video, 1.0, [], 20.0, False, False, True, False, 1000],
[conf_20_robot_video, 1.0, [], 20.0, False, False, True, False, 2000],
[conf_30_video, 1.0, [], 30.0, False, False, True, False, 1000],
[conf50_video, 1.0, [], 50.0, False, False, True, False, 1000],
[conf50_filter_sky_video, 1.0, [], 50.0, False, False, True, True, 1000],
]
def example_pipeline(
input_video,
video_sample_fps,
input_images,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
):
target_dir, image_paths = handle_uploads(input_video, input_images, video_sample_fps)
glbfile, log_msg = gradio_demo(
target_dir,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
)
return glbfile, log_msg, target_dir, image_paths
gr.Markdown("Click any row to load an example.")
gr.Examples(
examples=examples,
inputs=[
input_video,
video_sample_fps,
input_images,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
],
outputs=[
reconstruction_output,
log_output,
target_dir_output,
image_gallery,
],
fn=example_pipeline,
cache_examples=False,
examples_per_page=50,
)
input_video.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images, video_sample_fps],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
input_images.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images, video_sample_fps],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
video_sample_fps.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images, video_sample_fps],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
submit_btn.click(fn=clear_model3d, inputs=[], outputs=[reconstruction_output]).then(
fn=update_log,
inputs=[],
outputs=[log_output],
).then(
fn=gradio_demo,
inputs=[
target_dir_output,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
],
outputs=[reconstruction_output, log_output],
)
update_visual_btn.click(fn=update_visual_log, inputs=[], outputs=[log_output]).then(
fn=update_visualization,
inputs=[
target_dir_output,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
max_points_k,
],
outputs=[reconstruction_output, log_output],
)
return demo
demo = build_ui()
demo.queue(max_size=20)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)