lightweightmr / losses.py
bdck's picture
Upload losses.py
8e3ab85 verified
"""
Loss functions for SDF and VG stages.
All pure PyTorch, no compiled extensions needed.
"""
import torch
import torch.nn.functional as F
import numpy as np
def gaussian_kernel(points, queries):
"""
points: (n, k, 3) or (n, 3)
queries: (n, 3)
Returns: (n, k) or (n,) Gaussian weights.
"""
if points.dim() == 2:
points = points.unsqueeze(1) # (n, 1, 3)
pts_dist = torch.linalg.norm((points - queries.unsqueeze(1)), ord=2, dim=-1) # (n, k)
h = pts_dist.mean(dim=-1, keepdim=True) + 1e-8
dist_exp = torch.exp(-pts_dist ** 2 / h ** 2)
gaussian_weight = dist_exp / (dist_exp.sum(dim=-1, keepdim=True) + 1e-8)
return gaussian_weight
def pull_knn_loss(points, samples_moved, samples):
"""KNN-weighted pull loss."""
# points: (n, k, 3); samples_moved: (n, 3); samples: (n, 3)
if points.dim() == 2:
points = points.unsqueeze(1)
loss_pull = torch.linalg.norm((points - samples_moved.unsqueeze(1)), ord=2, dim=-1)
g_weight = gaussian_kernel(points, samples).detach()
loss_pull = (loss_pull * g_weight).sum(dim=-1).mean()
return loss_pull
def grad_consis_knn_loss(gradient_points_norm, gradients_samples_norm, points, samples):
if points.dim() == 2:
points = points.unsqueeze(1)
loss_consis = (1 - F.cosine_similarity(gradient_points_norm, gradients_samples_norm.unsqueeze(1), dim=-1))
g_weight = gaussian_kernel(points, samples).detach()
loss_consis = (loss_consis * g_weight).sum(dim=-1).mean()
return loss_consis
def eikonal_loss(gradients_samples):
loss_eikonal = ((gradients_samples.norm(2, dim=-1) - 1).square()).mean()
return loss_eikonal
def div_loss(samples, gradients_samples):
"""Divergence of gradient field."""
def gradient(inputs, outputs):
inputs.requires_grad_(True)
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
points_grad = torch.autograd.grad(
outputs=outputs,
inputs=inputs,
grad_outputs=d_points,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return points_grad
div_dx = gradient(samples, gradients_samples[:, 0])
div_dy = gradient(samples, gradients_samples[:, 1])
div_dz = gradient(samples, gradients_samples[:, 2])
divergence = div_dx[:, 0] + div_dy[:, 1] + div_dz[:, 2]
loss_div = (torch.clamp(torch.square(divergence), 0.1, 50)).mean()
return loss_div
def sdf_loss(sample_sdf):
return torch.abs(sample_sdf).mean()
def cal_curvature_with_normal(pts, normals, knn=16):
"""
Estimate curvature from normal variation over k-NN.
pts: (n, 3) normals: (n, 3)
Returns: (n, 1) curvature in [0, 1].
"""
from scipy.spatial import KDTree
pts_np = pts.detach().cpu().numpy()
normals_np = normals.detach().cpu().numpy()
tree = KDTree(pts_np)
_, idx = tree.query(pts_np, k=min(knn + 1, len(pts_np)))
idx = torch.from_numpy(idx[:, 1:]).long().to(pts.device) # (n, k)
neigh_pts = pts[idx] # (n, k, 3)
neigh_normals = normals[idx] # (n, k, 3)
neigh_curvature = 1.0 - F.cosine_similarity(normals.unsqueeze(1), neigh_normals, dim=-1) # (n, k)
g_weight = gaussian_kernel(neigh_pts, pts).detach()
pts_curvature_ave = (neigh_curvature * g_weight).sum(dim=-1, keepdim=True) # (n, 1)
pts_curvature_ave = torch.sigmoid(pts_curvature_ave - torch.mean(pts_curvature_ave))
pts_curvature_ave = (pts_curvature_ave - torch.min(pts_curvature_ave) + 1e-6) / (
torch.max(pts_curvature_ave) - torch.min(pts_curvature_ave) + 1e-6)
return pts_curvature_ave
def cal_nc_loss(surface_normals, sample_normals, sur_neigh_idx):
"""Normal consistency loss."""
neigh_normals = sample_normals[sur_neigh_idx]
nc_loss = 1.0 - F.cosine_similarity(surface_normals, neigh_normals, dim=-1).mean()
return nc_loss
def cal_chamfer_loss(sur_pts, sample_pts, curvature_surface, loss_w, nearest_clamp):
"""
Bidirectional Chamfer with curvature weighting + repulsion.
loss_w: [w_curv, w_plain, w_reverse, w_repulsion]
"""
from scipy.spatial import KDTree
# ---- nearest from surface to samples ----
sample_np = sample_pts.detach().cpu().numpy()
sur_np = sur_pts.detach().cpu().numpy()
tree = KDTree(sample_np)
_, sur_neigh_idx = tree.query(sur_np, k=1)
sur_neigh_idx = torch.from_numpy(sur_neigh_idx).long().to(sur_pts.device)
sur_neigh_pts = sample_pts[sur_neigh_idx]
dist_1 = torch.linalg.norm((sur_pts - sur_neigh_pts), ord=2, dim=-1).unsqueeze(-1) ** 2
weight_sur_curvature = curvature_surface
loss_part_1 = loss_w[0] * (dist_1 * weight_sur_curvature).mean() + loss_w[1] * dist_1.mean()
# ---- nearest from samples to surface ----
tree2 = KDTree(sur_np)
_, sample_neigh_idx = tree2.query(sample_np, k=1)
sample_neigh_idx = torch.from_numpy(sample_neigh_idx).long().to(sample_pts.device)
sample_neigh_pts = sur_pts[sample_neigh_idx]
dist_2 = torch.linalg.norm((sample_pts - sample_neigh_pts), ord=2, dim=-1) ** 2
loss_part_2 = loss_w[2] * dist_2.mean()
# ---- repulsion (self-nearest-neighbor distance) ----
tree3 = KDTree(sample_np)
_, sample_neigh_idx_self = tree3.query(sample_np, k=2)
sample_neigh_idx_self = torch.from_numpy(sample_neigh_idx_self[:, 1]).long().to(sample_pts.device)
sample_neigh_pts_self = sample_pts[sample_neigh_idx_self]
relative_dist = sample_neigh_pts_self - sample_pts
norm_dist = torch.linalg.norm(relative_dist, ord=2, dim=-1) ** 2
norm_dist_ = torch.clamp(norm_dist, max=nearest_clamp * norm_dist.mean())
loss_part_3 = -norm_dist_.mean() * loss_w[3]
chamfer_loss = loss_part_1 + loss_part_2 + loss_part_3
return chamfer_loss, sur_neigh_idx, sample_neigh_idx
def cal_vg_loss(sur_pts, sur_normals, curvature_ave, sample_pts, sample_grad, loss_w, nearest_clamp):
"""Total VG stage loss."""
sample_normal = F.normalize(sample_grad.detach(), dim=-1)
chamfer_loss, sur_neigh_idx, _ = cal_chamfer_loss(
sur_pts, sample_pts, curvature_ave, loss_w[:4], nearest_clamp)
normal_loss = cal_nc_loss(sur_normals, sample_normal, sur_neigh_idx)
loss = chamfer_loss + loss_w[4] * normal_loss
return loss