bdck commited on
Commit
191e774
·
verified ·
1 Parent(s): 3fe5c47

Upload learn_region_grow/utils.py

Browse files
Files changed (1) hide show
  1. learn_region_grow/utils.py +120 -0
learn_region_grow/utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility helpers: sampling, centering, metrics."""
2
+
3
+ import numpy as np
4
+ from typing import Tuple
5
+
6
+
7
+ def sample_or_pad(points: np.ndarray, target: int, seed: int = None) -> np.ndarray:
8
+ """
9
+ Sample exactly `target` points from `points` (N, C) with replacement if needed.
10
+
11
+ Parameters
12
+ ----------
13
+ points : np.ndarray, shape (N, C)
14
+ target : int
15
+ Number of points to return.
16
+ seed : int, optional
17
+ RNG seed for reproducibility.
18
+
19
+ Returns
20
+ -------
21
+ out : np.ndarray, shape (target, C)
22
+ """
23
+ n = len(points)
24
+ if n == 0:
25
+ return np.zeros((target, points.shape[1]), dtype=np.float32)
26
+ rng = np.random.default_rng(seed)
27
+ if n >= target:
28
+ idx = rng.choice(n, target, replace=False)
29
+ else:
30
+ idx = rng.choice(n, target, replace=True)
31
+ return points[idx]
32
+
33
+
34
+ def center_features(inliers: np.ndarray, neighbors: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
35
+ """
36
+ Translation / feature centering used by LrgNet.
37
+
38
+ Subtracts the median XY of the inlier set from XYZ of both inliers and neighbors.
39
+ Subtracts the median of channels 6+ (colors, normals, curvature) from both sets.
40
+
41
+ Parameters
42
+ ----------
43
+ inliers : np.ndarray, shape (Ni, 13)
44
+ neighbors: np.ndarray, shape (Nn, 13)
45
+
46
+ Returns
47
+ -------
48
+ inliers_c, neighbors_c : same shapes, centered.
49
+ """
50
+ inliers = inliers.copy()
51
+ neighbors = neighbors.copy()
52
+
53
+ # Median XY of inliers
54
+ med_xy = np.median(inliers[:, :2], axis=0)
55
+ inliers[:, :2] -= med_xy
56
+ neighbors[:, :2] -= med_xy
57
+
58
+ # Median of channels 6+ (colors, normals, curvature) from inliers
59
+ med_feat = np.median(inliers[:, 6:], axis=0)
60
+ inliers[:, 6:] -= med_feat
61
+ neighbors[:, 6:] -= med_feat
62
+
63
+ return inliers, neighbors
64
+
65
+
66
+ def cluster_metrics(pred_labels: np.ndarray, gt_labels: np.ndarray) -> dict:
67
+ """
68
+ Compute clustering quality metrics (NMI, AMI, ARS, Precision, Recall, IoU).
69
+
70
+ Parameters
71
+ ----------
72
+ pred_labels : np.ndarray, shape (N,)
73
+ Predicted instance IDs (0 = unlabeled / background).
74
+ gt_labels : np.ndarray, shape (N,)
75
+ Ground-truth instance IDs.
76
+
77
+ Returns
78
+ -------
79
+ dict with keys: nmi, ami, ars, precision, recall, iou.
80
+ """
81
+ from sklearn.metrics import normalized_mutual_info_score, adjusted_mutual_info_score, adjusted_rand_score
82
+
83
+ mask = gt_labels >= 0
84
+ pred = pred_labels[mask]
85
+ gt = gt_labels[mask]
86
+
87
+ if len(pred) == 0:
88
+ return {"nmi": 0.0, "ami": 0.0, "ars": 0.0, "precision": 0.0, "recall": 0.0, "iou": 0.0}
89
+
90
+ nmi = normalized_mutual_info_score(gt, pred)
91
+ ami = adjusted_mutual_info_score(gt, pred)
92
+ ars = adjusted_rand_score(gt, pred)
93
+
94
+ # Precision / Recall / IoU over pairs
95
+ pred_pairs = set()
96
+ gt_pairs = set()
97
+ n = len(pred)
98
+ for i in range(n):
99
+ for j in range(i + 1, n):
100
+ if pred[i] == pred[j] and pred[i] >= 0:
101
+ pred_pairs.add((i, j))
102
+ if gt[i] == gt[j] and gt[i] >= 0:
103
+ gt_pairs.add((i, j))
104
+
105
+ tp = len(pred_pairs & gt_pairs)
106
+ fp = len(pred_pairs - gt_pairs)
107
+ fn = len(gt_pairs - pred_pairs)
108
+
109
+ precision = tp / (tp + fp + 1e-8)
110
+ recall = tp / (tp + fn + 1e-8)
111
+ iou = tp / (tp + fp + fn + 1e-8)
112
+
113
+ return {
114
+ "nmi": nmi,
115
+ "ami": ami,
116
+ "ars": ars,
117
+ "precision": precision,
118
+ "recall": recall,
119
+ "iou": iou,
120
+ }