File size: 33,943 Bytes
874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 b54fae8 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 23edff6 874cec4 cceef29 874cec4 605e6f4 874cec4 5882bb9 b54fae8 5882bb9 874cec4 23edff6 ba850f3 23edff6 874cec4 cceef29 874cec4 3ed17f3 cceef29 874cec4 23edff6 874cec4 cceef29 874cec4 3ed17f3 874cec4 3ed17f3 874cec4 cceef29 874cec4 3ed17f3 874cec4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 | """Sat3DGen Gradio Demo.
Two-step interactive demo:
1. Upload a satellite image -> generate and visualize a 3D mesh.
2. Select a demo image with a pre-generated trajectory -> render panorama + perspective video.
"""
import csv
import datetime
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple
def log(msg: str):
"""Print with Beijing time (UTC+8) prefix."""
beijing_time = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8)))
timestamp = beijing_time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] {msg}")
import cv2
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as T
import trimesh
from PIL import Image
from source.generator import Sat3DGen
from source.rendering.transform_perspective import compose_rotmat
# ---------------------------------------------------------------------------
# Global state
# ---------------------------------------------------------------------------
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL: Optional[Sat3DGen] = None
PATCH_SIZE: int = 16
SAT_TRANSFORM = None
RESULTS_DIR = Path("./results/gradio_demo")
TRAJECTORY_PREVIEW_SIZE = 256
DEFAULT_SKY_FILENAMES = (
"default_panorama.jpg",
"default_panorama.png",
"default_panorama.jpeg",
"default_demo_panorama.jpg",
"default_demo_panorama.png",
"default_demo_panorama.jpeg",
"default_sky.jpg",
"default_sky.png",
"default_sky.jpeg",
)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
HUGGINGFACE_REPO = "qian43/Sat3DGen"
def load_model(checkpoint_path: str = "checkpoints"):
"""Load the Sat3DGen model (singleton).
Resolution order:
1. Local *checkpoint_path* directory (if it contains model files).
2. HuggingFace Hub repo ``qian43/Sat3DGen``.
When loading from a full checkpoint (local or Hub), the backbone
weights are already included in the safetensors file, so the
standalone DINOv3 download is skipped automatically.
"""
global MODEL, PATCH_SIZE, SAT_TRANSFORM
if MODEL is not None:
return
model_path: str | None = None
checkpoint_path_obj = Path(checkpoint_path)
if (checkpoint_path_obj / "config.json").exists():
model_path = str(checkpoint_path_obj)
elif (checkpoint_path_obj / "vqmodel_ema").exists():
model_path = str(checkpoint_path_obj / "vqmodel_ema")
elif (checkpoint_path_obj / "vqmodel").exists():
model_path = str(checkpoint_path_obj / "vqmodel")
if model_path is None:
model_path = HUGGINGFACE_REPO
log(f"Local checkpoint not found at '{checkpoint_path}', loading from HuggingFace: {HUGGINGFACE_REPO}")
# Skip redundant backbone weight download β from_pretrained will
# overwrite all parameters from the safetensors file anyway.
Sat3DGen._skip_backbone_weights = True
log(f"Loading model from {model_path} ...")
MODEL = Sat3DGen.from_pretrained(model_path).to(DEVICE)
Sat3DGen._skip_backbone_weights = False
MODEL.eval()
PATCH_SIZE = MODEL.unet_model.patch_size if hasattr(MODEL.unet_model, "patch_size") else 16
SAT_TRANSFORM = T.Compose([
T.Resize((PATCH_SIZE * 16, PATCH_SIZE * 16), interpolation=Image.BICUBIC),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
log("Model loaded successfully.")
# ---------------------------------------------------------------------------
# Utility helpers (adapted from single_image_inference.py)
# ---------------------------------------------------------------------------
def save_obj(vertices: np.ndarray, faces: np.ndarray, colors: np.ndarray, filepath: str):
vertices = vertices @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
faces = faces[:, [2, 1, 0]]
mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
vertex_colors=colors.astype(np.uint8),
)
mesh.export(filepath)
def position_to_c2w(position: Tuple[float, float, float]) -> torch.Tensor:
rotation = compose_rotmat(0, 0, 0)
pos = np.array(position, dtype=np.float32)
pos[0] *= -1
pos = pos[[1, 0, 2]]
c2w = np.eye(4, dtype=np.float32)
c2w[:3, :3] = np.array(rotation, dtype=np.float32)
c2w[:3, 3] = pos
return torch.from_numpy(c2w).unsqueeze(0).to(DEVICE)
def build_intrinsics() -> torch.Tensor:
fovx, fovy = 120, 120
fx = 0.5 * 256 / np.tan(0.5 * fovx / 180.0 * np.pi)
fy = 0.5 * 256 / np.tan(0.5 * fovy / 180.0 * np.pi)
cx = (256 - 1) / 2.0
cy = (256 - 1) / 2.0
intrinsics = np.array([[fx / 2, 0, cx / 2], [0, fy / 2, cy / 2], [0, 0, 1]], dtype=np.float32)
return torch.from_numpy(intrinsics).unsqueeze(0).to(DEVICE)
def tensor_to_numpy_rgb(tensor: torch.Tensor) -> np.ndarray:
"""Convert a [1, C, H, W] or [C, H, W] tensor in [0, 1] to a uint8 RGB numpy array."""
img = tensor.detach().cpu().clamp(0, 1)
if img.dim() == 4:
img = img.squeeze(0)
return (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
def get_pano_rgb(output) -> torch.Tensor:
if hasattr(output.str_output, "sr_image"):
return output.str_output.sr_image
return output.str_output.image_raw_compo
def get_per_rgb(output) -> torch.Tensor:
if hasattr(output.per_output, "sr_image"):
return output.per_output.sr_image
return output.per_output.image_raw_compo
def make_histo(grd_img_path: str) -> torch.Tensor:
grd_img = Image.open(grd_img_path).convert("RGB").resize((512, 128))
grd_img = T.ToTensor()(grd_img).unsqueeze(0).float().to(DEVICE)
# Derive the sky-mask path by replacing only the parent directory name,
# keeping the filename intact (just switching extension to .png).
grd_path = Path(grd_img_path)
parent_name = grd_path.parent.name
if parent_name in ("streetview", "panorama"):
mask_dir = grd_path.parent.parent / "pano_sky_mask"
mask_img_path = str(mask_dir / grd_path.with_suffix(".png").name)
else:
raise ValueError(f"Cannot infer sky-mask path from {grd_img_path}")
mask_img = Image.open(mask_img_path).convert("L").resize((512, 128), Image.NEAREST)
mask_img = T.ToTensor()(mask_img).unsqueeze(0).float().to(DEVICE)
sky_image = (grd_img * mask_img).mul(2).sub(1)
sky_image = sky_image.detach().cpu().numpy()
from source.sky_histogram import compute_sky_histogram
histo_sky = torch.from_numpy(
compute_sky_histogram(sky_image[0], hist_range=(-1, 1))
).unsqueeze(0).float().to(DEVICE)
return histo_sky
def read_trajectory_from_csv(csv_path: str, sat_image_size: int) -> Tuple[List[Tuple[float, float, float]], np.ndarray]:
"""Read a pre-generated trajectory .csv file (format: w,h,angle).
Returns:
positions: list of (x_norm, y_norm, z) in [-1, 1] range for rendering
pixel_coords: Nx2 array of pixel coordinates for visualization
"""
half = sat_image_size / 2
positions = []
pixel_coords = []
with open(csv_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
px = float(row["w"])
py = float(row["h"])
pixel_coords.append((px, py))
positions.append(((py - half) / half, (px - half) / half, -0.85))
return positions, np.array(pixel_coords, dtype=np.float32)
def draw_trajectory_on_satellite(
sat_image_pil: Image.Image,
pixel_coords: np.ndarray,
active_index: Optional[int] = None,
) -> np.ndarray:
"""Draw trajectory on satellite image with glow effect (matching demo_inference style)."""
sat_frame = np.array(sat_image_pil.convert("RGB"))
if len(pixel_coords) >= 2:
# White outline pass (thicker, drawn first)
for idx in range(len(pixel_coords) - 1):
pt1 = tuple(np.round(pixel_coords[idx]).astype(int))
pt2 = tuple(np.round(pixel_coords[idx + 1]).astype(int))
cv2.line(sat_frame, pt1, pt2, (255, 255, 255), 3, cv2.LINE_AA)
# Colored line pass (thinner, on top)
for idx in range(len(pixel_coords) - 1):
pt1 = tuple(np.round(pixel_coords[idx]).astype(int))
pt2 = tuple(np.round(pixel_coords[idx + 1]).astype(int))
cv2.line(sat_frame, pt1, pt2, (255, 80, 80), 2, cv2.LINE_AA)
if active_index is not None and len(pixel_coords) > 0:
coord = pixel_coords[min(active_index, len(pixel_coords) - 1)]
px, py = int(round(coord[0])), int(round(coord[1]))
# Outer glow via alpha blending
overlay = sat_frame.copy()
cv2.circle(overlay, (px, py), 12, (0, 255, 100), -1, cv2.LINE_AA)
sat_frame = cv2.addWeighted(sat_frame, 0.7, overlay, 0.3, 0)
# Solid inner circle + white ring
cv2.circle(sat_frame, (px, py), 6, (0, 255, 100), -1, cv2.LINE_AA)
cv2.circle(sat_frame, (px, py), 7, (255, 255, 255), 2, cv2.LINE_AA)
return sat_frame
def build_trajectory_preview(sat_image_pil: Image.Image, pixel_coords: np.ndarray) -> Image.Image:
sat_frame = draw_trajectory_on_satellite(sat_image_pil, pixel_coords)
preview = cv2.resize(
sat_frame,
(TRAJECTORY_PREVIEW_SIZE, TRAJECTORY_PREVIEW_SIZE),
interpolation=cv2.INTER_LINEAR,
)
return Image.fromarray(preview)
def resolve_demo_sky_pairs(demo_dir: Path) -> Tuple[List[Tuple[Path, Path]], Optional[Path]]:
pano_dir = demo_dir / "panorama"
mask_dir = demo_dir / "pano_sky_mask"
if not pano_dir.exists() or not mask_dir.exists():
return [], None
mask_lookup = {mask_path.stem: mask_path for mask_path in sorted(mask_dir.glob("*.png"))}
sky_pairs: List[Tuple[Path, Path]] = []
for pano_path in sorted(pano_dir.glob("*")):
if pano_path.suffix.lower() not in {".jpg", ".jpeg", ".png"}:
continue
mask_path = mask_lookup.get(pano_path.stem)
if mask_path is not None:
sky_pairs.append((pano_path, mask_path))
if not sky_pairs:
return [], None
default_idx = 0
for idx, (pano_path, _) in enumerate(sky_pairs):
pano_name_lower = pano_path.name.lower()
pano_stem_lower = pano_path.stem.lower()
if pano_name_lower in DEFAULT_SKY_FILENAMES or "default" in pano_stem_lower:
default_idx = idx
break
ordered_pairs = [sky_pairs[default_idx], *sky_pairs[:default_idx], *sky_pairs[default_idx + 1 :]]
return ordered_pairs, ordered_pairs[0][0]
# ---------------------------------------------------------------------------
# Step 1: Satellite Image β 3D Mesh
# ---------------------------------------------------------------------------
def generate_mesh(sat_image_pil: Image.Image, mesh_resolution: int = 256, progress=gr.Progress()):
"""Generate a 3D mesh from a satellite image."""
if sat_image_pil is None:
raise gr.Error("Please upload a satellite image first.")
log("[generate_mesh] >>> Start")
load_model()
log("[generate_mesh] Model loaded")
progress(0.1, desc="Preprocessing satellite image...")
log("[generate_mesh] Preprocessing satellite image...")
sat_input = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
progress(0.3, desc="Generating triplane features...")
log("[generate_mesh] Generating triplane features...")
with torch.no_grad():
triplane = MODEL.from_sat_to_triplane(sat_input)
log("[generate_mesh] Triplane generated successfully")
progress(0.5, desc="Extracting 3D mesh (this may take a moment)...")
log(f"[generate_mesh] Extracting 3D mesh (resolution={mesh_resolution})...")
with torch.no_grad():
vertices, faces, vertex_colors = MODEL.extract_mesh(triplane, mesh_resolution=mesh_resolution)
log(f"[generate_mesh] Mesh extracted: {vertices.shape[0]} vertices, {faces.shape[0]} faces")
vertices = vertices[:, [1, 2, 0]]
# Save mesh
mesh_path = str(RESULTS_DIR / "mesh.obj")
save_obj(vertices, faces, vertex_colors, mesh_path)
log(f"[generate_mesh] OBJ saved to {mesh_path}")
# Also save triplane to state for Step 2
state = {"triplane": triplane, "sat_image": sat_image_pil}
progress(0.9, desc="Preparing 3D visualization...")
log("[generate_mesh] Converting OBJ β GLB for 3D preview...")
# Create a glb file for Gradio's Model3D component.
# Use a tempfile so Gradio can reliably serve it via its file cache.
import tempfile, shutil
glb_path_local = str(RESULTS_DIR / "mesh.glb")
mesh_trimesh = trimesh.load(mesh_path, process=False)
# Ensure we have a single Trimesh (not a Scene) with vertex normals,
# otherwise Chrome's WebGL renderer shows a blank canvas.
if isinstance(mesh_trimesh, trimesh.Scene):
geometries = list(mesh_trimesh.geometry.values())
if geometries:
mesh_trimesh = trimesh.util.concatenate(geometries)
else:
raise gr.Error("Failed to load mesh geometry.")
if not hasattr(mesh_trimesh, 'vertex_normals') or mesh_trimesh.vertex_normals is None or len(mesh_trimesh.vertex_normals) == 0:
mesh_trimesh.vertex_normals # triggers auto-computation
log(f"[generate_mesh] Mesh has {len(mesh_trimesh.vertices)} verts, {len(mesh_trimesh.faces)} faces, normals: {mesh_trimesh.vertex_normals.shape}")
mesh_trimesh.export(glb_path_local, file_type="glb")
log(f"[generate_mesh] GLB saved to {glb_path_local} ({os.path.getsize(glb_path_local)} bytes)")
tmp_glb = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
shutil.copy2(glb_path_local, tmp_glb.name)
tmp_glb.close()
log(f"[generate_mesh] GLB copied to temp file: {tmp_glb.name}")
progress(1.0, desc="Done!")
log("[generate_mesh] <<< 3D mesh generated successfully!")
return tmp_glb.name, mesh_path, state
def download_mesh(mesh_path: str):
"""Return the mesh file for download."""
if mesh_path and os.path.exists(mesh_path):
return mesh_path
return None
# ---------------------------------------------------------------------------
# Step 2: Trajectory β Panorama + Perspective Video
# ---------------------------------------------------------------------------
def render_trajectory_video(
sat_image_pil: Image.Image,
trajectory_csv_path: str,
sky_path: str,
progress=gr.Progress(),
):
"""Render panorama and perspective views along a pre-generated trajectory.
Layout per frame:
Top row: satellite image (with camera marker) | panorama RGB
Bottom row: 4 perspective views in a horizontal row (left, front, right, back)
"""
log("[render_trajectory_video] >>> Start")
load_model()
sat_size = sat_image_pil.size[0]
positions, pixel_coords = read_trajectory_from_csv(trajectory_csv_path, sat_size)
if len(positions) == 0:
raise gr.Error(f"Trajectory file is empty: {trajectory_csv_path}")
log(f"[render_trajectory_video] Loaded {len(positions)} positions from {trajectory_csv_path}")
progress(0.1, desc="Extracting triplane features...")
sat_tensor = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
with torch.no_grad():
triplane = MODEL.from_sat_to_triplane(sat_tensor)
progress(0.2, desc="Preparing sky condition...")
sky_hist = make_histo(sky_path)
with torch.no_grad():
w_sky = MODEL.w_sky_prepare(sky_hist)
sky_feature_2d = MODEL.w_sky2sky_feature_2D(w_sky, sky_hist)
progress(0.25, desc="Rendering views along trajectory...")
intrinsics = build_intrinsics()
yaw_values = [0, -90, 90, 180]
video_dir = RESULTS_DIR / "video_frames"
if video_dir.exists():
shutil.rmtree(video_dir)
video_dir.mkdir(parents=True, exist_ok=True)
total_positions = len(positions)
for idx, position in enumerate(positions):
progress(0.25 + 0.6 * idx / total_positions, desc=f"Rendering frame {idx + 1}/{total_positions}...")
if idx % 10 == 0 or idx == total_positions - 1:
log(f"[render_trajectory_video] Rendering frame {idx + 1}/{total_positions}...")
c2w = position_to_c2w(position)
c2w[:, :3, 3] = c2w[:, :3, 3] * MODEL.position_scale_factor
with torch.no_grad():
pano_result = MODEL.from_3D_to_results(
triplane,
c2w=c2w,
w_sky=w_sky,
sky_feature_2D=sky_feature_2d,
syn_pano=True,
)
pano_rgb = tensor_to_numpy_rgb(get_pano_rgb(pano_result))
per_views = []
for yaw in yaw_values:
c2w_per = c2w.clone()
c2w_per[:, :3, :3] = torch.from_numpy(compose_rotmat(0, 0, yaw)).unsqueeze(0).to(DEVICE)
per_result = MODEL.from_3D_to_results(
triplane,
c2w=c2w_per,
w_sky=w_sky,
intrinsics=intrinsics,
sky_feature_2D=sky_feature_2d,
syn_pano=False,
syn_per=True,
)
per_rgb = tensor_to_numpy_rgb(get_per_rgb(per_result))
per_views.append(per_rgb)
# --- Satellite image with camera position marker ---
sat_frame = draw_trajectory_on_satellite(sat_image_pil, pixel_coords, active_index=idx)
# --- Compose frame ---
# Top row: satellite (square) | panorama RGB
pano_h, pano_w = pano_rgb.shape[:2]
sat_resized = cv2.resize(sat_frame, (pano_h, pano_h))
top_row = np.concatenate([sat_resized, pano_rgb], axis=1)
# Bottom row: 4 perspective views in a horizontal row (left, front, right, back)
# Flip back view for consistency
per_back = cv2.flip(per_views[3], 1)
per_row = np.concatenate([per_views[1], per_views[0], per_views[2], per_back], axis=1)
# Resize bottom row to match top row width
top_width = top_row.shape[1]
per_row_h = int(per_row.shape[0] * top_width / per_row.shape[1])
per_row_resized = cv2.resize(per_row, (top_width, per_row_h))
composed = np.concatenate([top_row, per_row_resized], axis=0)
frame_path = video_dir / f"{idx:04d}.png"
cv2.imwrite(str(frame_path), cv2.cvtColor(composed, cv2.COLOR_RGB2BGR))
progress(0.9, desc="Encoding video...")
log("[render_trajectory_video] All frames rendered, encoding video with ffmpeg...")
video_path = str(RESULTS_DIR / "trajectory_video.mp4")
ffmpeg_path = shutil.which("ffmpeg")
if ffmpeg_path is None:
raise gr.Error("ffmpeg not found. Please install ffmpeg to generate videos.")
subprocess.run([
ffmpeg_path, "-y", "-framerate", "5",
"-i", str(video_dir / "%04d.png"),
"-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2",
"-c:v", "libx264", "-pix_fmt", "yuv420p",
video_path,
], check=True, capture_output=True)
log(f"[render_trajectory_video] Video saved to {video_path}")
progress(1.0, desc="Done!")
return video_path
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_demo():
# Find sample images from demo directory
demo_dir = Path(__file__).resolve().parent / "demo_images"
sample_sat_images = sorted((demo_dir / "satellite").glob("*.png")) if (demo_dir / "satellite").exists() else []
sample_sat_images_with_csv = [p for p in sample_sat_images if p.with_suffix(".csv").exists()]
sample_sky_pairs, default_sky_path = resolve_demo_sky_pairs(demo_dir)
# Build thumbnail paths for faster UI loading
sat_thumb_dir = demo_dir / "satellite" / "thumbnails"
pano_thumb_dir = demo_dir / "panorama" / "thumbnails"
def get_thumbnail(original_path: Path) -> str:
"""Return thumbnail path if it exists, otherwise fall back to original."""
thumb_dir = sat_thumb_dir if "satellite" in str(original_path) else pano_thumb_dir
thumb_path = thumb_dir / (original_path.stem + ".jpg")
if thumb_path.exists():
return str(thumb_path)
return str(original_path)
with gr.Blocks(title="Sat3DGen Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
## [ICLR 2026] Sat3DGen: Comprehensive Street-Level 3D Scene Generation from Single Satellite Image
### [Code Page](https://github.com/qianmingduowan/Sat3DGen), [Project Page](https://qianmingduowan.github.io/Sat3DGen_project_page/)
Author: [Ming Qian](https://qianmingduowan.github.io/), [Zimin Xia](https://ziminxia.github.io/), [Changkun Liu](https://lck666666.github.io), [Shuailei Ma](https://scholar.google.com/citations?user=dNhzCu4AAAAJ&hl=zh-CN), [Wen Wang](https://encounter1997.github.io/), [Zeran Ke](https://calmke.github.io/), [Bin Tan](https://icetttb.github.io/), [Hang Zhang](https://openreview.net/profile?id=~Hang_Zhang22), [Gui-Song Xia](http://www.captain-whu.com/xia_En.html)
Upload a satellite image to **generate a 3D mesh** or **render a walkthrough video**.
π **Input requirements:** The satellite image should be at **zoom level 20**
(same as the [VIGOR](https://github.com/Jeff-Zilence/VIGOR) dataset), then will be resized to the input size.
You can download satellite tiles at this zoom level from any map tile API (e.g. Google Maps, Bing Maps, Mapbox).
"""
)
# Shared state
inference_state = gr.State(value=None)
mesh_file_path = gr.State(value=None)
# ---- 3D Mesh Generation ----
with gr.Tab("3D Mesh Generation"):
with gr.Row():
with gr.Column(scale=1):
sat_input = gr.Image(
label="Upload Satellite Image",
type="pil",
height=400,
)
mesh_resolution_slider = gr.Slider(
minimum=128, maximum=512, value=128, step=64,
label="Mesh Resolution (voxel size)",
)
generate_button = gr.Button("π Generate 3D Mesh", variant="primary", size="lg")
with gr.Column(scale=2):
mesh_viewer = gr.Model3D(label="3D Mesh Preview", height=500)
gr.Markdown(
"β³ *After generation completes, the 3D preview may take ~10-200 seconds to load. Please wait patiently.*"
)
download_button = gr.DownloadButton("πΎ Download Mesh (.obj)", variant="secondary")
if sample_sat_images:
gr.Markdown("### Sample Images β click to load")
mesh_sat_gallery = gr.Gallery(
value=[get_thumbnail(p) for p in sample_sat_images],
label="Click to load a sample satellite image",
columns=10,
rows=3,
height="auto",
object_fit="cover",
allow_preview=False,
)
def load_sat_for_mesh(evt: gr.SelectData):
"""Load the full-resolution image when a thumbnail is clicked."""
if evt.index is None or evt.index >= len(sample_sat_images):
return None
return Image.open(str(sample_sat_images[evt.index]))
mesh_sat_gallery.select(
fn=load_sat_for_mesh,
inputs=None,
outputs=[sat_input],
)
gr.Markdown(
"β οΈ **Note:** The 3D mesh preview may show slight color distortion. "
"The cause is currently under investigation."
)
generate_button.click(
fn=generate_mesh,
inputs=[sat_input, mesh_resolution_slider],
outputs=[mesh_viewer, mesh_file_path, inference_state],
)
mesh_file_path.change(
fn=download_mesh,
inputs=[mesh_file_path],
outputs=[download_button],
)
# ---- Video Rendering ----
with gr.Tab("Video Rendering"):
# Hidden state to track the resolved trajectory .csv path
trajectory_csv_state = gr.State(value=None)
sky_path_state = gr.State(value=str(default_sky_path) if default_sky_path is not None else None)
def load_sat_from_gallery(evt: gr.SelectData):
"""Load selected satellite image and check for a same-name trajectory .csv."""
if evt.index is None or evt.index >= len(sample_sat_images_with_csv):
return None, None, "No image selected.", None
sat_path = sample_sat_images_with_csv[evt.index]
sat_pil = Image.open(str(sat_path))
csv_path = sat_path.with_suffix(".csv")
if csv_path.exists():
status_msg = f"β
Trajectory found: `{csv_path.name}`"
_, pixel_coords = read_trajectory_from_csv(str(csv_path), sat_pil.size[0])
preview = build_trajectory_preview(sat_pil, pixel_coords)
return sat_pil, str(csv_path), status_msg, preview
status_msg = (
f"β οΈ No trajectory file found. "
f"Please pre-generate a trajectory and save it as "
f"`{csv_path.name}` in `{sat_path.parent}/` using:\n\n"
f"```\npython inference/make_trajectory.py "
f"--input_img_path {sat_path} --save_same_name\n```"
)
return sat_pil, None, status_msg, None
def on_sat_upload(sat_image_pil):
"""When user uploads a custom satellite image, no same-name trajectory CSV is available."""
if sat_image_pil is None:
return None, "No image uploaded.", None
return None, (
"β οΈ For uploaded images, you need a **trajectory .csv** file with the same name "
"as your satellite image (e.g. `my_image.csv` for `my_image.png`).\n\n"
"You can generate one interactively using either:\n\n"
"- **Jupyter Notebook** (recommended): `inference/make_trajectory.ipynb`\n"
"- **Command line**: "
"`python inference/make_trajectory.py --input_img_path <your_image_path> --save_same_name`\n\n"
"If you used the command line **without** `--save_same_name`, "
"the CSV is saved under `results/<image_name>/pixels.csv`. "
"You will need to **copy** it next to your satellite image with the same base name "
"(e.g. copy to `demo_images/satellite/my_image.csv`)."
), None
def load_sky_from_gallery(evt: gr.SelectData):
"""Select one demo panorama street image. The first entry is the default."""
if not sample_sky_pairs:
return None, None, "No demo panorama street image is available."
if evt.index is None or evt.index >= len(sample_sky_pairs):
sky_path = default_sky_path
else:
sky_path = sample_sky_pairs[evt.index][0]
default_suffix = " (Default)" if sky_path == default_sky_path else ""
status_msg = (
f"Selected demo panorama: `{sky_path.name}`{default_suffix}\n\n"
f"If you do not choose another one, this image will be used."
)
return str(sky_path), str(sky_path), status_msg
def render_video_from_state(sat_image, csv_path, sky_path, progress=gr.Progress()):
"""Render video using the pre-generated trajectory CSV."""
if sat_image is None:
raise gr.Error("Please select or upload a satellite image first.")
if csv_path is None or not Path(csv_path).exists():
raise gr.Error(
"No trajectory CSV found. Please pre-generate a trajectory using: "
"python inference/make_trajectory.py --input_img_path <image> --save_same_name"
)
resolved_sky_path = sky_path or (str(default_sky_path) if default_sky_path is not None else None)
if resolved_sky_path is None or not Path(resolved_sky_path).exists():
raise gr.Error("No valid demo panorama is available. Please add one under demo_images/panorama.")
return render_trajectory_video(sat_image, csv_path, resolved_sky_path, progress)
# ===== Main layout =====
with gr.Row(equal_height=False):
# Left column: satellite image selection
with gr.Column(scale=1):
sat_input_video = gr.Image(
label="Upload Satellite Image",
type="pil",
height=300,
)
trajectory_status = gr.Markdown(value="Select a demo image or upload your own.")
selected_sky_preview = gr.Image(
label="Selected Demo Panorama",
value=str(default_sky_path) if default_sky_path is not None else None,
height=180,
)
default_sky_message = "No demo panorama street image is available."
if default_sky_path is not None:
default_sky_message = (
f"Default demo panorama: `{default_sky_path.name}`\n\n"
"If you do not select another demo panorama, this one will be used."
)
sky_status = gr.Markdown(value=default_sky_message)
render_button = gr.Button("π¬ Render Video", variant="primary", size="lg")
gr.Markdown(
"β³ *Running on CPU β video rendering is slow (~5 min for 80 frames). Please be patient.*"
)
# Middle column: trajectory preview
with gr.Column(scale=1):
trajectory_preview = gr.Image(label="Trajectory Preview", height=300)
# Right column: video output
with gr.Column(scale=2):
video_output = gr.Video(label="Rendered Video", height=500)
# ===== Sample Satellite Images Gallery (only those with a trajectory CSV) =====
if sample_sat_images_with_csv:
gr.Markdown("### π°οΈ Sample Satellite Images β click to load")
sat_gallery = gr.Gallery(
value=[get_thumbnail(p) for p in sample_sat_images_with_csv],
label="Sample Satellite Images (with trajectory)",
columns=10,
rows=1,
height="auto",
object_fit="cover",
allow_preview=False,
)
sat_gallery.select(
fn=load_sat_from_gallery,
inputs=None,
outputs=[sat_input_video, trajectory_csv_state, trajectory_status, trajectory_preview],
)
if sample_sky_pairs:
gr.Markdown(
"### π€οΈ Demo Panorama Street Images β the first one is the default\n\n"
"The panorama image and its corresponding sky mask are used to extract a "
"**sky region color histogram**, which serves as a **lighting condition hint** "
"during street-view rendering. This only affects the appearance (illumination/color tone) "
"of the rendered views β it does **not** alter the underlying 3D NeRF geometry."
)
sky_gallery = gr.Gallery(
value=[
(
get_thumbnail(pano_path),
f"{pano_path.name} (Default)" if pano_path == default_sky_path else pano_path.name,
)
for pano_path, _ in sample_sky_pairs
],
label="Demo Panorama Street Images",
columns=5,
rows=1,
height="auto",
object_fit="cover",
allow_preview=False,
)
sky_gallery.select(
fn=load_sky_from_gallery,
inputs=None,
outputs=[sky_path_state, selected_sky_preview, sky_status],
)
# When user uploads a custom image
sat_input_video.upload(
fn=on_sat_upload,
inputs=[sat_input_video],
outputs=[trajectory_csv_state, trajectory_status, trajectory_preview],
)
render_button.click(
fn=render_video_from_state,
inputs=[sat_input_video, trajectory_csv_state, sky_path_state],
outputs=[video_output],
)
return demo
if __name__ == "__main__":
demo = build_demo()
port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
demo.launch(
server_name="0.0.0.0",
server_port=port,
share=False,
allowed_paths=[
str(Path(__file__).resolve().parent / "demo_images"),
str(Path(__file__).resolve().parent / "results"),
],
)
|