cmevs-code / pipelines /run_pipeline.py
anon-cmevs-2026's picture
Initial code release for NeurIPS 2026 D&B reviewer reference
5c1bb37 verified
#!/usr/bin/env python3
"""
ERPT Pipeline 主入口(Forward Warp,深度估计可选)
使用方法:
# 默认:使用已有深度真值做 warp(不加载深度估计权重)
python run_pipeline.py --stage warp_only --data_dir /path/to/scene
# 强制完整流程(深度估计 + warp)
python run_pipeline.py --stage all
# 仅深度估计
python run_pipeline.py --stage depth_only
"""
import argparse
import re
import time
from pathlib import Path
from typing import Dict, Any, List, Optional
import yaml
import numpy as np
import torch
import cv2
# 添加模块路径
import sys
sys.path.insert(0, str(Path(__file__).parent))
# Warp 相关(始终加载)
from core.erp_warp import warp_erp_to_target, WarpResult, create_comparison_image
from utils.io_utils import load_image, save_image, load_json, save_json, save_depth
from utils.pose_utils import Pose, load_pose
# 深度估计相关(延迟加载,仅 depth_only / all 模式才 import)
_depth_modules_loaded = False
def _load_depth_modules():
"""延迟加载深度估计模块(避免 warp_only 模式加载大模型权重)"""
global _depth_modules_loaded
if _depth_modules_loaded:
return
global build_icosahedron_slices, extract_all_tangents, compute_coverage_mask
global estimate_all_tangent_depths
global fuse_tangent_depths_to_erp, save_depth_visualization, visualize_depth
from core.tangent_extraction import (
build_icosahedron_slices,
extract_all_tangents,
compute_coverage_mask,
)
from core.depth_estimation import estimate_all_tangent_depths
from core.depth_fusion import (
fuse_tangent_depths_to_erp,
save_depth_visualization,
visualize_depth,
)
_depth_modules_loaded = True
print("[Depth] 深度估计模块已加载")
# =============================================================================
# 数据发现
# =============================================================================
def discover_image_files(directory: Path) -> dict:
"""自动发现目录中的全景图文件"""
image_extensions = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']
image_files = []
for ext in image_extensions:
image_files.extend(directory.glob(f"*{ext}"))
image_files = sorted(image_files)
if not image_files:
return {}
result = {}
for img_path in image_files:
stem = img_path.stem
match = re.search(r'[_-](\d+)$', stem)
if match:
result[int(match.group(1))] = img_path
continue
if stem.isdigit():
result[int(stem)] = img_path
if not result:
for idx, img_path in enumerate(image_files):
result[idx] = img_path
return result
def discover_pose_files(directory: Path) -> dict:
"""自动发现目录中的位姿文件"""
pose_files = sorted(directory.glob("*.json"))
result = {}
for pose_path in pose_files:
stem = pose_path.stem
if stem in ['meta', 'config', 'stats', 'cameras', 'render_meta', 'description']:
continue
match = re.search(r'[_-](\d+)$', stem)
if match:
result[int(match.group(1))] = pose_path
continue
if stem.isdigit():
result[int(stem)] = pose_path
return result
# =============================================================================
# 配置加载
# =============================================================================
def load_config(config_path: Path) -> Dict[str, Any]:
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def resolve_paths(cfg: Dict[str, Any], config_dir: Path) -> Dict[str, Any]:
"""解析相对路径为绝对路径"""
data_cfg = cfg.get("data", {})
for key in ["data_dir", "output_dir", "depth_dir"]:
if key in data_cfg and data_cfg[key]:
path = Path(data_cfg[key])
if not path.is_absolute():
data_cfg[key] = str(config_dir / path)
depth_pro_cfg = cfg.get("depth_pro", {})
if "repo_dir" in depth_pro_cfg and depth_pro_cfg["repo_dir"]:
rp = Path(depth_pro_cfg["repo_dir"])
if not rp.is_absolute():
depth_pro_cfg["repo_dir"] = str(config_dir / rp)
cfg["_project_root"] = str(config_dir)
return cfg
# =============================================================================
# 深度估计流程(仅 all / depth_only 模式调用)
# =============================================================================
def run_depth_pipeline(
center_rgb: np.ndarray,
cfg: Dict[str, Any],
device: torch.device,
output_dir: Path,
erp_h: int,
erp_w: int,
frame_id: int = 0,
) -> np.ndarray:
"""运行深度估计全流程:切片 -> 推理 -> 融合"""
_load_depth_modules()
depth_out_dir = output_dir / "depth_erp"
depth_out_dir.mkdir(parents=True, exist_ok=True)
# --- Step 1: 构建切片规格 ---
print(f"\n{'='*60}")
print(f"[Step 1] Building tangent slices (frame {frame_id})")
print(f"{'='*60}")
if "erp" not in cfg:
cfg["erp"] = {}
cfg["erp"]["height"] = erp_h
cfg["erp"]["width"] = erp_w
slices = build_icosahedron_slices(cfg)
print(f" Total slices: {len(slices)}")
for s in slices:
if s.slice_type != "face":
print(f" {s.slice_id}: type={s.slice_type}, fov={s.fov_deg:.1f}°")
coverage_mask, coverage_stats = compute_coverage_mask(slices, erp_h, erp_w, device)
print(f" Coverage: {coverage_stats['total_coverage']:.2f}%")
dbg_dir = output_dir / "debug"
dbg_dir.mkdir(parents=True, exist_ok=True)
save_image(np.stack([coverage_mask] * 3, axis=-1), dbg_dir / "coverage_mask.png")
# --- Step 2: 提取切片 ---
print(f"\n{'='*60}")
print(f"[Step 2] Extracting tangent slices (frame {frame_id})")
print(f"{'='*60}")
t0 = time.time()
tangent_rgbs = extract_all_tangents(center_rgb, slices, device)
print(f" Extracted {len(tangent_rgbs)} slices in {time.time()-t0:.2f}s")
if cfg.get("run", {}).get("save_intermediates", False):
tangent_dir = output_dir / "tangents"
tangent_dir.mkdir(parents=True, exist_ok=True)
for slice_id, rgb in tangent_rgbs.items():
save_image(rgb, tangent_dir / f"{slice_id}_rgb.png")
# --- Step 3: Depth Pro 推理 ---
print(f"\n{'='*60}")
print(f"[Step 3] Running Depth Pro inference (frame {frame_id})")
print(f"{'='*60}")
dp_cfg = cfg.get("depth_pro", {})
if not bool(dp_cfg.get("enabled", True)):
print(" [Warning] Depth Pro disabled, using dummy depth")
tangent_depths = {}
for sid, rgb in tangent_rgbs.items():
tangent_depths[sid] = np.full(rgb.shape[:2], 5.0, dtype=np.float32)
else:
t0 = time.time()
tangent_depths = estimate_all_tangent_depths(
tangent_rgbs, slices, cfg, device,
)
print(f" Estimated {len(tangent_depths)} depths in {time.time()-t0:.2f}s")
if cfg.get("run", {}).get("save_intermediates", False):
tangent_dir = output_dir / "tangents"
for sid, depth in tangent_depths.items():
save_depth(depth, tangent_dir / f"{sid}_depth.npy")
# --- Step 4: 融合到 ERP ---
print(f"\n{'='*60}")
print(f"[Step 4] Fusing tangent depths to ERP (frame {frame_id})")
print(f"{'='*60}")
t0 = time.time()
depth_erp, weight_sum, valid_mask = fuse_tangent_depths_to_erp(
tangent_depths, slices, cfg, device,
debug_dir=dbg_dir if cfg.get("run", {}).get("save_intermediates", False) else None,
)
print(f" Fused in {time.time()-t0:.2f}s")
valid_ratio = np.sum(valid_mask > 0) / (erp_h * erp_w)
valid_depths = depth_erp[np.isfinite(depth_erp) & (depth_erp > 0)]
if len(valid_depths) > 0:
print(f" Valid depth ratio: {valid_ratio * 100:.2f}%")
print(f" Depth range: [{valid_depths.min():.2f}, {valid_depths.max():.2f}] m")
# --- Step 5: 保存结果 ---
save_depth(depth_erp, depth_out_dir / f"depth_{frame_id:04d}.npy")
save_depth_visualization(depth_erp, depth_out_dir / f"depth_{frame_id:04d}_vis.png")
cv2.imwrite(str(depth_out_dir / f"depth_{frame_id:04d}_valid_mask.png"), valid_mask * 255)
return depth_erp
# =============================================================================
# Warp 流程
# =============================================================================
def run_warp_pipeline(
center_rgb: np.ndarray,
depth_erp: np.ndarray,
center_frame: int,
image_files: dict,
pose_files: dict,
cfg: Dict[str, Any],
device: torch.device,
output_dir: Path,
erp_h: int,
erp_w: int,
) -> None:
"""运行 warp 全流程:遍历目标帧,执行 forward splatting"""
warp_cfg = cfg.get("warp", {})
output_depth = bool(warp_cfg.get("output_depth", True))
# 确定目标帧列表
available_targets = sorted([fid for fid in pose_files.keys() if fid != center_frame])
cfg_targets = warp_cfg.get("target_frames", None)
if cfg_targets is not None and cfg_targets != "auto":
cfg_set = set(int(t) for t in cfg_targets)
target_frames = [fid for fid in available_targets if fid in cfg_set]
else:
target_frames = available_targets
print(f"\n{'='*60}")
print(f"[Warp] Forward splatting from frame {center_frame}")
print(f"{'='*60}")
print(f" Method: {warp_cfg.get('method', 'softmax_splatting')}")
print(f" Available targets with pose: {available_targets}")
print(f" Will warp: {target_frames}")
# 加载中心帧位姿
if center_frame not in pose_files:
print(f" [Error] Center pose not found for frame {center_frame}")
return
src_pose = load_pose(pose_files[center_frame])
print(f" Source pose: position={src_pose.position.tolist()}")
# 输出目录
warp_rgb_dir = output_dir / "warp_rgb"
warp_rgb_dir.mkdir(parents=True, exist_ok=True)
if output_depth:
warp_depth_dir = output_dir / "warp_depth"
warp_depth_dir.mkdir(parents=True, exist_ok=True)
total_warp = len(target_frames)
for idx, tgt_id in enumerate(target_frames):
if tgt_id not in pose_files:
print(f" [{idx+1}/{total_warp}] Frame {tgt_id}: pose not found, skip")
continue
tgt_pose = load_pose(pose_files[tgt_id])
print(f" [{idx+1}/{total_warp}] Frame {center_frame} -> {tgt_id} ...", end="", flush=True)
t0 = time.time()
result = warp_erp_to_target(
src_rgb=center_rgb,
src_depth=depth_erp,
src_pose=src_pose,
tgt_pose=tgt_pose,
cfg=cfg,
device=device,
)
dt = time.time() - t0
valid_pct = result.valid_mask.sum() / result.valid_mask.size * 100
print(f" done ({dt:.2f}s, valid={valid_pct:.1f}%)")
prefix = f"pano{center_frame:04d}_to_pano{tgt_id:04d}"
# 保存 warped RGB
save_image(result.warped_rgb, warp_rgb_dir / f"{prefix}_rgb.png")
# 保存 valid mask
cv2.imwrite(str(warp_rgb_dir / f"{prefix}_mask.png"), result.valid_mask * 255)
# 保存 warped depth
if output_depth and result.warped_depth is not None:
save_depth(result.warped_depth, warp_depth_dir / f"{prefix}_depth_range.npy")
print(f" Warp complete. Output saved to: {warp_rgb_dir}")
# =============================================================================
# 主函数
# =============================================================================
def main():
_script_dir = Path(__file__).parent
_default_config = _script_dir / "config.yaml"
parser = argparse.ArgumentParser(description="ERPT Pipeline")
parser.add_argument("--config", type=str,
default=str(_default_config) if _default_config.exists() else None,
help="Config file path")
parser.add_argument("--data_dir", type=str, default=None,
help="Data directory (overrides config)")
parser.add_argument("--output_dir", type=str, default=None,
help="Output directory (overrides config)")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--stage", type=str, default="warp_only",
choices=["all", "depth_only", "warp_only"])
parser.add_argument("--center_frame", type=int, default=None,
help="Center frame ID (overrides config)")
args = parser.parse_args()
# 加载配置
if args.config:
config_path = Path(args.config)
cfg = load_config(config_path)
cfg = resolve_paths(cfg, config_path.parent)
else:
cfg = {
"data": {},
"erp": {"auto_size": True},
"tangent": {},
"depth_pro": {"enabled": True, "precision": "fp16", "pass_f_px": True},
"fusion": {"blend_mode": "multiband", "output_scale": 1.10, "k": 4},
"run": {"save_intermediates": False},
}
# 命令行覆盖
if args.data_dir:
cfg["data"]["data_dir"] = str(Path(args.data_dir).resolve())
if args.output_dir:
cfg["data"]["output_dir"] = args.output_dir
data_dir = Path(cfg["data"].get("data_dir", "inputs"))
output_dir = Path(cfg["data"].get("output_dir", "outputs"))
device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu")
print(f"Using device: {device}")
center_frame = args.center_frame or int(cfg.get("warp", {}).get("center_frame", 0))
print(f"\n{'='*60}")
print("ERPT Pipeline")
print(f"{'='*60}")
print(f"Stage: {args.stage}")
print(f"Data dir: {data_dir}")
print(f"Output dir: {output_dir}")
t_start = time.time()
# --- 加载数据 ---
print(f"\n{'='*60}")
print("[Loading data]")
print(f"{'='*60}")
image_files = discover_image_files(data_dir)
pose_files = discover_pose_files(data_dir)
print(f" Found {len(image_files)} images, {len(pose_files)} poses")
if not image_files:
raise FileNotFoundError(f"No image files found in: {data_dir}")
if center_frame not in image_files:
center_frame = sorted(image_files.keys())[0]
print(f" Using frame {center_frame} as center")
center_rgb = load_image(image_files[center_frame])
print(f" Center image: {image_files[center_frame].name}")
print(f" Shape: {center_rgb.shape}")
erp_cfg = cfg.get("erp", {})
if bool(erp_cfg.get("auto_size", True)):
erp_h, erp_w = center_rgb.shape[:2]
print(f" Auto size: {erp_w}x{erp_h}")
else:
erp_h = int(erp_cfg.get("height", 2048))
erp_w = int(erp_cfg.get("width", 4096))
# --- 深度加载 / 估计 ---
depth_erp = None
if args.stage == "all":
print(f"\n [Stage: all] 强制执行深度估计")
depth_erp = run_depth_pipeline(
center_rgb, cfg, device, output_dir, erp_h, erp_w, center_frame,
)
elif args.stage == "depth_only":
depth_erp = run_depth_pipeline(
center_rgb, cfg, device, output_dir, erp_h, erp_w, center_frame,
)
elif args.stage == "warp_only":
# 搜索已有深度(真值 > 已估计结果),不回退到深度估计
depth_candidates = []
if center_frame in image_files:
stem = image_files[center_frame].stem
depth_candidates.append(data_dir / f"{stem}_depth.npy")
depth_candidates.append(data_dir / f"{stem}_depth.exr")
depth_candidates.append(data_dir / f"{stem}.npy")
depth_candidates.append(data_dir / f"depth_{center_frame:04d}.npy")
depth_candidates.append(output_dir / "depth_erp" / f"depth_{center_frame:04d}.npy")
for dp in depth_candidates:
if dp.exists():
if dp.suffix == ".exr":
depth_erp = cv2.imread(str(dp), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
if depth_erp is not None and depth_erp.ndim == 3:
depth_erp = depth_erp[:, :, 0]
depth_erp = depth_erp.astype(np.float32) if depth_erp is not None else None
else:
depth_erp = np.load(str(dp)).astype(np.float32)
if depth_erp is not None:
print(f" Loaded depth from {dp}")
break
# 尺寸校验
if depth_erp is not None and depth_erp.shape != (erp_h, erp_w):
old_shape = depth_erp.shape
depth_erp = cv2.resize(depth_erp, (erp_w, erp_h), interpolation=cv2.INTER_LINEAR)
print(f" [Warning] Depth resized: {old_shape} -> ({erp_h}, {erp_w})")
# 没找到深度 → 报错(不回退到深度估计)
if depth_erp is None:
tried = "\n ".join(str(p) for p in depth_candidates)
raise FileNotFoundError(
f"[warp_only] 未找到深度文件,无法执行 warp。\n"
f"已搜索路径:\n {tried}\n"
f"如需深度估计请使用 --stage all"
)
# --- Warp 阶段 ---
warp_cfg = cfg.get("warp", {})
warp_enabled = bool(warp_cfg.get("enabled", True))
if args.stage in ("all", "warp_only") and warp_enabled:
run_warp_pipeline(
center_rgb, depth_erp, center_frame,
image_files, pose_files,
cfg, device, output_dir, erp_h, erp_w,
)
# --- 完成 ---
total_time = time.time() - t_start
print(f"\n{'='*60}")
print("Pipeline Complete")
print(f"{'='*60}")
print(f"Total time: {total_time:.2f}s")
print(f"Output saved to: {output_dir}")
if __name__ == "__main__":
main()