| """Post-processing functions for segment predictions.""" |
| import numpy as np |
|
|
|
|
| def snap_to_point_cloud(vertices, xyz, class_id, snap_radius=0.5, |
| target_classes=None): |
| """Snap vertices to nearby point cloud clusters of specific semantic classes.""" |
| if target_classes is None: |
| target_classes = [1, 2] |
|
|
| snapped = vertices.copy() |
| mask = np.isin(class_id, target_classes) |
|
|
| if mask.sum() < 2: |
| return snapped |
|
|
| target_pts = xyz[mask] |
|
|
| for i, v in enumerate(vertices): |
| dists = np.linalg.norm(target_pts - v, axis=-1) |
| close = dists < snap_radius |
| if close.sum() >= 2: |
| snapped[i] = target_pts[close].mean(axis=0) |
|
|
| return snapped |
|
|
|
|
| def snap_horizontal(vertices, edges, max_slope=0.05): |
| """Snap near-horizontal edges to be exactly horizontal.""" |
| verts = vertices.copy() |
| for a, b in edges: |
| a, b = int(a), int(b) |
| dy = abs(verts[a, 1] - verts[b, 1]) |
| dxz = np.sqrt((verts[a, 0] - verts[b, 0])**2 + (verts[a, 2] - verts[b, 2])**2) |
| if dxz > 0.1 and dy / dxz < max_slope: |
| avg_y = 0.5 * (verts[a, 1] + verts[b, 1]) |
| verts[a, 1] = avg_y |
| verts[b, 1] = avg_y |
| return verts |
|
|