""" Two-stage optimization: 1. SDF learning from point cloud 2. Vertex generation + Delaunay meshing All pure PyTorch, no compiled extensions. """ import os import math import time import torch import torch.nn.functional as F import numpy as np from tqdm import tqdm from .sdfnet import SDFNetwork from .vgnet import VGNetwork from . import losses as loss_utils from . import meshing as mesh_utils from .io_utils import ( load_pointcloud, normalize_pointcloud, denormalize_pointcloud, estimate_normals, fps_sample, build_sigma_knn, save_mesh_ply, save_mesh_obj, ) class Runner: def __init__(self, pointcloud_path, out_dir='./output', device='cpu', sdf_iters=20_000, vg_iters=8_000, sdf_lr=1e-3, vg_lr=1e-3, sdf_batch=5_000, vg_batch=3_400, vertices_size=3_400, update_size=5, update_ratio=1.2, k_samples=21, multires=8, queries_size=1_000_000, surface_queries=200_000, project_sdf_level=0.0, save_freq=2_000, loss_weights_sdf=None, loss_weights_vg=None, ): self.device = torch.device(device) self.out_dir = out_dir os.makedirs(out_dir, exist_ok=True) # Load & normalize point cloud print("Loading point cloud...") raw_pts = load_pointcloud(pointcloud_path) self.raw_pts = raw_pts self.points, self.loc, self.scale = normalize_pointcloud(raw_pts) print(f" Points: {len(self.points)} Scale: {self.scale:.4f}") # Preprocess: build sigma + query samples for SDF training print("Preprocessing queries for SDF training...") self._preprocess_sdf_queries(queries_size) # Config self.sdf_iters = sdf_iters self.vg_iters = vg_iters self.sdf_lr = sdf_lr self.vg_lr = vg_lr self.sdf_batch = sdf_batch self.vg_batch = vg_batch self.vertices_size = vertices_size self.update_size = update_size self.update_ratio = update_ratio self.k_samples = k_samples self.project_sdf_level = project_sdf_level self.save_freq = save_freq self.loss_weights_sdf = loss_weights_sdf or [1.0, 0.1, 0.001, 0.0] self.loss_weights_vg = loss_weights_vg or [100.0, 1.0, 1.0, 1.0, 100.0] # Networks self.sdf_net = SDFNetwork( d_in=3, d_out=1, d_hidden=256, n_layers=8, skip_in=(4,), multires=multires, bias=0.5, scale=1.0, geometric_init=True, weight_norm=True, ).to(self.device) self.vg_net = VGNetwork( d_in=3, d_out=3, d_hidden=256, n_layers=8, skip_in=(4,), multires=multires, scale=1.0, geometric_init=True, weight_norm=True, ).to(self.device) self.sdf_optimizer = torch.optim.Adam(self.sdf_net.parameters(), lr=self.sdf_lr) self.vg_optimizer = torch.optim.Adam(self.vg_net.parameters(), lr=self.vg_lr) self.iter_step = 0 # ------------------------------------------------------------------ # SDF preprocessing # ------------------------------------------------------------------ def _preprocess_sdf_queries(self, queries_size): pts = self.points point_num = len(pts) point_num_gt = (point_num // 60) * 60 if point_num_gt == 0: point_num_gt = point_num query_each = max(queries_size // point_num_gt, 1) # subsample to ~1/60 if point_num > point_num_gt: idx = np.random.choice(point_num, point_num_gt, replace=False) else: idx = np.arange(point_num) subsample = pts[idx] sigmas = build_sigma_knn(subsample, k=min(51, len(subsample))) sample = [] sample_near = [] scale = 0.25 * np.sqrt(max(point_num_gt, 1) / 20000.0) for _ in range(query_each): tt = subsample + scale * sigmas[:, None] * np.random.normal(0.0, 1.0, size=subsample.shape) sample.append(tt) sample_near.append(subsample) sample = np.concatenate(sample, axis=0).astype(np.float32) sample_near = np.concatenate(sample_near, axis=0).astype(np.float32) n_uniform = max(sample.shape[0] // 10, 1) sample_uniform = 1.1 * (np.random.rand(n_uniform, 3).astype(np.float32) - 0.5) sample_uniform_near = subsample[np.random.choice(len(subsample), n_uniform, replace=True)] self.sample = torch.from_numpy(sample).to(self.device) self.sample_near = torch.from_numpy(sample_near).to(self.device) self.sample_uniform = torch.from_numpy(sample_uniform).to(self.device) self.sample_uniform_near = torch.from_numpy(sample_uniform_near).to(self.device) self.point_gt = torch.from_numpy(subsample).to(self.device) self.surface_queries_size = min(200_000, len(subsample)) # bbox self.bbox_min = subsample.min(axis=0) - 0.05 self.bbox_max = subsample.max(axis=0) + 0.05 # ------------------------------------------------------------------ # SDF stage # ------------------------------------------------------------------ def train_sdf(self): print("\n=== Stage 1: SDF Learning ===") self.sdf_net.train() pbar = tqdm(range(self.sdf_iters), desc="SDF") for iter_i in pbar: self.update_lr(self.sdf_optimizer, iter_i, self.sdf_iters, self.sdf_lr, warm_up_end=1000) # Sample batch n_near = self.sdf_batch idx_near = np.random.choice(len(self.sample), n_near, replace=False) idx_uniform = np.random.choice(len(self.sample_uniform), max(n_near // 2, 1), replace=False) sample_near = self.sample[idx_near] points_near = self.sample_near[idx_near] sample_uniform = self.sample_uniform[idx_uniform] points_uniform = self.sample_uniform_near[idx_uniform] samples = torch.cat([sample_near, sample_uniform], dim=0) gradients_samples, sdf_samples = self.sdf_net.gradient(samples) gradients_samples_norm = F.normalize(gradients_samples, dim=-1) samples_moved = samples - gradients_samples_norm * sdf_samples # Gradient consistency move_pos = samples_moved.detach() grad_moved, _ = self.sdf_net.gradient(move_pos) grad_moved_norm = F.normalize(grad_moved, dim=-1) loss_grad_consis = (1.0 - F.cosine_similarity(grad_moved_norm, gradients_samples_norm, dim=-1)).mean() points = torch.cat([points_near, points_uniform], dim=0) sdf_points = self.sdf_net.sdf(points) loss_pull = torch.linalg.norm((points - samples_moved), ord=2, dim=-1).mean() loss_sdf = torch.abs(sdf_points).mean() loss_inter = torch.exp(-100.0 * torch.abs(sdf_samples)).mean() loss_normal = torch.zeros(1, device=self.device) loss_eik = loss_utils.eikonal_loss(gradients_samples) loss_div = loss_utils.div_loss(samples, gradients_samples) w = self.loss_weights_sdf loss = (w[0] * loss_pull + w[1] * loss_sdf + w[2] * loss_grad_consis + w[3] * loss_inter + 0.01 * loss_normal + 0.005 * loss_eik + 0.001 * loss_div) self.sdf_optimizer.zero_grad() loss.backward() self.sdf_optimizer.step() if (iter_i + 1) % 500 == 0: pbar.set_postfix(loss=f"{loss.item():.4f}") if (iter_i + 1) % self.save_freq == 0: self.save_sdf_checkpoint(iter_i + 1) print("SDF training complete.") self.save_sdf_checkpoint('final') def update_lr(self, optimizer, iter_step, max_iter, init_lr, warm_up_end=1000): if iter_step < warm_up_end: lr = (iter_step / warm_up_end) * init_lr else: lr = 0.5 * (math.cos((iter_step - warm_up_end) / (max_iter - warm_up_end) * math.pi) + 1) * init_lr for g in optimizer.param_groups: g['lr'] = lr def save_sdf_checkpoint(self, tag): ckpt = { 'iter_step': self.iter_step, 'sdf_network': self.sdf_net.state_dict(), } os.makedirs(os.path.join(self.out_dir, 'sdf_checkpoints'), exist_ok=True) torch.save(ckpt, os.path.join(self.out_dir, 'sdf_checkpoints', f'sdf_{tag}.pth')) def load_sdf_checkpoint(self, path): ckpt = torch.load(path, map_location=self.device) self.sdf_net.load_state_dict(ckpt['sdf_network']) self.iter_step = ckpt.get('iter_step', 0) # ------------------------------------------------------------------ # VG stage helpers # ------------------------------------------------------------------ @torch.no_grad() def get_surface_queries(self, noisy_pts=False): """Project point_gt onto the learned SDF surface.""" sdf_level = self.project_sdf_level queries = self.point_gt.clone() if noisy_pts or sdf_level != 0.0: queries = self.project_queries(queries, sdf_level) n = len(queries) target = min(self.surface_queries_size, n + len(self.sample)) if target > n: pad_size = target - n # Use FPS on projected samples pad_queries = self.sample.clone() pad_queries = self.project_queries(pad_queries, sdf_level) idx = fps_sample(pad_queries.cpu().numpy(), pad_size) pad_queries = pad_queries[idx] queries = torch.cat([queries, pad_queries], dim=0) return queries.detach() @torch.no_grad() def project_queries(self, queries, sdf_level): batch_size = 100_000 out = [] for i in range(0, len(queries), batch_size): batch = queries[i:i + batch_size] for _ in range(10): grad, sdf = self.sdf_net.gradient(batch) grad = F.normalize(grad, dim=-1) batch = batch - grad * (sdf - sdf_level) out.append(batch) return torch.cat(out, dim=0) # ------------------------------------------------------------------ # VG stage # ------------------------------------------------------------------ def train_vg(self, vertices_size=None): if vertices_size is None: vertices_size = self.vertices_size print(f"\n=== Stage 2: Vertex Generation ({vertices_size} vertices) ===") self.vg_net.train() self.sdf_net.eval() # Build target surface queries print("Projecting surface queries...") point_gt = self.get_surface_queries() print(f" Surface queries: {len(point_gt)}") # Sample initial vertices via FPS sample_points = self.fps_select_vertices(point_gt, vertices_size) sample_normal, _ = self.sdf_net.gradient(sample_points) sample_normal = F.normalize(sample_normal.detach(), dim=-1) # Curvature on surface normal_gt, _ = self.sdf_net.gradient(point_gt) normal_gt = F.normalize(normal_gt.detach(), dim=-1) curvature_surface = loss_utils.cal_curvature_with_normal( point_gt, normal_gt, knn=min(16, len(point_gt) - 1)).detach() # Generate curriculum sizes batch_sizes = self.generate_list_with_ratio(vertices_size) print(f" Curriculum sizes: {batch_sizes}") cur_size_idx = 0 current_batch_size = batch_sizes[cur_size_idx] sample_points = self.fps_select_vertices(point_gt, current_batch_size) sample_normal, _ = self.sdf_net.gradient(sample_points) sample_normal = F.normalize(sample_normal.detach(), dim=-1) pbar = tqdm(range(self.vg_iters), desc="VG") size_update_freq = self.vg_iters // (self.update_size + 1) if size_update_freq == 0: size_update_freq = self.vg_iters nearest_clamp = self.cal_nearest_clamp(sample_points) for iter_i in pbar: generated = self.vg_net(sample_points, sample_normal) vertices_grad, _ = self.sdf_net.gradient(generated) loss = loss_utils.cal_vg_loss( point_gt, normal_gt, curvature_surface, generated, vertices_grad, self.loss_weights_vg, nearest_clamp) self.vg_optimizer.zero_grad() loss.backward(retain_graph=True) self.vg_optimizer.step() if (iter_i + 1) % 500 == 0: pbar.set_postfix(loss=f"{loss.item():.4f}") # Curriculum: increase vertex count if (iter_i + 1) % size_update_freq == 0: cur_size_idx += 1 if cur_size_idx < len(batch_sizes): current_batch_size = batch_sizes[cur_size_idx] moved = self.move_to_surface(generated) curv = loss_utils.cal_curvature_with_normal( moved, F.normalize(vertices_grad.detach(), dim=-1), knn=min(16, len(moved) - 1)) sample_points = self.upsample(curv, moved, point_gt, current_batch_size) sample_points = sample_points.detach() sn, _ = self.sdf_net.gradient(sample_points) sample_normal = F.normalize(sn.detach(), dim=-1) nearest_clamp = self.cal_nearest_clamp(sample_points) # Final projection to surface final_vertices = self.move_to_surface(generated).detach().cpu().numpy() print(f" Generated {len(final_vertices)} vertices.") return final_vertices def generate_list_with_ratio(self, final_size): """Build curriculum vertex counts.""" sizes = [int(final_size / (self.update_ratio ** (self.update_size - i))) for i in range(self.update_size)] sizes.append(final_size) # Ensure monotonic for i in range(1, len(sizes)): sizes[i] = max(sizes[i], sizes[i - 1] + 1) return sizes def fps_select_vertices(self, point_gt, batch_size): idx = fps_sample(point_gt.cpu().numpy(), min(batch_size, len(point_gt))) return point_gt[idx].detach() def cal_nearest_clamp(self, sample_pts): pts_np = sample_pts.detach().cpu().numpy() from scipy.spatial import KDTree tree = KDTree(pts_np) _, idx = tree.query(pts_np, k=2) idx = torch.from_numpy(idx[:, 1]).long().to(sample_pts.device) neigh = sample_pts[idx] dist = torch.linalg.norm(neigh - sample_pts, ord=2, dim=-1) ** 2 return dist.mean().item() def move_to_surface(self, generated, step=10): for _ in range(step): grad, sdf = self.sdf_net.gradient(generated) grad = F.normalize(grad.detach(), dim=-1) generated = generated - grad * (sdf.detach() - self.project_sdf_level) return generated.detach() def upsample(self, curvature, pts, point_gt, sample_size): """Upsample to target size by adding high-curvature neighbors.""" if len(pts) >= sample_size: return pts[:sample_size] up = sample_size - len(pts) topk = min(up, len(pts)) _, top_idx = torch.topk(curvature.view(-1), k=topk, largest=True) best = pts[top_idx] from scipy.spatial import KDTree tree = KDTree(point_gt.cpu().numpy()) _, idx = tree.query(best.cpu().numpy(), k=1) idx = torch.from_numpy(idx).long().to(pts.device) added = point_gt[idx] return torch.cat([pts, added], dim=0) # ------------------------------------------------------------------ # Meshing # ------------------------------------------------------------------ def generate_mesh(self, vertices, save_path=None): print("\n=== Meshing ===") v, f = mesh_utils.delaunay_meshing( vertices, self.sdf_net, sdf_threshold=self.project_sdf_level, k_samples=self.k_samples, device=self.device) if len(f) > 0: v, f = mesh_utils.add_mid_vertices(v, f) # Denormalize v = denormalize_pointcloud(v, self.loc, self.scale) if save_path: if save_path.endswith('.obj'): save_mesh_obj(save_path, v, f) else: save_mesh_ply(save_path, v, f) print(f"Saved mesh to {save_path}") return v, f # ------------------------------------------------------------------ # End-to-end # ------------------------------------------------------------------ def run(self, mesh_path=None): self.train_sdf() vertices = self.train_vg() v, f = self.generate_mesh(vertices, save_path=mesh_path) return v, f