| """High-level inference API for Point-SAM. |
| |
| This module provides a clean, hydra-free interface for running Point-SAM |
| segmentation on point clouds loaded from PLY or PCD files. |
| """ |
|
|
| import os |
| import warnings |
| from typing import Union, Tuple, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from safetensors.torch import load_model as load_safetensors_model |
|
|
| from .model.pc_encoder import PatchEmbed, PointCloudEncoder |
| from .model.pc_sam import PointCloudSAM |
| from .model.prompt_encoder import MaskEncoder, PointEncoder |
| from .model.mask_decoder import MaskDecoder |
| from .model.transformer import TwoWayTransformer |
| from .utils.torch_utils import replace_with_fused_layernorm |
|
|
|
|
| def _load_ply_ascii(filename: str) -> np.ndarray: |
| """Load an ASCII PLY file with xyzrgb columns.""" |
| with open(filename, "r") as rf: |
| num_of_points = None |
| while True: |
| line = rf.readline() |
| if not line: |
| break |
| if "end_header" in line: |
| break |
| if "element vertex" in line: |
| num_of_points = int(line.split()[2]) |
| if num_of_points is None: |
| raise ValueError(f"Could not parse vertex count from PLY header: {filename}") |
| points = np.zeros([num_of_points, 6], dtype=np.float32) |
| for i in range(num_of_points): |
| point = rf.readline().split() |
| if len(point) < 6: |
| raise ValueError( |
| f"Line {i} in PLY has fewer than 6 values ({len(point)})." |
| ) |
| points[i] = [float(p) for p in point[:6]] |
| return points |
|
|
|
|
| def _load_pcd_ascii(filename: str) -> np.ndarray: |
| """Load an ASCII PCD file with xyzrgb columns.""" |
| with open(filename, "r") as rf: |
| header_ended = False |
| num_of_points = None |
| data_mode = None |
| while True: |
| line = rf.readline() |
| if not line: |
| break |
| line = line.strip() |
| if line.startswith("POINTS "): |
| num_of_points = int(line.split()[1]) |
| if line.startswith("DATA "): |
| data_mode = line.split()[1] |
| header_ended = True |
| break |
| if num_of_points is None: |
| raise ValueError(f"Could not parse POINTS from PCD header: {filename}") |
| if data_mode != "ascii": |
| raise ValueError(f"Only ASCII PCD is supported; got DATA {data_mode}") |
| points = np.zeros([num_of_points, 6], dtype=np.float32) |
| for i in range(num_of_points): |
| point = rf.readline().split() |
| if len(point) < 6: |
| |
| if len(point) == 4: |
| |
| rgb_packed = float(point[3]) |
| rgb_int = int(rgb_packed) |
| r = ((rgb_int >> 16) & 0xFF) |
| g = ((rgb_int >> 8) & 0xFF) |
| b = (rgb_int & 0xFF) |
| points[i] = [float(point[0]), float(point[1]), float(point[2]), r, g, b] |
| else: |
| raise ValueError( |
| f"Line {i} in PCD has fewer than 6 values ({len(point)})." |
| ) |
| else: |
| points[i] = [float(p) for p in point[:6]] |
| return points |
|
|
|
|
| def load_pointcloud( |
| filepath: str, |
| normalize: bool = True, |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """Load a point cloud from a PLY or PCD file. |
| |
| Args: |
| filepath: Path to .ply or .pcd file. |
| normalize: Whether to normalize coordinates to a unit sphere in [-1, 1]. |
| |
| Returns: |
| coords: [N, 3] numpy array of coordinates. |
| rgb: [N, 3] numpy array of colors in [0, 255]. |
| original_coords: [N, 3] un-normalized coordinates. |
| """ |
| ext = os.path.splitext(filepath)[1].lower() |
| if ext == ".ply": |
| points = _load_ply_ascii(filepath) |
| elif ext == ".pcd": |
| points = _load_pcd_ascii(filepath) |
| else: |
| raise ValueError(f"Unsupported file extension: {ext}. Use .ply or .pcd") |
|
|
| original_coords = points[:, :3].copy() |
| rgb = points[:, 3:6].copy() |
|
|
| |
| if rgb.max() <= 1.0 + 1e-6: |
| rgb = rgb * 255.0 |
|
|
| if normalize: |
| coords = original_coords - original_coords.mean(axis=0) |
| max_norm = np.linalg.norm(coords, axis=1).max() |
| if max_norm > 1e-8: |
| coords = coords / max_norm |
| else: |
| coords = original_coords |
|
|
| return coords, rgb, original_coords |
|
|
|
|
| def build_point_sam( |
| variant: str = "large", |
| embed_dim: int = 256, |
| device: str = "cuda", |
| use_fused_layernorm: bool = False, |
| ) -> PointCloudSAM: |
| """Build a Point-SAM model from scratch (no hydra/omegaconf required). |
| |
| Args: |
| variant: Model size — "large" or "giant". |
| embed_dim: Embedding dimension for the decoder. |
| device: torch device to place the model on. |
| use_fused_layernorm: Whether to replace LayerNorm with apex FusedLayerNorm. |
| Requires apex to be installed. |
| |
| Returns: |
| PointCloudSAM model on the requested device. |
| """ |
| import timm |
|
|
| if variant == "large": |
| model_name = "eva02_large_patch14_448" |
| num_patches = 1024 |
| patch_size = 256 |
| prompt_iters = 5 |
| elif variant == "giant": |
| model_name = "eva_giant_patch14_560" |
| num_patches = 512 |
| patch_size = 64 |
| prompt_iters = 10 |
| else: |
| raise ValueError(f"Unknown variant: {variant}. Choose 'large' or 'giant'.") |
|
|
| |
| transformer_encoder = timm.create_model(model_name, pretrained=False) |
| patch_embed = PatchEmbed( |
| in_channels=6, |
| out_channels=512, |
| num_patches=num_patches, |
| patch_size=patch_size, |
| ) |
| pc_encoder = PointCloudEncoder( |
| patch_embed=patch_embed, |
| transformer=transformer_encoder, |
| embed_dim=embed_dim, |
| ) |
|
|
| |
| mask_encoder = MaskEncoder(embed_dim=embed_dim) |
|
|
| |
| transformer_decoder = TwoWayTransformer( |
| depth=2, |
| embedding_dim=embed_dim, |
| num_heads=8, |
| mlp_dim=2048, |
| ) |
| mask_decoder = MaskDecoder( |
| transformer_dim=embed_dim, |
| transformer=transformer_decoder, |
| num_multimask_outputs=3, |
| iou_head_depth=3, |
| iou_head_hidden_dim=256, |
| ) |
|
|
| |
| model = PointCloudSAM( |
| pc_encoder=pc_encoder, |
| mask_encoder=mask_encoder, |
| mask_decoder=mask_decoder, |
| prompt_iters=prompt_iters, |
| ) |
|
|
| if use_fused_layernorm: |
| if replace_with_fused_layernorm is None: |
| warnings.warn("apex FusedLayerNorm requested but not available. Skipping.") |
| else: |
| model = replace_with_fused_layernorm(model) |
|
|
| model = model.to(device) |
| return model |
|
|
|
|
| class PointSAM: |
| """User-friendly wrapper around PointCloudSAM for single-file inference. |
| |
| Typical usage: |
| >>> psam = PointSAM.from_pretrained("cuda") |
| >>> coords, rgb, original = load_pointcloud("scene.ply") |
| >>> mask = psam.predict(coords, rgb, prompt_point=[0.5, 0.1, -0.2]) |
| """ |
|
|
| def __init__( |
| self, |
| model: PointCloudSAM, |
| device: str = "cuda", |
| variant: str = "large", |
| ): |
| self.model = model |
| self.device = device |
| self.variant = variant |
| self._pc_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| checkpoint_path: Optional[str] = None, |
| variant: str = "large", |
| device: str = "cuda", |
| use_fused_layernorm: bool = False, |
| ) -> "PointSAM": |
| """Load a Point-SAM model from a local or Hub checkpoint. |
| |
| Args: |
| checkpoint_path: Local path to a .safetensors checkpoint. |
| If None, the model is initialized with random weights. |
| variant: "large" or "giant". |
| device: torch device. |
| use_fused_layernorm: Whether to use apex FusedLayerNorm. |
| |
| Returns: |
| PointSAM wrapper ready for inference. |
| """ |
| model = build_point_sam( |
| variant=variant, |
| device=device, |
| use_fused_layernorm=use_fused_layernorm, |
| ) |
| if checkpoint_path is not None: |
| load_safetensors_model(model, checkpoint_path) |
| print(f"Loaded checkpoint from {checkpoint_path}") |
| else: |
| warnings.warn( |
| "No checkpoint provided — model weights are randomly initialized!" |
| ) |
| model.eval() |
| return cls(model=model, device=device, variant=variant) |
|
|
| def set_pointcloud( |
| self, |
| coords: Union[np.ndarray, torch.Tensor], |
| rgb: Union[np.ndarray, torch.Tensor], |
| ): |
| """Cache a point cloud for repeated segmentation queries. |
| |
| This precomputes the encoder embeddings so that subsequent `predict` |
| calls with different prompt points are much faster. |
| |
| Args: |
| coords: [N, 3] coordinates (normalized to [-1, 1]). |
| rgb: [N, 3] colors in [0, 255]. |
| """ |
| if isinstance(coords, np.ndarray): |
| coords = torch.from_numpy(coords).float() |
| if isinstance(rgb, np.ndarray): |
| rgb = torch.from_numpy(rgb).float() |
|
|
| |
| if coords.dim() == 2: |
| coords = coords.unsqueeze(0) |
| if rgb.dim() == 2: |
| rgb = rgb.unsqueeze(0) |
|
|
| if rgb.max() > 1.0 + 1e-6: |
| rgb = rgb / 255.0 |
|
|
| coords = coords.to(self.device) |
| rgb = rgb.to(self.device) |
|
|
| self._pc_cache = (coords, rgb) |
|
|
| def predict( |
| self, |
| coords: Union[np.ndarray, torch.Tensor], |
| rgb: Union[np.ndarray, torch.Tensor], |
| prompt_point: Union[list, tuple, np.ndarray, torch.Tensor], |
| prompt_label: int = 1, |
| multimask_output: bool = True, |
| return_logits: bool = False, |
| ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: |
| """Run segmentation on a point cloud given a prompt point. |
| |
| Args: |
| coords: [N, 3] normalized coordinates or a cached point cloud. |
| rgb: [N, 3] colors in [0, 255]. Ignored if coords was cached via |
| `set_pointcloud`. |
| prompt_point: [x, y, z] coordinate of the prompt. Must be in the same |
| normalized space as `coords` (i.e., [-1, 1] if you used the default |
| `load_pointcloud` normalization). |
| prompt_label: 1 for foreground (positive), 0 for background (negative). |
| multimask_output: If True, return 3 mask candidates + IoU scores. |
| If False, return a single mask. |
| return_logits: If True, return raw logits instead of a boolean mask. |
| |
| Returns: |
| If multimask_output=False: |
| mask: [N] boolean array (or float logits if return_logits=True). |
| If multimask_output=True: |
| masks: [3, N] boolean array of candidate masks. |
| iou_preds: [3] IoU confidence scores for each candidate. |
| """ |
| |
| if self._pc_cache is not None and coords is None: |
| coords, rgb = self._pc_cache |
| else: |
| if isinstance(coords, np.ndarray): |
| coords = torch.from_numpy(coords).float() |
| if isinstance(rgb, np.ndarray): |
| rgb = torch.from_numpy(rgb).float() |
| if coords.dim() == 2: |
| coords = coords.unsqueeze(0) |
| if rgb.dim() == 2: |
| rgb = rgb.unsqueeze(0) |
| if rgb.max() > 1.0 + 1e-6: |
| rgb = rgb / 255.0 |
| coords = coords.to(self.device) |
| rgb = rgb.to(self.device) |
|
|
| |
| if isinstance(prompt_point, (list, tuple)): |
| prompt_point = np.array(prompt_point, dtype=np.float32) |
| if isinstance(prompt_point, np.ndarray): |
| prompt_point = torch.from_numpy(prompt_point).float() |
| if prompt_point.dim() == 1: |
| prompt_point = prompt_point.unsqueeze(0).unsqueeze(0) |
| prompt_point = prompt_point.to(self.device) |
|
|
| prompt_labels = torch.tensor([[prompt_label]], dtype=torch.long, device=self.device) |
|
|
| with torch.no_grad(): |
| masks, iou_preds = self.model.predict_masks( |
| coords, |
| rgb, |
| prompt_point, |
| prompt_labels, |
| prompt_masks=None, |
| multimask_output=multimask_output, |
| ) |
|
|
| |
| |
| masks = masks[0] |
| iou_preds = iou_preds[0] |
|
|
| if not multimask_output: |
| mask = masks[0] |
| if return_logits: |
| return mask.cpu().numpy() |
| return (mask > 0).cpu().numpy() |
|
|
| if return_logits: |
| return masks.cpu().numpy(), iou_preds.cpu().numpy() |
| return (masks > 0).cpu().numpy(), iou_preds.cpu().numpy() |
|
|
| @property |
| def num_points(self) -> int: |
| """Number of points in the cached point cloud, or 0 if none.""" |
| if self._pc_cache is None: |
| return 0 |
| return self._pc_cache[0].shape[1] |
|
|
| def adjust_patch_params(self, num_groups: int, group_size: int): |
| """Dynamically adjust the number of patches and patch size. |
| |
| Call this when working with very large point clouds (e.g. > 100k points) |
| to avoid OOM. The authors suggest num_groups=2048, group_size=256 for |
| clouds with > 100k points. |
| """ |
| self.model.pc_encoder.patch_embed.grouper.num_groups = num_groups |
| self.model.pc_encoder.patch_embed.grouper.group_size = group_size |
|
|