Upload folder using huggingface_hub
Browse files- checkpoint.pt +3 -0
- s23dr_2026_example/__init__.py +0 -0
- s23dr_2026_example/attention.py +226 -0
- s23dr_2026_example/bad_samples.txt +156 -0
- s23dr_2026_example/cache_scenes.py +373 -0
- s23dr_2026_example/color_mappings.py +209 -0
- s23dr_2026_example/data.py +237 -0
- s23dr_2026_example/losses.py +311 -0
- s23dr_2026_example/make_sampled_cache.py +260 -0
- s23dr_2026_example/model.py +696 -0
- s23dr_2026_example/point_fusion.py +554 -0
- s23dr_2026_example/postprocess_v2.py +39 -0
- s23dr_2026_example/segment_postprocess.py +60 -0
- s23dr_2026_example/sinkhorn.py +181 -0
- s23dr_2026_example/soft_hss_loss.py +507 -0
- s23dr_2026_example/tokenizer.py +88 -0
- s23dr_2026_example/varifold.py +196 -0
- s23dr_2026_example/wire_varifold_kernels.py +461 -0
- script.py +322 -0
checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc38a61ff512948b1dc92a30129d6efdd093f507948fc5b538050c4a38bfbf6c
|
| 3 |
+
size 106460054
|
s23dr_2026_example/__init__.py
ADDED
|
File without changes
|
s23dr_2026_example/attention.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# custom_transformer.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# =============================================================================
|
| 7 |
+
# Core Efficient Multihead Attention using Scaled Dot Product Attention (SDPA)
|
| 8 |
+
# =============================================================================
|
| 9 |
+
|
| 10 |
+
class MultiHeadSDPA(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Multi-head cross-attention using torch.nn.functional.scaled_dot_product_attention
|
| 13 |
+
without causal masking. Suitable for set inputs and cross-attention.
|
| 14 |
+
|
| 15 |
+
If qk_norm=True, L2-normalizes Q and K per-head before the dot product,
|
| 16 |
+
then scales by a learned per-head temperature (log_scale). This caps logit
|
| 17 |
+
magnitude to [-1, +1] * exp(log_scale), preventing attention entropy
|
| 18 |
+
collapse at large head_dim.
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, d_model: int, num_heads: int, kv_heads: int = None,
|
| 21 |
+
qk_norm: bool = False, qk_norm_type: str = "l2"):
|
| 22 |
+
super().__init__()
|
| 23 |
+
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
| 24 |
+
self.d_model = d_model
|
| 25 |
+
self.num_heads = num_heads
|
| 26 |
+
self.kv_heads = kv_heads or num_heads
|
| 27 |
+
assert self.num_heads % self.kv_heads == 0, "kv_heads must divide num_heads"
|
| 28 |
+
|
| 29 |
+
self.head_dim = d_model // num_heads
|
| 30 |
+
self.qk_norm = qk_norm
|
| 31 |
+
self.qk_norm_type = qk_norm_type
|
| 32 |
+
|
| 33 |
+
# Input projection layers
|
| 34 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 35 |
+
self.k_proj = nn.Linear(d_model, self.kv_heads * self.head_dim, bias=False)
|
| 36 |
+
self.v_proj = nn.Linear(d_model, self.kv_heads * self.head_dim, bias=False)
|
| 37 |
+
|
| 38 |
+
# Output projection
|
| 39 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 40 |
+
nn.init.zeros_(self.out_proj.weight)
|
| 41 |
+
|
| 42 |
+
if qk_norm:
|
| 43 |
+
import math
|
| 44 |
+
if qk_norm_type == "rms":
|
| 45 |
+
# Standard QK-norm (Qwen3/Gemma3 style): RMSNorm on Q and K,
|
| 46 |
+
# no learned temperature. SDPA's 1/sqrt(d) scaling is sufficient
|
| 47 |
+
# because RMSNorm preserves the expected logit variance.
|
| 48 |
+
pass # no extra parameters needed
|
| 49 |
+
else:
|
| 50 |
+
# L2 + learned temperature (nGPT/ViT-22B style):
|
| 51 |
+
# L2 projects to unit sphere, needs learned scale to compensate.
|
| 52 |
+
self.log_scale = nn.Parameter(
|
| 53 |
+
torch.full((num_heads,), math.log(math.sqrt(self.head_dim))))
|
| 54 |
+
|
| 55 |
+
def forward(
|
| 56 |
+
self,
|
| 57 |
+
query: torch.Tensor,
|
| 58 |
+
key: torch.Tensor,
|
| 59 |
+
key_padding_mask: torch.Tensor | None = None,
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
# Project
|
| 62 |
+
q = self.q_proj(query)
|
| 63 |
+
k = self.k_proj(key)
|
| 64 |
+
v = self.v_proj(key)
|
| 65 |
+
|
| 66 |
+
B, Tq, _ = q.shape
|
| 67 |
+
_, Tk, _ = k.shape
|
| 68 |
+
|
| 69 |
+
q = q.view(B, Tq, self.num_heads, self.head_dim).transpose(1, 2)
|
| 70 |
+
k = k.view(B, Tk, self.kv_heads, self.head_dim).transpose(1, 2)
|
| 71 |
+
v = v.view(B, Tk, self.kv_heads, self.head_dim).transpose(1, 2)
|
| 72 |
+
|
| 73 |
+
if self.kv_heads != self.num_heads:
|
| 74 |
+
repeat = self.num_heads // self.kv_heads
|
| 75 |
+
k = k.repeat_interleave(repeat, dim=1)
|
| 76 |
+
v = v.repeat_interleave(repeat, dim=1)
|
| 77 |
+
|
| 78 |
+
if self.qk_norm:
|
| 79 |
+
if self.qk_norm_type == "rms":
|
| 80 |
+
# RMSNorm (Qwen3/Gemma3 style): no learned temperature needed.
|
| 81 |
+
# After RMSNorm, logit variance matches standard SDPA naturally.
|
| 82 |
+
q = q * torch.rsqrt(q.square().mean(dim=-1, keepdim=True) + 1e-6)
|
| 83 |
+
k = k * torch.rsqrt(k.square().mean(dim=-1, keepdim=True) + 1e-6)
|
| 84 |
+
attn_mask = None
|
| 85 |
+
if key_padding_mask is not None:
|
| 86 |
+
attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
|
| 87 |
+
attn_out = F.scaled_dot_product_attention(
|
| 88 |
+
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
# L2 + learned temperature (nGPT/ViT-22B style)
|
| 92 |
+
q = F.normalize(q, dim=-1)
|
| 93 |
+
k = F.normalize(k, dim=-1)
|
| 94 |
+
scale = self.log_scale.exp().view(1, -1, 1, 1)
|
| 95 |
+
q = q * scale
|
| 96 |
+
attn_mask = None
|
| 97 |
+
if key_padding_mask is not None:
|
| 98 |
+
attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
|
| 99 |
+
attn_out = F.scaled_dot_product_attention(
|
| 100 |
+
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
|
| 101 |
+
scale=1.0,
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
attn_mask = None
|
| 105 |
+
if key_padding_mask is not None:
|
| 106 |
+
attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
|
| 107 |
+
attn_out = F.scaled_dot_product_attention(
|
| 108 |
+
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
attn_out = attn_out.transpose(1, 2).reshape(B, Tq, self.d_model)
|
| 112 |
+
return self.out_proj(attn_out)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# =============================================================================
|
| 116 |
+
# Transformer Feed-Forward Block
|
| 117 |
+
# =============================================================================
|
| 118 |
+
|
| 119 |
+
def _get_activation(name: str):
|
| 120 |
+
"""Look up activation function by name. Supports 'relu_sq' for ReLU^2."""
|
| 121 |
+
if name == "relu_sq":
|
| 122 |
+
return lambda x: F.relu(x).square()
|
| 123 |
+
return getattr(F, name)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class FeedForward(nn.Module):
|
| 127 |
+
"""
|
| 128 |
+
Position-wise MLP block: linear -> activation -> linear.
|
| 129 |
+
Supports 'gelu', 'relu', 'relu_sq', etc.
|
| 130 |
+
"""
|
| 131 |
+
def __init__(self, d_model: int, dim_ff: int, activation: str = "gelu"):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.linear1 = nn.Linear(d_model, dim_ff)
|
| 134 |
+
self.linear2 = nn.Linear(dim_ff, d_model)
|
| 135 |
+
nn.init.zeros_(self.linear2.weight)
|
| 136 |
+
nn.init.zeros_(self.linear2.bias)
|
| 137 |
+
self.activation = _get_activation(activation)
|
| 138 |
+
|
| 139 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 140 |
+
x = self.linear1(x)
|
| 141 |
+
return self.linear2(self.activation(x))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# =============================================================================
|
| 145 |
+
# Custom Transformer Block
|
| 146 |
+
# =============================================================================
|
| 147 |
+
|
| 148 |
+
class TransformerBlock(nn.Module):
|
| 149 |
+
"""
|
| 150 |
+
Single transformer block combining:
|
| 151 |
+
- multi-head SDPA (non-causal)
|
| 152 |
+
- layernorm + residual
|
| 153 |
+
- feed-forward MLP + residual
|
| 154 |
+
"""
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
d_model: int,
|
| 158 |
+
num_heads: int,
|
| 159 |
+
dim_ff: int,
|
| 160 |
+
dropout: float = 0.0,
|
| 161 |
+
activation: str = "gelu",
|
| 162 |
+
kv_heads: int = None,
|
| 163 |
+
):
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 166 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 167 |
+
|
| 168 |
+
self.attn = MultiHeadSDPA(d_model, num_heads, kv_heads=kv_heads)
|
| 169 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 170 |
+
self.ffn = FeedForward(d_model, dim_ff, activation=activation)
|
| 171 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 172 |
+
|
| 173 |
+
def forward(
|
| 174 |
+
self,
|
| 175 |
+
x: torch.Tensor,
|
| 176 |
+
memory: torch.Tensor,
|
| 177 |
+
memory_key_padding_mask: torch.Tensor | None = None,
|
| 178 |
+
) -> torch.Tensor:
|
| 179 |
+
res = x
|
| 180 |
+
x = self.norm1(x)
|
| 181 |
+
x = self.attn(x, memory, key_padding_mask=memory_key_padding_mask)
|
| 182 |
+
x = res + self.dropout1(x)
|
| 183 |
+
|
| 184 |
+
res = x
|
| 185 |
+
x = self.norm2(x)
|
| 186 |
+
x = self.ffn(x)
|
| 187 |
+
return res + self.dropout2(x)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class TransformerDecoderSets(nn.Module):
|
| 191 |
+
"""
|
| 192 |
+
A stack of TransformerBlock layers for set-to-set
|
| 193 |
+
modeling without causal masks.
|
| 194 |
+
"""
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
d_model: int,
|
| 198 |
+
num_heads: int,
|
| 199 |
+
dim_ff: int,
|
| 200 |
+
num_layers: int,
|
| 201 |
+
dropout: float = 0.0,
|
| 202 |
+
activation: str = "gelu",
|
| 203 |
+
kv_heads: int = None,
|
| 204 |
+
):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.layers = nn.ModuleList([
|
| 207 |
+
TransformerBlock(
|
| 208 |
+
d_model,
|
| 209 |
+
num_heads,
|
| 210 |
+
dim_ff,
|
| 211 |
+
dropout=dropout,
|
| 212 |
+
activation=activation,
|
| 213 |
+
kv_heads=kv_heads,
|
| 214 |
+
)
|
| 215 |
+
for _ in range(num_layers)
|
| 216 |
+
])
|
| 217 |
+
|
| 218 |
+
def forward(
|
| 219 |
+
self,
|
| 220 |
+
tgt: torch.Tensor,
|
| 221 |
+
memory: torch.Tensor,
|
| 222 |
+
memory_key_padding_mask: torch.Tensor | None = None,
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
for layer in self.layers:
|
| 225 |
+
tgt = layer(tgt, memory, memory_key_padding_mask=memory_key_padding_mask)
|
| 226 |
+
return tgt
|
s23dr_2026_example/bad_samples.txt
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
14b1872e960
|
| 2 |
+
1807ef90db4
|
| 3 |
+
180e6a67e87
|
| 4 |
+
1ad5c6bd31f
|
| 5 |
+
1c3f939ad93
|
| 6 |
+
1ede4c0d52f
|
| 7 |
+
214f17d9cc4
|
| 8 |
+
22256d88df9
|
| 9 |
+
24a92a8de6d
|
| 10 |
+
24b4e984bad
|
| 11 |
+
2565978cf53
|
| 12 |
+
2a71f1a2072
|
| 13 |
+
2d44c1fade6
|
| 14 |
+
2ebed43823a
|
| 15 |
+
33982551420
|
| 16 |
+
3b480496f82
|
| 17 |
+
412a2bdf7a4
|
| 18 |
+
44343bbabbb
|
| 19 |
+
4a0b3f04cbd
|
| 20 |
+
4a7fa170826
|
| 21 |
+
4b7dc027214
|
| 22 |
+
4e0dc2c9b18
|
| 23 |
+
5172a516c8b
|
| 24 |
+
529e8f15cd2
|
| 25 |
+
56fc6f6f163
|
| 26 |
+
575963ce814
|
| 27 |
+
578ec40a278
|
| 28 |
+
5a0c07c575a
|
| 29 |
+
5d521223c26
|
| 30 |
+
6148b5c9461
|
| 31 |
+
631eb6d7c03
|
| 32 |
+
655a14f8a75
|
| 33 |
+
66502d7ee6f
|
| 34 |
+
6da76fc6687
|
| 35 |
+
777eaaad0ca
|
| 36 |
+
7a4e2909d68
|
| 37 |
+
7c5c9baf483
|
| 38 |
+
80806dfd75e
|
| 39 |
+
81a4ead431d
|
| 40 |
+
833152dd554
|
| 41 |
+
85797868c0f
|
| 42 |
+
86460ad8181
|
| 43 |
+
86783a6bee4
|
| 44 |
+
95193322d7a
|
| 45 |
+
99a9d056200
|
| 46 |
+
9b1d4eeaab9
|
| 47 |
+
9ff759f2e4c
|
| 48 |
+
acbd243da16
|
| 49 |
+
b9b275710c0
|
| 50 |
+
beceaa9bb7c
|
| 51 |
+
c243d079286
|
| 52 |
+
c5c7337d2cb
|
| 53 |
+
cdf6f2d3b35
|
| 54 |
+
cfe370f1c87
|
| 55 |
+
d4a72aea80c
|
| 56 |
+
d655f066cd3
|
| 57 |
+
d79e8d9455c
|
| 58 |
+
d7d6c5be76e
|
| 59 |
+
dc30ae4b93b
|
| 60 |
+
de9495f7ca3
|
| 61 |
+
e1901819c72
|
| 62 |
+
e1d88c1a6b1
|
| 63 |
+
e5d3eb0a617
|
| 64 |
+
ec11d3cdcf6
|
| 65 |
+
ecb21fad0ad
|
| 66 |
+
ee55d8c6493
|
| 67 |
+
ee7e6d4dee1
|
| 68 |
+
008052054aa
|
| 69 |
+
03ecb7d3cf3
|
| 70 |
+
0555a655534
|
| 71 |
+
099cad230c6
|
| 72 |
+
0d061ae23f0
|
| 73 |
+
10741a421c0
|
| 74 |
+
110d5e407b9
|
| 75 |
+
128a7fb415a
|
| 76 |
+
13177736b26
|
| 77 |
+
1635d73bf7d
|
| 78 |
+
18a760de9ea
|
| 79 |
+
18d90d03e95
|
| 80 |
+
209627a5c1a
|
| 81 |
+
21e3cd4b7b8
|
| 82 |
+
22f5499200d
|
| 83 |
+
266eb64de68
|
| 84 |
+
269235f770b
|
| 85 |
+
2758490e558
|
| 86 |
+
2a203cf5d35
|
| 87 |
+
2a878ec47ab
|
| 88 |
+
2cb43eb2201
|
| 89 |
+
393298e282b
|
| 90 |
+
395abe6aac7
|
| 91 |
+
3d19c7a4ca3
|
| 92 |
+
44e2b719b1e
|
| 93 |
+
45039819fcc
|
| 94 |
+
4cb4ff01619
|
| 95 |
+
4e5eb5712fa
|
| 96 |
+
4e988765a6d
|
| 97 |
+
5077bf42714
|
| 98 |
+
55ed69b2622
|
| 99 |
+
5ae3b651a37
|
| 100 |
+
5ca1edeed4c
|
| 101 |
+
5daa76b1c7f
|
| 102 |
+
5fdd11dfae5
|
| 103 |
+
6078cf180c2
|
| 104 |
+
6682b309e9c
|
| 105 |
+
6c02d2038c0
|
| 106 |
+
71c595506c8
|
| 107 |
+
73c8f960c18
|
| 108 |
+
74ccc8fd057
|
| 109 |
+
7a34156a798
|
| 110 |
+
7ac7af9f59c
|
| 111 |
+
7f2ec0ea179
|
| 112 |
+
823b837b36c
|
| 113 |
+
82d7600f9a3
|
| 114 |
+
848161a2900
|
| 115 |
+
88cedf129eb
|
| 116 |
+
8dec106b6a6
|
| 117 |
+
8e335d08ca4
|
| 118 |
+
8ecf7c58193
|
| 119 |
+
8fa55008beb
|
| 120 |
+
90e09de2301
|
| 121 |
+
9197acc0b9d
|
| 122 |
+
954c25e876c
|
| 123 |
+
98517d5563d
|
| 124 |
+
99e717a0148
|
| 125 |
+
9a0c0635bd7
|
| 126 |
+
9ad436b7b3d
|
| 127 |
+
9be351cbf14
|
| 128 |
+
9e2a2e51798
|
| 129 |
+
a84a7ea9220
|
| 130 |
+
aa8cb84d3eb
|
| 131 |
+
b07977292da
|
| 132 |
+
b3e33456f0b
|
| 133 |
+
b7823de373e
|
| 134 |
+
bac379382d9
|
| 135 |
+
bd2d9bf67a3
|
| 136 |
+
c14584a84cd
|
| 137 |
+
c497170c970
|
| 138 |
+
cd8e767612b
|
| 139 |
+
d17917bb279
|
| 140 |
+
d42b9d432a9
|
| 141 |
+
d53d8857a85
|
| 142 |
+
d6808cf3d98
|
| 143 |
+
d6f509d1dd9
|
| 144 |
+
d7abd08e643
|
| 145 |
+
d83493bf974
|
| 146 |
+
d87293651ee
|
| 147 |
+
da9d4ac9e8e
|
| 148 |
+
daa1702791a
|
| 149 |
+
dcb12411c14
|
| 150 |
+
de9ab9cdd5b
|
| 151 |
+
df906c58a3c
|
| 152 |
+
e3870649eb5
|
| 153 |
+
ea90aed9b98
|
| 154 |
+
ecaa81b9711
|
| 155 |
+
efc1238665b
|
| 156 |
+
c5a65219daf
|
s23dr_2026_example/cache_scenes.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Cache compact scenes from HoHo22k shards to training-ready .pt files.
|
| 3 |
+
|
| 4 |
+
Runs build_compact_scene + precomputes group_id, semantic class, and
|
| 5 |
+
normalization so training only needs fast sampling + GPU forward.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python cache_scenes.py --data-dir data/ --out-dir cache/train
|
| 9 |
+
python cache_scenes.py --streaming --out-dir cache/train --limit 5000
|
| 10 |
+
python cache_scenes.py --data-dir data/ --out-dir cache/train --workers 4
|
| 11 |
+
|
| 12 |
+
Cache format per file (.pt):
|
| 13 |
+
xyz: float32 [P, 3] all points in world space
|
| 14 |
+
source: uint8 [P] 0=colmap, 1=depth
|
| 15 |
+
group_id: int8 [P] priority tier 0-4, -1=excluded
|
| 16 |
+
class_id: uint8 [P] one-hot class index (0-12), see SEMANTIC_CLASSES
|
| 17 |
+
visible_src: uint8 [P] for visualization (1=gestalt, 2=ade)
|
| 18 |
+
visible_id: int16 [P] for visualization (class id within space)
|
| 19 |
+
center: float32 [3] smart normalization center
|
| 20 |
+
scale: float32 scalar smart normalization scale
|
| 21 |
+
gt_vertices: float32 [V, 3] ground truth wireframe vertices
|
| 22 |
+
gt_edges: int32 [E, 2] ground truth wireframe edge indices
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import sys
|
| 27 |
+
from pathlib import Path as _Path
|
| 28 |
+
if __package__ is None or __package__ == "":
|
| 29 |
+
_here = _Path(__file__).resolve().parent
|
| 30 |
+
if str(_here.parent) not in sys.path:
|
| 31 |
+
sys.path.insert(0, str(_here.parent))
|
| 32 |
+
__package__ = _here.name
|
| 33 |
+
|
| 34 |
+
import argparse
|
| 35 |
+
import time
|
| 36 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
|
| 39 |
+
import numpy as np
|
| 40 |
+
import torch
|
| 41 |
+
|
| 42 |
+
from .point_fusion import (
|
| 43 |
+
FuserConfig, build_compact_scene,
|
| 44 |
+
GEST_ID_TO_NAME, ADE_ID_TO_NAME, NUM_GEST,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Semantic class encoding: 11 structural + 1 other_house + 1 non_house = 13
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
# Each structural gestalt class gets its own one-hot bit.
|
| 52 |
+
STRUCTURAL_CLASSES = (
|
| 53 |
+
"apex", "eave_end_point", "flashing_end_point", # point classes (tier 0)
|
| 54 |
+
"rake", "ridge", "eave", "hip", "valley", # roof edges (tier 1)
|
| 55 |
+
"flashing", "step_flashing",
|
| 56 |
+
"roof", # roof face (tier 2)
|
| 57 |
+
)
|
| 58 |
+
# Index 11 = other house part (door, window, siding, etc.)
|
| 59 |
+
# Index 12 = non-house / ADE / unlabeled
|
| 60 |
+
NUM_SEMANTIC_CLASSES = len(STRUCTURAL_CLASSES) + 2 # 13
|
| 61 |
+
|
| 62 |
+
# Priority tiers (same as tokenizer.py)
|
| 63 |
+
_GEST_NAME_TO_ID = {n: i for i, n in enumerate(GEST_ID_TO_NAME)}
|
| 64 |
+
_POINT_IDS = {_GEST_NAME_TO_ID[n] for n in ("apex", "eave_end_point", "flashing_end_point") if n in _GEST_NAME_TO_ID}
|
| 65 |
+
_EDGE_IDS = {_GEST_NAME_TO_ID[n] for n in ("rake", "ridge", "eave", "hip", "valley", "flashing", "step_flashing") if n in _GEST_NAME_TO_ID}
|
| 66 |
+
_FACE_IDS = {_GEST_NAME_TO_ID[n] for n in ("roof",) if n in _GEST_NAME_TO_ID}
|
| 67 |
+
_HOUSE_IDS = {_GEST_NAME_TO_ID[n] for n in (
|
| 68 |
+
"apex", "eave_end_point", "flashing_end_point",
|
| 69 |
+
"rake", "ridge", "eave", "hip", "valley", "flashing", "step_flashing",
|
| 70 |
+
"roof", "door", "garage", "window", "shutter", "fascia", "soffit",
|
| 71 |
+
"horizontal_siding", "vertical_siding", "brick", "concrete",
|
| 72 |
+
"other_wall", "trim", "post", "ground_line",
|
| 73 |
+
) if n in _GEST_NAME_TO_ID}
|
| 74 |
+
|
| 75 |
+
_ADE_NAME_TO_ID = {n.lower(): i for i, n in enumerate(ADE_ID_TO_NAME)}
|
| 76 |
+
_ADE_HOUSE_IDS = {_ADE_NAME_TO_ID[n] for n in ("building;edifice", "house", "wall", "windowpane;window", "door;double;door") if n in _ADE_NAME_TO_ID}
|
| 77 |
+
|
| 78 |
+
_UNCLS_ID = _GEST_NAME_TO_ID.get("unclassified", -1)
|
| 79 |
+
|
| 80 |
+
# Map structural gestalt names to one-hot index
|
| 81 |
+
_STRUCTURAL_ONEHOT = {}
|
| 82 |
+
for idx, name in enumerate(STRUCTURAL_CLASSES):
|
| 83 |
+
gid = _GEST_NAME_TO_ID.get(name)
|
| 84 |
+
if gid is not None:
|
| 85 |
+
_STRUCTURAL_ONEHOT[gid] = idx
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _compute_group_and_class(visible_src, visible_id, behind_id, source):
|
| 89 |
+
"""Compute priority group_id and semantic class_id per point (vectorized).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
visible_src: uint8 [P] -- 0=unlabeled, 1=gestalt, 2=ade
|
| 93 |
+
visible_id: int16 [P] -- class id within gestalt or ade space
|
| 94 |
+
behind_id: int16 [P] -- behind-gestalt id (-1 if none)
|
| 95 |
+
source: uint8 [P] -- 0=colmap, 1=depth
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
group_id: int8 [P] -- priority tier 0-4, -1 for excluded (unclassified)
|
| 99 |
+
class_id: uint8 [P] -- one-hot class index 0-12
|
| 100 |
+
"""
|
| 101 |
+
P = len(visible_src)
|
| 102 |
+
vsrc = visible_src.astype(np.int32)
|
| 103 |
+
vid = visible_id.astype(np.int32)
|
| 104 |
+
bid = behind_id.astype(np.int32)
|
| 105 |
+
|
| 106 |
+
# Effective gestalt id: prefer visible gestalt, fall back to behind
|
| 107 |
+
gest_id = np.full(P, -1, dtype=np.int32)
|
| 108 |
+
has_vis_gest = (vsrc == 1) & (vid >= 0)
|
| 109 |
+
has_behind = (bid >= 0) & ~has_vis_gest
|
| 110 |
+
gest_id[has_vis_gest] = vid[has_vis_gest]
|
| 111 |
+
gest_id[has_behind] = bid[has_behind]
|
| 112 |
+
|
| 113 |
+
# Exclude unclassified points
|
| 114 |
+
if _UNCLS_ID >= 0:
|
| 115 |
+
is_uncls = ((vsrc == 1) & (vid == _UNCLS_ID)) | (bid == _UNCLS_ID)
|
| 116 |
+
gest_id[is_uncls] = -1 # force excluded
|
| 117 |
+
|
| 118 |
+
# Build lookup arrays for gestalt id -> group and gestalt id -> class
|
| 119 |
+
max_gid = NUM_GEST
|
| 120 |
+
gid_to_group = np.full(max_gid, 4, dtype=np.int8) # default: tier 4
|
| 121 |
+
gid_to_class = np.full(max_gid, NUM_SEMANTIC_CLASSES - 1, dtype=np.uint8) # default: non-house
|
| 122 |
+
|
| 123 |
+
for gid in _POINT_IDS:
|
| 124 |
+
gid_to_group[gid] = 0
|
| 125 |
+
for gid in _EDGE_IDS:
|
| 126 |
+
gid_to_group[gid] = 1
|
| 127 |
+
for gid in _FACE_IDS:
|
| 128 |
+
gid_to_group[gid] = 2
|
| 129 |
+
for gid in _HOUSE_IDS - _POINT_IDS - _EDGE_IDS - _FACE_IDS:
|
| 130 |
+
gid_to_group[gid] = 3
|
| 131 |
+
for gid, onehot_idx in _STRUCTURAL_ONEHOT.items():
|
| 132 |
+
gid_to_class[gid] = onehot_idx
|
| 133 |
+
for gid in _HOUSE_IDS - set(_STRUCTURAL_ONEHOT.keys()):
|
| 134 |
+
gid_to_class[gid] = len(STRUCTURAL_CLASSES) # other_house
|
| 135 |
+
|
| 136 |
+
# Apply lookup for points with valid gestalt ids
|
| 137 |
+
has_gest = gest_id >= 0
|
| 138 |
+
group_id = np.full(P, 4, dtype=np.int8) # default: tier 4
|
| 139 |
+
class_id = np.full(P, NUM_SEMANTIC_CLASSES - 1, dtype=np.uint8) # default: non-house
|
| 140 |
+
|
| 141 |
+
group_id[has_gest] = gid_to_group[gest_id[has_gest]]
|
| 142 |
+
class_id[has_gest] = gid_to_class[gest_id[has_gest]]
|
| 143 |
+
|
| 144 |
+
# ADE house points (no gestalt) get tier 3 + class_id = other_house
|
| 145 |
+
ade_house_arr = np.array(sorted(_ADE_HOUSE_IDS), dtype=np.int32)
|
| 146 |
+
is_ade_house = ~has_gest & (vsrc == 2) & (vid >= 0) & np.isin(vid, ade_house_arr)
|
| 147 |
+
group_id[is_ade_house] = 3
|
| 148 |
+
class_id[is_ade_house] = len(STRUCTURAL_CLASSES) # other_house (index 11)
|
| 149 |
+
|
| 150 |
+
# Mark excluded points (unclassified) as -1
|
| 151 |
+
if _UNCLS_ID >= 0:
|
| 152 |
+
group_id[is_uncls] = -1
|
| 153 |
+
class_id[is_uncls] = NUM_SEMANTIC_CLASSES - 1
|
| 154 |
+
|
| 155 |
+
return group_id, class_id
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _compute_smart_center_scale(xyz, source, mad_k=2.5, percentile=95.0,
|
| 159 |
+
max_points=8000):
|
| 160 |
+
"""Compute normalization center and scale from depth points with MAD filter."""
|
| 161 |
+
depth_mask = source == 1
|
| 162 |
+
ref = xyz[depth_mask] if depth_mask.any() else xyz
|
| 163 |
+
if ref.shape[0] == 0:
|
| 164 |
+
center = xyz.mean(axis=0)
|
| 165 |
+
scale = max(np.linalg.norm(xyz - center, axis=1).max(), 1e-6)
|
| 166 |
+
return center.astype(np.float32), np.float32(scale)
|
| 167 |
+
|
| 168 |
+
if ref.shape[0] > max_points:
|
| 169 |
+
idx = np.random.choice(ref.shape[0], max_points, replace=False)
|
| 170 |
+
ref = ref[idx]
|
| 171 |
+
|
| 172 |
+
center0 = np.median(ref, axis=0)
|
| 173 |
+
dist = np.linalg.norm(ref - center0, axis=1)
|
| 174 |
+
med = np.median(dist)
|
| 175 |
+
mad = max(np.median(np.abs(dist - med)), 1e-6)
|
| 176 |
+
inliers = dist <= (med + mad_k * mad)
|
| 177 |
+
if inliers.any():
|
| 178 |
+
ref = ref[inliers]
|
| 179 |
+
|
| 180 |
+
# Percentile bounding box
|
| 181 |
+
lo_f = (100.0 - percentile) * 0.5 / 100.0
|
| 182 |
+
sorted_v = np.sort(ref, axis=0)
|
| 183 |
+
n = sorted_v.shape[0]
|
| 184 |
+
lo_idx = max(0, min(n - 1, int(lo_f * (n - 1))))
|
| 185 |
+
hi_idx = max(0, min(n - 1, int((1.0 - lo_f) * (n - 1))))
|
| 186 |
+
low = sorted_v[lo_idx]
|
| 187 |
+
high = sorted_v[hi_idx]
|
| 188 |
+
|
| 189 |
+
center = 0.5 * (low + high)
|
| 190 |
+
scale = max(np.sqrt(((high - low) ** 2).sum()), 1e-6)
|
| 191 |
+
return center.astype(np.float32), np.float32(scale)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _process_one(sample, cfg):
|
| 195 |
+
"""Process a single HF sample into a cache dict. Returns (order_id, dict) or None."""
|
| 196 |
+
rng = np.random.RandomState() # worker-local rng
|
| 197 |
+
|
| 198 |
+
n_edges = len(sample.get("wf_edges", []))
|
| 199 |
+
if n_edges == 0 or n_edges > 64:
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
scene = build_compact_scene(sample, cfg, rng=rng)
|
| 203 |
+
if scene is None:
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
gt_v = scene.get("gt_vertices")
|
| 207 |
+
gt_e = scene.get("gt_edges")
|
| 208 |
+
if gt_v is None or gt_e is None or len(gt_e) == 0:
|
| 209 |
+
return None
|
| 210 |
+
|
| 211 |
+
xyz = scene["xyz"]
|
| 212 |
+
source = scene["source"]
|
| 213 |
+
visible_src = scene["visible_src"]
|
| 214 |
+
visible_id = scene["visible_id"]
|
| 215 |
+
behind_id = scene["behind_gest_id"]
|
| 216 |
+
|
| 217 |
+
group_id, class_id = _compute_group_and_class(
|
| 218 |
+
visible_src, visible_id, behind_id, source
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
center, scale = _compute_smart_center_scale(xyz, source)
|
| 222 |
+
|
| 223 |
+
order_id = sample.get("order_id", "unknown")
|
| 224 |
+
|
| 225 |
+
return order_id, {
|
| 226 |
+
"xyz": xyz.astype(np.float32),
|
| 227 |
+
"source": source.astype(np.uint8),
|
| 228 |
+
"group_id": group_id,
|
| 229 |
+
"class_id": class_id,
|
| 230 |
+
"behind_gest_id": behind_id.astype(np.int16),
|
| 231 |
+
"visible_src": visible_src.astype(np.uint8),
|
| 232 |
+
"visible_id": visible_id.astype(np.int16),
|
| 233 |
+
"n_views_voted": scene["n_views_voted"],
|
| 234 |
+
"vote_frac": scene["vote_frac"],
|
| 235 |
+
"center": center,
|
| 236 |
+
"scale": scale,
|
| 237 |
+
"gt_vertices": gt_v.astype(np.float32),
|
| 238 |
+
"gt_edges": gt_e.astype(np.int32),
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def main():
|
| 243 |
+
p = argparse.ArgumentParser(description="Cache compact scenes from HoHo22k")
|
| 244 |
+
g = p.add_mutually_exclusive_group(required=True)
|
| 245 |
+
g.add_argument("--data-dir", help="Local dir with shards")
|
| 246 |
+
g.add_argument("--streaming", action="store_true", help="Stream from HuggingFace")
|
| 247 |
+
p.add_argument("--out-dir", required=True, help="Output directory for .pt files")
|
| 248 |
+
p.add_argument("--limit", type=int, default=0)
|
| 249 |
+
p.add_argument("--depth-per-view", type=int, default=8000)
|
| 250 |
+
p.add_argument("--workers", type=int, default=0,
|
| 251 |
+
help="Parallel workers (0=sequential)")
|
| 252 |
+
p.add_argument("--skip-existing", action="store_true",
|
| 253 |
+
help="Skip samples whose .pt already exists in out-dir")
|
| 254 |
+
p.add_argument("--shard-start", type=int, default=0,
|
| 255 |
+
help="First shard index (for parallel launches)")
|
| 256 |
+
p.add_argument("--shard-stride", type=int, default=1,
|
| 257 |
+
help="Stride between shards (e.g. 8 means take every 8th shard)")
|
| 258 |
+
args = p.parse_args()
|
| 259 |
+
|
| 260 |
+
out_dir = Path(args.out_dir)
|
| 261 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 262 |
+
existing_ids = set(p.stem for p in out_dir.glob("*.pt")) if args.skip_existing else set()
|
| 263 |
+
|
| 264 |
+
# Load dataset
|
| 265 |
+
from datasets import load_dataset
|
| 266 |
+
if args.streaming:
|
| 267 |
+
ds = load_dataset(
|
| 268 |
+
"usm3d/hoho22k_2026_trainval",
|
| 269 |
+
streaming=True, trust_remote_code=True, split="train",
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
data_root = Path(args.data_dir).resolve()
|
| 273 |
+
tars = []
|
| 274 |
+
for candidate in [data_root / "data" / "train", data_root / "train", data_root]:
|
| 275 |
+
if candidate.exists():
|
| 276 |
+
tars = sorted(str(p) for p in candidate.glob("*.tar"))
|
| 277 |
+
if tars:
|
| 278 |
+
break
|
| 279 |
+
loader = None
|
| 280 |
+
for c in [data_root / "hoho22k_2026_trainval.py"]:
|
| 281 |
+
if c.exists():
|
| 282 |
+
loader = c
|
| 283 |
+
break
|
| 284 |
+
if loader is None:
|
| 285 |
+
found = list(data_root.rglob("hoho22k_2026_trainval.py"))
|
| 286 |
+
loader = found[0] if found else None
|
| 287 |
+
if loader is None:
|
| 288 |
+
raise FileNotFoundError("Cannot find loader script")
|
| 289 |
+
# Shard-level parallelism: each process handles a slice of tars
|
| 290 |
+
if args.shard_stride > 1:
|
| 291 |
+
tars = tars[args.shard_start::args.shard_stride]
|
| 292 |
+
print(f"Shard slice: start={args.shard_start} stride={args.shard_stride} -> {len(tars)} shards")
|
| 293 |
+
ds = load_dataset(str(loader), data_files={"train": tars},
|
| 294 |
+
streaming=True, trust_remote_code=True, split="train")
|
| 295 |
+
|
| 296 |
+
cfg = FuserConfig(depth_points_per_view=args.depth_per_view)
|
| 297 |
+
|
| 298 |
+
saved = 0
|
| 299 |
+
skipped = 0
|
| 300 |
+
t_start = time.perf_counter()
|
| 301 |
+
|
| 302 |
+
if args.workers > 0:
|
| 303 |
+
# Parallel: collect samples into batches, process in worker pool
|
| 304 |
+
# Note: HF streaming datasets can't be shared across workers, so we
|
| 305 |
+
# iterate in the main thread and dispatch processing to workers.
|
| 306 |
+
with ProcessPoolExecutor(max_workers=args.workers) as pool:
|
| 307 |
+
futures = {}
|
| 308 |
+
for i, sample in enumerate(ds):
|
| 309 |
+
if args.limit > 0 and i >= args.limit:
|
| 310 |
+
break
|
| 311 |
+
oid = sample.get("order_id", "unknown")
|
| 312 |
+
if oid in existing_ids:
|
| 313 |
+
skipped += 1
|
| 314 |
+
continue
|
| 315 |
+
future = pool.submit(_process_one, sample, cfg)
|
| 316 |
+
futures[future] = i
|
| 317 |
+
|
| 318 |
+
# Drain completed futures to bound memory
|
| 319 |
+
if len(futures) >= args.workers * 4:
|
| 320 |
+
done = [f for f in futures if f.done()]
|
| 321 |
+
for f in done:
|
| 322 |
+
result = f.result()
|
| 323 |
+
del futures[f]
|
| 324 |
+
if result is None:
|
| 325 |
+
skipped += 1
|
| 326 |
+
continue
|
| 327 |
+
order_id, data = result
|
| 328 |
+
torch.save(data, out_dir / f"{order_id}.pt")
|
| 329 |
+
saved += 1
|
| 330 |
+
if saved % 50 == 0:
|
| 331 |
+
elapsed = time.perf_counter() - t_start
|
| 332 |
+
print(f"Saved {saved} (skipped {skipped}) "
|
| 333 |
+
f"[{saved / elapsed:.1f} samples/s]")
|
| 334 |
+
|
| 335 |
+
# Drain remaining
|
| 336 |
+
for f in as_completed(futures):
|
| 337 |
+
result = f.result()
|
| 338 |
+
if result is None:
|
| 339 |
+
skipped += 1
|
| 340 |
+
continue
|
| 341 |
+
order_id, data = result
|
| 342 |
+
torch.save(data, out_dir / f"{order_id}.pt")
|
| 343 |
+
saved += 1
|
| 344 |
+
else:
|
| 345 |
+
# Sequential
|
| 346 |
+
for i, sample in enumerate(ds):
|
| 347 |
+
if args.limit > 0 and i >= args.limit:
|
| 348 |
+
break
|
| 349 |
+
oid = sample.get("order_id", "unknown")
|
| 350 |
+
if oid in existing_ids:
|
| 351 |
+
skipped += 1
|
| 352 |
+
continue
|
| 353 |
+
|
| 354 |
+
result = _process_one(sample, cfg)
|
| 355 |
+
if result is None:
|
| 356 |
+
skipped += 1
|
| 357 |
+
continue
|
| 358 |
+
order_id, data = result
|
| 359 |
+
torch.save(data, out_dir / f"{order_id}.pt")
|
| 360 |
+
saved += 1
|
| 361 |
+
|
| 362 |
+
if saved % 50 == 0:
|
| 363 |
+
elapsed = time.perf_counter() - t_start
|
| 364 |
+
print(f"Saved {saved} (skipped {skipped}) "
|
| 365 |
+
f"[{saved / elapsed:.1f} samples/s]")
|
| 366 |
+
|
| 367 |
+
elapsed = time.perf_counter() - t_start
|
| 368 |
+
print(f"Done. Saved {saved}, skipped {skipped} in {elapsed:.0f}s "
|
| 369 |
+
f"({saved / elapsed:.1f} samples/s)")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
main()
|
s23dr_2026_example/color_mappings.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gestalt_color_mapping = {
|
| 2 |
+
"unclassified": (215, 62, 138),
|
| 3 |
+
"apex": (235, 88, 48),
|
| 4 |
+
"eave_end_point": (248, 130, 228),
|
| 5 |
+
"flashing_end_point": (71, 11, 161),
|
| 6 |
+
"ridge": (214, 251, 248),
|
| 7 |
+
"rake": (13, 94, 47),
|
| 8 |
+
"eave": (54, 243, 63),
|
| 9 |
+
"post": (187, 123, 236),
|
| 10 |
+
"ground_line": (136, 206, 14),
|
| 11 |
+
"flashing": (162, 162, 32),
|
| 12 |
+
"step_flashing": (169, 255, 219),
|
| 13 |
+
"hip": (8, 89, 52),
|
| 14 |
+
"valley": (85, 27, 65),
|
| 15 |
+
"roof": (215, 232, 179),
|
| 16 |
+
"door": (110, 52, 23),
|
| 17 |
+
"garage": (50, 233, 171),
|
| 18 |
+
"window": (230, 249, 40),
|
| 19 |
+
"shutter": (122, 4, 233),
|
| 20 |
+
"fascia": (95, 230, 240),
|
| 21 |
+
"soffit": (2, 102, 197),
|
| 22 |
+
"horizontal_siding": (131, 88, 59),
|
| 23 |
+
"vertical_siding": (110, 187, 198),
|
| 24 |
+
"brick": (171, 252, 7),
|
| 25 |
+
"concrete": (32, 47, 246),
|
| 26 |
+
"other_wall": (112, 61, 240),
|
| 27 |
+
"trim": (151, 206, 58),
|
| 28 |
+
"unknown": (127, 127, 127),
|
| 29 |
+
"transition_line": (0,0,0),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
ade20k_color_mapping = {
|
| 33 |
+
'wall': (120, 120, 120),
|
| 34 |
+
'building;edifice': (180, 120, 120),
|
| 35 |
+
'sky': (6, 230, 230),
|
| 36 |
+
'floor;flooring': (80, 50, 50),
|
| 37 |
+
'tree': (4, 200, 3),
|
| 38 |
+
'ceiling': (120, 120, 80),
|
| 39 |
+
'road;route': (140, 140, 140),
|
| 40 |
+
'bed': (204, 5, 255),
|
| 41 |
+
'windowpane;window': (230, 230, 230),
|
| 42 |
+
'grass': (4, 250, 7),
|
| 43 |
+
'cabinet': (224, 5, 255),
|
| 44 |
+
'sidewalk;pavement': (235, 255, 7),
|
| 45 |
+
'person;individual;someone;somebody;mortal;soul': (150, 5, 61),
|
| 46 |
+
'earth;ground': (120, 120, 70),
|
| 47 |
+
'door;double;door': (8, 255, 51),
|
| 48 |
+
'table': (255, 6, 82),
|
| 49 |
+
'mountain;mount': (143, 255, 140),
|
| 50 |
+
'plant;flora;plant;life': (204, 255, 4),
|
| 51 |
+
'curtain;drape;drapery;mantle;pall': (255, 51, 7),
|
| 52 |
+
'chair': (204, 70, 3),
|
| 53 |
+
'car;auto;automobile;machine;motorcar': (0, 102, 200),
|
| 54 |
+
'water': (61, 230, 250),
|
| 55 |
+
'painting;picture': (255, 6, 51),
|
| 56 |
+
'sofa;couch;lounge': (11, 102, 255),
|
| 57 |
+
'shelf': (255, 7, 71),
|
| 58 |
+
'house': (255, 9, 224),
|
| 59 |
+
'sea': (9, 7, 230),
|
| 60 |
+
'mirror': (220, 220, 220),
|
| 61 |
+
'rug;carpet;carpeting': (255, 9, 92),
|
| 62 |
+
'field': (112, 9, 255),
|
| 63 |
+
'armchair': (8, 255, 214),
|
| 64 |
+
'seat': (7, 255, 224),
|
| 65 |
+
'fence;fencing': (255, 184, 6),
|
| 66 |
+
'desk': (10, 255, 71),
|
| 67 |
+
'rock;stone': (255, 41, 10),
|
| 68 |
+
'wardrobe;closet;press': (7, 255, 255),
|
| 69 |
+
'lamp': (224, 255, 8),
|
| 70 |
+
'bathtub;bathing;tub;bath;tub': (102, 8, 255),
|
| 71 |
+
'railing;rail': (255, 61, 6),
|
| 72 |
+
'cushion': (255, 194, 7),
|
| 73 |
+
'base;pedestal;stand': (255, 122, 8),
|
| 74 |
+
'box': (0, 255, 20),
|
| 75 |
+
'column;pillar': (255, 8, 41),
|
| 76 |
+
'signboard;sign': (255, 5, 153),
|
| 77 |
+
'chest;of;drawers;chest;bureau;dresser': (6, 51, 255),
|
| 78 |
+
'counter': (235, 12, 255),
|
| 79 |
+
'sand': (160, 150, 20),
|
| 80 |
+
'sink': (0, 163, 255),
|
| 81 |
+
'skyscraper': (140, 140, 140),
|
| 82 |
+
'fireplace;hearth;open;fireplace': (250, 10, 15),
|
| 83 |
+
'refrigerator;icebox': (20, 255, 0),
|
| 84 |
+
'grandstand;covered;stand': (31, 255, 0),
|
| 85 |
+
'path': (255, 31, 0),
|
| 86 |
+
'stairs;steps': (255, 224, 0),
|
| 87 |
+
'runway': (153, 255, 0),
|
| 88 |
+
'case;display;case;showcase;vitrine': (0, 0, 255),
|
| 89 |
+
'pool;table;billiard;table;snooker;table': (255, 71, 0),
|
| 90 |
+
'pillow': (0, 235, 255),
|
| 91 |
+
'screen;door;screen': (0, 173, 255),
|
| 92 |
+
'stairway;staircase': (31, 0, 255),
|
| 93 |
+
'river': (11, 200, 200),
|
| 94 |
+
'bridge;span': (255 ,82, 0),
|
| 95 |
+
'bookcase': (0, 255, 245),
|
| 96 |
+
'blind;screen': (0, 61, 255),
|
| 97 |
+
'coffee;table;cocktail;table': (0, 255, 112),
|
| 98 |
+
'toilet;can;commode;crapper;pot;potty;stool;throne': (0, 255, 133),
|
| 99 |
+
'flower': (255, 0, 0),
|
| 100 |
+
'book': (255, 163, 0),
|
| 101 |
+
'hill': (255, 102, 0),
|
| 102 |
+
'bench': (194, 255, 0),
|
| 103 |
+
'countertop': (0, 143, 255),
|
| 104 |
+
'stove;kitchen;stove;range;kitchen;range;cooking;stove': (51, 255, 0),
|
| 105 |
+
'palm;palm;tree': (0, 82, 255),
|
| 106 |
+
'kitchen;island': (0, 255, 41),
|
| 107 |
+
'computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system': (0, 255, 173),
|
| 108 |
+
'swivel;chair': (10, 0, 255),
|
| 109 |
+
'boat': (173, 255, 0),
|
| 110 |
+
'bar': (0, 255, 153),
|
| 111 |
+
'arcade;machine': (255, 92, 0),
|
| 112 |
+
'hovel;hut;hutch;shack;shanty': (255, 0, 255),
|
| 113 |
+
'bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle': (255, 0, 245),
|
| 114 |
+
'towel': (255, 0, 102),
|
| 115 |
+
'light;light;source': (255, 173, 0),
|
| 116 |
+
'truck;motortruck': (255, 0, 20),
|
| 117 |
+
'tower': (255, 184, 184),
|
| 118 |
+
'chandelier;pendant;pendent': (0, 31, 255),
|
| 119 |
+
'awning;sunshade;sunblind': (0, 255, 61),
|
| 120 |
+
'streetlight;street;lamp': (0, 71, 255),
|
| 121 |
+
'booth;cubicle;stall;kiosk': (255, 0, 204),
|
| 122 |
+
'television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box': (0, 255, 194),
|
| 123 |
+
'airplane;aeroplane;plane': (0, 255, 82),
|
| 124 |
+
'dirt;track': (0, 10, 255),
|
| 125 |
+
'apparel;wearing;apparel;dress;clothes': (0, 112, 255),
|
| 126 |
+
'pole': (51, 0, 255),
|
| 127 |
+
'land;ground;soil': (0, 194, 255),
|
| 128 |
+
'bannister;banister;balustrade;balusters;handrail': (0, 122, 255),
|
| 129 |
+
'escalator;moving;staircase;moving;stairway': (0, 255, 163),
|
| 130 |
+
'ottoman;pouf;pouffe;puff;hassock': (255, 153, 0),
|
| 131 |
+
'bottle': (0, 255, 10),
|
| 132 |
+
'buffet;counter;sideboard': (255, 112, 0),
|
| 133 |
+
'poster;posting;placard;notice;bill;card': (143, 255, 0),
|
| 134 |
+
'stage': (82, 0, 255),
|
| 135 |
+
'van': (163, 255, 0),
|
| 136 |
+
'ship': (255, 235, 0),
|
| 137 |
+
'fountain': (8, 184, 170),
|
| 138 |
+
'conveyer;belt;conveyor;belt;conveyer;conveyor;transporter': (133, 0, 255),
|
| 139 |
+
'canopy': (0, 255, 92),
|
| 140 |
+
'washer;automatic;washer;washing;machine': (184, 0, 255),
|
| 141 |
+
'plaything;toy': (255, 0, 31),
|
| 142 |
+
'swimming;pool;swimming;bath;natatorium': (0, 184, 255),
|
| 143 |
+
'stool': (0, 214, 255),
|
| 144 |
+
'barrel;cask': (255, 0, 112),
|
| 145 |
+
'basket;handbasket': (92, 255, 0),
|
| 146 |
+
'waterfall;falls': (0, 224, 255),
|
| 147 |
+
'tent;collapsible;shelter': (112, 224, 255),
|
| 148 |
+
'bag': (70, 184, 160),
|
| 149 |
+
'minibike;motorbike': (163, 0, 255),
|
| 150 |
+
'cradle': (153, 0, 255),
|
| 151 |
+
'oven': (71, 255, 0),
|
| 152 |
+
'ball': (255, 0, 163),
|
| 153 |
+
'food;solid;food': (255, 204, 0),
|
| 154 |
+
'step;stair': (255, 0, 143),
|
| 155 |
+
'tank;storage;tank': (0, 255, 235),
|
| 156 |
+
'trade;name;brand;name;brand;marque': (133, 255, 0),
|
| 157 |
+
'microwave;microwave;oven': (255, 0, 235),
|
| 158 |
+
'pot;flowerpot': (245, 0, 255),
|
| 159 |
+
'animal;animate;being;beast;brute;creature;fauna': (255, 0, 122),
|
| 160 |
+
'bicycle;bike;wheel;cycle': (255, 245, 0),
|
| 161 |
+
'lake': (10, 190, 212),
|
| 162 |
+
'dishwasher;dish;washer;dishwashing;machine': (214, 255, 0),
|
| 163 |
+
'screen;silver;screen;projection;screen': (0, 204, 255),
|
| 164 |
+
'blanket;cover': (20, 0, 255),
|
| 165 |
+
'sculpture': (255, 255, 0),
|
| 166 |
+
'hood;exhaust;hood': (0, 153, 255),
|
| 167 |
+
'sconce': (0, 41, 255),
|
| 168 |
+
'vase': (0, 255, 204),
|
| 169 |
+
'traffic;light;traffic;signal;stoplight': (41, 0, 255),
|
| 170 |
+
'tray': (41, 255, 0),
|
| 171 |
+
'ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin': (173, 0, 255),
|
| 172 |
+
'fan': (0, 245, 255),
|
| 173 |
+
'pier;wharf;wharfage;dock': (71, 0, 255),
|
| 174 |
+
'crt;screen': (122, 0, 255),
|
| 175 |
+
'plate': (0, 255, 184),
|
| 176 |
+
'monitor;monitoring;device': (0, 92, 255),
|
| 177 |
+
'bulletin;board;notice;board': (184, 255, 0),
|
| 178 |
+
'shower': (0, 133, 255),
|
| 179 |
+
'radiator': (255, 214, 0),
|
| 180 |
+
'glass;drinking;glass': (25, 194, 194),
|
| 181 |
+
'clock': (102, 255, 0),
|
| 182 |
+
'flag': (92, 0, 255),
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
EDGE_CLASSES = {'cornice_return': 0,
|
| 187 |
+
'cornice_strip': 1,
|
| 188 |
+
'eave': 2,
|
| 189 |
+
'flashing': 3,
|
| 190 |
+
'hip': 4,
|
| 191 |
+
'rake': 5,
|
| 192 |
+
'ridge': 6,
|
| 193 |
+
'step_flashing': 7,
|
| 194 |
+
'transition_line': 8,
|
| 195 |
+
'valley': 9}
|
| 196 |
+
EDGE_CLASSES_BY_ID = {v: k for k, v in EDGE_CLASSES.items()}
|
| 197 |
+
|
| 198 |
+
edge_color_mapping = {
|
| 199 |
+
'cornice_return': (215, 62, 138),
|
| 200 |
+
'cornice_strip': (235, 88, 48),
|
| 201 |
+
'eave': (54, 243, 63),
|
| 202 |
+
"flashing": (162, 162, 32),
|
| 203 |
+
'hip': (8, 89, 52),
|
| 204 |
+
'rake': (13, 94, 47),
|
| 205 |
+
'ridge': (214, 251, 248),
|
| 206 |
+
"step_flashing": (169, 255, 219),
|
| 207 |
+
'transition_line': (200,0,50),
|
| 208 |
+
'valley': (85, 27, 65),
|
| 209 |
+
}
|
s23dr_2026_example/data.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data loading for pre-sampled HF datasets.
|
| 2 |
+
|
| 3 |
+
Expects pre-sampled npz blobs with xyz_norm [2048, 3] (not full PCD).
|
| 4 |
+
Use make_sampled_cache.py to produce these from full point clouds.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from .tokenizer import EdgeDepthSequenceConfig
|
| 14 |
+
|
| 15 |
+
# Default token budget (must match make_sampled_cache.py)
|
| 16 |
+
SEQ_LEN = 2048
|
| 17 |
+
COLMAP_POINTS = 1536
|
| 18 |
+
DEPTH_POINTS = 512
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Datasets
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
def _load_bad_sample_ids():
|
| 26 |
+
"""Load the set of known-bad sample IDs (misaligned GT, extreme scale)."""
|
| 27 |
+
bad_file = Path(__file__).parent / "bad_samples.txt"
|
| 28 |
+
if not bad_file.exists():
|
| 29 |
+
return set()
|
| 30 |
+
return set(line.strip() for line in bad_file.read_text().splitlines() if line.strip())
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class HFCachedDataset(torch.utils.data.Dataset):
|
| 34 |
+
"""Load pre-sampled HuggingFace dataset into memory."""
|
| 35 |
+
|
| 36 |
+
def __init__(self, hf_dataset, aug_rotate=False, aug_jitter=0.0,
|
| 37 |
+
aug_drop=0.0, aug_flip=False):
|
| 38 |
+
import io as _io
|
| 39 |
+
bad_ids = _load_bad_sample_ids()
|
| 40 |
+
print(f"Pre-decoding {len(hf_dataset)} samples into memory...")
|
| 41 |
+
self.samples = []
|
| 42 |
+
self.order_ids = []
|
| 43 |
+
n_skipped = 0
|
| 44 |
+
for i, sample in enumerate(hf_dataset):
|
| 45 |
+
if sample["order_id"] in bad_ids:
|
| 46 |
+
n_skipped += 1
|
| 47 |
+
continue
|
| 48 |
+
d = dict(np.load(_io.BytesIO(sample["data"])))
|
| 49 |
+
if "xyz_norm" not in d:
|
| 50 |
+
raise ValueError(
|
| 51 |
+
f"Sample {sample['order_id']} missing 'xyz_norm' -- this looks like "
|
| 52 |
+
f"a full PCD dataset, not pre-sampled. Use make_sampled_cache.py first.")
|
| 53 |
+
self.samples.append(d)
|
| 54 |
+
self.order_ids.append(sample["order_id"])
|
| 55 |
+
if (i + 1) % 2000 == 0:
|
| 56 |
+
print(f" {i+1}/{len(hf_dataset)}...")
|
| 57 |
+
print(f" Done. {len(self.samples)} samples in memory"
|
| 58 |
+
f" ({n_skipped} bad samples filtered).")
|
| 59 |
+
self.aug_rotate = aug_rotate
|
| 60 |
+
self.aug_jitter = aug_jitter
|
| 61 |
+
self.aug_drop = aug_drop
|
| 62 |
+
self.aug_flip = aug_flip
|
| 63 |
+
|
| 64 |
+
def __len__(self):
|
| 65 |
+
return len(self.samples)
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, idx):
|
| 68 |
+
out = _process_sample(self.samples[idx], self.aug_rotate,
|
| 69 |
+
self.aug_jitter, self.aug_drop, self.aug_flip)
|
| 70 |
+
out["sample_id"] = self.order_ids[idx]
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _process_sample(d, aug_rotate, aug_jitter=0.0, aug_drop=0.0, aug_flip=False):
|
| 75 |
+
"""Process a pre-sampled npz dict into training tensors.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
aug_rotate: random yaw rotation
|
| 79 |
+
aug_jitter: std of Gaussian noise added to point positions (0=disabled)
|
| 80 |
+
aug_drop: fraction of points to randomly drop (0=disabled)
|
| 81 |
+
aug_flip: random mirror along X axis (50% chance)
|
| 82 |
+
"""
|
| 83 |
+
xyz_norm = d["xyz_norm"].copy()
|
| 84 |
+
gt_seg = d["gt_segments"].copy()
|
| 85 |
+
mask = d["mask"].copy()
|
| 86 |
+
|
| 87 |
+
if aug_rotate:
|
| 88 |
+
theta = np.random.rand() * 2 * np.pi
|
| 89 |
+
cos_t, sin_t = np.cos(theta), np.sin(theta)
|
| 90 |
+
x, z = xyz_norm[:, 0].copy(), xyz_norm[:, 2].copy()
|
| 91 |
+
xyz_norm[:, 0] = x * cos_t - z * sin_t
|
| 92 |
+
xyz_norm[:, 2] = x * sin_t + z * cos_t
|
| 93 |
+
for ep in range(2):
|
| 94 |
+
sx, sz = gt_seg[:, ep, 0].copy(), gt_seg[:, ep, 2].copy()
|
| 95 |
+
gt_seg[:, ep, 0] = sx * cos_t - sz * sin_t
|
| 96 |
+
gt_seg[:, ep, 2] = sx * sin_t + sz * cos_t
|
| 97 |
+
|
| 98 |
+
if aug_flip and np.random.rand() < 0.5:
|
| 99 |
+
xyz_norm[:, 0] = -xyz_norm[:, 0]
|
| 100 |
+
gt_seg[:, :, 0] = -gt_seg[:, :, 0]
|
| 101 |
+
|
| 102 |
+
if aug_jitter > 0:
|
| 103 |
+
valid = mask.astype(bool)
|
| 104 |
+
xyz_norm[valid] += np.random.randn(valid.sum(), 3).astype(np.float32) * aug_jitter
|
| 105 |
+
|
| 106 |
+
if aug_drop > 0:
|
| 107 |
+
valid_idx = np.where(mask)[0]
|
| 108 |
+
n_drop = int(len(valid_idx) * aug_drop)
|
| 109 |
+
if n_drop > 0:
|
| 110 |
+
drop_idx = np.random.choice(valid_idx, n_drop, replace=False)
|
| 111 |
+
mask[drop_idx] = False
|
| 112 |
+
|
| 113 |
+
result = {
|
| 114 |
+
"xyz_norm": torch.as_tensor(xyz_norm, dtype=torch.float32),
|
| 115 |
+
"class_id": torch.as_tensor(d["class_id"], dtype=torch.long),
|
| 116 |
+
"source": torch.as_tensor(d["source"], dtype=torch.long),
|
| 117 |
+
"mask": torch.as_tensor(mask),
|
| 118 |
+
"gt_segments": torch.as_tensor(gt_seg, dtype=torch.float32),
|
| 119 |
+
"scale": torch.tensor(float(d["scale"]), dtype=torch.float32),
|
| 120 |
+
"center": torch.as_tensor(d["center"], dtype=torch.float32),
|
| 121 |
+
"gt_vertices": d["gt_vertices"],
|
| 122 |
+
"gt_edges": d["gt_edges"],
|
| 123 |
+
"visible_src": torch.as_tensor(d["visible_src"], dtype=torch.long),
|
| 124 |
+
"visible_id": torch.as_tensor(d["visible_id"], dtype=torch.long),
|
| 125 |
+
}
|
| 126 |
+
if "behind" in d:
|
| 127 |
+
result["behind"] = torch.as_tensor(
|
| 128 |
+
np.clip(np.asarray(d["behind"], dtype=np.int16), 0, None), dtype=torch.long)
|
| 129 |
+
if "n_views_voted" in d:
|
| 130 |
+
result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32)
|
| 131 |
+
if "vote_frac" in d:
|
| 132 |
+
result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32)
|
| 133 |
+
if "gt_edge_classes" in d:
|
| 134 |
+
result["gt_edge_classes"] = torch.as_tensor(
|
| 135 |
+
np.asarray(d["gt_edge_classes"], dtype=np.int64), dtype=torch.long)
|
| 136 |
+
return result
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# Collation + DataLoader
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
def collate(batch):
|
| 144 |
+
"""Stack samples into batched tensors."""
|
| 145 |
+
out = {
|
| 146 |
+
"xyz_norm": torch.stack([d["xyz_norm"] for d in batch]),
|
| 147 |
+
"class_id": torch.stack([d["class_id"] for d in batch]),
|
| 148 |
+
"source": torch.stack([d["source"] for d in batch]),
|
| 149 |
+
"mask": torch.stack([d["mask"] for d in batch]),
|
| 150 |
+
"gt_segments": [d["gt_segments"] for d in batch],
|
| 151 |
+
"scales": torch.stack([d["scale"] for d in batch]),
|
| 152 |
+
"meta": batch,
|
| 153 |
+
}
|
| 154 |
+
# Optional fields: check ALL samples, not just batch[0].
|
| 155 |
+
# If any sample has it, all must have it (no mixed data versions).
|
| 156 |
+
for field in ("behind", "n_views_voted", "vote_frac"):
|
| 157 |
+
if any(field in d for d in batch):
|
| 158 |
+
missing = [i for i, d in enumerate(batch) if field not in d]
|
| 159 |
+
if missing:
|
| 160 |
+
raise KeyError(
|
| 161 |
+
f"Field '{field}' present in some batch samples but missing in "
|
| 162 |
+
f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
|
| 163 |
+
out[field] = torch.stack([d[field] for d in batch])
|
| 164 |
+
# gt_edge_classes: variable length per sample (like gt_segments), keep as list
|
| 165 |
+
if any("gt_edge_classes" in d for d in batch):
|
| 166 |
+
missing = [i for i, d in enumerate(batch) if "gt_edge_classes" not in d]
|
| 167 |
+
if missing:
|
| 168 |
+
raise KeyError(
|
| 169 |
+
f"Field 'gt_edge_classes' present in some batch samples but missing in "
|
| 170 |
+
f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
|
| 171 |
+
out["gt_edge_classes"] = [d["gt_edge_classes"] for d in batch]
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def build_loader(cache_dir, batch_size, aug_rotate=False, aug_jitter=0.0,
|
| 176 |
+
aug_drop=0.0, aug_flip=False):
|
| 177 |
+
"""Create a DataLoader from HF dataset.
|
| 178 |
+
|
| 179 |
+
cache_dir should be 'hf://repo/name:split' format.
|
| 180 |
+
"""
|
| 181 |
+
if not cache_dir.startswith("hf://"):
|
| 182 |
+
raise ValueError(
|
| 183 |
+
f"cache_dir must be 'hf://repo:split' format, got: {cache_dir}. "
|
| 184 |
+
f"Local .pt caches are no longer supported in the training path.")
|
| 185 |
+
parts = cache_dir[5:].split(":")
|
| 186 |
+
repo = parts[0]
|
| 187 |
+
split = parts[1] if len(parts) > 1 else "train"
|
| 188 |
+
from datasets import load_dataset
|
| 189 |
+
hf_ds = load_dataset(repo, split=split)
|
| 190 |
+
ds = HFCachedDataset(hf_ds, aug_rotate=aug_rotate, aug_jitter=aug_jitter,
|
| 191 |
+
aug_drop=aug_drop, aug_flip=aug_flip)
|
| 192 |
+
loader = torch.utils.data.DataLoader(
|
| 193 |
+
ds, batch_size=batch_size, shuffle=True,
|
| 194 |
+
num_workers=0, collate_fn=collate,
|
| 195 |
+
)
|
| 196 |
+
print(f"Dataset: {len(ds)} scenes, batch_size={batch_size}")
|
| 197 |
+
return loader
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# ---------------------------------------------------------------------------
|
| 201 |
+
# Token building (GPU)
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
|
| 204 |
+
def build_tokens(batch, model, device):
|
| 205 |
+
"""Apply Fourier features + learned embeddings on GPU."""
|
| 206 |
+
xyz = batch["xyz_norm"].to(device)
|
| 207 |
+
cid = batch["class_id"].to(device)
|
| 208 |
+
src = batch["source"].to(device)
|
| 209 |
+
masks = batch["mask"].to(device)
|
| 210 |
+
gt = [g.to(device) for g in batch["gt_segments"]]
|
| 211 |
+
scales = batch["scales"]
|
| 212 |
+
|
| 213 |
+
B, T, _ = xyz.shape
|
| 214 |
+
tok = model.tokenizer
|
| 215 |
+
fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \
|
| 216 |
+
if tok.pos_enc is not None else xyz.new_zeros(B, T, 0)
|
| 217 |
+
parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))]
|
| 218 |
+
if tok.behind_emb_dim > 0:
|
| 219 |
+
if "behind" in batch:
|
| 220 |
+
beh = batch["behind"].to(device)
|
| 221 |
+
else:
|
| 222 |
+
# Data doesn't have behind -- use zeros (embed index 0).
|
| 223 |
+
# This is intentional for eval on old data; for training,
|
| 224 |
+
# fail fast by requiring the field (checked in _process_sample).
|
| 225 |
+
beh = xyz.new_zeros(B, T, dtype=torch.long)
|
| 226 |
+
parts.append(tok.behind_emb(beh))
|
| 227 |
+
if tok.use_vote_features:
|
| 228 |
+
if "n_views_voted" not in batch or "vote_frac" not in batch:
|
| 229 |
+
raise KeyError(
|
| 230 |
+
"Model expects vote features (--vote-features) but data is missing "
|
| 231 |
+
"'n_views_voted'/'vote_frac'. Use v2 dataset or regenerate cache.")
|
| 232 |
+
# Normalize to ~zero mean, unit variance (dataset stats: nv~2.7+/-1.0, vf~0.5+/-0.25)
|
| 233 |
+
nv = ((batch["n_views_voted"].to(device).float() - 2.7) / 1.0).unsqueeze(-1)
|
| 234 |
+
vf = ((batch["vote_frac"].to(device).float() - 0.5) / 0.25).unsqueeze(-1)
|
| 235 |
+
parts.extend([nv, vf])
|
| 236 |
+
tokens = torch.cat(parts, dim=-1)
|
| 237 |
+
return tokens, masks, gt, scales, batch["meta"]
|
s23dr_2026_example/losses.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loss computation for wireframe prediction."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .varifold import varifold_loss_batch
|
| 7 |
+
from .sinkhorn import batched_sinkhorn_loss
|
| 8 |
+
from .soft_hss_loss import batched_sinkhorn_vertex_f1, batched_soft_hss_v2
|
| 9 |
+
|
| 10 |
+
# Varifold config
|
| 11 |
+
VARIANT = "simpson3"
|
| 12 |
+
SIGMAS = [0.5, 1.0, 2.0] # meters (divided by per-scene scale at runtime)
|
| 13 |
+
ALPHAS = [0.2, 0.6, 0.2]
|
| 14 |
+
LEN_POW = 1.0
|
| 15 |
+
VARIFOLD_CROSS_ONLY = False # Set to True to drop self-energy (avoids O(S^2) blowup)
|
| 16 |
+
|
| 17 |
+
# Sinkhorn config (note: near-zero gradients at eps=0.05, effectively disabled)
|
| 18 |
+
SINKHORN_EPS = 0.05
|
| 19 |
+
SINKHORN_ITERS = 10
|
| 20 |
+
|
| 21 |
+
# Distance thresholds in meters (divided by per-scene scale at runtime)
|
| 22 |
+
VERTEX_THRESH_M = 0.5 # vertex match threshold (mirrors real HSS)
|
| 23 |
+
TUBE_RADIUS_M = 0.5 # tube IoU radius (mirrors real HSS)
|
| 24 |
+
|
| 25 |
+
# Sinkhorn dustbin cost: controls the OT "not matching" penalty.
|
| 26 |
+
# Like tau, this is an OT behavior parameter, NOT a physical distance.
|
| 27 |
+
# Must be comparable to typical matching costs in normalized space (~0.1).
|
| 28 |
+
# Do NOT divide by scale.
|
| 29 |
+
SINKHORN_DUSTBIN = 0.1
|
| 30 |
+
|
| 31 |
+
# Sigmoid temperature: controls gradient smoothness, NOT a distance threshold.
|
| 32 |
+
# Must stay large enough in normalized space to provide useful gradients.
|
| 33 |
+
# Do NOT divide by scale (unlike the thresholds above).
|
| 34 |
+
SIGMOID_TAU = 0.05
|
| 35 |
+
|
| 36 |
+
MAX_GT = 64 # fixed pad size for compile-friendly shapes
|
| 37 |
+
|
| 38 |
+
# Precomputed constants (created once on first call)
|
| 39 |
+
_loss_constants = {}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _get_loss_constants(device, dtype):
|
| 43 |
+
key = (device, dtype)
|
| 44 |
+
if key not in _loss_constants:
|
| 45 |
+
_loss_constants[key] = {
|
| 46 |
+
"sigmas": torch.tensor(SIGMAS, device=device, dtype=dtype),
|
| 47 |
+
"alphas": torch.tensor(ALPHAS, device=device, dtype=dtype),
|
| 48 |
+
}
|
| 49 |
+
return _loss_constants[key]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def pad_gt_fixed(gt_list, device, dtype):
|
| 53 |
+
"""Pad GT segments to fixed MAX_GT for compile-friendly shapes."""
|
| 54 |
+
B = len(gt_list)
|
| 55 |
+
gt_pad = torch.zeros((B, MAX_GT, 2, 3), device=device, dtype=dtype)
|
| 56 |
+
gt_mask = torch.zeros((B, MAX_GT), device=device, dtype=torch.bool)
|
| 57 |
+
gt_lengths = torch.zeros(B, device=device, dtype=dtype)
|
| 58 |
+
for i, g in enumerate(gt_list):
|
| 59 |
+
n = g.shape[0]
|
| 60 |
+
if n > 0:
|
| 61 |
+
gt_pad[i, :n] = g
|
| 62 |
+
gt_mask[i, :n] = True
|
| 63 |
+
gt_lengths[i] = torch.linalg.norm(g[:, 1] - g[:, 0], dim=-1).sum()
|
| 64 |
+
return gt_pad, gt_mask, gt_lengths
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
|
| 68 |
+
sigmas, alphas, varifold_w, vertex_f1_w):
|
| 69 |
+
"""Pure tensor loss -- no Python control flow, no boolean indexing."""
|
| 70 |
+
has_gt = (gt_lengths > 0).float()
|
| 71 |
+
|
| 72 |
+
sigmas_eff = sigmas / scales[:, None]
|
| 73 |
+
loss_batch = varifold_loss_batch(
|
| 74 |
+
pred_segments, gt_pad, gt_mask=gt_mask,
|
| 75 |
+
variant=VARIANT, sigmas=sigmas_eff, alpha=alphas, len_pow=LEN_POW,
|
| 76 |
+
cross_only=VARIFOLD_CROSS_ONLY,
|
| 77 |
+
)
|
| 78 |
+
v = loss_batch / gt_lengths.clamp(min=1.0)
|
| 79 |
+
v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 80 |
+
|
| 81 |
+
thresh = VERTEX_THRESH_M / scales
|
| 82 |
+
f1 = batched_sinkhorn_vertex_f1(
|
| 83 |
+
pred_segments, gt_pad, gt_mask, thresh=thresh, tau=SIGMOID_TAU)
|
| 84 |
+
f1 = (f1 * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 85 |
+
|
| 86 |
+
total = varifold_w * v + vertex_f1_w * f1
|
| 87 |
+
return total, v, f1
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Will be replaced with compiled version on CUDA
|
| 91 |
+
_loss_fn = _loss_inner
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _conf_match_loss(pred_segments, gt_pad, gt_mask, conf_logits, scales):
|
| 95 |
+
"""Auxiliary BCE loss: train conf to predict whether each segment matches GT.
|
| 96 |
+
|
| 97 |
+
Computes per-segment min-distance to GT, creates soft match target via
|
| 98 |
+
sigmoid thresholding, and returns BCE(sigmoid(conf), target).
|
| 99 |
+
"""
|
| 100 |
+
B, S = pred_segments.shape[:2]
|
| 101 |
+
# Decoupled cost: midpoint + direction + length (same as sinkhorn)
|
| 102 |
+
p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
|
| 103 |
+
g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
|
| 104 |
+
mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
|
| 105 |
+
mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
|
| 106 |
+
d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1)
|
| 107 |
+
len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 108 |
+
len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 109 |
+
dir_p = half_p / len_p
|
| 110 |
+
dir_g = half_g / len_g
|
| 111 |
+
cos_angle = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1)
|
| 112 |
+
d_dir = 1.0 - cos_angle.abs()
|
| 113 |
+
d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs()
|
| 114 |
+
cost = d_mid + d_dir + d_len # [B, S, M]
|
| 115 |
+
|
| 116 |
+
# Mask invalid GT with high cost
|
| 117 |
+
cost = torch.where(gt_mask.unsqueeze(1), cost, cost.new_tensor(1e6))
|
| 118 |
+
min_dist = cost.min(dim=2).values # [B, S]
|
| 119 |
+
|
| 120 |
+
# Soft target: sigmoid((thresh - dist) / tau), in normalized space
|
| 121 |
+
thresh = VERTEX_THRESH_M / scales # [B]
|
| 122 |
+
target = torch.sigmoid((thresh[:, None] - min_dist) / SIGMOID_TAU)
|
| 123 |
+
|
| 124 |
+
return torch.nn.functional.binary_cross_entropy_with_logits(
|
| 125 |
+
conf_logits, target.detach(), reduction="mean")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def compute_loss(pred_segments, gt_list, scales, device,
|
| 129 |
+
varifold_w, sinkhorn_w, vertex_f1_w=0.0, soft_hss_w=0.0,
|
| 130 |
+
endpoint_w=0.0,
|
| 131 |
+
conf_logits=None, conf_weight=0.0, conf_mode="match",
|
| 132 |
+
sinkhorn_eps=None, sinkhorn_iters=None,
|
| 133 |
+
sinkhorn_dustbin=None, conf_clamp_min=None):
|
| 134 |
+
"""Combined loss with fixed-size GT padding.
|
| 135 |
+
|
| 136 |
+
conf_mode: "match" = BCE matching supervision, "sinkhorn" = conf-weighted sinkhorn.
|
| 137 |
+
"""
|
| 138 |
+
if conf_logits is not None and conf_clamp_min is not None:
|
| 139 |
+
conf_logits = conf_logits.clamp(min=conf_clamp_min)
|
| 140 |
+
gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype)
|
| 141 |
+
c = _get_loss_constants(device, pred_segments.dtype)
|
| 142 |
+
|
| 143 |
+
total, v, f1 = _loss_fn(
|
| 144 |
+
pred_segments, gt_pad, gt_mask, gt_lengths, scales,
|
| 145 |
+
c["sigmas"], c["alphas"], varifold_w, vertex_f1_w)
|
| 146 |
+
|
| 147 |
+
terms = {}
|
| 148 |
+
if varifold_w > 0:
|
| 149 |
+
terms["varifold"] = v.detach()
|
| 150 |
+
if vertex_f1_w > 0:
|
| 151 |
+
terms["vertex_f1"] = f1.detach()
|
| 152 |
+
|
| 153 |
+
if sinkhorn_w > 0:
|
| 154 |
+
has_gt = (gt_lengths > 0).float()
|
| 155 |
+
if conf_logits is not None and conf_mode == "sinkhorn":
|
| 156 |
+
pred_mass = torch.sigmoid(conf_logits)
|
| 157 |
+
elif conf_logits is not None and conf_mode == "sinkhorn_detach":
|
| 158 |
+
pred_mass = torch.sigmoid(conf_logits.detach())
|
| 159 |
+
else:
|
| 160 |
+
pred_mass = None
|
| 161 |
+
eps = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS
|
| 162 |
+
iters = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS
|
| 163 |
+
dustbin = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN
|
| 164 |
+
S = pred_segments.shape[1]
|
| 165 |
+
sink_per = batched_sinkhorn_loss(
|
| 166 |
+
pred_segments, gt_pad, gt_mask,
|
| 167 |
+
eps, iters, dustbin,
|
| 168 |
+
pred_mass=pred_mass,
|
| 169 |
+
) / (gt_lengths.clamp(min=1.0) * S)
|
| 170 |
+
s = (sink_per * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 171 |
+
total = total + sinkhorn_w * s
|
| 172 |
+
terms["sinkhorn"] = s.detach()
|
| 173 |
+
|
| 174 |
+
if soft_hss_w > 0:
|
| 175 |
+
has_gt = (gt_lengths > 0).float()
|
| 176 |
+
vert_thresh = VERTEX_THRESH_M / scales
|
| 177 |
+
edge_thresh = TUBE_RADIUS_M / scales
|
| 178 |
+
hss_loss = batched_soft_hss_v2(
|
| 179 |
+
pred_segments, gt_pad, gt_mask,
|
| 180 |
+
vert_thresh=vert_thresh, edge_thresh=edge_thresh, tau=SIGMOID_TAU)
|
| 181 |
+
hs = (hss_loss * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 182 |
+
total = total + soft_hss_w * hs
|
| 183 |
+
terms["soft_hss"] = hs.detach()
|
| 184 |
+
|
| 185 |
+
if conf_logits is not None and conf_weight > 0:
|
| 186 |
+
if conf_mode == "match":
|
| 187 |
+
# Explicit BCE supervision from nearest-GT distances
|
| 188 |
+
cl = _conf_match_loss(pred_segments, gt_pad, gt_mask, conf_logits, scales)
|
| 189 |
+
total = total + conf_weight * cl
|
| 190 |
+
terms["conf"] = cl.detach()
|
| 191 |
+
elif conf_mode in ("sinkhorn", "sinkhorn_detach"):
|
| 192 |
+
# Conf trained through sinkhorn transport gradients (via pred_mass).
|
| 193 |
+
# sinkhorn_detach: pred_mass uses detached conf, so OT can't push conf negative.
|
| 194 |
+
# Add count regularizer to prevent all-zero conf collapse.
|
| 195 |
+
# Normalized by S so magnitude doesn't depend on segment count.
|
| 196 |
+
conf_w = torch.sigmoid(conf_logits)
|
| 197 |
+
S = conf_logits.shape[1]
|
| 198 |
+
gt_counts = gt_mask.sum(dim=1).float()
|
| 199 |
+
conf_sum = conf_w.sum(dim=1)
|
| 200 |
+
reg = (((conf_sum - gt_counts) / S) ** 2).mean()
|
| 201 |
+
total = total + conf_weight * reg
|
| 202 |
+
terms["conf_reg"] = reg.detach()
|
| 203 |
+
elif conf_mode == "varifold":
|
| 204 |
+
# Conf-weighted varifold: weight each pred segment's contribution
|
| 205 |
+
# by sigmoid(conf). Low-conf segments contribute less to the loss.
|
| 206 |
+
# Needs regularizer to prevent all-zero conf collapse.
|
| 207 |
+
has_gt = (gt_lengths > 0).float()
|
| 208 |
+
conf_w = torch.sigmoid(conf_logits) # [B, S]
|
| 209 |
+
sigmas_eff = c["sigmas"] / scales[:, None]
|
| 210 |
+
vf_conf = varifold_loss_batch(
|
| 211 |
+
pred_segments, gt_pad, gt_mask=gt_mask,
|
| 212 |
+
variant=VARIANT, sigmas=sigmas_eff, alpha=c["alphas"],
|
| 213 |
+
len_pow=LEN_POW, pred_weights=conf_w,
|
| 214 |
+
)
|
| 215 |
+
vc = (vf_conf / gt_lengths.clamp(min=1.0))
|
| 216 |
+
vc = (vc * has_gt).sum() / has_gt.sum().clamp(min=1.0)
|
| 217 |
+
# Regularizer: penalize total conf being far from n_gt
|
| 218 |
+
# Normalized by S so magnitude doesn't depend on segment count
|
| 219 |
+
S = conf_logits.shape[1]
|
| 220 |
+
gt_counts = gt_mask.sum(dim=1).float() # [B]
|
| 221 |
+
conf_sum = conf_w.sum(dim=1) # [B]
|
| 222 |
+
reg = (((conf_sum - gt_counts) / S) ** 2).mean()
|
| 223 |
+
total = total + conf_weight * vc + 0.01 * reg
|
| 224 |
+
terms["conf_vf"] = vc.detach()
|
| 225 |
+
terms["conf_reg"] = reg.detach()
|
| 226 |
+
else:
|
| 227 |
+
raise ValueError(f"Unknown conf_mode: {conf_mode}")
|
| 228 |
+
|
| 229 |
+
if endpoint_w > 0:
|
| 230 |
+
has_gt = (gt_lengths > 0).float()
|
| 231 |
+
eps_ep = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS
|
| 232 |
+
iters_ep = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS
|
| 233 |
+
dustbin_ep = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN
|
| 234 |
+
B, S = pred_segments.shape[:2]
|
| 235 |
+
M = gt_pad.shape[1]
|
| 236 |
+
|
| 237 |
+
# Compute hard assignment via sinkhorn (detached — matching is not trained)
|
| 238 |
+
with torch.no_grad():
|
| 239 |
+
pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None
|
| 240 |
+
sink_loss_for_assign = batched_sinkhorn_loss(
|
| 241 |
+
pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep,
|
| 242 |
+
pred_mass=pred_mass_ep)
|
| 243 |
+
# Re-run sinkhorn to get transport matrix for assignment
|
| 244 |
+
# (reuse the cost computation from batched_sinkhorn_loss internals)
|
| 245 |
+
p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
|
| 246 |
+
g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
|
| 247 |
+
mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
|
| 248 |
+
mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
|
| 249 |
+
d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1)
|
| 250 |
+
len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 251 |
+
len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 252 |
+
dir_p, dir_g = half_p / len_p, half_g / len_g
|
| 253 |
+
cos_a = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1)
|
| 254 |
+
d_dir = 1.0 - cos_a.abs()
|
| 255 |
+
d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs()
|
| 256 |
+
cost = d_mid + d_dir + d_len
|
| 257 |
+
dc = torch.as_tensor(dustbin_ep, device=cost.device, dtype=cost.dtype)
|
| 258 |
+
cost = torch.where(gt_mask.unsqueeze(1), cost, dc * 10.0)
|
| 259 |
+
cost_pad = dc.expand(B, S + 1, M + 1).clone()
|
| 260 |
+
cost_pad[:, :S, :M] = cost
|
| 261 |
+
cost_pad[:, -1, -1] = 0.0
|
| 262 |
+
gt_counts = gt_mask.sum(dim=1).float()
|
| 263 |
+
if pred_mass_ep is not None:
|
| 264 |
+
pm = pred_mass_ep.clamp(min=0.0)
|
| 265 |
+
a = torch.cat([pm, (gt_counts - pm.sum(1)).clamp(min=0).unsqueeze(1)], dim=1)
|
| 266 |
+
b_val = torch.zeros(B, M + 1, device=cost.device, dtype=cost.dtype)
|
| 267 |
+
b_val[:, :M] = gt_mask.float()
|
| 268 |
+
b_val[:, -1] = (pm.sum(1) - gt_counts).clamp(min=0)
|
| 269 |
+
else:
|
| 270 |
+
n = float(S)
|
| 271 |
+
denom = n + gt_counts
|
| 272 |
+
a = (1.0 / denom).unsqueeze(1).expand(B, S + 1).clone()
|
| 273 |
+
a[:, -1] = gt_counts / denom
|
| 274 |
+
b_val = (1.0 / denom).unsqueeze(1).expand(B, M + 1).clone()
|
| 275 |
+
b_val[:, -1] = n / denom
|
| 276 |
+
b_val[:, :M] = b_val[:, :M] * gt_mask.float()
|
| 277 |
+
log_a = torch.log(a + 1e-9)
|
| 278 |
+
log_b = torch.log(b_val + 1e-9)
|
| 279 |
+
log_k = -cost_pad / eps_ep
|
| 280 |
+
log_u = torch.zeros_like(a)
|
| 281 |
+
log_v = torch.zeros_like(b_val)
|
| 282 |
+
for _ in range(iters_ep):
|
| 283 |
+
log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
|
| 284 |
+
log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
|
| 285 |
+
transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
|
| 286 |
+
assignment = transport[:, :S, :M+1].argmax(dim=2)
|
| 287 |
+
assignment[assignment >= M] = -1
|
| 288 |
+
|
| 289 |
+
# Everything below is WITH gradients (assignment is detached but pred_segments is live)
|
| 290 |
+
matched = (assignment >= 0) # [B, S]
|
| 291 |
+
n_matched = matched.float().sum().clamp(min=1.0)
|
| 292 |
+
assign_safe = assignment.clamp(min=0)
|
| 293 |
+
gt_matched = gt_pad[
|
| 294 |
+
torch.arange(B, device=device)[:, None].expand(B, S),
|
| 295 |
+
assign_safe] # [B, S, 2, 3]
|
| 296 |
+
|
| 297 |
+
# Symmetric endpoint distance
|
| 298 |
+
ref_ep1 = pred_segments[:, :, 0]
|
| 299 |
+
ref_ep2 = pred_segments[:, :, 1]
|
| 300 |
+
gt_ep1 = gt_matched[:, :, 0]
|
| 301 |
+
gt_ep2 = gt_matched[:, :, 1]
|
| 302 |
+
dist_fwd = (ref_ep1 - gt_ep1).norm(dim=-1) + (ref_ep2 - gt_ep2).norm(dim=-1)
|
| 303 |
+
dist_rev = (ref_ep1 - gt_ep2).norm(dim=-1) + (ref_ep2 - gt_ep1).norm(dim=-1)
|
| 304 |
+
ep_dist = torch.min(dist_fwd, dist_rev)
|
| 305 |
+
|
| 306 |
+
# Normalize by GT total length * S (same scale as sinkhorn)
|
| 307 |
+
ep_loss = (ep_dist * matched.float()).sum() / n_matched
|
| 308 |
+
total = total + endpoint_w * ep_loss
|
| 309 |
+
terms["endpoint"] = ep_loss.detach()
|
| 310 |
+
|
| 311 |
+
return total, terms
|
s23dr_2026_example/make_sampled_cache.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Convert full point cloud cache to pre-sampled 2048-point npz files.
|
| 3 |
+
|
| 4 |
+
Reads from either local .pt files or the HF dataset, priority-samples
|
| 5 |
+
2048 points, normalizes, and saves as compact npz files (~50KB each).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# From local cache:
|
| 9 |
+
python make_sampled_cache.py --in-dir /workspace/cache/v2 --out-dir /workspace/cache/sampled
|
| 10 |
+
|
| 11 |
+
# From HF dataset:
|
| 12 |
+
python make_sampled_cache.py --hf-repo usm3d/s23dr-2026-cached_full_pcd --out-dir /workspace/cache/sampled
|
| 13 |
+
|
| 14 |
+
# Specify split:
|
| 15 |
+
python make_sampled_cache.py --hf-repo usm3d/s23dr-2026-cached_full_pcd --split validation --out-dir /workspace/cache/sampled_val
|
| 16 |
+
|
| 17 |
+
# With edge classifications (from extract_edge_classes.py):
|
| 18 |
+
python make_sampled_cache.py --hf-repo usm3d/s23dr-2026-cached_full_pcd --out-dir /workspace/cache/sampled \
|
| 19 |
+
--edge-classes edge_classifications.npz
|
| 20 |
+
|
| 21 |
+
Note: uses a fixed seed so each scene gets one deterministic sample of 2048
|
| 22 |
+
points. This means no sampling augmentation across epochs -- every epoch sees
|
| 23 |
+
the same points. Fine for now; better augmentation can be added later.
|
| 24 |
+
"""
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import sys
|
| 28 |
+
from pathlib import Path as _Path
|
| 29 |
+
if __package__ is None or __package__ == "":
|
| 30 |
+
_here = _Path(__file__).resolve().parent
|
| 31 |
+
if str(_here.parent) not in sys.path:
|
| 32 |
+
sys.path.insert(0, str(_here.parent))
|
| 33 |
+
__package__ = _here.name
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import io
|
| 37 |
+
import time
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
|
| 40 |
+
import numpy as np
|
| 41 |
+
import torch
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Priority sampling (same logic as train.py)
|
| 45 |
+
def _priority_sample(source, group_id, seq_len, colmap_quota, depth_quota):
|
| 46 |
+
def pick(src_id, quota):
|
| 47 |
+
base = source == src_id
|
| 48 |
+
picked, remaining = [], quota
|
| 49 |
+
for tier in range(5):
|
| 50 |
+
if remaining <= 0:
|
| 51 |
+
break
|
| 52 |
+
pool = np.where(base & (group_id == tier))[0]
|
| 53 |
+
if len(pool) == 0:
|
| 54 |
+
continue
|
| 55 |
+
np.random.shuffle(pool)
|
| 56 |
+
take = min(remaining, len(pool))
|
| 57 |
+
picked.append(pool[:take])
|
| 58 |
+
remaining -= take
|
| 59 |
+
if remaining > 0:
|
| 60 |
+
pool = np.where(base & (group_id >= 0))[0]
|
| 61 |
+
if len(pool) > 0:
|
| 62 |
+
np.random.shuffle(pool)
|
| 63 |
+
picked.append(pool[:min(remaining, len(pool))])
|
| 64 |
+
remaining -= min(remaining, len(pool))
|
| 65 |
+
return np.concatenate(picked) if picked else np.array([], dtype=np.int64), remaining
|
| 66 |
+
|
| 67 |
+
idx_c, rem_c = pick(0, colmap_quota)
|
| 68 |
+
idx_d, rem_d = pick(1, depth_quota)
|
| 69 |
+
|
| 70 |
+
if rem_c > 0:
|
| 71 |
+
extra = np.setdiff1d(np.where((source == 1) & (group_id >= 0))[0], idx_d)
|
| 72 |
+
np.random.shuffle(extra)
|
| 73 |
+
idx_d = np.concatenate([idx_d, extra[:rem_c]])
|
| 74 |
+
if rem_d > 0:
|
| 75 |
+
extra = np.setdiff1d(np.where((source == 0) & (group_id >= 0))[0], idx_c)
|
| 76 |
+
np.random.shuffle(extra)
|
| 77 |
+
idx_c = np.concatenate([idx_c, extra[:rem_d]])
|
| 78 |
+
|
| 79 |
+
indices = np.concatenate([idx_c, idx_d])
|
| 80 |
+
num_valid = len(indices)
|
| 81 |
+
if num_valid < seq_len:
|
| 82 |
+
if num_valid == 0:
|
| 83 |
+
return np.zeros(seq_len, dtype=np.int64), np.zeros(seq_len, dtype=bool)
|
| 84 |
+
indices = np.concatenate([indices, np.full(seq_len - num_valid, indices[-1])])
|
| 85 |
+
mask = np.zeros(seq_len, dtype=bool)
|
| 86 |
+
mask[:num_valid] = True
|
| 87 |
+
return indices[:seq_len], mask
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def process_sample(xyz, source, group_id, class_id, vis_src, vis_id,
|
| 91 |
+
center, scale, gt_v, gt_e, behind=None,
|
| 92 |
+
n_views_voted=None, vote_frac=None,
|
| 93 |
+
gt_edge_classes=None,
|
| 94 |
+
seq_len=2048, colmap_q=1536, depth_q=512):
|
| 95 |
+
"""Sample and normalize one scene. Returns dict of numpy arrays."""
|
| 96 |
+
indices, mask = _priority_sample(source, group_id, seq_len, colmap_q, depth_q)
|
| 97 |
+
xyz_norm = ((xyz[indices] - center) / scale).astype(np.float32)
|
| 98 |
+
gt_seg = np.stack([gt_v[gt_e[:, 0]], gt_v[gt_e[:, 1]]], axis=1)
|
| 99 |
+
gt_seg_norm = ((gt_seg - center) / scale).astype(np.float32)
|
| 100 |
+
|
| 101 |
+
result = {
|
| 102 |
+
"xyz_norm": xyz_norm,
|
| 103 |
+
"class_id": class_id[indices].astype(np.uint8),
|
| 104 |
+
"source": source[indices].astype(np.uint8),
|
| 105 |
+
"mask": mask,
|
| 106 |
+
"gt_segments": gt_seg_norm,
|
| 107 |
+
"scale": np.float32(scale),
|
| 108 |
+
"center": center.astype(np.float32),
|
| 109 |
+
"gt_vertices": gt_v.astype(np.float32),
|
| 110 |
+
"gt_edges": gt_e.astype(np.int32),
|
| 111 |
+
"visible_src": vis_src[indices].astype(np.uint8),
|
| 112 |
+
"visible_id": vis_id[indices].astype(np.int16),
|
| 113 |
+
}
|
| 114 |
+
if behind is not None:
|
| 115 |
+
result["behind"] = behind[indices].astype(np.int16)
|
| 116 |
+
if n_views_voted is not None:
|
| 117 |
+
result["n_views_voted"] = n_views_voted[indices].astype(np.uint8)
|
| 118 |
+
if vote_frac is not None:
|
| 119 |
+
result["vote_frac"] = vote_frac[indices].astype(np.float32)
|
| 120 |
+
if gt_edge_classes is not None:
|
| 121 |
+
if len(gt_edge_classes) != len(gt_e):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"gt_edge_classes length {len(gt_edge_classes)} != "
|
| 124 |
+
f"gt_edges length {len(gt_e)}")
|
| 125 |
+
result["gt_edge_classes"] = gt_edge_classes.astype(np.int64)
|
| 126 |
+
return result
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _load_edge_classes(path):
|
| 130 |
+
"""Load edge classifications lookup from npz file."""
|
| 131 |
+
if path is None:
|
| 132 |
+
return None
|
| 133 |
+
path = Path(path)
|
| 134 |
+
if not path.exists():
|
| 135 |
+
raise FileNotFoundError(f"Edge classifications file not found: {path}")
|
| 136 |
+
data = np.load(str(path), allow_pickle=False)
|
| 137 |
+
lookup = {k: data[k] for k in data.files}
|
| 138 |
+
print(f"Loaded edge classifications for {len(lookup)} orders from {path}")
|
| 139 |
+
return lookup
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def main():
|
| 143 |
+
p = argparse.ArgumentParser()
|
| 144 |
+
g = p.add_mutually_exclusive_group(required=True)
|
| 145 |
+
g.add_argument("--in-dir", help="Local directory of .pt files")
|
| 146 |
+
g.add_argument("--hf-repo", help="HuggingFace dataset repo (e.g. usm3d/s23dr-2026-cached_full_pcd)")
|
| 147 |
+
p.add_argument("--split", default="train", help="HF dataset split")
|
| 148 |
+
p.add_argument("--out-dir", required=True)
|
| 149 |
+
p.add_argument("--edge-classes", default=None,
|
| 150 |
+
help="Path to edge_classifications.npz from extract_edge_classes.py")
|
| 151 |
+
p.add_argument("--seq-len", type=int, default=2048)
|
| 152 |
+
p.add_argument("--colmap-quota", type=int, default=1536)
|
| 153 |
+
p.add_argument("--depth-quota", type=int, default=512)
|
| 154 |
+
p.add_argument("--seed", type=int, default=7)
|
| 155 |
+
args = p.parse_args()
|
| 156 |
+
|
| 157 |
+
out_dir = Path(args.out_dir)
|
| 158 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 159 |
+
np.random.seed(args.seed)
|
| 160 |
+
|
| 161 |
+
edge_cls_lookup = _load_edge_classes(args.edge_classes)
|
| 162 |
+
n_edge_matched, n_edge_missing = 0, 0
|
| 163 |
+
|
| 164 |
+
t_start = time.perf_counter()
|
| 165 |
+
done = 0
|
| 166 |
+
|
| 167 |
+
if args.in_dir:
|
| 168 |
+
# Local .pt files
|
| 169 |
+
files = sorted(Path(args.in_dir).glob("*.pt"))
|
| 170 |
+
print(f"Converting {len(files)} local .pt files...")
|
| 171 |
+
for f in files:
|
| 172 |
+
out_f = out_dir / (f.stem + ".npz")
|
| 173 |
+
if out_f.exists():
|
| 174 |
+
done += 1
|
| 175 |
+
continue
|
| 176 |
+
d = torch.load(f, weights_only=False)
|
| 177 |
+
behind = np.asarray(d["behind_gest_id"], np.int16) if "behind_gest_id" in d else None
|
| 178 |
+
n_vv = np.asarray(d["n_views_voted"], np.uint8) if "n_views_voted" in d else None
|
| 179 |
+
vf = np.asarray(d["vote_frac"], np.float32) if "vote_frac" in d else None
|
| 180 |
+
gt_ec = None
|
| 181 |
+
if edge_cls_lookup is not None:
|
| 182 |
+
order_id = f.stem
|
| 183 |
+
if order_id in edge_cls_lookup:
|
| 184 |
+
gt_ec = edge_cls_lookup[order_id]
|
| 185 |
+
n_edge_matched += 1
|
| 186 |
+
else:
|
| 187 |
+
n_edge_missing += 1
|
| 188 |
+
result = process_sample(
|
| 189 |
+
np.asarray(d["xyz"], np.float32),
|
| 190 |
+
np.asarray(d["source"], np.uint8),
|
| 191 |
+
np.asarray(d["group_id"], np.int8),
|
| 192 |
+
np.asarray(d["class_id"], np.uint8),
|
| 193 |
+
np.asarray(d["visible_src"], np.uint8),
|
| 194 |
+
np.asarray(d["visible_id"], np.int16),
|
| 195 |
+
np.asarray(d["center"], np.float32),
|
| 196 |
+
float(d["scale"]),
|
| 197 |
+
np.asarray(d["gt_vertices"], np.float32),
|
| 198 |
+
np.asarray(d["gt_edges"], np.int32),
|
| 199 |
+
behind=behind, n_views_voted=n_vv, vote_frac=vf,
|
| 200 |
+
gt_edge_classes=gt_ec,
|
| 201 |
+
seq_len=args.seq_len, colmap_q=args.colmap_quota, depth_q=args.depth_quota,
|
| 202 |
+
)
|
| 203 |
+
np.savez(out_f, **result)
|
| 204 |
+
done += 1
|
| 205 |
+
if done % 2000 == 0:
|
| 206 |
+
print(f" {done}/{len(files)} [{done/(time.perf_counter()-t_start):.0f}/s]")
|
| 207 |
+
else:
|
| 208 |
+
# HF dataset
|
| 209 |
+
from datasets import load_dataset
|
| 210 |
+
print(f"Loading {args.hf_repo} split={args.split}...")
|
| 211 |
+
ds = load_dataset(args.hf_repo, split=args.split)
|
| 212 |
+
print(f"Converting {len(ds)} samples...")
|
| 213 |
+
for i, sample in enumerate(ds):
|
| 214 |
+
order_id = sample["order_id"]
|
| 215 |
+
out_f = out_dir / f"{order_id}.npz"
|
| 216 |
+
if out_f.exists():
|
| 217 |
+
done += 1
|
| 218 |
+
continue
|
| 219 |
+
arrays = np.load(io.BytesIO(sample["data"]))
|
| 220 |
+
behind = arrays["behind_gest_id"] if "behind_gest_id" in arrays else None
|
| 221 |
+
n_vv = arrays["n_views_voted"] if "n_views_voted" in arrays else None
|
| 222 |
+
vf = arrays["vote_frac"] if "vote_frac" in arrays else None
|
| 223 |
+
gt_ec = None
|
| 224 |
+
if edge_cls_lookup is not None:
|
| 225 |
+
if order_id in edge_cls_lookup:
|
| 226 |
+
gt_ec = edge_cls_lookup[order_id]
|
| 227 |
+
n_edge_matched += 1
|
| 228 |
+
else:
|
| 229 |
+
n_edge_missing += 1
|
| 230 |
+
result = process_sample(
|
| 231 |
+
arrays["xyz"], arrays["source"], arrays["group_id"],
|
| 232 |
+
arrays["class_id"], arrays["visible_src"], arrays["visible_id"],
|
| 233 |
+
arrays["center"], float(arrays["scale"]),
|
| 234 |
+
arrays["gt_vertices"], arrays["gt_edges"],
|
| 235 |
+
behind=behind, n_views_voted=n_vv, vote_frac=vf,
|
| 236 |
+
gt_edge_classes=gt_ec,
|
| 237 |
+
seq_len=args.seq_len, colmap_q=args.colmap_quota, depth_q=args.depth_quota,
|
| 238 |
+
)
|
| 239 |
+
np.savez(out_f, **result)
|
| 240 |
+
done += 1
|
| 241 |
+
if done % 2000 == 0:
|
| 242 |
+
print(f" {done}/{len(ds)} [{done/(time.perf_counter()-t_start):.0f}/s]")
|
| 243 |
+
|
| 244 |
+
elapsed = time.perf_counter() - t_start
|
| 245 |
+
print(f"Done: {done} files in {elapsed:.0f}s ({done/max(1,elapsed):.0f}/s)")
|
| 246 |
+
|
| 247 |
+
if edge_cls_lookup is not None:
|
| 248 |
+
print(f"Edge classifications: {n_edge_matched} matched, {n_edge_missing} missing")
|
| 249 |
+
|
| 250 |
+
# Report sizes
|
| 251 |
+
import os
|
| 252 |
+
npz_files = list(out_dir.glob("*.npz"))
|
| 253 |
+
if npz_files:
|
| 254 |
+
sizes = [os.path.getsize(f) for f in npz_files[:100]]
|
| 255 |
+
print(f"Avg file size: {np.mean(sizes)/1024:.0f}KB")
|
| 256 |
+
print(f"Est total: {np.mean(sizes)*len(npz_files)/1e9:.1f}GB")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
main()
|
s23dr_2026_example/model.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Perceiver-based transformer for 3D roof wireframe prediction.
|
| 3 |
+
|
| 4 |
+
Architecture overview:
|
| 5 |
+
|
| 6 |
+
Input tokens [B, T, D]
|
| 7 |
+
|
|
| 8 |
+
v
|
| 9 |
+
input_proj: Linear -> GELU -> Linear -> LayerNorm => [B, T, hidden]
|
| 10 |
+
|
|
| 11 |
+
v
|
| 12 |
+
Perceiver latent bottleneck (N PerceiverLatentLayers):
|
| 13 |
+
Learnable latent embeddings [L, hidden] are broadcast to batch.
|
| 14 |
+
Each layer: cross-attn(latents <- tokens) -> self-attn(latents) -> FFN
|
| 15 |
+
Output: latents [B, L, hidden]
|
| 16 |
+
|
|
| 17 |
+
v
|
| 18 |
+
Segment decoder (M SegmentDecoderLayers):
|
| 19 |
+
Learnable query embeddings [S, hidden] are broadcast to batch.
|
| 20 |
+
Each layer: cross-attn(queries <- latents) -> self-attn(queries) -> FFN
|
| 21 |
+
Output: queries [B, S, hidden]
|
| 22 |
+
|
|
| 23 |
+
v
|
| 24 |
+
segment_head: Linear -> 6D -> (midpoint, half_vector)
|
| 25 |
+
+ query_offsets (learnable per-query bias)
|
| 26 |
+
endpoints = midpoint +/- half_vector -> [B, S, 2, 3]
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
|
| 32 |
+
from .attention import MultiHeadSDPA, FeedForward
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Building blocks
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
class AttnResidual(nn.Module):
|
| 40 |
+
"""Pre-norm attention + residual + dropout."""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
d_model: int,
|
| 45 |
+
num_heads: int,
|
| 46 |
+
dropout: float = 0.0,
|
| 47 |
+
kv_heads: int | None = None,
|
| 48 |
+
norm_class=None,
|
| 49 |
+
qk_norm: bool = False,
|
| 50 |
+
qk_norm_type: str = "l2",
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
norm_class = norm_class or nn.LayerNorm
|
| 54 |
+
self.norm = norm_class(d_model)
|
| 55 |
+
self.attn = MultiHeadSDPA(d_model, num_heads, kv_heads=kv_heads, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
|
| 56 |
+
self.drop = nn.Dropout(dropout)
|
| 57 |
+
|
| 58 |
+
def forward(
|
| 59 |
+
self,
|
| 60 |
+
x: torch.Tensor,
|
| 61 |
+
memory: torch.Tensor,
|
| 62 |
+
memory_key_padding_mask: torch.Tensor | None = None,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
res = x
|
| 65 |
+
x = self.norm(x)
|
| 66 |
+
x = self.attn(x, memory, key_padding_mask=memory_key_padding_mask)
|
| 67 |
+
return res + self.drop(x)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class FFNResidual(nn.Module):
|
| 71 |
+
"""Pre-norm feed-forward + residual + dropout."""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
d_model: int,
|
| 76 |
+
dim_ff: int,
|
| 77 |
+
dropout: float = 0.0,
|
| 78 |
+
activation: str = "gelu",
|
| 79 |
+
norm_class=None,
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
norm_class = norm_class or nn.LayerNorm
|
| 83 |
+
self.norm = norm_class(d_model)
|
| 84 |
+
self.ffn = FeedForward(d_model, dim_ff, activation=activation)
|
| 85 |
+
self.drop = nn.Dropout(dropout)
|
| 86 |
+
|
| 87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
res = x
|
| 89 |
+
x = self.norm(x)
|
| 90 |
+
x = self.ffn(x)
|
| 91 |
+
return res + self.drop(x)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
# Perceiver encoder layer
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
class PerceiverLatentLayer(nn.Module):
|
| 99 |
+
"""Single Perceiver latent layer.
|
| 100 |
+
|
| 101 |
+
If use_cross=True: cross-attn(latents <- points) -> self-attn -> FFN
|
| 102 |
+
If use_cross=False: self-attn -> FFN (saves compute in deep stacks)
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
d_model: int,
|
| 108 |
+
num_heads: int,
|
| 109 |
+
dim_ff: int,
|
| 110 |
+
dropout: float = 0.0,
|
| 111 |
+
activation: str = "gelu",
|
| 112 |
+
kv_heads_cross: int | None = None,
|
| 113 |
+
kv_heads_self: int | None = None,
|
| 114 |
+
use_cross: bool = True,
|
| 115 |
+
norm_class=None,
|
| 116 |
+
qk_norm: bool = False,
|
| 117 |
+
qk_norm_type: str = "l2",
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.use_cross = use_cross
|
| 121 |
+
if use_cross:
|
| 122 |
+
self.cross = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
|
| 123 |
+
self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_self, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
|
| 124 |
+
self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class)
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
latents: torch.Tensor,
|
| 129 |
+
points: torch.Tensor,
|
| 130 |
+
points_key_padding_mask: torch.Tensor | None = None,
|
| 131 |
+
) -> torch.Tensor:
|
| 132 |
+
if self.use_cross:
|
| 133 |
+
latents = self.cross(latents, points, memory_key_padding_mask=points_key_padding_mask)
|
| 134 |
+
latents = self.self_attn(latents, latents)
|
| 135 |
+
latents = self.ffn(latents)
|
| 136 |
+
return latents
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# Segment decoder layer
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
class SegmentDecoderLayer(nn.Module):
|
| 144 |
+
"""Single segment decoder layer.
|
| 145 |
+
|
| 146 |
+
cross-attn(queries <- latents) -> [cross-attn(queries <- inputs)] -> self-attn(queries) -> FFN
|
| 147 |
+
|
| 148 |
+
If input_xattn=True, adds a second cross-attention that attends directly
|
| 149 |
+
to the projected input tokens (bypassing the latent bottleneck). This gives
|
| 150 |
+
queries access to fine-grained point-level detail for vertex precision.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
d_model: int,
|
| 156 |
+
num_heads: int,
|
| 157 |
+
dim_ff: int,
|
| 158 |
+
dropout: float = 0.0,
|
| 159 |
+
activation: str = "gelu",
|
| 160 |
+
kv_heads_cross: int | None = None,
|
| 161 |
+
kv_heads_self: int | None = None,
|
| 162 |
+
norm_class=None,
|
| 163 |
+
input_xattn: bool = False,
|
| 164 |
+
qk_norm: bool = False,
|
| 165 |
+
qk_norm_type: str = "l2",
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.cross = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
|
| 169 |
+
self.input_xattn = input_xattn
|
| 170 |
+
if input_xattn:
|
| 171 |
+
self.cross_input = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
|
| 172 |
+
self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_self, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
|
| 173 |
+
self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class)
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self,
|
| 177 |
+
queries: torch.Tensor,
|
| 178 |
+
latents: torch.Tensor,
|
| 179 |
+
src: torch.Tensor | None = None,
|
| 180 |
+
src_key_padding_mask: torch.Tensor | None = None,
|
| 181 |
+
) -> torch.Tensor:
|
| 182 |
+
queries = self.cross(queries, latents)
|
| 183 |
+
if self.input_xattn and src is not None:
|
| 184 |
+
queries = self.cross_input(queries, src, memory_key_padding_mask=src_key_padding_mask)
|
| 185 |
+
queries = self.self_attn(queries, queries)
|
| 186 |
+
queries = self.ffn(queries)
|
| 187 |
+
return queries
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# ---------------------------------------------------------------------------
|
| 191 |
+
# Full model
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
|
| 194 |
+
class TokenTransformerSegments(nn.Module):
|
| 195 |
+
"""Perceiver transformer that predicts 3D roof wireframe segments.
|
| 196 |
+
|
| 197 |
+
Takes point-cloud tokens and outputs segment endpoints as [B, S, 2, 3]
|
| 198 |
+
where S is the number of segments and each segment has two 3D endpoints.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
segments: Number of predicted segments (S).
|
| 202 |
+
in_dim: Dimensionality of input tokens.
|
| 203 |
+
hidden: Internal hidden dimension throughout the model.
|
| 204 |
+
num_heads: Number of attention heads.
|
| 205 |
+
kv_heads_cross: Grouped-query heads for cross-attention (None = standard MHA).
|
| 206 |
+
kv_heads_self: Grouped-query heads for self-attention (None = standard MHA).
|
| 207 |
+
dim_feedforward: FFN intermediate dimension.
|
| 208 |
+
dropout: Dropout rate applied after attention and FFN.
|
| 209 |
+
latent_tokens: Number of learnable latent embeddings (L) in the bottleneck.
|
| 210 |
+
latent_layers: Number of PerceiverLatentLayers (N).
|
| 211 |
+
decoder_layers: Number of SegmentDecoderLayers (M).
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
segments: int = 32,
|
| 217 |
+
in_dim: int = 128,
|
| 218 |
+
hidden: int = 128,
|
| 219 |
+
num_heads: int = 4,
|
| 220 |
+
kv_heads_cross: int | None = 2,
|
| 221 |
+
kv_heads_self: int | None = 0,
|
| 222 |
+
dim_feedforward: int = 256,
|
| 223 |
+
dropout: float = 0.01,
|
| 224 |
+
latent_tokens: int = 64,
|
| 225 |
+
latent_layers: int = 2,
|
| 226 |
+
decoder_layers: int = 2,
|
| 227 |
+
cross_attn_interval: int = 1,
|
| 228 |
+
norm_class=None,
|
| 229 |
+
activation: str = "gelu",
|
| 230 |
+
segment_conf: bool = False,
|
| 231 |
+
pre_encoder_layers: int = 0,
|
| 232 |
+
segment_param: str = "midpoint_halfvec",
|
| 233 |
+
length_floor: float = 0.0,
|
| 234 |
+
decoder_input_xattn: bool = False,
|
| 235 |
+
qk_norm: bool = False,
|
| 236 |
+
qk_norm_type: str = "l2",
|
| 237 |
+
):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.segments = segments
|
| 240 |
+
self.out_vertices = segments * 2
|
| 241 |
+
self.segment_param = segment_param
|
| 242 |
+
self.length_floor = length_floor
|
| 243 |
+
self.decoder_input_xattn = decoder_input_xattn
|
| 244 |
+
norm_class = norm_class or nn.LayerNorm
|
| 245 |
+
|
| 246 |
+
# Treat 0 as "use standard MHA"
|
| 247 |
+
if kv_heads_cross is not None and kv_heads_cross <= 0:
|
| 248 |
+
kv_heads_cross = None
|
| 249 |
+
if kv_heads_self is not None and kv_heads_self <= 0:
|
| 250 |
+
kv_heads_self = None
|
| 251 |
+
|
| 252 |
+
# -- Input projection --
|
| 253 |
+
self.input_proj = nn.Sequential(
|
| 254 |
+
nn.Linear(in_dim, dim_feedforward),
|
| 255 |
+
nn.GELU(),
|
| 256 |
+
nn.Linear(dim_feedforward, hidden),
|
| 257 |
+
norm_class(hidden),
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# -- Optional pre-encoder: self-attention on full token sequence --
|
| 261 |
+
if pre_encoder_layers > 0:
|
| 262 |
+
self.pre_encoder = nn.ModuleList([
|
| 263 |
+
SelfAttentionEncoderLayer(
|
| 264 |
+
d_model=hidden,
|
| 265 |
+
num_heads=num_heads,
|
| 266 |
+
dim_ff=dim_feedforward,
|
| 267 |
+
dropout=dropout,
|
| 268 |
+
activation=activation,
|
| 269 |
+
kv_heads=kv_heads_self,
|
| 270 |
+
norm_class=norm_class,
|
| 271 |
+
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 272 |
+
)
|
| 273 |
+
for _ in range(pre_encoder_layers)
|
| 274 |
+
])
|
| 275 |
+
else:
|
| 276 |
+
self.pre_encoder = None
|
| 277 |
+
|
| 278 |
+
# -- Perceiver latent bottleneck --
|
| 279 |
+
self.latent_embed = nn.Embedding(latent_tokens, hidden)
|
| 280 |
+
N = latent_layers
|
| 281 |
+
self.latent_layers = nn.ModuleList([
|
| 282 |
+
PerceiverLatentLayer(
|
| 283 |
+
d_model=hidden,
|
| 284 |
+
num_heads=num_heads,
|
| 285 |
+
dim_ff=dim_feedforward,
|
| 286 |
+
dropout=dropout,
|
| 287 |
+
activation=activation,
|
| 288 |
+
kv_heads_cross=kv_heads_cross,
|
| 289 |
+
kv_heads_self=kv_heads_self,
|
| 290 |
+
use_cross=(i == 0) or (i == N - 1) or (i % cross_attn_interval == 0),
|
| 291 |
+
norm_class=norm_class,
|
| 292 |
+
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 293 |
+
)
|
| 294 |
+
for i in range(N)
|
| 295 |
+
])
|
| 296 |
+
|
| 297 |
+
# -- Segment decoder --
|
| 298 |
+
self.query_embed = nn.Embedding(segments, hidden)
|
| 299 |
+
self.decoder_layers = nn.ModuleList([
|
| 300 |
+
SegmentDecoderLayer(
|
| 301 |
+
d_model=hidden,
|
| 302 |
+
num_heads=num_heads,
|
| 303 |
+
dim_ff=dim_feedforward,
|
| 304 |
+
dropout=dropout,
|
| 305 |
+
activation=activation,
|
| 306 |
+
kv_heads_cross=kv_heads_cross,
|
| 307 |
+
kv_heads_self=kv_heads_self,
|
| 308 |
+
norm_class=norm_class,
|
| 309 |
+
input_xattn=decoder_input_xattn,
|
| 310 |
+
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 311 |
+
)
|
| 312 |
+
for _ in range(decoder_layers)
|
| 313 |
+
])
|
| 314 |
+
|
| 315 |
+
# -- Output head --
|
| 316 |
+
if segment_param == "midpoint_dir_len":
|
| 317 |
+
self.segment_head = nn.Linear(hidden, 7) # mid(3) + dir(3) + len(1)
|
| 318 |
+
else:
|
| 319 |
+
self.segment_head = nn.Linear(hidden, 6) # mid(3) + half(3)
|
| 320 |
+
self.query_offsets = nn.Parameter(torch.zeros(segments, 2, 3))
|
| 321 |
+
|
| 322 |
+
nn.init.trunc_normal_(self.segment_head.weight, mean=0.0, std=1e-3)
|
| 323 |
+
if self.segment_head.bias is not None:
|
| 324 |
+
nn.init.zeros_(self.segment_head.bias)
|
| 325 |
+
if segment_param == "midpoint_dir_len":
|
| 326 |
+
# softplus(0.5) * 0.1 ≈ 0.097 default length in normalized space
|
| 327 |
+
self.segment_head.bias.data[6] = 0.5
|
| 328 |
+
nn.init.normal_(self.query_offsets, mean=0.0, std=0.05)
|
| 329 |
+
|
| 330 |
+
# -- Optional confidence head --
|
| 331 |
+
self.segment_conf = segment_conf
|
| 332 |
+
if segment_conf:
|
| 333 |
+
self.conf_head = nn.Linear(hidden, 1)
|
| 334 |
+
nn.init.zeros_(self.conf_head.bias)
|
| 335 |
+
|
| 336 |
+
def forward(
|
| 337 |
+
self,
|
| 338 |
+
tokens: torch.Tensor,
|
| 339 |
+
mask: torch.Tensor | None = None,
|
| 340 |
+
) -> dict[str, torch.Tensor | list]:
|
| 341 |
+
"""
|
| 342 |
+
Args:
|
| 343 |
+
tokens: Input point-cloud tokens [B, T, in_dim].
|
| 344 |
+
mask: Boolean validity mask [B, T]. True = valid token.
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
Dict with keys:
|
| 348 |
+
"vertices": [B, S*2, 3] flattened endpoints.
|
| 349 |
+
"segments": [B, S, 2, 3] segment endpoints.
|
| 350 |
+
"edges": Per-batch list of (start, end) index pairs into vertices.
|
| 351 |
+
"conf": [B, S] logits (only if segment_conf=True).
|
| 352 |
+
"""
|
| 353 |
+
B = tokens.shape[0]
|
| 354 |
+
|
| 355 |
+
# Project input tokens
|
| 356 |
+
src = self.input_proj(tokens) # [B, T, hidden]
|
| 357 |
+
|
| 358 |
+
# Padding mask (True where padded) for cross-attention
|
| 359 |
+
pad_mask = ~mask.bool() if mask is not None else None
|
| 360 |
+
|
| 361 |
+
# Optional pre-encoder: self-attention on full token sequence
|
| 362 |
+
if self.pre_encoder is not None:
|
| 363 |
+
for layer in self.pre_encoder:
|
| 364 |
+
src = layer(src, key_padding_mask=pad_mask)
|
| 365 |
+
|
| 366 |
+
# Perceiver latent bottleneck
|
| 367 |
+
latents = self.latent_embed.weight.unsqueeze(0).expand(B, -1, -1)
|
| 368 |
+
for layer in self.latent_layers:
|
| 369 |
+
latents = layer(latents, src, points_key_padding_mask=pad_mask)
|
| 370 |
+
|
| 371 |
+
# Segment decoder
|
| 372 |
+
queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1)
|
| 373 |
+
for layer in self.decoder_layers:
|
| 374 |
+
queries = layer(queries, latents,
|
| 375 |
+
src=src if self.decoder_input_xattn else None,
|
| 376 |
+
src_key_padding_mask=pad_mask if self.decoder_input_xattn else None)
|
| 377 |
+
|
| 378 |
+
# Predict segments -> endpoints
|
| 379 |
+
if self.segment_param == "midpoint_dir_len":
|
| 380 |
+
raw = self.segment_head(queries) # [B, S, 7]
|
| 381 |
+
mid = raw[:, :, :3] + self.query_offsets[:, 0, :].unsqueeze(0)
|
| 382 |
+
direction = torch.nn.functional.normalize(raw[:, :, 3:6], dim=-1)
|
| 383 |
+
length = torch.nn.functional.softplus(raw[:, :, 6:7]) * 0.1
|
| 384 |
+
half = direction * length * 0.5
|
| 385 |
+
else:
|
| 386 |
+
raw = self.segment_head(queries).view(B, self.segments, 2, 3)
|
| 387 |
+
raw = raw + self.query_offsets.unsqueeze(0)
|
| 388 |
+
mid, half = raw[:, :, 0], raw[:, :, 1]
|
| 389 |
+
seg_params = torch.stack([mid - half, mid + half], dim=2)
|
| 390 |
+
|
| 391 |
+
vertices = seg_params.reshape(B, self.out_vertices, 3)
|
| 392 |
+
edges = [[(2 * i, 2 * i + 1) for i in range(self.segments)] for _ in range(B)]
|
| 393 |
+
|
| 394 |
+
out = {"vertices": vertices, "segments": seg_params, "edges": edges,
|
| 395 |
+
"src": src, "pad_mask": pad_mask, "queries": queries}
|
| 396 |
+
if self.segment_conf:
|
| 397 |
+
out["conf"] = self.conf_head(queries).squeeze(-1) # [B, S]
|
| 398 |
+
return out
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# ---------------------------------------------------------------------------
|
| 402 |
+
# Encoder-only layer (self-attention on full token sequence)
|
| 403 |
+
# ---------------------------------------------------------------------------
|
| 404 |
+
|
| 405 |
+
class SelfAttentionEncoderLayer(nn.Module):
|
| 406 |
+
"""Single self-attention layer: self-attn(tokens) -> FFN."""
|
| 407 |
+
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
d_model: int,
|
| 411 |
+
num_heads: int,
|
| 412 |
+
dim_ff: int,
|
| 413 |
+
dropout: float = 0.0,
|
| 414 |
+
activation: str = "gelu",
|
| 415 |
+
kv_heads: int | None = None,
|
| 416 |
+
norm_class=None,
|
| 417 |
+
qk_norm: bool = False,
|
| 418 |
+
qk_norm_type: str = "l2",
|
| 419 |
+
):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
|
| 422 |
+
self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class)
|
| 423 |
+
|
| 424 |
+
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 425 |
+
x = self.self_attn(x, x, memory_key_padding_mask=key_padding_mask)
|
| 426 |
+
x = self.ffn(x)
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ---------------------------------------------------------------------------
|
| 431 |
+
# Vanilla transformer: self-attention encoder + segment query decoder
|
| 432 |
+
# ---------------------------------------------------------------------------
|
| 433 |
+
|
| 434 |
+
class TransformerSegments(nn.Module):
|
| 435 |
+
"""Standard transformer encoder + cross-attention segment decoder.
|
| 436 |
+
|
| 437 |
+
Architecture:
|
| 438 |
+
Input tokens [B, T, D]
|
| 439 |
+
|
|
| 440 |
+
v
|
| 441 |
+
input_proj: Linear -> GELU -> Linear -> Norm => [B, T, hidden]
|
| 442 |
+
|
|
| 443 |
+
v
|
| 444 |
+
N SelfAttentionEncoderLayers (self-attn over all T tokens)
|
| 445 |
+
|
|
| 446 |
+
v
|
| 447 |
+
Segment decoder (same as Perceiver version):
|
| 448 |
+
M SegmentDecoderLayers (queries cross-attend to encoded tokens)
|
| 449 |
+
|
|
| 450 |
+
v
|
| 451 |
+
segment_head -> endpoints [B, S, 2, 3] (midpoint_halfvec or midpoint_dir_len)
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
def __init__(
|
| 455 |
+
self,
|
| 456 |
+
segments: int = 32,
|
| 457 |
+
in_dim: int = 128,
|
| 458 |
+
hidden: int = 128,
|
| 459 |
+
num_heads: int = 4,
|
| 460 |
+
kv_heads_cross: int | None = 2,
|
| 461 |
+
kv_heads_self: int | None = 0,
|
| 462 |
+
dim_feedforward: int = 256,
|
| 463 |
+
dropout: float = 0.01,
|
| 464 |
+
encoder_layers: int = 4,
|
| 465 |
+
decoder_layers: int = 2,
|
| 466 |
+
norm_class=None,
|
| 467 |
+
activation: str = "gelu",
|
| 468 |
+
segment_conf: bool = False,
|
| 469 |
+
segment_param: str = "midpoint_halfvec",
|
| 470 |
+
length_floor: float = 0.0,
|
| 471 |
+
decoder_input_xattn: bool = False,
|
| 472 |
+
qk_norm: bool = False,
|
| 473 |
+
qk_norm_type: str = "l2",
|
| 474 |
+
):
|
| 475 |
+
super().__init__()
|
| 476 |
+
self.segments = segments
|
| 477 |
+
self.out_vertices = segments * 2
|
| 478 |
+
self.segment_param = segment_param
|
| 479 |
+
self.length_floor = length_floor
|
| 480 |
+
norm_class = norm_class or nn.LayerNorm
|
| 481 |
+
|
| 482 |
+
if kv_heads_cross is not None and kv_heads_cross <= 0:
|
| 483 |
+
kv_heads_cross = None
|
| 484 |
+
if kv_heads_self is not None and kv_heads_self <= 0:
|
| 485 |
+
kv_heads_self = None
|
| 486 |
+
|
| 487 |
+
# -- Input projection --
|
| 488 |
+
self.input_proj = nn.Sequential(
|
| 489 |
+
nn.Linear(in_dim, dim_feedforward),
|
| 490 |
+
nn.GELU(),
|
| 491 |
+
nn.Linear(dim_feedforward, hidden),
|
| 492 |
+
norm_class(hidden),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# -- Self-attention encoder --
|
| 496 |
+
self.encoder_layers = nn.ModuleList([
|
| 497 |
+
SelfAttentionEncoderLayer(
|
| 498 |
+
d_model=hidden,
|
| 499 |
+
num_heads=num_heads,
|
| 500 |
+
dim_ff=dim_feedforward,
|
| 501 |
+
dropout=dropout,
|
| 502 |
+
activation=activation,
|
| 503 |
+
kv_heads=kv_heads_self,
|
| 504 |
+
norm_class=norm_class,
|
| 505 |
+
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 506 |
+
)
|
| 507 |
+
for _ in range(encoder_layers)
|
| 508 |
+
])
|
| 509 |
+
|
| 510 |
+
# -- Segment decoder (same structure as Perceiver version) --
|
| 511 |
+
# Note: for transformer arch, decoder_input_xattn is ignored because
|
| 512 |
+
# the decoder already cross-attends to the full encoded token sequence.
|
| 513 |
+
self.query_embed = nn.Embedding(segments, hidden)
|
| 514 |
+
self.decoder_layers = nn.ModuleList([
|
| 515 |
+
SegmentDecoderLayer(
|
| 516 |
+
d_model=hidden,
|
| 517 |
+
num_heads=num_heads,
|
| 518 |
+
dim_ff=dim_feedforward,
|
| 519 |
+
dropout=dropout,
|
| 520 |
+
activation=activation,
|
| 521 |
+
kv_heads_cross=kv_heads_cross,
|
| 522 |
+
kv_heads_self=kv_heads_self,
|
| 523 |
+
norm_class=norm_class,
|
| 524 |
+
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 525 |
+
)
|
| 526 |
+
for _ in range(decoder_layers)
|
| 527 |
+
])
|
| 528 |
+
|
| 529 |
+
# -- Output head (shared logic with Perceiver version) --
|
| 530 |
+
if segment_param == "midpoint_dir_len":
|
| 531 |
+
self.segment_head = nn.Linear(hidden, 7) # mid(3) + dir(3) + len(1)
|
| 532 |
+
else:
|
| 533 |
+
self.segment_head = nn.Linear(hidden, 6) # mid(3) + half(3)
|
| 534 |
+
self.query_offsets = nn.Parameter(torch.zeros(segments, 2, 3))
|
| 535 |
+
|
| 536 |
+
nn.init.trunc_normal_(self.segment_head.weight, mean=0.0, std=1e-3)
|
| 537 |
+
if self.segment_head.bias is not None:
|
| 538 |
+
nn.init.zeros_(self.segment_head.bias)
|
| 539 |
+
if segment_param == "midpoint_dir_len":
|
| 540 |
+
# sigmoid(-2.2) ~ 0.1 default length in normalized space (~3m)
|
| 541 |
+
self.segment_head.bias.data[6] = -2.2
|
| 542 |
+
nn.init.normal_(self.query_offsets, mean=0.0, std=0.05)
|
| 543 |
+
|
| 544 |
+
self.segment_conf = segment_conf
|
| 545 |
+
if segment_conf:
|
| 546 |
+
self.conf_head = nn.Linear(hidden, 1)
|
| 547 |
+
nn.init.zeros_(self.conf_head.bias)
|
| 548 |
+
|
| 549 |
+
def forward(
|
| 550 |
+
self,
|
| 551 |
+
tokens: torch.Tensor,
|
| 552 |
+
mask: torch.Tensor | None = None,
|
| 553 |
+
) -> dict[str, torch.Tensor | list]:
|
| 554 |
+
B = tokens.shape[0]
|
| 555 |
+
|
| 556 |
+
src = self.input_proj(tokens)
|
| 557 |
+
pad_mask = ~mask.bool() if mask is not None else None
|
| 558 |
+
|
| 559 |
+
# Encode: self-attention over all tokens
|
| 560 |
+
for layer in self.encoder_layers:
|
| 561 |
+
src = layer(src, key_padding_mask=pad_mask)
|
| 562 |
+
|
| 563 |
+
# Decode: segment queries cross-attend to encoded tokens
|
| 564 |
+
queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1)
|
| 565 |
+
for layer in self.decoder_layers:
|
| 566 |
+
queries = layer(queries, src)
|
| 567 |
+
|
| 568 |
+
# Predict segments -> endpoints
|
| 569 |
+
if self.segment_param == "midpoint_dir_len":
|
| 570 |
+
raw = self.segment_head(queries) # [B, S, 7]
|
| 571 |
+
mid = raw[:, :, :3] + self.query_offsets[:, 0, :].unsqueeze(0)
|
| 572 |
+
direction = torch.nn.functional.normalize(raw[:, :, 3:6], dim=-1)
|
| 573 |
+
length = torch.nn.functional.softplus(raw[:, :, 6:7]) * 0.1
|
| 574 |
+
half = direction * length * 0.5
|
| 575 |
+
else:
|
| 576 |
+
raw = self.segment_head(queries).view(B, self.segments, 2, 3)
|
| 577 |
+
raw = raw + self.query_offsets.unsqueeze(0)
|
| 578 |
+
mid, half = raw[:, :, 0], raw[:, :, 1]
|
| 579 |
+
seg_params = torch.stack([mid - half, mid + half], dim=2)
|
| 580 |
+
|
| 581 |
+
vertices = seg_params.reshape(B, self.out_vertices, 3)
|
| 582 |
+
edges = [[(2 * i, 2 * i + 1) for i in range(self.segments)] for _ in range(B)]
|
| 583 |
+
|
| 584 |
+
out = {"vertices": vertices, "segments": seg_params, "edges": edges}
|
| 585 |
+
if self.segment_conf:
|
| 586 |
+
out["conf"] = self.conf_head(queries).squeeze(-1)
|
| 587 |
+
return out
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# ---------------------------------------------------------------------------
|
| 591 |
+
# End-to-end model: tokenizer embeddings + transformer/perceiver
|
| 592 |
+
# ---------------------------------------------------------------------------
|
| 593 |
+
|
| 594 |
+
class EdgeDepthSegmentsModel(nn.Module):
|
| 595 |
+
"""Tokenizer embeddings + transformer for 3D roof wireframes.
|
| 596 |
+
|
| 597 |
+
Supports two architectures via the `arch` parameter:
|
| 598 |
+
- "perceiver": Perceiver latent bottleneck (default, O(L*T) attention)
|
| 599 |
+
- "transformer": Standard self-attention encoder (O(T^2) attention)
|
| 600 |
+
|
| 601 |
+
Both share the same decoder, output head, and tokenizer.
|
| 602 |
+
"""
|
| 603 |
+
|
| 604 |
+
def __init__(
|
| 605 |
+
self,
|
| 606 |
+
seq_cfg,
|
| 607 |
+
segments: int = 32,
|
| 608 |
+
hidden: int = 128,
|
| 609 |
+
num_heads: int = 4,
|
| 610 |
+
kv_heads_cross: int | None = 2,
|
| 611 |
+
kv_heads_self: int | None = 0,
|
| 612 |
+
dim_feedforward: int = 256,
|
| 613 |
+
dropout: float = 0.1,
|
| 614 |
+
latent_tokens: int = 64,
|
| 615 |
+
latent_layers: int = 1,
|
| 616 |
+
decoder_layers: int = 2,
|
| 617 |
+
label_emb_dim: int = 16,
|
| 618 |
+
src_emb_dim: int = 2,
|
| 619 |
+
behind_emb_dim: int = 8,
|
| 620 |
+
fourier_seed: int = 0,
|
| 621 |
+
cross_attn_interval: int = 1,
|
| 622 |
+
norm_class=None,
|
| 623 |
+
activation: str = "gelu",
|
| 624 |
+
segment_conf: bool = False,
|
| 625 |
+
use_vote_features: bool = False,
|
| 626 |
+
arch: str = "perceiver",
|
| 627 |
+
encoder_layers: int = 4,
|
| 628 |
+
pre_encoder_layers: int = 0,
|
| 629 |
+
segment_param: str = "midpoint_halfvec",
|
| 630 |
+
length_floor: float = 0.0,
|
| 631 |
+
decoder_input_xattn: bool = False,
|
| 632 |
+
qk_norm: bool = False,
|
| 633 |
+
qk_norm_type: str = "l2",
|
| 634 |
+
learnable_fourier: bool = False,
|
| 635 |
+
):
|
| 636 |
+
super().__init__()
|
| 637 |
+
self.seq_cfg = seq_cfg
|
| 638 |
+
|
| 639 |
+
from .tokenizer import EdgeDepthSequenceBuilder
|
| 640 |
+
self.tokenizer = EdgeDepthSequenceBuilder(
|
| 641 |
+
seq_cfg,
|
| 642 |
+
label_emb_dim=label_emb_dim,
|
| 643 |
+
src_emb_dim=src_emb_dim,
|
| 644 |
+
behind_emb_dim=behind_emb_dim,
|
| 645 |
+
fourier_seed=fourier_seed,
|
| 646 |
+
use_vote_features=use_vote_features,
|
| 647 |
+
learnable_fourier=learnable_fourier,
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
if arch == "transformer":
|
| 651 |
+
self.segmenter = TransformerSegments(
|
| 652 |
+
segments=segments,
|
| 653 |
+
in_dim=self.tokenizer.out_dim,
|
| 654 |
+
hidden=hidden,
|
| 655 |
+
num_heads=num_heads,
|
| 656 |
+
kv_heads_cross=kv_heads_cross,
|
| 657 |
+
kv_heads_self=kv_heads_self,
|
| 658 |
+
dim_feedforward=dim_feedforward,
|
| 659 |
+
dropout=dropout,
|
| 660 |
+
encoder_layers=encoder_layers,
|
| 661 |
+
decoder_layers=decoder_layers,
|
| 662 |
+
norm_class=norm_class,
|
| 663 |
+
activation=activation,
|
| 664 |
+
segment_conf=segment_conf,
|
| 665 |
+
segment_param=segment_param,
|
| 666 |
+
length_floor=length_floor,
|
| 667 |
+
decoder_input_xattn=decoder_input_xattn,
|
| 668 |
+
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
self.segmenter = TokenTransformerSegments(
|
| 672 |
+
segments=segments,
|
| 673 |
+
in_dim=self.tokenizer.out_dim,
|
| 674 |
+
hidden=hidden,
|
| 675 |
+
num_heads=num_heads,
|
| 676 |
+
kv_heads_cross=kv_heads_cross,
|
| 677 |
+
kv_heads_self=kv_heads_self,
|
| 678 |
+
dim_feedforward=dim_feedforward,
|
| 679 |
+
dropout=dropout,
|
| 680 |
+
latent_tokens=latent_tokens,
|
| 681 |
+
latent_layers=latent_layers,
|
| 682 |
+
decoder_layers=decoder_layers,
|
| 683 |
+
cross_attn_interval=cross_attn_interval,
|
| 684 |
+
norm_class=norm_class,
|
| 685 |
+
activation=activation,
|
| 686 |
+
segment_conf=segment_conf,
|
| 687 |
+
pre_encoder_layers=pre_encoder_layers,
|
| 688 |
+
segment_param=segment_param,
|
| 689 |
+
length_floor=length_floor,
|
| 690 |
+
decoder_input_xattn=decoder_input_xattn,
|
| 691 |
+
qk_norm=qk_norm, qk_norm_type=qk_norm_type,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
def forward_tokens(self, tokens: torch.Tensor, mask: torch.Tensor):
|
| 695 |
+
"""Run the segmenter on pre-built token tensors."""
|
| 696 |
+
return self.segmenter(tokens, mask)
|
s23dr_2026_example/point_fusion.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
point_fusion.py
|
| 3 |
+
|
| 4 |
+
Simplified semantic point fusion for the 2026 dataset format.
|
| 5 |
+
|
| 6 |
+
Takes per-view (ADE segmap, Gestalt segmap, depth) + sparse COLMAP point cloud
|
| 7 |
+
from the usm3d/hoho22k_2026_trainval dataset and builds a compact, house-centric
|
| 8 |
+
semantic point representation suitable for downstream wireframe prediction.
|
| 9 |
+
|
| 10 |
+
Key differences from the 2025 pipeline:
|
| 11 |
+
- COLMAP is a ZIP of text files (cameras.txt, images.txt, points3D.txt)
|
| 12 |
+
- Depth is millimeter I;16 PNG (depth_scale=0.001 converts to meters)
|
| 13 |
+
- Views flagged with pose_only_in_colmap=True have zeroed K/R/t and must be
|
| 14 |
+
skipped for depth unprojection and projection
|
| 15 |
+
- Images arrive as PIL Images, not byte arrays
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import zipfile
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from io import BytesIO
|
| 23 |
+
from typing import Dict, List, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
import cv2
|
| 26 |
+
import numpy as np
|
| 27 |
+
from scipy.stats import mode as scipy_mode
|
| 28 |
+
|
| 29 |
+
from .color_mappings import ade20k_color_mapping, gestalt_color_mapping
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Color packing helpers
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
def _pack_rgb_u32(rgb: np.ndarray) -> np.ndarray:
|
| 36 |
+
"""Pack uint8 RGB (..., 3) into uint32 codes."""
|
| 37 |
+
rgb = rgb.astype(np.uint32, copy=False)
|
| 38 |
+
return (rgb[..., 0] << 16) | (rgb[..., 1] << 8) | rgb[..., 2]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _build_rgbcode_maps(color_mapping):
|
| 42 |
+
"""Return (rgbcode_to_id, id_to_name) for a color mapping dict."""
|
| 43 |
+
names = list(color_mapping.keys())
|
| 44 |
+
rgbs = np.array([color_mapping[n] for n in names], dtype=np.uint8)
|
| 45 |
+
codes = _pack_rgb_u32(rgbs.reshape(-1, 1, 3)).reshape(-1)
|
| 46 |
+
rgbcode_to_id = {int(c): i for i, c in enumerate(codes)}
|
| 47 |
+
return rgbcode_to_id, names
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _name_to_packed_rgb(name, mapping):
|
| 51 |
+
"""Case-insensitive lookup returning a packed RGB code, or None."""
|
| 52 |
+
for key in mapping:
|
| 53 |
+
if key.lower() == name.lower():
|
| 54 |
+
rgb = np.array(mapping[key], np.uint8).reshape(1, 1, 3)
|
| 55 |
+
return int(_pack_rgb_u32(rgb).reshape(()))
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Label mapping constants
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
ADE_RGBCODE_TO_ID, ADE_ID_TO_NAME = _build_rgbcode_maps(ade20k_color_mapping)
|
| 63 |
+
GEST_RGBCODE_TO_ID, GEST_ID_TO_NAME = _build_rgbcode_maps(gestalt_color_mapping)
|
| 64 |
+
NUM_ADE = len(ADE_ID_TO_NAME)
|
| 65 |
+
NUM_GEST = len(GEST_ID_TO_NAME)
|
| 66 |
+
|
| 67 |
+
GEST_INVALID_NAMES = ("unclassified", "unknown", "transition_line")
|
| 68 |
+
GEST_INVALID_CODES = set(
|
| 69 |
+
int(_pack_rgb_u32(np.array(gestalt_color_mapping[n], np.uint8).reshape(1, 1, 3)).reshape(()))
|
| 70 |
+
for n in GEST_INVALID_NAMES if n in gestalt_color_mapping
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# ADE classes whose surfaces are "see-through" for label fusion: when a point
|
| 74 |
+
# projects onto one of these, we use the Gestalt label behind it instead.
|
| 75 |
+
ADE_TRANSPARENT_NAMES = (
|
| 76 |
+
"wall", "building;edifice", "floor;flooring", "ceiling",
|
| 77 |
+
"windowpane;window", "door;double;door", "house", "skyscraper",
|
| 78 |
+
"screen;door;screen", "blind;screen", "hovel;hut;hutch;shack;shanty",
|
| 79 |
+
"tower", "booth;cubicle;stall;kiosk",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# ADE classes kept as "occluders/add-ons" when overlapping the house silhouette.
|
| 83 |
+
ADE_OCCLUDER_ALLOWLIST_NAMES = (
|
| 84 |
+
"tree", "person;individual;someone;somebody;mortal;soul",
|
| 85 |
+
"car;auto;automobile;machine;motorcar", "truck;motortruck", "van",
|
| 86 |
+
"fence;fencing", "railing;rail",
|
| 87 |
+
"bannister;banister;balustrade;balusters;handrail",
|
| 88 |
+
"stairs;steps", "stairway;staircase", "step;stair", "pole",
|
| 89 |
+
"streetlight;street;lamp", "signboard;sign", "awning;sunshade;sunblind",
|
| 90 |
+
"plant;flora;plant;life", "pot;flowerpot",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Precomputed arrays for the default name lists (avoids re-lookup every call).
|
| 94 |
+
_DEFAULT_ADE_TRANSPARENT_CODES = np.array(
|
| 95 |
+
[c for n in ADE_TRANSPARENT_NAMES
|
| 96 |
+
if (c := _name_to_packed_rgb(n, ade20k_color_mapping)) is not None],
|
| 97 |
+
dtype=np.uint32,
|
| 98 |
+
)
|
| 99 |
+
_DEFAULT_ADE_OCCLUDER_IDS = np.array(
|
| 100 |
+
sorted({ADE_RGBCODE_TO_ID[c]
|
| 101 |
+
for n in ADE_OCCLUDER_ALLOWLIST_NAMES
|
| 102 |
+
if (c := _name_to_packed_rgb(n, ade20k_color_mapping)) is not None
|
| 103 |
+
and c in ADE_RGBCODE_TO_ID}),
|
| 104 |
+
dtype=np.int32,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
# Config
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
|
| 111 |
+
@dataclass(frozen=True)
|
| 112 |
+
class FuserConfig:
|
| 113 |
+
"""Simplified fusion configuration (no depth calibration fields)."""
|
| 114 |
+
depth_points_per_view: int = 20_000 # depth samples per view
|
| 115 |
+
depth_scale: float = 0.001 # mm -> meters
|
| 116 |
+
depth_clip_percentile: float = 99.5 # drop extreme outliers
|
| 117 |
+
house_mask_dilate_px: int = 5 # dilate gestalt mask
|
| 118 |
+
min_support_views: int = 1 # min views for a kept point
|
| 119 |
+
ade_transparent_classes: Tuple[str, ...] = ADE_TRANSPARENT_NAMES
|
| 120 |
+
ade_occluder_allowlist: Tuple[str, ...] = ADE_OCCLUDER_ALLOWLIST_NAMES
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
# Geometry: projection + depth unprojection
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
def project_world_points(points_world, K, R, t):
|
| 127 |
+
"""Project (N,3) world points to pixel (u,v) with validity mask."""
|
| 128 |
+
pts = points_world.astype(np.float32, copy=False)
|
| 129 |
+
cam = (R @ pts.T + t).T # (N, 3)
|
| 130 |
+
z = cam[:, 2]
|
| 131 |
+
valid = z > 1e-6
|
| 132 |
+
inv_z = np.zeros_like(z)
|
| 133 |
+
inv_z[valid] = 1.0 / z[valid]
|
| 134 |
+
x = cam[:, 0] * inv_z
|
| 135 |
+
y = cam[:, 1] * inv_z
|
| 136 |
+
u = K[0, 0] * x + K[0, 2]
|
| 137 |
+
v = K[1, 1] * y + K[1, 2]
|
| 138 |
+
return u, v, valid
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def unproject_depth_to_world(depth, K, R, t, num_points, sample_mask=None, rng=None):
|
| 142 |
+
"""Convert a depth map + camera params to (M, 3) world points, M <= num_points."""
|
| 143 |
+
if rng is None:
|
| 144 |
+
rng = np.random.default_rng()
|
| 145 |
+
d = np.asarray(depth, dtype=np.float32)
|
| 146 |
+
if d.ndim != 2:
|
| 147 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 148 |
+
|
| 149 |
+
valid = np.isfinite(d) & (d > 1e-6)
|
| 150 |
+
if sample_mask is not None:
|
| 151 |
+
mask = np.asarray(sample_mask, dtype=bool)
|
| 152 |
+
if mask.shape != d.shape:
|
| 153 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 154 |
+
valid &= mask
|
| 155 |
+
|
| 156 |
+
ys, xs = np.where(valid)
|
| 157 |
+
if ys.size == 0:
|
| 158 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 159 |
+
|
| 160 |
+
idx = rng.choice(ys.size, size=min(num_points, ys.size), replace=False)
|
| 161 |
+
y = ys[idx].astype(np.float32)
|
| 162 |
+
x = xs[idx].astype(np.float32)
|
| 163 |
+
z = d[ys[idx], xs[idx]].astype(np.float32)
|
| 164 |
+
|
| 165 |
+
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
|
| 166 |
+
cam_pts = np.stack([(x - cx) * z / fx, (y - cy) * z / fy, z], axis=0)
|
| 167 |
+
# cam = R * world + t => world = R^T * (cam - t)
|
| 168 |
+
world = (R.T @ (cam_pts - t)).T
|
| 169 |
+
return world.astype(np.float32, copy=False)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def clean_depth(depth, clip_percentile):
|
| 173 |
+
"""Clip extreme depth values."""
|
| 174 |
+
d = np.asarray(depth, dtype=np.float32)
|
| 175 |
+
d = np.where(np.isfinite(d), d, 0.0)
|
| 176 |
+
d[d <= 0] = 0.0
|
| 177 |
+
if clip_percentile is not None and clip_percentile > 0 and np.any(d > 0):
|
| 178 |
+
hi = float(np.percentile(d[d > 0], clip_percentile))
|
| 179 |
+
d = np.clip(d, 0.0, hi)
|
| 180 |
+
return d
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def dilate_mask(mask, radius_px):
|
| 184 |
+
"""Binary dilation via cv2. mask: (H, W) bool."""
|
| 185 |
+
if radius_px <= 0:
|
| 186 |
+
return mask
|
| 187 |
+
k = 2 * radius_px + 1
|
| 188 |
+
kernel = np.ones((k, k), np.uint8)
|
| 189 |
+
return cv2.dilate(mask.astype(np.uint8), kernel) > 0
|
| 190 |
+
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
# COLMAP extraction (2026 format)
|
| 193 |
+
# ---------------------------------------------------------------------------
|
| 194 |
+
|
| 195 |
+
def extract_colmap_points_2026(sample):
|
| 196 |
+
"""Extract (N, 3) float32 COLMAP world points from a 2026-format sample.
|
| 197 |
+
|
| 198 |
+
sample['colmap'] must be a ZIP archive containing points3D.txt.
|
| 199 |
+
Fails fast if that file is missing (it is always present in the 2026 format).
|
| 200 |
+
"""
|
| 201 |
+
colmap_blob = sample.get("colmap")
|
| 202 |
+
if colmap_blob is None:
|
| 203 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 204 |
+
if not isinstance(colmap_blob, (bytes, bytearray, memoryview)):
|
| 205 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
with zipfile.ZipFile(BytesIO(colmap_blob)) as zf:
|
| 209 |
+
if "points3D.txt" not in set(zf.namelist()):
|
| 210 |
+
raise FileNotFoundError(
|
| 211 |
+
"COLMAP ZIP is missing points3D.txt -- "
|
| 212 |
+
"this is required in the 2026 dataset format")
|
| 213 |
+
with zf.open("points3D.txt") as f:
|
| 214 |
+
text = f.read().decode("utf-8", errors="ignore")
|
| 215 |
+
# Format: POINT3D_ID X Y Z R G B ERROR TRACK[]
|
| 216 |
+
# Filter comment/blank lines, parse columns 1-3 (X,Y,Z)
|
| 217 |
+
from io import StringIO
|
| 218 |
+
clean = "\n".join(l for l in text.split("\n") if l and not l.startswith("#"))
|
| 219 |
+
if not clean:
|
| 220 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 221 |
+
return np.loadtxt(StringIO(clean), dtype=np.float32, usecols=(1, 2, 3))
|
| 222 |
+
except zipfile.BadZipFile:
|
| 223 |
+
pass
|
| 224 |
+
return np.zeros((0, 3), dtype=np.float32)
|
| 225 |
+
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
# Label helpers
|
| 228 |
+
# ---------------------------------------------------------------------------
|
| 229 |
+
|
| 230 |
+
def _codes_from_image(img):
|
| 231 |
+
"""Convert a PIL Image or numpy array to a (H, W) uint32 packed-RGB map."""
|
| 232 |
+
arr = np.asarray(img)
|
| 233 |
+
if arr.ndim == 2:
|
| 234 |
+
arr = np.stack([arr, arr, arr], axis=-1)
|
| 235 |
+
arr = arr[..., :3]
|
| 236 |
+
if arr.dtype != np.uint8:
|
| 237 |
+
arr = np.clip(arr, 0, 255).astype(np.uint8)
|
| 238 |
+
return _pack_rgb_u32(arr)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _row_majority(values):
|
| 242 |
+
"""Row-wise majority vote on (P, V) int array; -1 means "no vote".
|
| 243 |
+
Returns (P,) with the most frequent non-negative value per row, or -1.
|
| 244 |
+
|
| 245 |
+
Masks -1 entries before voting so that abstentions don't outvote
|
| 246 |
+
actual labels (which happens when a point is visible in only 1-2 views).
|
| 247 |
+
"""
|
| 248 |
+
P, V = values.shape
|
| 249 |
+
result = np.full(P, -1, dtype=values.dtype)
|
| 250 |
+
|
| 251 |
+
# For each row, find the most frequent non-negative value.
|
| 252 |
+
# Vectorized approach: flatten valid entries per row using argmax on counts.
|
| 253 |
+
# Since values are typically small non-negative ints (0-200), we can use
|
| 254 |
+
# a simple max-of-first-valid approach for speed when V is small.
|
| 255 |
+
for vi in range(V):
|
| 256 |
+
# For rows still unset, take the first valid vote
|
| 257 |
+
col = values[:, vi]
|
| 258 |
+
unset = result == -1
|
| 259 |
+
has_val = col >= 0
|
| 260 |
+
update = unset & has_val
|
| 261 |
+
result[update] = col[update]
|
| 262 |
+
|
| 263 |
+
# Now refine: if a row has multiple different valid votes, pick the mode.
|
| 264 |
+
# Check if any row has conflicting votes across views.
|
| 265 |
+
has_any = np.any(values >= 0, axis=1)
|
| 266 |
+
n_valid = np.sum(values >= 0, axis=1)
|
| 267 |
+
needs_vote = has_any & (n_valid > 1)
|
| 268 |
+
|
| 269 |
+
if np.any(needs_vote):
|
| 270 |
+
for i in np.where(needs_vote)[0]:
|
| 271 |
+
valid = values[i][values[i] >= 0]
|
| 272 |
+
# Use numpy bincount for speed (values are small non-neg ints)
|
| 273 |
+
counts = np.bincount(valid.astype(np.intp))
|
| 274 |
+
result[i] = counts.argmax()
|
| 275 |
+
|
| 276 |
+
return result
|
| 277 |
+
|
| 278 |
+
# ---------------------------------------------------------------------------
|
| 279 |
+
# Semantic fusion: house-centric, occluder-aware
|
| 280 |
+
# ---------------------------------------------------------------------------
|
| 281 |
+
|
| 282 |
+
def _fuse_labels_for_points(
|
| 283 |
+
points_world, Ks, Rs, ts, ade_images, gestalt_images,
|
| 284 |
+
ade_transparent_codes, ade_occluder_allowed_ids,
|
| 285 |
+
min_support_views, valid_view_mask=None,
|
| 286 |
+
):
|
| 287 |
+
"""Multi-view semantic label fusion with majority voting.
|
| 288 |
+
|
| 289 |
+
For each 3D point, project into every valid view:
|
| 290 |
+
- ADE "envelope" class -> use the Gestalt label behind it.
|
| 291 |
+
- ADE non-envelope -> keep if on the occluder allowlist.
|
| 292 |
+
Then majority-vote across views.
|
| 293 |
+
|
| 294 |
+
Returns dict: keep, visible_src, visible_id, behind_gest_id, support
|
| 295 |
+
"""
|
| 296 |
+
P = points_world.shape[0]
|
| 297 |
+
V = min(len(Ks), len(Rs), len(ts), len(ade_images), len(gestalt_images))
|
| 298 |
+
empty = {
|
| 299 |
+
"keep": np.zeros(P, dtype=bool),
|
| 300 |
+
"visible_src": np.zeros(P, np.uint8),
|
| 301 |
+
"visible_id": np.full(P, -1, np.int16),
|
| 302 |
+
"behind_gest_id": np.full(P, -1, np.int16),
|
| 303 |
+
"support": np.zeros(P, np.uint8),
|
| 304 |
+
}
|
| 305 |
+
if P == 0 or V == 0:
|
| 306 |
+
return empty
|
| 307 |
+
|
| 308 |
+
# Per-view labels. src: 1=gestalt, 2=ade; -1 = no contribution.
|
| 309 |
+
visible_src_pv = np.full((P, V), -1, dtype=np.int8)
|
| 310 |
+
visible_id_pv = np.full((P, V), -1, dtype=np.int32)
|
| 311 |
+
behind_id_pv = np.full((P, V), -1, dtype=np.int32)
|
| 312 |
+
support = np.zeros(P, dtype=np.int32)
|
| 313 |
+
|
| 314 |
+
ade_allowed_set = set(ade_occluder_allowed_ids.tolist())
|
| 315 |
+
ade_transparent_u32 = ade_transparent_codes.astype(np.uint32, copy=False)
|
| 316 |
+
gest_invalid_arr = np.array(list(GEST_INVALID_CODES), dtype=np.uint32)
|
| 317 |
+
|
| 318 |
+
for vi in range(V):
|
| 319 |
+
if valid_view_mask is not None and not valid_view_mask[vi]:
|
| 320 |
+
continue
|
| 321 |
+
|
| 322 |
+
K = np.asarray(Ks[vi], np.float32)
|
| 323 |
+
R = np.asarray(Rs[vi], np.float32)
|
| 324 |
+
t = np.asarray(ts[vi], np.float32).reshape(3, 1)
|
| 325 |
+
|
| 326 |
+
ade_codes_img = _codes_from_image(ade_images[vi])
|
| 327 |
+
gest_codes_img = _codes_from_image(gestalt_images[vi])
|
| 328 |
+
H, W = ade_codes_img.shape
|
| 329 |
+
|
| 330 |
+
u, v, valid = project_world_points(points_world, K, R, t)
|
| 331 |
+
in_img = valid & (u >= 0) & (u < W) & (v >= 0) & (v < H)
|
| 332 |
+
if not np.any(in_img):
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
ui = np.clip(np.round(u[in_img]).astype(np.int32), 0, W - 1)
|
| 336 |
+
vi_pix = np.clip(np.round(v[in_img]).astype(np.int32), 0, H - 1)
|
| 337 |
+
ade_codes = ade_codes_img[vi_pix, ui]
|
| 338 |
+
gest_codes = gest_codes_img[vi_pix, ui]
|
| 339 |
+
|
| 340 |
+
in_house = ~np.isin(gest_codes, gest_invalid_arr)
|
| 341 |
+
if not np.any(in_house):
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
idx = np.where(in_img)[0][in_house]
|
| 345 |
+
ade_codes_h = ade_codes[in_house]
|
| 346 |
+
gest_codes_h = gest_codes[in_house]
|
| 347 |
+
|
| 348 |
+
behind_local = np.array(
|
| 349 |
+
[GEST_RGBCODE_TO_ID.get(int(c), -1) for c in gest_codes_h],
|
| 350 |
+
dtype=np.int32)
|
| 351 |
+
behind_id_pv[idx, vi] = behind_local
|
| 352 |
+
|
| 353 |
+
ade_is_transparent = np.isin(ade_codes_h, ade_transparent_u32)
|
| 354 |
+
|
| 355 |
+
# Case A: ADE is envelope -- use Gestalt label.
|
| 356 |
+
mask_a = ade_is_transparent & (behind_local >= 0)
|
| 357 |
+
if np.any(mask_a):
|
| 358 |
+
visible_src_pv[idx[mask_a], vi] = 1
|
| 359 |
+
visible_id_pv[idx[mask_a], vi] = behind_local[mask_a]
|
| 360 |
+
|
| 361 |
+
# Case B: ADE is non-envelope -- use ADE label (allowlist-filtered).
|
| 362 |
+
mask_b = ~ade_is_transparent
|
| 363 |
+
if np.any(mask_b):
|
| 364 |
+
ade_local = np.array(
|
| 365 |
+
[ADE_RGBCODE_TO_ID.get(int(c), -1) for c in ade_codes_h[mask_b]],
|
| 366 |
+
dtype=np.int32)
|
| 367 |
+
on_allowlist = np.array(
|
| 368 |
+
[int(a) in ade_allowed_set for a in ade_local], dtype=bool
|
| 369 |
+
) & (ade_local >= 0)
|
| 370 |
+
if np.any(on_allowlist):
|
| 371 |
+
visible_src_pv[idx[mask_b][on_allowlist], vi] = 2
|
| 372 |
+
visible_id_pv[idx[mask_b][on_allowlist], vi] = ade_local[on_allowlist]
|
| 373 |
+
|
| 374 |
+
support[idx] += 1
|
| 375 |
+
|
| 376 |
+
# ---- Aggregate across views via majority vote ----
|
| 377 |
+
keep = (support >= min_support_views) & np.any(visible_src_pv >= 0, axis=1)
|
| 378 |
+
|
| 379 |
+
# Combine (src, id) into a single key for voting, then split back.
|
| 380 |
+
# src in {1,2} and id in [0, ~150], so stride=100k avoids collisions.
|
| 381 |
+
VIS_STRIDE = 100_000
|
| 382 |
+
vis_key = np.where(
|
| 383 |
+
visible_src_pv >= 0,
|
| 384 |
+
visible_src_pv.astype(np.int64) * VIS_STRIDE + visible_id_pv.astype(np.int64),
|
| 385 |
+
-1)
|
| 386 |
+
voted_key = _row_majority(vis_key)
|
| 387 |
+
voted_behind = _row_majority(behind_id_pv)
|
| 388 |
+
|
| 389 |
+
final_src = np.zeros(P, dtype=np.uint8)
|
| 390 |
+
final_id = np.full(P, -1, dtype=np.int16)
|
| 391 |
+
ok = voted_key >= 0
|
| 392 |
+
if np.any(ok):
|
| 393 |
+
final_src[ok] = (voted_key[ok] // VIS_STRIDE).astype(np.uint8)
|
| 394 |
+
final_id[ok] = (voted_key[ok] % VIS_STRIDE).astype(np.int16)
|
| 395 |
+
|
| 396 |
+
# ---- Vote confidence metadata ----
|
| 397 |
+
n_views_voted = np.sum(visible_src_pv >= 0, axis=1).astype(np.uint8)
|
| 398 |
+
|
| 399 |
+
# Fraction of voting views that agreed with the majority label
|
| 400 |
+
vote_frac = np.zeros(P, dtype=np.float32)
|
| 401 |
+
if np.any(ok):
|
| 402 |
+
for i in np.where(ok)[0]:
|
| 403 |
+
votes = vis_key[i][vis_key[i] >= 0]
|
| 404 |
+
if len(votes) > 0:
|
| 405 |
+
vote_frac[i] = (votes == voted_key[i]).sum() / len(votes)
|
| 406 |
+
|
| 407 |
+
return {
|
| 408 |
+
"keep": keep,
|
| 409 |
+
"visible_src": final_src,
|
| 410 |
+
"visible_id": final_id,
|
| 411 |
+
"behind_gest_id": voted_behind.astype(np.int16),
|
| 412 |
+
"support": support.astype(np.uint8),
|
| 413 |
+
"n_views_voted": n_views_voted,
|
| 414 |
+
"vote_frac": vote_frac,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
# ---------------------------------------------------------------------------
|
| 418 |
+
# Compact scene builder (2026 dataset format)
|
| 419 |
+
# ---------------------------------------------------------------------------
|
| 420 |
+
|
| 421 |
+
def _resolve_ade_codes(cfg):
|
| 422 |
+
"""Return (transparent_codes, occluder_ids) for the given config.
|
| 423 |
+
Uses precomputed module-level arrays when the config has default names.
|
| 424 |
+
"""
|
| 425 |
+
if cfg.ade_transparent_classes == ADE_TRANSPARENT_NAMES:
|
| 426 |
+
transparent = _DEFAULT_ADE_TRANSPARENT_CODES
|
| 427 |
+
else:
|
| 428 |
+
transparent = np.array(
|
| 429 |
+
[c for n in cfg.ade_transparent_classes
|
| 430 |
+
if (c := _name_to_packed_rgb(n, ade20k_color_mapping)) is not None],
|
| 431 |
+
dtype=np.uint32)
|
| 432 |
+
|
| 433 |
+
if cfg.ade_occluder_allowlist == ADE_OCCLUDER_ALLOWLIST_NAMES:
|
| 434 |
+
occluder_ids = _DEFAULT_ADE_OCCLUDER_IDS
|
| 435 |
+
else:
|
| 436 |
+
occluder_ids = np.array(
|
| 437 |
+
sorted({ADE_RGBCODE_TO_ID[c]
|
| 438 |
+
for n in cfg.ade_occluder_allowlist
|
| 439 |
+
if (c := _name_to_packed_rgb(n, ade20k_color_mapping)) is not None
|
| 440 |
+
and c in ADE_RGBCODE_TO_ID}),
|
| 441 |
+
dtype=np.int32)
|
| 442 |
+
return transparent, occluder_ids
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _parse_gt_array(sample, key, dtype, expected_cols):
|
| 446 |
+
"""Parse an optional ground-truth array from the sample dict."""
|
| 447 |
+
raw = sample.get(key)
|
| 448 |
+
if raw is None:
|
| 449 |
+
return None
|
| 450 |
+
arr = np.asarray(raw, dtype=dtype)
|
| 451 |
+
if arr.ndim == 2 and arr.shape[1] == expected_cols:
|
| 452 |
+
return arr
|
| 453 |
+
return None
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def build_compact_scene(sample, cfg, rng):
|
| 457 |
+
"""Build a compact semantic point representation from a HuggingFace sample.
|
| 458 |
+
|
| 459 |
+
Expected sample keys: K, R, t, ade, gestalt, depth, colmap,
|
| 460 |
+
pose_only_in_colmap, wf_vertices (opt), wf_edges (opt), __key__ (opt).
|
| 461 |
+
|
| 462 |
+
Returns dict (xyz, source, visible_src, visible_id, behind_gest_id,
|
| 463 |
+
gt_vertices, gt_edges, sample_id) or None if no points survive fusion.
|
| 464 |
+
"""
|
| 465 |
+
Ks = sample.get("K") or []
|
| 466 |
+
Rs = sample.get("R") or []
|
| 467 |
+
ts = sample.get("t") or []
|
| 468 |
+
ade_imgs = sample.get("ade") or []
|
| 469 |
+
gest_imgs = sample.get("gestalt") or []
|
| 470 |
+
depths = sample.get("depth") or []
|
| 471 |
+
pose_flags = sample.get("pose_only_in_colmap") or []
|
| 472 |
+
|
| 473 |
+
V = min(len(Ks), len(Rs), len(ts), len(ade_imgs), len(gest_imgs))
|
| 474 |
+
if V == 0:
|
| 475 |
+
return None
|
| 476 |
+
|
| 477 |
+
valid_view = [not (vi < len(pose_flags) and pose_flags[vi]) for vi in range(V)]
|
| 478 |
+
if not any(valid_view):
|
| 479 |
+
return None
|
| 480 |
+
|
| 481 |
+
# ---- COLMAP points ----
|
| 482 |
+
colmap_pts = extract_colmap_points_2026(sample)
|
| 483 |
+
|
| 484 |
+
# ---- Precompute house masks (from Gestalt), optionally dilated ----
|
| 485 |
+
gest_invalid_arr = np.array(list(GEST_INVALID_CODES), dtype=np.uint32)
|
| 486 |
+
house_masks = []
|
| 487 |
+
for vi in range(V):
|
| 488 |
+
if not valid_view[vi]:
|
| 489 |
+
house_masks.append(None)
|
| 490 |
+
continue
|
| 491 |
+
mask = ~np.isin(_codes_from_image(gest_imgs[vi]), gest_invalid_arr)
|
| 492 |
+
if cfg.house_mask_dilate_px > 0:
|
| 493 |
+
mask = dilate_mask(mask, cfg.house_mask_dilate_px)
|
| 494 |
+
house_masks.append(mask)
|
| 495 |
+
|
| 496 |
+
# ---- Sample depth points per view ----
|
| 497 |
+
depth_points_all = []
|
| 498 |
+
for vi in range(min(V, len(depths))):
|
| 499 |
+
if not valid_view[vi] or depths[vi] is None:
|
| 500 |
+
continue
|
| 501 |
+
d = clean_depth(
|
| 502 |
+
np.asarray(depths[vi], dtype=np.float32) * cfg.depth_scale,
|
| 503 |
+
cfg.depth_clip_percentile)
|
| 504 |
+
pts = unproject_depth_to_world(
|
| 505 |
+
depth=d,
|
| 506 |
+
K=np.asarray(Ks[vi], np.float32),
|
| 507 |
+
R=np.asarray(Rs[vi], np.float32),
|
| 508 |
+
t=np.asarray(ts[vi], np.float32).reshape(3, 1),
|
| 509 |
+
num_points=cfg.depth_points_per_view,
|
| 510 |
+
sample_mask=house_masks[vi], rng=rng)
|
| 511 |
+
if pts.shape[0]:
|
| 512 |
+
depth_points_all.append(pts)
|
| 513 |
+
|
| 514 |
+
# ---- Combine COLMAP + depth points ----
|
| 515 |
+
pts_list, src_list = [], []
|
| 516 |
+
if colmap_pts.shape[0]:
|
| 517 |
+
pts_list.append(colmap_pts)
|
| 518 |
+
src_list.append(np.zeros(colmap_pts.shape[0], dtype=np.uint8)) # 0=colmap
|
| 519 |
+
if depth_points_all:
|
| 520 |
+
all_depth = np.concatenate(depth_points_all, axis=0)
|
| 521 |
+
pts_list.append(all_depth)
|
| 522 |
+
src_list.append(np.ones(all_depth.shape[0], dtype=np.uint8)) # 1=depth
|
| 523 |
+
if not pts_list:
|
| 524 |
+
return None
|
| 525 |
+
|
| 526 |
+
points_world = np.concatenate(pts_list, axis=0).astype(np.float32, copy=False)
|
| 527 |
+
point_source = np.concatenate(src_list, axis=0).astype(np.uint8, copy=False)
|
| 528 |
+
|
| 529 |
+
# ---- Fuse semantic labels ----
|
| 530 |
+
ade_transparent_arr, ade_allow_ids = _resolve_ade_codes(cfg)
|
| 531 |
+
fused = _fuse_labels_for_points(
|
| 532 |
+
points_world=points_world, Ks=Ks, Rs=Rs, ts=ts,
|
| 533 |
+
ade_images=ade_imgs, gestalt_images=gest_imgs,
|
| 534 |
+
ade_transparent_codes=ade_transparent_arr,
|
| 535 |
+
ade_occluder_allowed_ids=ade_allow_ids,
|
| 536 |
+
min_support_views=cfg.min_support_views,
|
| 537 |
+
valid_view_mask=valid_view)
|
| 538 |
+
|
| 539 |
+
keep = fused["keep"]
|
| 540 |
+
if not np.any(keep):
|
| 541 |
+
return None
|
| 542 |
+
|
| 543 |
+
return {
|
| 544 |
+
"xyz": points_world[keep],
|
| 545 |
+
"source": point_source[keep], # 0=colmap, 1=monodepth
|
| 546 |
+
"visible_src": fused["visible_src"][keep], # 1=gestalt, 2=ade
|
| 547 |
+
"visible_id": fused["visible_id"][keep],
|
| 548 |
+
"behind_gest_id": fused["behind_gest_id"][keep],
|
| 549 |
+
"n_views_voted": fused["n_views_voted"][keep],
|
| 550 |
+
"vote_frac": fused["vote_frac"][keep],
|
| 551 |
+
"gt_vertices": _parse_gt_array(sample, "wf_vertices", np.float32, 3),
|
| 552 |
+
"gt_edges": _parse_gt_array(sample, "wf_edges", np.int64, 2),
|
| 553 |
+
"sample_id": sample.get("__key__", None),
|
| 554 |
+
}
|
s23dr_2026_example/postprocess_v2.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Post-processing functions for segment predictions."""
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def snap_to_point_cloud(vertices, xyz, class_id, snap_radius=0.5,
|
| 6 |
+
target_classes=None):
|
| 7 |
+
"""Snap vertices to nearby point cloud clusters of specific semantic classes."""
|
| 8 |
+
if target_classes is None:
|
| 9 |
+
target_classes = [1, 2] # apex, eave_end_point
|
| 10 |
+
|
| 11 |
+
snapped = vertices.copy()
|
| 12 |
+
mask = np.isin(class_id, target_classes)
|
| 13 |
+
|
| 14 |
+
if mask.sum() < 2:
|
| 15 |
+
return snapped
|
| 16 |
+
|
| 17 |
+
target_pts = xyz[mask]
|
| 18 |
+
|
| 19 |
+
for i, v in enumerate(vertices):
|
| 20 |
+
dists = np.linalg.norm(target_pts - v, axis=-1)
|
| 21 |
+
close = dists < snap_radius
|
| 22 |
+
if close.sum() >= 2:
|
| 23 |
+
snapped[i] = target_pts[close].mean(axis=0)
|
| 24 |
+
|
| 25 |
+
return snapped
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def snap_horizontal(vertices, edges, max_slope=0.05):
|
| 29 |
+
"""Snap near-horizontal edges to be exactly horizontal."""
|
| 30 |
+
verts = vertices.copy()
|
| 31 |
+
for a, b in edges:
|
| 32 |
+
a, b = int(a), int(b)
|
| 33 |
+
dy = abs(verts[a, 1] - verts[b, 1])
|
| 34 |
+
dxz = np.sqrt((verts[a, 0] - verts[b, 0])**2 + (verts[a, 2] - verts[b, 2])**2)
|
| 35 |
+
if dxz > 0.1 and dy / dxz < max_slope:
|
| 36 |
+
avg_y = 0.5 * (verts[a, 1] + verts[b, 1])
|
| 37 |
+
verts[a, 1] = avg_y
|
| 38 |
+
verts[b, 1] = avg_y
|
| 39 |
+
return verts
|
s23dr_2026_example/segment_postprocess.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def merge_vertices(vertices: np.ndarray, edges: np.ndarray, thresh: float):
|
| 7 |
+
verts = np.asarray(vertices, dtype=np.float32)
|
| 8 |
+
edges = np.asarray(edges, dtype=np.int64)
|
| 9 |
+
if verts.size == 0 or edges.size == 0:
|
| 10 |
+
return verts, edges
|
| 11 |
+
|
| 12 |
+
n = verts.shape[0]
|
| 13 |
+
parent = np.arange(n, dtype=np.int64)
|
| 14 |
+
|
| 15 |
+
def find(i):
|
| 16 |
+
while parent[i] != i:
|
| 17 |
+
parent[i] = parent[parent[i]]
|
| 18 |
+
i = parent[i]
|
| 19 |
+
return i
|
| 20 |
+
|
| 21 |
+
def union(i, j):
|
| 22 |
+
ri = find(i)
|
| 23 |
+
rj = find(j)
|
| 24 |
+
if ri != rj:
|
| 25 |
+
parent[rj] = ri
|
| 26 |
+
|
| 27 |
+
for i in range(n):
|
| 28 |
+
vi = verts[i]
|
| 29 |
+
for j in range(i + 1, n):
|
| 30 |
+
if np.linalg.norm(vi - verts[j]) <= thresh:
|
| 31 |
+
union(i, j)
|
| 32 |
+
|
| 33 |
+
clusters = {}
|
| 34 |
+
for i in range(n):
|
| 35 |
+
root = find(i)
|
| 36 |
+
clusters.setdefault(root, []).append(i)
|
| 37 |
+
|
| 38 |
+
new_vertices = []
|
| 39 |
+
mapping = {}
|
| 40 |
+
for new_idx, idxs in enumerate(clusters.values()):
|
| 41 |
+
pts = verts[idxs]
|
| 42 |
+
center = pts.mean(axis=0)
|
| 43 |
+
new_vertices.append(center)
|
| 44 |
+
for i in idxs:
|
| 45 |
+
mapping[i] = new_idx
|
| 46 |
+
|
| 47 |
+
new_edges = []
|
| 48 |
+
seen = set()
|
| 49 |
+
for a, b in edges:
|
| 50 |
+
na = mapping.get(int(a), int(a))
|
| 51 |
+
nb = mapping.get(int(b), int(b))
|
| 52 |
+
if na == nb:
|
| 53 |
+
continue
|
| 54 |
+
key = (na, nb) if na <= nb else (nb, na)
|
| 55 |
+
if key in seen:
|
| 56 |
+
continue
|
| 57 |
+
seen.add(key)
|
| 58 |
+
new_edges.append([na, nb])
|
| 59 |
+
|
| 60 |
+
return np.asarray(new_vertices, dtype=np.float32), np.asarray(new_edges, dtype=np.int64)
|
s23dr_2026_example/sinkhorn.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sinkhorn optimal transport loss for segment matching.
|
| 2 |
+
|
| 3 |
+
Note: at eps=0.05, sinkhorn gradients are near-zero (~1e-7 norm) for
|
| 4 |
+
typical matrix sizes. The loss value is tracked but does not meaningfully
|
| 5 |
+
train the model. Default sinkhorn_weight=0.0. See worklog.md for details.
|
| 6 |
+
|
| 7 |
+
Future: schedule eps from large (1.0) to small (0.05) during training
|
| 8 |
+
to get useful gradients early and precise matching late.
|
| 9 |
+
"""
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def segment_pair_cost(pred_segments: torch.Tensor, gt_segments: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
"""Cost between pred and GT segments: midpoint + direction + length (decoupled).
|
| 15 |
+
pred_segments: [N, 2, 3], gt_segments: [M, 2, 3] -> [N, M]
|
| 16 |
+
"""
|
| 17 |
+
p0, p1 = pred_segments[:, 0], pred_segments[:, 1]
|
| 18 |
+
g0, g1 = gt_segments[:, 0], gt_segments[:, 1]
|
| 19 |
+
mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
|
| 20 |
+
mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
|
| 21 |
+
d_mid = torch.cdist(mid_p, mid_g)
|
| 22 |
+
len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 23 |
+
len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 24 |
+
dir_p = half_p / len_p
|
| 25 |
+
dir_g = half_g / len_g
|
| 26 |
+
cos_angle = (dir_p[:, None, :] * dir_g[None, :, :]).sum(dim=-1)
|
| 27 |
+
d_dir = 1.0 - cos_angle.abs()
|
| 28 |
+
d_len = (len_p[:, None, :] - len_g[None, :, :]).squeeze(-1).abs()
|
| 29 |
+
return d_mid + d_dir + d_len
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def batched_sinkhorn_loss(
|
| 34 |
+
pred_segments: torch.Tensor,
|
| 35 |
+
gt_pad: torch.Tensor,
|
| 36 |
+
gt_mask: torch.Tensor,
|
| 37 |
+
eps: float,
|
| 38 |
+
iters: int,
|
| 39 |
+
dustbin_cost: float | torch.Tensor,
|
| 40 |
+
pred_mass: torch.Tensor | None = None,
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
+
"""Batched sinkhorn segment matching loss.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
pred_segments: [B, S, 2, 3] predicted segments
|
| 46 |
+
gt_pad: [B, M, 2, 3] padded GT segments
|
| 47 |
+
gt_mask: [B, M] bool mask (True = valid GT segment)
|
| 48 |
+
eps: sinkhorn regularization
|
| 49 |
+
iters: sinkhorn iterations
|
| 50 |
+
dustbin_cost: cost for unmatched segments (scalar or [B])
|
| 51 |
+
pred_mass: [B, S] per-segment mass weights (e.g. sigmoid(conf)).
|
| 52 |
+
If None, uniform masses are used.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
[B] per-sample sinkhorn transport cost
|
| 56 |
+
"""
|
| 57 |
+
B, S, _, _ = pred_segments.shape
|
| 58 |
+
M = gt_pad.shape[1]
|
| 59 |
+
|
| 60 |
+
# Allow per-sample dustbin cost
|
| 61 |
+
dc = torch.as_tensor(dustbin_cost, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 62 |
+
if dc.dim() == 0:
|
| 63 |
+
dc = dc.expand(B)
|
| 64 |
+
|
| 65 |
+
# Compute cost matrices [B, S, M] in midpoint-halfvec space.
|
| 66 |
+
# Decouples position from direction: mid gradient is pure position,
|
| 67 |
+
# half gradient is pure direction/length. Sign-invariance on half
|
| 68 |
+
# handles segment direction ambiguity cleanly.
|
| 69 |
+
p0 = pred_segments[:, :, 0] # [B, S, 3]
|
| 70 |
+
p1 = pred_segments[:, :, 1] # [B, S, 3]
|
| 71 |
+
g0 = gt_pad[:, :, 0] # [B, M, 3]
|
| 72 |
+
g1 = gt_pad[:, :, 1] # [B, M, 3]
|
| 73 |
+
|
| 74 |
+
mid_pred = 0.5 * (p0 + p1) # [B, S, 3]
|
| 75 |
+
half_pred = 0.5 * (p1 - p0) # [B, S, 3]
|
| 76 |
+
mid_gt = 0.5 * (g0 + g1) # [B, M, 3]
|
| 77 |
+
half_gt = 0.5 * (g1 - g0) # [B, M, 3]
|
| 78 |
+
|
| 79 |
+
# Midpoint distance [B, S, M]
|
| 80 |
+
d_mid = torch.linalg.norm(
|
| 81 |
+
mid_pred.unsqueeze(2) - mid_gt.unsqueeze(1), dim=-1)
|
| 82 |
+
|
| 83 |
+
# Decoupled direction + length distance (sign-invariant for direction ambiguity)
|
| 84 |
+
len_pred = torch.linalg.norm(half_pred, dim=-1, keepdim=True).clamp(min=1e-6) # [B, S, 1]
|
| 85 |
+
len_gt = torch.linalg.norm(half_gt, dim=-1, keepdim=True).clamp(min=1e-6) # [B, M, 1]
|
| 86 |
+
dir_pred = half_pred / len_pred # [B, S, 3]
|
| 87 |
+
dir_gt = half_gt / len_gt # [B, M, 3]
|
| 88 |
+
|
| 89 |
+
# Direction distance: 1 - |cos(angle)|, sign-invariant [B, S, M]
|
| 90 |
+
cos_angle = (dir_pred.unsqueeze(2) * dir_gt.unsqueeze(1)).sum(dim=-1) # [B, S, M]
|
| 91 |
+
d_dir = 1.0 - cos_angle.abs()
|
| 92 |
+
|
| 93 |
+
# Length distance [B, S, M]
|
| 94 |
+
d_len = (len_pred.unsqueeze(2) - len_gt.unsqueeze(1)).squeeze(-1).abs()
|
| 95 |
+
|
| 96 |
+
cost = d_mid + d_dir + d_len # [B, S, M]
|
| 97 |
+
|
| 98 |
+
# Mask invalid GT segments with high cost so they go to dustbin
|
| 99 |
+
cost = torch.where(gt_mask.unsqueeze(1), cost, dc[:, None, None] * 10.0)
|
| 100 |
+
|
| 101 |
+
# Pad with dustbin row and column: [B, S+1, M+1]
|
| 102 |
+
cost_pad = dc[:, None, None].expand(B, S + 1, M + 1).clone()
|
| 103 |
+
cost_pad[:, :S, :M] = cost
|
| 104 |
+
cost_pad[:, -1, -1] = 0.0
|
| 105 |
+
|
| 106 |
+
# Masses
|
| 107 |
+
gt_counts = gt_mask.sum(dim=1).float() # [B]
|
| 108 |
+
|
| 109 |
+
if pred_mass is not None:
|
| 110 |
+
# Confidence-weighted masses (matches learned_v2 approach).
|
| 111 |
+
# sigmoid(conf) gives per-segment mass; dustbin masses balance the totals.
|
| 112 |
+
# No normalization -- sum(a) == sum(b) == max(sum_pred, sum_gt).
|
| 113 |
+
pm = pred_mass.clamp(min=0.0) # [B, S]
|
| 114 |
+
sum_pred = pm.sum(dim=1) # [B]
|
| 115 |
+
sum_gt = gt_counts # [B]
|
| 116 |
+
pred_dustbin = (sum_gt - sum_pred).clamp(min=0.0) # [B]
|
| 117 |
+
gt_dustbin = (sum_pred - sum_gt).clamp(min=0.0) # [B]
|
| 118 |
+
a = torch.cat([pm, pred_dustbin.unsqueeze(1)], dim=1) # [B, S+1]
|
| 119 |
+
b_val = torch.zeros(B, M + 1, device=cost.device, dtype=cost.dtype)
|
| 120 |
+
b_val[:, :M] = gt_mask.float() # 1.0 per valid GT segment
|
| 121 |
+
b_val[:, -1] = gt_dustbin
|
| 122 |
+
else:
|
| 123 |
+
# Uniform masses (normalized)
|
| 124 |
+
n = float(S)
|
| 125 |
+
denom = n + gt_counts # [B]
|
| 126 |
+
a = (1.0 / denom).unsqueeze(1).expand(B, S + 1).clone() # [B, S+1]
|
| 127 |
+
a[:, -1] = gt_counts / denom
|
| 128 |
+
b_val = (1.0 / denom).unsqueeze(1).expand(B, M + 1).clone() # [B, M+1]
|
| 129 |
+
b_val[:, -1] = n / denom
|
| 130 |
+
# Zero out mass for invalid GT
|
| 131 |
+
b_val[:, :M] = b_val[:, :M] * gt_mask.float()
|
| 132 |
+
|
| 133 |
+
# Log-domain sinkhorn
|
| 134 |
+
log_a = torch.log(a + 1e-9)
|
| 135 |
+
log_b = torch.log(b_val + 1e-9)
|
| 136 |
+
log_k = -cost_pad / eps
|
| 137 |
+
|
| 138 |
+
log_u = torch.zeros_like(a)
|
| 139 |
+
log_v = torch.zeros_like(b_val)
|
| 140 |
+
|
| 141 |
+
for _ in range(iters):
|
| 142 |
+
log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
|
| 143 |
+
log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
|
| 144 |
+
|
| 145 |
+
transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
|
| 146 |
+
return (transport * cost_pad).sum(dim=(1, 2)) # [B]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# Keep the per-sample version for compatibility
|
| 150 |
+
def sinkhorn_segment_loss(
|
| 151 |
+
pred_segments: torch.Tensor,
|
| 152 |
+
gt_segments: torch.Tensor,
|
| 153 |
+
eps: float,
|
| 154 |
+
iters: int,
|
| 155 |
+
dustbin_cost: float,
|
| 156 |
+
pred_mass: torch.Tensor | None = None,
|
| 157 |
+
) -> torch.Tensor:
|
| 158 |
+
if pred_segments.numel() == 0 or gt_segments.numel() == 0:
|
| 159 |
+
return pred_segments.new_tensor(dustbin_cost)
|
| 160 |
+
cost = segment_pair_cost(pred_segments, gt_segments)
|
| 161 |
+
n, m = cost.shape
|
| 162 |
+
if n == 0 or m == 0:
|
| 163 |
+
return cost.new_tensor(dustbin_cost)
|
| 164 |
+
cost_pad = torch.full((n + 1, m + 1), dustbin_cost, device=cost.device, dtype=cost.dtype)
|
| 165 |
+
cost_pad[:n, :m] = cost
|
| 166 |
+
cost_pad[-1, -1] = 0.0
|
| 167 |
+
denom = float(n + m)
|
| 168 |
+
a = torch.full((n + 1,), 1.0 / denom, device=cost.device, dtype=cost.dtype)
|
| 169 |
+
b = torch.full((m + 1,), 1.0 / denom, device=cost.device, dtype=cost.dtype)
|
| 170 |
+
a[-1] = m / denom
|
| 171 |
+
b[-1] = n / denom
|
| 172 |
+
log_a = torch.log(a + 1e-9)
|
| 173 |
+
log_b = torch.log(b + 1e-9)
|
| 174 |
+
log_k = -cost_pad / eps
|
| 175 |
+
log_u = torch.zeros_like(a)
|
| 176 |
+
log_v = torch.zeros_like(b)
|
| 177 |
+
for _ in range(iters):
|
| 178 |
+
log_u = log_a - torch.logsumexp(log_k + log_v[None, :], dim=1)
|
| 179 |
+
log_v = log_b - torch.logsumexp(log_k + log_u[:, None], dim=0)
|
| 180 |
+
transport = torch.exp(log_u[:, None] + log_v[None, :] + log_k)
|
| 181 |
+
return torch.sum(transport * cost_pad)
|
s23dr_2026_example/soft_hss_loss.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _softmin(values: torch.Tensor, dim: int, tau: float) -> torch.Tensor:
|
| 5 |
+
tau_t = torch.as_tensor(tau, device=values.device, dtype=values.dtype).clamp_min(1e-8)
|
| 6 |
+
return -tau_t * torch.logsumexp(-values / tau_t, dim=dim)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def point_segment_distance_squared(
|
| 10 |
+
points: torch.Tensor,
|
| 11 |
+
seg_a: torch.Tensor,
|
| 12 |
+
seg_b: torch.Tensor,
|
| 13 |
+
eps: float = 1e-9,
|
| 14 |
+
) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
points: (P,3)
|
| 17 |
+
seg_a/seg_b: (S,3)
|
| 18 |
+
returns dist2: (P,S)
|
| 19 |
+
"""
|
| 20 |
+
ab = seg_b - seg_a # (S,3)
|
| 21 |
+
ab2 = (ab * ab).sum(dim=-1).clamp_min(eps) # (S,)
|
| 22 |
+
ap = points[:, None, :] - seg_a[None, :, :] # (P,S,3)
|
| 23 |
+
t = (ap * ab[None, :, :]).sum(dim=-1) / ab2[None, :] # (P,S)
|
| 24 |
+
t = t.clamp(0.0, 1.0)
|
| 25 |
+
closest = seg_a[None, :, :] + t[:, :, None] * ab[None, :, :]
|
| 26 |
+
diff = points[:, None, :] - closest
|
| 27 |
+
return (diff * diff).sum(dim=-1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def distance_to_segments(
|
| 31 |
+
points: torch.Tensor,
|
| 32 |
+
segments: torch.Tensor,
|
| 33 |
+
eps: float = 1e-9,
|
| 34 |
+
) -> torch.Tensor:
|
| 35 |
+
"""
|
| 36 |
+
points: (P,3)
|
| 37 |
+
segments: (S,2,3)
|
| 38 |
+
returns min distance: (P,)
|
| 39 |
+
"""
|
| 40 |
+
a = segments[:, 0]
|
| 41 |
+
b = segments[:, 1]
|
| 42 |
+
dist2 = point_segment_distance_squared(points, a, b, eps=eps)
|
| 43 |
+
return torch.sqrt(dist2.min(dim=1).values + eps)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def soft_vertex_f1(
|
| 47 |
+
pred_vertices: torch.Tensor,
|
| 48 |
+
gt_vertices: torch.Tensor,
|
| 49 |
+
thresh: float,
|
| 50 |
+
tau: float = 0.05,
|
| 51 |
+
softmin_tau: float = 0.05,
|
| 52 |
+
eps: float = 1e-8,
|
| 53 |
+
) -> torch.Tensor:
|
| 54 |
+
"""
|
| 55 |
+
Soft surrogate for the Hungarian-thresholded corner F1 used by HSS.
|
| 56 |
+
|
| 57 |
+
Uses (soft) nearest-neighbor distances and a sigmoid threshold.
|
| 58 |
+
"""
|
| 59 |
+
if pred_vertices.numel() == 0 or gt_vertices.numel() == 0:
|
| 60 |
+
return torch.zeros((), device=pred_vertices.device, dtype=pred_vertices.dtype)
|
| 61 |
+
|
| 62 |
+
pred = pred_vertices
|
| 63 |
+
gt = gt_vertices
|
| 64 |
+
|
| 65 |
+
diff = pred[:, None, :] - gt[None, :, :]
|
| 66 |
+
dist = torch.sqrt((diff * diff).sum(dim=-1) + eps) # (P,G)
|
| 67 |
+
|
| 68 |
+
d_pred = _softmin(dist, dim=1, tau=softmin_tau) # (P,)
|
| 69 |
+
d_gt = _softmin(dist, dim=0, tau=softmin_tau) # (G,)
|
| 70 |
+
|
| 71 |
+
tau_t = torch.as_tensor(tau, device=dist.device, dtype=dist.dtype).clamp_min(1e-8)
|
| 72 |
+
thresh_t = torch.as_tensor(thresh, device=dist.device, dtype=dist.dtype)
|
| 73 |
+
p_match = torch.sigmoid((thresh_t - d_pred) / tau_t).mean()
|
| 74 |
+
r_match = torch.sigmoid((thresh_t - d_gt) / tau_t).mean()
|
| 75 |
+
return 2.0 * p_match * r_match / (p_match + r_match + eps)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def soft_tube_iou_mc(
|
| 79 |
+
pred_segments: torch.Tensor,
|
| 80 |
+
gt_segments: torch.Tensor,
|
| 81 |
+
radius: float,
|
| 82 |
+
n_samples: int = 4096,
|
| 83 |
+
tau: float = 0.05,
|
| 84 |
+
seed: int = 0,
|
| 85 |
+
eps: float = 1e-8,
|
| 86 |
+
) -> torch.Tensor:
|
| 87 |
+
"""
|
| 88 |
+
Soft surrogate for volumetric tube IoU (edge_thresh in HSS).
|
| 89 |
+
|
| 90 |
+
Samples points uniformly in a padded bbox around {pred,gt} endpoints.
|
| 91 |
+
Occupancy is sigmoid((radius - d(x, segments))/tau).
|
| 92 |
+
IoU is approximated by mean(min(occ_p, occ_g)) / mean(max(occ_p, occ_g)).
|
| 93 |
+
"""
|
| 94 |
+
if pred_segments.numel() == 0 or gt_segments.numel() == 0:
|
| 95 |
+
return torch.zeros((), device=pred_segments.device, dtype=pred_segments.dtype)
|
| 96 |
+
|
| 97 |
+
pts_all = torch.cat([pred_segments.reshape(-1, 3), gt_segments.reshape(-1, 3)], dim=0)
|
| 98 |
+
pad = torch.as_tensor(radius, device=pts_all.device, dtype=pts_all.dtype)
|
| 99 |
+
lo = pts_all.min(dim=0).values - pad
|
| 100 |
+
hi = pts_all.max(dim=0).values + pad
|
| 101 |
+
|
| 102 |
+
gen = torch.Generator(device=pts_all.device)
|
| 103 |
+
gen.manual_seed(int(seed))
|
| 104 |
+
u = torch.rand((int(n_samples), 3), generator=gen, device=pts_all.device, dtype=pts_all.dtype)
|
| 105 |
+
x = lo[None, :] + u * (hi - lo)[None, :]
|
| 106 |
+
|
| 107 |
+
d_p = distance_to_segments(x, pred_segments, eps=eps)
|
| 108 |
+
d_g = distance_to_segments(x, gt_segments, eps=eps)
|
| 109 |
+
|
| 110 |
+
tau_t = torch.as_tensor(tau, device=pts_all.device, dtype=pts_all.dtype).clamp_min(1e-8)
|
| 111 |
+
rad_t = torch.as_tensor(radius, device=pts_all.device, dtype=pts_all.dtype)
|
| 112 |
+
occ_p = torch.sigmoid((rad_t - d_p) / tau_t)
|
| 113 |
+
occ_g = torch.sigmoid((rad_t - d_g) / tau_t)
|
| 114 |
+
|
| 115 |
+
inter = torch.minimum(occ_p, occ_g).mean()
|
| 116 |
+
union = torch.maximum(occ_p, occ_g).mean().clamp_min(eps)
|
| 117 |
+
return inter / union
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def soft_hss(
|
| 121 |
+
pred_segments: torch.Tensor,
|
| 122 |
+
gt_segments: torch.Tensor,
|
| 123 |
+
gt_vertices: torch.Tensor,
|
| 124 |
+
vert_thresh: float = 0.5,
|
| 125 |
+
edge_thresh: float = 0.5,
|
| 126 |
+
tau: float = 0.05,
|
| 127 |
+
softmin_tau: float = 0.05,
|
| 128 |
+
n_samples: int = 4096,
|
| 129 |
+
seed: int = 0,
|
| 130 |
+
eps: float = 1e-8,
|
| 131 |
+
):
|
| 132 |
+
"""
|
| 133 |
+
Returns (soft_hss, soft_f1, soft_iou), all scalars in [0,1] (approximately).
|
| 134 |
+
"""
|
| 135 |
+
pred_vertices = pred_segments.reshape(-1, 3)
|
| 136 |
+
f1 = soft_vertex_f1(pred_vertices, gt_vertices, thresh=vert_thresh, tau=tau, softmin_tau=softmin_tau, eps=eps)
|
| 137 |
+
iou = soft_tube_iou_mc(
|
| 138 |
+
pred_segments,
|
| 139 |
+
gt_segments,
|
| 140 |
+
radius=edge_thresh,
|
| 141 |
+
n_samples=n_samples,
|
| 142 |
+
tau=tau,
|
| 143 |
+
seed=seed,
|
| 144 |
+
eps=eps,
|
| 145 |
+
)
|
| 146 |
+
denom = (f1 + iou).clamp_min(eps)
|
| 147 |
+
hss = 2.0 * f1 * iou / denom
|
| 148 |
+
return hss, f1, iou
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
# Improved: Sinkhorn-matched vertex F1
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
#
|
| 155 |
+
# The original soft_vertex_f1 uses independent softmin nearest-neighbor
|
| 156 |
+
# distances, which allows multiple predicted vertices to claim the same GT
|
| 157 |
+
# vertex. This inflates precision and fails to penalize duplicate vertices --
|
| 158 |
+
# the exact failure mode that requires merge_vertices post-processing.
|
| 159 |
+
#
|
| 160 |
+
# This version uses Sinkhorn optimal transport to find a soft one-to-one
|
| 161 |
+
# assignment between predicted and GT vertices, then computes precision and
|
| 162 |
+
# recall from the matched distances. This is a better surrogate for the
|
| 163 |
+
# Hungarian matching used by the real HSS metric.
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def sinkhorn_vertex_f1(
|
| 167 |
+
pred_vertices: torch.Tensor,
|
| 168 |
+
gt_vertices: torch.Tensor,
|
| 169 |
+
thresh: float = 0.5,
|
| 170 |
+
tau: float = 0.05,
|
| 171 |
+
eps_sinkhorn: float = 0.05,
|
| 172 |
+
iters: int = 20,
|
| 173 |
+
eps: float = 1e-8,
|
| 174 |
+
) -> torch.Tensor:
|
| 175 |
+
"""Soft vertex F1 using Sinkhorn matching (better aligned with real HSS).
|
| 176 |
+
|
| 177 |
+
Instead of independent nearest-neighbor distances (which allow double-
|
| 178 |
+
claiming), this uses optimal transport to find a soft one-to-one assignment
|
| 179 |
+
between predicted and GT vertices.
|
| 180 |
+
|
| 181 |
+
Returns a differentiable scalar in [0, 1].
|
| 182 |
+
"""
|
| 183 |
+
if pred_vertices.numel() == 0 or gt_vertices.numel() == 0:
|
| 184 |
+
return torch.zeros((), device=pred_vertices.device, dtype=pred_vertices.dtype)
|
| 185 |
+
|
| 186 |
+
P = pred_vertices.shape[0]
|
| 187 |
+
G = gt_vertices.shape[0]
|
| 188 |
+
|
| 189 |
+
# Pairwise distance matrix (P, G)
|
| 190 |
+
dist = torch.cdist(pred_vertices, gt_vertices)
|
| 191 |
+
|
| 192 |
+
# Sinkhorn with dustbin: (P+1) x (G+1)
|
| 193 |
+
# Dustbin cost = thresh (unmatched vertices are "at threshold distance")
|
| 194 |
+
dustbin = thresh
|
| 195 |
+
cost_pad = torch.full((P + 1, G + 1), dustbin, device=dist.device, dtype=dist.dtype)
|
| 196 |
+
cost_pad[:P, :G] = dist
|
| 197 |
+
cost_pad[-1, -1] = 0.0
|
| 198 |
+
|
| 199 |
+
# Uniform masses with dustbin slack
|
| 200 |
+
denom = float(P + G)
|
| 201 |
+
a = torch.full((P + 1,), 1.0 / denom, device=dist.device, dtype=dist.dtype)
|
| 202 |
+
b = torch.full((G + 1,), 1.0 / denom, device=dist.device, dtype=dist.dtype)
|
| 203 |
+
a[-1] = G / denom # pred dustbin absorbs unmatched GT
|
| 204 |
+
b[-1] = P / denom # GT dustbin absorbs unmatched pred
|
| 205 |
+
|
| 206 |
+
# Log-domain Sinkhorn
|
| 207 |
+
log_a = torch.log(a + 1e-9)
|
| 208 |
+
log_b = torch.log(b + 1e-9)
|
| 209 |
+
log_k = -cost_pad / max(eps_sinkhorn, 1e-6)
|
| 210 |
+
log_u = torch.zeros_like(a)
|
| 211 |
+
log_v = torch.zeros_like(b)
|
| 212 |
+
for _ in range(iters):
|
| 213 |
+
log_u = log_a - torch.logsumexp(log_k + log_v[None, :], dim=1)
|
| 214 |
+
log_v = log_b - torch.logsumexp(log_k + log_u[:, None], dim=0)
|
| 215 |
+
|
| 216 |
+
# Transport plan (P+1, G+1)
|
| 217 |
+
transport = torch.exp(log_u[:, None] + log_v[None, :] + log_k)
|
| 218 |
+
|
| 219 |
+
# Extract the non-dustbin transport (P, G) -- these are the soft assignments
|
| 220 |
+
T = transport[:P, :G]
|
| 221 |
+
|
| 222 |
+
# For each predicted vertex, its matched distance is the transport-weighted
|
| 223 |
+
# average distance to GT vertices
|
| 224 |
+
# Normalize rows to sum to 1 (how much of this pred is matched vs dustbin)
|
| 225 |
+
row_sums = T.sum(dim=1).clamp_min(eps)
|
| 226 |
+
matched_dist_pred = (T * dist).sum(dim=1) / row_sums # (P,)
|
| 227 |
+
match_weight_pred = row_sums * denom # how much of this pred is matched (0-1 ish)
|
| 228 |
+
|
| 229 |
+
# Same for GT vertices (column perspective)
|
| 230 |
+
col_sums = T.sum(dim=0).clamp_min(eps)
|
| 231 |
+
matched_dist_gt = (T * dist).sum(dim=0) / col_sums # (G,)
|
| 232 |
+
match_weight_gt = col_sums * denom
|
| 233 |
+
|
| 234 |
+
# Soft precision: fraction of pred vertices that are matched AND within threshold
|
| 235 |
+
tau_t = torch.as_tensor(tau, device=dist.device, dtype=dist.dtype).clamp_min(1e-8)
|
| 236 |
+
thresh_t = torch.as_tensor(thresh, device=dist.device, dtype=dist.dtype)
|
| 237 |
+
|
| 238 |
+
prec_per = match_weight_pred * torch.sigmoid((thresh_t - matched_dist_pred) / tau_t)
|
| 239 |
+
precision = prec_per.mean()
|
| 240 |
+
|
| 241 |
+
# Soft recall: fraction of GT vertices that are matched AND within threshold
|
| 242 |
+
rec_per = match_weight_gt * torch.sigmoid((thresh_t - matched_dist_gt) / tau_t)
|
| 243 |
+
recall = rec_per.mean()
|
| 244 |
+
|
| 245 |
+
return 2.0 * precision * recall / (precision + recall + eps)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ---------------------------------------------------------------------------
|
| 249 |
+
# Improved: Segment-sampled tube IoU
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
#
|
| 252 |
+
# The original soft_tube_iou_mc samples random points in the bounding box,
|
| 253 |
+
# wasting most samples in empty space. This version samples along the segments
|
| 254 |
+
# themselves, concentrating gradient signal where it matters.
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _sample_along_segments(segments: torch.Tensor, n_per_seg: int = 64) -> torch.Tensor:
|
| 258 |
+
"""Sample n_per_seg points uniformly along each segment.
|
| 259 |
+
|
| 260 |
+
segments: (S, 2, 3)
|
| 261 |
+
returns: (S * n_per_seg, 3)
|
| 262 |
+
"""
|
| 263 |
+
t = torch.linspace(0, 1, n_per_seg, device=segments.device, dtype=segments.dtype)
|
| 264 |
+
# (S, 1, 3) + (1, N, 1) * (S, 1, 3) -> (S, N, 3)
|
| 265 |
+
a = segments[:, 0:1, :]
|
| 266 |
+
b = segments[:, 1:2, :]
|
| 267 |
+
pts = a + t[None, :, None] * (b - a)
|
| 268 |
+
return pts.reshape(-1, 3)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def segment_sampled_tube_iou(
|
| 272 |
+
pred_segments: torch.Tensor,
|
| 273 |
+
gt_segments: torch.Tensor,
|
| 274 |
+
radius: float = 0.5,
|
| 275 |
+
n_per_seg: int = 64,
|
| 276 |
+
tau: float = 0.05,
|
| 277 |
+
eps: float = 1e-8,
|
| 278 |
+
) -> torch.Tensor:
|
| 279 |
+
"""Soft tube IoU by sampling along segments instead of in the bounding box.
|
| 280 |
+
|
| 281 |
+
Samples points along predicted and GT segments, then checks what fraction
|
| 282 |
+
of each set falls within radius of the other. More sample-efficient than
|
| 283 |
+
bbox Monte Carlo and gives better gradients.
|
| 284 |
+
|
| 285 |
+
Returns a differentiable scalar in [0, 1].
|
| 286 |
+
"""
|
| 287 |
+
if pred_segments.numel() == 0 or gt_segments.numel() == 0:
|
| 288 |
+
return torch.zeros((), device=pred_segments.device, dtype=pred_segments.dtype)
|
| 289 |
+
|
| 290 |
+
pred_pts = _sample_along_segments(pred_segments, n_per_seg)
|
| 291 |
+
gt_pts = _sample_along_segments(gt_segments, n_per_seg)
|
| 292 |
+
|
| 293 |
+
tau_t = torch.as_tensor(tau, device=pred_pts.device, dtype=pred_pts.dtype).clamp_min(1e-8)
|
| 294 |
+
rad_t = torch.as_tensor(radius, device=pred_pts.device, dtype=pred_pts.dtype)
|
| 295 |
+
|
| 296 |
+
# Precision: fraction of pred points within radius of any GT segment
|
| 297 |
+
d_pred = distance_to_segments(pred_pts, gt_segments, eps=eps)
|
| 298 |
+
prec = torch.sigmoid((rad_t - d_pred) / tau_t).mean()
|
| 299 |
+
|
| 300 |
+
# Recall: fraction of GT points within radius of any pred segment
|
| 301 |
+
d_gt = distance_to_segments(gt_pts, pred_segments, eps=eps)
|
| 302 |
+
rec = torch.sigmoid((rad_t - d_gt) / tau_t).mean()
|
| 303 |
+
|
| 304 |
+
# Soft IoU from precision and recall:
|
| 305 |
+
# IoU = intersection/union = (P*R) / (P + R - P*R) for occupancy overlap
|
| 306 |
+
return prec * rec / (prec + rec - prec * rec + eps)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def soft_hss_v2(
|
| 310 |
+
pred_segments: torch.Tensor,
|
| 311 |
+
gt_segments: torch.Tensor,
|
| 312 |
+
gt_vertices: torch.Tensor,
|
| 313 |
+
vert_thresh: float = 0.5,
|
| 314 |
+
edge_thresh: float = 0.5,
|
| 315 |
+
tau: float = 0.05,
|
| 316 |
+
sinkhorn_eps: float = 0.05,
|
| 317 |
+
sinkhorn_iters: int = 20,
|
| 318 |
+
n_per_seg: int = 64,
|
| 319 |
+
eps: float = 1e-8,
|
| 320 |
+
):
|
| 321 |
+
"""Improved soft HSS using Sinkhorn vertex matching + segment-sampled IoU.
|
| 322 |
+
|
| 323 |
+
Returns (soft_hss, soft_f1, soft_iou).
|
| 324 |
+
"""
|
| 325 |
+
pred_vertices = pred_segments.reshape(-1, 3)
|
| 326 |
+
f1 = sinkhorn_vertex_f1(
|
| 327 |
+
pred_vertices, gt_vertices,
|
| 328 |
+
thresh=vert_thresh, tau=tau,
|
| 329 |
+
eps_sinkhorn=sinkhorn_eps, iters=sinkhorn_iters, eps=eps,
|
| 330 |
+
)
|
| 331 |
+
iou = segment_sampled_tube_iou(
|
| 332 |
+
pred_segments, gt_segments,
|
| 333 |
+
radius=edge_thresh, n_per_seg=n_per_seg, tau=tau, eps=eps,
|
| 334 |
+
)
|
| 335 |
+
denom = (f1 + iou).clamp_min(eps)
|
| 336 |
+
hss = 2.0 * f1 * iou / denom
|
| 337 |
+
return hss, f1, iou
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# ---------------------------------------------------------------------------
|
| 342 |
+
# Batched versions for training speed
|
| 343 |
+
# ---------------------------------------------------------------------------
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def batched_sinkhorn_vertex_f1(
|
| 347 |
+
pred_segments: torch.Tensor,
|
| 348 |
+
gt_pad: torch.Tensor,
|
| 349 |
+
gt_mask: torch.Tensor,
|
| 350 |
+
thresh: float | torch.Tensor = 0.5,
|
| 351 |
+
tau: float | torch.Tensor = 0.05,
|
| 352 |
+
eps_sinkhorn: float = 0.05,
|
| 353 |
+
iters: int = 10,
|
| 354 |
+
eps: float = 1e-8,
|
| 355 |
+
) -> torch.Tensor:
|
| 356 |
+
"""Batched Sinkhorn vertex F1 loss.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
pred_segments: [B, S, 2, 3] predicted segments
|
| 360 |
+
gt_pad: [B, M, 2, 3] padded GT segments
|
| 361 |
+
gt_mask: [B, M] bool mask (True = valid GT segment)
|
| 362 |
+
thresh: distance threshold for a vertex match (scalar or [B])
|
| 363 |
+
tau: sigmoid temperature (scalar or [B])
|
| 364 |
+
Returns:
|
| 365 |
+
[B] per-sample (1 - F1) loss
|
| 366 |
+
"""
|
| 367 |
+
B, S = pred_segments.shape[:2]
|
| 368 |
+
M = gt_pad.shape[1]
|
| 369 |
+
P = S * 2 # pred vertices (both endpoints)
|
| 370 |
+
|
| 371 |
+
# Allow per-sample thresh and tau ([B] tensors or scalars)
|
| 372 |
+
thresh_t = torch.as_tensor(thresh, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 373 |
+
if thresh_t.dim() == 0:
|
| 374 |
+
thresh_t = thresh_t.expand(B)
|
| 375 |
+
tau_t = torch.as_tensor(tau, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 376 |
+
if tau_t.dim() == 0:
|
| 377 |
+
tau_t = tau_t.expand(B)
|
| 378 |
+
tau_t = tau_t.clamp_min(1e-8)
|
| 379 |
+
|
| 380 |
+
pred_verts = pred_segments.reshape(B, P, 3)
|
| 381 |
+
gt_verts = gt_pad.reshape(B, M * 2, 3) # will mask invalid ones
|
| 382 |
+
|
| 383 |
+
# Build GT vertex mask: each valid segment contributes 2 vertices
|
| 384 |
+
gt_vert_mask = gt_mask.unsqueeze(2).expand(B, M, 2).reshape(B, M * 2)
|
| 385 |
+
G = M * 2
|
| 386 |
+
|
| 387 |
+
# Pairwise distances [B, P, G]
|
| 388 |
+
dist = torch.linalg.norm(
|
| 389 |
+
pred_verts.unsqueeze(2) - gt_verts.unsqueeze(1), dim=-1)
|
| 390 |
+
|
| 391 |
+
# Mask invalid GT with high distance
|
| 392 |
+
dist = torch.where(gt_vert_mask.unsqueeze(1), dist, thresh_t[:, None, None] * 10.0)
|
| 393 |
+
|
| 394 |
+
# Sinkhorn matching: [B, P+1, G+1]
|
| 395 |
+
cost_pad = thresh_t[:, None, None].expand(B, P + 1, G + 1).clone()
|
| 396 |
+
cost_pad[:, :P, :G] = dist
|
| 397 |
+
cost_pad[:, -1, -1] = 0.0
|
| 398 |
+
|
| 399 |
+
gt_counts = gt_vert_mask.sum(dim=1).float() # [B]
|
| 400 |
+
n = float(P)
|
| 401 |
+
denom = n + gt_counts # [B]
|
| 402 |
+
|
| 403 |
+
a = (1.0 / denom).unsqueeze(1).expand(B, P + 1).clone()
|
| 404 |
+
a[:, -1] = gt_counts / denom
|
| 405 |
+
b = (1.0 / denom).unsqueeze(1).expand(B, G + 1).clone()
|
| 406 |
+
b[:, -1] = n / denom
|
| 407 |
+
b[:, :G] = b[:, :G] * gt_vert_mask.float()
|
| 408 |
+
|
| 409 |
+
log_a = torch.log(a + 1e-9)
|
| 410 |
+
log_b = torch.log(b + 1e-9)
|
| 411 |
+
log_k = -cost_pad / max(eps_sinkhorn, 1e-6)
|
| 412 |
+
log_u = torch.zeros_like(a)
|
| 413 |
+
log_v = torch.zeros_like(b)
|
| 414 |
+
|
| 415 |
+
for _ in range(iters):
|
| 416 |
+
log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
|
| 417 |
+
log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
|
| 418 |
+
|
| 419 |
+
transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
|
| 420 |
+
T = transport[:, :P, :G] # [B, P, G]
|
| 421 |
+
|
| 422 |
+
# Matched distances
|
| 423 |
+
row_sums = T.sum(dim=2).clamp_min(eps)
|
| 424 |
+
matched_d_pred = (T * dist).sum(dim=2) / row_sums # [B, P]
|
| 425 |
+
w_pred = row_sums * denom.unsqueeze(1)
|
| 426 |
+
|
| 427 |
+
col_sums = T.sum(dim=1).clamp_min(eps)
|
| 428 |
+
matched_d_gt = (T * dist).sum(dim=1) / col_sums # [B, G]
|
| 429 |
+
w_gt = col_sums * denom.unsqueeze(1)
|
| 430 |
+
|
| 431 |
+
precision = (w_pred * torch.sigmoid((thresh_t[:, None] - matched_d_pred) / tau_t[:, None])).mean(dim=1)
|
| 432 |
+
recall_raw = w_gt * torch.sigmoid((thresh_t[:, None] - matched_d_gt) / tau_t[:, None])
|
| 433 |
+
# Mask invalid GT vertices in recall
|
| 434 |
+
recall = (recall_raw * gt_vert_mask.float()).sum(dim=1) / gt_counts.clamp_min(1.0)
|
| 435 |
+
|
| 436 |
+
f1 = 2.0 * precision * recall / (precision + recall + eps)
|
| 437 |
+
return 1.0 - f1 # return loss (1 - F1)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def batched_segment_sampled_iou(
|
| 442 |
+
pred_segments: torch.Tensor,
|
| 443 |
+
gt_pad: torch.Tensor,
|
| 444 |
+
gt_mask: torch.Tensor,
|
| 445 |
+
radius: float | torch.Tensor = 0.5,
|
| 446 |
+
n_per_seg: int = 32,
|
| 447 |
+
tau: float | torch.Tensor = 0.05,
|
| 448 |
+
eps: float = 1e-8,
|
| 449 |
+
) -> torch.Tensor:
|
| 450 |
+
"""Batched segment-sampled tube IoU loss.
|
| 451 |
+
|
| 452 |
+
Returns [B] per-sample (1 - IoU) loss.
|
| 453 |
+
"""
|
| 454 |
+
B, S = pred_segments.shape[:2]
|
| 455 |
+
M = gt_pad.shape[1]
|
| 456 |
+
|
| 457 |
+
# Allow per-sample radius and tau ([B] tensors or scalars)
|
| 458 |
+
rad_t = torch.as_tensor(radius, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 459 |
+
if rad_t.dim() == 0:
|
| 460 |
+
rad_t = rad_t.expand(B)
|
| 461 |
+
tau_t = torch.as_tensor(tau, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 462 |
+
if tau_t.dim() == 0:
|
| 463 |
+
tau_t = tau_t.expand(B)
|
| 464 |
+
tau_t = tau_t.clamp_min(1e-8)
|
| 465 |
+
|
| 466 |
+
# Sample points along segments
|
| 467 |
+
t = torch.linspace(0, 1, n_per_seg, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 468 |
+
|
| 469 |
+
# Pred points: [B, S*n_per_seg, 3]
|
| 470 |
+
pa = pred_segments[:, :, 0:1, :] # [B, S, 1, 3]
|
| 471 |
+
pb = pred_segments[:, :, 1:2, :]
|
| 472 |
+
pred_pts = (pa + t[None, None, :, None] * (pb - pa)).reshape(B, S * n_per_seg, 3)
|
| 473 |
+
|
| 474 |
+
# GT points: [B, M*n_per_seg, 3]
|
| 475 |
+
ga = gt_pad[:, :, 0:1, :]
|
| 476 |
+
gb = gt_pad[:, :, 1:2, :]
|
| 477 |
+
gt_pts = (ga + t[None, None, :, None] * (gb - ga)).reshape(B, M * n_per_seg, 3)
|
| 478 |
+
|
| 479 |
+
# For each pred point, min distance to any GT segment endpoint samples
|
| 480 |
+
d_pred_to_gt = torch.cdist(pred_pts, gt_pts) # [B, S*n, M*n]
|
| 481 |
+
d_pred = d_pred_to_gt.min(dim=2).values # [B, S*n]
|
| 482 |
+
prec = torch.sigmoid((rad_t[:, None] - d_pred) / tau_t[:, None]).mean(dim=1) # [B]
|
| 483 |
+
|
| 484 |
+
d_gt_to_pred = d_pred_to_gt.min(dim=1).values # [B, M*n]
|
| 485 |
+
# Mask invalid GT points
|
| 486 |
+
gt_pt_mask = gt_mask.unsqueeze(2).expand(B, M, n_per_seg).reshape(B, M * n_per_seg)
|
| 487 |
+
rec_raw = torch.sigmoid((rad_t[:, None] - d_gt_to_pred) / tau_t[:, None])
|
| 488 |
+
rec = (rec_raw * gt_pt_mask.float()).sum(dim=1) / gt_pt_mask.float().sum(dim=1).clamp_min(1.0)
|
| 489 |
+
|
| 490 |
+
iou = prec * rec / (prec + rec - prec * rec + eps)
|
| 491 |
+
return 1.0 - iou # return loss
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def batched_soft_hss_v2(pred_segments, gt_pad, gt_mask,
|
| 495 |
+
vert_thresh=0.5, edge_thresh=0.5, tau=0.05,
|
| 496 |
+
sinkhorn_iters=10, n_per_seg=32):
|
| 497 |
+
"""Batched soft HSS loss. Returns [B] per-sample (1 - HSS)."""
|
| 498 |
+
f1_loss = batched_sinkhorn_vertex_f1(
|
| 499 |
+
pred_segments, gt_pad, gt_mask,
|
| 500 |
+
thresh=vert_thresh, tau=tau, iters=sinkhorn_iters)
|
| 501 |
+
iou_loss = batched_segment_sampled_iou(
|
| 502 |
+
pred_segments, gt_pad, gt_mask,
|
| 503 |
+
radius=edge_thresh, n_per_seg=n_per_seg, tau=tau)
|
| 504 |
+
f1 = 1.0 - f1_loss
|
| 505 |
+
iou = 1.0 - iou_loss
|
| 506 |
+
hss = 2.0 * f1 * iou / (f1 + iou + 1e-8)
|
| 507 |
+
return 1.0 - hss
|
s23dr_2026_example/tokenizer.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tokenizer: learned embeddings + Fourier features for the point cloud tokens.
|
| 2 |
+
|
| 3 |
+
The EdgeDepthSequenceBuilder holds the learned embedding tables (label, source,
|
| 4 |
+
behind) and the random Fourier positional encoding. At training time,
|
| 5 |
+
build_tokens() in data.py applies these to pre-sampled point indices on GPU.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from .point_fusion import NUM_ADE, NUM_GEST
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# -- Config --
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class EdgeDepthSequenceConfig:
|
| 23 |
+
seq_len: int = 2048
|
| 24 |
+
colmap_points: int = 1280
|
| 25 |
+
depth_points: int = 768
|
| 26 |
+
use_fourier: bool = True
|
| 27 |
+
fourier_dim: int = 32
|
| 28 |
+
fourier_scale: float = 10.0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# -- Fourier positional encoding --
|
| 32 |
+
|
| 33 |
+
class FourierFeatures(nn.Module):
|
| 34 |
+
def __init__(self, in_dim: int = 3, fourier_dim: int = 64,
|
| 35 |
+
scale: float = 10.0, seed: int = 0,
|
| 36 |
+
learnable: bool = False):
|
| 37 |
+
super().__init__()
|
| 38 |
+
gen = torch.Generator()
|
| 39 |
+
gen.manual_seed(seed)
|
| 40 |
+
B = torch.randn(fourier_dim, in_dim, generator=gen) * scale
|
| 41 |
+
if learnable:
|
| 42 |
+
self.B = nn.Parameter(B)
|
| 43 |
+
else:
|
| 44 |
+
self.register_buffer("B", B, persistent=True)
|
| 45 |
+
|
| 46 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
proj = (2.0 * np.pi) * (x @ self.B.t())
|
| 48 |
+
return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# -- Sequence builder (holds embeddings) --
|
| 52 |
+
|
| 53 |
+
class EdgeDepthSequenceBuilder(nn.Module):
|
| 54 |
+
"""Holds learned embeddings for point cloud tokenization.
|
| 55 |
+
|
| 56 |
+
Used by the model at training time: build_tokens() calls
|
| 57 |
+
self.label_emb(class_id), self.src_emb(source), etc.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, cfg: EdgeDepthSequenceConfig, label_emb_dim: int = 16,
|
| 61 |
+
src_emb_dim: int = 2, behind_emb_dim: int = 8,
|
| 62 |
+
fourier_seed: int = 0, use_vote_features: bool = False,
|
| 63 |
+
learnable_fourier: bool = False):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.cfg = cfg
|
| 66 |
+
|
| 67 |
+
self.num_labels = 13 # 11 structural + other_house + non_house
|
| 68 |
+
self.label_emb = nn.Embedding(self.num_labels, label_emb_dim)
|
| 69 |
+
self.src_emb = nn.Embedding(2, src_emb_dim)
|
| 70 |
+
self.behind_emb_dim = behind_emb_dim
|
| 71 |
+
if behind_emb_dim > 0:
|
| 72 |
+
self.behind_emb = nn.Embedding(NUM_GEST + 1, behind_emb_dim)
|
| 73 |
+
|
| 74 |
+
# Fourier positional encoding
|
| 75 |
+
if cfg.use_fourier:
|
| 76 |
+
self.pos_enc = FourierFeatures(
|
| 77 |
+
in_dim=3, fourier_dim=cfg.fourier_dim,
|
| 78 |
+
scale=cfg.fourier_scale, seed=fourier_seed,
|
| 79 |
+
learnable=learnable_fourier,
|
| 80 |
+
)
|
| 81 |
+
pos_dim = 3 + 2 * cfg.fourier_dim
|
| 82 |
+
else:
|
| 83 |
+
self.pos_enc = None
|
| 84 |
+
pos_dim = 3
|
| 85 |
+
|
| 86 |
+
vote_dim = 2 if use_vote_features else 0 # n_views_voted + vote_frac
|
| 87 |
+
self.use_vote_features = use_vote_features
|
| 88 |
+
self.out_dim = pos_dim + label_emb_dim + src_emb_dim + behind_emb_dim + vote_dim
|
s23dr_2026_example/varifold.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .wire_varifold_kernels import (
|
| 4 |
+
loss_semi_lobatto3,
|
| 5 |
+
loss_semi_lobatto3_mix,
|
| 6 |
+
loss_semi_lobatto3_mix_simple,
|
| 7 |
+
loss_simpson3,
|
| 8 |
+
loss_simpson3_batch,
|
| 9 |
+
loss_simpson3_mix,
|
| 10 |
+
loss_simpson3_mix_batch,
|
| 11 |
+
loss_simpson3_lenpow,
|
| 12 |
+
loss_simpson3_lenpow_mix,
|
| 13 |
+
loss_semi_legendre,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def edges_to_segments(vertices, edges) -> torch.Tensor:
|
| 18 |
+
verts = torch.as_tensor(vertices, dtype=torch.float32)
|
| 19 |
+
idx = torch.as_tensor(edges, dtype=torch.long)
|
| 20 |
+
return torch.stack([verts[idx[:, 0]], verts[idx[:, 1]]], dim=1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def segments_to_vertices_edges(segments: torch.Tensor):
|
| 24 |
+
segs = torch.as_tensor(segments, dtype=torch.float32)
|
| 25 |
+
vertices = segs.reshape(-1, 3)
|
| 26 |
+
edges = [(2 * i, 2 * i + 1) for i in range(segs.shape[0])]
|
| 27 |
+
return vertices, edges
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def varifold_loss(
|
| 31 |
+
pred_segments: torch.Tensor,
|
| 32 |
+
gt_segments: torch.Tensor,
|
| 33 |
+
sigma: float = 0.1,
|
| 34 |
+
variant: str = "semi_lobatto3",
|
| 35 |
+
t_nodes01: torch.Tensor | None = None,
|
| 36 |
+
t_w: torch.Tensor | None = None,
|
| 37 |
+
sigmas: torch.Tensor | None = None,
|
| 38 |
+
alpha: torch.Tensor | None = None,
|
| 39 |
+
normalize_alpha: bool = True,
|
| 40 |
+
len_pow: float | None = None,
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
+
p_pred, q_pred = pred_segments[:, 0], pred_segments[:, 1]
|
| 43 |
+
p_gt, q_gt = gt_segments[:, 0], gt_segments[:, 1]
|
| 44 |
+
|
| 45 |
+
if variant == "semi_lobatto3":
|
| 46 |
+
return loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 47 |
+
if variant == "semi_lobatto3_mix":
|
| 48 |
+
if sigmas is None or alpha is None:
|
| 49 |
+
raise ValueError("sigmas and alpha are required for semi_lobatto3_mix")
|
| 50 |
+
return loss_semi_lobatto3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 51 |
+
if variant == "semi_lobatto3_mix_simple":
|
| 52 |
+
if sigmas is None or alpha is None:
|
| 53 |
+
raise ValueError("sigmas and alpha are required for semi_lobatto3_mix_simple")
|
| 54 |
+
return loss_semi_lobatto3_mix_simple(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 55 |
+
if variant == "simpson3":
|
| 56 |
+
if sigmas is not None or alpha is not None:
|
| 57 |
+
if sigmas is None or alpha is None:
|
| 58 |
+
raise ValueError("sigmas and alpha are required for simpson3 mix")
|
| 59 |
+
return loss_simpson3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 60 |
+
return loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 61 |
+
if variant == "simpson3_lenpow":
|
| 62 |
+
if len_pow is None:
|
| 63 |
+
len_pow = 1.0
|
| 64 |
+
if sigmas is not None or alpha is not None:
|
| 65 |
+
if sigmas is None or alpha is None:
|
| 66 |
+
raise ValueError("sigmas and alpha are required for simpson3_lenpow mix")
|
| 67 |
+
return loss_simpson3_lenpow_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, len_pow, normalize_alpha)
|
| 68 |
+
return loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma, len_pow)
|
| 69 |
+
if variant == "semi_legendre":
|
| 70 |
+
return loss_semi_legendre(p_pred, q_pred, p_gt, q_gt, sigma, t_nodes01, t_w)
|
| 71 |
+
if variant in ("centers", "segments_varifold", "semi_lobatto1"):
|
| 72 |
+
return varifold_loss_centers(pred_segments, gt_segments, sigma)
|
| 73 |
+
raise ValueError(f"Unknown varifold variant: {variant}")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def varifold_loss_batch(
|
| 77 |
+
pred_segments: torch.Tensor,
|
| 78 |
+
gt_segments: torch.Tensor,
|
| 79 |
+
*,
|
| 80 |
+
sigma: float = 0.1,
|
| 81 |
+
variant: str = "semi_lobatto3",
|
| 82 |
+
t_nodes01: torch.Tensor | None = None,
|
| 83 |
+
t_w: torch.Tensor | None = None,
|
| 84 |
+
sigmas: torch.Tensor | None = None,
|
| 85 |
+
alpha: torch.Tensor | None = None,
|
| 86 |
+
normalize_alpha: bool = True,
|
| 87 |
+
len_pow: float | None = None,
|
| 88 |
+
gt_mask: torch.Tensor | None = None,
|
| 89 |
+
pred_weights: torch.Tensor | None = None,
|
| 90 |
+
cross_only: bool = False,
|
| 91 |
+
) -> torch.Tensor:
|
| 92 |
+
if pred_segments.dim() != 4 or gt_segments.dim() != 4:
|
| 93 |
+
raise ValueError("pred_segments and gt_segments must be (B, N, 2, 3)")
|
| 94 |
+
p_pred, q_pred = pred_segments[:, :, 0], pred_segments[:, :, 1]
|
| 95 |
+
p_gt, q_gt = gt_segments[:, :, 0], gt_segments[:, :, 1]
|
| 96 |
+
|
| 97 |
+
w_gt = None
|
| 98 |
+
if gt_mask is not None:
|
| 99 |
+
w_gt = gt_mask.to(device=pred_segments.device, dtype=pred_segments.dtype)
|
| 100 |
+
|
| 101 |
+
w_pred = None
|
| 102 |
+
if pred_weights is not None:
|
| 103 |
+
w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)
|
| 104 |
+
|
| 105 |
+
if variant == "simpson3":
|
| 106 |
+
if sigmas is not None or alpha is not None:
|
| 107 |
+
if sigmas is None or alpha is None:
|
| 108 |
+
raise ValueError("sigmas and alpha are required for simpson3 mix")
|
| 109 |
+
return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only)
|
| 110 |
+
return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred)
|
| 111 |
+
|
| 112 |
+
# Fallback to per-sample loop for unsupported variants.
|
| 113 |
+
losses = []
|
| 114 |
+
sigmas_t = None
|
| 115 |
+
if sigmas is not None:
|
| 116 |
+
sigmas_t = torch.as_tensor(sigmas, device=pred_segments.device, dtype=pred_segments.dtype)
|
| 117 |
+
for idx in range(pred_segments.shape[0]):
|
| 118 |
+
gt_b = gt_segments[idx]
|
| 119 |
+
if gt_mask is not None:
|
| 120 |
+
gt_b = gt_b[gt_mask[idx]]
|
| 121 |
+
sigmas_i = sigmas
|
| 122 |
+
if sigmas_t is not None and sigmas_t.ndim == 2:
|
| 123 |
+
sigmas_i = sigmas_t[idx]
|
| 124 |
+
losses.append(
|
| 125 |
+
varifold_loss(
|
| 126 |
+
pred_segments[idx],
|
| 127 |
+
gt_b,
|
| 128 |
+
sigma=sigma,
|
| 129 |
+
variant=variant,
|
| 130 |
+
t_nodes01=t_nodes01,
|
| 131 |
+
t_w=t_w,
|
| 132 |
+
sigmas=sigmas_i,
|
| 133 |
+
alpha=alpha,
|
| 134 |
+
normalize_alpha=normalize_alpha,
|
| 135 |
+
len_pow=len_pow,
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
return torch.stack(losses, dim=0)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def varifold_loss_centers(
|
| 142 |
+
pred_segments: torch.Tensor,
|
| 143 |
+
gt_segments: torch.Tensor,
|
| 144 |
+
sigma: float = 0.1,
|
| 145 |
+
normalize_weights: bool = True,
|
| 146 |
+
) -> torch.Tensor:
|
| 147 |
+
eps = 1e-8
|
| 148 |
+
a_p, b_p = pred_segments[:, 0], pred_segments[:, 1]
|
| 149 |
+
a_g, b_g = gt_segments[:, 0], gt_segments[:, 1]
|
| 150 |
+
|
| 151 |
+
v_p = b_p - a_p
|
| 152 |
+
v_g = b_g - a_g
|
| 153 |
+
len_p = torch.linalg.norm(v_p, dim=-1)
|
| 154 |
+
len_g = torch.linalg.norm(v_g, dim=-1)
|
| 155 |
+
|
| 156 |
+
x_p = 0.5 * (a_p + b_p)
|
| 157 |
+
x_g = 0.5 * (a_g + b_g)
|
| 158 |
+
|
| 159 |
+
u_p = v_p / (len_p[:, None] + eps)
|
| 160 |
+
u_g = v_g / (len_g[:, None] + eps)
|
| 161 |
+
|
| 162 |
+
w_p = len_p
|
| 163 |
+
w_g = len_g
|
| 164 |
+
if normalize_weights:
|
| 165 |
+
w_p = w_p / (w_p.sum() + eps)
|
| 166 |
+
w_g = w_g / (w_g.sum() + eps)
|
| 167 |
+
|
| 168 |
+
diff_pp = x_p[:, None, :] - x_p[None, :, :]
|
| 169 |
+
diff_gg = x_g[:, None, :] - x_g[None, :, :]
|
| 170 |
+
diff_pg = x_p[:, None, :] - x_g[None, :, :]
|
| 171 |
+
d_pp = (diff_pp * diff_pp).sum(dim=-1)
|
| 172 |
+
d_gg = (diff_gg * diff_gg).sum(dim=-1)
|
| 173 |
+
d_pg = (diff_pg * diff_pg).sum(dim=-1)
|
| 174 |
+
|
| 175 |
+
inv2s2 = 1.0 / (2.0 * sigma * sigma)
|
| 176 |
+
k_pp = torch.exp(-d_pp * inv2s2)
|
| 177 |
+
k_gg = torch.exp(-d_gg * inv2s2)
|
| 178 |
+
k_pg = torch.exp(-d_pg * inv2s2)
|
| 179 |
+
|
| 180 |
+
dot_pp = (u_p[:, None, :] * u_p[None, :, :]).sum(dim=-1)
|
| 181 |
+
dot_gg = (u_g[:, None, :] * u_g[None, :, :]).sum(dim=-1)
|
| 182 |
+
dot_pg = (u_p[:, None, :] * u_g[None, :, :]).sum(dim=-1)
|
| 183 |
+
|
| 184 |
+
k_pp = k_pp * (dot_pp * dot_pp)
|
| 185 |
+
k_gg = k_gg * (dot_gg * dot_gg)
|
| 186 |
+
k_pg = k_pg * (dot_pg * dot_pg)
|
| 187 |
+
|
| 188 |
+
wp_row = w_p[:, None]
|
| 189 |
+
wp_col = w_p[None, :]
|
| 190 |
+
wg_row = w_g[:, None]
|
| 191 |
+
wg_col = w_g[None, :]
|
| 192 |
+
|
| 193 |
+
a_pp = (wp_row * wp_col * k_pp).sum(dim=-1).sum(dim=-1)
|
| 194 |
+
a_gg = (wg_row * wg_col * k_gg).sum(dim=-1).sum(dim=-1)
|
| 195 |
+
a_pg = (w_p[:, None] * w_g[None, :] * k_pg).sum(dim=-1).sum(dim=-1)
|
| 196 |
+
return a_pp + a_gg - 2.0 * a_pg
|
s23dr_2026_example/wire_varifold_kernels.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# -----------------------------
|
| 5 |
+
# Helpers
|
| 6 |
+
# -----------------------------
|
| 7 |
+
def segment_geom(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-9):
|
| 8 |
+
"""
|
| 9 |
+
p,q: (...,3)
|
| 10 |
+
returns d, a, ell, u:
|
| 11 |
+
d = q - p
|
| 12 |
+
a = ||d||^2
|
| 13 |
+
ell = sqrt(a + eps^2)
|
| 14 |
+
u = d / ell
|
| 15 |
+
"""
|
| 16 |
+
d = q - p
|
| 17 |
+
a = (d * d).sum(dim=-1)
|
| 18 |
+
eps_val = eps
|
| 19 |
+
if p.dtype in (torch.float16, torch.bfloat16):
|
| 20 |
+
eps_val = max(eps, float(torch.finfo(p.dtype).eps))
|
| 21 |
+
ell = torch.sqrt(a + eps_val * eps_val)
|
| 22 |
+
u = d / ell.unsqueeze(-1)
|
| 23 |
+
return d, a, ell, u
|
| 24 |
+
|
| 25 |
+
def sample_points(p: torch.Tensor, q: torch.Tensor, nodes01: torch.Tensor):
|
| 26 |
+
# (...,3) + (K,) -> (...,K,3)
|
| 27 |
+
d = q - p
|
| 28 |
+
nodes = nodes01.to(device=p.device, dtype=p.dtype)
|
| 29 |
+
shape = [1] * (p.dim() - 1) + [nodes.shape[0], 1]
|
| 30 |
+
nodes = nodes.view(*shape)
|
| 31 |
+
return p.unsqueeze(-2) + nodes * d.unsqueeze(-2)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Fixed Lobatto-3 / Simpson nodes+weights on [0,1]
|
| 35 |
+
LOBATTO3_NODES = torch.tensor([0.0, 0.5, 1.0])
|
| 36 |
+
# LOBATTO3_W = torch.tensor([1.0/6.0, 4.0/6.0, 1.0/6.0])
|
| 37 |
+
LOBATTO3_W = torch.tensor([1/3, 1/3, 1/3])
|
| 38 |
+
LOBATTO3_W2 = LOBATTO3_W[:, None] * LOBATTO3_W[None, :] # (3,3)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha: bool):
|
| 42 |
+
sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6)
|
| 43 |
+
alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype)
|
| 44 |
+
if normalize_alpha:
|
| 45 |
+
alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12)
|
| 46 |
+
return sigmas_t, alpha_t
|
| 47 |
+
|
| 48 |
+
# -----------------------------
|
| 49 |
+
# 1) Simpson-3 on both segments (3x3 product rule)
|
| 50 |
+
# -----------------------------
|
| 51 |
+
def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
|
| 52 |
+
if w is None:
|
| 53 |
+
return None
|
| 54 |
+
w = torch.as_tensor(w, device=device, dtype=dtype)
|
| 55 |
+
if w.dim() == 1:
|
| 56 |
+
if w.shape[0] != n:
|
| 57 |
+
raise ValueError(f"weight length {w.shape[0]} != {n}")
|
| 58 |
+
w = w.unsqueeze(0).expand(b, -1)
|
| 59 |
+
elif w.dim() == 2:
|
| 60 |
+
if w.shape[0] != b or w.shape[1] != n:
|
| 61 |
+
raise ValueError(f"weight shape {tuple(w.shape)} != ({b}, {n})")
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError("weights must be 1D or 2D")
|
| 64 |
+
return w
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def cross_simpson3(
|
| 68 |
+
pA,
|
| 69 |
+
qA,
|
| 70 |
+
pB,
|
| 71 |
+
qB,
|
| 72 |
+
sigma: float | torch.Tensor,
|
| 73 |
+
wA: torch.Tensor | None = None,
|
| 74 |
+
wB: torch.Tensor | None = None,
|
| 75 |
+
):
|
| 76 |
+
device, dtype = pA.device, pA.dtype
|
| 77 |
+
batched = pA.dim() == 3
|
| 78 |
+
if not batched:
|
| 79 |
+
pA = pA.unsqueeze(0)
|
| 80 |
+
qA = qA.unsqueeze(0)
|
| 81 |
+
pB = pB.unsqueeze(0)
|
| 82 |
+
qB = qB.unsqueeze(0)
|
| 83 |
+
nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
|
| 84 |
+
w2 = LOBATTO3_W2.to(device=device, dtype=dtype)
|
| 85 |
+
|
| 86 |
+
bsz, nA, _ = pA.shape
|
| 87 |
+
nB = pB.shape[1]
|
| 88 |
+
wA = _prep_weight(wA, nA, bsz, device, dtype)
|
| 89 |
+
wB = _prep_weight(wB, nB, bsz, device, dtype)
|
| 90 |
+
|
| 91 |
+
_, _, ellA, uA = segment_geom(pA, qA)
|
| 92 |
+
_, _, ellB, uB = segment_geom(pB, qB)
|
| 93 |
+
|
| 94 |
+
XA = sample_points(pA, qA, nodes) # (B,N,3,3)
|
| 95 |
+
YB = sample_points(pB, qB, nodes) # (B,M,3,3)
|
| 96 |
+
|
| 97 |
+
# angular + length factors: (N,M)
|
| 98 |
+
ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2)
|
| 99 |
+
lenfac = ellA[:, :, None] * ellB[:, None, :]
|
| 100 |
+
if wA is not None or wB is not None:
|
| 101 |
+
if wA is None:
|
| 102 |
+
wA = torch.ones((bsz, nA), device=device, dtype=dtype)
|
| 103 |
+
if wB is None:
|
| 104 |
+
wB = torch.ones((bsz, nB), device=device, dtype=dtype)
|
| 105 |
+
lenfac = lenfac * (wA[:, :, None] * wB[:, None, :])
|
| 106 |
+
|
| 107 |
+
# spatial: build (N,M,3,3) kernel via broadcasting
|
| 108 |
+
diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] # (B,N,M,3,3,3)
|
| 109 |
+
r2 = (diff * diff).sum(dim=-1) # (B,N,M,3,3)
|
| 110 |
+
sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype)
|
| 111 |
+
if sigma_t.ndim == 0:
|
| 112 |
+
inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t)
|
| 113 |
+
else:
|
| 114 |
+
if sigma_t.shape[0] != bsz:
|
| 115 |
+
raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}")
|
| 116 |
+
inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1)
|
| 117 |
+
K = torch.exp(-r2 * inv2s2) # (B,N,M,3,3)
|
| 118 |
+
|
| 119 |
+
spatial = (K * w2).sum(dim=-1).sum(dim=-1) # (B,N,M)
|
| 120 |
+
out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) # (B,)
|
| 121 |
+
return out[0] if not batched else out
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def cross_simpson3_lenpow(
|
| 125 |
+
pA,
|
| 126 |
+
qA,
|
| 127 |
+
pB,
|
| 128 |
+
qB,
|
| 129 |
+
sigma: float | torch.Tensor,
|
| 130 |
+
len_pow: float,
|
| 131 |
+
wA: torch.Tensor | None = None,
|
| 132 |
+
wB: torch.Tensor | None = None,
|
| 133 |
+
):
|
| 134 |
+
device, dtype = pA.device, pA.dtype
|
| 135 |
+
batched = pA.dim() == 3
|
| 136 |
+
if not batched:
|
| 137 |
+
pA = pA.unsqueeze(0)
|
| 138 |
+
qA = qA.unsqueeze(0)
|
| 139 |
+
pB = pB.unsqueeze(0)
|
| 140 |
+
qB = qB.unsqueeze(0)
|
| 141 |
+
nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
|
| 142 |
+
w2 = LOBATTO3_W2.to(device=device, dtype=dtype)
|
| 143 |
+
|
| 144 |
+
bsz, nA, _ = pA.shape
|
| 145 |
+
nB = pB.shape[1]
|
| 146 |
+
wA = _prep_weight(wA, nA, bsz, device, dtype)
|
| 147 |
+
wB = _prep_weight(wB, nB, bsz, device, dtype)
|
| 148 |
+
|
| 149 |
+
_, _, ellA, uA = segment_geom(pA, qA)
|
| 150 |
+
_, _, ellB, uB = segment_geom(pB, qB)
|
| 151 |
+
|
| 152 |
+
XA = sample_points(pA, qA, nodes) # (B,N,3,3)
|
| 153 |
+
YB = sample_points(pB, qB, nodes) # (B,M,3,3)
|
| 154 |
+
|
| 155 |
+
ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2)
|
| 156 |
+
lenfac = (ellA[:, :, None] * ellB[:, None, :]).pow(len_pow)
|
| 157 |
+
if wA is not None or wB is not None:
|
| 158 |
+
if wA is None:
|
| 159 |
+
wA = torch.ones((bsz, nA), device=device, dtype=dtype)
|
| 160 |
+
if wB is None:
|
| 161 |
+
wB = torch.ones((bsz, nB), device=device, dtype=dtype)
|
| 162 |
+
lenfac = lenfac * (wA[:, :, None] * wB[:, None, :])
|
| 163 |
+
|
| 164 |
+
diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] # (B,N,M,3,3,3)
|
| 165 |
+
r2 = (diff * diff).sum(dim=-1) # (B,N,M,3,3)
|
| 166 |
+
sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype)
|
| 167 |
+
if sigma_t.ndim == 0:
|
| 168 |
+
inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t)
|
| 169 |
+
else:
|
| 170 |
+
if sigma_t.shape[0] != bsz:
|
| 171 |
+
raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}")
|
| 172 |
+
inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1)
|
| 173 |
+
K = torch.exp(-r2 * inv2s2) # (B,N,M,3,3)
|
| 174 |
+
|
| 175 |
+
spatial = (K * w2).sum(dim=-1).sum(dim=-1) # (B,N,M)
|
| 176 |
+
out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) # (B,)
|
| 177 |
+
return out[0] if not batched else out
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# -----------------------------
|
| 181 |
+
# 2/3) Semi-analytic in s, quadrature in t
|
| 182 |
+
# - Lobatto-3 (endpoints+midpoint)
|
| 183 |
+
# - Gauss-Legendre Q (nodes/weights passed in)
|
| 184 |
+
# -----------------------------
|
| 185 |
+
def cross_semi_analytic(pA, qA, pB, qB, sigma: float, t_nodes01: torch.Tensor, t_w: torch.Tensor):
|
| 186 |
+
"""
|
| 187 |
+
Gaussian k_x. Integrate s exactly along A, integrate t numerically along B.
|
| 188 |
+
t_nodes01, t_w: (Q,) nodes/weights on [0,1] (constants you pass in)
|
| 189 |
+
"""
|
| 190 |
+
device, dtype = pA.device, pA.dtype
|
| 191 |
+
t = t_nodes01.to(device=device, dtype=dtype) # (Q,)
|
| 192 |
+
w = t_w.to(device=device, dtype=dtype) # (Q,)
|
| 193 |
+
|
| 194 |
+
dA, aA, ellA, uA = segment_geom(pA, qA)
|
| 195 |
+
dB, _, ellB, uB = segment_geom(pB, qB)
|
| 196 |
+
|
| 197 |
+
# (N,M) factors
|
| 198 |
+
ang = (uA @ uB.t()).pow(2)
|
| 199 |
+
lenfac = ellA[:, None] * ellB[None, :]
|
| 200 |
+
|
| 201 |
+
# r0: (N,M,3)
|
| 202 |
+
r0 = pA[:, None, :] - pB[None, :, :]
|
| 203 |
+
|
| 204 |
+
# r(t): (N,M,Q,3)
|
| 205 |
+
r = r0[:, :, None, :] - t[None, None, :, None] * dB[None, :, None, :]
|
| 206 |
+
|
| 207 |
+
# beta, r2: (N,M,Q)
|
| 208 |
+
beta = (r * dA[:, None, None, :]).sum(dim=-1)
|
| 209 |
+
r2 = (r * r).sum(dim=-1)
|
| 210 |
+
|
| 211 |
+
# semi-analytic constants per A segment: shapes broadcast to (N,1,1)
|
| 212 |
+
a = aA.clamp_min(1e-12)
|
| 213 |
+
inv_a = (1.0 / a).view(-1, 1, 1)
|
| 214 |
+
denom = (torch.sqrt(2.0 * a) * sigma).view(-1, 1, 1)
|
| 215 |
+
pref = (math.sqrt(math.pi) * sigma / torch.sqrt(2.0 * a)).view(-1, 1, 1)
|
| 216 |
+
|
| 217 |
+
# J(t): (N,M,Q)
|
| 218 |
+
exp_term = torch.exp(-(r2 - (beta * beta) * inv_a) / (2.0 * sigma * sigma))
|
| 219 |
+
erf1 = torch.special.erf((a.view(-1, 1, 1) + beta) / denom)
|
| 220 |
+
erf0 = torch.special.erf(beta / denom)
|
| 221 |
+
J = pref * (erf1 - erf0) * exp_term
|
| 222 |
+
|
| 223 |
+
# integrate over t: (N,M)
|
| 224 |
+
spatial = (J * w.view(1, 1, -1)).sum(dim=-1)
|
| 225 |
+
return (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def cross_semi_lobatto3(pA, qA, pB, qB, sigma: float):
|
| 229 |
+
device, dtype = pA.device, pA.dtype
|
| 230 |
+
t = LOBATTO3_NODES.to(device=device, dtype=dtype)
|
| 231 |
+
w = LOBATTO3_W.to(device=device, dtype=dtype)
|
| 232 |
+
return cross_semi_analytic(pA, qA, pB, qB, sigma, t, w)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def cross_semi_lobatto3_mix(
|
| 236 |
+
pA,
|
| 237 |
+
qA,
|
| 238 |
+
pB,
|
| 239 |
+
qB,
|
| 240 |
+
sigmas,
|
| 241 |
+
alpha,
|
| 242 |
+
normalize_alpha: bool = True,
|
| 243 |
+
):
|
| 244 |
+
"""
|
| 245 |
+
Semi-analytic in s (along A), Lobatto-3 in t (along B), with a sigma mixture.
|
| 246 |
+
"""
|
| 247 |
+
device, dtype = pA.device, pA.dtype
|
| 248 |
+
t_nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
|
| 249 |
+
t_w = LOBATTO3_W.to(device=device, dtype=dtype)
|
| 250 |
+
|
| 251 |
+
sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
|
| 252 |
+
|
| 253 |
+
dA, aA, ellA, uA = segment_geom(pA, qA)
|
| 254 |
+
dB, _, ellB, uB = segment_geom(pB, qB)
|
| 255 |
+
|
| 256 |
+
ang = (uA @ uB.t()).pow(2)
|
| 257 |
+
lenfac = ellA[:, None] * ellB[None, :]
|
| 258 |
+
|
| 259 |
+
r0 = pA[:, None, :] - pB[None, :, :]
|
| 260 |
+
|
| 261 |
+
a = aA.clamp_min(1e-12)
|
| 262 |
+
inv_a = (1.0 / a).view(-1, 1)
|
| 263 |
+
sqrt_a = torch.sqrt(2.0 * a).clamp_min(1e-12)
|
| 264 |
+
|
| 265 |
+
denom = (sqrt_a[:, None] * sigmas_t[None, :]).clamp_min(1e-12)
|
| 266 |
+
pref = math.sqrt(math.pi) * sigmas_t[None, :] / sqrt_a[:, None]
|
| 267 |
+
inv2s2 = (1.0 / (2.0 * sigmas_t * sigmas_t)).view(1, 1, -1)
|
| 268 |
+
|
| 269 |
+
denom_nmS = denom[:, None, :]
|
| 270 |
+
pref_nmS = pref[:, None, :]
|
| 271 |
+
alpha_nmS = alpha_t.view(1, 1, -1)
|
| 272 |
+
a_nm1 = a[:, None, None]
|
| 273 |
+
|
| 274 |
+
spatial = torch.zeros((pA.shape[0], pB.shape[0]), device=device, dtype=dtype)
|
| 275 |
+
for tk, wk in zip(t_nodes, t_w):
|
| 276 |
+
r = r0 - tk * dB[None, :, :]
|
| 277 |
+
beta = (r * dA[:, None, :]).sum(dim=-1)
|
| 278 |
+
r2 = (r * r).sum(dim=-1)
|
| 279 |
+
core = r2 - (beta * beta) * inv_a
|
| 280 |
+
|
| 281 |
+
exp_term = torch.exp(-core[:, :, None] * inv2s2)
|
| 282 |
+
erf1 = torch.special.erf((a_nm1 + beta[:, :, None]) / denom_nmS)
|
| 283 |
+
erf0 = torch.special.erf(beta[:, :, None] / denom_nmS)
|
| 284 |
+
J = pref_nmS * (erf1 - erf0) * exp_term
|
| 285 |
+
spatial = spatial + wk * (J * alpha_nmS).sum(dim=-1)
|
| 286 |
+
|
| 287 |
+
return (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# -----------------------------
|
| 291 |
+
# Full losses (self + self - 2 cross)
|
| 292 |
+
# -----------------------------
|
| 293 |
+
# def loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma: float):
|
| 294 |
+
# s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
|
| 295 |
+
# # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma)
|
| 296 |
+
# cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 297 |
+
# # return s_pred + s_gt - 2.0 * cross
|
| 298 |
+
# return s_pred - 2.0 * cross
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma: float):
|
| 302 |
+
s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
|
| 303 |
+
# s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma)
|
| 304 |
+
cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 305 |
+
# return s_pred + s_gt - 2.0 * cross
|
| 306 |
+
return s_pred - 2.0 * cross
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma: float, len_pow: float):
|
| 310 |
+
s_pred = cross_simpson3_lenpow(p_pred, q_pred, p_pred, q_pred, sigma, len_pow)
|
| 311 |
+
# s_gt = cross_simpson3_lenpow(p_gt, q_gt, p_gt, q_gt, sigma, len_pow)
|
| 312 |
+
cross = cross_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma, len_pow)
|
| 313 |
+
# return s_pred + s_gt - 2.0 * cross
|
| 314 |
+
return s_pred - 2.0 * cross
|
| 315 |
+
|
| 316 |
+
def loss_simpson3_mix(
|
| 317 |
+
p_pred,
|
| 318 |
+
q_pred,
|
| 319 |
+
p_gt,
|
| 320 |
+
q_gt,
|
| 321 |
+
sigmas,
|
| 322 |
+
alpha,
|
| 323 |
+
normalize_alpha: bool = True,
|
| 324 |
+
):
|
| 325 |
+
device, dtype = p_pred.device, p_pred.dtype
|
| 326 |
+
sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
|
| 327 |
+
losses = [loss_simpson3(p_pred, q_pred, p_gt, q_gt, s) for s in sigmas_t]
|
| 328 |
+
return (torch.stack(losses) * alpha_t).sum()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# def loss_simpson3_batch(
|
| 332 |
+
# p_pred: torch.Tensor,
|
| 333 |
+
# q_pred: torch.Tensor,
|
| 334 |
+
# p_gt: torch.Tensor,
|
| 335 |
+
# q_gt: torch.Tensor,
|
| 336 |
+
# sigma: float | torch.Tensor,
|
| 337 |
+
# w_gt: torch.Tensor | None = None,
|
| 338 |
+
# ) -> torch.Tensor:
|
| 339 |
+
# s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
|
| 340 |
+
# # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma, wA=w_gt, wB=w_gt)
|
| 341 |
+
# cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wB=w_gt)
|
| 342 |
+
# # return s_pred + s_gt - 2.0 * cross
|
| 343 |
+
# return s_pred - 2.0 * cross
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def loss_simpson3_batch(
|
| 347 |
+
p_pred: torch.Tensor,
|
| 348 |
+
q_pred: torch.Tensor,
|
| 349 |
+
p_gt: torch.Tensor,
|
| 350 |
+
q_gt: torch.Tensor,
|
| 351 |
+
sigma: float | torch.Tensor,
|
| 352 |
+
w_gt: torch.Tensor | None = None,
|
| 353 |
+
w_pred: torch.Tensor | None = None,
|
| 354 |
+
cross_only: bool = False,
|
| 355 |
+
) -> torch.Tensor:
|
| 356 |
+
cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wA=w_pred, wB=w_gt)
|
| 357 |
+
if cross_only:
|
| 358 |
+
# No self-energy: avoids O(S^2) blowup, sinkhorn handles repulsion
|
| 359 |
+
return -2.0 * cross
|
| 360 |
+
s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma, wA=w_pred, wB=w_pred)
|
| 361 |
+
return s_pred - 2.0 * cross
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def loss_simpson3_mix_batch(
|
| 365 |
+
p_pred: torch.Tensor,
|
| 366 |
+
q_pred: torch.Tensor,
|
| 367 |
+
p_gt: torch.Tensor,
|
| 368 |
+
q_gt: torch.Tensor,
|
| 369 |
+
sigmas,
|
| 370 |
+
alpha,
|
| 371 |
+
w_gt: torch.Tensor | None = None,
|
| 372 |
+
w_pred: torch.Tensor | None = None,
|
| 373 |
+
normalize_alpha: bool = True,
|
| 374 |
+
cross_only: bool = False,
|
| 375 |
+
) -> torch.Tensor:
|
| 376 |
+
device, dtype = p_pred.device, p_pred.dtype
|
| 377 |
+
sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6)
|
| 378 |
+
alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype)
|
| 379 |
+
if normalize_alpha:
|
| 380 |
+
alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12)
|
| 381 |
+
if sigmas_t.ndim == 1:
|
| 382 |
+
losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, s, w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for s in sigmas_t]
|
| 383 |
+
return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
|
| 384 |
+
if sigmas_t.ndim == 2:
|
| 385 |
+
losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])]
|
| 386 |
+
return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
|
| 387 |
+
raise ValueError("sigmas must be 1D or 2D for batch loss")
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def loss_simpson3_lenpow_mix(
|
| 391 |
+
p_pred,
|
| 392 |
+
q_pred,
|
| 393 |
+
p_gt,
|
| 394 |
+
q_gt,
|
| 395 |
+
sigmas,
|
| 396 |
+
alpha,
|
| 397 |
+
len_pow: float,
|
| 398 |
+
normalize_alpha: bool = True,
|
| 399 |
+
):
|
| 400 |
+
device, dtype = p_pred.device, p_pred.dtype
|
| 401 |
+
sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
|
| 402 |
+
losses = [loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, s, len_pow) for s in sigmas_t]
|
| 403 |
+
return (torch.stack(losses) * alpha_t).sum()
|
| 404 |
+
|
| 405 |
+
def loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma: float):
|
| 406 |
+
s_pred = cross_semi_lobatto3(p_pred, q_pred, p_pred, q_pred, sigma)
|
| 407 |
+
# s_gt = cross_semi_lobatto3(p_gt, q_gt, p_gt, q_gt, sigma)
|
| 408 |
+
cross = cross_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma)
|
| 409 |
+
# return s_pred + s_gt - 2.0 * cross
|
| 410 |
+
return s_pred - 2.0 * cross
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def loss_semi_lobatto3_mix(
|
| 414 |
+
p_pred,
|
| 415 |
+
q_pred,
|
| 416 |
+
p_gt,
|
| 417 |
+
q_gt,
|
| 418 |
+
sigmas,
|
| 419 |
+
alpha,
|
| 420 |
+
normalize_alpha: bool = True,
|
| 421 |
+
):
|
| 422 |
+
s_pred = cross_semi_lobatto3_mix(p_pred, q_pred, p_pred, q_pred, sigmas, alpha, normalize_alpha)
|
| 423 |
+
# s_gt = cross_semi_lobatto3_mix(p_gt, q_gt, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 424 |
+
cross = cross_semi_lobatto3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
|
| 425 |
+
# return s_pred + s_gt - 2.0 * cross
|
| 426 |
+
return s_pred - 2.0 * cross
|
| 427 |
+
|
| 428 |
+
def loss_semi_lobatto3_mix_simple(
|
| 429 |
+
p_pred,
|
| 430 |
+
q_pred,
|
| 431 |
+
p_gt,
|
| 432 |
+
q_gt,
|
| 433 |
+
sigmas,
|
| 434 |
+
alpha,
|
| 435 |
+
normalize_alpha: bool = True,
|
| 436 |
+
):
|
| 437 |
+
device, dtype = p_pred.device, p_pred.dtype
|
| 438 |
+
sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
|
| 439 |
+
losses = [loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, s) for s in sigmas_t]
|
| 440 |
+
return (torch.stack(losses) * alpha_t).sum()
|
| 441 |
+
|
| 442 |
+
def loss_semi_legendre(p_pred, q_pred, p_gt, q_gt, sigma: float, t_nodes01, t_w):
|
| 443 |
+
s_pred = cross_semi_analytic(p_pred, q_pred, p_pred, q_pred, sigma, t_nodes01, t_w)
|
| 444 |
+
s_gt = cross_semi_analytic(p_gt, q_gt, p_gt, q_gt, sigma, t_nodes01, t_w)
|
| 445 |
+
cross = cross_semi_analytic(p_pred, q_pred, p_gt, q_gt, sigma, t_nodes01, t_w)
|
| 446 |
+
return s_pred + s_gt - 2.0 * cross
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
# -----------------------------
|
| 450 |
+
# torch.compile usage
|
| 451 |
+
# -----------------------------
|
| 452 |
+
# For Legendre: generate nodes/weights ONCE outside compile and pass them in.
|
| 453 |
+
# Example:
|
| 454 |
+
# import numpy as np
|
| 455 |
+
# x,w = np.polynomial.legendre.leggauss(Q)
|
| 456 |
+
# t_nodes = torch.tensor(0.5*(x+1.0), device=device, dtype=dtype)
|
| 457 |
+
# t_w = torch.tensor(0.5*w, device=device, dtype=dtype)
|
| 458 |
+
#
|
| 459 |
+
# compiled_loss = torch.compile(loss_semi_lobatto3, fullgraph=True)
|
| 460 |
+
# compiled_loss_leg = torch.compile(lambda pp,qp,pg,qg,s: loss_semi_legendre(pp,qp,pg,qg,s,t_nodes,t_w),
|
| 461 |
+
# fullgraph=True)
|
script.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""S23DR 2026 submission: learned wireframe prediction from fused point clouds.
|
| 2 |
+
|
| 3 |
+
Pipeline: raw sample -> point fusion -> priority sample 2048 -> model -> post-process -> wireframe
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def empty_solution():
|
| 17 |
+
return np.zeros((2, 3)), [(0, 1)]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Point fusion + sampling (from cache_scenes.py / make_sampled_cache.py)
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
# Add our package to path
|
| 25 |
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
| 26 |
+
sys.path.insert(0, str(SCRIPT_DIR))
|
| 27 |
+
|
| 28 |
+
from s23dr_2026_example.point_fusion import build_compact_scene, FuserConfig
|
| 29 |
+
from s23dr_2026_example.cache_scenes import (
|
| 30 |
+
_compute_group_and_class, _compute_smart_center_scale,
|
| 31 |
+
)
|
| 32 |
+
from s23dr_2026_example.make_sampled_cache import _priority_sample
|
| 33 |
+
|
| 34 |
+
# Tokenizer / model imports
|
| 35 |
+
from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig
|
| 36 |
+
from s23dr_2026_example.model import EdgeDepthSegmentsModel
|
| 37 |
+
from s23dr_2026_example.segment_postprocess import merge_vertices
|
| 38 |
+
from s23dr_2026_example.varifold import segments_to_vertices_edges
|
| 39 |
+
from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
|
| 40 |
+
|
| 41 |
+
SEQ_LEN = 2048
|
| 42 |
+
COLMAP_QUOTA = 1536
|
| 43 |
+
DEPTH_QUOTA = 512
|
| 44 |
+
CONF_THRESH = 0.7
|
| 45 |
+
MERGE_THRESH = 0.4
|
| 46 |
+
SNAP_RADIUS = 0.5
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def fuse_and_sample(sample, cfg, rng):
|
| 50 |
+
"""Run point fusion + priority sampling on a raw dataset sample.
|
| 51 |
+
|
| 52 |
+
Returns a dict with xyz_norm, class_id, source, mask, center, scale, etc.
|
| 53 |
+
ready for model inference. Returns None if fusion fails.
|
| 54 |
+
"""
|
| 55 |
+
try:
|
| 56 |
+
scene = build_compact_scene(sample, cfg, rng)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f" Fusion failed: {e}")
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
xyz = scene["xyz"]
|
| 62 |
+
source = scene["source"]
|
| 63 |
+
|
| 64 |
+
if len(xyz) < 10:
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
# Compute group_id and class_id (same as cache_scenes.py)
|
| 68 |
+
behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16))
|
| 69 |
+
group_id, class_id = _compute_group_and_class(
|
| 70 |
+
scene["visible_src"], scene["visible_id"], behind_id, source)
|
| 71 |
+
|
| 72 |
+
# Normalize
|
| 73 |
+
center, scale = _compute_smart_center_scale(xyz, source)
|
| 74 |
+
|
| 75 |
+
# Priority sample
|
| 76 |
+
indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA)
|
| 77 |
+
|
| 78 |
+
xyz_norm = (xyz[indices] - center) / scale
|
| 79 |
+
|
| 80 |
+
result = {
|
| 81 |
+
"xyz_norm": xyz_norm.astype(np.float32),
|
| 82 |
+
"class_id": class_id[indices].astype(np.int64),
|
| 83 |
+
"source": source[indices].astype(np.int64),
|
| 84 |
+
"mask": mask,
|
| 85 |
+
"center": center.astype(np.float32),
|
| 86 |
+
"scale": np.float32(scale),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Optional fields
|
| 90 |
+
if "behind_gest_id" in scene:
|
| 91 |
+
behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None)
|
| 92 |
+
result["behind"] = behind.astype(np.int64)
|
| 93 |
+
if "n_views_voted" in scene:
|
| 94 |
+
result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32)
|
| 95 |
+
if "vote_frac" in scene:
|
| 96 |
+
result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32)
|
| 97 |
+
|
| 98 |
+
# Visible src/id for snap post-processing
|
| 99 |
+
result["visible_src"] = scene["visible_src"][indices].astype(np.int64)
|
| 100 |
+
result["visible_id"] = scene["visible_id"][indices].astype(np.int64)
|
| 101 |
+
|
| 102 |
+
return result
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_model(checkpoint_path, device):
|
| 106 |
+
"""Load model from checkpoint."""
|
| 107 |
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 108 |
+
args = ckpt.get("args", {})
|
| 109 |
+
|
| 110 |
+
norm_class = torch.nn.RMSNorm if args.get("rms_norm") else None
|
| 111 |
+
seq_cfg = EdgeDepthSequenceConfig(
|
| 112 |
+
seq_len=SEQ_LEN, colmap_points=COLMAP_QUOTA, depth_points=DEPTH_QUOTA)
|
| 113 |
+
|
| 114 |
+
model = EdgeDepthSegmentsModel(
|
| 115 |
+
seq_cfg=seq_cfg,
|
| 116 |
+
segments=args.get("segments", 64),
|
| 117 |
+
hidden=args.get("hidden", 256),
|
| 118 |
+
num_heads=args.get("num_heads", 4),
|
| 119 |
+
kv_heads_cross=args.get("kv_heads_cross", 2),
|
| 120 |
+
kv_heads_self=args.get("kv_heads_self", 2),
|
| 121 |
+
dim_feedforward=args.get("ff", 1024),
|
| 122 |
+
dropout=args.get("dropout", 0.1),
|
| 123 |
+
latent_tokens=args.get("latent_tokens", 256),
|
| 124 |
+
latent_layers=args.get("latent_layers", 7),
|
| 125 |
+
decoder_layers=args.get("decoder_layers", 3),
|
| 126 |
+
cross_attn_interval=args.get("cross_attn_interval", 4),
|
| 127 |
+
norm_class=norm_class,
|
| 128 |
+
activation=args.get("activation", "gelu"),
|
| 129 |
+
segment_conf=args.get("segment_conf", True),
|
| 130 |
+
behind_emb_dim=args.get("behind_emb_dim", 8),
|
| 131 |
+
use_vote_features=args.get("vote_features", True),
|
| 132 |
+
arch=args.get("arch", "perceiver"),
|
| 133 |
+
encoder_layers=args.get("encoder_layers", 4),
|
| 134 |
+
pre_encoder_layers=args.get("pre_encoder_layers", 0),
|
| 135 |
+
segment_param=args.get("segment_param", "midpoint_dir_len"),
|
| 136 |
+
qk_norm=args.get("qk_norm", True),
|
| 137 |
+
).to(device)
|
| 138 |
+
|
| 139 |
+
# Handle torch.compile _orig_mod prefix
|
| 140 |
+
state = ckpt["model"]
|
| 141 |
+
fixed = {k.replace("segmenter._orig_mod.", "segmenter."): v
|
| 142 |
+
for k, v in state.items()}
|
| 143 |
+
model.load_state_dict(fixed, strict=True)
|
| 144 |
+
model.eval()
|
| 145 |
+
return model
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def build_tokens_single(sample_dict, model, device):
|
| 149 |
+
"""Build token tensor for a single sample (no DataLoader)."""
|
| 150 |
+
xyz = torch.as_tensor(sample_dict["xyz_norm"], dtype=torch.float32).unsqueeze(0).to(device)
|
| 151 |
+
cid = torch.as_tensor(sample_dict["class_id"], dtype=torch.long).unsqueeze(0).to(device)
|
| 152 |
+
src = torch.as_tensor(sample_dict["source"], dtype=torch.long).unsqueeze(0).to(device)
|
| 153 |
+
masks = torch.as_tensor(sample_dict["mask"], dtype=torch.bool).unsqueeze(0).to(device)
|
| 154 |
+
|
| 155 |
+
B, T, _ = xyz.shape
|
| 156 |
+
tok = model.tokenizer
|
| 157 |
+
fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \
|
| 158 |
+
if tok.pos_enc is not None else xyz.new_zeros(B, T, 0)
|
| 159 |
+
parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))]
|
| 160 |
+
|
| 161 |
+
if tok.behind_emb_dim > 0:
|
| 162 |
+
if "behind" in sample_dict:
|
| 163 |
+
beh = torch.as_tensor(sample_dict["behind"], dtype=torch.long).unsqueeze(0).to(device)
|
| 164 |
+
else:
|
| 165 |
+
beh = xyz.new_zeros(B, T, dtype=torch.long)
|
| 166 |
+
parts.append(tok.behind_emb(beh))
|
| 167 |
+
|
| 168 |
+
if tok.use_vote_features:
|
| 169 |
+
if "n_views_voted" in sample_dict and "vote_frac" in sample_dict:
|
| 170 |
+
nv = ((torch.as_tensor(sample_dict["n_views_voted"], dtype=torch.float32).unsqueeze(0).to(device) - 2.7) / 1.0).unsqueeze(-1)
|
| 171 |
+
vf = ((torch.as_tensor(sample_dict["vote_frac"], dtype=torch.float32).unsqueeze(0).to(device) - 0.5) / 0.25).unsqueeze(-1)
|
| 172 |
+
parts.extend([nv, vf])
|
| 173 |
+
else:
|
| 174 |
+
parts.extend([xyz.new_zeros(B, T, 1), xyz.new_zeros(B, T, 1)])
|
| 175 |
+
|
| 176 |
+
tokens = torch.cat(parts, dim=-1)
|
| 177 |
+
return tokens, masks
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def predict_sample(sample_dict, model, device):
|
| 181 |
+
"""Run model inference + post-processing on a fused sample.
|
| 182 |
+
|
| 183 |
+
Returns (vertices, edges) in world space.
|
| 184 |
+
"""
|
| 185 |
+
tokens, masks = build_tokens_single(sample_dict, model, device)
|
| 186 |
+
scale = float(sample_dict["scale"])
|
| 187 |
+
center = sample_dict["center"]
|
| 188 |
+
|
| 189 |
+
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
|
| 190 |
+
enabled=(device.type == 'cuda')):
|
| 191 |
+
out = model.forward_tokens(tokens, masks)
|
| 192 |
+
|
| 193 |
+
segs = out["segments"][0].float().cpu()
|
| 194 |
+
conf = torch.sigmoid(out["conf"][0].float()).cpu().numpy() if "conf" in out else None
|
| 195 |
+
|
| 196 |
+
# Confidence filter
|
| 197 |
+
if conf is not None:
|
| 198 |
+
keep = conf > CONF_THRESH
|
| 199 |
+
segs = segs[keep]
|
| 200 |
+
if len(segs) < 1:
|
| 201 |
+
return empty_solution()
|
| 202 |
+
|
| 203 |
+
# To world space
|
| 204 |
+
segs_world = segs.numpy() * scale + center
|
| 205 |
+
|
| 206 |
+
# Vertices + edges from segments
|
| 207 |
+
pv, pe = segments_to_vertices_edges(torch.tensor(segs_world))
|
| 208 |
+
pv, pe = pv.numpy(), np.array(pe, dtype=np.int32)
|
| 209 |
+
|
| 210 |
+
# Merge
|
| 211 |
+
pv, pe = merge_vertices(pv, pe, MERGE_THRESH)
|
| 212 |
+
|
| 213 |
+
# Snap to point cloud
|
| 214 |
+
xyz_norm = sample_dict["xyz_norm"]
|
| 215 |
+
mask = sample_dict["mask"]
|
| 216 |
+
cid = sample_dict["class_id"]
|
| 217 |
+
xyz_world = xyz_norm[mask] * scale + center
|
| 218 |
+
cid_valid = cid[mask]
|
| 219 |
+
pv = snap_to_point_cloud(pv, xyz_world, cid_valid, snap_radius=SNAP_RADIUS)
|
| 220 |
+
|
| 221 |
+
# Horizontal snap
|
| 222 |
+
pv = snap_horizontal(pv, pe)
|
| 223 |
+
|
| 224 |
+
if len(pv) < 2 or len(pe) < 1:
|
| 225 |
+
return empty_solution()
|
| 226 |
+
|
| 227 |
+
edges = [(int(a), int(b)) for a, b in pe]
|
| 228 |
+
return pv, edges
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ---------------------------------------------------------------------------
|
| 232 |
+
# Main
|
| 233 |
+
# ---------------------------------------------------------------------------
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
t_start = time.time()
|
| 237 |
+
|
| 238 |
+
# Load params
|
| 239 |
+
param_path = Path("params.json")
|
| 240 |
+
with param_path.open() as f:
|
| 241 |
+
params = json.load(f)
|
| 242 |
+
print(f"Competition: {params.get('competition_id', '?')}")
|
| 243 |
+
print(f"Dataset: {params.get('dataset', '?')}")
|
| 244 |
+
|
| 245 |
+
# Load test data
|
| 246 |
+
data_path = Path("/tmp/data")
|
| 247 |
+
if not data_path.exists():
|
| 248 |
+
from huggingface_hub import snapshot_download
|
| 249 |
+
snapshot_download(
|
| 250 |
+
repo_id=params["dataset"],
|
| 251 |
+
local_dir="/tmp/data",
|
| 252 |
+
repo_type="dataset",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
from datasets import load_dataset
|
| 256 |
+
data_files = {
|
| 257 |
+
"validation": [str(p) for p in data_path.rglob("*public*/**/*.tar")],
|
| 258 |
+
"test": [str(p) for p in data_path.rglob("*private*/**/*.tar")],
|
| 259 |
+
}
|
| 260 |
+
print(f"Data files: {data_files}")
|
| 261 |
+
dataset = load_dataset(
|
| 262 |
+
str(data_path / "hoho22k_2026_test_x_anon.py"),
|
| 263 |
+
data_files=data_files,
|
| 264 |
+
trust_remote_code=True,
|
| 265 |
+
writer_batch_size=100,
|
| 266 |
+
)
|
| 267 |
+
print(f"Loaded: {dataset}")
|
| 268 |
+
|
| 269 |
+
# Load model
|
| 270 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 271 |
+
print(f"Device: {device}")
|
| 272 |
+
checkpoint_path = SCRIPT_DIR / "checkpoint.pt"
|
| 273 |
+
model = load_model(checkpoint_path, device)
|
| 274 |
+
print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
|
| 275 |
+
|
| 276 |
+
# Point fusion config
|
| 277 |
+
cfg = FuserConfig()
|
| 278 |
+
rng = np.random.RandomState(2718)
|
| 279 |
+
|
| 280 |
+
# Process all samples
|
| 281 |
+
solution = []
|
| 282 |
+
total_samples = sum(len(dataset[s]) for s in dataset)
|
| 283 |
+
processed = 0
|
| 284 |
+
|
| 285 |
+
for subset_name in dataset:
|
| 286 |
+
print(f"\nProcessing {subset_name} ({len(dataset[subset_name])} samples)...")
|
| 287 |
+
|
| 288 |
+
for sample in tqdm(dataset[subset_name], desc=subset_name):
|
| 289 |
+
order_id = sample["order_id"]
|
| 290 |
+
|
| 291 |
+
# Fuse + sample
|
| 292 |
+
fused = fuse_and_sample(sample, cfg, rng)
|
| 293 |
+
if fused is None:
|
| 294 |
+
pred_v, pred_e = empty_solution()
|
| 295 |
+
else:
|
| 296 |
+
try:
|
| 297 |
+
pred_v, pred_e = predict_sample(fused, model, device)
|
| 298 |
+
except Exception as e:
|
| 299 |
+
print(f" Predict failed for {order_id}: {e}")
|
| 300 |
+
pred_v, pred_e = empty_solution()
|
| 301 |
+
|
| 302 |
+
solution.append({
|
| 303 |
+
"order_id": order_id,
|
| 304 |
+
"wf_vertices": pred_v.tolist() if isinstance(pred_v, np.ndarray) else pred_v,
|
| 305 |
+
"wf_edges": [(int(a), int(b)) for a, b in pred_e],
|
| 306 |
+
})
|
| 307 |
+
processed += 1
|
| 308 |
+
|
| 309 |
+
if processed % 50 == 0:
|
| 310 |
+
elapsed = time.time() - t_start
|
| 311 |
+
rate = elapsed / processed
|
| 312 |
+
remaining = (total_samples - processed) * rate
|
| 313 |
+
print(f" [{processed}/{total_samples}] "
|
| 314 |
+
f"{elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining")
|
| 315 |
+
|
| 316 |
+
# Save
|
| 317 |
+
with open("submission.json", "w") as f:
|
| 318 |
+
json.dump(solution, f)
|
| 319 |
+
|
| 320 |
+
elapsed = time.time() - t_start
|
| 321 |
+
print(f"\nDone. {processed} samples in {elapsed:.0f}s ({elapsed/max(processed,1):.1f}s/sample)")
|
| 322 |
+
print(f"Saved submission.json ({len(solution)} entries)")
|