| |
| """ |
| ERPT Pipeline 主入口(Forward Warp,深度估计可选) |
| |
| 使用方法: |
| # 默认:使用已有深度真值做 warp(不加载深度估计权重) |
| python run_pipeline.py --stage warp_only --data_dir /path/to/scene |
| |
| # 强制完整流程(深度估计 + warp) |
| python run_pipeline.py --stage all |
| |
| # 仅深度估计 |
| python run_pipeline.py --stage depth_only |
| """ |
|
|
| import argparse |
| import re |
| import time |
| from pathlib import Path |
| from typing import Dict, Any, List, Optional |
|
|
| import yaml |
| import numpy as np |
| import torch |
| import cv2 |
|
|
| |
| import sys |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| |
| from core.erp_warp import warp_erp_to_target, WarpResult, create_comparison_image |
| from utils.io_utils import load_image, save_image, load_json, save_json, save_depth |
| from utils.pose_utils import Pose, load_pose |
|
|
| |
| _depth_modules_loaded = False |
|
|
|
|
| def _load_depth_modules(): |
| """延迟加载深度估计模块(避免 warp_only 模式加载大模型权重)""" |
| global _depth_modules_loaded |
| if _depth_modules_loaded: |
| return |
| global build_icosahedron_slices, extract_all_tangents, compute_coverage_mask |
| global estimate_all_tangent_depths |
| global fuse_tangent_depths_to_erp, save_depth_visualization, visualize_depth |
|
|
| from core.tangent_extraction import ( |
| build_icosahedron_slices, |
| extract_all_tangents, |
| compute_coverage_mask, |
| ) |
| from core.depth_estimation import estimate_all_tangent_depths |
| from core.depth_fusion import ( |
| fuse_tangent_depths_to_erp, |
| save_depth_visualization, |
| visualize_depth, |
| ) |
| _depth_modules_loaded = True |
| print("[Depth] 深度估计模块已加载") |
|
|
|
|
| |
| |
| |
|
|
| def discover_image_files(directory: Path) -> dict: |
| """自动发现目录中的全景图文件""" |
| image_extensions = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'] |
| image_files = [] |
| for ext in image_extensions: |
| image_files.extend(directory.glob(f"*{ext}")) |
| image_files = sorted(image_files) |
| if not image_files: |
| return {} |
|
|
| result = {} |
| for img_path in image_files: |
| stem = img_path.stem |
| match = re.search(r'[_-](\d+)$', stem) |
| if match: |
| result[int(match.group(1))] = img_path |
| continue |
| if stem.isdigit(): |
| result[int(stem)] = img_path |
|
|
| if not result: |
| for idx, img_path in enumerate(image_files): |
| result[idx] = img_path |
|
|
| return result |
|
|
|
|
| def discover_pose_files(directory: Path) -> dict: |
| """自动发现目录中的位姿文件""" |
| pose_files = sorted(directory.glob("*.json")) |
| result = {} |
| for pose_path in pose_files: |
| stem = pose_path.stem |
| if stem in ['meta', 'config', 'stats', 'cameras', 'render_meta', 'description']: |
| continue |
| match = re.search(r'[_-](\d+)$', stem) |
| if match: |
| result[int(match.group(1))] = pose_path |
| continue |
| if stem.isdigit(): |
| result[int(stem)] = pose_path |
|
|
| return result |
|
|
|
|
| |
| |
| |
|
|
| def load_config(config_path: Path) -> Dict[str, Any]: |
| with open(config_path, "r", encoding="utf-8") as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def resolve_paths(cfg: Dict[str, Any], config_dir: Path) -> Dict[str, Any]: |
| """解析相对路径为绝对路径""" |
| data_cfg = cfg.get("data", {}) |
| for key in ["data_dir", "output_dir", "depth_dir"]: |
| if key in data_cfg and data_cfg[key]: |
| path = Path(data_cfg[key]) |
| if not path.is_absolute(): |
| data_cfg[key] = str(config_dir / path) |
|
|
| depth_pro_cfg = cfg.get("depth_pro", {}) |
| if "repo_dir" in depth_pro_cfg and depth_pro_cfg["repo_dir"]: |
| rp = Path(depth_pro_cfg["repo_dir"]) |
| if not rp.is_absolute(): |
| depth_pro_cfg["repo_dir"] = str(config_dir / rp) |
|
|
| cfg["_project_root"] = str(config_dir) |
| return cfg |
|
|
|
|
| |
| |
| |
|
|
| def run_depth_pipeline( |
| center_rgb: np.ndarray, |
| cfg: Dict[str, Any], |
| device: torch.device, |
| output_dir: Path, |
| erp_h: int, |
| erp_w: int, |
| frame_id: int = 0, |
| ) -> np.ndarray: |
| """运行深度估计全流程:切片 -> 推理 -> 融合""" |
| _load_depth_modules() |
|
|
| depth_out_dir = output_dir / "depth_erp" |
| depth_out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| print(f"\n{'='*60}") |
| print(f"[Step 1] Building tangent slices (frame {frame_id})") |
| print(f"{'='*60}") |
|
|
| if "erp" not in cfg: |
| cfg["erp"] = {} |
| cfg["erp"]["height"] = erp_h |
| cfg["erp"]["width"] = erp_w |
|
|
| slices = build_icosahedron_slices(cfg) |
| print(f" Total slices: {len(slices)}") |
| for s in slices: |
| if s.slice_type != "face": |
| print(f" {s.slice_id}: type={s.slice_type}, fov={s.fov_deg:.1f}°") |
|
|
| coverage_mask, coverage_stats = compute_coverage_mask(slices, erp_h, erp_w, device) |
| print(f" Coverage: {coverage_stats['total_coverage']:.2f}%") |
|
|
| dbg_dir = output_dir / "debug" |
| dbg_dir.mkdir(parents=True, exist_ok=True) |
| save_image(np.stack([coverage_mask] * 3, axis=-1), dbg_dir / "coverage_mask.png") |
|
|
| |
| print(f"\n{'='*60}") |
| print(f"[Step 2] Extracting tangent slices (frame {frame_id})") |
| print(f"{'='*60}") |
|
|
| t0 = time.time() |
| tangent_rgbs = extract_all_tangents(center_rgb, slices, device) |
| print(f" Extracted {len(tangent_rgbs)} slices in {time.time()-t0:.2f}s") |
|
|
| if cfg.get("run", {}).get("save_intermediates", False): |
| tangent_dir = output_dir / "tangents" |
| tangent_dir.mkdir(parents=True, exist_ok=True) |
| for slice_id, rgb in tangent_rgbs.items(): |
| save_image(rgb, tangent_dir / f"{slice_id}_rgb.png") |
|
|
| |
| print(f"\n{'='*60}") |
| print(f"[Step 3] Running Depth Pro inference (frame {frame_id})") |
| print(f"{'='*60}") |
|
|
| dp_cfg = cfg.get("depth_pro", {}) |
| if not bool(dp_cfg.get("enabled", True)): |
| print(" [Warning] Depth Pro disabled, using dummy depth") |
| tangent_depths = {} |
| for sid, rgb in tangent_rgbs.items(): |
| tangent_depths[sid] = np.full(rgb.shape[:2], 5.0, dtype=np.float32) |
| else: |
| t0 = time.time() |
| tangent_depths = estimate_all_tangent_depths( |
| tangent_rgbs, slices, cfg, device, |
| ) |
| print(f" Estimated {len(tangent_depths)} depths in {time.time()-t0:.2f}s") |
|
|
| if cfg.get("run", {}).get("save_intermediates", False): |
| tangent_dir = output_dir / "tangents" |
| for sid, depth in tangent_depths.items(): |
| save_depth(depth, tangent_dir / f"{sid}_depth.npy") |
|
|
| |
| print(f"\n{'='*60}") |
| print(f"[Step 4] Fusing tangent depths to ERP (frame {frame_id})") |
| print(f"{'='*60}") |
|
|
| t0 = time.time() |
| depth_erp, weight_sum, valid_mask = fuse_tangent_depths_to_erp( |
| tangent_depths, slices, cfg, device, |
| debug_dir=dbg_dir if cfg.get("run", {}).get("save_intermediates", False) else None, |
| ) |
| print(f" Fused in {time.time()-t0:.2f}s") |
|
|
| valid_ratio = np.sum(valid_mask > 0) / (erp_h * erp_w) |
| valid_depths = depth_erp[np.isfinite(depth_erp) & (depth_erp > 0)] |
| if len(valid_depths) > 0: |
| print(f" Valid depth ratio: {valid_ratio * 100:.2f}%") |
| print(f" Depth range: [{valid_depths.min():.2f}, {valid_depths.max():.2f}] m") |
|
|
| |
| save_depth(depth_erp, depth_out_dir / f"depth_{frame_id:04d}.npy") |
| save_depth_visualization(depth_erp, depth_out_dir / f"depth_{frame_id:04d}_vis.png") |
| cv2.imwrite(str(depth_out_dir / f"depth_{frame_id:04d}_valid_mask.png"), valid_mask * 255) |
|
|
| return depth_erp |
|
|
|
|
| |
| |
| |
|
|
| def run_warp_pipeline( |
| center_rgb: np.ndarray, |
| depth_erp: np.ndarray, |
| center_frame: int, |
| image_files: dict, |
| pose_files: dict, |
| cfg: Dict[str, Any], |
| device: torch.device, |
| output_dir: Path, |
| erp_h: int, |
| erp_w: int, |
| ) -> None: |
| """运行 warp 全流程:遍历目标帧,执行 forward splatting""" |
| warp_cfg = cfg.get("warp", {}) |
| output_depth = bool(warp_cfg.get("output_depth", True)) |
|
|
| |
| available_targets = sorted([fid for fid in pose_files.keys() if fid != center_frame]) |
|
|
| cfg_targets = warp_cfg.get("target_frames", None) |
| if cfg_targets is not None and cfg_targets != "auto": |
| cfg_set = set(int(t) for t in cfg_targets) |
| target_frames = [fid for fid in available_targets if fid in cfg_set] |
| else: |
| target_frames = available_targets |
|
|
| print(f"\n{'='*60}") |
| print(f"[Warp] Forward splatting from frame {center_frame}") |
| print(f"{'='*60}") |
| print(f" Method: {warp_cfg.get('method', 'softmax_splatting')}") |
| print(f" Available targets with pose: {available_targets}") |
| print(f" Will warp: {target_frames}") |
|
|
| |
| if center_frame not in pose_files: |
| print(f" [Error] Center pose not found for frame {center_frame}") |
| return |
| src_pose = load_pose(pose_files[center_frame]) |
| print(f" Source pose: position={src_pose.position.tolist()}") |
|
|
| |
| warp_rgb_dir = output_dir / "warp_rgb" |
| warp_rgb_dir.mkdir(parents=True, exist_ok=True) |
| if output_depth: |
| warp_depth_dir = output_dir / "warp_depth" |
| warp_depth_dir.mkdir(parents=True, exist_ok=True) |
|
|
| total_warp = len(target_frames) |
| for idx, tgt_id in enumerate(target_frames): |
| if tgt_id not in pose_files: |
| print(f" [{idx+1}/{total_warp}] Frame {tgt_id}: pose not found, skip") |
| continue |
|
|
| tgt_pose = load_pose(pose_files[tgt_id]) |
| print(f" [{idx+1}/{total_warp}] Frame {center_frame} -> {tgt_id} ...", end="", flush=True) |
|
|
| t0 = time.time() |
| result = warp_erp_to_target( |
| src_rgb=center_rgb, |
| src_depth=depth_erp, |
| src_pose=src_pose, |
| tgt_pose=tgt_pose, |
| cfg=cfg, |
| device=device, |
| ) |
| dt = time.time() - t0 |
|
|
| valid_pct = result.valid_mask.sum() / result.valid_mask.size * 100 |
| print(f" done ({dt:.2f}s, valid={valid_pct:.1f}%)") |
|
|
| prefix = f"pano{center_frame:04d}_to_pano{tgt_id:04d}" |
|
|
| |
| save_image(result.warped_rgb, warp_rgb_dir / f"{prefix}_rgb.png") |
|
|
| |
| cv2.imwrite(str(warp_rgb_dir / f"{prefix}_mask.png"), result.valid_mask * 255) |
|
|
| |
| if output_depth and result.warped_depth is not None: |
| save_depth(result.warped_depth, warp_depth_dir / f"{prefix}_depth_range.npy") |
|
|
| print(f" Warp complete. Output saved to: {warp_rgb_dir}") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| _script_dir = Path(__file__).parent |
| _default_config = _script_dir / "config.yaml" |
|
|
| parser = argparse.ArgumentParser(description="ERPT Pipeline") |
| parser.add_argument("--config", type=str, |
| default=str(_default_config) if _default_config.exists() else None, |
| help="Config file path") |
| parser.add_argument("--data_dir", type=str, default=None, |
| help="Data directory (overrides config)") |
| parser.add_argument("--output_dir", type=str, default=None, |
| help="Output directory (overrides config)") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--stage", type=str, default="warp_only", |
| choices=["all", "depth_only", "warp_only"]) |
| parser.add_argument("--center_frame", type=int, default=None, |
| help="Center frame ID (overrides config)") |
| args = parser.parse_args() |
|
|
| |
| if args.config: |
| config_path = Path(args.config) |
| cfg = load_config(config_path) |
| cfg = resolve_paths(cfg, config_path.parent) |
| else: |
| cfg = { |
| "data": {}, |
| "erp": {"auto_size": True}, |
| "tangent": {}, |
| "depth_pro": {"enabled": True, "precision": "fp16", "pass_f_px": True}, |
| "fusion": {"blend_mode": "multiband", "output_scale": 1.10, "k": 4}, |
| "run": {"save_intermediates": False}, |
| } |
|
|
| |
| if args.data_dir: |
| cfg["data"]["data_dir"] = str(Path(args.data_dir).resolve()) |
| if args.output_dir: |
| cfg["data"]["output_dir"] = args.output_dir |
|
|
| data_dir = Path(cfg["data"].get("data_dir", "inputs")) |
| output_dir = Path(cfg["data"].get("output_dir", "outputs")) |
|
|
| device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu") |
| print(f"Using device: {device}") |
|
|
| center_frame = args.center_frame or int(cfg.get("warp", {}).get("center_frame", 0)) |
|
|
| print(f"\n{'='*60}") |
| print("ERPT Pipeline") |
| print(f"{'='*60}") |
| print(f"Stage: {args.stage}") |
| print(f"Data dir: {data_dir}") |
| print(f"Output dir: {output_dir}") |
|
|
| t_start = time.time() |
|
|
| |
| print(f"\n{'='*60}") |
| print("[Loading data]") |
| print(f"{'='*60}") |
|
|
| image_files = discover_image_files(data_dir) |
| pose_files = discover_pose_files(data_dir) |
| print(f" Found {len(image_files)} images, {len(pose_files)} poses") |
|
|
| if not image_files: |
| raise FileNotFoundError(f"No image files found in: {data_dir}") |
|
|
| if center_frame not in image_files: |
| center_frame = sorted(image_files.keys())[0] |
| print(f" Using frame {center_frame} as center") |
|
|
| center_rgb = load_image(image_files[center_frame]) |
| print(f" Center image: {image_files[center_frame].name}") |
| print(f" Shape: {center_rgb.shape}") |
|
|
| erp_cfg = cfg.get("erp", {}) |
| if bool(erp_cfg.get("auto_size", True)): |
| erp_h, erp_w = center_rgb.shape[:2] |
| print(f" Auto size: {erp_w}x{erp_h}") |
| else: |
| erp_h = int(erp_cfg.get("height", 2048)) |
| erp_w = int(erp_cfg.get("width", 4096)) |
|
|
| |
| depth_erp = None |
|
|
| if args.stage == "all": |
| print(f"\n [Stage: all] 强制执行深度估计") |
| depth_erp = run_depth_pipeline( |
| center_rgb, cfg, device, output_dir, erp_h, erp_w, center_frame, |
| ) |
|
|
| elif args.stage == "depth_only": |
| depth_erp = run_depth_pipeline( |
| center_rgb, cfg, device, output_dir, erp_h, erp_w, center_frame, |
| ) |
|
|
| elif args.stage == "warp_only": |
| |
| depth_candidates = [] |
|
|
| if center_frame in image_files: |
| stem = image_files[center_frame].stem |
| depth_candidates.append(data_dir / f"{stem}_depth.npy") |
| depth_candidates.append(data_dir / f"{stem}_depth.exr") |
| depth_candidates.append(data_dir / f"{stem}.npy") |
| depth_candidates.append(data_dir / f"depth_{center_frame:04d}.npy") |
| depth_candidates.append(output_dir / "depth_erp" / f"depth_{center_frame:04d}.npy") |
|
|
| for dp in depth_candidates: |
| if dp.exists(): |
| if dp.suffix == ".exr": |
| depth_erp = cv2.imread(str(dp), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) |
| if depth_erp is not None and depth_erp.ndim == 3: |
| depth_erp = depth_erp[:, :, 0] |
| depth_erp = depth_erp.astype(np.float32) if depth_erp is not None else None |
| else: |
| depth_erp = np.load(str(dp)).astype(np.float32) |
| if depth_erp is not None: |
| print(f" Loaded depth from {dp}") |
| break |
|
|
| |
| if depth_erp is not None and depth_erp.shape != (erp_h, erp_w): |
| old_shape = depth_erp.shape |
| depth_erp = cv2.resize(depth_erp, (erp_w, erp_h), interpolation=cv2.INTER_LINEAR) |
| print(f" [Warning] Depth resized: {old_shape} -> ({erp_h}, {erp_w})") |
|
|
| |
| if depth_erp is None: |
| tried = "\n ".join(str(p) for p in depth_candidates) |
| raise FileNotFoundError( |
| f"[warp_only] 未找到深度文件,无法执行 warp。\n" |
| f"已搜索路径:\n {tried}\n" |
| f"如需深度估计请使用 --stage all" |
| ) |
|
|
| |
| warp_cfg = cfg.get("warp", {}) |
| warp_enabled = bool(warp_cfg.get("enabled", True)) |
|
|
| if args.stage in ("all", "warp_only") and warp_enabled: |
| run_warp_pipeline( |
| center_rgb, depth_erp, center_frame, |
| image_files, pose_files, |
| cfg, device, output_dir, erp_h, erp_w, |
| ) |
|
|
| |
| total_time = time.time() - t_start |
| print(f"\n{'='*60}") |
| print("Pipeline Complete") |
| print(f"{'='*60}") |
| print(f"Total time: {total_time:.2f}s") |
| print(f"Output saved to: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|