qian43 commited on
Commit
23edff6
·
verified ·
1 Parent(s): 924a6c1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -21
app.py CHANGED
@@ -6,6 +6,7 @@ Two-step interactive demo:
6
  """
7
 
8
  import csv
 
9
  import os
10
  import shutil
11
  import subprocess
@@ -13,6 +14,13 @@ import tempfile
13
  from pathlib import Path
14
  from typing import List, Optional, Tuple
15
 
 
 
 
 
 
 
 
16
  import cv2
17
  import gradio as gr
18
  import numpy as np
@@ -76,12 +84,12 @@ def load_model(checkpoint_path: str = "checkpoints"):
76
 
77
  if model_path is None:
78
  model_path = HUGGINGFACE_REPO
79
- print(f"Local checkpoint not found at '{checkpoint_path}', loading from HuggingFace: {HUGGINGFACE_REPO}")
80
 
81
  # Skip redundant backbone weight download – from_pretrained will
82
  # overwrite all parameters from the safetensors file anyway.
83
  Sat3DGen._skip_backbone_weights = True
84
- print(f"Loading model from {model_path} ...")
85
  MODEL = Sat3DGen.from_pretrained(model_path).to(DEVICE)
86
  Sat3DGen._skip_backbone_weights = False
87
  MODEL.eval()
@@ -91,7 +99,7 @@ def load_model(checkpoint_path: str = "checkpoints"):
91
  T.ToTensor(),
92
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
93
  ])
94
- print("Model loaded successfully.")
95
 
96
 
97
  # ---------------------------------------------------------------------------
@@ -281,38 +289,38 @@ def generate_mesh(sat_image_pil: Image.Image, mesh_resolution: int = 256, progre
281
  if sat_image_pil is None:
282
  raise gr.Error("Please upload a satellite image first.")
283
 
284
- print("[generate_mesh] >>> Start")
285
  load_model()
286
- print("[generate_mesh] Model loaded")
287
 
288
  progress(0.1, desc="Preprocessing satellite image...")
289
- print("[generate_mesh] Preprocessing satellite image...")
290
  sat_input = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
291
 
292
  progress(0.3, desc="Generating triplane features...")
293
- print("[generate_mesh] Generating triplane features...")
294
  with torch.no_grad():
295
  triplane = MODEL.from_sat_to_triplane(sat_input)
296
- print("[generate_mesh] Triplane generated successfully")
297
 
298
  progress(0.5, desc="Extracting 3D mesh (this may take a moment)...")
299
- print(f"[generate_mesh] Extracting 3D mesh (resolution={mesh_resolution})...")
300
  with torch.no_grad():
301
  vertices, faces, vertex_colors = MODEL.extract_mesh(triplane, mesh_resolution=mesh_resolution)
302
- print(f"[generate_mesh] Mesh extracted: {vertices.shape[0]} vertices, {faces.shape[0]} faces")
303
 
304
  vertices = vertices[:, [1, 2, 0]]
305
 
306
  # Save mesh
307
  mesh_path = str(RESULTS_DIR / "mesh.obj")
308
  save_obj(vertices, faces, vertex_colors, mesh_path)
309
- print(f"[generate_mesh] OBJ saved to {mesh_path}")
310
 
311
  # Also save triplane to state for Step 2
312
  state = {"triplane": triplane, "sat_image": sat_image_pil}
313
 
314
  progress(0.9, desc="Preparing 3D visualization...")
315
- print("[generate_mesh] Converting OBJ → GLB for 3D preview...")
316
 
317
  # Create a glb file for Gradio's Model3D component.
318
  # Use a tempfile so Gradio can reliably serve it via its file cache.
@@ -329,17 +337,17 @@ def generate_mesh(sat_image_pil: Image.Image, mesh_resolution: int = 256, progre
329
  raise gr.Error("Failed to load mesh geometry.")
330
  if not hasattr(mesh_trimesh, 'vertex_normals') or mesh_trimesh.vertex_normals is None or len(mesh_trimesh.vertex_normals) == 0:
331
  mesh_trimesh.vertex_normals # triggers auto-computation
332
- print(f"[generate_mesh] Mesh has {len(mesh_trimesh.vertices)} verts, {len(mesh_trimesh.faces)} faces, normals: {mesh_trimesh.vertex_normals.shape}")
333
  mesh_trimesh.export(glb_path_local, file_type="glb")
334
- print(f"[generate_mesh] GLB saved to {glb_path_local} ({os.path.getsize(glb_path_local)} bytes)")
335
 
336
  tmp_glb = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
337
  shutil.copy2(glb_path_local, tmp_glb.name)
338
  tmp_glb.close()
339
- print(f"[generate_mesh] GLB copied to temp file: {tmp_glb.name}")
340
 
341
  progress(1.0, desc="Done!")
342
- print("[generate_mesh] <<< 3D mesh generated successfully!")
343
  return tmp_glb.name, mesh_path, state
344
 
345
 
@@ -366,14 +374,14 @@ def render_trajectory_video(
366
  Top row: satellite image (with camera marker) | panorama RGB
367
  Bottom row: 4 perspective views in a horizontal row (left, front, right, back)
368
  """
369
- print("[render_trajectory_video] >>> Start")
370
  load_model()
371
 
372
  sat_size = sat_image_pil.size[0]
373
  positions, pixel_coords = read_trajectory_from_csv(trajectory_csv_path, sat_size)
374
  if len(positions) == 0:
375
  raise gr.Error(f"Trajectory file is empty: {trajectory_csv_path}")
376
- print(f"[render_trajectory_video] Loaded {len(positions)} positions from {trajectory_csv_path}")
377
 
378
  progress(0.1, desc="Extracting triplane features...")
379
  sat_tensor = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
@@ -399,7 +407,7 @@ def render_trajectory_video(
399
  for idx, position in enumerate(positions):
400
  progress(0.25 + 0.6 * idx / total_positions, desc=f"Rendering frame {idx + 1}/{total_positions}...")
401
  if idx % 10 == 0 or idx == total_positions - 1:
402
- print(f"[render_trajectory_video] Rendering frame {idx + 1}/{total_positions}...")
403
 
404
  c2w = position_to_c2w(position)
405
  c2w[:, :3, 3] = c2w[:, :3, 3] * MODEL.position_scale_factor
@@ -455,7 +463,7 @@ def render_trajectory_video(
455
  cv2.imwrite(str(frame_path), cv2.cvtColor(composed, cv2.COLOR_RGB2BGR))
456
 
457
  progress(0.9, desc="Encoding video...")
458
- print("[render_trajectory_video] All frames rendered, encoding video with ffmpeg...")
459
  video_path = str(RESULTS_DIR / "trajectory_video.mp4")
460
  ffmpeg_path = shutil.which("ffmpeg")
461
  if ffmpeg_path is None:
@@ -469,7 +477,7 @@ def render_trajectory_video(
469
  video_path,
470
  ], check=True, capture_output=True)
471
 
472
- print(f"[render_trajectory_video] Video saved to {video_path}")
473
  progress(1.0, desc="Done!")
474
  return video_path
475
 
@@ -533,6 +541,9 @@ def build_demo():
533
 
534
  with gr.Column(scale=2):
535
  mesh_viewer = gr.Model3D(label="3D Mesh Preview", height=500)
 
 
 
536
  download_button = gr.DownloadButton("💾 Download Mesh (.obj)", variant="secondary")
537
 
538
  if sample_sat_images:
@@ -670,6 +681,9 @@ def build_demo():
670
  )
671
  sky_status = gr.Markdown(value=default_sky_message)
672
  render_button = gr.Button("🎬 Render Video", variant="primary", size="lg")
 
 
 
673
 
674
  # Middle column: trajectory preview
675
  with gr.Column(scale=1):
 
6
  """
7
 
8
  import csv
9
+ import datetime
10
  import os
11
  import shutil
12
  import subprocess
 
14
  from pathlib import Path
15
  from typing import List, Optional, Tuple
16
 
17
+
18
+ def log(msg: str):
19
+ """Print with Beijing time (UTC+8) prefix."""
20
+ beijing_time = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8)))
21
+ timestamp = beijing_time.strftime("%Y-%m-%d %H:%M:%S")
22
+ print(f"[{timestamp}] {msg}")
23
+
24
  import cv2
25
  import gradio as gr
26
  import numpy as np
 
84
 
85
  if model_path is None:
86
  model_path = HUGGINGFACE_REPO
87
+ log(f"Local checkpoint not found at '{checkpoint_path}', loading from HuggingFace: {HUGGINGFACE_REPO}")
88
 
89
  # Skip redundant backbone weight download – from_pretrained will
90
  # overwrite all parameters from the safetensors file anyway.
91
  Sat3DGen._skip_backbone_weights = True
92
+ log(f"Loading model from {model_path} ...")
93
  MODEL = Sat3DGen.from_pretrained(model_path).to(DEVICE)
94
  Sat3DGen._skip_backbone_weights = False
95
  MODEL.eval()
 
99
  T.ToTensor(),
100
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
101
  ])
102
+ log("Model loaded successfully.")
103
 
104
 
105
  # ---------------------------------------------------------------------------
 
289
  if sat_image_pil is None:
290
  raise gr.Error("Please upload a satellite image first.")
291
 
292
+ log("[generate_mesh] >>> Start")
293
  load_model()
294
+ log("[generate_mesh] Model loaded")
295
 
296
  progress(0.1, desc="Preprocessing satellite image...")
297
+ log("[generate_mesh] Preprocessing satellite image...")
298
  sat_input = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
299
 
300
  progress(0.3, desc="Generating triplane features...")
301
+ log("[generate_mesh] Generating triplane features...")
302
  with torch.no_grad():
303
  triplane = MODEL.from_sat_to_triplane(sat_input)
304
+ log("[generate_mesh] Triplane generated successfully")
305
 
306
  progress(0.5, desc="Extracting 3D mesh (this may take a moment)...")
307
+ log(f"[generate_mesh] Extracting 3D mesh (resolution={mesh_resolution})...")
308
  with torch.no_grad():
309
  vertices, faces, vertex_colors = MODEL.extract_mesh(triplane, mesh_resolution=mesh_resolution)
310
+ log(f"[generate_mesh] Mesh extracted: {vertices.shape[0]} vertices, {faces.shape[0]} faces")
311
 
312
  vertices = vertices[:, [1, 2, 0]]
313
 
314
  # Save mesh
315
  mesh_path = str(RESULTS_DIR / "mesh.obj")
316
  save_obj(vertices, faces, vertex_colors, mesh_path)
317
+ log(f"[generate_mesh] OBJ saved to {mesh_path}")
318
 
319
  # Also save triplane to state for Step 2
320
  state = {"triplane": triplane, "sat_image": sat_image_pil}
321
 
322
  progress(0.9, desc="Preparing 3D visualization...")
323
+ log("[generate_mesh] Converting OBJ → GLB for 3D preview...")
324
 
325
  # Create a glb file for Gradio's Model3D component.
326
  # Use a tempfile so Gradio can reliably serve it via its file cache.
 
337
  raise gr.Error("Failed to load mesh geometry.")
338
  if not hasattr(mesh_trimesh, 'vertex_normals') or mesh_trimesh.vertex_normals is None or len(mesh_trimesh.vertex_normals) == 0:
339
  mesh_trimesh.vertex_normals # triggers auto-computation
340
+ log(f"[generate_mesh] Mesh has {len(mesh_trimesh.vertices)} verts, {len(mesh_trimesh.faces)} faces, normals: {mesh_trimesh.vertex_normals.shape}")
341
  mesh_trimesh.export(glb_path_local, file_type="glb")
342
+ log(f"[generate_mesh] GLB saved to {glb_path_local} ({os.path.getsize(glb_path_local)} bytes)")
343
 
344
  tmp_glb = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
345
  shutil.copy2(glb_path_local, tmp_glb.name)
346
  tmp_glb.close()
347
+ log(f"[generate_mesh] GLB copied to temp file: {tmp_glb.name}")
348
 
349
  progress(1.0, desc="Done!")
350
+ log("[generate_mesh] <<< 3D mesh generated successfully!")
351
  return tmp_glb.name, mesh_path, state
352
 
353
 
 
374
  Top row: satellite image (with camera marker) | panorama RGB
375
  Bottom row: 4 perspective views in a horizontal row (left, front, right, back)
376
  """
377
+ log("[render_trajectory_video] >>> Start")
378
  load_model()
379
 
380
  sat_size = sat_image_pil.size[0]
381
  positions, pixel_coords = read_trajectory_from_csv(trajectory_csv_path, sat_size)
382
  if len(positions) == 0:
383
  raise gr.Error(f"Trajectory file is empty: {trajectory_csv_path}")
384
+ log(f"[render_trajectory_video] Loaded {len(positions)} positions from {trajectory_csv_path}")
385
 
386
  progress(0.1, desc="Extracting triplane features...")
387
  sat_tensor = SAT_TRANSFORM(sat_image_pil.convert("RGB")).unsqueeze(0).to(DEVICE)
 
407
  for idx, position in enumerate(positions):
408
  progress(0.25 + 0.6 * idx / total_positions, desc=f"Rendering frame {idx + 1}/{total_positions}...")
409
  if idx % 10 == 0 or idx == total_positions - 1:
410
+ log(f"[render_trajectory_video] Rendering frame {idx + 1}/{total_positions}...")
411
 
412
  c2w = position_to_c2w(position)
413
  c2w[:, :3, 3] = c2w[:, :3, 3] * MODEL.position_scale_factor
 
463
  cv2.imwrite(str(frame_path), cv2.cvtColor(composed, cv2.COLOR_RGB2BGR))
464
 
465
  progress(0.9, desc="Encoding video...")
466
+ log("[render_trajectory_video] All frames rendered, encoding video with ffmpeg...")
467
  video_path = str(RESULTS_DIR / "trajectory_video.mp4")
468
  ffmpeg_path = shutil.which("ffmpeg")
469
  if ffmpeg_path is None:
 
477
  video_path,
478
  ], check=True, capture_output=True)
479
 
480
+ log(f"[render_trajectory_video] Video saved to {video_path}")
481
  progress(1.0, desc="Done!")
482
  return video_path
483
 
 
541
 
542
  with gr.Column(scale=2):
543
  mesh_viewer = gr.Model3D(label="3D Mesh Preview", height=500)
544
+ gr.Markdown(
545
+ "⏳ *After generation completes, the 3D preview may take ~10 seconds to load. Please wait.*"
546
+ )
547
  download_button = gr.DownloadButton("💾 Download Mesh (.obj)", variant="secondary")
548
 
549
  if sample_sat_images:
 
681
  )
682
  sky_status = gr.Markdown(value=default_sky_message)
683
  render_button = gr.Button("🎬 Render Video", variant="primary", size="lg")
684
+ gr.Markdown(
685
+ "⏳ *Running on CPU — video rendering is slow (~5 min for 80 frames). Please be patient.*"
686
+ )
687
 
688
  # Middle column: trajectory preview
689
  with gr.Column(scale=1):