File size: 6,300 Bytes
8e3ab85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Loss functions for SDF and VG stages.
All pure PyTorch, no compiled extensions needed.
"""
import torch
import torch.nn.functional as F
import numpy as np


def gaussian_kernel(points, queries):
    """
    points:  (n, k, 3) or (n, 3)
    queries: (n, 3)
    Returns: (n, k) or (n,) Gaussian weights.
    """
    if points.dim() == 2:
        points = points.unsqueeze(1)  # (n, 1, 3)
    pts_dist = torch.linalg.norm((points - queries.unsqueeze(1)), ord=2, dim=-1)  # (n, k)
    h = pts_dist.mean(dim=-1, keepdim=True) + 1e-8
    dist_exp = torch.exp(-pts_dist ** 2 / h ** 2)
    gaussian_weight = dist_exp / (dist_exp.sum(dim=-1, keepdim=True) + 1e-8)
    return gaussian_weight


def pull_knn_loss(points, samples_moved, samples):
    """KNN-weighted pull loss."""
    # points: (n, k, 3); samples_moved: (n, 3); samples: (n, 3)
    if points.dim() == 2:
        points = points.unsqueeze(1)
    loss_pull = torch.linalg.norm((points - samples_moved.unsqueeze(1)), ord=2, dim=-1)
    g_weight = gaussian_kernel(points, samples).detach()
    loss_pull = (loss_pull * g_weight).sum(dim=-1).mean()
    return loss_pull


def grad_consis_knn_loss(gradient_points_norm, gradients_samples_norm, points, samples):
    if points.dim() == 2:
        points = points.unsqueeze(1)
    loss_consis = (1 - F.cosine_similarity(gradient_points_norm, gradients_samples_norm.unsqueeze(1), dim=-1))
    g_weight = gaussian_kernel(points, samples).detach()
    loss_consis = (loss_consis * g_weight).sum(dim=-1).mean()
    return loss_consis


def eikonal_loss(gradients_samples):
    loss_eikonal = ((gradients_samples.norm(2, dim=-1) - 1).square()).mean()
    return loss_eikonal


def div_loss(samples, gradients_samples):
    """Divergence of gradient field."""
    def gradient(inputs, outputs):
        inputs.requires_grad_(True)
        d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
        points_grad = torch.autograd.grad(
            outputs=outputs,
            inputs=inputs,
            grad_outputs=d_points,
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        return points_grad

    div_dx = gradient(samples, gradients_samples[:, 0])
    div_dy = gradient(samples, gradients_samples[:, 1])
    div_dz = gradient(samples, gradients_samples[:, 2])
    divergence = div_dx[:, 0] + div_dy[:, 1] + div_dz[:, 2]
    loss_div = (torch.clamp(torch.square(divergence), 0.1, 50)).mean()
    return loss_div


def sdf_loss(sample_sdf):
    return torch.abs(sample_sdf).mean()


def cal_curvature_with_normal(pts, normals, knn=16):
    """
    Estimate curvature from normal variation over k-NN.
    pts: (n, 3)    normals: (n, 3)
    Returns: (n, 1) curvature in [0, 1].
    """
    from scipy.spatial import KDTree
    pts_np = pts.detach().cpu().numpy()
    normals_np = normals.detach().cpu().numpy()
    tree = KDTree(pts_np)
    _, idx = tree.query(pts_np, k=min(knn + 1, len(pts_np)))
    idx = torch.from_numpy(idx[:, 1:]).long().to(pts.device)  # (n, k)
    neigh_pts = pts[idx]            # (n, k, 3)
    neigh_normals = normals[idx]    # (n, k, 3)
    neigh_curvature = 1.0 - F.cosine_similarity(normals.unsqueeze(1), neigh_normals, dim=-1)  # (n, k)
    g_weight = gaussian_kernel(neigh_pts, pts).detach()
    pts_curvature_ave = (neigh_curvature * g_weight).sum(dim=-1, keepdim=True)  # (n, 1)

    pts_curvature_ave = torch.sigmoid(pts_curvature_ave - torch.mean(pts_curvature_ave))
    pts_curvature_ave = (pts_curvature_ave - torch.min(pts_curvature_ave) + 1e-6) / (
        torch.max(pts_curvature_ave) - torch.min(pts_curvature_ave) + 1e-6)
    return pts_curvature_ave


def cal_nc_loss(surface_normals, sample_normals, sur_neigh_idx):
    """Normal consistency loss."""
    neigh_normals = sample_normals[sur_neigh_idx]
    nc_loss = 1.0 - F.cosine_similarity(surface_normals, neigh_normals, dim=-1).mean()
    return nc_loss


def cal_chamfer_loss(sur_pts, sample_pts, curvature_surface, loss_w, nearest_clamp):
    """
    Bidirectional Chamfer with curvature weighting + repulsion.
    loss_w: [w_curv, w_plain, w_reverse, w_repulsion]
    """
    from scipy.spatial import KDTree

    # ---- nearest from surface to samples ----
    sample_np = sample_pts.detach().cpu().numpy()
    sur_np = sur_pts.detach().cpu().numpy()
    tree = KDTree(sample_np)
    _, sur_neigh_idx = tree.query(sur_np, k=1)
    sur_neigh_idx = torch.from_numpy(sur_neigh_idx).long().to(sur_pts.device)
    sur_neigh_pts = sample_pts[sur_neigh_idx]
    dist_1 = torch.linalg.norm((sur_pts - sur_neigh_pts), ord=2, dim=-1).unsqueeze(-1) ** 2
    weight_sur_curvature = curvature_surface
    loss_part_1 = loss_w[0] * (dist_1 * weight_sur_curvature).mean() + loss_w[1] * dist_1.mean()

    # ---- nearest from samples to surface ----
    tree2 = KDTree(sur_np)
    _, sample_neigh_idx = tree2.query(sample_np, k=1)
    sample_neigh_idx = torch.from_numpy(sample_neigh_idx).long().to(sample_pts.device)
    sample_neigh_pts = sur_pts[sample_neigh_idx]
    dist_2 = torch.linalg.norm((sample_pts - sample_neigh_pts), ord=2, dim=-1) ** 2
    loss_part_2 = loss_w[2] * dist_2.mean()

    # ---- repulsion (self-nearest-neighbor distance) ----
    tree3 = KDTree(sample_np)
    _, sample_neigh_idx_self = tree3.query(sample_np, k=2)
    sample_neigh_idx_self = torch.from_numpy(sample_neigh_idx_self[:, 1]).long().to(sample_pts.device)
    sample_neigh_pts_self = sample_pts[sample_neigh_idx_self]
    relative_dist = sample_neigh_pts_self - sample_pts
    norm_dist = torch.linalg.norm(relative_dist, ord=2, dim=-1) ** 2
    norm_dist_ = torch.clamp(norm_dist, max=nearest_clamp * norm_dist.mean())
    loss_part_3 = -norm_dist_.mean() * loss_w[3]

    chamfer_loss = loss_part_1 + loss_part_2 + loss_part_3
    return chamfer_loss, sur_neigh_idx, sample_neigh_idx


def cal_vg_loss(sur_pts, sur_normals, curvature_ave, sample_pts, sample_grad, loss_w, nearest_clamp):
    """Total VG stage loss."""
    sample_normal = F.normalize(sample_grad.detach(), dim=-1)
    chamfer_loss, sur_neigh_idx, _ = cal_chamfer_loss(
        sur_pts, sample_pts, curvature_ave, loss_w[:4], nearest_clamp)
    normal_loss = cal_nc_loss(sur_normals, sample_normal, sur_neigh_idx)
    loss = chamfer_loss + loss_w[4] * normal_loss
    return loss