File size: 5,400 Bytes
5c1bb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
位姿工具模块

位姿加载、验证和转换。

ERPT_native 位姿格式:
{
    "frame_id": int,
    "position": [x, y, z],          # 相机中心在世界坐标系的位置(米)
    "rotation_quaternion": [w, x, y, z]  # camera->world 旋转 (R_cw)
}
"""

import numpy as np
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple

from .io_utils import load_json

import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from core.coordinate import (
    quat_wxyz_to_rotation_matrix,
    validate_rotation_matrix,
    orthonormalize_rotation,
)


@dataclass
class Pose:
    """位姿数据类"""
    frame_id: int
    position: np.ndarray        # (3,) 相机中心在世界坐标系的位置
    R_cw: np.ndarray            # (3, 3) camera->world 旋转矩阵
    
    @property
    def R_wc(self) -> np.ndarray:
        """world->camera 旋转矩阵"""
        return self.R_cw.T
    
    def to_dict(self) -> Dict[str, Any]:
        """转换为字典"""
        from core.coordinate import rotation_matrix_to_quat_wxyz
        quat = rotation_matrix_to_quat_wxyz(self.R_cw)
        return {
            "frame_id": self.frame_id,
            "position": self.position.tolist(),
            "rotation_quaternion": quat.tolist(),
        }


def load_pose(path: Path) -> Pose:
    """
    加载位姿文件
    
    Args:
        path: 位姿 JSON 文件路径
        
    Returns:
        Pose 实例
    """
    data = load_json(path)
    return parse_pose(data)


def parse_pose(data: Dict[str, Any]) -> Pose:
    """
    解析位姿数据
    
    Args:
        data: 位姿 JSON 数据
        
    Returns:
        Pose 实例
    """
    # 必需字段
    if "position" not in data:
        raise ValueError("Pose requires 'position' field")
    if "rotation_quaternion" not in data:
        raise ValueError("Pose requires 'rotation_quaternion' field")
    
    frame_id = data.get("frame_id", 0)
    position = np.array(data["position"], dtype=np.float64)
    quat = np.array(data["rotation_quaternion"], dtype=np.float64)
    
    # 验证
    if position.shape != (3,):
        raise ValueError(f"position shape {position.shape}, expected (3,)")
    if quat.shape != (4,):
        raise ValueError(f"rotation_quaternion shape {quat.shape}, expected (4,)")
    
    # 四元数归一化
    quat_norm = np.linalg.norm(quat)
    if quat_norm < 1e-9:
        raise ValueError(f"Quaternion norm too small: {quat_norm}")
    quat = quat / quat_norm
    
    # 转换为旋转矩阵
    R_cw = quat_wxyz_to_rotation_matrix(quat)
    
    return Pose(
        frame_id=int(frame_id),
        position=position.astype(np.float32),
        R_cw=R_cw.astype(np.float32),
    )


def validate_pose(pose: Pose, strict: bool = True) -> Tuple[bool, str]:
    """
    验证位姿有效性
    
    Args:
        pose: Pose 实例
        strict: 严格模式
        
    Returns:
        (is_valid, message)
    """
    # 检查旋转矩阵
    is_valid, msg = validate_rotation_matrix(pose.R_cw)
    if not is_valid:
        if strict:
            return False, msg
        else:
            # 尝试修复
            pose.R_cw = orthonormalize_rotation(pose.R_cw).astype(np.float32)
    
    # 检查位置是否有限
    if not np.all(np.isfinite(pose.position)):
        return False, f"Position contains non-finite values: {pose.position}"
    
    return True, "Valid pose"


def load_all_poses(pose_dir: Path) -> List[Pose]:
    """
    加载目录中所有位姿文件
    
    Args:
        pose_dir: 位姿目录
        
    Returns:
        poses: 按 frame_id 排序的位姿列表
    """
    pose_files = sorted(pose_dir.glob("pose_*.json"))
    
    poses = []
    for path in pose_files:
        try:
            pose = load_pose(path)
            poses.append(pose)
        except Exception as e:
            print(f"[Warning] Failed to load {path}: {e}")
    
    # 按 frame_id 排序
    poses.sort(key=lambda p: p.frame_id)
    
    return poses


def compute_relative_pose(
    pose_src: Pose,
    pose_tgt: Pose,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    计算从源视角到目标视角的相对位姿
    
    Args:
        pose_src: 源位姿
        pose_tgt: 目标位姿
        
    Returns:
        R_rel: (3, 3) 相对旋转 R_tgt_src
        t_rel: (3,) 相对平移
    """
    # R_tgt_src = R_wc_tgt @ R_cw_src
    R_rel = pose_tgt.R_wc @ pose_src.R_cw
    
    # t_rel = R_wc_tgt @ (t_src - t_tgt)
    t_rel = pose_tgt.R_wc @ (pose_src.position - pose_tgt.position)
    
    return R_rel, t_rel


def compute_translation_distance(pose1: Pose, pose2: Pose) -> float:
    """
    计算两个位姿之间的平移距离
    
    Args:
        pose1, pose2: Pose 实例
        
    Returns:
        distance: 欧氏距离(米)
    """
    return float(np.linalg.norm(pose1.position - pose2.position))


def compute_rotation_angle(pose1: Pose, pose2: Pose) -> float:
    """
    计算两个位姿之间的旋转角度
    
    Args:
        pose1, pose2: Pose 实例
        
    Returns:
        angle: 旋转角度(度)
    """
    R_rel = pose1.R_wc @ pose2.R_cw
    trace = np.trace(R_rel)
    # cos(θ) = (trace - 1) / 2
    cos_theta = np.clip((trace - 1.0) / 2.0, -1.0, 1.0)
    angle_rad = np.arccos(cos_theta)
    return float(np.degrees(angle_rad))