cmevs-code / utils /pose_utils.py
anon-cmevs-2026's picture
Initial code release for NeurIPS 2026 D&B reviewer reference
5c1bb37 verified
"""
位姿工具模块
位姿加载、验证和转换。
ERPT_native 位姿格式:
{
"frame_id": int,
"position": [x, y, z], # 相机中心在世界坐标系的位置(米)
"rotation_quaternion": [w, x, y, z] # camera->world 旋转 (R_cw)
}
"""
import numpy as np
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from .io_utils import load_json
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from core.coordinate import (
quat_wxyz_to_rotation_matrix,
validate_rotation_matrix,
orthonormalize_rotation,
)
@dataclass
class Pose:
"""位姿数据类"""
frame_id: int
position: np.ndarray # (3,) 相机中心在世界坐标系的位置
R_cw: np.ndarray # (3, 3) camera->world 旋转矩阵
@property
def R_wc(self) -> np.ndarray:
"""world->camera 旋转矩阵"""
return self.R_cw.T
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
from core.coordinate import rotation_matrix_to_quat_wxyz
quat = rotation_matrix_to_quat_wxyz(self.R_cw)
return {
"frame_id": self.frame_id,
"position": self.position.tolist(),
"rotation_quaternion": quat.tolist(),
}
def load_pose(path: Path) -> Pose:
"""
加载位姿文件
Args:
path: 位姿 JSON 文件路径
Returns:
Pose 实例
"""
data = load_json(path)
return parse_pose(data)
def parse_pose(data: Dict[str, Any]) -> Pose:
"""
解析位姿数据
Args:
data: 位姿 JSON 数据
Returns:
Pose 实例
"""
# 必需字段
if "position" not in data:
raise ValueError("Pose requires 'position' field")
if "rotation_quaternion" not in data:
raise ValueError("Pose requires 'rotation_quaternion' field")
frame_id = data.get("frame_id", 0)
position = np.array(data["position"], dtype=np.float64)
quat = np.array(data["rotation_quaternion"], dtype=np.float64)
# 验证
if position.shape != (3,):
raise ValueError(f"position shape {position.shape}, expected (3,)")
if quat.shape != (4,):
raise ValueError(f"rotation_quaternion shape {quat.shape}, expected (4,)")
# 四元数归一化
quat_norm = np.linalg.norm(quat)
if quat_norm < 1e-9:
raise ValueError(f"Quaternion norm too small: {quat_norm}")
quat = quat / quat_norm
# 转换为旋转矩阵
R_cw = quat_wxyz_to_rotation_matrix(quat)
return Pose(
frame_id=int(frame_id),
position=position.astype(np.float32),
R_cw=R_cw.astype(np.float32),
)
def validate_pose(pose: Pose, strict: bool = True) -> Tuple[bool, str]:
"""
验证位姿有效性
Args:
pose: Pose 实例
strict: 严格模式
Returns:
(is_valid, message)
"""
# 检查旋转矩阵
is_valid, msg = validate_rotation_matrix(pose.R_cw)
if not is_valid:
if strict:
return False, msg
else:
# 尝试修复
pose.R_cw = orthonormalize_rotation(pose.R_cw).astype(np.float32)
# 检查位置是否有限
if not np.all(np.isfinite(pose.position)):
return False, f"Position contains non-finite values: {pose.position}"
return True, "Valid pose"
def load_all_poses(pose_dir: Path) -> List[Pose]:
"""
加载目录中所有位姿文件
Args:
pose_dir: 位姿目录
Returns:
poses: 按 frame_id 排序的位姿列表
"""
pose_files = sorted(pose_dir.glob("pose_*.json"))
poses = []
for path in pose_files:
try:
pose = load_pose(path)
poses.append(pose)
except Exception as e:
print(f"[Warning] Failed to load {path}: {e}")
# 按 frame_id 排序
poses.sort(key=lambda p: p.frame_id)
return poses
def compute_relative_pose(
pose_src: Pose,
pose_tgt: Pose,
) -> Tuple[np.ndarray, np.ndarray]:
"""
计算从源视角到目标视角的相对位姿
Args:
pose_src: 源位姿
pose_tgt: 目标位姿
Returns:
R_rel: (3, 3) 相对旋转 R_tgt_src
t_rel: (3,) 相对平移
"""
# R_tgt_src = R_wc_tgt @ R_cw_src
R_rel = pose_tgt.R_wc @ pose_src.R_cw
# t_rel = R_wc_tgt @ (t_src - t_tgt)
t_rel = pose_tgt.R_wc @ (pose_src.position - pose_tgt.position)
return R_rel, t_rel
def compute_translation_distance(pose1: Pose, pose2: Pose) -> float:
"""
计算两个位姿之间的平移距离
Args:
pose1, pose2: Pose 实例
Returns:
distance: 欧氏距离(米)
"""
return float(np.linalg.norm(pose1.position - pose2.position))
def compute_rotation_angle(pose1: Pose, pose2: Pose) -> float:
"""
计算两个位姿之间的旋转角度
Args:
pose1, pose2: Pose 实例
Returns:
angle: 旋转角度(度)
"""
R_rel = pose1.R_wc @ pose2.R_cw
trace = np.trace(R_rel)
# cos(θ) = (trace - 1) / 2
cos_theta = np.clip((trace - 1.0) / 2.0, -1.0, 1.0)
angle_rad = np.arccos(cos_theta)
return float(np.degrees(angle_rad))