| """ |
| Two-stage optimization: |
| 1. SDF learning from point cloud |
| 2. Vertex generation + Delaunay meshing |
| |
| All pure PyTorch, no compiled extensions. |
| """ |
| import os |
| import math |
| import time |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from .sdfnet import SDFNetwork |
| from .vgnet import VGNetwork |
| from . import losses as loss_utils |
| from . import meshing as mesh_utils |
| from .io_utils import ( |
| load_pointcloud, |
| normalize_pointcloud, |
| denormalize_pointcloud, |
| estimate_normals, |
| fps_sample, |
| build_sigma_knn, |
| save_mesh_ply, |
| save_mesh_obj, |
| ) |
|
|
|
|
| class Runner: |
| def __init__(self, |
| pointcloud_path, |
| out_dir='./output', |
| device='cpu', |
| sdf_iters=20_000, |
| vg_iters=8_000, |
| sdf_lr=1e-3, |
| vg_lr=1e-3, |
| sdf_batch=5_000, |
| vg_batch=3_400, |
| vertices_size=3_400, |
| update_size=5, |
| update_ratio=1.2, |
| k_samples=21, |
| multires=8, |
| queries_size=1_000_000, |
| surface_queries=200_000, |
| project_sdf_level=0.0, |
| save_freq=2_000, |
| loss_weights_sdf=None, |
| loss_weights_vg=None, |
| ): |
| self.device = torch.device(device) |
| self.out_dir = out_dir |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| |
| print("Loading point cloud...") |
| raw_pts = load_pointcloud(pointcloud_path) |
| self.raw_pts = raw_pts |
| self.points, self.loc, self.scale = normalize_pointcloud(raw_pts) |
| print(f" Points: {len(self.points)} Scale: {self.scale:.4f}") |
|
|
| |
| print("Preprocessing queries for SDF training...") |
| self._preprocess_sdf_queries(queries_size) |
|
|
| |
| self.sdf_iters = sdf_iters |
| self.vg_iters = vg_iters |
| self.sdf_lr = sdf_lr |
| self.vg_lr = vg_lr |
| self.sdf_batch = sdf_batch |
| self.vg_batch = vg_batch |
| self.vertices_size = vertices_size |
| self.update_size = update_size |
| self.update_ratio = update_ratio |
| self.k_samples = k_samples |
| self.project_sdf_level = project_sdf_level |
| self.save_freq = save_freq |
|
|
| self.loss_weights_sdf = loss_weights_sdf or [1.0, 0.1, 0.001, 0.0] |
| self.loss_weights_vg = loss_weights_vg or [100.0, 1.0, 1.0, 1.0, 100.0] |
|
|
| |
| self.sdf_net = SDFNetwork( |
| d_in=3, d_out=1, d_hidden=256, n_layers=8, |
| skip_in=(4,), multires=multires, |
| bias=0.5, scale=1.0, |
| geometric_init=True, weight_norm=True, |
| ).to(self.device) |
| self.vg_net = VGNetwork( |
| d_in=3, d_out=3, d_hidden=256, n_layers=8, |
| skip_in=(4,), multires=multires, |
| scale=1.0, geometric_init=True, weight_norm=True, |
| ).to(self.device) |
|
|
| self.sdf_optimizer = torch.optim.Adam(self.sdf_net.parameters(), lr=self.sdf_lr) |
| self.vg_optimizer = torch.optim.Adam(self.vg_net.parameters(), lr=self.vg_lr) |
|
|
| self.iter_step = 0 |
|
|
| |
| |
| |
| def _preprocess_sdf_queries(self, queries_size): |
| pts = self.points |
| point_num = len(pts) |
| point_num_gt = (point_num // 60) * 60 |
| if point_num_gt == 0: |
| point_num_gt = point_num |
| query_each = max(queries_size // point_num_gt, 1) |
|
|
| |
| if point_num > point_num_gt: |
| idx = np.random.choice(point_num, point_num_gt, replace=False) |
| else: |
| idx = np.arange(point_num) |
| subsample = pts[idx] |
|
|
| sigmas = build_sigma_knn(subsample, k=min(51, len(subsample))) |
|
|
| sample = [] |
| sample_near = [] |
| scale = 0.25 * np.sqrt(max(point_num_gt, 1) / 20000.0) |
| for _ in range(query_each): |
| tt = subsample + scale * sigmas[:, None] * np.random.normal(0.0, 1.0, size=subsample.shape) |
| sample.append(tt) |
| sample_near.append(subsample) |
|
|
| sample = np.concatenate(sample, axis=0).astype(np.float32) |
| sample_near = np.concatenate(sample_near, axis=0).astype(np.float32) |
| n_uniform = max(sample.shape[0] // 10, 1) |
| sample_uniform = 1.1 * (np.random.rand(n_uniform, 3).astype(np.float32) - 0.5) |
| sample_uniform_near = subsample[np.random.choice(len(subsample), n_uniform, replace=True)] |
|
|
| self.sample = torch.from_numpy(sample).to(self.device) |
| self.sample_near = torch.from_numpy(sample_near).to(self.device) |
| self.sample_uniform = torch.from_numpy(sample_uniform).to(self.device) |
| self.sample_uniform_near = torch.from_numpy(sample_uniform_near).to(self.device) |
| self.point_gt = torch.from_numpy(subsample).to(self.device) |
| self.surface_queries_size = min(200_000, len(subsample)) |
|
|
| |
| self.bbox_min = subsample.min(axis=0) - 0.05 |
| self.bbox_max = subsample.max(axis=0) + 0.05 |
|
|
| |
| |
| |
| def train_sdf(self): |
| print("\n=== Stage 1: SDF Learning ===") |
| self.sdf_net.train() |
| pbar = tqdm(range(self.sdf_iters), desc="SDF") |
| for iter_i in pbar: |
| self.update_lr(self.sdf_optimizer, iter_i, self.sdf_iters, self.sdf_lr, warm_up_end=1000) |
|
|
| |
| n_near = self.sdf_batch |
| idx_near = np.random.choice(len(self.sample), n_near, replace=False) |
| idx_uniform = np.random.choice(len(self.sample_uniform), max(n_near // 2, 1), replace=False) |
|
|
| sample_near = self.sample[idx_near] |
| points_near = self.sample_near[idx_near] |
| sample_uniform = self.sample_uniform[idx_uniform] |
| points_uniform = self.sample_uniform_near[idx_uniform] |
|
|
| samples = torch.cat([sample_near, sample_uniform], dim=0) |
| gradients_samples, sdf_samples = self.sdf_net.gradient(samples) |
| gradients_samples_norm = F.normalize(gradients_samples, dim=-1) |
| samples_moved = samples - gradients_samples_norm * sdf_samples |
|
|
| |
| move_pos = samples_moved.detach() |
| grad_moved, _ = self.sdf_net.gradient(move_pos) |
| grad_moved_norm = F.normalize(grad_moved, dim=-1) |
| loss_grad_consis = (1.0 - F.cosine_similarity(grad_moved_norm, gradients_samples_norm, dim=-1)).mean() |
|
|
| points = torch.cat([points_near, points_uniform], dim=0) |
| sdf_points = self.sdf_net.sdf(points) |
|
|
| loss_pull = torch.linalg.norm((points - samples_moved), ord=2, dim=-1).mean() |
| loss_sdf = torch.abs(sdf_points).mean() |
| loss_inter = torch.exp(-100.0 * torch.abs(sdf_samples)).mean() |
| loss_normal = torch.zeros(1, device=self.device) |
| loss_eik = loss_utils.eikonal_loss(gradients_samples) |
| loss_div = loss_utils.div_loss(samples, gradients_samples) |
|
|
| w = self.loss_weights_sdf |
| loss = (w[0] * loss_pull + |
| w[1] * loss_sdf + |
| w[2] * loss_grad_consis + |
| w[3] * loss_inter + |
| 0.01 * loss_normal + |
| 0.005 * loss_eik + |
| 0.001 * loss_div) |
|
|
| self.sdf_optimizer.zero_grad() |
| loss.backward() |
| self.sdf_optimizer.step() |
|
|
| if (iter_i + 1) % 500 == 0: |
| pbar.set_postfix(loss=f"{loss.item():.4f}") |
| if (iter_i + 1) % self.save_freq == 0: |
| self.save_sdf_checkpoint(iter_i + 1) |
|
|
| print("SDF training complete.") |
| self.save_sdf_checkpoint('final') |
|
|
| def update_lr(self, optimizer, iter_step, max_iter, init_lr, warm_up_end=1000): |
| if iter_step < warm_up_end: |
| lr = (iter_step / warm_up_end) * init_lr |
| else: |
| lr = 0.5 * (math.cos((iter_step - warm_up_end) / (max_iter - warm_up_end) * math.pi) + 1) * init_lr |
| for g in optimizer.param_groups: |
| g['lr'] = lr |
|
|
| def save_sdf_checkpoint(self, tag): |
| ckpt = { |
| 'iter_step': self.iter_step, |
| 'sdf_network': self.sdf_net.state_dict(), |
| } |
| os.makedirs(os.path.join(self.out_dir, 'sdf_checkpoints'), exist_ok=True) |
| torch.save(ckpt, os.path.join(self.out_dir, 'sdf_checkpoints', f'sdf_{tag}.pth')) |
|
|
| def load_sdf_checkpoint(self, path): |
| ckpt = torch.load(path, map_location=self.device) |
| self.sdf_net.load_state_dict(ckpt['sdf_network']) |
| self.iter_step = ckpt.get('iter_step', 0) |
|
|
| |
| |
| |
| @torch.no_grad() |
| def get_surface_queries(self, noisy_pts=False): |
| """Project point_gt onto the learned SDF surface.""" |
| sdf_level = self.project_sdf_level |
| queries = self.point_gt.clone() |
| if noisy_pts or sdf_level != 0.0: |
| queries = self.project_queries(queries, sdf_level) |
|
|
| n = len(queries) |
| target = min(self.surface_queries_size, n + len(self.sample)) |
| if target > n: |
| pad_size = target - n |
| |
| pad_queries = self.sample.clone() |
| pad_queries = self.project_queries(pad_queries, sdf_level) |
| idx = fps_sample(pad_queries.cpu().numpy(), pad_size) |
| pad_queries = pad_queries[idx] |
| queries = torch.cat([queries, pad_queries], dim=0) |
| return queries.detach() |
|
|
| @torch.no_grad() |
| def project_queries(self, queries, sdf_level): |
| batch_size = 100_000 |
| out = [] |
| for i in range(0, len(queries), batch_size): |
| batch = queries[i:i + batch_size] |
| for _ in range(10): |
| grad, sdf = self.sdf_net.gradient(batch) |
| grad = F.normalize(grad, dim=-1) |
| batch = batch - grad * (sdf - sdf_level) |
| out.append(batch) |
| return torch.cat(out, dim=0) |
|
|
| |
| |
| |
| def train_vg(self, vertices_size=None): |
| if vertices_size is None: |
| vertices_size = self.vertices_size |
| print(f"\n=== Stage 2: Vertex Generation ({vertices_size} vertices) ===") |
| self.vg_net.train() |
| self.sdf_net.eval() |
|
|
| |
| print("Projecting surface queries...") |
| point_gt = self.get_surface_queries() |
| print(f" Surface queries: {len(point_gt)}") |
|
|
| |
| sample_points = self.fps_select_vertices(point_gt, vertices_size) |
| sample_normal, _ = self.sdf_net.gradient(sample_points) |
| sample_normal = F.normalize(sample_normal.detach(), dim=-1) |
|
|
| |
| normal_gt, _ = self.sdf_net.gradient(point_gt) |
| normal_gt = F.normalize(normal_gt.detach(), dim=-1) |
| curvature_surface = loss_utils.cal_curvature_with_normal( |
| point_gt, normal_gt, knn=min(16, len(point_gt) - 1)).detach() |
|
|
| |
| batch_sizes = self.generate_list_with_ratio(vertices_size) |
| print(f" Curriculum sizes: {batch_sizes}") |
|
|
| cur_size_idx = 0 |
| current_batch_size = batch_sizes[cur_size_idx] |
| sample_points = self.fps_select_vertices(point_gt, current_batch_size) |
| sample_normal, _ = self.sdf_net.gradient(sample_points) |
| sample_normal = F.normalize(sample_normal.detach(), dim=-1) |
|
|
| pbar = tqdm(range(self.vg_iters), desc="VG") |
| size_update_freq = self.vg_iters // (self.update_size + 1) |
| if size_update_freq == 0: |
| size_update_freq = self.vg_iters |
|
|
| nearest_clamp = self.cal_nearest_clamp(sample_points) |
|
|
| for iter_i in pbar: |
| generated = self.vg_net(sample_points, sample_normal) |
| vertices_grad, _ = self.sdf_net.gradient(generated) |
|
|
| loss = loss_utils.cal_vg_loss( |
| point_gt, normal_gt, curvature_surface, |
| generated, vertices_grad, |
| self.loss_weights_vg, nearest_clamp) |
|
|
| self.vg_optimizer.zero_grad() |
| loss.backward(retain_graph=True) |
| self.vg_optimizer.step() |
|
|
| if (iter_i + 1) % 500 == 0: |
| pbar.set_postfix(loss=f"{loss.item():.4f}") |
|
|
| |
| if (iter_i + 1) % size_update_freq == 0: |
| cur_size_idx += 1 |
| if cur_size_idx < len(batch_sizes): |
| current_batch_size = batch_sizes[cur_size_idx] |
| moved = self.move_to_surface(generated) |
| curv = loss_utils.cal_curvature_with_normal( |
| moved, F.normalize(vertices_grad.detach(), dim=-1), |
| knn=min(16, len(moved) - 1)) |
| sample_points = self.upsample(curv, moved, point_gt, current_batch_size) |
| sample_points = sample_points.detach() |
| sn, _ = self.sdf_net.gradient(sample_points) |
| sample_normal = F.normalize(sn.detach(), dim=-1) |
| nearest_clamp = self.cal_nearest_clamp(sample_points) |
|
|
| |
| final_vertices = self.move_to_surface(generated).detach().cpu().numpy() |
| print(f" Generated {len(final_vertices)} vertices.") |
| return final_vertices |
|
|
| def generate_list_with_ratio(self, final_size): |
| """Build curriculum vertex counts.""" |
| sizes = [int(final_size / (self.update_ratio ** (self.update_size - i))) |
| for i in range(self.update_size)] |
| sizes.append(final_size) |
| |
| for i in range(1, len(sizes)): |
| sizes[i] = max(sizes[i], sizes[i - 1] + 1) |
| return sizes |
|
|
| def fps_select_vertices(self, point_gt, batch_size): |
| idx = fps_sample(point_gt.cpu().numpy(), min(batch_size, len(point_gt))) |
| return point_gt[idx].detach() |
|
|
| def cal_nearest_clamp(self, sample_pts): |
| pts_np = sample_pts.detach().cpu().numpy() |
| from scipy.spatial import KDTree |
| tree = KDTree(pts_np) |
| _, idx = tree.query(pts_np, k=2) |
| idx = torch.from_numpy(idx[:, 1]).long().to(sample_pts.device) |
| neigh = sample_pts[idx] |
| dist = torch.linalg.norm(neigh - sample_pts, ord=2, dim=-1) ** 2 |
| return dist.mean().item() |
|
|
| def move_to_surface(self, generated, step=10): |
| for _ in range(step): |
| grad, sdf = self.sdf_net.gradient(generated) |
| grad = F.normalize(grad.detach(), dim=-1) |
| generated = generated - grad * (sdf.detach() - self.project_sdf_level) |
| return generated.detach() |
|
|
| def upsample(self, curvature, pts, point_gt, sample_size): |
| """Upsample to target size by adding high-curvature neighbors.""" |
| if len(pts) >= sample_size: |
| return pts[:sample_size] |
| up = sample_size - len(pts) |
| topk = min(up, len(pts)) |
| _, top_idx = torch.topk(curvature.view(-1), k=topk, largest=True) |
| best = pts[top_idx] |
| from scipy.spatial import KDTree |
| tree = KDTree(point_gt.cpu().numpy()) |
| _, idx = tree.query(best.cpu().numpy(), k=1) |
| idx = torch.from_numpy(idx).long().to(pts.device) |
| added = point_gt[idx] |
| return torch.cat([pts, added], dim=0) |
|
|
| |
| |
| |
| def generate_mesh(self, vertices, save_path=None): |
| print("\n=== Meshing ===") |
| v, f = mesh_utils.delaunay_meshing( |
| vertices, self.sdf_net, |
| sdf_threshold=self.project_sdf_level, |
| k_samples=self.k_samples, |
| device=self.device) |
|
|
| if len(f) > 0: |
| v, f = mesh_utils.add_mid_vertices(v, f) |
|
|
| |
| v = denormalize_pointcloud(v, self.loc, self.scale) |
|
|
| if save_path: |
| if save_path.endswith('.obj'): |
| save_mesh_obj(save_path, v, f) |
| else: |
| save_mesh_ply(save_path, v, f) |
| print(f"Saved mesh to {save_path}") |
| return v, f |
|
|
| |
| |
| |
| def run(self, mesh_path=None): |
| self.train_sdf() |
| vertices = self.train_vg() |
| v, f = self.generate_mesh(vertices, save_path=mesh_path) |
| return v, f |
|
|