| """ |
| ERP Forward Warp 模块(移植自原版 ERPT erp_softsplat.py) |
| |
| 使用锁定的投影/坐标系接口: |
| - core.erp_projection: erp_to_direction, direction_to_erp, wrap_u, clamp_v |
| - utils.pose_utils: Pose (R_cw, R_wc, position) |
| |
| 算法流程: |
| 1. 对每个 src ERP 像素,通过 erp_to_direction 获取射线方向 |
| 2. 根据深度计算 3D 点,变换到目标相机坐标系 |
| 3. 通过 direction_to_erp 投影到目标 ERP |
| 4. Forward splatting 累积 RGB(softmax / zbuffer / point) |
| |
| 支持的 splatting 方法: |
| - softmax_splatting(默认):自适应半径 + 高斯核 + softmax 深度竞争 |
| - zbuffer_splatting:两遍 z-buffer 硬遮挡 |
| - zbuffer_point:最近邻投影 |
| """ |
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| from .erp_projection import ( |
| erp_to_direction, |
| direction_to_erp, |
| wrap_u, |
| create_erp_grid, |
| ) |
|
|
| import sys |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| from utils.pose_utils import Pose |
|
|
|
|
| @dataclass |
| class WarpResult: |
| """Warp 结果""" |
| warped_rgb: np.ndarray |
| valid_mask: np.ndarray |
| flow: Optional[np.ndarray] |
| weight_sum: np.ndarray |
| warped_depth: Optional[np.ndarray] = None |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def _forward_project( |
| src_depth_t: torch.Tensor, |
| src_pose: Pose, |
| tgt_pose: Pose, |
| erp_h: int, |
| erp_w: int, |
| device: torch.device, |
| uu: Optional[torch.Tensor] = None, |
| vv: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| 将源 ERP 像素投影到目标 ERP |
| |
| 使用锁定的 erp_projection 接口进行坐标变换。 |
| |
| Returns: |
| u_tgt, v_tgt: (H, W) 目标像素坐标 |
| range_tgt: (H, W) 目标 range depth |
| dirs_tgt: (H, W, 3) 目标方向向量 |
| """ |
| if uu is None or vv is None: |
| uu, vv = create_erp_grid(erp_h, erp_w, device) |
|
|
| |
| dirs_src = erp_to_direction(uu, vv, erp_h, erp_w) |
|
|
| |
| P_cam_src = dirs_src * src_depth_t.unsqueeze(-1) |
|
|
| |
| R_cw_src = torch.tensor(src_pose.R_cw, device=device, dtype=torch.float32) |
| t_src = torch.tensor(src_pose.position, device=device, dtype=torch.float32) |
| P_world = torch.einsum("ij,hwj->hwi", R_cw_src, P_cam_src) + t_src |
|
|
| |
| R_wc_tgt = torch.tensor(tgt_pose.R_wc, device=device, dtype=torch.float32) |
| t_tgt = torch.tensor(tgt_pose.position, device=device, dtype=torch.float32) |
| P_cam_tgt = torch.einsum("ij,hwj->hwi", R_wc_tgt, P_world - t_tgt) |
|
|
| |
| range_tgt = torch.norm(P_cam_tgt, dim=-1) |
| dirs_tgt = P_cam_tgt / torch.clamp(range_tgt.unsqueeze(-1), min=1e-9) |
|
|
| |
| u_tgt, v_tgt = direction_to_erp(dirs_tgt, erp_h, erp_w) |
| u_tgt = wrap_u(u_tgt, erp_w) |
|
|
| return u_tgt, v_tgt, range_tgt, dirs_tgt |
|
|
|
|
| |
| |
| |
|
|
| def _adaptive_splat_rgb( |
| erp_h: int, |
| erp_w: int, |
| u: torch.Tensor, |
| v: torch.Tensor, |
| rgb: torch.Tensor, |
| depth_compete: torch.Tensor, |
| valid: torch.Tensor, |
| alpha: float, |
| radius: torch.Tensor, |
| occlusion_gate: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| 自适应半径 softmax splatting |
| |
| - 高斯核加权 |
| - softmax(alpha * inv_depth) 深度竞争 |
| - 可选 occlusion gate(近似 z-buffer 门控) |
| """ |
| device = u.device |
| u_flat = u.reshape(-1) |
| v_flat = v.reshape(-1) |
| rgb_flat = rgb.reshape(-1, 3) |
| d_flat = depth_compete.reshape(-1) |
| valid_flat = valid.reshape(-1) |
| r_flat = radius.reshape(-1) |
|
|
| |
| safe_d = torch.where( |
| valid_flat & torch.isfinite(d_flat) & (d_flat > 0), |
| d_flat, torch.ones_like(d_flat), |
| ) |
|
|
| |
| inv_d = 1.0 / torch.clamp(safe_d, min=0.1) |
| valid_inv = inv_d[valid_flat] |
| inv_max = valid_inv.max() if len(valid_inv) > 0 else inv_d.max() |
| exp_w = torch.exp(alpha * (inv_d - inv_max)) |
|
|
| |
| gate_enabled = False |
| min_d_flat: Optional[torch.Tensor] = None |
| gate_abs = 0.0 |
| gate_rel = 0.0 |
| if occlusion_gate and bool(occlusion_gate.get("enabled", False)): |
| gate_enabled = True |
| gate_abs = float(occlusion_gate.get("abs_eps_m", 0.05)) |
| gate_rel = float(occlusion_gate.get("rel_eps", 0.05)) |
| u_nn = torch.round(u_flat).to(torch.long) |
| v_nn = torch.round(v_flat).to(torch.long) |
| u_nn = torch.remainder(u_nn, erp_w) |
| v_ok = (v_nn >= 0) & (v_nn < erp_h) |
| v_nn_c = torch.clamp(v_nn, 0, erp_h - 1) |
| idx_nn = v_nn_c * erp_w + u_nn |
| min_d_flat = torch.full((erp_h * erp_w,), float("inf"), device=device) |
| d_nn = torch.where(valid_flat & v_ok & torch.isfinite(d_flat), |
| d_flat, torch.full_like(d_flat, float("inf"))) |
| min_d_flat.scatter_reduce_(0, idx_nn, d_nn, reduce="amin", include_self=True) |
|
|
| accum_rgb = torch.zeros(erp_h, erp_w, 3, device=device, dtype=torch.float32) |
| accum_w = torch.zeros(erp_h, erp_w, device=device, dtype=torch.float32) |
| accum_hit = torch.zeros(erp_h, erp_w, device=device, dtype=torch.float32) |
| accum_d = torch.zeros(erp_h, erp_w, device=device, dtype=torch.float32) |
|
|
| u0 = torch.floor(u_flat).to(torch.int64) |
| v0 = torch.floor(v_flat).to(torch.int64) |
| du = (u_flat - u0.float()).clamp(0, 1) |
| dv = (v_flat - v0.float()).clamp(0, 1) |
|
|
| |
| valid_radii = r_flat[valid_flat & torch.isfinite(r_flat)] |
| max_r = min(int(valid_radii.max().item()) + 1, 5) if len(valid_radii) > 0 else 2 |
|
|
| def _add(u_idx, v_idx, bw): |
| v_ok = (v_idx >= 0) & (v_idx < erp_h) |
| m = valid_flat & v_ok & torch.isfinite(d_flat) |
| u_safe = torch.where(m, u_idx, torch.zeros_like(u_idx)) |
| v_safe = torch.where(m, v_idx, torch.zeros_like(v_idx)) |
| idx = v_safe * erp_w + u_safe |
|
|
| if gate_enabled and min_d_flat is not None: |
| md = min_d_flat.gather(0, idx) |
| gate = d_flat <= (md * (1.0 + gate_rel) + gate_abs) |
| mm = m & gate |
| else: |
| mm = m |
|
|
| final_w = torch.where(mm, bw * exp_w, torch.zeros_like(bw)) |
| hit_w = torch.where(mm, bw, torch.zeros_like(bw)) |
| accum_w.view(-1).scatter_add_(0, idx, final_w) |
| accum_hit.view(-1).scatter_add_(0, idx, hit_w) |
| accum_rgb.view(-1, 3).scatter_add_( |
| 0, idx.unsqueeze(-1).expand(-1, 3), |
| (final_w.unsqueeze(-1) * rgb_flat).float(), |
| ) |
| accum_d.view(-1).scatter_add_(0, idx, (final_w * d_flat).float()) |
|
|
| for di in range(-max_r, max_r + 1): |
| for dj in range(-max_r, max_r + 1): |
| dist_ij = math.sqrt(di * di + dj * dj) |
| if dist_ij > max_r + 0.5: |
| continue |
| dx = float(di) - du |
| dy = float(dj) - dv |
| dist = torch.sqrt(dx * dx + dy * dy) |
| within = dist <= (r_flat + 0.5) |
| gauss_w = torch.where( |
| within, |
| torch.exp(-0.5 * (dist / r_flat.clamp(min=0.5)) ** 2), |
| torch.zeros_like(r_flat), |
| ) |
| u_off = torch.remainder(u0 + di, erp_w) |
| v_off = v0 + dj |
| _add(u_off, v_off, gauss_w) |
|
|
| return accum_rgb, accum_w, accum_hit, accum_d |
|
|
|
|
| |
| |
| |
|
|
| def _zbuffer_splat_rgb( |
| erp_h: int, erp_w: int, |
| u: torch.Tensor, v: torch.Tensor, |
| rgb: torch.Tensor, depth_compete: torch.Tensor, valid: torch.Tensor, |
| eps_abs_m: float, eps_rel: float, min_w: float, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Z-buffer 硬遮挡 forward splatting(两遍法)""" |
| device = u.device |
| u_flat, v_flat = u.reshape(-1), v.reshape(-1) |
| d_flat = depth_compete.reshape(-1) |
| rgb_flat = rgb.reshape(-1, 3) |
| valid_flat = valid.reshape(-1) |
|
|
| m0 = valid_flat & torch.isfinite(u_flat) & torch.isfinite(v_flat) & \ |
| torch.isfinite(d_flat) & (d_flat > 0.0) |
|
|
| u0 = torch.floor(u_flat).to(torch.int64) |
| v0 = torch.floor(v_flat).to(torch.int64) |
| du = (u_flat - u0.float()).clamp(0, 1) |
| dv = (v_flat - v0.float()).clamp(0, 1) |
| u0w = torch.remainder(u0, erp_w) |
| u1w = torch.remainder(u0 + 1, erp_w) |
| v1 = v0 + 1 |
| w00 = (1 - du) * (1 - dv) |
| w10 = du * (1 - dv) |
| w01 = (1 - du) * dv |
| w11 = du * dv |
|
|
| |
| min_depth = torch.full((erp_h * erp_w,), float("inf"), device=device) |
|
|
| def _amin(ui, vi, w): |
| m = m0 & (vi >= 0) & (vi < erp_h) & (w >= min_w) |
| us = torch.where(m, ui, torch.zeros_like(ui)) |
| vs = torch.where(m, vi, torch.zeros_like(vi)) |
| idx = vs * erp_w + us |
| cand = torch.where(m, d_flat, torch.full_like(d_flat, float("inf"))) |
| min_depth.scatter_reduce_(0, idx, cand, reduce="amin", include_self=True) |
|
|
| _amin(u0w, v0, w00); _amin(u1w, v0, w10) |
| _amin(u0w, v1, w01); _amin(u1w, v1, w11) |
|
|
| |
| accum_rgb = torch.zeros(erp_h, erp_w, 3, device=device) |
| accum_w = torch.zeros(erp_h, erp_w, device=device) |
| accum_hit = torch.zeros(erp_h, erp_w, device=device) |
| accum_d = torch.zeros(erp_h, erp_w, device=device) |
|
|
| def _acc(ui, vi, w): |
| m = m0 & (vi >= 0) & (vi < erp_h) & (w >= min_w) |
| us = torch.where(m, ui, torch.zeros_like(ui)) |
| vs = torch.where(m, vi, torch.zeros_like(vi)) |
| idx = vs * erp_w + us |
| md = min_depth.gather(0, idx) |
| gate = d_flat <= (md * (1 + eps_rel) + eps_abs_m) |
| mm = m & gate |
| wf = torch.where(mm, w, torch.zeros_like(w)) |
| accum_w.view(-1).scatter_add_(0, idx, wf) |
| accum_hit.view(-1).scatter_add_(0, idx, wf) |
| accum_rgb.view(-1, 3).scatter_add_( |
| 0, idx.unsqueeze(-1).expand(-1, 3), |
| (wf.unsqueeze(-1) * rgb_flat).float(), |
| ) |
| accum_d.view(-1).scatter_add_(0, idx, (wf * d_flat).float()) |
|
|
| _acc(u0w, v0, w00); _acc(u1w, v0, w10) |
| _acc(u0w, v1, w01); _acc(u1w, v1, w11) |
|
|
| return accum_rgb, accum_w, accum_hit, accum_d |
|
|
|
|
| |
| |
| |
|
|
| def _zbuffer_point_rgb( |
| erp_h: int, erp_w: int, |
| u: torch.Tensor, v: torch.Tensor, |
| rgb: torch.Tensor, depth_compete: torch.Tensor, valid: torch.Tensor, |
| eps_abs_m: float, eps_rel: float, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Z-buffer 点渲染(radius=0, winner-take-all)""" |
| device = u.device |
| u_flat, v_flat = u.reshape(-1), v.reshape(-1) |
| d_flat = depth_compete.reshape(-1) |
| rgb_flat = rgb.reshape(-1, 3) |
| valid_flat = valid.reshape(-1) |
|
|
| m0 = valid_flat & torch.isfinite(u_flat) & torch.isfinite(v_flat) & \ |
| torch.isfinite(d_flat) & (d_flat > 0.0) |
|
|
| u_nn = torch.remainder(torch.round(u_flat).to(torch.int64), erp_w) |
| v_nn = torch.round(v_flat).to(torch.int64) |
| v_ok = (v_nn >= 0) & (v_nn < erp_h) |
| m = m0 & v_ok |
| us = torch.where(m, u_nn, torch.zeros_like(u_nn)) |
| vs = torch.where(m, v_nn, torch.zeros_like(v_nn)) |
| idx = vs * erp_w + us |
|
|
| |
| min_depth = torch.full((erp_h * erp_w,), float("inf"), device=device) |
| cand = torch.where(m, d_flat, torch.full_like(d_flat, float("inf"))) |
| min_depth.scatter_reduce_(0, idx, cand, reduce="amin", include_self=True) |
|
|
| |
| md = min_depth.gather(0, idx) |
| gate = d_flat <= (md * (1 + eps_rel) + eps_abs_m) |
| mm = m & gate |
| wf = torch.where(mm, torch.ones_like(d_flat), torch.zeros_like(d_flat)) |
|
|
| accum_rgb = torch.zeros(erp_h, erp_w, 3, device=device) |
| accum_w = torch.zeros(erp_h, erp_w, device=device) |
| accum_hit = torch.zeros(erp_h, erp_w, device=device) |
| accum_d = torch.zeros(erp_h, erp_w, device=device) |
|
|
| accum_w.view(-1).scatter_add_(0, idx, wf) |
| accum_hit.view(-1).scatter_add_(0, idx, wf) |
| accum_rgb.view(-1, 3).scatter_add_( |
| 0, idx.unsqueeze(-1).expand(-1, 3), |
| (wf.unsqueeze(-1) * rgb_flat).float(), |
| ) |
| accum_d.view(-1).scatter_add_(0, idx, (wf * d_flat).float()) |
|
|
| return accum_rgb, accum_w, accum_hit, accum_d |
|
|
|
|
| |
| |
| |
|
|
| def _edge_aware_hole_fill( |
| rgb: np.ndarray, mask: np.ndarray, |
| max_hole_px: int = 5, |
| inpaint_radius: int = 2, |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| """小洞填充(只填充极小洞,避免 disocclusion 被错误填充)""" |
| holes = (mask == 0).astype(np.uint8) |
| if holes.sum() == 0: |
| return rgb, mask |
|
|
| num, labels, stats, _ = cv2.connectedComponentsWithStats(holes, connectivity=8) |
| fill_mask = np.zeros_like(holes) |
| max_area = max_hole_px * max_hole_px |
|
|
| for i in range(1, num): |
| area = stats[i, cv2.CC_STAT_AREA] |
| if area <= max_area: |
| fill_mask[labels == i] = 1 |
|
|
| if fill_mask.sum() == 0: |
| return rgb, mask |
|
|
| rgb_bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) |
| filled = cv2.inpaint(rgb_bgr, fill_mask, inpaint_radius, cv2.INPAINT_TELEA) |
| filled_rgb = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB) |
|
|
| rgb_out = rgb.copy() |
| mask_out = mask.copy() |
| fill_bool = fill_mask > 0 |
| rgb_out[fill_bool] = filled_rgb[fill_bool] |
| mask_out[fill_bool] = 1 |
|
|
| return rgb_out, mask_out |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def warp_erp_to_target( |
| src_rgb: np.ndarray, |
| src_depth: np.ndarray, |
| src_pose: Pose, |
| tgt_pose: Pose, |
| cfg: Dict[str, Any], |
| device: torch.device, |
| ) -> WarpResult: |
| """ |
| 从源 ERP 视角 warp 到目标 ERP 视角 |
| |
| 使用锁定的 erp_projection.py 进行坐标变换, |
| 使用锁定的 pose_utils.Pose 进行位姿处理。 |
| |
| Args: |
| src_rgb: (H, W, 3) uint8 源 RGB |
| src_depth: (H, W) float32 源 range depth(米) |
| src_pose: 源相机位姿(Pose 实例) |
| tgt_pose: 目标相机位姿(Pose 实例) |
| cfg: 配置字典 |
| device: 计算设备 |
| |
| Returns: |
| WarpResult |
| """ |
| warp_cfg = cfg.get("warp", {}) |
| method = str(warp_cfg.get("method", "softmax_splatting")) |
| alpha = float(warp_cfg.get("alpha", 2.0)) |
| min_weight_sum = float(warp_cfg.get("min_weight_sum", 1e-4)) |
| output_flow = bool(warp_cfg.get("output_flow", True)) |
| output_depth = bool(warp_cfg.get("output_depth", True)) |
| depth_scale_factor = float(warp_cfg.get("depth_scale_factor", 1.0)) |
|
|
| |
| z_eps_abs = float(warp_cfg.get("zbuffer_eps_abs_m", 0.03)) |
| z_eps_rel = float(warp_cfg.get("zbuffer_eps_rel", 0.03)) |
| z_min_w = float(warp_cfg.get("zbuffer_min_weight", 1e-3)) |
|
|
| |
| base_radius = float(warp_cfg.get("splat_radius_px", 1.5)) |
| radius_min = float(warp_cfg.get("radius_min_px", 0.6)) |
| radius_max_eq = float(warp_cfg.get("radius_max_px", 2.2)) |
| radius_max_pole = float(warp_cfg.get("radius_max_pole_px", 3.4)) |
| pole_radius_scale = float(warp_cfg.get("pole_radius_scale", 3.0)) |
| pole_lat_threshold = float(warp_cfg.get("pole_lat_threshold", 60.0)) * math.pi / 180.0 |
| depth_radius_scale = bool(warp_cfg.get("depth_radius_scale", False)) |
| depth_ref = float(warp_cfg.get("depth_ref_m", 2.0)) |
| depth_edge_aware = bool(warp_cfg.get("depth_edge_aware", True)) |
| depth_edge_threshold = float(warp_cfg.get("depth_edge_threshold", 0.3)) |
| depth_edge_min_scale = float(warp_cfg.get("depth_edge_min_scale", 0.12)) |
|
|
| |
| hole_fill = bool(warp_cfg.get("hole_fill_enabled", False)) and method not in ("zbuffer_splatting", "zbuffer_point") |
| max_hole_px = int(warp_cfg.get("max_hole_px", 16)) |
|
|
| erp_h, erp_w = src_rgb.shape[:2] |
|
|
| |
| src_rgb_t = torch.from_numpy(src_rgb.astype(np.float32)).to(device) / 255.0 |
| src_depth_t = torch.from_numpy(src_depth.astype(np.float32)).to(device) |
| if depth_scale_factor != 1.0: |
| src_depth_t *= depth_scale_factor |
|
|
| valid = torch.isfinite(src_depth_t) & (src_depth_t > 0.0) |
|
|
| |
| depth_edge_scale = torch.ones_like(src_depth_t) |
| if depth_edge_aware: |
| from torch.nn.functional import conv2d |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], |
| dtype=torch.float32, device=device).view(1, 1, 3, 3) |
| sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], |
| dtype=torch.float32, device=device).view(1, 1, 3, 3) |
| safe_d = torch.where(valid, src_depth_t, src_depth_t[valid].median() if valid.any() else torch.ones_like(src_depth_t)) |
| log_d = torch.log(torch.clamp(safe_d, min=0.1)).unsqueeze(0).unsqueeze(0) |
| gx = conv2d(log_d, sobel_x, padding=1).squeeze() |
| gy = conv2d(log_d, sobel_y, padding=1).squeeze() |
| grad = torch.sqrt(gx ** 2 + gy ** 2) |
| gmax = grad.max() |
| if gmax > 1e-6: |
| gnorm = grad / gmax |
| else: |
| gnorm = torch.zeros_like(grad) |
| depth_edge_scale = torch.clamp( |
| 1.0 - gnorm / max(depth_edge_threshold, 1e-6), |
| min=depth_edge_min_scale, max=1.0, |
| ) |
| depth_edge_scale = torch.where(torch.isfinite(depth_edge_scale), |
| depth_edge_scale, torch.ones_like(depth_edge_scale)) |
|
|
| |
| uu, vv = create_erp_grid(erp_h, erp_w, device) |
|
|
| |
| u_tgt, v_tgt, range_tgt, dirs_tgt = _forward_project( |
| src_depth_t, src_pose, tgt_pose, erp_h, erp_w, device, uu, vv, |
| ) |
|
|
| |
| lat_tgt = torch.asin(torch.clamp(dirs_tgt[..., 1], -1.0, 1.0)) |
| abs_lat = torch.abs(lat_tgt) |
| pole_factor = torch.clamp( |
| (abs_lat - pole_lat_threshold) / (0.5 * math.pi - pole_lat_threshold), |
| min=0.0, max=1.0, |
| ) |
| lat_scale = 1.0 + pole_factor * (pole_radius_scale - 1.0) |
|
|
| if depth_radius_scale: |
| safe_range = torch.where(valid, range_tgt, torch.full_like(range_tgt, depth_ref)) |
| d_scale = 1.0 / (1.0 + safe_range / depth_ref) |
| else: |
| d_scale = torch.ones_like(range_tgt) |
|
|
| adaptive_radius = base_radius * lat_scale * d_scale * depth_edge_scale |
| adaptive_radius = torch.where(valid, adaptive_radius, torch.full_like(adaptive_radius, base_radius)) |
| radius_max_local = radius_max_eq + pole_factor * (radius_max_pole - radius_max_eq) |
| adaptive_radius = torch.clamp(adaptive_radius, min=radius_min) |
| adaptive_radius = torch.minimum(adaptive_radius, radius_max_local) |
|
|
| |
| if method == "zbuffer_splatting": |
| _rgb, _w, _hit, _d = _zbuffer_splat_rgb( |
| erp_h, erp_w, u_tgt, v_tgt, src_rgb_t, range_tgt, valid, |
| z_eps_abs, z_eps_rel, z_min_w, |
| ) |
| elif method == "zbuffer_point": |
| _rgb, _w, _hit, _d = _zbuffer_point_rgb( |
| erp_h, erp_w, u_tgt, v_tgt, src_rgb_t, range_tgt, valid, |
| z_eps_abs, z_eps_rel, |
| ) |
| else: |
| _rgb, _w, _hit, _d = _adaptive_splat_rgb( |
| erp_h, erp_w, u_tgt, v_tgt, src_rgb_t, range_tgt, valid, |
| alpha, adaptive_radius, warp_cfg.get("occlusion_gate", None), |
| ) |
|
|
| |
| denom = _w > 0.0 |
| out_rgb = torch.zeros_like(_rgb) |
| out_rgb[denom] = _rgb[denom] / _w[denom].unsqueeze(-1) |
|
|
| min_hit = float(warp_cfg.get("min_hit_sum", 1e-6)) |
| valid_mask = _hit > min_hit |
|
|
| warped_np = (out_rgb.clamp(0, 1) * 255).byte().cpu().numpy() |
| mask_np = valid_mask.cpu().numpy().astype(np.uint8) |
| weight_np = _hit.cpu().numpy().astype(np.float32) |
|
|
| |
| warped_depth_np = None |
| if output_depth: |
| out_d = torch.full((erp_h, erp_w), float("nan"), device=device) |
| out_d[denom] = _d[denom] / torch.clamp(_w[denom], min=1e-9) |
| out_d[~valid_mask] = float("nan") |
| warped_depth_np = out_d.cpu().numpy().astype(np.float32) |
|
|
| |
| if hole_fill: |
| warped_np, mask_np = _edge_aware_hole_fill(warped_np, mask_np, max_hole_px) |
|
|
| |
| flow_np = None |
| if output_flow: |
| du = u_tgt - uu |
| du = (du + 0.5 * erp_w) % erp_w - 0.5 * erp_w |
| dv = v_tgt - vv |
| flow_np = torch.stack([du, dv], dim=-1).cpu().numpy().astype(np.float32) |
|
|
| return WarpResult( |
| warped_rgb=warped_np, |
| valid_mask=mask_np, |
| flow=flow_np, |
| weight_sum=weight_np, |
| warped_depth=warped_depth_np, |
| ) |
|
|
|
|
| def create_comparison_image( |
| warped_rgb: np.ndarray, |
| valid_mask: np.ndarray, |
| gt_rgb: Optional[np.ndarray] = None, |
| ) -> np.ndarray: |
| """创建对比图(warped | GT),如无 GT 则只返回 warped""" |
| vis = warped_rgb.copy() |
| vis[valid_mask == 0] = 0 |
|
|
| if gt_rgb is not None: |
| return np.concatenate([vis, gt_rgb], axis=0) |
| return vis |
|
|