bdck commited on
Commit
4d5a18c
·
verified ·
1 Parent(s): dccd28a

Upload point_sam/inference.py

Browse files
Files changed (1) hide show
  1. point_sam/inference.py +402 -0
point_sam/inference.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """High-level inference API for Point-SAM.
2
+
3
+ This module provides a clean, hydra-free interface for running Point-SAM
4
+ segmentation on point clouds loaded from PLY or PCD files.
5
+ """
6
+
7
+ import os
8
+ import warnings
9
+ from typing import Union, Tuple, Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from safetensors.torch import load_model as load_safetensors_model
15
+
16
+ from .model.pc_encoder import PatchEmbed, PointCloudEncoder
17
+ from .model.pc_sam import PointCloudSAM
18
+ from .model.prompt_encoder import MaskEncoder, PointEncoder
19
+ from .model.mask_decoder import MaskDecoder
20
+ from .model.transformer import TwoWayTransformer
21
+ from .utils.torch_utils import replace_with_fused_layernorm
22
+
23
+
24
+ def _load_ply_ascii(filename: str) -> np.ndarray:
25
+ """Load an ASCII PLY file with xyzrgb columns."""
26
+ with open(filename, "r") as rf:
27
+ num_of_points = None
28
+ while True:
29
+ line = rf.readline()
30
+ if not line:
31
+ break
32
+ if "end_header" in line:
33
+ break
34
+ if "element vertex" in line:
35
+ num_of_points = int(line.split()[2])
36
+ if num_of_points is None:
37
+ raise ValueError(f"Could not parse vertex count from PLY header: {filename}")
38
+ points = np.zeros([num_of_points, 6], dtype=np.float32)
39
+ for i in range(num_of_points):
40
+ point = rf.readline().split()
41
+ if len(point) < 6:
42
+ raise ValueError(
43
+ f"Line {i} in PLY has fewer than 6 values ({len(point)})."
44
+ )
45
+ points[i] = [float(p) for p in point[:6]]
46
+ return points
47
+
48
+
49
+ def _load_pcd_ascii(filename: str) -> np.ndarray:
50
+ """Load an ASCII PCD file with xyzrgb columns."""
51
+ with open(filename, "r") as rf:
52
+ header_ended = False
53
+ num_of_points = None
54
+ data_mode = None
55
+ while True:
56
+ line = rf.readline()
57
+ if not line:
58
+ break
59
+ line = line.strip()
60
+ if line.startswith("POINTS "):
61
+ num_of_points = int(line.split()[1])
62
+ if line.startswith("DATA "):
63
+ data_mode = line.split()[1]
64
+ header_ended = True
65
+ break
66
+ if num_of_points is None:
67
+ raise ValueError(f"Could not parse POINTS from PCD header: {filename}")
68
+ if data_mode != "ascii":
69
+ raise ValueError(f"Only ASCII PCD is supported; got DATA {data_mode}")
70
+ points = np.zeros([num_of_points, 6], dtype=np.float32)
71
+ for i in range(num_of_points):
72
+ point = rf.readline().split()
73
+ if len(point) < 6:
74
+ # Some PCD files store x y z rgb as single float — try unpacking
75
+ if len(point) == 4:
76
+ # x y z rgb (packed) — unpack rgb into r g b
77
+ rgb_packed = float(point[3])
78
+ rgb_int = int(rgb_packed)
79
+ r = ((rgb_int >> 16) & 0xFF)
80
+ g = ((rgb_int >> 8) & 0xFF)
81
+ b = (rgb_int & 0xFF)
82
+ points[i] = [float(point[0]), float(point[1]), float(point[2]), r, g, b]
83
+ else:
84
+ raise ValueError(
85
+ f"Line {i} in PCD has fewer than 6 values ({len(point)})."
86
+ )
87
+ else:
88
+ points[i] = [float(p) for p in point[:6]]
89
+ return points
90
+
91
+
92
+ def load_pointcloud(
93
+ filepath: str,
94
+ normalize: bool = True,
95
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
96
+ """Load a point cloud from a PLY or PCD file.
97
+
98
+ Args:
99
+ filepath: Path to .ply or .pcd file.
100
+ normalize: Whether to normalize coordinates to a unit sphere in [-1, 1].
101
+
102
+ Returns:
103
+ coords: [N, 3] numpy array of coordinates.
104
+ rgb: [N, 3] numpy array of colors in [0, 255].
105
+ original_coords: [N, 3] un-normalized coordinates.
106
+ """
107
+ ext = os.path.splitext(filepath)[1].lower()
108
+ if ext == ".ply":
109
+ points = _load_ply_ascii(filepath)
110
+ elif ext == ".pcd":
111
+ points = _load_pcd_ascii(filepath)
112
+ else:
113
+ raise ValueError(f"Unsupported file extension: {ext}. Use .ply or .pcd")
114
+
115
+ original_coords = points[:, :3].copy()
116
+ rgb = points[:, 3:6].copy()
117
+
118
+ # If colors look very small, they may already be in [0, 1]
119
+ if rgb.max() <= 1.0 + 1e-6:
120
+ rgb = rgb * 255.0
121
+
122
+ if normalize:
123
+ coords = original_coords - original_coords.mean(axis=0)
124
+ max_norm = np.linalg.norm(coords, axis=1).max()
125
+ if max_norm > 1e-8:
126
+ coords = coords / max_norm
127
+ else:
128
+ coords = original_coords
129
+
130
+ return coords, rgb, original_coords
131
+
132
+
133
+ def build_point_sam(
134
+ variant: str = "large",
135
+ embed_dim: int = 256,
136
+ device: str = "cuda",
137
+ use_fused_layernorm: bool = False,
138
+ ) -> PointCloudSAM:
139
+ """Build a Point-SAM model from scratch (no hydra/omegaconf required).
140
+
141
+ Args:
142
+ variant: Model size — "large" or "giant".
143
+ embed_dim: Embedding dimension for the decoder.
144
+ device: torch device to place the model on.
145
+ use_fused_layernorm: Whether to replace LayerNorm with apex FusedLayerNorm.
146
+ Requires apex to be installed.
147
+
148
+ Returns:
149
+ PointCloudSAM model on the requested device.
150
+ """
151
+ import timm
152
+
153
+ if variant == "large":
154
+ model_name = "eva02_large_patch14_448"
155
+ num_patches = 1024
156
+ patch_size = 256
157
+ prompt_iters = 5
158
+ elif variant == "giant":
159
+ model_name = "eva_giant_patch14_560"
160
+ num_patches = 512
161
+ patch_size = 64
162
+ prompt_iters = 10
163
+ else:
164
+ raise ValueError(f"Unknown variant: {variant}. Choose 'large' or 'giant'.")
165
+
166
+ # Build encoder
167
+ transformer_encoder = timm.create_model(model_name, pretrained=False)
168
+ patch_embed = PatchEmbed(
169
+ in_channels=6,
170
+ out_channels=512,
171
+ num_patches=num_patches,
172
+ patch_size=patch_size,
173
+ )
174
+ pc_encoder = PointCloudEncoder(
175
+ patch_embed=patch_embed,
176
+ transformer=transformer_encoder,
177
+ embed_dim=embed_dim,
178
+ )
179
+
180
+ # Build prompt encoder
181
+ mask_encoder = MaskEncoder(embed_dim=embed_dim)
182
+
183
+ # Build decoder
184
+ transformer_decoder = TwoWayTransformer(
185
+ depth=2,
186
+ embedding_dim=embed_dim,
187
+ num_heads=8,
188
+ mlp_dim=2048,
189
+ )
190
+ mask_decoder = MaskDecoder(
191
+ transformer_dim=embed_dim,
192
+ transformer=transformer_decoder,
193
+ num_multimask_outputs=3,
194
+ iou_head_depth=3,
195
+ iou_head_hidden_dim=256,
196
+ )
197
+
198
+ # Assemble full model
199
+ model = PointCloudSAM(
200
+ pc_encoder=pc_encoder,
201
+ mask_encoder=mask_encoder,
202
+ mask_decoder=mask_decoder,
203
+ prompt_iters=prompt_iters,
204
+ )
205
+
206
+ if use_fused_layernorm:
207
+ if replace_with_fused_layernorm is None:
208
+ warnings.warn("apex FusedLayerNorm requested but not available. Skipping.")
209
+ else:
210
+ model = replace_with_fused_layernorm(model)
211
+
212
+ model = model.to(device)
213
+ return model
214
+
215
+
216
+ class PointSAM:
217
+ """User-friendly wrapper around PointCloudSAM for single-file inference.
218
+
219
+ Typical usage:
220
+ >>> psam = PointSAM.from_pretrained("cuda")
221
+ >>> coords, rgb, original = load_pointcloud("scene.ply")
222
+ >>> mask = psam.predict(coords, rgb, prompt_point=[0.5, 0.1, -0.2])
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ model: PointCloudSAM,
228
+ device: str = "cuda",
229
+ variant: str = "large",
230
+ ):
231
+ self.model = model
232
+ self.device = device
233
+ self.variant = variant
234
+ self._pc_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
235
+
236
+ @classmethod
237
+ def from_pretrained(
238
+ cls,
239
+ checkpoint_path: Optional[str] = None,
240
+ variant: str = "large",
241
+ device: str = "cuda",
242
+ use_fused_layernorm: bool = False,
243
+ ) -> "PointSAM":
244
+ """Load a Point-SAM model from a local or Hub checkpoint.
245
+
246
+ Args:
247
+ checkpoint_path: Local path to a .safetensors checkpoint.
248
+ If None, the model is initialized with random weights.
249
+ variant: "large" or "giant".
250
+ device: torch device.
251
+ use_fused_layernorm: Whether to use apex FusedLayerNorm.
252
+
253
+ Returns:
254
+ PointSAM wrapper ready for inference.
255
+ """
256
+ model = build_point_sam(
257
+ variant=variant,
258
+ device=device,
259
+ use_fused_layernorm=use_fused_layernorm,
260
+ )
261
+ if checkpoint_path is not None:
262
+ load_safetensors_model(model, checkpoint_path)
263
+ print(f"Loaded checkpoint from {checkpoint_path}")
264
+ else:
265
+ warnings.warn(
266
+ "No checkpoint provided — model weights are randomly initialized!"
267
+ )
268
+ model.eval()
269
+ return cls(model=model, device=device, variant=variant)
270
+
271
+ def set_pointcloud(
272
+ self,
273
+ coords: Union[np.ndarray, torch.Tensor],
274
+ rgb: Union[np.ndarray, torch.Tensor],
275
+ ):
276
+ """Cache a point cloud for repeated segmentation queries.
277
+
278
+ This precomputes the encoder embeddings so that subsequent `predict`
279
+ calls with different prompt points are much faster.
280
+
281
+ Args:
282
+ coords: [N, 3] coordinates (normalized to [-1, 1]).
283
+ rgb: [N, 3] colors in [0, 255].
284
+ """
285
+ if isinstance(coords, np.ndarray):
286
+ coords = torch.from_numpy(coords).float()
287
+ if isinstance(rgb, np.ndarray):
288
+ rgb = torch.from_numpy(rgb).float()
289
+
290
+ # Ensure batch dim and normalize colors
291
+ if coords.dim() == 2:
292
+ coords = coords.unsqueeze(0) # [1, N, 3]
293
+ if rgb.dim() == 2:
294
+ rgb = rgb.unsqueeze(0) # [1, N, 3]
295
+
296
+ if rgb.max() > 1.0 + 1e-6:
297
+ rgb = rgb / 255.0
298
+
299
+ coords = coords.to(self.device)
300
+ rgb = rgb.to(self.device)
301
+
302
+ self._pc_cache = (coords, rgb)
303
+
304
+ def predict(
305
+ self,
306
+ coords: Union[np.ndarray, torch.Tensor],
307
+ rgb: Union[np.ndarray, torch.Tensor],
308
+ prompt_point: Union[list, tuple, np.ndarray, torch.Tensor],
309
+ prompt_label: int = 1,
310
+ multimask_output: bool = True,
311
+ return_logits: bool = False,
312
+ ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
313
+ """Run segmentation on a point cloud given a prompt point.
314
+
315
+ Args:
316
+ coords: [N, 3] normalized coordinates or a cached point cloud.
317
+ rgb: [N, 3] colors in [0, 255]. Ignored if coords was cached via
318
+ `set_pointcloud`.
319
+ prompt_point: [x, y, z] coordinate of the prompt. Must be in the same
320
+ normalized space as `coords` (i.e., [-1, 1] if you used the default
321
+ `load_pointcloud` normalization).
322
+ prompt_label: 1 for foreground (positive), 0 for background (negative).
323
+ multimask_output: If True, return 3 mask candidates + IoU scores.
324
+ If False, return a single mask.
325
+ return_logits: If True, return raw logits instead of a boolean mask.
326
+
327
+ Returns:
328
+ If multimask_output=False:
329
+ mask: [N] boolean array (or float logits if return_logits=True).
330
+ If multimask_output=True:
331
+ masks: [3, N] boolean array of candidate masks.
332
+ iou_preds: [3] IoU confidence scores for each candidate.
333
+ """
334
+ # Use cached point cloud if available and coords wasn't passed fresh
335
+ if self._pc_cache is not None and coords is None:
336
+ coords, rgb = self._pc_cache
337
+ else:
338
+ if isinstance(coords, np.ndarray):
339
+ coords = torch.from_numpy(coords).float()
340
+ if isinstance(rgb, np.ndarray):
341
+ rgb = torch.from_numpy(rgb).float()
342
+ if coords.dim() == 2:
343
+ coords = coords.unsqueeze(0)
344
+ if rgb.dim() == 2:
345
+ rgb = rgb.unsqueeze(0)
346
+ if rgb.max() > 1.0 + 1e-6:
347
+ rgb = rgb / 255.0
348
+ coords = coords.to(self.device)
349
+ rgb = rgb.to(self.device)
350
+
351
+ # Prepare prompt
352
+ if isinstance(prompt_point, (list, tuple)):
353
+ prompt_point = np.array(prompt_point, dtype=np.float32)
354
+ if isinstance(prompt_point, np.ndarray):
355
+ prompt_point = torch.from_numpy(prompt_point).float()
356
+ if prompt_point.dim() == 1:
357
+ prompt_point = prompt_point.unsqueeze(0).unsqueeze(0) # [1, 1, 3]
358
+ prompt_point = prompt_point.to(self.device)
359
+
360
+ prompt_labels = torch.tensor([[prompt_label]], dtype=torch.long, device=self.device)
361
+
362
+ with torch.no_grad():
363
+ masks, iou_preds = self.model.predict_masks(
364
+ coords,
365
+ rgb,
366
+ prompt_point,
367
+ prompt_labels,
368
+ prompt_masks=None,
369
+ multimask_output=multimask_output,
370
+ )
371
+
372
+ # masks: [1, num_outputs, N]
373
+ # iou_preds: [1, num_outputs]
374
+ masks = masks[0] # [num_outputs, N]
375
+ iou_preds = iou_preds[0] # [num_outputs]
376
+
377
+ if not multimask_output:
378
+ mask = masks[0]
379
+ if return_logits:
380
+ return mask.cpu().numpy()
381
+ return (mask > 0).cpu().numpy()
382
+
383
+ if return_logits:
384
+ return masks.cpu().numpy(), iou_preds.cpu().numpy()
385
+ return (masks > 0).cpu().numpy(), iou_preds.cpu().numpy()
386
+
387
+ @property
388
+ def num_points(self) -> int:
389
+ """Number of points in the cached point cloud, or 0 if none."""
390
+ if self._pc_cache is None:
391
+ return 0
392
+ return self._pc_cache[0].shape[1]
393
+
394
+ def adjust_patch_params(self, num_groups: int, group_size: int):
395
+ """Dynamically adjust the number of patches and patch size.
396
+
397
+ Call this when working with very large point clouds (e.g. > 100k points)
398
+ to avoid OOM. The authors suggest num_groups=2048, group_size=256 for
399
+ clouds with > 100k points.
400
+ """
401
+ self.model.pc_encoder.patch_embed.grouper.num_groups = num_groups
402
+ self.model.pc_encoder.patch_embed.grouper.group_size = group_size