| import torch |
| import sys |
| from sklearn.preprocessing import StandardScaler |
| import pytorch_lightning as pl |
| from torch.utils.data import DataLoader |
| from pytorch_lightning.utilities.combined_loader import CombinedLoader |
| import laspy |
| import numpy as np |
| from scipy.spatial import cKDTree |
| import math |
| from functools import partial |
| from torch.utils.data import TensorDataset |
|
|
|
|
| class GaussianMM: |
| def __init__(self, mu, var): |
| super().__init__() |
| self.centers = torch.tensor(mu) |
| self.logstd = torch.tensor(var).log() / 2.0 |
| self.K = self.centers.shape[0] |
|
|
| def logprob(self, x): |
| logprobs = self.normal_logprob( |
| x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd |
| ) |
| logprobs = torch.sum(logprobs, dim=2) |
| return torch.logsumexp(logprobs, dim=1) - math.log(self.K) |
|
|
| def normal_logprob(self, z, mean, log_std): |
| mean = mean + torch.tensor(0.0) |
| log_std = log_std + torch.tensor(0.0) |
| c = torch.tensor([math.log(2 * math.pi)]).to(z) |
| inv_sigma = torch.exp(-log_std) |
| tmp = (z - mean) * inv_sigma |
| return -0.5 * (tmp * tmp + 2 * log_std + c) |
|
|
| def __call__(self, n_samples): |
| idx = torch.randint(self.K, (n_samples,)).to(self.centers.device) |
| mean = self.centers[idx] |
| return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean |
| |
| class BranchedLidarDataModule(pl.LightningDataModule): |
| def __init__(self, args): |
| super().__init__() |
| self.save_hyperparameters() |
|
|
| self.data_path = args.data_path |
| self.batch_size = args.batch_size |
| self.max_dim = args.dim |
| self.whiten = args.whiten |
| self.p0_mu = [ |
| [-4.5, -4.0, 0.5], |
| [-4.2, -3.5, 0.5], |
| [-4.0, -3.0, 0.5], |
| [-3.75, -2.5, 0.5], |
| ] |
| self.p0_var = 0.02 |
| |
| self.p1_1_mu = [ |
| [-2.5, -0.25, 0.5], |
| [-2.25, 0.675, 0.5], |
| [-2, 1.5, 0.5], |
| ] |
| self.p1_2_mu = [ |
| [2, -2, 0.5], |
| [2.6, -1.25, 0.5], |
| [3.2, -0.5, 0.5] |
| ] |
| |
| self.p1_var = 0.03 |
| self.k = 20 |
| self.n_samples = 5000 |
| self.num_timesteps = 2 |
| self.split_ratios = args.split_ratios |
| self._prepare_data() |
| |
| def assign_region(self): |
| all_centers = { |
| 0: torch.tensor(self.p0_mu), |
| 1: torch.tensor(self.p1_1_mu), |
| 2: torch.tensor(self.p1_2_mu), |
| } |
|
|
| dataset = self.dataset.to(torch.float32) |
| N = dataset.shape[0] |
| assignments = torch.zeros(N, dtype=torch.long) |
|
|
| |
| for i in range(N): |
| point = dataset[i] |
| min_dist = float("inf") |
| best_region = 0 |
| for region, centers in all_centers.items(): |
| dists = ((centers - point)**2).sum(dim=1) |
| region_min = dists.min() |
| if region_min < min_dist: |
| min_dist = region_min |
| best_region = region |
| assignments[i] = best_region |
| return assignments |
|
|
| def _prepare_data(self): |
| las = laspy.read(self.data_path) |
| |
| self.mask = las.classification == 2 |
| |
| x_offset, x_scale = las.header.offsets[0], las.header.scales[0] |
| y_offset, y_scale = las.header.offsets[1], las.header.scales[1] |
| z_offset, z_scale = las.header.offsets[2], las.header.scales[2] |
| dataset = np.vstack( |
| ( |
| las.X[self.mask] * x_scale + x_offset, |
| las.Y[self.mask] * y_scale + y_offset, |
| las.Z[self.mask] * z_scale + z_offset, |
| ) |
| ).transpose() |
| mi = dataset.min(axis=0, keepdims=True) |
| ma = dataset.max(axis=0, keepdims=True) |
| dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0] |
|
|
| self.dataset = torch.tensor(dataset, dtype=torch.float32) |
| self.tree = cKDTree(dataset) |
|
|
| x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples) |
| x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples) |
| x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples) |
|
|
| x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian) |
| x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian) |
| x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian) |
|
|
| split_index = int(self.n_samples * self.split_ratios[0]) |
|
|
| self.scaler = StandardScaler() |
| if self.whiten: |
| self.dataset = torch.tensor( |
| self.scaler.fit_transform(dataset), dtype=torch.float32 |
| ) |
| x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32) |
| x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32) |
| x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32) |
|
|
| train_x0 = x0[:split_index] |
| val_x0 = x0[split_index:] |
| |
| |
| train_x1_1 = x1_1[:split_index] |
| print("train_x1_1") |
| print(train_x1_1.shape) |
| val_x1_1 = x1_1[split_index:] |
| train_x1_2 = x1_2[:split_index] |
| val_x1_2 = x1_2[split_index:] |
| |
| self.val_x0 = val_x0 |
|
|
| |
| if self.n_samples - split_index < self.batch_size: |
| split_index = self.n_samples - self.batch_size |
|
|
| self.train_dataloaders = { |
| "x0": DataLoader(train_x0, batch_size=self.batch_size, shuffle=True, drop_last=True), |
| "x1_1": DataLoader(train_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True), |
| "x1_2": DataLoader(train_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True), |
| } |
| self.val_dataloaders = { |
| "x0": DataLoader(val_x0, batch_size=self.batch_size, shuffle=False, drop_last=True), |
| "x1_1": DataLoader(val_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True), |
| "x1_2": DataLoader(val_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True), |
| } |
| |
| self.test_dataloaders = [ |
| DataLoader( |
| self.val_x0, |
| batch_size=self.val_x0.shape[0], |
| shuffle=False, |
| drop_last=False, |
| ), |
| DataLoader( |
| self.dataset, |
| batch_size=self.dataset.shape[0], |
| shuffle=False, |
| drop_last=False, |
| ), |
| ] |
| |
| points = self.dataset.cpu().numpy() |
| x, y = points[:, 0], points[:, 1] |
| |
| u = (x + y) / np.sqrt(2) |
| |
| u_thresh = np.percentile(u, 30) |
| mask_A = u <= u_thresh |
|
|
| |
| remaining = ~mask_A |
| mask_B = remaining & (x < y) |
| mask_C = remaining & (x >= y) |
|
|
| |
| self.metric_samples_dataloaders = [ |
| DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False), |
| DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False), |
| DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False), |
| ] |
|
|
| def train_dataloader(self): |
| combined_loaders = { |
| "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), |
| "metric_samples": CombinedLoader( |
| self.metric_samples_dataloaders, mode="min_size" |
| ), |
| } |
| return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
| def val_dataloader(self): |
| combined_loaders = { |
| "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), |
| "metric_samples": CombinedLoader( |
| self.metric_samples_dataloaders, mode="min_size" |
| ), |
| } |
|
|
| return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
| def test_dataloader(self): |
| return CombinedLoader(self.test_dataloaders) |
|
|
| def get_tangent_proj(self, points): |
| w = self.get_tangent_plane(points) |
| return partial(BranchedLidarDataModule.projection_op, w=w) |
|
|
| def get_tangent_plane(self, points, temp=1e-3): |
| points_np = points.detach().cpu().numpy() |
| _, idx = self.tree.query(points_np, k=self.k) |
| nearest_pts = self.dataset[idx] |
| nearest_pts = torch.tensor(nearest_pts).to(points) |
|
|
| dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) |
| weights = torch.exp(-dists / temp) |
|
|
| |
| w = BranchedLidarDataModule.fit_plane(nearest_pts, weights) |
| return w |
|
|
| @staticmethod |
| def fit_plane(points, weights=None): |
| """Expects points to be of shape (..., 3). |
| Returns [a, b, c] such that the plane is defined as |
| ax + by + c = z |
| """ |
| D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1) |
| z = points[..., 2] |
| if weights is not None: |
| Dtrans = D.transpose(-1, -2) |
| else: |
| DW = D * weights |
| Dtrans = DW.transpose(-1, -2) |
| w = torch.linalg.solve( |
| torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1)) |
| ).squeeze(-1) |
| return w |
|
|
| @staticmethod |
| def projection_op(x, w): |
| """Projects points to a plane defined by w.""" |
| |
| n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1) |
|
|
| pn = torch.sum(x * n, dim=-1, keepdim=True) |
| nn = torch.sum(n * n, dim=-1, keepdim=True) |
|
|
| |
| d = w[..., 2:3] |
|
|
| |
| projn_x = ((pn + d) / nn) * n |
|
|
| |
| return x - projn_x |
|
|
| class WeightedBranchedLidarDataModule(pl.LightningDataModule): |
| def __init__(self, args): |
| super().__init__() |
| self.save_hyperparameters() |
|
|
| self.data_path = args.data_path |
| self.batch_size = args.batch_size |
| self.max_dim = args.dim |
| self.whiten = args.whiten |
| self.p0_mu = [ |
| [-4.5, -4.0, 0.5], |
| [-4.2, -3.5, 0.5], |
| [-4.0, -3.0, 0.5], |
| [-3.75, -2.5, 0.5], |
| ] |
| self.p0_var = 0.02 |
| |
| |
| self.p1_1_mu = [ |
| [-2.5, -0.25, 0.5], |
| [-2.25, 0.675, 0.5], |
| [-2, 1.5, 0.5], |
| ] |
| self.p1_2_mu = [ |
| [2, -2, 0.5], |
| [2.6, -1.25, 0.5], |
| [3.2, -0.5, 0.5] |
| ] |
| |
| self.p1_var = 0.03 |
| self.k = 20 |
| self.n_samples = 5000 |
| self.num_timesteps = 2 |
| self.split_ratios = args.split_ratios |
| |
| self.num_timesteps = 2 |
| self.metric_clusters = 3 |
| self.args = args |
| self._prepare_data() |
|
|
| def _prepare_data(self): |
| las = laspy.read(self.data_path) |
| |
| self.mask = las.classification == 2 |
| |
| x_offset, x_scale = las.header.offsets[0], las.header.scales[0] |
| y_offset, y_scale = las.header.offsets[1], las.header.scales[1] |
| z_offset, z_scale = las.header.offsets[2], las.header.scales[2] |
| dataset = np.vstack( |
| ( |
| las.X[self.mask] * x_scale + x_offset, |
| las.Y[self.mask] * y_scale + y_offset, |
| las.Z[self.mask] * z_scale + z_offset, |
| ) |
| ).transpose() |
| mi = dataset.min(axis=0, keepdims=True) |
| ma = dataset.max(axis=0, keepdims=True) |
| dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0] |
|
|
| self.dataset = torch.tensor(dataset, dtype=torch.float32) |
| self.tree = cKDTree(dataset) |
|
|
| x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples) |
| x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples) |
| x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples) |
|
|
| x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian) |
| x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian) |
| x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian) |
|
|
| split_index = int(self.n_samples * self.split_ratios[0]) |
|
|
| self.scaler = StandardScaler() |
| if self.whiten: |
| self.dataset = torch.tensor( |
| self.scaler.fit_transform(dataset), dtype=torch.float32 |
| ) |
| x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32) |
| x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32) |
| x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32) |
|
|
| self.coords_t0 = x0 |
| self.coords_t1_1 = x1_1 |
| self.coords_t1_2 = x1_2 |
| self.time_labels = np.concatenate([ |
| np.zeros(len(self.coords_t0)), |
| np.ones(len(self.coords_t1_1)), |
| np.ones(len(self.coords_t1_2)), |
| ]) |
| |
| train_x0 = x0[:split_index] |
| val_x0 = x0[split_index:] |
| |
| |
| train_x1_1 = x1_1[:split_index] |
| |
| val_x1_1 = x1_1[split_index:] |
| train_x1_2 = x1_2[:split_index] |
| val_x1_2 = x1_2[split_index:] |
| |
| self.val_x0 = val_x0 |
| |
| train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) |
| train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5) |
| train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5) |
| |
| val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) |
| val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5) |
| val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5) |
|
|
| |
| if self.n_samples - split_index < self.batch_size: |
| split_index = self.n_samples - self.batch_size |
| |
| self.train_dataloaders = { |
| "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| } |
| |
| self.val_dataloaders = { |
| "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), |
| "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| } |
| |
| |
| self.test_dataloaders = { |
| "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), |
| "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True), |
| "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True), |
| "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), |
| } |
| |
| points = self.dataset.cpu().numpy() |
| x, y = points[:, 0], points[:, 1] |
| |
| u = (x + y) / np.sqrt(2) |
| |
| u_thresh = np.percentile(u, 30) |
| mask_A = u <= u_thresh |
|
|
| |
| remaining = ~mask_A |
| mask_B = remaining & (x < y) |
| mask_C = remaining & (x >= y) |
|
|
| |
| self.metric_samples_dataloaders = [ |
| DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False), |
| DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False), |
| DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False), |
| ] |
|
|
| def train_dataloader(self): |
| combined_loaders = { |
| "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), |
| "metric_samples": CombinedLoader( |
| self.metric_samples_dataloaders, mode="min_size" |
| ), |
| } |
| return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
| def val_dataloader(self): |
| combined_loaders = { |
| "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), |
| "metric_samples": CombinedLoader( |
| self.metric_samples_dataloaders, mode="min_size" |
| ), |
| } |
|
|
| return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
| def test_dataloader(self): |
| combined_loaders = { |
| "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), |
| "metric_samples": CombinedLoader( |
| self.metric_samples_dataloaders, mode="min_size" |
| ), |
| } |
|
|
| return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
| def get_tangent_proj(self, points): |
| w = self.get_tangent_plane(points) |
| return partial(BranchedLidarDataModule.projection_op, w=w) |
|
|
| def get_tangent_plane(self, points, temp=1e-3): |
| points_np = points.detach().cpu().numpy() |
| _, idx = self.tree.query(points_np, k=self.k) |
| nearest_pts = self.dataset[idx] |
| nearest_pts = torch.tensor(nearest_pts).to(points) |
|
|
| dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) |
| weights = torch.exp(-dists / temp) |
|
|
| |
| w = BranchedLidarDataModule.fit_plane(nearest_pts, weights) |
| return w |
|
|
| @staticmethod |
| def fit_plane(points, weights=None): |
| """Expects points to be of shape (..., 3). |
| Returns [a, b, c] such that the plane is defined as |
| ax + by + c = z |
| """ |
| D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1) |
| z = points[..., 2] |
| if weights is not None: |
| Dtrans = D.transpose(-1, -2) |
| else: |
| DW = D * weights |
| Dtrans = DW.transpose(-1, -2) |
| w = torch.linalg.solve( |
| torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1)) |
| ).squeeze(-1) |
| return w |
|
|
| @staticmethod |
| def projection_op(x, w): |
| """Projects points to a plane defined by w.""" |
| |
| n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1) |
|
|
| pn = torch.sum(x * n, dim=-1, keepdim=True) |
| nn = torch.sum(n * n, dim=-1, keepdim=True) |
|
|
| |
| d = w[..., 2:3] |
|
|
| |
| projn_x = ((pn + d) / nn) * n |
|
|
| |
| return x - projn_x |
| |
| def get_timepoint_data(self): |
| """Return data organized by timepoints for visualization""" |
| return { |
| 't0': self.coords_t0, |
| 't1_1': self.coords_t1_1, |
| 't1_2': self.coords_t1_2, |
| 'time_labels': self.time_labels |
| } |
|
|
| def get_datamodule(): |
| datamodule = WeightedBranchedLidarDataModule(args) |
| datamodule.setup(stage="fit") |
| return datamodule |