bdck commited on
Commit
0b9d51d
Β·
verified Β·
1 Parent(s): 03b9058

add main optimization loop

Browse files
Files changed (1) hide show
  1. point2mesh/optimize.py +292 -0
point2mesh/optimize.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core Point2Mesh optimisation loop.
3
+
4
+ Implements the coarse-to-fine self-prior optimisation: at each level
5
+ the CNN weights are re-initialised, a fixed random input is drawn,
6
+ and the network learns to deform the mesh surface to match the target
7
+ point cloud. Between levels the mesh is subdivided and remeshed.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ import time
14
+ from dataclasses import dataclass, field
15
+ from typing import Optional, Callable
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from .mesh import Mesh, PartMesh, edge_to_vertex_displacement
21
+ from .network import Point2MeshNet
22
+ from .losses import sample_surface, chamfer_loss, beam_gap_loss, normal_loss
23
+ from .io_utils import (
24
+ load_pointcloud,
25
+ build_initial_mesh,
26
+ remesh,
27
+ save_mesh,
28
+ estimate_normals,
29
+ )
30
+
31
+ logger = logging.getLogger("point2mesh")
32
+
33
+
34
+ # ──────────────────────────────────────────────────────────────────────
35
+ # Configuration
36
+ # ──────────────────────────────────────────────────────────────────────
37
+ @dataclass
38
+ class Point2MeshConfig:
39
+ # Coarse-to-fine levels
40
+ n_levels: int = 4
41
+ iters_per_level: int = 1000
42
+
43
+ # Mesh resolution schedule
44
+ init_faces: int = 2000
45
+ face_growth: float = 1.5
46
+ max_faces: int = 20000
47
+
48
+ # Sampling
49
+ samples_start: int = 15000
50
+ samples_end: int = 50000
51
+
52
+ # Loss weights
53
+ lambda_beam: float = 1.0
54
+ lambda_normal: float = 0.1
55
+ beam_epsilon: float = 0.5
56
+
57
+ # Network
58
+ in_channels: int = 6
59
+ enc_channels: list = field(default_factory=lambda: [64, 128, 256, 256])
60
+ lr: float = 2e-4
61
+
62
+ # PartMesh (set > 0 to enable spatial partitioning)
63
+ part_threshold: int = 10000 # enable parts when #faces > this
64
+ n_parts: int = 2
65
+
66
+ # Misc
67
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
68
+ log_every: int = 50
69
+ save_intermediates: bool = False
70
+ output_dir: str = "."
71
+
72
+
73
+ # ──────────────────────────────────────────────────────────────────────
74
+ # Optimisation loop
75
+ # ──────────────────────────────────────────────────────────────────────
76
+ def run_point2mesh(
77
+ input_path: str,
78
+ output_path: str,
79
+ cfg: Optional[Point2MeshConfig] = None,
80
+ progress_callback: Optional[Callable] = None,
81
+ ) -> str:
82
+ """
83
+ Full Point2Mesh pipeline.
84
+
85
+ Parameters
86
+ ----------
87
+ input_path : path to .ply / .pcd / .xyz / .obj point cloud
88
+ output_path : path to write final mesh (.obj / .ply / .stl)
89
+ cfg : configuration dataclass (defaults are sensible)
90
+ progress_callback : optional fn(level, iteration, loss_val)
91
+
92
+ Returns
93
+ -------
94
+ output_path : the path where the mesh was saved
95
+ """
96
+ if cfg is None:
97
+ cfg = Point2MeshConfig()
98
+
99
+ device = torch.device(cfg.device)
100
+
101
+ # ── 1. Load point cloud ──────────────────────────────────────────
102
+ logger.info(f"Loading point cloud from {input_path}")
103
+ points_np, normals_np = load_pointcloud(input_path)
104
+ logger.info(f" {len(points_np)} points loaded")
105
+
106
+ # Centre and scale to unit sphere
107
+ centroid = points_np.mean(axis=0)
108
+ points_np -= centroid
109
+ scale = np.abs(points_np).max()
110
+ if scale > 0:
111
+ points_np /= scale
112
+
113
+ X = torch.tensor(points_np, dtype=torch.float32, device=device)
114
+
115
+ # Normals
116
+ if normals_np is None:
117
+ logger.info(" Estimating normals …")
118
+ normals_np = estimate_normals(points_np)
119
+ else:
120
+ normals_np = normals_np.copy()
121
+ # Renormalize
122
+ norms = np.linalg.norm(normals_np, axis=1, keepdims=True)
123
+ normals_np /= np.clip(norms, 1e-8, None)
124
+
125
+ X_normals = torch.tensor(normals_np, dtype=torch.float32, device=device)
126
+
127
+ # ── 2. Build initial mesh (convex hull) ──────────────────────────
128
+ logger.info("Building initial mesh (convex hull) …")
129
+ init_verts, init_faces = build_initial_mesh(points_np, target_faces=cfg.init_faces)
130
+ logger.info(f" Initial mesh: {len(init_verts)} verts, {len(init_faces)} faces")
131
+
132
+ mesh = Mesh(init_verts, init_faces, device=str(device))
133
+
134
+ # ── 3. Coarse-to-fine optimisation ───────────────────────────────
135
+ net = Point2MeshNet(
136
+ in_ch=cfg.in_channels,
137
+ enc_channels=cfg.enc_channels,
138
+ ).to(device)
139
+
140
+ for level in range(cfg.n_levels):
141
+ logger.info(f"\n{'='*60}")
142
+ logger.info(f"Level {level} | {mesh.n_faces} faces, {mesh.n_edges} edges")
143
+ logger.info(f"{'='*60}")
144
+
145
+ # Re-initialise network weights and random input
146
+ net.reset_weights()
147
+ C_l = torch.rand(1, cfg.in_channels, mesh.n_edges, device=device)
148
+
149
+ optimiser = torch.optim.Adam(net.parameters(), lr=cfg.lr)
150
+
151
+ # Optionally use PartMesh for large meshes
152
+ use_parts = mesh.n_faces > cfg.part_threshold
153
+ pmesh = PartMesh(mesh, cfg.n_parts) if use_parts else None
154
+ if use_parts:
155
+ logger.info(f" Using PartMesh with {len(pmesh.parts)} parts")
156
+
157
+ t0 = time.time()
158
+ for it in range(cfg.iters_per_level):
159
+ optimiser.zero_grad()
160
+
161
+ # Linear ramp of sample count
162
+ frac = it / max(cfg.iters_per_level - 1, 1)
163
+ n_samples = int(cfg.samples_start + (cfg.samples_end - cfg.samples_start) * frac)
164
+
165
+ if use_parts:
166
+ loss_total = _forward_partmesh(
167
+ pmesh, net, C_l, X, X_normals, n_samples, cfg, mesh, device
168
+ )
169
+ else:
170
+ loss_total = _forward_single(
171
+ net, C_l, mesh, X, X_normals, n_samples, cfg
172
+ )
173
+
174
+ loss_total.backward()
175
+ optimiser.step()
176
+
177
+ if progress_callback:
178
+ progress_callback(level, it, loss_total.item())
179
+
180
+ if it % cfg.log_every == 0 or it == cfg.iters_per_level - 1:
181
+ elapsed = time.time() - t0
182
+ logger.info(
183
+ f" [{level}][{it:4d}/{cfg.iters_per_level}] "
184
+ f"loss={loss_total.item():.6f} "
185
+ f"({elapsed:.1f}s)"
186
+ )
187
+
188
+ # ── Apply final displacements ────────────────────────────────
189
+ with torch.no_grad():
190
+ delta_edges = net(C_l, mesh)
191
+ delta_v = edge_to_vertex_displacement(delta_edges, mesh)
192
+ new_verts = mesh.vs + delta_v
193
+
194
+ # Save intermediate
195
+ if cfg.save_intermediates:
196
+ import os
197
+ intermediate_path = os.path.join(
198
+ cfg.output_dir, f"level_{level}.obj"
199
+ )
200
+ save_mesh(
201
+ intermediate_path,
202
+ (new_verts.cpu().numpy() * scale) + centroid,
203
+ mesh.faces.cpu().numpy(),
204
+ )
205
+ logger.info(f" Saved intermediate β†’ {intermediate_path}")
206
+
207
+ # ── Remesh for next level ────────────────────────────────────
208
+ if level < cfg.n_levels - 1:
209
+ target = min(
210
+ int(mesh.n_faces * cfg.face_growth), cfg.max_faces
211
+ )
212
+ logger.info(f" Remeshing {mesh.n_faces} β†’ {target} faces …")
213
+ new_v_np = new_verts.cpu().numpy()
214
+ new_f_np = mesh.faces.cpu().numpy()
215
+ rem_v, rem_f = remesh(new_v_np, new_f_np, target)
216
+ mesh = Mesh(rem_v, rem_f, device=str(device))
217
+ else:
218
+ # Final level β€” just update vertex positions
219
+ mesh.vs = new_verts
220
+
221
+ # ── 4. Write output ──────────────────────────────────────────────
222
+ final_verts = mesh.vs.detach().cpu().numpy() * scale + centroid
223
+ final_faces = mesh.faces.cpu().numpy()
224
+
225
+ save_mesh(output_path, final_verts, final_faces)
226
+ logger.info(f"\nDone! Mesh saved to {output_path}")
227
+ logger.info(f" {len(final_verts)} vertices, {len(final_faces)} faces")
228
+
229
+ return output_path
230
+
231
+
232
+ # ──────────────────────────────────────────────────────────────────────
233
+ # Forward pass helpers
234
+ # ──────────────────────────────────────────────────────────────────────
235
+ def _forward_single(net, C_l, mesh, X, X_normals, n_samples, cfg):
236
+ """Standard forward on the full mesh."""
237
+ delta_edges = net(C_l, mesh)
238
+ delta_v = edge_to_vertex_displacement(delta_edges, mesh)
239
+ V_new = mesh.vs + delta_v
240
+
241
+ Y, face_ids = sample_surface(V_new, mesh.faces, n_samples)
242
+
243
+ # Face normals at sampled points
244
+ fnormals = mesh.face_normals(V_new)
245
+ Y_normals = fnormals[face_ids]
246
+
247
+ loss = chamfer_loss(X, Y)
248
+
249
+ if cfg.lambda_beam > 0:
250
+ loss = loss + cfg.lambda_beam * beam_gap_loss(
251
+ Y, Y_normals, X, epsilon=cfg.beam_epsilon
252
+ )
253
+
254
+ if cfg.lambda_normal > 0 and X_normals is not None:
255
+ loss = loss + cfg.lambda_normal * normal_loss(Y, Y_normals, X, X_normals)
256
+
257
+ return loss
258
+
259
+
260
+ def _forward_partmesh(pmesh, net, C_l_full, X, X_normals, n_samples, cfg, full_mesh, device):
261
+ """
262
+ Forward pass with PartMesh splitting.
263
+ Process each spatial part independently, accumulate gradients.
264
+ """
265
+ n_parts = len(pmesh.parts)
266
+ samples_per_part = max(n_samples // n_parts, 1000)
267
+
268
+ total_loss = torch.tensor(0.0, device=device, requires_grad=True)
269
+
270
+ for part_idx, part_mesh in enumerate(pmesh.parts):
271
+ # Create random input for this part's edge count
272
+ C_part = torch.rand(1, cfg.in_channels, part_mesh.n_edges, device=device)
273
+
274
+ delta_edges = net(C_part, part_mesh)
275
+ delta_v = edge_to_vertex_displacement(delta_edges, part_mesh)
276
+ V_new = part_mesh.vs + delta_v
277
+
278
+ Y, face_ids = sample_surface(V_new, part_mesh.faces, samples_per_part)
279
+ fnormals = part_mesh.face_normals(V_new)
280
+ Y_normals = fnormals[face_ids]
281
+
282
+ loss = chamfer_loss(X, Y)
283
+ if cfg.lambda_beam > 0:
284
+ loss = loss + cfg.lambda_beam * beam_gap_loss(
285
+ Y, Y_normals, X, epsilon=cfg.beam_epsilon
286
+ )
287
+ if cfg.lambda_normal > 0 and X_normals is not None:
288
+ loss = loss + cfg.lambda_normal * normal_loss(Y, Y_normals, X, X_normals)
289
+
290
+ total_loss = total_loss + loss / n_parts
291
+
292
+ return total_loss