File size: 3,952 Bytes
72dfa8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
segment.py
==========

Command-line script to run learned region growing on a PLY / PCD point cloud.

Example:
    python segment.py --input scene.ply --ckpt checkpoints/best_model.pt \
                      --output segmented_scene.ply --device cuda
"""

import argparse
import torch
import numpy as np
from pathlib import Path

from learn_region_grow.io import load_point_cloud, save_ply
from learn_region_grow.preprocess import voxel_equalize, compute_normals_and_curvature, build_feature_vector
from learn_region_grow.lrg_net import LrgNet
from learn_region_grow.growing import RegionGrower


def main():
    parser = argparse.ArgumentParser(description="LRGNet: Learned Region Growing on a Point Cloud")
    parser.add_argument("--input", required=True, help="Input .ply or .pcd file")
    parser.add_argument("--ckpt", required=True, help="PyTorch checkpoint (.pt)")
    parser.add_argument("--output", default="output.ply", help="Output segmented PLY")
    parser.add_argument("--device", default="cuda", help="cuda or cpu")
    parser.add_argument("--resolution", type=float, default=0.1, help="Voxel resolution in meters")
    parser.add_argument("--lite", type=int, default=0, choices=[0,1,2], help="Lite model variant")
    parser.add_argument("--stochastic", action="store_true", default=True, help="Use stochastic growing (default)")
    parser.add_argument("--deterministic", dest="stochastic", action="store_false", help="Use deterministic thresholding")
    parser.add_argument("--add_threshold", type=float, default=0.5)
    parser.add_argument("--remove_threshold", type=float, default=0.5)
    parser.add_argument("--cluster_threshold", type=int, default=10)
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print(f"Loading point cloud from {args.input} ...")
    xyz, rgb, normals_input = load_point_cloud(args.input)
    print(f"  {len(xyz)} points loaded")

    # Voxel equalization
    print(f"Voxel equalization (resolution={args.resolution}m) ...")
    eq_xyz, eq_idx, voxel_map = voxel_equalize(xyz, args.resolution)
    eq_rgb = rgb[eq_idx] if rgb is not None else None
    print(f"  {len(eq_xyz)} points after equalization")

    # Normals / curvature
    if normals_input is not None:
        normals = np.abs(normals_input[eq_idx])
        curvature = np.zeros(len(eq_xyz), dtype=np.float32)  # could estimate from normals
        # Better: still compute curvature from PCA
        _, curvature = compute_normals_and_curvature(eq_xyz, args.resolution)
    else:
        print("Computing normals & curvature via PCA ...")
        normals, curvature = compute_normals_and_curvature(eq_xyz, args.resolution)

    # Build feature vector
    features = build_feature_vector(eq_xyz, eq_rgb, normals, curvature)
    print(f"Feature vector shape: {features.shape}")

    # Load model
    print(f"Loading checkpoint {args.ckpt} ...")
    model = LrgNet(in_channels=13, lite=args.lite)
    model.load_state_dict(torch.load(args.ckpt, map_location=device))
    model.to(device)
    print("Model loaded.")

    # Run region growing
    grower = RegionGrower(
        model, device,
        add_threshold=args.add_threshold,
        remove_threshold=args.remove_threshold,
        cluster_threshold=args.cluster_threshold,
        stochastic=args.stochastic,
    )
    print("Running learned region growing ...")
    labels = grower.grow(eq_xyz, features, voxel_map, args.resolution)
    n_instances = len(np.unique(labels[labels >= 0]))
    print(f"Segmented into {n_instances} instances")

    # Expand labels back to original dense cloud via nearest equalized neighbor
    from scipy.spatial import cKDTree
    tree = cKDTree(eq_xyz)
    _, nn = tree.query(xyz)
    full_labels = labels[nn]

    # Save
    print(f"Saving output to {args.output} ...")
    save_ply(args.output, xyz, labels=full_labels)
    print("Done.")


if __name__ == "__main__":
    main()