Upload learn_region_grow/utils.py
Browse files- learn_region_grow/utils.py +120 -0
learn_region_grow/utils.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility helpers: sampling, centering, metrics."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def sample_or_pad(points: np.ndarray, target: int, seed: int = None) -> np.ndarray:
|
| 8 |
+
"""
|
| 9 |
+
Sample exactly `target` points from `points` (N, C) with replacement if needed.
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
----------
|
| 13 |
+
points : np.ndarray, shape (N, C)
|
| 14 |
+
target : int
|
| 15 |
+
Number of points to return.
|
| 16 |
+
seed : int, optional
|
| 17 |
+
RNG seed for reproducibility.
|
| 18 |
+
|
| 19 |
+
Returns
|
| 20 |
+
-------
|
| 21 |
+
out : np.ndarray, shape (target, C)
|
| 22 |
+
"""
|
| 23 |
+
n = len(points)
|
| 24 |
+
if n == 0:
|
| 25 |
+
return np.zeros((target, points.shape[1]), dtype=np.float32)
|
| 26 |
+
rng = np.random.default_rng(seed)
|
| 27 |
+
if n >= target:
|
| 28 |
+
idx = rng.choice(n, target, replace=False)
|
| 29 |
+
else:
|
| 30 |
+
idx = rng.choice(n, target, replace=True)
|
| 31 |
+
return points[idx]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def center_features(inliers: np.ndarray, neighbors: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 35 |
+
"""
|
| 36 |
+
Translation / feature centering used by LrgNet.
|
| 37 |
+
|
| 38 |
+
Subtracts the median XY of the inlier set from XYZ of both inliers and neighbors.
|
| 39 |
+
Subtracts the median of channels 6+ (colors, normals, curvature) from both sets.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
inliers : np.ndarray, shape (Ni, 13)
|
| 44 |
+
neighbors: np.ndarray, shape (Nn, 13)
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
inliers_c, neighbors_c : same shapes, centered.
|
| 49 |
+
"""
|
| 50 |
+
inliers = inliers.copy()
|
| 51 |
+
neighbors = neighbors.copy()
|
| 52 |
+
|
| 53 |
+
# Median XY of inliers
|
| 54 |
+
med_xy = np.median(inliers[:, :2], axis=0)
|
| 55 |
+
inliers[:, :2] -= med_xy
|
| 56 |
+
neighbors[:, :2] -= med_xy
|
| 57 |
+
|
| 58 |
+
# Median of channels 6+ (colors, normals, curvature) from inliers
|
| 59 |
+
med_feat = np.median(inliers[:, 6:], axis=0)
|
| 60 |
+
inliers[:, 6:] -= med_feat
|
| 61 |
+
neighbors[:, 6:] -= med_feat
|
| 62 |
+
|
| 63 |
+
return inliers, neighbors
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def cluster_metrics(pred_labels: np.ndarray, gt_labels: np.ndarray) -> dict:
|
| 67 |
+
"""
|
| 68 |
+
Compute clustering quality metrics (NMI, AMI, ARS, Precision, Recall, IoU).
|
| 69 |
+
|
| 70 |
+
Parameters
|
| 71 |
+
----------
|
| 72 |
+
pred_labels : np.ndarray, shape (N,)
|
| 73 |
+
Predicted instance IDs (0 = unlabeled / background).
|
| 74 |
+
gt_labels : np.ndarray, shape (N,)
|
| 75 |
+
Ground-truth instance IDs.
|
| 76 |
+
|
| 77 |
+
Returns
|
| 78 |
+
-------
|
| 79 |
+
dict with keys: nmi, ami, ars, precision, recall, iou.
|
| 80 |
+
"""
|
| 81 |
+
from sklearn.metrics import normalized_mutual_info_score, adjusted_mutual_info_score, adjusted_rand_score
|
| 82 |
+
|
| 83 |
+
mask = gt_labels >= 0
|
| 84 |
+
pred = pred_labels[mask]
|
| 85 |
+
gt = gt_labels[mask]
|
| 86 |
+
|
| 87 |
+
if len(pred) == 0:
|
| 88 |
+
return {"nmi": 0.0, "ami": 0.0, "ars": 0.0, "precision": 0.0, "recall": 0.0, "iou": 0.0}
|
| 89 |
+
|
| 90 |
+
nmi = normalized_mutual_info_score(gt, pred)
|
| 91 |
+
ami = adjusted_mutual_info_score(gt, pred)
|
| 92 |
+
ars = adjusted_rand_score(gt, pred)
|
| 93 |
+
|
| 94 |
+
# Precision / Recall / IoU over pairs
|
| 95 |
+
pred_pairs = set()
|
| 96 |
+
gt_pairs = set()
|
| 97 |
+
n = len(pred)
|
| 98 |
+
for i in range(n):
|
| 99 |
+
for j in range(i + 1, n):
|
| 100 |
+
if pred[i] == pred[j] and pred[i] >= 0:
|
| 101 |
+
pred_pairs.add((i, j))
|
| 102 |
+
if gt[i] == gt[j] and gt[i] >= 0:
|
| 103 |
+
gt_pairs.add((i, j))
|
| 104 |
+
|
| 105 |
+
tp = len(pred_pairs & gt_pairs)
|
| 106 |
+
fp = len(pred_pairs - gt_pairs)
|
| 107 |
+
fn = len(gt_pairs - pred_pairs)
|
| 108 |
+
|
| 109 |
+
precision = tp / (tp + fp + 1e-8)
|
| 110 |
+
recall = tp / (tp + fn + 1e-8)
|
| 111 |
+
iou = tp / (tp + fp + fn + 1e-8)
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
"nmi": nmi,
|
| 115 |
+
"ami": ami,
|
| 116 |
+
"ars": ars,
|
| 117 |
+
"precision": precision,
|
| 118 |
+
"recall": recall,
|
| 119 |
+
"iou": iou,
|
| 120 |
+
}
|