vdpm / gradio_demo.py
dxm21's picture
Upload folder using huggingface_hub
4d798a0 verified
import gc
import glob
import os
# Set fragmentation fix BEFORE importing torch (or immediately after imports)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import shutil
import sys
import time
import zipfile
from datetime import datetime
from pathlib import Path
import cv2
import gradio as gr
import matplotlib
import numpy as np
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
import json # Added for exporting
import spaces # ZeroGPU support
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from dpm.model import VDPM
from vggt.utils.load_fn import load_and_preprocess_images
from util.depth import write_depth_to_png
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
# This disables the specific fused GELU that's crashing
torch.set_float32_matmul_precision('high')
# ============================================================================
# MEMORY OPTIMIZATION SETTINGS FOR 8GB GPUs (RTX 3070 Ti, 3060 Ti, etc.)
# ============================================================================
# Model size: 1.66B parameters (~3GB FP16 weights, ~12GB total with activations)
#
# Memory reduction options (choose one):
# USE_HALF_PRECISION = True: FP16 model -> ~1.5GB weights, ~6-7GB total (RECOMMENDED FOR GPU)
# USE_QUANTIZATION = True: INT8 quantization -> CPU ONLY, not supported on CUDA
# Both False: FP16/BF16 inference only -> ~3GB weights, ~8-10GB total (may OOM)
#
# MAX_FRAMES: Limit input frames (5 recommended for 8GB GPUs)
# ============================================================================
MAX_POINTS_PER_FRAME = 50_000
TRAIL_LENGTH = 20
MAX_TRACKS = 150
STATIC_THRESHOLD = 0.01 # Lower threshold to show more subtle motion
VIDEO_SAMPLE_HZ = 2.0
# Dynamic Configuration based on Helper/Hardware
USE_QUANTIZATION = False
USE_HALF_PRECISION = True
MAX_FRAMES = 5 # Default for 8GB
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
# Enable TF32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Check VRAM to auto-scale MAX_FRAMES
vram_bytes = torch.cuda.get_device_properties(0).total_memory
vram_gb = vram_bytes / (1024**3)
print(f"\u2713 GPU Detected: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)")
if vram_gb >= 30: # A10G (24GB), A100 (40/80GB), RTX 3090/4090 (24GB)
MAX_FRAMES = 80
print(f" -> High VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
elif vram_gb >= 15: # T4 (16GB), 4080 (16GB)
MAX_FRAMES = 32
print(f" -> Medium VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
elif vram_gb >= 7.5: # RTX 3070 Ti, 2080, etc (8GB)
MAX_FRAMES = 24
print(f" -> 8GB VRAM detected. Set MAX_FRAMES to {MAX_FRAMES}")
else:
MAX_FRAMES = 8
print(f" -> Low VRAM (<8GB). Keeping MAX_FRAMES at {MAX_FRAMES} to prevent OOM")
print(f"\u2713 TF32 enabled for faster matrix operations")
def load_cfg_from_cli() -> "omegaconf.DictConfig":
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
# In notebooks or some interactive shells, sys.argv contains kernel launch flags
# like `--f=...` which Hydra cannot parse as overrides. Filter those out.
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
raw_overrides = sys.argv[1:]
overrides = [o for o in raw_overrides if not (o.startswith("--f=") or "kernel" in o)]
with initialize(config_path="configs"):
return compose(config_name="visualise", overrides=overrides)
def load_model(cfg) -> VDPM:
model = VDPM(cfg).to(device)
# Use a persistent cache directory that Spaces preserves
cache_dir = os.path.expanduser("~/.cache/vdpm")
os.makedirs(cache_dir, exist_ok=True)
model_path = os.path.join(cache_dir, "vdpm_model.pt")
_URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
# Download only if not cached
if not os.path.exists(model_path):
print(f"Downloading model to {model_path}...")
sd = torch.hub.load_state_dict_from_url(
_URL,
file_name="vdpm_model.pt",
progress=True,
map_location=device
)
# Save to persistent location
torch.save(sd, model_path)
print(f"✓ Model cached at {model_path}")
else:
print(f"✓ Loading cached model from {model_path}")
sd = torch.load(model_path, map_location=device)
print(model.load_state_dict(sd, strict=True))
model.eval()
# Option 1: Use FP16/BF16 for all model weights (simple, ~2x memory/speed boost)
if USE_HALF_PRECISION and not USE_QUANTIZATION:
# Use BF16 only on Hopper+ (compute >= 9) where BF16 throughput matches FP16
# On Ampere (compute 8.x, e.g. 3070Ti), FP16 tensor cores are ~2x faster than BF16
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9:
print("Converting model to BF16 precision (Hopper+ GPU detected)...")
model = model.to(torch.bfloat16)
print("✓ Model converted to BF16")
else:
print("Converting model to FP16 precision...")
# Convert model to full FP16 (with helper that can preserve norm layers if desired)
try:
model = model.to_fp16(keep_norm_fp32=False)
print("✓ Model converted to FP16 via to_fp16()")
except Exception:
# Fallback to half() if helper not available
model = model.half()
print("✓ Model converted to FP16 via half() (fallback)")
# Option 2: Apply INT8 dynamic quantization (more aggressive, ~3-4x reduction)
if USE_QUANTIZATION:
try:
print("Applying INT8 dynamic quantization to reduce memory usage...")
# Move to CPU for quantization, then back to GPU
model = model.cpu()
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear, torch.nn.Conv2d}, # Quantize these layer types
dtype=torch.qint8
)
model = model.to(device)
print("✓ Model quantized: ~3x memory reduction (3GB -> 1GB)")
except Exception as e:
print(f"⚠️ Quantization failed: {e}")
print("Continuing with FP16/BF16 precision...")
model = model.to(device)
# Enable torch.compile for faster inference (PyTorch 2.0+)
# torch.compile is handled per-block in decoder.compile_blocks()
# called during load_state_dict above
return model
def require_cuda():
if device != "cuda":
raise ValueError("CUDA is not available. Check your environment.")
def gradio_file_path(file_obj):
if file_obj is None:
return None
if isinstance(file_obj, dict) and "name" in file_obj:
return file_obj["name"]
return file_obj
def ensure_nhwc_images(images: np.ndarray) -> np.ndarray:
if images.ndim == 4 and images.shape[1] == 3:
return np.transpose(images, (0, 2, 3, 1))
return images
# ============================================================================
# NORMALS AND OPTICAL FLOW COMPUTATION UTILITIES
# ============================================================================
def compute_normals_from_pointmap(point_map: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute surface normals from a point map using cross product of local tangents.
Args:
point_map: (T, V, H, W, 3) or (T, H, W, 3) point map in world coordinates.
Returns:
normals: Same shape as input, unit normals in world space.
tangent_x: Same shape as input, unit dP/dx tangents.
tangent_y: Same shape as input, unit dP/dy tangents.
"""
original_shape = point_map.shape
# Handle both (T, V, H, W, 3) and (T, H, W, 3) inputs
if point_map.ndim == 5:
T, V, H, W, _ = point_map.shape
pm = point_map.reshape(T * V, H, W, 3)
elif point_map.ndim == 4:
T, H, W, _ = point_map.shape
V = 1
pm = point_map
else:
raise ValueError(f"Expected 4D or 5D point map, got shape {point_map.shape}")
# Convert to torch and permute to (B, 3, H, W)
pm_torch = torch.from_numpy(pm).permute(0, 3, 1, 2).float()
# Finite differences
dy = pm_torch[:, :, 1:, :] - pm_torch[:, :, :-1, :]
dx = pm_torch[:, :, :, 1:] - pm_torch[:, :, :, :-1]
# Pad back to original resolution
dy = F.pad(dy, (0, 0, 0, 1), mode="replicate")
dx = F.pad(dx, (0, 1, 0, 0), mode="replicate")
# Compute tangents and normals
tangent_x = F.normalize(dx, dim=1, eps=1e-8)
tangent_y = F.normalize(dy, dim=1, eps=1e-8)
normals = torch.cross(dx, dy, dim=1)
normals = F.normalize(normals, dim=1, eps=1e-8)
# Replace NaNs/Infs
tangent_x = torch.nan_to_num(tangent_x, nan=0.0, posinf=0.0, neginf=0.0)
tangent_y = torch.nan_to_num(tangent_y, nan=0.0, posinf=0.0, neginf=0.0)
normals = torch.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0)
# Convert back to numpy with original shape
normals_np = normals.permute(0, 2, 3, 1).numpy()
tangent_x_np = tangent_x.permute(0, 2, 3, 1).numpy()
tangent_y_np = tangent_y.permute(0, 2, 3, 1).numpy()
if point_map.ndim == 5:
normals_np = normals_np.reshape(T, V, H, W, 3)
tangent_x_np = tangent_x_np.reshape(T, V, H, W, 3)
tangent_y_np = tangent_y_np.reshape(T, V, H, W, 3)
return normals_np.astype(np.float32), tangent_x_np.astype(np.float32), tangent_y_np.astype(np.float32)
def compute_smooth_normals(normals: np.ndarray, kernel_size: int = 5) -> np.ndarray:
"""Apply Gaussian smoothing to normal maps then renormalize.
Args:
normals: (T, V, H, W, 3) or (T, H, W, 3) normal map.
kernel_size: Size of Gaussian kernel for smoothing.
Returns:
Smoothed and renormalized normal maps with same shape as input.
"""
original_shape = normals.shape
if normals.ndim == 5:
T, V, H, W, _ = normals.shape
n = normals.reshape(T * V, H, W, 3)
elif normals.ndim == 4:
T, H, W, _ = normals.shape
V = 1
n = normals
else:
raise ValueError(f"Expected 4D or 5D normals, got shape {normals.shape}")
# Convert to torch (B, 3, H, W)
n_torch = torch.from_numpy(n.astype(np.float32)).permute(0, 3, 1, 2)
# Create Gaussian kernel
sigma = kernel_size / 4.0
coords = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
gauss = torch.exp(-coords**2 / (2 * sigma**2))
gauss = gauss / gauss.sum()
kernel = gauss.unsqueeze(0) * gauss.unsqueeze(1)
kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1) # (3, 1, kH, kW)
# Apply separable Gaussian blur (per-channel)
padding = kernel_size // 2
n_smoothed = F.conv2d(n_torch, kernel, padding=padding, groups=3)
# Re-normalize
n_smoothed = F.normalize(n_smoothed, dim=1, eps=1e-8)
n_smoothed = torch.nan_to_num(n_smoothed, nan=0.0, posinf=0.0, neginf=0.0)
# Convert back
result = n_smoothed.permute(0, 2, 3, 1).numpy()
if normals.ndim == 5:
result = result.reshape(T, V, H, W, 3)
return result.astype(np.float32)
def compute_optical_flow(world_points: np.ndarray, extrinsics: np.ndarray = None, intrinsics: np.ndarray = None, num_views: int = 1) -> np.ndarray:
"""Compute optical flow from world point trajectories.
Optical flow is computed as the 2D displacement of points between consecutive frames.
If camera intrinsics/extrinsics are provided, we project to image space.
Otherwise, we use the XY displacement in world space as an approximation.
Args:
world_points: (T, V, H, W, 3) world point trajectories.
extrinsics: Optional (N, 4, 4) camera extrinsics.
intrinsics: Optional (N, 3, 3) camera intrinsics.
num_views: Number of views per timestep.
Returns:
optical_flow: (T-1, V, H, W, 2) optical flow as (dx, dy) in pixels or world units.
"""
T, V, H, W, _ = world_points.shape
if T < 2:
# Cannot compute flow with less than 2 frames
return np.zeros((1, V, H, W, 2), dtype=np.float32)
if extrinsics is not None and intrinsics is not None:
# Project points to image space and compute flow
flows = []
for t in range(T - 1):
flow_t = np.zeros((V, H, W, 2), dtype=np.float32)
for v in range(V):
cam_idx_curr = t * V + v
cam_idx_next = (t + 1) * V + v
if cam_idx_curr >= len(extrinsics) or cam_idx_next >= len(intrinsics):
continue
pts_curr = world_points[t, v] # (H, W, 3)
pts_next = world_points[t + 1, v] # (H, W, 3)
# Project current points to current frame's image space
K = intrinsics[cam_idx_curr] # (3, 3)
R = extrinsics[cam_idx_curr, :3, :3]
t_vec = extrinsics[cam_idx_curr, :3, 3]
# Transform to camera space
pts_cam_curr = np.einsum('ij,hwj->hwi', R, pts_curr) + t_vec
pts_cam_next = np.einsum('ij,hwj->hwi', R, pts_next) + t_vec
# Project to image space (avoid division by zero)
z_curr = np.maximum(pts_cam_curr[..., 2:3], 1e-6)
z_next = np.maximum(pts_cam_next[..., 2:3], 1e-6)
xy_curr = pts_cam_curr[..., :2] / z_curr
xy_next = pts_cam_next[..., :2] / z_next
uv_curr = xy_curr @ K[:2, :2].T + K[:2, 2]
uv_next = xy_next @ K[:2, :2].T + K[:2, 2]
flow_t[v] = uv_next - uv_curr
flows.append(flow_t)
return np.stack(flows, axis=0).astype(np.float32)
else:
# Simple world-space XY displacement
flow = world_points[1:, ..., :2] - world_points[:-1, ..., :2]
return flow.astype(np.float32)
def write_normal_to_png(outpath: str, normals: np.ndarray) -> None:
"""Write a normal map as a PNG image.
Args:
outpath: Output file path.
normals: (H, W, 3) normal map with values in [-1, 1].
"""
import PIL.Image
# Map from [-1, 1] to [0, 1] then to [0, 255]
n = np.clip((normals.astype(np.float32) + 1.0) * 0.5, 0.0, 1.0)
rgb = (n * 255.0).astype(np.uint8)
PIL.Image.fromarray(rgb).save(outpath)
def write_optical_flow_to_png(outpath: str, flow: np.ndarray, max_flow: float = None) -> None:
"""Write optical flow as a color-coded PNG using HSV color wheel.
Args:
outpath: Output file path.
flow: (H, W, 2) optical flow as (dx, dy).
max_flow: Optional maximum flow magnitude for normalization.
"""
import PIL.Image
fx, fy = flow[..., 0], flow[..., 1]
magnitude = np.sqrt(fx**2 + fy**2)
angle = np.arctan2(fy, fx) # Range [-pi, pi]
# Normalize magnitude
if max_flow is None:
max_flow = magnitude.max() if magnitude.max() > 0 else 1.0
magnitude = np.clip(magnitude / max_flow, 0, 1)
# Convert angle to hue (0-180 for OpenCV HSV)
hue = ((angle + np.pi) / (2 * np.pi) * 180).astype(np.uint8)
saturation = (magnitude * 255).astype(np.uint8)
value = np.full_like(hue, 255, dtype=np.uint8)
hsv = np.stack([hue, saturation, value], axis=-1)
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
PIL.Image.fromarray(rgb).save(outpath)
def _flow_to_image_3d(flow_hw3: np.ndarray) -> np.ndarray:
"""Convert a 3D flow map (H, W, 3) to an RGB image.
Maps each XYZ component to R, G, B scaled by the global max magnitude.
"""
mag = np.linalg.norm(flow_hw3, axis=-1, keepdims=True)
max_mag = float(mag.max()) + 1e-8
norm = flow_hw3 / max_mag
rgb = np.clip((norm + 1.0) * 0.5, 0.0, 1.0)
return (rgb * 255).astype(np.uint8)
def create_output_zip(target_dir: str) -> str:
"""Create a zip file containing all outputs for download.
Args:
target_dir: Directory containing all output files.
Returns:
Path to the created zip file.
"""
zip_path = os.path.join(target_dir, "outputs.zip")
# Items to include in the zip
items_to_zip = [
"output_4d.npz",
"tracks.npz",
"poses.npz",
"depths.npz",
"depth_normals.npz",
"scene_flow.npz",
"angular_flow.npz",
"depths", # directory
"normals", # directory
"scene_flow", # directory
"angular_flow", # directory
"images", # directory
"meta.json",
]
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for item in items_to_zip:
item_path = os.path.join(target_dir, item)
if os.path.isfile(item_path):
zipf.write(item_path, item)
elif os.path.isdir(item_path):
for root, dirs, files in os.walk(item_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, target_dir)
zipf.write(file_path, arcname)
print(f"✓ Created output zip: {zip_path}")
return zip_path
def compute_scene_bounds(world_points: np.ndarray):
all_pts = world_points.reshape(-1, 3)
raw_min = all_pts.min(axis=0)
raw_max = all_pts.max(axis=0)
center = 0.5 * (raw_min + raw_max)
half_extent = 0.5 * (raw_max - raw_min) * 1.05
if np.all(half_extent < 1e-6):
half_extent[:] = 1.0
else:
half_extent[half_extent < 1e-6] = half_extent.max()
global_min = center - half_extent
global_max = center + half_extent
max_half = half_extent.max()
aspectratio = {
"x": float(half_extent[0] / max_half),
"y": float(half_extent[1] / max_half),
"z": float(half_extent[2] / max_half),
}
return global_min, global_max, aspectratio
def stride_downsample(pts: np.ndarray, cols: np.ndarray, max_points: int):
n = pts.shape[0]
if n <= max_points:
return pts, cols
step = int(np.ceil(n / max_points))
idx = np.arange(0, n, step)[:max_points]
return pts[idx], cols[idx]
# ============================================================
# NEW: Single shared mask function (used by points + tracks)
# ============================================================
def compute_point_mask(
conf_score: np.ndarray | None,
cols: np.ndarray,
conf_thres: float,
mask_black_bg: bool,
mask_white_bg: bool,
) -> np.ndarray:
"""
conf_score: (N,) or None
cols: (N,3) uint8
Returns: (N,) boolean mask
"""
mask = np.ones(cols.shape[0], dtype=bool)
# confidence percentile threshold (same semantics as before)
if conf_score is not None and conf_thres > 0:
thresh = np.percentile(conf_score, conf_thres)
mask &= (conf_score >= thresh) & (conf_score > 1e-5)
# background masks (same as before)
if mask_black_bg:
mask &= (cols.sum(axis=1) >= 16)
if mask_white_bg:
mask &= ~((cols[:, 0] > 240) & (cols[:, 1] > 240) & (cols[:, 2] > 240))
return mask
def sample_frame_points(
world_points: np.ndarray,
images_nhwc: np.ndarray,
conf: np.ndarray | None,
idx: int,
conf_thres: float,
mask_black_bg: bool,
mask_white_bg: bool,
max_points: int,
):
i = int(np.clip(idx, 0, world_points.shape[0] - 1))
# world_points is now (T, N, 3) where N = V*H*W (all views combined)
pts = world_points[i] # Already flattened: (N, 3)
cols = (images_nhwc[i].reshape(-1, 3) * 255).astype(np.uint8)
conf_score = conf[i] if (conf is not None) else None # Already flattened: (N,)
mask = compute_point_mask(
conf_score=conf_score,
cols=cols,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
)
pts = pts[mask]
cols = cols[mask]
pts, cols = stride_downsample(pts, cols, max_points)
if pts.size == 0:
pts = np.array([[0.0, 0.0, 0.0]])
cols = np.array([[255, 255, 255]], dtype=np.uint8)
colors_str = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in cols]
return pts, colors_str
# ============================================================
# UPDATED: prepare_tracks now applies the SAME masks as points
# ============================================================
def prepare_tracks(
world_points_tracks: np.ndarray,
world_points_conf_tracks: np.ndarray,
images_raw: np.ndarray,
conf_thres: float,
mask_black_bg: bool,
mask_white_bg: bool,
num_views: int,
static_threshold: float = 0.01,
):
"""
Prepare tracks from per-image temporal trajectories.
Args:
world_points_tracks: (S, T, H, W, 3) - each of S images tracked over T timesteps
world_points_conf_tracks: (S, T, H, W) - confidence for tracks
images_raw: (S, C, H, W) - original input images
conf_thres: confidence threshold percentile
mask_black_bg: filter black background
mask_white_bg: filter white background
num_views: number of camera views
Returns:
tracks_xyz: (T, N_tracks, 3) - selected point trajectories
colorscale: colormap for tracks
track_ids: track identifiers
"""
S, T, H, W, _ = world_points_tracks.shape
if T < 2:
return None, None, None
# Collect tracks from all images, but sample to avoid overwhelming visualization
all_tracks = []
all_colors = []
# Process each input image's tracks
for img_idx in range(S):
tracks_img = world_points_tracks[img_idx] # (T, H, W, 3)
conf_img = world_points_conf_tracks[img_idx] # (T, H, W)
# Flatten spatial dimensions: (T, H*W, 3)
tracks_flat = tracks_img.reshape(T, -1, 3)
conf_flat = conf_img.reshape(T, -1)
N = tracks_flat.shape[1]
# Compute displacement to find dynamic points
disp = np.linalg.norm(tracks_flat - tracks_flat[0:1], axis=-1) # (T, N)
dynamic_mask = disp.max(axis=0) > static_threshold # (N,)
# Average confidence across time
conf_score = conf_flat.mean(axis=0) # (N,)
# Get colors from source image
img_hwc = np.transpose(images_raw[img_idx], (1, 2, 0)) # (H, W, C)
cols = (img_hwc.reshape(-1, 3) * 255).astype(np.uint8) # (H*W, 3)
# Apply masks
point_mask = compute_point_mask(
conf_score=conf_score,
cols=cols,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
)
valid_mask = dynamic_mask & point_mask
idx_valid = np.nonzero(valid_mask)[0]
if idx_valid.size == 0:
continue
# Sample tracks per image (distribute budget across views)
max_per_image = MAX_TRACKS // S
if idx_valid.size > max_per_image:
step = int(np.ceil(idx_valid.size / max_per_image))
idx_valid = idx_valid[::step][:max_per_image]
# Extract selected tracks and colors
selected_tracks = tracks_flat[:, idx_valid, :] # (T, N_selected, 3)
selected_colors = cols[idx_valid] # (N_selected, 3)
all_tracks.append(selected_tracks)
all_colors.append(selected_colors)
if not all_tracks:
return None, None, None
# Concatenate tracks from all images: (T, N_total, 3)
tracks_xyz = np.concatenate(all_tracks, axis=1)
colors_all = np.concatenate(all_colors, axis=0) # (N_total, 3)
num_tracks = tracks_xyz.shape[1]
# Sort tracks by Y position for consistent ordering
order = np.argsort(tracks_xyz[0, :, 1])
tracks_xyz = tracks_xyz[:, order, :]
# Create HSV colormap for tracks (rainbow colors)
cmap = matplotlib.colormaps.get_cmap("hsv")
norm = matplotlib.colors.Normalize(vmin=0, vmax=max(num_tracks - 1, 1))
colorscale = []
for t in range(num_tracks):
r, g, b, _ = cmap(norm(t))
r, g, b = int(r * 255), int(g * 255), int(b * 255)
pos = t / max(num_tracks - 1, 1)
colorscale.append([pos, f"rgb({r},{g},{b})"])
track_ids = np.arange(num_tracks, dtype=float)
print(f" Prepared {num_tracks} tracks from {S} images ({len(all_tracks)} with valid tracks)")
return tracks_xyz, colorscale, track_ids
def track_segments_for_frame(tracks_xyz: np.ndarray | None, track_ids: np.ndarray | None, f: int):
if tracks_xyz is None or track_ids is None or f <= 0:
return np.array([]), np.array([]), np.array([]), np.array([])
start_t = max(0, f - TRAIL_LENGTH)
num_tracks = tracks_xyz.shape[1]
xs, ys, zs, cs = [], [], [], []
for j in range(num_tracks):
seg = tracks_xyz[start_t : f + 1, j, :]
if seg.shape[0] < 2:
continue
xs.extend([seg[:, 0], np.array([np.nan])])
ys.extend([seg[:, 1], np.array([np.nan])])
zs.extend([seg[:, 2], np.array([np.nan])])
cs.append(np.full(seg.shape[0] + 1, track_ids[j], dtype=float))
x = np.concatenate(xs) if xs else np.array([])
y = np.concatenate(ys) if ys else np.array([])
z = np.concatenate(zs) if zs else np.array([])
c = np.concatenate(cs) if cs else np.array([])
return x, y, z, c
def build_pointcloud_figure_update(
data,
conf_thres: float,
mask_black_bg: bool,
mask_white_bg: bool,
static_threshold: float = 0.01,
):
if data is None:
return go.Figure()
world_points = data["world_points"] # (T, N, 3) where N = V*H*W
conf = data.get("world_points_conf") # (T, N)
images = ensure_nhwc_images(data["images"]) # (T, C, H, W) -> (T, H, W, C)
S = world_points.shape[0] # Number of timesteps
global_min, global_max, aspectratio = compute_scene_bounds(world_points)
# UPDATED: use separate tracking data for per-image trajectories
world_points_tracks = data.get("world_points_tracks") # (S_images, T, H, W, 3)
world_points_conf_tracks = data.get("world_points_conf_tracks") # (S_images, T, H, W)
images_raw = data.get("images_raw") # (S_images, C, H, W)
num_views = data.get("num_views", 1)
tracks_xyz, colorscale, track_ids = None, None, None
if world_points_tracks is not None and images_raw is not None:
tracks_xyz, colorscale, track_ids = prepare_tracks(
world_points_tracks=world_points_tracks,
world_points_conf_tracks=world_points_conf_tracks,
images_raw=images_raw,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
num_views=num_views,
static_threshold=static_threshold,
)
track_cmax = max(len(track_ids) - 1, 1) if track_ids is not None else 1
pts_xyz = [None] * S
pts_cols = [None] * S
trk_xyz = [None] * S
trk_c = [None] * S
for i in range(S):
pts_i, cols_i = sample_frame_points(
world_points=world_points,
images_nhwc=images,
conf=conf,
idx=i,
conf_thres=conf_thres,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
max_points=MAX_POINTS_PER_FRAME,
)
pts_xyz[i] = pts_i
pts_cols[i] = cols_i
x, y, z, c = track_segments_for_frame(tracks_xyz, track_ids, f=i)
trk_xyz[i] = (x, y, z)
trk_c[i] = c
p0 = pts_xyz[0]
c0 = pts_cols[0]
x0, y0, z0 = trk_xyz[0]
tc0 = trk_c[0]
scene_cfg = dict(
xaxis=dict(
visible=False,
showbackground=False,
showgrid=False,
zeroline=False,
showticklabels=False,
range=[float(global_min[0]), float(global_max[0])],
),
yaxis=dict(
visible=False,
showbackground=False,
showgrid=False,
zeroline=False,
showticklabels=False,
range=[float(global_min[1]), float(global_max[1])],
),
zaxis=dict(
visible=False,
showbackground=False,
showgrid=False,
zeroline=False,
showticklabels=False,
range=[float(global_min[2]), float(global_max[2])],
),
aspectmode="manual",
aspectratio=aspectratio,
dragmode="orbit",
camera=dict(
eye=dict(x=0.0, y=0.0, z=-1.0),
center=dict(x=0.0, y=0.0, z=0.0),
up=dict(x=0.0, y=-1.0, z=0.0),
),
)
fig = go.Figure(
data=[
go.Scatter3d(
x=p0[:, 0],
y=p0[:, 1],
z=p0[:, 2],
mode="markers",
marker=dict(size=2, color=c0),
showlegend=False,
name="points",
),
go.Scatter3d(
x=x0,
y=y0,
z=z0,
mode="lines",
line=dict(
width=2,
color=tc0 if (tc0 is not None and tc0.size) else None,
colorscale=colorscale if colorscale is not None else None,
cmin=0,
cmax=track_cmax,
),
hoverinfo="skip",
showlegend=False,
name="tracks",
),
]
)
steps = []
for i in range(S):
pi = pts_xyz[i]
ci = pts_cols[i]
xi, yi, zi = trk_xyz[i]
ti = trk_c[i]
steps.append(
dict(
method="update",
label=str(i),
args=[
{
"x": [pi[:, 0], xi],
"y": [pi[:, 1], yi],
"z": [pi[:, 2], zi],
"marker.color": [ci, None],
"line.color": [None, ti if (ti is not None and len(ti)) else None],
},
{},
],
)
)
sliders = [
dict(
active=0,
currentvalue={"prefix": "Frame: ", "visible": True, "font": {"size": 14}},
pad={"t": 10},
len=0.6,
x=0.2,
font={"size": 8},
steps=steps,
)
]
fig.update_layout(
margin=dict(l=0, r=0, t=30, b=0),
scene=scene_cfg,
sliders=sliders,
showlegend=False,
title="Scrub frames with the slider below",
uirevision="keep-camera",
height=700,
)
return fig
def decode_poses(pose_enc: np.ndarray, image_hw: tuple) -> tuple:
"""Decode VGGT pose encodings to camera matrices."""
try:
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
pose_enc_t = torch.from_numpy(pose_enc).float()
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc_t, image_hw)
extrinsic = extrinsic[0].numpy() # (N, 3, 4)
intrinsic = intrinsic[0].numpy() # (N, 3, 3)
N = extrinsic.shape[0]
bottom = np.array([0, 0, 0, 1], dtype=np.float32).reshape(1, 1, 4)
bottom = np.tile(bottom, (N, 1, 1))
extrinsics_4x4 = np.concatenate([extrinsic, bottom], axis=1)
return extrinsics_4x4, intrinsic
except ImportError:
print("Warning: vggt not available. Using identity poses.")
N = pose_enc.shape[1]
extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
H, W = image_hw
fx = fy = max(H, W)
cx, cy = W / 2, H / 2
intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
intrinsics = np.tile(intrinsic, (N, 1, 1))
return extrinsics, intrinsics
def compute_depths(world_points: np.ndarray, extrinsics: np.ndarray, num_views: int) -> np.ndarray:
"""Compute depth maps from world points and camera extrinsics."""
# world_points: (T, V, H, W, 3)
# extrinsics: (N, 4, 4) where N = T * V
T, V, H, W, _ = world_points.shape
depths = np.zeros((T, V, H, W), dtype=np.float32)
for t in range(T):
for v in range(V):
cam_idx = t * V + v
if cam_idx >= len(extrinsics):
continue
extrinsic = extrinsics[cam_idx] # (4, 4)
R = extrinsic[:3, :3]
t_vec = extrinsic[:3, 3]
points_world = world_points[t, v] # (H, W, 3)
points_cam = np.einsum('ij,hwj->hwi', R, points_world) + t_vec
depths[t, v] = points_cam[..., 2] # Z coordinate in camera space
return depths
def run_model(target_dir: str, model: VDPM, frame_id_arg=0, use_temporal_tracking=True) -> dict:
require_cuda()
image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
if not image_names:
raise ValueError("No images found. Check your upload.")
# Load metadata for Multi-View sync (Load BEFORE slicing to respect view count)
meta_path = os.path.join(target_dir, "meta.json")
num_views = 1
if os.path.exists(meta_path):
try:
import json
with open(meta_path, 'r') as f:
num_views = json.load(f).get("num_views", 1)
except:
pass
# Limit frames to prevent OOM on 8GB GPUs
if len(image_names) > MAX_FRAMES:
# Round down to nearest multiple of num_views to preserve full scenes
limit = (MAX_FRAMES // num_views) * num_views
if limit == 0:
limit = num_views # At least one full timestep
print(f"⚠️ Warning: MAX_FRAMES={MAX_FRAMES} is smaller than num_views={num_views}. Processing 1 full timestep anyway (may OOM).")
print(f"⚠️ Limiting to {limit} frames ({limit // num_views} timesteps * {num_views} views) to fit in GPU memory")
image_names = image_names[:limit]
images = load_and_preprocess_images(image_names).to(device)
if device == "cuda":
print(f"GPU memory before inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
S = len(image_names) # Number of input images
print(f"Running inference on {S} images...")
inference_start = time.time()
with torch.no_grad():
with torch.amp.autocast('cuda'):
# =================================================================
# CORRECT API: Single-batch inference for temporal tracking
# =================================================================
# Pass images as a batch: (S, C, H, W) -> unsqueeze to (1, S, C, H, W)
# This triggers VDPM's pairwise DPM mode where:
# - Output: (T, S, H, W, 3) with T = S
# - world_points[t, s] = P_s(t, π₀) = "image s's pixels at time t"
# - Pixel correspondences are maintained across time
# =================================================================
predictions = model.inference(None, images=images.unsqueeze(0))
inference_time = time.time() - inference_start
print(f"✓ Inference completed in {inference_time:.2f}s ({inference_time/S:.2f}s per frame)")
# Move results to CPU immediately to free GPU memory
# predictions["pointmaps"] is a list of dicts, one per timestep query
# Each dict has pts3d shape (1, S, H, W, 3) where S = num input images
pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
# Extract pose encoding before clearing predictions
pose_enc = None
if "pose_enc" in predictions:
pose_enc = predictions["pose_enc"].detach().cpu().numpy()
# Clear predictions from GPU to save memory
del predictions
if device == "cuda":
torch.cuda.empty_cache()
print(f"GPU memory after inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
# Stack timestep results: list of (1, S, H, W, 3) -> (T, S, H, W, 3)
world_points_raw = np.concatenate(pts_list, axis=0) # (T, S, H, W, 3)
world_points_conf_raw = np.concatenate(conf_list, axis=0) # (T, S, H, W)
T = world_points_raw.shape[0] # Number of timesteps
S = world_points_raw.shape[1] # Number of input images
print(f"DEBUG: Raw model output shape: (T={T}, S={S}, H, W, 3)")
print(f"DEBUG: num_views={num_views}")
# ==========================================================================
# VDPM OUTPUT INTERPRETATION (Pairwise DPM)
# ==========================================================================
#
# For S input images, VDPM outputs (T, S, H, W, 3) where:
# - T = S (one timestep query per input image)
# - world_points[t, s, h, w] = 3D position of pixel (h,w) from image s,
# as it would appear at timestep t from
# the reference viewpoint π₀
#
# KEY INSIGHT from the paper:
# - P_s(t, π₀) = world_points[t, s] = "Image s's content at time t"
# - Scene flow for pixel (h,w) in image s:
# flow = P_s(t, π₀)[h,w] - P_s(t_s, π₀)[h,w]
# = world_points[t, s, h, w] - world_points[s, s, h, w]
#
# For VISUALIZATION and TRACKING:
# - Use a reference frame (frame_id) to track all pixels
# - world_points[:, frame_id] gives temporal trajectory of frame_id's pixels
#
# ==========================================================================
try:
frame_id = int(frame_id_arg)
except Exception:
frame_id = 0
if frame_id >= S:
frame_id = 0
# For multi-view: S images = num_views * num_timesteps
# Interleaved as: [v0_t0, v1_t0, ..., v0_t1, v1_t1, ...]
num_timesteps = S // num_views
print(f"DEBUG: Detected {num_views} views × {num_timesteps} timesteps")
# For 4DGS export, extract diagonal entries (each image at its natural time)
# For multi-view: at timestep t, extract all V views
# Diagonal indices: t*num_views+0, t*num_views+1, ..., t*num_views+(V-1)
world_points_multiview = []
world_points_conf_multiview = []
for t in range(num_timesteps):
views_at_t = []
conf_at_t = []
for v in range(num_views):
img_idx = t * num_views + v # Index in interleaved sequence
views_at_t.append(world_points_raw[img_idx, img_idx]) # Diagonal element
conf_at_t.append(world_points_conf_raw[img_idx, img_idx])
# Stack views: (V, H, W, 3)
world_points_multiview.append(np.stack(views_at_t, axis=0))
world_points_conf_multiview.append(np.stack(conf_at_t, axis=0))
# Stack timesteps: (T, V, H, W, 3)
world_points_full = np.stack(world_points_multiview, axis=0)
world_points_conf_full = np.stack(world_points_conf_multiview, axis=0)
print(f"DEBUG: Multi-view extraction shape: {world_points_full.shape}")
# For visualization: combine all views at each timestep into a single point cloud
# Flatten views dimension: (T, V*H*W, 3)
world_points_tracked = world_points_full.reshape(num_timesteps, -1, 3)
world_points_conf_tracked = world_points_conf_full.reshape(num_timesteps, -1)
print(f"DEBUG: Visualization (all views combined) shape: {world_points_tracked.shape}")
# Save tracks for 4DGS training
tracks_path = os.path.join(target_dir, "tracks.npz")
print(f"Saving tracks to {tracks_path}")
np.savez_compressed(
tracks_path,
# For per-frame 3DGS (diagonal extraction)
world_points=world_points_full, # (T, V, H, W, 3)
world_points_conf=world_points_conf_full, # (T, V, H, W)
# For scene flow / deformation (full pairwise DPM output)
world_points_raw=world_points_raw, # (T, S, H, W, 3)
world_points_conf_raw=world_points_conf_raw, # (T, S, H, W)
num_views=num_views,
num_timesteps=num_timesteps,
reference_frame=frame_id
)
# Save pose encodings for GS training
if pose_enc is not None:
poses_path = os.path.join(target_dir, "poses.npz")
print(f"Saving poses to {poses_path}")
np.savez_compressed(poses_path, pose_enc=pose_enc)
# Compute and save depths
depths = None
if pose_enc is not None:
print("Computing depth maps from world points and camera poses...")
# Get image dimensions from world_points_full
H_orig, W_orig = world_points_full.shape[2:4]
extrinsics, intrinsics = decode_poses(pose_enc, (H_orig, W_orig))
depths = compute_depths(world_points_full, extrinsics, num_views)
# Save depths as npz
depths_path = os.path.join(target_dir, "depths.npz")
print(f"Saving depths to {depths_path}")
np.savez_compressed(
depths_path,
depths=depths,
num_views=num_views,
num_timesteps=num_timesteps
)
# Save individual depth images as PNGs
depths_dir = os.path.join(target_dir, "depths")
os.makedirs(depths_dir, exist_ok=True)
print(f"Saving depth images to {depths_dir}/")
T_depth = depths.shape[0]
V_depth = depths.shape[1]
for t in range(T_depth):
for v in range(V_depth):
depth_map = depths[t, v]
png_path = os.path.join(depths_dir, f"depth_t{t:04d}_v{v:02d}.png")
write_depth_to_png(png_path, depth_map)
print(f"✓ Saved {T_depth * V_depth} depth images")
else:
print("⚠ No pose encodings available - skipping depth computation")
extrinsics, intrinsics = None, None
# ============================================================================
# COMPUTE AND SAVE NORMALS (depth_normals.npz — matches trainer format)
# ============================================================================
print("Computing surface normals from point maps...")
normals, tangent_x, tangent_y = compute_normals_from_pointmap(world_points_full)
# Compute smoothed normals
print("Computing smooth normals...")
smooth_normals = compute_smooth_normals(normals, kernel_size=7)
# Save as depth_normals.npz (the name the trainer/viewer expect)
depth_normals_path = os.path.join(target_dir, "depth_normals.npz")
print(f"Saving normals to {depth_normals_path}")
np.savez_compressed(
depth_normals_path,
depth_normals=normals.astype(np.float16), # key must be 'depth_normals'
smooth_normals=smooth_normals.astype(np.float16),
tangent_x=tangent_x.astype(np.float16),
tangent_y=tangent_y.astype(np.float16),
)
# Save individual normal images as PNGs
normals_dir = os.path.join(target_dir, "normals")
os.makedirs(normals_dir, exist_ok=True)
print(f"Saving normal images to {normals_dir}/")
T_norm, V_norm = normals.shape[:2]
for t in range(T_norm):
for v in range(V_norm):
# Raw normals
normal_map = normals[t, v]
png_path = os.path.join(normals_dir, f"normal_t{t:04d}_v{v:02d}.png")
write_normal_to_png(png_path, normal_map)
# Smooth normals
smooth_map = smooth_normals[t, v]
smooth_path = os.path.join(normals_dir, f"smooth_normal_t{t:04d}_v{v:02d}.png")
write_normal_to_png(smooth_path, smooth_map)
print(f"✓ Saved {T_norm * V_norm * 2} normal images (raw + smooth)")
# ============================================================================
# COMPUTE AND SAVE SCENE FLOW + ANGULAR FLOW (matches trainer format)
# ============================================================================
# The trainer expects:
# scene_flow.npz — keys "tXXXX_vYY" → (H, W, 3) float32
# angular_flow.npz — keys "tXXXX_vYY" → (H, W, 9) float32
# Computed from world_points_raw which has the full pairwise DPM output:
# world_points_raw[t, s, h, w] = P_s(t, π₀)
# tracks[frame_idx, ref_idx] style indexing
# ============================================================================
print("Computing scene flow and angular flow from pairwise DPM output...")
sf_npz_dict = {}
af_npz_dict = {}
sf_dir = os.path.join(target_dir, "scene_flow")
af_dir = os.path.join(target_dir, "angular_flow")
os.makedirs(sf_dir, exist_ok=True)
os.makedirs(af_dir, exist_ok=True)
# world_points_raw: (T_query, S_source, H, W, 3)
# For multi-view: S = num_views * num_timesteps, interleaved as
# [v0_t0, v1_t0, ..., v0_t1, v1_t1, ...]
# frame_idx = t * num_views + v (same as the trainer's convention)
H_raw, W_raw = world_points_raw.shape[2:4]
sf_count = 0
for t in range(num_timesteps - 1):
for v in range(num_views):
frame_idx = t * num_views + v
next_frame_idx = (t + 1) * num_views + v
# P(t): points at current time (frame queries itself)
P_t = world_points_raw[frame_idx, frame_idx].astype(np.float32) # (H, W, 3)
# P(t+1): where frame t's points are at time t+1
P_t1 = world_points_raw[frame_idx, next_frame_idx].astype(np.float32) # (H, W, 3)
scene_flow = np.nan_to_num(P_t1 - P_t, nan=0.0, posinf=0.0, neginf=0.0)
key = f"t{t:04d}_v{v:02d}"
sf_npz_dict[key] = scene_flow.astype(np.float32)
# Angular flow: normal difference + tangent frame difference
n_t, tx_t, ty_t = compute_normals_from_pointmap(P_t[np.newaxis])
n_t1, tx_t1, ty_t1 = compute_normals_from_pointmap(P_t1[np.newaxis])
delta_n = np.nan_to_num(n_t1[0] - n_t[0], nan=0.0)
delta_tx = np.nan_to_num(tx_t1[0] - tx_t[0], nan=0.0)
delta_ty = np.nan_to_num(ty_t1[0] - ty_t[0], nan=0.0)
angular_flow = np.concatenate([delta_n, delta_tx, delta_ty], axis=-1) # (H, W, 9)
af_npz_dict[key] = angular_flow.astype(np.float32)
# Save debug images
try:
sf_img = _flow_to_image_3d(scene_flow)
PIL.Image.fromarray(sf_img).save(os.path.join(sf_dir, f"{key}.png"))
af_img = _flow_to_image_3d(delta_n)
PIL.Image.fromarray(af_img).save(os.path.join(af_dir, f"{key}.png"))
except Exception:
pass
sf_count += 1
if sf_count <= 4:
mag = np.linalg.norm(scene_flow, axis=-1)
print(f" [{key}] scene flow: mean={mag.mean():.6f}, max={mag.max():.6f}")
sf_npz_path = os.path.join(target_dir, "scene_flow.npz")
af_npz_path = os.path.join(target_dir, "angular_flow.npz")
np.savez_compressed(sf_npz_path, **sf_npz_dict)
np.savez_compressed(af_npz_path, **af_npz_dict)
print(f"✓ Saved {sf_count} scene flow entries to {sf_npz_path}")
print(f"✓ Saved {sf_count} angular flow entries (9ch) to {af_npz_path}")
# (Moved saving logic to the end of function to capture all viz variables)
num_viz_frames = num_timesteps
print(f"DEBUG: Final viz world_points.shape={world_points_tracked.shape}, num_viz_frames={num_viz_frames}")
torch.cuda.empty_cache()
# ============================================================================
# VISUALIZATION MODE SELECTION
# ============================================================================
if use_temporal_tracking:
# TEMPORAL TRAJECTORY MODE (DEFAULT):
# Track reference frames (t=0, all views) across time
# Shows "Where do the initial pixels move to over time?"
print("Using TEMPORAL TRACKING mode: tracking reference frame pixels over time")
# Get indices of all views at t=0
ref_indices = list(range(num_views)) # [0, 1, ..., V-1]
# Extract temporal trajectories for all reference views
# For multi-view: extract timestep indices (every num_views-th index)
# world_points_raw[t*num_views, ref_idx] = "reference view at timestep t"
ref_trajectories = []
ref_conf_trajectories = []
for ref_idx in ref_indices:
# Extract trajectory for this reference view across actual timesteps
traj_list = []
conf_list = []
for t in range(num_timesteps):
# For multi-view, each timestep's data is at index t*num_views
timestep_idx = t * num_views
traj_list.append(world_points_raw[timestep_idx, ref_idx, ::2, ::2, :]) # (H/2, W/2, 3)
conf_list.append(world_points_conf_raw[timestep_idx, ref_idx, ::2, ::2]) # (H/2, W/2)
# Stack across timesteps: (T_actual, H/2, W/2, 3)
traj = np.stack(traj_list, axis=0)
conf_traj = np.stack(conf_list, axis=0)
ref_trajectories.append(traj)
ref_conf_trajectories.append(conf_traj)
# Stack views: (V, T, H/2, W/2, 3) -> (T, V, H/2, W/2, 3)
ref_trajectories_stacked = np.stack(ref_trajectories, axis=0) # (V, T, H/2, W/2, 3)
ref_trajectories_stacked = np.transpose(ref_trajectories_stacked, (1, 0, 2, 3, 4)) # (T, V, H/2, W/2, 3)
ref_conf_stacked = np.stack(ref_conf_trajectories, axis=0) # (V, T, H/2, W/2)
ref_conf_stacked = np.transpose(ref_conf_stacked, (1, 0, 2, 3)) # (T, V, H/2, W/2)
# Flatten views for visualization: (T, V*H/2*W/2, 3)
T_viz, V_viz, H_viz, W_viz, _ = ref_trajectories_stacked.shape
world_points_s = ref_trajectories_stacked.reshape(T_viz, V_viz * H_viz * W_viz, 3)
world_points_conf_s = ref_conf_stacked.reshape(T_viz, V_viz * H_viz * W_viz)
else:
# PER-FRAME RECONSTRUCTION MODE:
# Show each frame's natural point map at its own time (diagonal extraction)
# Shows "What does the scene look like at time t?"
print("Using PER-FRAME RECONSTRUCTION mode: diagonal extraction")
# Subsample and combine all views at each timestep
world_points_viz = world_points_full[:, :, ::2, ::2, :] # (T, V, H/2, W/2, 3)
world_points_conf_viz = world_points_conf_full[:, :, ::2, ::2] # (T, V, H/2, W/2)
# Flatten views for visualization: (T, V*H/2*W/2, 3)
T_viz, V_viz, H_viz, W_viz, _ = world_points_viz.shape
world_points_s = world_points_viz.reshape(T_viz, V_viz * H_viz * W_viz, 3)
world_points_conf_s = world_points_conf_viz.reshape(T_viz, V_viz * H_viz * W_viz)
print(f"DEBUG: Subsampled visualization shape: {world_points_s.shape}")
# For TRACKING: extract temporal trajectories for each input image
# world_points_raw[t, s] = P_s(t, π_0) = "image s's pixels at time t"
# For each image s, extract its trajectory across all timesteps T
world_points_tracking = []
world_points_conf_tracking = []
for img_idx in range(S):
# Extract this image's pixels across all timesteps
trajectory = world_points_raw[:, img_idx, ::2, ::2, :] # (T, H/2, W/2, 3)
conf_traj = world_points_conf_raw[:, img_idx, ::2, ::2] # (T, H/2, W/2)
world_points_tracking.append(trajectory)
world_points_conf_tracking.append(conf_traj)
# Stack: (S, T, H/2, W/2, 3) - S images, each tracked over T timesteps
world_points_tracks = np.stack(world_points_tracking, axis=0)
world_points_conf_tracks = np.stack(world_points_conf_tracking, axis=0)
print(f"DEBUG: Tracking shape: {world_points_tracks.shape} (S={S} images × T={world_points_tracks.shape[1]} timesteps)")
# Get images for coloring - match visualization mode
# img_np: (S, C, H, W) where S = V*T interleaved
img_np = images.detach().cpu().numpy()
if use_temporal_tracking:
# Use reference frame images (t=0) repeated across time
ref_imgs = []
for v in range(num_views):
ref_imgs.append(img_np[v:v+1, :, ::2, ::2]) # (1, C, H/2, W/2)
# Concatenate views: (1, C, V*H/2, W/2)
combined_ref = np.concatenate(ref_imgs, axis=2)
# Repeat across timesteps: (T, C, V*H/2, W/2)
img_np_viz = np.repeat(combined_ref, num_timesteps, axis=0)
else:
# Deinterleave and combine views for each timestep
img_viz_list = []
for t in range(num_timesteps):
views_imgs = []
for v in range(num_views):
img_idx = t * num_views + v
views_imgs.append(img_np[img_idx:img_idx+1, :, ::2, ::2]) # (1, C, H/2, W/2)
# Concatenate views along height dimension for combined visualization
# Result: (1, C, V*H/2, W/2)
combined = np.concatenate(views_imgs, axis=2)
img_viz_list.append(combined)
img_np_viz = np.concatenate(img_viz_list, axis=0) # (T, C, V*H/2, W/2)
# ================================================================
# BUILD world_points_tracks in (TV, TV, H, W, 3) pairwise format
# ================================================================
# The trainer's _ensure_flow_targets expects:
# tracks[frame_idx, ref_idx] where frame_idx = t * V + v
# world_points_raw is (T_query, S_source, H, W, 3) — this IS the
# pairwise DPM output with T_query = S_source = T * V.
# We just need to save it at full resolution (no subsampling).
TV = num_timesteps * num_views
assert world_points_raw.shape[0] == TV and world_points_raw.shape[1] == TV, \
f"Expected ({TV}, {TV}, H, W, 3) but got {world_points_raw.shape}"
world_points_tracks_full = world_points_raw # (TV, TV, H, W, 3)
# Save Results for Download (Final Format)
output_path = os.path.join(target_dir, "output_4d.npz")
save_dict = {
"world_points": world_points_s,
"world_points_conf": world_points_conf_s,
"world_points_tracks": world_points_tracks_full, # (TV, TV, H, W, 3) pairwise
"world_points_conf_tracks": world_points_conf_tracks, # (S, T, H, W) for full_sample
"images": img_np_viz,
"images_raw": img_np[:, :, ::2, ::2], # Original images subsampled
"num_views": num_views,
"num_timesteps": num_timesteps,
"use_temporal_tracking": use_temporal_tracking,
}
if depths is not None:
save_dict["depths"] = depths
np.savez_compressed(output_path, **save_dict)
# Create zip file containing all outputs for easy download
print("Creating downloadable zip archive of all outputs...")
zip_path = create_output_zip(target_dir)
return {
"world_points": world_points_s,
"world_points_conf": world_points_conf_s,
"world_points_tracks": world_points_tracks, # (S, T, H, W, 3) for tracking
"world_points_conf_tracks": world_points_conf_tracks, # (S, T, H, W)
"images": img_np_viz,
"images_raw": img_np[:, :, ::2, ::2], # Original images subsampled
"download_path": zip_path, # Now returns the zip file for download
"output_dir": target_dir, # Full output directory path
"num_views": num_views,
"num_timesteps": num_timesteps,
"use_temporal_tracking": use_temporal_tracking,
}
def handle_uploads(input_video, input_images):
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
output_base = os.path.join(os.path.dirname(__file__), "..", "output")
os.makedirs(output_base, exist_ok=True)
target_dir = os.path.join(output_base, f"reconstruction_{timestamp}")
target_dir_images = os.path.join(target_dir, "images")
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir_images, exist_ok=True)
image_paths = []
if input_images:
for file_obj in input_images:
src = gradio_file_path(file_obj)
if not src:
continue
dst = os.path.join(target_dir_images, os.path.basename(src))
shutil.copy(src, dst)
image_paths.append(dst)
if input_video:
# Check if input is a list (Gradio 4.x/5.x or file_count="multiple")
input_video_list = input_video if isinstance(input_video, list) else [input_video]
# Determine starting frame number based on existing images
existing_files = os.listdir(target_dir_images)
frame_num = len(existing_files)
# Modified for Interleaved/Synchronized processing
# 1. Open all videos
captures = []
capture_meta = []
for idx, vid_obj in enumerate(input_video_list):
video_path = gradio_file_path(vid_obj)
print(f"Preparing video {idx+1}/{len(input_video_list)}: {video_path}")
vs = cv2.VideoCapture(video_path)
fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0)
if fps <= 0: fps = 30.0 # Fallback
frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
captures.append(vs)
capture_meta.append({"interval": frame_interval, "name": video_path})
# 2. Step through them together
print("Processing videos in interleaved mode...")
step_count = 0
active_videos = True
while active_videos:
active_videos = False
for i, vs in enumerate(captures):
if not vs.isOpened():
continue
gotit, frame = vs.read()
if gotit:
active_videos = True # Keep going as long as at least one video has frames
if step_count % capture_meta[i]["interval"] == 0:
out_path = os.path.join(target_dir_images, f"{frame_num:06}.png")
cv2.imwrite(out_path, frame)
image_paths.append(out_path)
frame_num += 1
else:
vs.release()
step_count += 1
image_paths.sort()
# Save metadata about capture structure (num_views)
num_views = len(input_video_list) if input_video else 1
meta_path = os.path.join(target_dir, "meta.json")
try:
import json
with open(meta_path, 'w') as f:
json.dump({"num_views": num_views}, f)
except Exception as e:
print(f"Warning: could not save metadata: {e}")
print(f"Files copied to {target_dir_images}; took {time.time() - start_time:.3f} seconds")
return target_dir, image_paths
def update_gallery_on_upload(input_video, input_images):
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images)
return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
@spaces.GPU(duration=240)
def gradio_reconstruct(
target_dir,
conf_thres=50.0,
mask_black_bg=False,
mask_white_bg=False,
static_threshold=0.01,
use_temporal_tracking=True,
frame_id_val=0,
):
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first.", None
gc.collect()
torch.cuda.empty_cache()
target_dir_images = os.path.join(target_dir, "images")
num_frames = len(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else 0
with torch.no_grad():
predictions = run_model(target_dir, model, frame_id_val, use_temporal_tracking)
fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg, static_threshold)
torch.cuda.empty_cache()
msg = f"Reconstruction Success ({num_frames} frames processed, showing frame {frame_id_val})."
return fig, msg, predictions
def update_download_button(predictions):
if predictions is None:
return gr.update(visible=False)
path = predictions.get("download_path")
if path and os.path.exists(path):
return gr.update(value=path, visible=True)
return gr.update(visible=False)
def update_plot(
target_dir,
predictions,
conf_thres,
mask_black_bg,
mask_white_bg,
static_threshold,
is_example,
):
if is_example == "True" or predictions is None:
return None, "No reconstruction available. Please click the Reconstruct button first.", gr.update(visible=False)
fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg, static_threshold)
return fig, "Updated visualization with new settings. Use the slider below the plot to scrub frames.", gr.update(visible=False)
def clear_plot():
return None
def clear_fields():
return None, None, [], 0, 0, "", None
def update_log():
return "Loading and Reconstructing..."
def example_pipeline(
input_video_ex,
num_images_str,
input_images_ex,
conf_thres_val,
mask_black_bg_val,
mask_white_bg_val,
static_threshold_val,
use_temporal_tracking_val,
is_example_str,
frame_id_val,
):
target_dir, image_paths = handle_uploads(input_video_ex, input_images_ex)
fig, log_msg, predictions = gradio_reconstruct(
target_dir,
conf_thres_val,
mask_black_bg_val,
mask_white_bg_val,
static_threshold_val,
use_temporal_tracking_val,
frame_id_val,
)
return fig, log_msg, target_dir, predictions, image_paths
colosseum_video = "examples/videos/Colosseum.mp4"
camel_video = "examples/videos/camel.mp4"
tennis_video = "examples/videos/tennis.mp4"
paragliding_video = "examples/videos/paragliding.mp4"
stroller_video = "examples/videos/stroller.mp4"
goldfish_video = "examples/videos/goldfish.mp4"
horse_video = "examples/videos/horse.mp4"
swing_video = "examples/videos/swing.mp4"
car_video = "examples/videos/car.mp4"
figure1_video = "examples/videos/figure1.mp4"
figure2_video = "examples/videos/figure2.mp4"
figure3_video = "examples/videos/figure3.mp4"
tesla_video = "examples/videos/tesla.mp4"
pstudio_video = "examples/videos/pstudio.mp4"
theme = gr.themes.Default(
primary_hue=gr.themes.colors.slate,
secondary_hue=gr.themes.colors.zinc,
neutral_hue=gr.themes.colors.slate,
).set(
checkbox_label_background_fill_selected="*button_primary_background_fill",
checkbox_label_text_color_selected="*button_primary_text_color",
body_background_fill="#FFFFFF",
)
css = """
.custom-log * {
font-style: italic;
font-size: 22px !important;
background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%);
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
}
.example-log * {
font-style: italic;
font-size: 16px !important;
background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
}
#my_radio .wrap {
display: flex;
flex-wrap: nowrap;
justify-content: center;
align-items: center;
}
#my_radio .wrap label {
display: flex;
width: 50%;
justify-content: center;
align-items: center;
margin: 0;
padding: 10px 0;
box-sizing: border-box;
}
"""
cfg = load_cfg_from_cli()
model = load_model(cfg)
with gr.Blocks(theme=theme, css=css) as demo:
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
frame_id_state = gr.Textbox(label="frame_id", visible=False, value="0")
gr.HTML(
"""
<h1>V-DPM: Video Reconstruction with Dynamic Point Maps</h1>
<p>
<a href="https://github.com/eldar/vdpm">🐙 GitHub Repository</a> |
<a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/">Project Page</a>
</p>
<div style="font-size: 16px; line-height: 1.5;">
<p>Upload a video (or multiple videos for multi-view setup) or a set of images to create a dynamic point map reconstruction of a scene or object.</p>
</div>
"""
)
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
predictions_state = gr.State(value=None)
with gr.Row():
with gr.Column(scale=2):
# Change Video input to File input to allow multiple videos for multi-view support
gr.Markdown("### Input")
input_video = gr.File(
label="Upload Video(s)",
file_count="multiple",
file_types=[".mp4", ".avi", ".mov", ".mkv", ".webm"],
interactive=True
)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column(scale=5):
gr.Markdown("**3D Reconstruction (Point Cloud)**")
log_output = gr.Markdown(
"Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
)
download_btn = gr.DownloadButton("Download All Outputs (.zip)", visible=False) # Hidden initially - now downloads zip with depths, normals, flow
reconstruction_output = gr.Plot(label="3D Point Cloud") # Moved up so it's defined before ClearButton usage
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
gr.ClearButton(
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, download_btn],
scale=1,
)
with gr.Row():
conf_thres = gr.Slider(0, 100, value=50, step=1, label="Confidence Threshold (%)")
static_threshold = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Track Motion Threshold")
with gr.Row():
with gr.Column():
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
with gr.Column():
use_temporal_tracking = gr.Checkbox(
label="Temporal Tracking Mode",
value=True,
info="Track reference frame pixels over time (vs. per-frame reconstruction)"
)
examples = [
[camel_video, "17", None, 15.0, False, False, 0.01, True, "True", "8"],
[horse_video, "18", None, 50.0, False, False, 0.01, True, "True", "2"],
[tennis_video, "11", None, 5.0, False, False, 0.01, True, "True", "0"],
[paragliding_video, "11", None, 5.0, False, False, 0.01, True, "True", "0"],
[stroller_video, "17", None, 10.0, False, False, 0.01, True, "True", "8"],
[goldfish_video, "11", None, 12.0, False, False, 0.01, True, "True", "5"],
[swing_video, "10", None, 40.0, False, False, 0.01, True, "True", "4"],
[car_video, "13", None, 15.0, False, False, 0.01, True, "True", "7"],
[figure1_video, "10", None, 25.0, False, False, 0.01, True, "True", "0"],
[figure2_video, "12", None, 25.0, False, False, 0.01, True, "True", "6"],
[figure3_video, "13", None, 30.0, False, False, 0.01, True, "True", "0"],
[tesla_video, "18", None, 20.0, False, True, 0.01, True, "True", "0"],
[pstudio_video, "12", None, 0.0, False, False, 0.01, True, "True", "6"],
]
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
# Helper to return dummy download for examples
def example_pipeline_wrapper(*args):
# original returns: fig, log, target_dir, preds, gallery
res = example_pipeline(*args)
# res[3] is predictions, use it to generate the download button update
btn_update = update_download_button(res[3])
return list(res) + [btn_update]
# Create Gradio examples only if example files exist to avoid startup failures
try:
examples_exist = True
for ex in (camel_video, horse_video, tennis_video, paragliding_video,
stroller_video, goldfish_video, swing_video, car_video,
figure1_video, figure2_video, figure3_video, tesla_video, pstudio_video):
if not os.path.exists(ex):
examples_exist = False
break
if examples_exist:
gr.Examples(
examples=examples,
inputs=[
input_video,
num_images,
input_images,
conf_thres,
mask_black_bg,
mask_white_bg,
static_threshold,
use_temporal_tracking,
is_example,
frame_id_state,
],
outputs=[
reconstruction_output,
log_output,
target_dir_output,
predictions_state,
image_gallery,
download_btn, # Add download button to outputs
],
fn=example_pipeline_wrapper,
cache_examples=False,
examples_per_page=50,
)
else:
print('[vdpm] Some example files missing; skipping Gradio Examples.')
except Exception:
print('[vdpm] Failed to build Gradio Examples; continuing without examples.')
submit_btn.click(fn=clear_plot, inputs=[], outputs=[reconstruction_output]).then(
fn=update_log, inputs=[], outputs=[log_output]
).then(
fn=gradio_reconstruct,
inputs=[
target_dir_output,
conf_thres,
mask_black_bg,
mask_white_bg,
static_threshold,
use_temporal_tracking,
frame_id_state,
],
outputs=[reconstruction_output, log_output, predictions_state],
).then(
fn=update_download_button, inputs=[predictions_state], outputs=[download_btn]
).then(
fn=lambda: "False", inputs=[], outputs=[is_example]
)
for ctrl in (conf_thres, mask_black_bg, mask_white_bg, static_threshold):
ctrl.change(
fn=update_plot,
inputs=[
target_dir_output,
predictions_state,
conf_thres,
mask_black_bg,
mask_white_bg,
static_threshold,
is_example,
],
outputs=[reconstruction_output, log_output, download_btn],
)
# Temporal tracking mode requires re-running model
use_temporal_tracking.change(
fn=clear_plot, inputs=[], outputs=[reconstruction_output]
).then(
fn=lambda: "Switching visualization mode. Click 'Reconstruct' to update.",
inputs=[],
outputs=[log_output]
)
input_video.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
input_images.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
# Allow access to output directory (parent of vdpm folder)
output_dir = str(Path(__file__).parent.parent / "output")
demo.queue(max_size=20).launch(show_error=True, share=True, allowed_paths=[output_dir])