| import gc |
| import glob |
| import os |
| |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
| import shutil |
| import sys |
| import time |
| import zipfile |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import cv2 |
| import gradio as gr |
| import matplotlib |
| import numpy as np |
| import plotly.graph_objects as go |
| import torch |
| import torch.nn.functional as F |
| import json |
| import spaces |
| from hydra import compose, initialize |
| from hydra.core.global_hydra import GlobalHydra |
|
|
| from dpm.model import VDPM |
| from vggt.utils.load_fn import load_and_preprocess_images |
| from util.depth import write_depth_to_png |
|
|
| import torch |
| import torch._dynamo |
| torch._dynamo.config.suppress_errors = True |
| |
| torch.set_float32_matmul_precision('high') |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| MAX_POINTS_PER_FRAME = 50_000 |
| TRAIL_LENGTH = 20 |
| MAX_TRACKS = 150 |
| STATIC_THRESHOLD = 0.01 |
| VIDEO_SAMPLE_HZ = 2.0 |
|
|
| |
| USE_QUANTIZATION = False |
| USE_HALF_PRECISION = True |
| MAX_FRAMES = 5 |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| if device == "cuda": |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| |
| |
| vram_bytes = torch.cuda.get_device_properties(0).total_memory |
| vram_gb = vram_bytes / (1024**3) |
| |
| print(f"\u2713 GPU Detected: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)") |
| |
| if vram_gb >= 30: |
| MAX_FRAMES = 80 |
| print(f" -> High VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}") |
| elif vram_gb >= 15: |
| MAX_FRAMES = 32 |
| print(f" -> Medium VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}") |
| elif vram_gb >= 7.5: |
| MAX_FRAMES = 24 |
| print(f" -> 8GB VRAM detected. Set MAX_FRAMES to {MAX_FRAMES}") |
| else: |
| MAX_FRAMES = 8 |
| print(f" -> Low VRAM (<8GB). Keeping MAX_FRAMES at {MAX_FRAMES} to prevent OOM") |
| print(f"\u2713 TF32 enabled for faster matrix operations") |
|
|
|
|
|
|
| def load_cfg_from_cli() -> "omegaconf.DictConfig": |
| if GlobalHydra.instance().is_initialized(): |
| GlobalHydra.instance().clear() |
| |
| |
| if GlobalHydra.instance().is_initialized(): |
| GlobalHydra.instance().clear() |
| raw_overrides = sys.argv[1:] |
| overrides = [o for o in raw_overrides if not (o.startswith("--f=") or "kernel" in o)] |
| with initialize(config_path="configs"): |
| return compose(config_name="visualise", overrides=overrides) |
|
|
|
|
| def load_model(cfg) -> VDPM: |
| model = VDPM(cfg).to(device) |
|
|
| |
| cache_dir = os.path.expanduser("~/.cache/vdpm") |
| os.makedirs(cache_dir, exist_ok=True) |
| model_path = os.path.join(cache_dir, "vdpm_model.pt") |
| |
| _URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt" |
| |
| |
| if not os.path.exists(model_path): |
| print(f"Downloading model to {model_path}...") |
| sd = torch.hub.load_state_dict_from_url( |
| _URL, |
| file_name="vdpm_model.pt", |
| progress=True, |
| map_location=device |
| ) |
| |
| torch.save(sd, model_path) |
| print(f"✓ Model cached at {model_path}") |
| else: |
| print(f"✓ Loading cached model from {model_path}") |
| sd = torch.load(model_path, map_location=device) |
| |
| print(model.load_state_dict(sd, strict=True)) |
|
|
| model.eval() |
| |
| |
| if USE_HALF_PRECISION and not USE_QUANTIZATION: |
| |
| |
| if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9: |
| print("Converting model to BF16 precision (Hopper+ GPU detected)...") |
| model = model.to(torch.bfloat16) |
| print("✓ Model converted to BF16") |
| else: |
| print("Converting model to FP16 precision...") |
| |
| try: |
| model = model.to_fp16(keep_norm_fp32=False) |
| print("✓ Model converted to FP16 via to_fp16()") |
| except Exception: |
| |
| model = model.half() |
| print("✓ Model converted to FP16 via half() (fallback)") |
| |
| |
| |
| if USE_QUANTIZATION: |
| try: |
| print("Applying INT8 dynamic quantization to reduce memory usage...") |
| |
| model = model.cpu() |
| model = torch.quantization.quantize_dynamic( |
| model, |
| {torch.nn.Linear, torch.nn.Conv2d}, |
| dtype=torch.qint8 |
| ) |
| model = model.to(device) |
| print("✓ Model quantized: ~3x memory reduction (3GB -> 1GB)") |
| except Exception as e: |
| print(f"⚠️ Quantization failed: {e}") |
| print("Continuing with FP16/BF16 precision...") |
| model = model.to(device) |
| |
| |
| |
| |
| |
| return model |
|
|
|
|
| def require_cuda(): |
| if device != "cuda": |
| raise ValueError("CUDA is not available. Check your environment.") |
|
|
|
|
| def gradio_file_path(file_obj): |
| if file_obj is None: |
| return None |
| if isinstance(file_obj, dict) and "name" in file_obj: |
| return file_obj["name"] |
| return file_obj |
|
|
|
|
| def ensure_nhwc_images(images: np.ndarray) -> np.ndarray: |
| if images.ndim == 4 and images.shape[1] == 3: |
| return np.transpose(images, (0, 2, 3, 1)) |
| return images |
|
|
|
|
| |
| |
| |
|
|
| def compute_normals_from_pointmap(point_map: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """Compute surface normals from a point map using cross product of local tangents. |
| |
| Args: |
| point_map: (T, V, H, W, 3) or (T, H, W, 3) point map in world coordinates. |
| |
| Returns: |
| normals: Same shape as input, unit normals in world space. |
| tangent_x: Same shape as input, unit dP/dx tangents. |
| tangent_y: Same shape as input, unit dP/dy tangents. |
| """ |
| original_shape = point_map.shape |
| |
| |
| if point_map.ndim == 5: |
| T, V, H, W, _ = point_map.shape |
| pm = point_map.reshape(T * V, H, W, 3) |
| elif point_map.ndim == 4: |
| T, H, W, _ = point_map.shape |
| V = 1 |
| pm = point_map |
| else: |
| raise ValueError(f"Expected 4D or 5D point map, got shape {point_map.shape}") |
| |
| |
| pm_torch = torch.from_numpy(pm).permute(0, 3, 1, 2).float() |
| |
| |
| dy = pm_torch[:, :, 1:, :] - pm_torch[:, :, :-1, :] |
| dx = pm_torch[:, :, :, 1:] - pm_torch[:, :, :, :-1] |
| |
| |
| dy = F.pad(dy, (0, 0, 0, 1), mode="replicate") |
| dx = F.pad(dx, (0, 1, 0, 0), mode="replicate") |
| |
| |
| tangent_x = F.normalize(dx, dim=1, eps=1e-8) |
| tangent_y = F.normalize(dy, dim=1, eps=1e-8) |
| normals = torch.cross(dx, dy, dim=1) |
| normals = F.normalize(normals, dim=1, eps=1e-8) |
| |
| |
| tangent_x = torch.nan_to_num(tangent_x, nan=0.0, posinf=0.0, neginf=0.0) |
| tangent_y = torch.nan_to_num(tangent_y, nan=0.0, posinf=0.0, neginf=0.0) |
| normals = torch.nan_to_num(normals, nan=0.0, posinf=0.0, neginf=0.0) |
| |
| |
| normals_np = normals.permute(0, 2, 3, 1).numpy() |
| tangent_x_np = tangent_x.permute(0, 2, 3, 1).numpy() |
| tangent_y_np = tangent_y.permute(0, 2, 3, 1).numpy() |
| |
| if point_map.ndim == 5: |
| normals_np = normals_np.reshape(T, V, H, W, 3) |
| tangent_x_np = tangent_x_np.reshape(T, V, H, W, 3) |
| tangent_y_np = tangent_y_np.reshape(T, V, H, W, 3) |
| |
| return normals_np.astype(np.float32), tangent_x_np.astype(np.float32), tangent_y_np.astype(np.float32) |
|
|
|
|
| def compute_smooth_normals(normals: np.ndarray, kernel_size: int = 5) -> np.ndarray: |
| """Apply Gaussian smoothing to normal maps then renormalize. |
| |
| Args: |
| normals: (T, V, H, W, 3) or (T, H, W, 3) normal map. |
| kernel_size: Size of Gaussian kernel for smoothing. |
| |
| Returns: |
| Smoothed and renormalized normal maps with same shape as input. |
| """ |
| original_shape = normals.shape |
| |
| if normals.ndim == 5: |
| T, V, H, W, _ = normals.shape |
| n = normals.reshape(T * V, H, W, 3) |
| elif normals.ndim == 4: |
| T, H, W, _ = normals.shape |
| V = 1 |
| n = normals |
| else: |
| raise ValueError(f"Expected 4D or 5D normals, got shape {normals.shape}") |
| |
| |
| n_torch = torch.from_numpy(n.astype(np.float32)).permute(0, 3, 1, 2) |
| |
| |
| sigma = kernel_size / 4.0 |
| coords = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2 |
| gauss = torch.exp(-coords**2 / (2 * sigma**2)) |
| gauss = gauss / gauss.sum() |
| kernel = gauss.unsqueeze(0) * gauss.unsqueeze(1) |
| kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1) |
| |
| |
| padding = kernel_size // 2 |
| n_smoothed = F.conv2d(n_torch, kernel, padding=padding, groups=3) |
| |
| |
| n_smoothed = F.normalize(n_smoothed, dim=1, eps=1e-8) |
| n_smoothed = torch.nan_to_num(n_smoothed, nan=0.0, posinf=0.0, neginf=0.0) |
| |
| |
| result = n_smoothed.permute(0, 2, 3, 1).numpy() |
| |
| if normals.ndim == 5: |
| result = result.reshape(T, V, H, W, 3) |
| |
| return result.astype(np.float32) |
|
|
|
|
| def compute_optical_flow(world_points: np.ndarray, extrinsics: np.ndarray = None, intrinsics: np.ndarray = None, num_views: int = 1) -> np.ndarray: |
| """Compute optical flow from world point trajectories. |
| |
| Optical flow is computed as the 2D displacement of points between consecutive frames. |
| If camera intrinsics/extrinsics are provided, we project to image space. |
| Otherwise, we use the XY displacement in world space as an approximation. |
| |
| Args: |
| world_points: (T, V, H, W, 3) world point trajectories. |
| extrinsics: Optional (N, 4, 4) camera extrinsics. |
| intrinsics: Optional (N, 3, 3) camera intrinsics. |
| num_views: Number of views per timestep. |
| |
| Returns: |
| optical_flow: (T-1, V, H, W, 2) optical flow as (dx, dy) in pixels or world units. |
| """ |
| T, V, H, W, _ = world_points.shape |
| |
| if T < 2: |
| |
| return np.zeros((1, V, H, W, 2), dtype=np.float32) |
| |
| if extrinsics is not None and intrinsics is not None: |
| |
| flows = [] |
| for t in range(T - 1): |
| flow_t = np.zeros((V, H, W, 2), dtype=np.float32) |
| for v in range(V): |
| cam_idx_curr = t * V + v |
| cam_idx_next = (t + 1) * V + v |
| |
| if cam_idx_curr >= len(extrinsics) or cam_idx_next >= len(intrinsics): |
| continue |
| |
| pts_curr = world_points[t, v] |
| pts_next = world_points[t + 1, v] |
| |
| |
| K = intrinsics[cam_idx_curr] |
| R = extrinsics[cam_idx_curr, :3, :3] |
| t_vec = extrinsics[cam_idx_curr, :3, 3] |
| |
| |
| pts_cam_curr = np.einsum('ij,hwj->hwi', R, pts_curr) + t_vec |
| pts_cam_next = np.einsum('ij,hwj->hwi', R, pts_next) + t_vec |
| |
| |
| z_curr = np.maximum(pts_cam_curr[..., 2:3], 1e-6) |
| z_next = np.maximum(pts_cam_next[..., 2:3], 1e-6) |
| |
| xy_curr = pts_cam_curr[..., :2] / z_curr |
| xy_next = pts_cam_next[..., :2] / z_next |
| |
| uv_curr = xy_curr @ K[:2, :2].T + K[:2, 2] |
| uv_next = xy_next @ K[:2, :2].T + K[:2, 2] |
| |
| flow_t[v] = uv_next - uv_curr |
| |
| flows.append(flow_t) |
| |
| return np.stack(flows, axis=0).astype(np.float32) |
| else: |
| |
| flow = world_points[1:, ..., :2] - world_points[:-1, ..., :2] |
| return flow.astype(np.float32) |
|
|
|
|
| def write_normal_to_png(outpath: str, normals: np.ndarray) -> None: |
| """Write a normal map as a PNG image. |
| |
| Args: |
| outpath: Output file path. |
| normals: (H, W, 3) normal map with values in [-1, 1]. |
| """ |
| import PIL.Image |
| |
| n = np.clip((normals.astype(np.float32) + 1.0) * 0.5, 0.0, 1.0) |
| rgb = (n * 255.0).astype(np.uint8) |
| PIL.Image.fromarray(rgb).save(outpath) |
|
|
|
|
| def write_optical_flow_to_png(outpath: str, flow: np.ndarray, max_flow: float = None) -> None: |
| """Write optical flow as a color-coded PNG using HSV color wheel. |
| |
| Args: |
| outpath: Output file path. |
| flow: (H, W, 2) optical flow as (dx, dy). |
| max_flow: Optional maximum flow magnitude for normalization. |
| """ |
| import PIL.Image |
| |
| fx, fy = flow[..., 0], flow[..., 1] |
| magnitude = np.sqrt(fx**2 + fy**2) |
| angle = np.arctan2(fy, fx) |
| |
| |
| if max_flow is None: |
| max_flow = magnitude.max() if magnitude.max() > 0 else 1.0 |
| magnitude = np.clip(magnitude / max_flow, 0, 1) |
| |
| |
| hue = ((angle + np.pi) / (2 * np.pi) * 180).astype(np.uint8) |
| saturation = (magnitude * 255).astype(np.uint8) |
| value = np.full_like(hue, 255, dtype=np.uint8) |
| |
| hsv = np.stack([hue, saturation, value], axis=-1) |
| rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) |
| |
| PIL.Image.fromarray(rgb).save(outpath) |
|
|
|
|
| def _flow_to_image_3d(flow_hw3: np.ndarray) -> np.ndarray: |
| """Convert a 3D flow map (H, W, 3) to an RGB image. |
| |
| Maps each XYZ component to R, G, B scaled by the global max magnitude. |
| """ |
| mag = np.linalg.norm(flow_hw3, axis=-1, keepdims=True) |
| max_mag = float(mag.max()) + 1e-8 |
| norm = flow_hw3 / max_mag |
| rgb = np.clip((norm + 1.0) * 0.5, 0.0, 1.0) |
| return (rgb * 255).astype(np.uint8) |
|
|
|
|
| def create_output_zip(target_dir: str) -> str: |
| """Create a zip file containing all outputs for download. |
| |
| Args: |
| target_dir: Directory containing all output files. |
| |
| Returns: |
| Path to the created zip file. |
| """ |
| zip_path = os.path.join(target_dir, "outputs.zip") |
| |
| |
| items_to_zip = [ |
| "output_4d.npz", |
| "tracks.npz", |
| "poses.npz", |
| "depths.npz", |
| "depth_normals.npz", |
| "scene_flow.npz", |
| "angular_flow.npz", |
| "depths", |
| "normals", |
| "scene_flow", |
| "angular_flow", |
| "images", |
| "meta.json", |
| ] |
| |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
| for item in items_to_zip: |
| item_path = os.path.join(target_dir, item) |
| if os.path.isfile(item_path): |
| zipf.write(item_path, item) |
| elif os.path.isdir(item_path): |
| for root, dirs, files in os.walk(item_path): |
| for file in files: |
| file_path = os.path.join(root, file) |
| arcname = os.path.relpath(file_path, target_dir) |
| zipf.write(file_path, arcname) |
| |
| print(f"✓ Created output zip: {zip_path}") |
| return zip_path |
|
|
|
|
| def compute_scene_bounds(world_points: np.ndarray): |
| all_pts = world_points.reshape(-1, 3) |
| raw_min = all_pts.min(axis=0) |
| raw_max = all_pts.max(axis=0) |
|
|
| center = 0.5 * (raw_min + raw_max) |
| half_extent = 0.5 * (raw_max - raw_min) * 1.05 |
|
|
| if np.all(half_extent < 1e-6): |
| half_extent[:] = 1.0 |
| else: |
| half_extent[half_extent < 1e-6] = half_extent.max() |
|
|
| global_min = center - half_extent |
| global_max = center + half_extent |
|
|
| max_half = half_extent.max() |
| aspectratio = { |
| "x": float(half_extent[0] / max_half), |
| "y": float(half_extent[1] / max_half), |
| "z": float(half_extent[2] / max_half), |
| } |
| return global_min, global_max, aspectratio |
|
|
|
|
| def stride_downsample(pts: np.ndarray, cols: np.ndarray, max_points: int): |
| n = pts.shape[0] |
| if n <= max_points: |
| return pts, cols |
| step = int(np.ceil(n / max_points)) |
| idx = np.arange(0, n, step)[:max_points] |
| return pts[idx], cols[idx] |
|
|
|
|
| |
| |
| |
| def compute_point_mask( |
| conf_score: np.ndarray | None, |
| cols: np.ndarray, |
| conf_thres: float, |
| mask_black_bg: bool, |
| mask_white_bg: bool, |
| ) -> np.ndarray: |
| """ |
| conf_score: (N,) or None |
| cols: (N,3) uint8 |
| Returns: (N,) boolean mask |
| """ |
| mask = np.ones(cols.shape[0], dtype=bool) |
|
|
| |
| if conf_score is not None and conf_thres > 0: |
| thresh = np.percentile(conf_score, conf_thres) |
| mask &= (conf_score >= thresh) & (conf_score > 1e-5) |
|
|
| |
| if mask_black_bg: |
| mask &= (cols.sum(axis=1) >= 16) |
| if mask_white_bg: |
| mask &= ~((cols[:, 0] > 240) & (cols[:, 1] > 240) & (cols[:, 2] > 240)) |
|
|
| return mask |
|
|
|
|
| def sample_frame_points( |
| world_points: np.ndarray, |
| images_nhwc: np.ndarray, |
| conf: np.ndarray | None, |
| idx: int, |
| conf_thres: float, |
| mask_black_bg: bool, |
| mask_white_bg: bool, |
| max_points: int, |
| ): |
| i = int(np.clip(idx, 0, world_points.shape[0] - 1)) |
| |
| pts = world_points[i] |
| cols = (images_nhwc[i].reshape(-1, 3) * 255).astype(np.uint8) |
|
|
| conf_score = conf[i] if (conf is not None) else None |
|
|
| mask = compute_point_mask( |
| conf_score=conf_score, |
| cols=cols, |
| conf_thres=conf_thres, |
| mask_black_bg=mask_black_bg, |
| mask_white_bg=mask_white_bg, |
| ) |
|
|
| pts = pts[mask] |
| cols = cols[mask] |
|
|
| pts, cols = stride_downsample(pts, cols, max_points) |
|
|
| if pts.size == 0: |
| pts = np.array([[0.0, 0.0, 0.0]]) |
| cols = np.array([[255, 255, 255]], dtype=np.uint8) |
|
|
| colors_str = [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in cols] |
| return pts, colors_str |
|
|
|
|
| |
| |
| |
| def prepare_tracks( |
| world_points_tracks: np.ndarray, |
| world_points_conf_tracks: np.ndarray, |
| images_raw: np.ndarray, |
| conf_thres: float, |
| mask_black_bg: bool, |
| mask_white_bg: bool, |
| num_views: int, |
| static_threshold: float = 0.01, |
| ): |
| """ |
| Prepare tracks from per-image temporal trajectories. |
| |
| Args: |
| world_points_tracks: (S, T, H, W, 3) - each of S images tracked over T timesteps |
| world_points_conf_tracks: (S, T, H, W) - confidence for tracks |
| images_raw: (S, C, H, W) - original input images |
| conf_thres: confidence threshold percentile |
| mask_black_bg: filter black background |
| mask_white_bg: filter white background |
| num_views: number of camera views |
| |
| Returns: |
| tracks_xyz: (T, N_tracks, 3) - selected point trajectories |
| colorscale: colormap for tracks |
| track_ids: track identifiers |
| """ |
| S, T, H, W, _ = world_points_tracks.shape |
| |
| if T < 2: |
| return None, None, None |
| |
| |
| all_tracks = [] |
| all_colors = [] |
| |
| |
| for img_idx in range(S): |
| tracks_img = world_points_tracks[img_idx] |
| conf_img = world_points_conf_tracks[img_idx] |
| |
| |
| tracks_flat = tracks_img.reshape(T, -1, 3) |
| conf_flat = conf_img.reshape(T, -1) |
| |
| N = tracks_flat.shape[1] |
| |
| |
| disp = np.linalg.norm(tracks_flat - tracks_flat[0:1], axis=-1) |
| dynamic_mask = disp.max(axis=0) > static_threshold |
| |
| |
| conf_score = conf_flat.mean(axis=0) |
| |
| |
| img_hwc = np.transpose(images_raw[img_idx], (1, 2, 0)) |
| cols = (img_hwc.reshape(-1, 3) * 255).astype(np.uint8) |
| |
| |
| point_mask = compute_point_mask( |
| conf_score=conf_score, |
| cols=cols, |
| conf_thres=conf_thres, |
| mask_black_bg=mask_black_bg, |
| mask_white_bg=mask_white_bg, |
| ) |
| |
| valid_mask = dynamic_mask & point_mask |
| idx_valid = np.nonzero(valid_mask)[0] |
| |
| if idx_valid.size == 0: |
| continue |
| |
| |
| max_per_image = MAX_TRACKS // S |
| if idx_valid.size > max_per_image: |
| step = int(np.ceil(idx_valid.size / max_per_image)) |
| idx_valid = idx_valid[::step][:max_per_image] |
| |
| |
| selected_tracks = tracks_flat[:, idx_valid, :] |
| selected_colors = cols[idx_valid] |
| |
| all_tracks.append(selected_tracks) |
| all_colors.append(selected_colors) |
| |
| if not all_tracks: |
| return None, None, None |
| |
| |
| tracks_xyz = np.concatenate(all_tracks, axis=1) |
| colors_all = np.concatenate(all_colors, axis=0) |
| |
| num_tracks = tracks_xyz.shape[1] |
| |
| |
| order = np.argsort(tracks_xyz[0, :, 1]) |
| tracks_xyz = tracks_xyz[:, order, :] |
| |
| |
| cmap = matplotlib.colormaps.get_cmap("hsv") |
| norm = matplotlib.colors.Normalize(vmin=0, vmax=max(num_tracks - 1, 1)) |
| |
| colorscale = [] |
| for t in range(num_tracks): |
| r, g, b, _ = cmap(norm(t)) |
| r, g, b = int(r * 255), int(g * 255), int(b * 255) |
| pos = t / max(num_tracks - 1, 1) |
| colorscale.append([pos, f"rgb({r},{g},{b})"]) |
| |
| track_ids = np.arange(num_tracks, dtype=float) |
| |
| print(f" Prepared {num_tracks} tracks from {S} images ({len(all_tracks)} with valid tracks)") |
| |
| return tracks_xyz, colorscale, track_ids |
|
|
|
|
| def track_segments_for_frame(tracks_xyz: np.ndarray | None, track_ids: np.ndarray | None, f: int): |
| if tracks_xyz is None or track_ids is None or f <= 0: |
| return np.array([]), np.array([]), np.array([]), np.array([]) |
|
|
| start_t = max(0, f - TRAIL_LENGTH) |
| num_tracks = tracks_xyz.shape[1] |
|
|
| xs, ys, zs, cs = [], [], [], [] |
| for j in range(num_tracks): |
| seg = tracks_xyz[start_t : f + 1, j, :] |
| if seg.shape[0] < 2: |
| continue |
|
|
| xs.extend([seg[:, 0], np.array([np.nan])]) |
| ys.extend([seg[:, 1], np.array([np.nan])]) |
| zs.extend([seg[:, 2], np.array([np.nan])]) |
| cs.append(np.full(seg.shape[0] + 1, track_ids[j], dtype=float)) |
|
|
| x = np.concatenate(xs) if xs else np.array([]) |
| y = np.concatenate(ys) if ys else np.array([]) |
| z = np.concatenate(zs) if zs else np.array([]) |
| c = np.concatenate(cs) if cs else np.array([]) |
|
|
| return x, y, z, c |
|
|
|
|
| def build_pointcloud_figure_update( |
| data, |
| conf_thres: float, |
| mask_black_bg: bool, |
| mask_white_bg: bool, |
| static_threshold: float = 0.01, |
| ): |
| if data is None: |
| return go.Figure() |
|
|
| world_points = data["world_points"] |
| conf = data.get("world_points_conf") |
| images = ensure_nhwc_images(data["images"]) |
| S = world_points.shape[0] |
|
|
| global_min, global_max, aspectratio = compute_scene_bounds(world_points) |
|
|
| |
| world_points_tracks = data.get("world_points_tracks") |
| world_points_conf_tracks = data.get("world_points_conf_tracks") |
| images_raw = data.get("images_raw") |
| num_views = data.get("num_views", 1) |
| |
| tracks_xyz, colorscale, track_ids = None, None, None |
| if world_points_tracks is not None and images_raw is not None: |
| tracks_xyz, colorscale, track_ids = prepare_tracks( |
| world_points_tracks=world_points_tracks, |
| world_points_conf_tracks=world_points_conf_tracks, |
| images_raw=images_raw, |
| conf_thres=conf_thres, |
| mask_black_bg=mask_black_bg, |
| mask_white_bg=mask_white_bg, |
| num_views=num_views, |
| static_threshold=static_threshold, |
| ) |
| |
| track_cmax = max(len(track_ids) - 1, 1) if track_ids is not None else 1 |
|
|
| pts_xyz = [None] * S |
| pts_cols = [None] * S |
| trk_xyz = [None] * S |
| trk_c = [None] * S |
|
|
| for i in range(S): |
| pts_i, cols_i = sample_frame_points( |
| world_points=world_points, |
| images_nhwc=images, |
| conf=conf, |
| idx=i, |
| conf_thres=conf_thres, |
| mask_black_bg=mask_black_bg, |
| mask_white_bg=mask_white_bg, |
| max_points=MAX_POINTS_PER_FRAME, |
| ) |
| pts_xyz[i] = pts_i |
| pts_cols[i] = cols_i |
|
|
| x, y, z, c = track_segments_for_frame(tracks_xyz, track_ids, f=i) |
| trk_xyz[i] = (x, y, z) |
| trk_c[i] = c |
|
|
| p0 = pts_xyz[0] |
| c0 = pts_cols[0] |
| x0, y0, z0 = trk_xyz[0] |
| tc0 = trk_c[0] |
|
|
| scene_cfg = dict( |
| xaxis=dict( |
| visible=False, |
| showbackground=False, |
| showgrid=False, |
| zeroline=False, |
| showticklabels=False, |
| range=[float(global_min[0]), float(global_max[0])], |
| ), |
| yaxis=dict( |
| visible=False, |
| showbackground=False, |
| showgrid=False, |
| zeroline=False, |
| showticklabels=False, |
| range=[float(global_min[1]), float(global_max[1])], |
| ), |
| zaxis=dict( |
| visible=False, |
| showbackground=False, |
| showgrid=False, |
| zeroline=False, |
| showticklabels=False, |
| range=[float(global_min[2]), float(global_max[2])], |
| ), |
| aspectmode="manual", |
| aspectratio=aspectratio, |
| dragmode="orbit", |
| camera=dict( |
| eye=dict(x=0.0, y=0.0, z=-1.0), |
| center=dict(x=0.0, y=0.0, z=0.0), |
| up=dict(x=0.0, y=-1.0, z=0.0), |
| ), |
| ) |
|
|
| fig = go.Figure( |
| data=[ |
| go.Scatter3d( |
| x=p0[:, 0], |
| y=p0[:, 1], |
| z=p0[:, 2], |
| mode="markers", |
| marker=dict(size=2, color=c0), |
| showlegend=False, |
| name="points", |
| ), |
| go.Scatter3d( |
| x=x0, |
| y=y0, |
| z=z0, |
| mode="lines", |
| line=dict( |
| width=2, |
| color=tc0 if (tc0 is not None and tc0.size) else None, |
| colorscale=colorscale if colorscale is not None else None, |
| cmin=0, |
| cmax=track_cmax, |
| ), |
| hoverinfo="skip", |
| showlegend=False, |
| name="tracks", |
| ), |
| ] |
| ) |
|
|
| steps = [] |
| for i in range(S): |
| pi = pts_xyz[i] |
| ci = pts_cols[i] |
| xi, yi, zi = trk_xyz[i] |
| ti = trk_c[i] |
|
|
| steps.append( |
| dict( |
| method="update", |
| label=str(i), |
| args=[ |
| { |
| "x": [pi[:, 0], xi], |
| "y": [pi[:, 1], yi], |
| "z": [pi[:, 2], zi], |
| "marker.color": [ci, None], |
| "line.color": [None, ti if (ti is not None and len(ti)) else None], |
| }, |
| {}, |
| ], |
| ) |
| ) |
|
|
| sliders = [ |
| dict( |
| active=0, |
| currentvalue={"prefix": "Frame: ", "visible": True, "font": {"size": 14}}, |
| pad={"t": 10}, |
| len=0.6, |
| x=0.2, |
| font={"size": 8}, |
| steps=steps, |
| ) |
| ] |
|
|
| fig.update_layout( |
| margin=dict(l=0, r=0, t=30, b=0), |
| scene=scene_cfg, |
| sliders=sliders, |
| showlegend=False, |
| title="Scrub frames with the slider below", |
| uirevision="keep-camera", |
| height=700, |
| ) |
| return fig |
|
|
|
|
| def decode_poses(pose_enc: np.ndarray, image_hw: tuple) -> tuple: |
| """Decode VGGT pose encodings to camera matrices.""" |
| try: |
| from vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| |
| pose_enc_t = torch.from_numpy(pose_enc).float() |
| extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc_t, image_hw) |
| |
| extrinsic = extrinsic[0].numpy() |
| intrinsic = intrinsic[0].numpy() |
| |
| N = extrinsic.shape[0] |
| bottom = np.array([0, 0, 0, 1], dtype=np.float32).reshape(1, 1, 4) |
| bottom = np.tile(bottom, (N, 1, 1)) |
| extrinsics_4x4 = np.concatenate([extrinsic, bottom], axis=1) |
| |
| return extrinsics_4x4, intrinsic |
| |
| except ImportError: |
| print("Warning: vggt not available. Using identity poses.") |
| N = pose_enc.shape[1] |
| extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1)) |
| |
| H, W = image_hw |
| fx = fy = max(H, W) |
| cx, cy = W / 2, H / 2 |
| intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) |
| intrinsics = np.tile(intrinsic, (N, 1, 1)) |
| |
| return extrinsics, intrinsics |
|
|
|
|
| def compute_depths(world_points: np.ndarray, extrinsics: np.ndarray, num_views: int) -> np.ndarray: |
| """Compute depth maps from world points and camera extrinsics.""" |
| |
| |
| T, V, H, W, _ = world_points.shape |
| depths = np.zeros((T, V, H, W), dtype=np.float32) |
| |
| for t in range(T): |
| for v in range(V): |
| cam_idx = t * V + v |
| if cam_idx >= len(extrinsics): |
| continue |
| |
| extrinsic = extrinsics[cam_idx] |
| R = extrinsic[:3, :3] |
| t_vec = extrinsic[:3, 3] |
| |
| points_world = world_points[t, v] |
| points_cam = np.einsum('ij,hwj->hwi', R, points_world) + t_vec |
| depths[t, v] = points_cam[..., 2] |
| |
| return depths |
|
|
|
|
|
|
|
|
| def run_model(target_dir: str, model: VDPM, frame_id_arg=0, use_temporal_tracking=True) -> dict: |
| require_cuda() |
|
|
| image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*"))) |
| if not image_names: |
| raise ValueError("No images found. Check your upload.") |
|
|
| |
| meta_path = os.path.join(target_dir, "meta.json") |
| num_views = 1 |
| if os.path.exists(meta_path): |
| try: |
| import json |
| with open(meta_path, 'r') as f: |
| num_views = json.load(f).get("num_views", 1) |
| except: |
| pass |
|
|
| |
| if len(image_names) > MAX_FRAMES: |
| |
| limit = (MAX_FRAMES // num_views) * num_views |
| if limit == 0: |
| limit = num_views |
| print(f"⚠️ Warning: MAX_FRAMES={MAX_FRAMES} is smaller than num_views={num_views}. Processing 1 full timestep anyway (may OOM).") |
| |
| print(f"⚠️ Limiting to {limit} frames ({limit // num_views} timesteps * {num_views} views) to fit in GPU memory") |
| image_names = image_names[:limit] |
|
|
| images = load_and_preprocess_images(image_names).to(device) |
|
|
| if device == "cuda": |
| print(f"GPU memory before inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") |
| |
| S = len(image_names) |
| print(f"Running inference on {S} images...") |
|
|
| inference_start = time.time() |
| |
| with torch.no_grad(): |
| with torch.amp.autocast('cuda'): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| predictions = model.inference(None, images=images.unsqueeze(0)) |
| |
| inference_time = time.time() - inference_start |
| print(f"✓ Inference completed in {inference_time:.2f}s ({inference_time/S:.2f}s per frame)") |
|
|
| |
| |
| |
| pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]] |
| conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]] |
| |
| |
| pose_enc = None |
| if "pose_enc" in predictions: |
| pose_enc = predictions["pose_enc"].detach().cpu().numpy() |
| |
| |
| del predictions |
| if device == "cuda": |
| torch.cuda.empty_cache() |
| print(f"GPU memory after inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") |
|
|
| |
| world_points_raw = np.concatenate(pts_list, axis=0) |
| world_points_conf_raw = np.concatenate(conf_list, axis=0) |
| |
| T = world_points_raw.shape[0] |
| S = world_points_raw.shape[1] |
| |
| print(f"DEBUG: Raw model output shape: (T={T}, S={S}, H, W, 3)") |
| print(f"DEBUG: num_views={num_views}") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| try: |
| frame_id = int(frame_id_arg) |
| except Exception: |
| frame_id = 0 |
| |
| if frame_id >= S: |
| frame_id = 0 |
| |
| |
| |
| num_timesteps = S // num_views |
| |
| print(f"DEBUG: Detected {num_views} views × {num_timesteps} timesteps") |
| |
| |
| |
| |
| world_points_multiview = [] |
| world_points_conf_multiview = [] |
| |
| for t in range(num_timesteps): |
| views_at_t = [] |
| conf_at_t = [] |
| for v in range(num_views): |
| img_idx = t * num_views + v |
| views_at_t.append(world_points_raw[img_idx, img_idx]) |
| conf_at_t.append(world_points_conf_raw[img_idx, img_idx]) |
| |
| |
| world_points_multiview.append(np.stack(views_at_t, axis=0)) |
| world_points_conf_multiview.append(np.stack(conf_at_t, axis=0)) |
| |
| |
| world_points_full = np.stack(world_points_multiview, axis=0) |
| world_points_conf_full = np.stack(world_points_conf_multiview, axis=0) |
| |
| print(f"DEBUG: Multi-view extraction shape: {world_points_full.shape}") |
| |
| |
| |
| world_points_tracked = world_points_full.reshape(num_timesteps, -1, 3) |
| world_points_conf_tracked = world_points_conf_full.reshape(num_timesteps, -1) |
| |
| print(f"DEBUG: Visualization (all views combined) shape: {world_points_tracked.shape}") |
|
|
| |
| tracks_path = os.path.join(target_dir, "tracks.npz") |
| print(f"Saving tracks to {tracks_path}") |
| np.savez_compressed( |
| tracks_path, |
| |
| world_points=world_points_full, |
| world_points_conf=world_points_conf_full, |
| |
| world_points_raw=world_points_raw, |
| world_points_conf_raw=world_points_conf_raw, |
| num_views=num_views, |
| num_timesteps=num_timesteps, |
| reference_frame=frame_id |
| ) |
|
|
| |
| if pose_enc is not None: |
| poses_path = os.path.join(target_dir, "poses.npz") |
| print(f"Saving poses to {poses_path}") |
| np.savez_compressed(poses_path, pose_enc=pose_enc) |
| |
| |
| depths = None |
| if pose_enc is not None: |
| print("Computing depth maps from world points and camera poses...") |
| |
| H_orig, W_orig = world_points_full.shape[2:4] |
| extrinsics, intrinsics = decode_poses(pose_enc, (H_orig, W_orig)) |
| depths = compute_depths(world_points_full, extrinsics, num_views) |
| |
| |
| depths_path = os.path.join(target_dir, "depths.npz") |
| print(f"Saving depths to {depths_path}") |
| np.savez_compressed( |
| depths_path, |
| depths=depths, |
| num_views=num_views, |
| num_timesteps=num_timesteps |
| ) |
| |
| |
| depths_dir = os.path.join(target_dir, "depths") |
| os.makedirs(depths_dir, exist_ok=True) |
| print(f"Saving depth images to {depths_dir}/") |
| |
| T_depth = depths.shape[0] |
| V_depth = depths.shape[1] |
| for t in range(T_depth): |
| for v in range(V_depth): |
| depth_map = depths[t, v] |
| png_path = os.path.join(depths_dir, f"depth_t{t:04d}_v{v:02d}.png") |
| write_depth_to_png(png_path, depth_map) |
| |
| print(f"✓ Saved {T_depth * V_depth} depth images") |
| else: |
| print("⚠ No pose encodings available - skipping depth computation") |
| extrinsics, intrinsics = None, None |
|
|
| |
| |
| |
| print("Computing surface normals from point maps...") |
| normals, tangent_x, tangent_y = compute_normals_from_pointmap(world_points_full) |
| |
| |
| print("Computing smooth normals...") |
| smooth_normals = compute_smooth_normals(normals, kernel_size=7) |
| |
| |
| depth_normals_path = os.path.join(target_dir, "depth_normals.npz") |
| print(f"Saving normals to {depth_normals_path}") |
| np.savez_compressed( |
| depth_normals_path, |
| depth_normals=normals.astype(np.float16), |
| smooth_normals=smooth_normals.astype(np.float16), |
| tangent_x=tangent_x.astype(np.float16), |
| tangent_y=tangent_y.astype(np.float16), |
| ) |
| |
| |
| normals_dir = os.path.join(target_dir, "normals") |
| os.makedirs(normals_dir, exist_ok=True) |
| print(f"Saving normal images to {normals_dir}/") |
| |
| T_norm, V_norm = normals.shape[:2] |
| for t in range(T_norm): |
| for v in range(V_norm): |
| |
| normal_map = normals[t, v] |
| png_path = os.path.join(normals_dir, f"normal_t{t:04d}_v{v:02d}.png") |
| write_normal_to_png(png_path, normal_map) |
| |
| |
| smooth_map = smooth_normals[t, v] |
| smooth_path = os.path.join(normals_dir, f"smooth_normal_t{t:04d}_v{v:02d}.png") |
| write_normal_to_png(smooth_path, smooth_map) |
| |
| print(f"✓ Saved {T_norm * V_norm * 2} normal images (raw + smooth)") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print("Computing scene flow and angular flow from pairwise DPM output...") |
| |
| sf_npz_dict = {} |
| af_npz_dict = {} |
| |
| sf_dir = os.path.join(target_dir, "scene_flow") |
| af_dir = os.path.join(target_dir, "angular_flow") |
| os.makedirs(sf_dir, exist_ok=True) |
| os.makedirs(af_dir, exist_ok=True) |
| |
| |
| |
| |
| |
| H_raw, W_raw = world_points_raw.shape[2:4] |
| |
| sf_count = 0 |
| for t in range(num_timesteps - 1): |
| for v in range(num_views): |
| frame_idx = t * num_views + v |
| next_frame_idx = (t + 1) * num_views + v |
| |
| |
| P_t = world_points_raw[frame_idx, frame_idx].astype(np.float32) |
| |
| P_t1 = world_points_raw[frame_idx, next_frame_idx].astype(np.float32) |
| |
| scene_flow = np.nan_to_num(P_t1 - P_t, nan=0.0, posinf=0.0, neginf=0.0) |
| |
| key = f"t{t:04d}_v{v:02d}" |
| sf_npz_dict[key] = scene_flow.astype(np.float32) |
| |
| |
| n_t, tx_t, ty_t = compute_normals_from_pointmap(P_t[np.newaxis]) |
| n_t1, tx_t1, ty_t1 = compute_normals_from_pointmap(P_t1[np.newaxis]) |
| |
| delta_n = np.nan_to_num(n_t1[0] - n_t[0], nan=0.0) |
| delta_tx = np.nan_to_num(tx_t1[0] - tx_t[0], nan=0.0) |
| delta_ty = np.nan_to_num(ty_t1[0] - ty_t[0], nan=0.0) |
| angular_flow = np.concatenate([delta_n, delta_tx, delta_ty], axis=-1) |
| af_npz_dict[key] = angular_flow.astype(np.float32) |
| |
| |
| try: |
| sf_img = _flow_to_image_3d(scene_flow) |
| PIL.Image.fromarray(sf_img).save(os.path.join(sf_dir, f"{key}.png")) |
| af_img = _flow_to_image_3d(delta_n) |
| PIL.Image.fromarray(af_img).save(os.path.join(af_dir, f"{key}.png")) |
| except Exception: |
| pass |
| |
| sf_count += 1 |
| if sf_count <= 4: |
| mag = np.linalg.norm(scene_flow, axis=-1) |
| print(f" [{key}] scene flow: mean={mag.mean():.6f}, max={mag.max():.6f}") |
| |
| sf_npz_path = os.path.join(target_dir, "scene_flow.npz") |
| af_npz_path = os.path.join(target_dir, "angular_flow.npz") |
| np.savez_compressed(sf_npz_path, **sf_npz_dict) |
| np.savez_compressed(af_npz_path, **af_npz_dict) |
| print(f"✓ Saved {sf_count} scene flow entries to {sf_npz_path}") |
| print(f"✓ Saved {sf_count} angular flow entries (9ch) to {af_npz_path}") |
|
|
| |
|
|
| num_viz_frames = num_timesteps |
| print(f"DEBUG: Final viz world_points.shape={world_points_tracked.shape}, num_viz_frames={num_viz_frames}") |
|
|
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| if use_temporal_tracking: |
| |
| |
| |
| print("Using TEMPORAL TRACKING mode: tracking reference frame pixels over time") |
| |
| |
| ref_indices = list(range(num_views)) |
| |
| |
| |
| |
| ref_trajectories = [] |
| ref_conf_trajectories = [] |
| |
| for ref_idx in ref_indices: |
| |
| traj_list = [] |
| conf_list = [] |
| for t in range(num_timesteps): |
| |
| timestep_idx = t * num_views |
| traj_list.append(world_points_raw[timestep_idx, ref_idx, ::2, ::2, :]) |
| conf_list.append(world_points_conf_raw[timestep_idx, ref_idx, ::2, ::2]) |
| |
| |
| traj = np.stack(traj_list, axis=0) |
| conf_traj = np.stack(conf_list, axis=0) |
| ref_trajectories.append(traj) |
| ref_conf_trajectories.append(conf_traj) |
| |
| |
| ref_trajectories_stacked = np.stack(ref_trajectories, axis=0) |
| ref_trajectories_stacked = np.transpose(ref_trajectories_stacked, (1, 0, 2, 3, 4)) |
| |
| ref_conf_stacked = np.stack(ref_conf_trajectories, axis=0) |
| ref_conf_stacked = np.transpose(ref_conf_stacked, (1, 0, 2, 3)) |
| |
| |
| T_viz, V_viz, H_viz, W_viz, _ = ref_trajectories_stacked.shape |
| world_points_s = ref_trajectories_stacked.reshape(T_viz, V_viz * H_viz * W_viz, 3) |
| world_points_conf_s = ref_conf_stacked.reshape(T_viz, V_viz * H_viz * W_viz) |
| else: |
| |
| |
| |
| print("Using PER-FRAME RECONSTRUCTION mode: diagonal extraction") |
| |
| |
| world_points_viz = world_points_full[:, :, ::2, ::2, :] |
| world_points_conf_viz = world_points_conf_full[:, :, ::2, ::2] |
| |
| |
| T_viz, V_viz, H_viz, W_viz, _ = world_points_viz.shape |
| world_points_s = world_points_viz.reshape(T_viz, V_viz * H_viz * W_viz, 3) |
| world_points_conf_s = world_points_conf_viz.reshape(T_viz, V_viz * H_viz * W_viz) |
| |
| print(f"DEBUG: Subsampled visualization shape: {world_points_s.shape}") |
| |
| |
| |
| |
| world_points_tracking = [] |
| world_points_conf_tracking = [] |
| |
| for img_idx in range(S): |
| |
| trajectory = world_points_raw[:, img_idx, ::2, ::2, :] |
| conf_traj = world_points_conf_raw[:, img_idx, ::2, ::2] |
| world_points_tracking.append(trajectory) |
| world_points_conf_tracking.append(conf_traj) |
| |
| |
| world_points_tracks = np.stack(world_points_tracking, axis=0) |
| world_points_conf_tracks = np.stack(world_points_conf_tracking, axis=0) |
| |
| print(f"DEBUG: Tracking shape: {world_points_tracks.shape} (S={S} images × T={world_points_tracks.shape[1]} timesteps)") |
|
|
| |
| |
| img_np = images.detach().cpu().numpy() |
| |
| if use_temporal_tracking: |
| |
| ref_imgs = [] |
| for v in range(num_views): |
| ref_imgs.append(img_np[v:v+1, :, ::2, ::2]) |
| |
| combined_ref = np.concatenate(ref_imgs, axis=2) |
| |
| img_np_viz = np.repeat(combined_ref, num_timesteps, axis=0) |
| else: |
| |
| img_viz_list = [] |
| for t in range(num_timesteps): |
| views_imgs = [] |
| for v in range(num_views): |
| img_idx = t * num_views + v |
| views_imgs.append(img_np[img_idx:img_idx+1, :, ::2, ::2]) |
| |
| |
| combined = np.concatenate(views_imgs, axis=2) |
| img_viz_list.append(combined) |
| img_np_viz = np.concatenate(img_viz_list, axis=0) |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| TV = num_timesteps * num_views |
| assert world_points_raw.shape[0] == TV and world_points_raw.shape[1] == TV, \ |
| f"Expected ({TV}, {TV}, H, W, 3) but got {world_points_raw.shape}" |
| world_points_tracks_full = world_points_raw |
|
|
| |
| output_path = os.path.join(target_dir, "output_4d.npz") |
| save_dict = { |
| "world_points": world_points_s, |
| "world_points_conf": world_points_conf_s, |
| "world_points_tracks": world_points_tracks_full, |
| "world_points_conf_tracks": world_points_conf_tracks, |
| "images": img_np_viz, |
| "images_raw": img_np[:, :, ::2, ::2], |
| "num_views": num_views, |
| "num_timesteps": num_timesteps, |
| "use_temporal_tracking": use_temporal_tracking, |
| } |
| if depths is not None: |
| save_dict["depths"] = depths |
| np.savez_compressed(output_path, **save_dict) |
|
|
| |
| print("Creating downloadable zip archive of all outputs...") |
| zip_path = create_output_zip(target_dir) |
|
|
| return { |
| "world_points": world_points_s, |
| "world_points_conf": world_points_conf_s, |
| "world_points_tracks": world_points_tracks, |
| "world_points_conf_tracks": world_points_conf_tracks, |
| "images": img_np_viz, |
| "images_raw": img_np[:, :, ::2, ::2], |
| "download_path": zip_path, |
| "output_dir": target_dir, |
| "num_views": num_views, |
| "num_timesteps": num_timesteps, |
| "use_temporal_tracking": use_temporal_tracking, |
| } |
|
|
|
|
| def handle_uploads(input_video, input_images): |
| start_time = time.time() |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
| output_base = os.path.join(os.path.dirname(__file__), "..", "output") |
| os.makedirs(output_base, exist_ok=True) |
| target_dir = os.path.join(output_base, f"reconstruction_{timestamp}") |
| target_dir_images = os.path.join(target_dir, "images") |
|
|
| if os.path.exists(target_dir): |
| shutil.rmtree(target_dir) |
| os.makedirs(target_dir_images, exist_ok=True) |
|
|
| image_paths = [] |
|
|
| if input_images: |
| for file_obj in input_images: |
| src = gradio_file_path(file_obj) |
| if not src: |
| continue |
| dst = os.path.join(target_dir_images, os.path.basename(src)) |
| shutil.copy(src, dst) |
| image_paths.append(dst) |
|
|
| if input_video: |
| |
| input_video_list = input_video if isinstance(input_video, list) else [input_video] |
| |
| |
| existing_files = os.listdir(target_dir_images) |
| frame_num = len(existing_files) |
| |
| |
| |
| captures = [] |
| capture_meta = [] |
| for idx, vid_obj in enumerate(input_video_list): |
| video_path = gradio_file_path(vid_obj) |
| print(f"Preparing video {idx+1}/{len(input_video_list)}: {video_path}") |
| |
| vs = cv2.VideoCapture(video_path) |
| fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0) |
| if fps <= 0: fps = 30.0 |
| |
| frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1) |
| captures.append(vs) |
| capture_meta.append({"interval": frame_interval, "name": video_path}) |
| |
| |
| print("Processing videos in interleaved mode...") |
| step_count = 0 |
| active_videos = True |
| |
| while active_videos: |
| active_videos = False |
| for i, vs in enumerate(captures): |
| if not vs.isOpened(): |
| continue |
| |
| gotit, frame = vs.read() |
| if gotit: |
| active_videos = True |
| |
| if step_count % capture_meta[i]["interval"] == 0: |
| out_path = os.path.join(target_dir_images, f"{frame_num:06}.png") |
| cv2.imwrite(out_path, frame) |
| image_paths.append(out_path) |
| frame_num += 1 |
| else: |
| vs.release() |
| |
| step_count += 1 |
|
|
|
|
| image_paths.sort() |
| |
| |
| num_views = len(input_video_list) if input_video else 1 |
| meta_path = os.path.join(target_dir, "meta.json") |
| try: |
| import json |
| with open(meta_path, 'w') as f: |
| json.dump({"num_views": num_views}, f) |
| except Exception as e: |
| print(f"Warning: could not save metadata: {e}") |
|
|
| print(f"Files copied to {target_dir_images}; took {time.time() - start_time:.3f} seconds") |
| return target_dir, image_paths |
|
|
|
|
| def update_gallery_on_upload(input_video, input_images): |
| if not input_video and not input_images: |
| return None, None, None, None |
| target_dir, image_paths = handle_uploads(input_video, input_images) |
| return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing." |
|
|
|
|
| @spaces.GPU(duration=240) |
| def gradio_reconstruct( |
| target_dir, |
| conf_thres=50.0, |
| mask_black_bg=False, |
| mask_white_bg=False, |
| static_threshold=0.01, |
| use_temporal_tracking=True, |
| frame_id_val=0, |
| ): |
| if not os.path.isdir(target_dir) or target_dir == "None": |
| return None, "No valid target directory found. Please upload first.", None |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| target_dir_images = os.path.join(target_dir, "images") |
| num_frames = len(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else 0 |
|
|
| with torch.no_grad(): |
| predictions = run_model(target_dir, model, frame_id_val, use_temporal_tracking) |
|
|
| fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg, static_threshold) |
|
|
| torch.cuda.empty_cache() |
| msg = f"Reconstruction Success ({num_frames} frames processed, showing frame {frame_id_val})." |
| return fig, msg, predictions |
|
|
|
|
| def update_download_button(predictions): |
| if predictions is None: |
| return gr.update(visible=False) |
| path = predictions.get("download_path") |
| if path and os.path.exists(path): |
| return gr.update(value=path, visible=True) |
| return gr.update(visible=False) |
|
|
|
|
| def update_plot( |
| target_dir, |
| predictions, |
| conf_thres, |
| mask_black_bg, |
| mask_white_bg, |
| static_threshold, |
| is_example, |
| ): |
| if is_example == "True" or predictions is None: |
| return None, "No reconstruction available. Please click the Reconstruct button first.", gr.update(visible=False) |
|
|
| fig = build_pointcloud_figure_update(predictions, conf_thres, mask_black_bg, mask_white_bg, static_threshold) |
| return fig, "Updated visualization with new settings. Use the slider below the plot to scrub frames.", gr.update(visible=False) |
|
|
|
|
| def clear_plot(): |
| return None |
|
|
|
|
| def clear_fields(): |
| return None, None, [], 0, 0, "", None |
|
|
|
|
| def update_log(): |
| return "Loading and Reconstructing..." |
|
|
|
|
| def example_pipeline( |
| input_video_ex, |
| num_images_str, |
| input_images_ex, |
| conf_thres_val, |
| mask_black_bg_val, |
| mask_white_bg_val, |
| static_threshold_val, |
| use_temporal_tracking_val, |
| is_example_str, |
| frame_id_val, |
| ): |
| target_dir, image_paths = handle_uploads(input_video_ex, input_images_ex) |
| fig, log_msg, predictions = gradio_reconstruct( |
| target_dir, |
| conf_thres_val, |
| mask_black_bg_val, |
| mask_white_bg_val, |
| static_threshold_val, |
| use_temporal_tracking_val, |
| frame_id_val, |
| ) |
| return fig, log_msg, target_dir, predictions, image_paths |
|
|
|
|
| colosseum_video = "examples/videos/Colosseum.mp4" |
| camel_video = "examples/videos/camel.mp4" |
| tennis_video = "examples/videos/tennis.mp4" |
| paragliding_video = "examples/videos/paragliding.mp4" |
| stroller_video = "examples/videos/stroller.mp4" |
| goldfish_video = "examples/videos/goldfish.mp4" |
| horse_video = "examples/videos/horse.mp4" |
| swing_video = "examples/videos/swing.mp4" |
| car_video = "examples/videos/car.mp4" |
| figure1_video = "examples/videos/figure1.mp4" |
| figure2_video = "examples/videos/figure2.mp4" |
| figure3_video = "examples/videos/figure3.mp4" |
| tesla_video = "examples/videos/tesla.mp4" |
| pstudio_video = "examples/videos/pstudio.mp4" |
|
|
| theme = gr.themes.Default( |
| primary_hue=gr.themes.colors.slate, |
| secondary_hue=gr.themes.colors.zinc, |
| neutral_hue=gr.themes.colors.slate, |
| ).set( |
| checkbox_label_background_fill_selected="*button_primary_background_fill", |
| checkbox_label_text_color_selected="*button_primary_text_color", |
| body_background_fill="#FFFFFF", |
| ) |
|
|
| css = """ |
| .custom-log * { |
| font-style: italic; |
| font-size: 22px !important; |
| background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%); |
| -webkit-background-clip: text; |
| background-clip: text; |
| font-weight: bold !important; |
| color: transparent !important; |
| text-align: center !important; |
| } |
| |
| .example-log * { |
| font-style: italic; |
| font-size: 16px !important; |
| background-image: linear-gradient(120deg, #1f2937 0%, #4b5563 100%); |
| -webkit-background-clip: text; |
| background-clip: text; |
| color: transparent !important; |
| } |
| |
| #my_radio .wrap { |
| display: flex; |
| flex-wrap: nowrap; |
| justify-content: center; |
| align-items: center; |
| } |
| |
| #my_radio .wrap label { |
| display: flex; |
| width: 50%; |
| justify-content: center; |
| align-items: center; |
| margin: 0; |
| padding: 10px 0; |
| box-sizing: border-box; |
| } |
| """ |
|
|
| cfg = load_cfg_from_cli() |
| model = load_model(cfg) |
|
|
| with gr.Blocks(theme=theme, css=css) as demo: |
| is_example = gr.Textbox(label="is_example", visible=False, value="None") |
| num_images = gr.Textbox(label="num_images", visible=False, value="None") |
| frame_id_state = gr.Textbox(label="frame_id", visible=False, value="0") |
|
|
| gr.HTML( |
| """ |
| <h1>V-DPM: Video Reconstruction with Dynamic Point Maps</h1> |
| <p> |
| <a href="https://github.com/eldar/vdpm">🐙 GitHub Repository</a> | |
| <a href="https://www.robots.ox.ac.uk/~vgg/research/vdpm/">Project Page</a> |
| </p> |
| <div style="font-size: 16px; line-height: 1.5;"> |
| <p>Upload a video (or multiple videos for multi-view setup) or a set of images to create a dynamic point map reconstruction of a scene or object.</p> |
| </div> |
| """ |
| ) |
|
|
| target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") |
| predictions_state = gr.State(value=None) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| |
| gr.Markdown("### Input") |
| input_video = gr.File( |
| label="Upload Video(s)", |
| file_count="multiple", |
| file_types=[".mp4", ".avi", ".mov", ".mkv", ".webm"], |
| interactive=True |
| ) |
| input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) |
| image_gallery = gr.Gallery( |
| label="Preview", |
| columns=4, |
| height="300px", |
| show_download_button=True, |
| object_fit="contain", |
| preview=True, |
| ) |
|
|
| with gr.Column(scale=5): |
| gr.Markdown("**3D Reconstruction (Point Cloud)**") |
| log_output = gr.Markdown( |
| "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"] |
| ) |
| download_btn = gr.DownloadButton("Download All Outputs (.zip)", visible=False) |
|
|
| reconstruction_output = gr.Plot(label="3D Point Cloud") |
|
|
| with gr.Row(): |
| submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") |
| gr.ClearButton( |
| [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, download_btn], |
| scale=1, |
| ) |
|
|
| with gr.Row(): |
| conf_thres = gr.Slider(0, 100, value=50, step=1, label="Confidence Threshold (%)") |
| static_threshold = gr.Slider(0.001, 0.1, value=0.01, step=0.001, label="Track Motion Threshold") |
| with gr.Row(): |
| with gr.Column(): |
| mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False) |
| mask_white_bg = gr.Checkbox(label="Filter White Background", value=False) |
| with gr.Column(): |
| use_temporal_tracking = gr.Checkbox( |
| label="Temporal Tracking Mode", |
| value=True, |
| info="Track reference frame pixels over time (vs. per-frame reconstruction)" |
| ) |
|
|
| examples = [ |
| [camel_video, "17", None, 15.0, False, False, 0.01, True, "True", "8"], |
| [horse_video, "18", None, 50.0, False, False, 0.01, True, "True", "2"], |
| [tennis_video, "11", None, 5.0, False, False, 0.01, True, "True", "0"], |
| [paragliding_video, "11", None, 5.0, False, False, 0.01, True, "True", "0"], |
| [stroller_video, "17", None, 10.0, False, False, 0.01, True, "True", "8"], |
| [goldfish_video, "11", None, 12.0, False, False, 0.01, True, "True", "5"], |
| [swing_video, "10", None, 40.0, False, False, 0.01, True, "True", "4"], |
| [car_video, "13", None, 15.0, False, False, 0.01, True, "True", "7"], |
| [figure1_video, "10", None, 25.0, False, False, 0.01, True, "True", "0"], |
| [figure2_video, "12", None, 25.0, False, False, 0.01, True, "True", "6"], |
| [figure3_video, "13", None, 30.0, False, False, 0.01, True, "True", "0"], |
| [tesla_video, "18", None, 20.0, False, True, 0.01, True, "True", "0"], |
| [pstudio_video, "12", None, 0.0, False, False, 0.01, True, "True", "6"], |
| ] |
|
|
| gr.Markdown("Click any row to load an example.", elem_classes=["example-log"]) |
|
|
| |
| def example_pipeline_wrapper(*args): |
| |
| res = example_pipeline(*args) |
| |
| btn_update = update_download_button(res[3]) |
| return list(res) + [btn_update] |
|
|
| |
| try: |
| examples_exist = True |
| for ex in (camel_video, horse_video, tennis_video, paragliding_video, |
| stroller_video, goldfish_video, swing_video, car_video, |
| figure1_video, figure2_video, figure3_video, tesla_video, pstudio_video): |
| if not os.path.exists(ex): |
| examples_exist = False |
| break |
|
|
| if examples_exist: |
| gr.Examples( |
| examples=examples, |
| inputs=[ |
| input_video, |
| num_images, |
| input_images, |
| conf_thres, |
| mask_black_bg, |
| mask_white_bg, |
| static_threshold, |
| use_temporal_tracking, |
| is_example, |
| frame_id_state, |
| ], |
| outputs=[ |
| reconstruction_output, |
| log_output, |
| target_dir_output, |
| predictions_state, |
| image_gallery, |
| download_btn, |
| ], |
| fn=example_pipeline_wrapper, |
| cache_examples=False, |
| examples_per_page=50, |
| ) |
| else: |
| print('[vdpm] Some example files missing; skipping Gradio Examples.') |
| except Exception: |
| print('[vdpm] Failed to build Gradio Examples; continuing without examples.') |
|
|
| submit_btn.click(fn=clear_plot, inputs=[], outputs=[reconstruction_output]).then( |
| fn=update_log, inputs=[], outputs=[log_output] |
| ).then( |
| fn=gradio_reconstruct, |
| inputs=[ |
| target_dir_output, |
| conf_thres, |
| mask_black_bg, |
| mask_white_bg, |
| static_threshold, |
| use_temporal_tracking, |
| frame_id_state, |
| ], |
| outputs=[reconstruction_output, log_output, predictions_state], |
| ).then( |
| fn=update_download_button, inputs=[predictions_state], outputs=[download_btn] |
| ).then( |
| fn=lambda: "False", inputs=[], outputs=[is_example] |
| ) |
|
|
| for ctrl in (conf_thres, mask_black_bg, mask_white_bg, static_threshold): |
| ctrl.change( |
| fn=update_plot, |
| inputs=[ |
| target_dir_output, |
| predictions_state, |
| conf_thres, |
| mask_black_bg, |
| mask_white_bg, |
| static_threshold, |
| is_example, |
| ], |
| outputs=[reconstruction_output, log_output, download_btn], |
| ) |
| |
| |
| use_temporal_tracking.change( |
| fn=clear_plot, inputs=[], outputs=[reconstruction_output] |
| ).then( |
| fn=lambda: "Switching visualization mode. Click 'Reconstruct' to update.", |
| inputs=[], |
| outputs=[log_output] |
| ) |
|
|
| input_video.change( |
| fn=update_gallery_on_upload, |
| inputs=[input_video, input_images], |
| outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], |
| ) |
| input_images.change( |
| fn=update_gallery_on_upload, |
| inputs=[input_video, input_images], |
| outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], |
| ) |
|
|
| |
| output_dir = str(Path(__file__).parent.parent / "output") |
| demo.queue(max_size=20).launch(show_error=True, share=True, allowed_paths=[output_dir]) |
|
|