Sat3DGen / app.py
qian43's picture
Update app.py
605e6f4 verified
"""Sat3DGen Gradio Demo.
Two-step interactive demo:
1. Upload a satellite image -> generate and visualize a 3D mesh.
2. Select a demo image with a pre-generated trajectory -> render panorama + perspective video.
"""
import csv
import datetime
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple
def log(msg: str):
"""Print with Beijing time (UTC+8) prefix."""
beijing_time = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8)))
timestamp = beijing_time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] {msg}")
import cv2
import gradio as gr
import numpy as np
import torch
import torchvision.transforms as T
import trimesh
from PIL import Image
from source.generator import Sat3DGen
from source.rendering.transform_perspective import compose_rotmat
# ---------------------------------------------------------------------------
# Global state
# ---------------------------------------------------------------------------
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL: Optional[Sat3DGen] = None
PATCH_SIZE: int = 16
SAT_TRANSFORM = None
RESULTS_DIR = Path("./results/gradio_demo")
TRAJECTORY_PREVIEW_SIZE = 256
DEFAULT_SKY_FILENAMES = (
"default_panorama.jpg",
"default_panorama.png",
"default_panorama.jpeg",
"default_demo_panorama.jpg",
"default_demo_panorama.png",
"default_demo_panorama.jpeg",
"default_sky.jpg",
"default_sky.png",
"default_sky.jpeg",
)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
HUGGINGFACE_REPO = "qian43/Sat3DGen"
def load_model(checkpoint_path: str = "checkpoints"):
"""Load the Sat3DGen model (singleton).
Resolution order:
1. Local *checkpoint_path* directory (if it contains model files).
2. HuggingFace Hub repo ``qian43/Sat3DGen``.
When loading from a full checkpoint (local or Hub), the backbone
weights are already included in the safetensors file, so the
standalone DINOv3 download is skipped automatically.
"""
global MODEL, PATCH_SIZE, SAT_TRANSFORM
if MODEL is not None:
return
model_path: str | None = None
checkpoint_path_obj = Path(checkpoint_path)
if (checkpoint_path_obj / "config.json").exists():
model_path = str(checkpoint_path_obj)
elif (checkpoint_path_obj / "vqmodel_ema").exists():
model_path = str(checkpoint_path_obj / "vqmodel_ema")
elif (checkpoint_path_obj / "vqmodel").exists():
model_path = str(checkpoint_path_obj / "vqmodel")
if model_path is None:
model_path = HUGGINGFACE_REPO
log(f"Local checkpoint not found at '{checkpoint_path}', loading from HuggingFace: {HUGGINGFACE_REPO}")
# Skip redundant backbone weight download – from_pretrained will
# overwrite all parameters from the safetensors file anyway.
Sat3DGen._skip_backbone_weights = True
log(f"Loading model from {model_path} ...")
MODEL = Sat3DGen.from_pretrained(model_path).to(DEVICE)
Sat3DGen._skip_backbone_weights = False
MODEL.eval()
PATCH_SIZE = MODEL.unet_model.patch_size if hasattr(MODEL.unet_model, "patch_size") else 16
SAT_TRANSFORM = T.Compose([
T.Resize((PATCH_SIZE * 16, PATCH_SIZE * 16), interpolation=Image.BICUBIC),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
log("Model loaded successfully.")
# ---------------------------------------------------------------------------
# Utility helpers (adapted from single_image_inference.py)
# ---------------------------------------------------------------------------
def save_obj(vertices: np.ndarray, faces: np.ndarray, colors: np.ndarray, filepath: str):
vertices = vertices @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
faces = faces[:, [2, 1, 0]]
mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
vertex_colors=colors.astype(np.uint8),
)
mesh.export(filepath)
def position_to_c2w(position: Tuple[float, float, float]) -> torch.Tensor:
rotation = compose_rotmat(0, 0, 0)
pos = np.array(position, dtype=np.float32)
pos[0] *= -1
pos = pos[[1, 0, 2]]
c2w = np.eye(4, dtype=np.float32)
c2w[:3, :3] = np.array(rotation, dtype=np.float32)
c2w[:3, 3] = pos
return torch.from_numpy(c2w).unsqueeze(0).to(DEVICE)
def build_intrinsics() -> torch.Tensor:
fovx, fovy = 120, 120
fx = 0.5 * 256 / np.tan(0.5 * fovx / 180.0 * np.pi)
fy = 0.5 * 256 / np.tan(0.5 * fovy / 180.0 * np.pi)
cx = (256 - 1) / 2.0
cy = (256 - 1) / 2.0
intrinsics = np.array([[fx / 2, 0, cx / 2], [0, fy / 2, cy / 2], [0, 0, 1]], dtype=np.float32)
return torch.from_numpy(intrinsics).unsqueeze(0).to(DEVICE)
def tensor_to_numpy_rgb(tensor: torch.Tensor) -> np.ndarray:
"""Convert a [1, C, H, W] or [C, H, W] tensor in [0, 1] to a uint8 RGB numpy array."""
img = tensor.detach().cpu().clamp(0, 1)
if img.dim() == 4:
img = img.squeeze(0)
return (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
def get_pano_rgb(output) -> torch.Tensor:
if hasattr(output.str_output, "sr_image"):
return output.str_output.sr_image
return output.str_output.image_raw_compo
def get_per_rgb(output) -> torch.Tensor:
if hasattr(output.per_output, "sr_image"):
return output.per_output.sr_image
return output.per_output.image_raw_compo
def make_histo(grd_img_path: str) -> torch.Tensor:
grd_img = Image.open(grd_img_path).convert("RGB").resize((512, 128))
grd_img = T.ToTensor()(grd_img).unsqueeze(0).float().to(DEVICE)
# Derive the sky-mask path by replacing only the parent directory name,
# keeping the filename intact (just switching extension to .png).
grd_path = Path(grd_img_path)
parent_name = grd_path.parent.name
if parent_name in ("streetview", "panorama"):
mask_dir = grd_path.parent.parent / "pano_sky_mask"
mask_img_path = str(mask_dir / grd_path.with_suffix(".png").name)
else:
raise ValueError(f"Cannot infer sky-mask path from {grd_img_path}")
mask_img = Image.open(mask_img_path).convert("L").resize((512, 128), Image.NEAREST)
mask_img = T.ToTensor()(mask_img).unsqueeze(0).float().to(DEVICE)
sky_image = (grd_img * mask_img).mul(2).sub(1)
sky_image = sky_image.detach().cpu().numpy()
from source.sky_histogram import compute_sky_histogram
histo_sky = torch.from_numpy(
compute_sky_histogram(sky_image[0], hist_range=(-1, 1))
).unsqueeze(0).float().to(DEVICE)
return histo_sky
def read_trajectory_from_csv(csv_path: str, sat_image_size: int) -> Tuple[List[Tuple[float, float, float]], np.ndarray]:
"""Read a pre-generated trajectory .csv file (format: w,h,angle).
Returns:
positions: list of (x_norm, y_norm, z) in [-1, 1] range for rendering
pixel_coords: Nx2 array of pixel coordinates for visualization
"""
half = sat_image_size / 2
positions = []
pixel_coords = []
with open(csv_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
px = float(row["w"])
py = float(row["h"])
pixel_coords.append((px, py))
positions.append(((py - half) / half, (px - half) / half, -0.85))
return positions, np.array(pixel_coords, dtype=np.float32)
def draw_trajectory_on_satellite(
sat_image_pil: Image.Image,
pixel_coords: np.ndarray,
active_index: Optional[int] = None,
) -> np.ndarray:
"""Draw trajectory on satellite image with glow effect (matching demo_inference style)."""
sat_frame = np.array(sat_image_pil.convert("RGB"))
if len(pixel_coords) >= 2:
# White outline pass (thicker, drawn first)
for idx in range(len(pixel_coords) - 1):
pt1 = tuple(np.round(pixel_coords[idx]).astype(int))
pt2 = tuple(np.round(pixel_coords[idx + 1]).astype(int))
cv2.line(sat_frame, pt1, pt2, (255, 255, 255), 3, cv2.LINE_AA)
# Colored line pass (thinner, on top)
for idx in range(len(pixel_coords) - 1):
pt1 = tuple(np.round(pixel_coords[idx]).astype(int))
pt2 = tuple(np.round(pixel_coords[idx + 1]).astype(int))
cv2.line(sat_frame, pt1, pt2, (255, 80, 80), 2, cv2.LINE_AA)
if active_index is not None and len(pixel_coords) > 0:
coord = pixel_coords[min(active_index, len(pixel_coords) - 1)]
px, py = int(round(coord[0])), int(round(coord[1]))
# Outer glow via alpha blending
overlay = sat_frame.copy()
cv2.circle(overlay, (px, py), 12, (0, 255, 100), -1, cv2.LINE_AA)
sat_frame = cv2.addWeighted(sat_frame, 0.7, overlay, 0.3, 0)
# Solid inner circle + white ring
cv2.circle(sat_frame, (px, py), 6, (0, 255, 100), -1, cv2.LINE_AA)
cv2.circle(sat_frame, (px, py), 7, (255, 255, 255), 2, cv2.LINE_AA)
return sat_frame
def build_trajectory_preview(sat_image_pil: Image.Image, pixel_coords: np.ndarray) -> Image.Image:
sat_frame = draw_trajectory_on_satellite(sat_image_pil, pixel_coords)
preview = cv2.resize(
sat_frame,
(TRAJECTORY_PREVIEW_SIZE, TRAJECTORY_PREVIEW_SIZE),
interpolation=cv2.INTER_LINEAR,
)
return Image.fromarray(preview)
def resolve_demo_sky_pairs(demo_dir: Path) -> Tuple[List[Tuple[Path, Path]], Optional[Path]]:
pano_dir = demo_dir / "panorama"
mask_dir = demo_dir / "pano_sky_mask"
if not pano_dir.exists() or not mask_dir.exists():
return [], None
mask_lookup = {mask_path.stem: mask_path for mask_path in sorted(mask_dir.glob("*.png"))}
sky_pairs: List[Tuple[Path, Path]] = []
for pano_path in sorted(pano_dir.glob("*")):
if pano_path.suffix.lower() not in {".jpg", ".jpeg", ".png"}:
continue
mask_path = mask_lookup.get(pano_path.stem)
if mask_path is not None:
sky_pairs.append((pano_path, mask_path))
if not sky_pairs:
return [], None
default_idx = 0
for idx, (pano_path, _) in enumerate(sky_pairs):
pano_name_lower = pano_path.name.lower()
pano_stem_lower = pano_path.stem.lower()
if pano_name_lower in DEFAULT_SKY_FILENAMES or "default" in pano_stem_lower:
default_idx = idx
break
ordered_pairs = [sky_pairs[default_idx], *sky_pairs[:default_idx], *sky_pairs[default_idx + 1 :]]
return ordered_pairs, ordered_pairs[0][0]
# ---------------------------------------------------------------------------
# Step 1: Satellite Image → 3D Mesh
# ---------------------------------------------------------------------------
def generate_mesh(sat_image_pil: Image.Image, mesh_resolution: int = 256, progress=gr.Progress()):
"""Generate a 3D mesh from a satellite image."""
if sat_image_pil is None:
raise gr.Error("Please upload a satellite image first.")
log("[generate_mesh] >>> Start")
load_model()
log("[generate_mesh] Model loaded")
progress(0.1, desc="Preprocessing satellite image...")
log("[generate_mesh] Preprocessing satellite image...")
sat_input = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
progress(0.3, desc="Generating triplane features...")
log("[generate_mesh] Generating triplane features...")
with torch.no_grad():
triplane = MODEL.from_sat_to_triplane(sat_input)
log("[generate_mesh] Triplane generated successfully")
progress(0.5, desc="Extracting 3D mesh (this may take a moment)...")
log(f"[generate_mesh] Extracting 3D mesh (resolution={mesh_resolution})...")
with torch.no_grad():
vertices, faces, vertex_colors = MODEL.extract_mesh(triplane, mesh_resolution=mesh_resolution)
log(f"[generate_mesh] Mesh extracted: {vertices.shape[0]} vertices, {faces.shape[0]} faces")
vertices = vertices[:, [1, 2, 0]]
# Save mesh
mesh_path = str(RESULTS_DIR / "mesh.obj")
save_obj(vertices, faces, vertex_colors, mesh_path)
log(f"[generate_mesh] OBJ saved to {mesh_path}")
# Also save triplane to state for Step 2
state = {"triplane": triplane, "sat_image": sat_image_pil}
progress(0.9, desc="Preparing 3D visualization...")
log("[generate_mesh] Converting OBJ → GLB for 3D preview...")
# Create a glb file for Gradio's Model3D component.
# Use a tempfile so Gradio can reliably serve it via its file cache.
import tempfile, shutil
glb_path_local = str(RESULTS_DIR / "mesh.glb")
mesh_trimesh = trimesh.load(mesh_path, process=False)
# Ensure we have a single Trimesh (not a Scene) with vertex normals,
# otherwise Chrome's WebGL renderer shows a blank canvas.
if isinstance(mesh_trimesh, trimesh.Scene):
geometries = list(mesh_trimesh.geometry.values())
if geometries:
mesh_trimesh = trimesh.util.concatenate(geometries)
else:
raise gr.Error("Failed to load mesh geometry.")
if not hasattr(mesh_trimesh, 'vertex_normals') or mesh_trimesh.vertex_normals is None or len(mesh_trimesh.vertex_normals) == 0:
mesh_trimesh.vertex_normals # triggers auto-computation
log(f"[generate_mesh] Mesh has {len(mesh_trimesh.vertices)} verts, {len(mesh_trimesh.faces)} faces, normals: {mesh_trimesh.vertex_normals.shape}")
mesh_trimesh.export(glb_path_local, file_type="glb")
log(f"[generate_mesh] GLB saved to {glb_path_local} ({os.path.getsize(glb_path_local)} bytes)")
tmp_glb = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
shutil.copy2(glb_path_local, tmp_glb.name)
tmp_glb.close()
log(f"[generate_mesh] GLB copied to temp file: {tmp_glb.name}")
progress(1.0, desc="Done!")
log("[generate_mesh] <<< 3D mesh generated successfully!")
return tmp_glb.name, mesh_path, state
def download_mesh(mesh_path: str):
"""Return the mesh file for download."""
if mesh_path and os.path.exists(mesh_path):
return mesh_path
return None
# ---------------------------------------------------------------------------
# Step 2: Trajectory → Panorama + Perspective Video
# ---------------------------------------------------------------------------
def render_trajectory_video(
sat_image_pil: Image.Image,
trajectory_csv_path: str,
sky_path: str,
progress=gr.Progress(),
):
"""Render panorama and perspective views along a pre-generated trajectory.
Layout per frame:
Top row: satellite image (with camera marker) | panorama RGB
Bottom row: 4 perspective views in a horizontal row (left, front, right, back)
"""
log("[render_trajectory_video] >>> Start")
load_model()
sat_size = sat_image_pil.size[0]
positions, pixel_coords = read_trajectory_from_csv(trajectory_csv_path, sat_size)
if len(positions) == 0:
raise gr.Error(f"Trajectory file is empty: {trajectory_csv_path}")
log(f"[render_trajectory_video] Loaded {len(positions)} positions from {trajectory_csv_path}")
progress(0.1, desc="Extracting triplane features...")
sat_tensor = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
with torch.no_grad():
triplane = MODEL.from_sat_to_triplane(sat_tensor)
progress(0.2, desc="Preparing sky condition...")
sky_hist = make_histo(sky_path)
with torch.no_grad():
w_sky = MODEL.w_sky_prepare(sky_hist)
sky_feature_2d = MODEL.w_sky2sky_feature_2D(w_sky, sky_hist)
progress(0.25, desc="Rendering views along trajectory...")
intrinsics = build_intrinsics()
yaw_values = [0, -90, 90, 180]
video_dir = RESULTS_DIR / "video_frames"
if video_dir.exists():
shutil.rmtree(video_dir)
video_dir.mkdir(parents=True, exist_ok=True)
total_positions = len(positions)
for idx, position in enumerate(positions):
progress(0.25 + 0.6 * idx / total_positions, desc=f"Rendering frame {idx + 1}/{total_positions}...")
if idx % 10 == 0 or idx == total_positions - 1:
log(f"[render_trajectory_video] Rendering frame {idx + 1}/{total_positions}...")
c2w = position_to_c2w(position)
c2w[:, :3, 3] = c2w[:, :3, 3] * MODEL.position_scale_factor
with torch.no_grad():
pano_result = MODEL.from_3D_to_results(
triplane,
c2w=c2w,
w_sky=w_sky,
sky_feature_2D=sky_feature_2d,
syn_pano=True,
)
pano_rgb = tensor_to_numpy_rgb(get_pano_rgb(pano_result))
per_views = []
for yaw in yaw_values:
c2w_per = c2w.clone()
c2w_per[:, :3, :3] = torch.from_numpy(compose_rotmat(0, 0, yaw)).unsqueeze(0).to(DEVICE)
per_result = MODEL.from_3D_to_results(
triplane,
c2w=c2w_per,
w_sky=w_sky,
intrinsics=intrinsics,
sky_feature_2D=sky_feature_2d,
syn_pano=False,
syn_per=True,
)
per_rgb = tensor_to_numpy_rgb(get_per_rgb(per_result))
per_views.append(per_rgb)
# --- Satellite image with camera position marker ---
sat_frame = draw_trajectory_on_satellite(sat_image_pil, pixel_coords, active_index=idx)
# --- Compose frame ---
# Top row: satellite (square) | panorama RGB
pano_h, pano_w = pano_rgb.shape[:2]
sat_resized = cv2.resize(sat_frame, (pano_h, pano_h))
top_row = np.concatenate([sat_resized, pano_rgb], axis=1)
# Bottom row: 4 perspective views in a horizontal row (left, front, right, back)
# Flip back view for consistency
per_back = cv2.flip(per_views[3], 1)
per_row = np.concatenate([per_views[1], per_views[0], per_views[2], per_back], axis=1)
# Resize bottom row to match top row width
top_width = top_row.shape[1]
per_row_h = int(per_row.shape[0] * top_width / per_row.shape[1])
per_row_resized = cv2.resize(per_row, (top_width, per_row_h))
composed = np.concatenate([top_row, per_row_resized], axis=0)
frame_path = video_dir / f"{idx:04d}.png"
cv2.imwrite(str(frame_path), cv2.cvtColor(composed, cv2.COLOR_RGB2BGR))
progress(0.9, desc="Encoding video...")
log("[render_trajectory_video] All frames rendered, encoding video with ffmpeg...")
video_path = str(RESULTS_DIR / "trajectory_video.mp4")
ffmpeg_path = shutil.which("ffmpeg")
if ffmpeg_path is None:
raise gr.Error("ffmpeg not found. Please install ffmpeg to generate videos.")
subprocess.run([
ffmpeg_path, "-y", "-framerate", "5",
"-i", str(video_dir / "%04d.png"),
"-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2",
"-c:v", "libx264", "-pix_fmt", "yuv420p",
video_path,
], check=True, capture_output=True)
log(f"[render_trajectory_video] Video saved to {video_path}")
progress(1.0, desc="Done!")
return video_path
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_demo():
# Find sample images from demo directory
demo_dir = Path(__file__).resolve().parent / "demo_images"
sample_sat_images = sorted((demo_dir / "satellite").glob("*.png")) if (demo_dir / "satellite").exists() else []
sample_sat_images_with_csv = [p for p in sample_sat_images if p.with_suffix(".csv").exists()]
sample_sky_pairs, default_sky_path = resolve_demo_sky_pairs(demo_dir)
# Build thumbnail paths for faster UI loading
sat_thumb_dir = demo_dir / "satellite" / "thumbnails"
pano_thumb_dir = demo_dir / "panorama" / "thumbnails"
def get_thumbnail(original_path: Path) -> str:
"""Return thumbnail path if it exists, otherwise fall back to original."""
thumb_dir = sat_thumb_dir if "satellite" in str(original_path) else pano_thumb_dir
thumb_path = thumb_dir / (original_path.stem + ".jpg")
if thumb_path.exists():
return str(thumb_path)
return str(original_path)
with gr.Blocks(title="Sat3DGen Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
## [ICLR 2026] Sat3DGen: Comprehensive Street-Level 3D Scene Generation from Single Satellite Image
### [Code Page](https://github.com/qianmingduowan/Sat3DGen), [Project Page](https://qianmingduowan.github.io/Sat3DGen_project_page/)
Author: [Ming Qian](https://qianmingduowan.github.io/), [Zimin Xia](https://ziminxia.github.io/), [Changkun Liu](https://lck666666.github.io), [Shuailei Ma](https://scholar.google.com/citations?user=dNhzCu4AAAAJ&hl=zh-CN), [Wen Wang](https://encounter1997.github.io/), [Zeran Ke](https://calmke.github.io/), [Bin Tan](https://icetttb.github.io/), [Hang Zhang](https://openreview.net/profile?id=~Hang_Zhang22), [Gui-Song Xia](http://www.captain-whu.com/xia_En.html)
Upload a satellite image to **generate a 3D mesh** or **render a walkthrough video**.
📌 **Input requirements:** The satellite image should be at **zoom level 20**
(same as the [VIGOR](https://github.com/Jeff-Zilence/VIGOR) dataset), then will be resized to the input size.
You can download satellite tiles at this zoom level from any map tile API (e.g. Google Maps, Bing Maps, Mapbox).
"""
)
# Shared state
inference_state = gr.State(value=None)
mesh_file_path = gr.State(value=None)
# ---- 3D Mesh Generation ----
with gr.Tab("3D Mesh Generation"):
with gr.Row():
with gr.Column(scale=1):
sat_input = gr.Image(
label="Upload Satellite Image",
type="pil",
height=400,
)
mesh_resolution_slider = gr.Slider(
minimum=128, maximum=512, value=128, step=64,
label="Mesh Resolution (voxel size)",
)
generate_button = gr.Button("🚀 Generate 3D Mesh", variant="primary", size="lg")
with gr.Column(scale=2):
mesh_viewer = gr.Model3D(label="3D Mesh Preview", height=500)
gr.Markdown(
"⏳ *After generation completes, the 3D preview may take ~10-200 seconds to load. Please wait patiently.*"
)
download_button = gr.DownloadButton("💾 Download Mesh (.obj)", variant="secondary")
if sample_sat_images:
gr.Markdown("### Sample Images — click to load")
mesh_sat_gallery = gr.Gallery(
value=[get_thumbnail(p) for p in sample_sat_images],
label="Click to load a sample satellite image",
columns=10,
rows=3,
height="auto",
object_fit="cover",
allow_preview=False,
)
def load_sat_for_mesh(evt: gr.SelectData):
"""Load the full-resolution image when a thumbnail is clicked."""
if evt.index is None or evt.index >= len(sample_sat_images):
return None
return Image.open(str(sample_sat_images[evt.index]))
mesh_sat_gallery.select(
fn=load_sat_for_mesh,
inputs=None,
outputs=[sat_input],
)
gr.Markdown(
"⚠️ **Note:** The 3D mesh preview may show slight color distortion. "
"The cause is currently under investigation."
)
generate_button.click(
fn=generate_mesh,
inputs=[sat_input, mesh_resolution_slider],
outputs=[mesh_viewer, mesh_file_path, inference_state],
)
mesh_file_path.change(
fn=download_mesh,
inputs=[mesh_file_path],
outputs=[download_button],
)
# ---- Video Rendering ----
with gr.Tab("Video Rendering"):
# Hidden state to track the resolved trajectory .csv path
trajectory_csv_state = gr.State(value=None)
sky_path_state = gr.State(value=str(default_sky_path) if default_sky_path is not None else None)
def load_sat_from_gallery(evt: gr.SelectData):
"""Load selected satellite image and check for a same-name trajectory .csv."""
if evt.index is None or evt.index >= len(sample_sat_images_with_csv):
return None, None, "No image selected.", None
sat_path = sample_sat_images_with_csv[evt.index]
sat_pil = Image.open(str(sat_path))
csv_path = sat_path.with_suffix(".csv")
if csv_path.exists():
status_msg = f"✅ Trajectory found: `{csv_path.name}`"
_, pixel_coords = read_trajectory_from_csv(str(csv_path), sat_pil.size[0])
preview = build_trajectory_preview(sat_pil, pixel_coords)
return sat_pil, str(csv_path), status_msg, preview
status_msg = (
f"⚠️ No trajectory file found. "
f"Please pre-generate a trajectory and save it as "
f"`{csv_path.name}` in `{sat_path.parent}/` using:\n\n"
f"```\npython inference/make_trajectory.py "
f"--input_img_path {sat_path} --save_same_name\n```"
)
return sat_pil, None, status_msg, None
def on_sat_upload(sat_image_pil):
"""When user uploads a custom satellite image, no same-name trajectory CSV is available."""
if sat_image_pil is None:
return None, "No image uploaded.", None
return None, (
"⚠️ For uploaded images, you need a **trajectory .csv** file with the same name "
"as your satellite image (e.g. `my_image.csv` for `my_image.png`).\n\n"
"You can generate one interactively using either:\n\n"
"- **Jupyter Notebook** (recommended): `inference/make_trajectory.ipynb`\n"
"- **Command line**: "
"`python inference/make_trajectory.py --input_img_path <your_image_path> --save_same_name`\n\n"
"If you used the command line **without** `--save_same_name`, "
"the CSV is saved under `results/<image_name>/pixels.csv`. "
"You will need to **copy** it next to your satellite image with the same base name "
"(e.g. copy to `demo_images/satellite/my_image.csv`)."
), None
def load_sky_from_gallery(evt: gr.SelectData):
"""Select one demo panorama street image. The first entry is the default."""
if not sample_sky_pairs:
return None, None, "No demo panorama street image is available."
if evt.index is None or evt.index >= len(sample_sky_pairs):
sky_path = default_sky_path
else:
sky_path = sample_sky_pairs[evt.index][0]
default_suffix = " (Default)" if sky_path == default_sky_path else ""
status_msg = (
f"Selected demo panorama: `{sky_path.name}`{default_suffix}\n\n"
f"If you do not choose another one, this image will be used."
)
return str(sky_path), str(sky_path), status_msg
def render_video_from_state(sat_image, csv_path, sky_path, progress=gr.Progress()):
"""Render video using the pre-generated trajectory CSV."""
if sat_image is None:
raise gr.Error("Please select or upload a satellite image first.")
if csv_path is None or not Path(csv_path).exists():
raise gr.Error(
"No trajectory CSV found. Please pre-generate a trajectory using: "
"python inference/make_trajectory.py --input_img_path <image> --save_same_name"
)
resolved_sky_path = sky_path or (str(default_sky_path) if default_sky_path is not None else None)
if resolved_sky_path is None or not Path(resolved_sky_path).exists():
raise gr.Error("No valid demo panorama is available. Please add one under demo_images/panorama.")
return render_trajectory_video(sat_image, csv_path, resolved_sky_path, progress)
# ===== Main layout =====
with gr.Row(equal_height=False):
# Left column: satellite image selection
with gr.Column(scale=1):
sat_input_video = gr.Image(
label="Upload Satellite Image",
type="pil",
height=300,
)
trajectory_status = gr.Markdown(value="Select a demo image or upload your own.")
selected_sky_preview = gr.Image(
label="Selected Demo Panorama",
value=str(default_sky_path) if default_sky_path is not None else None,
height=180,
)
default_sky_message = "No demo panorama street image is available."
if default_sky_path is not None:
default_sky_message = (
f"Default demo panorama: `{default_sky_path.name}`\n\n"
"If you do not select another demo panorama, this one will be used."
)
sky_status = gr.Markdown(value=default_sky_message)
render_button = gr.Button("🎬 Render Video", variant="primary", size="lg")
gr.Markdown(
"⏳ *Running on CPU — video rendering is slow (~5 min for 80 frames). Please be patient.*"
)
# Middle column: trajectory preview
with gr.Column(scale=1):
trajectory_preview = gr.Image(label="Trajectory Preview", height=300)
# Right column: video output
with gr.Column(scale=2):
video_output = gr.Video(label="Rendered Video", height=500)
# ===== Sample Satellite Images Gallery (only those with a trajectory CSV) =====
if sample_sat_images_with_csv:
gr.Markdown("### 🛰️ Sample Satellite Images — click to load")
sat_gallery = gr.Gallery(
value=[get_thumbnail(p) for p in sample_sat_images_with_csv],
label="Sample Satellite Images (with trajectory)",
columns=10,
rows=1,
height="auto",
object_fit="cover",
allow_preview=False,
)
sat_gallery.select(
fn=load_sat_from_gallery,
inputs=None,
outputs=[sat_input_video, trajectory_csv_state, trajectory_status, trajectory_preview],
)
if sample_sky_pairs:
gr.Markdown(
"### 🌤️ Demo Panorama Street Images — the first one is the default\n\n"
"The panorama image and its corresponding sky mask are used to extract a "
"**sky region color histogram**, which serves as a **lighting condition hint** "
"during street-view rendering. This only affects the appearance (illumination/color tone) "
"of the rendered views — it does **not** alter the underlying 3D NeRF geometry."
)
sky_gallery = gr.Gallery(
value=[
(
get_thumbnail(pano_path),
f"{pano_path.name} (Default)" if pano_path == default_sky_path else pano_path.name,
)
for pano_path, _ in sample_sky_pairs
],
label="Demo Panorama Street Images",
columns=5,
rows=1,
height="auto",
object_fit="cover",
allow_preview=False,
)
sky_gallery.select(
fn=load_sky_from_gallery,
inputs=None,
outputs=[sky_path_state, selected_sky_preview, sky_status],
)
# When user uploads a custom image
sat_input_video.upload(
fn=on_sat_upload,
inputs=[sat_input_video],
outputs=[trajectory_csv_state, trajectory_status, trajectory_preview],
)
render_button.click(
fn=render_video_from_state,
inputs=[sat_input_video, trajectory_csv_state, sky_path_state],
outputs=[video_output],
)
return demo
if __name__ == "__main__":
demo = build_demo()
port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
demo.launch(
server_name="0.0.0.0",
server_port=port,
share=False,
allowed_paths=[
str(Path(__file__).resolve().parent / "demo_images"),
str(Path(__file__).resolve().parent / "results"),
],
)