Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
jacklangerman commited on
Commit
f4487da
·
verified ·
1 Parent(s): 4ddee35

Upload folder using huggingface_hub

Browse files
checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc38a61ff512948b1dc92a30129d6efdd093f507948fc5b538050c4a38bfbf6c
3
+ size 106460054
s23dr_2026_example/__init__.py ADDED
File without changes
s23dr_2026_example/attention.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom_transformer.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ # =============================================================================
7
+ # Core Efficient Multihead Attention using Scaled Dot Product Attention (SDPA)
8
+ # =============================================================================
9
+
10
+ class MultiHeadSDPA(nn.Module):
11
+ """
12
+ Multi-head cross-attention using torch.nn.functional.scaled_dot_product_attention
13
+ without causal masking. Suitable for set inputs and cross-attention.
14
+
15
+ If qk_norm=True, L2-normalizes Q and K per-head before the dot product,
16
+ then scales by a learned per-head temperature (log_scale). This caps logit
17
+ magnitude to [-1, +1] * exp(log_scale), preventing attention entropy
18
+ collapse at large head_dim.
19
+ """
20
+ def __init__(self, d_model: int, num_heads: int, kv_heads: int = None,
21
+ qk_norm: bool = False, qk_norm_type: str = "l2"):
22
+ super().__init__()
23
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
24
+ self.d_model = d_model
25
+ self.num_heads = num_heads
26
+ self.kv_heads = kv_heads or num_heads
27
+ assert self.num_heads % self.kv_heads == 0, "kv_heads must divide num_heads"
28
+
29
+ self.head_dim = d_model // num_heads
30
+ self.qk_norm = qk_norm
31
+ self.qk_norm_type = qk_norm_type
32
+
33
+ # Input projection layers
34
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
35
+ self.k_proj = nn.Linear(d_model, self.kv_heads * self.head_dim, bias=False)
36
+ self.v_proj = nn.Linear(d_model, self.kv_heads * self.head_dim, bias=False)
37
+
38
+ # Output projection
39
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
40
+ nn.init.zeros_(self.out_proj.weight)
41
+
42
+ if qk_norm:
43
+ import math
44
+ if qk_norm_type == "rms":
45
+ # Standard QK-norm (Qwen3/Gemma3 style): RMSNorm on Q and K,
46
+ # no learned temperature. SDPA's 1/sqrt(d) scaling is sufficient
47
+ # because RMSNorm preserves the expected logit variance.
48
+ pass # no extra parameters needed
49
+ else:
50
+ # L2 + learned temperature (nGPT/ViT-22B style):
51
+ # L2 projects to unit sphere, needs learned scale to compensate.
52
+ self.log_scale = nn.Parameter(
53
+ torch.full((num_heads,), math.log(math.sqrt(self.head_dim))))
54
+
55
+ def forward(
56
+ self,
57
+ query: torch.Tensor,
58
+ key: torch.Tensor,
59
+ key_padding_mask: torch.Tensor | None = None,
60
+ ) -> torch.Tensor:
61
+ # Project
62
+ q = self.q_proj(query)
63
+ k = self.k_proj(key)
64
+ v = self.v_proj(key)
65
+
66
+ B, Tq, _ = q.shape
67
+ _, Tk, _ = k.shape
68
+
69
+ q = q.view(B, Tq, self.num_heads, self.head_dim).transpose(1, 2)
70
+ k = k.view(B, Tk, self.kv_heads, self.head_dim).transpose(1, 2)
71
+ v = v.view(B, Tk, self.kv_heads, self.head_dim).transpose(1, 2)
72
+
73
+ if self.kv_heads != self.num_heads:
74
+ repeat = self.num_heads // self.kv_heads
75
+ k = k.repeat_interleave(repeat, dim=1)
76
+ v = v.repeat_interleave(repeat, dim=1)
77
+
78
+ if self.qk_norm:
79
+ if self.qk_norm_type == "rms":
80
+ # RMSNorm (Qwen3/Gemma3 style): no learned temperature needed.
81
+ # After RMSNorm, logit variance matches standard SDPA naturally.
82
+ q = q * torch.rsqrt(q.square().mean(dim=-1, keepdim=True) + 1e-6)
83
+ k = k * torch.rsqrt(k.square().mean(dim=-1, keepdim=True) + 1e-6)
84
+ attn_mask = None
85
+ if key_padding_mask is not None:
86
+ attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
87
+ attn_out = F.scaled_dot_product_attention(
88
+ q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
89
+ )
90
+ else:
91
+ # L2 + learned temperature (nGPT/ViT-22B style)
92
+ q = F.normalize(q, dim=-1)
93
+ k = F.normalize(k, dim=-1)
94
+ scale = self.log_scale.exp().view(1, -1, 1, 1)
95
+ q = q * scale
96
+ attn_mask = None
97
+ if key_padding_mask is not None:
98
+ attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
99
+ attn_out = F.scaled_dot_product_attention(
100
+ q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
101
+ scale=1.0,
102
+ )
103
+ else:
104
+ attn_mask = None
105
+ if key_padding_mask is not None:
106
+ attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
107
+ attn_out = F.scaled_dot_product_attention(
108
+ q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
109
+ )
110
+
111
+ attn_out = attn_out.transpose(1, 2).reshape(B, Tq, self.d_model)
112
+ return self.out_proj(attn_out)
113
+
114
+
115
+ # =============================================================================
116
+ # Transformer Feed-Forward Block
117
+ # =============================================================================
118
+
119
+ def _get_activation(name: str):
120
+ """Look up activation function by name. Supports 'relu_sq' for ReLU^2."""
121
+ if name == "relu_sq":
122
+ return lambda x: F.relu(x).square()
123
+ return getattr(F, name)
124
+
125
+
126
+ class FeedForward(nn.Module):
127
+ """
128
+ Position-wise MLP block: linear -> activation -> linear.
129
+ Supports 'gelu', 'relu', 'relu_sq', etc.
130
+ """
131
+ def __init__(self, d_model: int, dim_ff: int, activation: str = "gelu"):
132
+ super().__init__()
133
+ self.linear1 = nn.Linear(d_model, dim_ff)
134
+ self.linear2 = nn.Linear(dim_ff, d_model)
135
+ nn.init.zeros_(self.linear2.weight)
136
+ nn.init.zeros_(self.linear2.bias)
137
+ self.activation = _get_activation(activation)
138
+
139
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
140
+ x = self.linear1(x)
141
+ return self.linear2(self.activation(x))
142
+
143
+
144
+ # =============================================================================
145
+ # Custom Transformer Block
146
+ # =============================================================================
147
+
148
+ class TransformerBlock(nn.Module):
149
+ """
150
+ Single transformer block combining:
151
+ - multi-head SDPA (non-causal)
152
+ - layernorm + residual
153
+ - feed-forward MLP + residual
154
+ """
155
+ def __init__(
156
+ self,
157
+ d_model: int,
158
+ num_heads: int,
159
+ dim_ff: int,
160
+ dropout: float = 0.0,
161
+ activation: str = "gelu",
162
+ kv_heads: int = None,
163
+ ):
164
+ super().__init__()
165
+ self.norm1 = nn.LayerNorm(d_model)
166
+ self.norm2 = nn.LayerNorm(d_model)
167
+
168
+ self.attn = MultiHeadSDPA(d_model, num_heads, kv_heads=kv_heads)
169
+ self.dropout1 = nn.Dropout(dropout)
170
+ self.ffn = FeedForward(d_model, dim_ff, activation=activation)
171
+ self.dropout2 = nn.Dropout(dropout)
172
+
173
+ def forward(
174
+ self,
175
+ x: torch.Tensor,
176
+ memory: torch.Tensor,
177
+ memory_key_padding_mask: torch.Tensor | None = None,
178
+ ) -> torch.Tensor:
179
+ res = x
180
+ x = self.norm1(x)
181
+ x = self.attn(x, memory, key_padding_mask=memory_key_padding_mask)
182
+ x = res + self.dropout1(x)
183
+
184
+ res = x
185
+ x = self.norm2(x)
186
+ x = self.ffn(x)
187
+ return res + self.dropout2(x)
188
+
189
+
190
+ class TransformerDecoderSets(nn.Module):
191
+ """
192
+ A stack of TransformerBlock layers for set-to-set
193
+ modeling without causal masks.
194
+ """
195
+ def __init__(
196
+ self,
197
+ d_model: int,
198
+ num_heads: int,
199
+ dim_ff: int,
200
+ num_layers: int,
201
+ dropout: float = 0.0,
202
+ activation: str = "gelu",
203
+ kv_heads: int = None,
204
+ ):
205
+ super().__init__()
206
+ self.layers = nn.ModuleList([
207
+ TransformerBlock(
208
+ d_model,
209
+ num_heads,
210
+ dim_ff,
211
+ dropout=dropout,
212
+ activation=activation,
213
+ kv_heads=kv_heads,
214
+ )
215
+ for _ in range(num_layers)
216
+ ])
217
+
218
+ def forward(
219
+ self,
220
+ tgt: torch.Tensor,
221
+ memory: torch.Tensor,
222
+ memory_key_padding_mask: torch.Tensor | None = None,
223
+ ) -> torch.Tensor:
224
+ for layer in self.layers:
225
+ tgt = layer(tgt, memory, memory_key_padding_mask=memory_key_padding_mask)
226
+ return tgt
s23dr_2026_example/bad_samples.txt ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 14b1872e960
2
+ 1807ef90db4
3
+ 180e6a67e87
4
+ 1ad5c6bd31f
5
+ 1c3f939ad93
6
+ 1ede4c0d52f
7
+ 214f17d9cc4
8
+ 22256d88df9
9
+ 24a92a8de6d
10
+ 24b4e984bad
11
+ 2565978cf53
12
+ 2a71f1a2072
13
+ 2d44c1fade6
14
+ 2ebed43823a
15
+ 33982551420
16
+ 3b480496f82
17
+ 412a2bdf7a4
18
+ 44343bbabbb
19
+ 4a0b3f04cbd
20
+ 4a7fa170826
21
+ 4b7dc027214
22
+ 4e0dc2c9b18
23
+ 5172a516c8b
24
+ 529e8f15cd2
25
+ 56fc6f6f163
26
+ 575963ce814
27
+ 578ec40a278
28
+ 5a0c07c575a
29
+ 5d521223c26
30
+ 6148b5c9461
31
+ 631eb6d7c03
32
+ 655a14f8a75
33
+ 66502d7ee6f
34
+ 6da76fc6687
35
+ 777eaaad0ca
36
+ 7a4e2909d68
37
+ 7c5c9baf483
38
+ 80806dfd75e
39
+ 81a4ead431d
40
+ 833152dd554
41
+ 85797868c0f
42
+ 86460ad8181
43
+ 86783a6bee4
44
+ 95193322d7a
45
+ 99a9d056200
46
+ 9b1d4eeaab9
47
+ 9ff759f2e4c
48
+ acbd243da16
49
+ b9b275710c0
50
+ beceaa9bb7c
51
+ c243d079286
52
+ c5c7337d2cb
53
+ cdf6f2d3b35
54
+ cfe370f1c87
55
+ d4a72aea80c
56
+ d655f066cd3
57
+ d79e8d9455c
58
+ d7d6c5be76e
59
+ dc30ae4b93b
60
+ de9495f7ca3
61
+ e1901819c72
62
+ e1d88c1a6b1
63
+ e5d3eb0a617
64
+ ec11d3cdcf6
65
+ ecb21fad0ad
66
+ ee55d8c6493
67
+ ee7e6d4dee1
68
+ 008052054aa
69
+ 03ecb7d3cf3
70
+ 0555a655534
71
+ 099cad230c6
72
+ 0d061ae23f0
73
+ 10741a421c0
74
+ 110d5e407b9
75
+ 128a7fb415a
76
+ 13177736b26
77
+ 1635d73bf7d
78
+ 18a760de9ea
79
+ 18d90d03e95
80
+ 209627a5c1a
81
+ 21e3cd4b7b8
82
+ 22f5499200d
83
+ 266eb64de68
84
+ 269235f770b
85
+ 2758490e558
86
+ 2a203cf5d35
87
+ 2a878ec47ab
88
+ 2cb43eb2201
89
+ 393298e282b
90
+ 395abe6aac7
91
+ 3d19c7a4ca3
92
+ 44e2b719b1e
93
+ 45039819fcc
94
+ 4cb4ff01619
95
+ 4e5eb5712fa
96
+ 4e988765a6d
97
+ 5077bf42714
98
+ 55ed69b2622
99
+ 5ae3b651a37
100
+ 5ca1edeed4c
101
+ 5daa76b1c7f
102
+ 5fdd11dfae5
103
+ 6078cf180c2
104
+ 6682b309e9c
105
+ 6c02d2038c0
106
+ 71c595506c8
107
+ 73c8f960c18
108
+ 74ccc8fd057
109
+ 7a34156a798
110
+ 7ac7af9f59c
111
+ 7f2ec0ea179
112
+ 823b837b36c
113
+ 82d7600f9a3
114
+ 848161a2900
115
+ 88cedf129eb
116
+ 8dec106b6a6
117
+ 8e335d08ca4
118
+ 8ecf7c58193
119
+ 8fa55008beb
120
+ 90e09de2301
121
+ 9197acc0b9d
122
+ 954c25e876c
123
+ 98517d5563d
124
+ 99e717a0148
125
+ 9a0c0635bd7
126
+ 9ad436b7b3d
127
+ 9be351cbf14
128
+ 9e2a2e51798
129
+ a84a7ea9220
130
+ aa8cb84d3eb
131
+ b07977292da
132
+ b3e33456f0b
133
+ b7823de373e
134
+ bac379382d9
135
+ bd2d9bf67a3
136
+ c14584a84cd
137
+ c497170c970
138
+ cd8e767612b
139
+ d17917bb279
140
+ d42b9d432a9
141
+ d53d8857a85
142
+ d6808cf3d98
143
+ d6f509d1dd9
144
+ d7abd08e643
145
+ d83493bf974
146
+ d87293651ee
147
+ da9d4ac9e8e
148
+ daa1702791a
149
+ dcb12411c14
150
+ de9ab9cdd5b
151
+ df906c58a3c
152
+ e3870649eb5
153
+ ea90aed9b98
154
+ ecaa81b9711
155
+ efc1238665b
156
+ c5a65219daf
s23dr_2026_example/cache_scenes.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Cache compact scenes from HoHo22k shards to training-ready .pt files.
3
+
4
+ Runs build_compact_scene + precomputes group_id, semantic class, and
5
+ normalization so training only needs fast sampling + GPU forward.
6
+
7
+ Usage:
8
+ python cache_scenes.py --data-dir data/ --out-dir cache/train
9
+ python cache_scenes.py --streaming --out-dir cache/train --limit 5000
10
+ python cache_scenes.py --data-dir data/ --out-dir cache/train --workers 4
11
+
12
+ Cache format per file (.pt):
13
+ xyz: float32 [P, 3] all points in world space
14
+ source: uint8 [P] 0=colmap, 1=depth
15
+ group_id: int8 [P] priority tier 0-4, -1=excluded
16
+ class_id: uint8 [P] one-hot class index (0-12), see SEMANTIC_CLASSES
17
+ visible_src: uint8 [P] for visualization (1=gestalt, 2=ade)
18
+ visible_id: int16 [P] for visualization (class id within space)
19
+ center: float32 [3] smart normalization center
20
+ scale: float32 scalar smart normalization scale
21
+ gt_vertices: float32 [V, 3] ground truth wireframe vertices
22
+ gt_edges: int32 [E, 2] ground truth wireframe edge indices
23
+ """
24
+ from __future__ import annotations
25
+
26
+ import sys
27
+ from pathlib import Path as _Path
28
+ if __package__ is None or __package__ == "":
29
+ _here = _Path(__file__).resolve().parent
30
+ if str(_here.parent) not in sys.path:
31
+ sys.path.insert(0, str(_here.parent))
32
+ __package__ = _here.name
33
+
34
+ import argparse
35
+ import time
36
+ from concurrent.futures import ProcessPoolExecutor, as_completed
37
+ from pathlib import Path
38
+
39
+ import numpy as np
40
+ import torch
41
+
42
+ from .point_fusion import (
43
+ FuserConfig, build_compact_scene,
44
+ GEST_ID_TO_NAME, ADE_ID_TO_NAME, NUM_GEST,
45
+ )
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Semantic class encoding: 11 structural + 1 other_house + 1 non_house = 13
49
+ # ---------------------------------------------------------------------------
50
+
51
+ # Each structural gestalt class gets its own one-hot bit.
52
+ STRUCTURAL_CLASSES = (
53
+ "apex", "eave_end_point", "flashing_end_point", # point classes (tier 0)
54
+ "rake", "ridge", "eave", "hip", "valley", # roof edges (tier 1)
55
+ "flashing", "step_flashing",
56
+ "roof", # roof face (tier 2)
57
+ )
58
+ # Index 11 = other house part (door, window, siding, etc.)
59
+ # Index 12 = non-house / ADE / unlabeled
60
+ NUM_SEMANTIC_CLASSES = len(STRUCTURAL_CLASSES) + 2 # 13
61
+
62
+ # Priority tiers (same as tokenizer.py)
63
+ _GEST_NAME_TO_ID = {n: i for i, n in enumerate(GEST_ID_TO_NAME)}
64
+ _POINT_IDS = {_GEST_NAME_TO_ID[n] for n in ("apex", "eave_end_point", "flashing_end_point") if n in _GEST_NAME_TO_ID}
65
+ _EDGE_IDS = {_GEST_NAME_TO_ID[n] for n in ("rake", "ridge", "eave", "hip", "valley", "flashing", "step_flashing") if n in _GEST_NAME_TO_ID}
66
+ _FACE_IDS = {_GEST_NAME_TO_ID[n] for n in ("roof",) if n in _GEST_NAME_TO_ID}
67
+ _HOUSE_IDS = {_GEST_NAME_TO_ID[n] for n in (
68
+ "apex", "eave_end_point", "flashing_end_point",
69
+ "rake", "ridge", "eave", "hip", "valley", "flashing", "step_flashing",
70
+ "roof", "door", "garage", "window", "shutter", "fascia", "soffit",
71
+ "horizontal_siding", "vertical_siding", "brick", "concrete",
72
+ "other_wall", "trim", "post", "ground_line",
73
+ ) if n in _GEST_NAME_TO_ID}
74
+
75
+ _ADE_NAME_TO_ID = {n.lower(): i for i, n in enumerate(ADE_ID_TO_NAME)}
76
+ _ADE_HOUSE_IDS = {_ADE_NAME_TO_ID[n] for n in ("building;edifice", "house", "wall", "windowpane;window", "door;double;door") if n in _ADE_NAME_TO_ID}
77
+
78
+ _UNCLS_ID = _GEST_NAME_TO_ID.get("unclassified", -1)
79
+
80
+ # Map structural gestalt names to one-hot index
81
+ _STRUCTURAL_ONEHOT = {}
82
+ for idx, name in enumerate(STRUCTURAL_CLASSES):
83
+ gid = _GEST_NAME_TO_ID.get(name)
84
+ if gid is not None:
85
+ _STRUCTURAL_ONEHOT[gid] = idx
86
+
87
+
88
+ def _compute_group_and_class(visible_src, visible_id, behind_id, source):
89
+ """Compute priority group_id and semantic class_id per point (vectorized).
90
+
91
+ Args:
92
+ visible_src: uint8 [P] -- 0=unlabeled, 1=gestalt, 2=ade
93
+ visible_id: int16 [P] -- class id within gestalt or ade space
94
+ behind_id: int16 [P] -- behind-gestalt id (-1 if none)
95
+ source: uint8 [P] -- 0=colmap, 1=depth
96
+
97
+ Returns:
98
+ group_id: int8 [P] -- priority tier 0-4, -1 for excluded (unclassified)
99
+ class_id: uint8 [P] -- one-hot class index 0-12
100
+ """
101
+ P = len(visible_src)
102
+ vsrc = visible_src.astype(np.int32)
103
+ vid = visible_id.astype(np.int32)
104
+ bid = behind_id.astype(np.int32)
105
+
106
+ # Effective gestalt id: prefer visible gestalt, fall back to behind
107
+ gest_id = np.full(P, -1, dtype=np.int32)
108
+ has_vis_gest = (vsrc == 1) & (vid >= 0)
109
+ has_behind = (bid >= 0) & ~has_vis_gest
110
+ gest_id[has_vis_gest] = vid[has_vis_gest]
111
+ gest_id[has_behind] = bid[has_behind]
112
+
113
+ # Exclude unclassified points
114
+ if _UNCLS_ID >= 0:
115
+ is_uncls = ((vsrc == 1) & (vid == _UNCLS_ID)) | (bid == _UNCLS_ID)
116
+ gest_id[is_uncls] = -1 # force excluded
117
+
118
+ # Build lookup arrays for gestalt id -> group and gestalt id -> class
119
+ max_gid = NUM_GEST
120
+ gid_to_group = np.full(max_gid, 4, dtype=np.int8) # default: tier 4
121
+ gid_to_class = np.full(max_gid, NUM_SEMANTIC_CLASSES - 1, dtype=np.uint8) # default: non-house
122
+
123
+ for gid in _POINT_IDS:
124
+ gid_to_group[gid] = 0
125
+ for gid in _EDGE_IDS:
126
+ gid_to_group[gid] = 1
127
+ for gid in _FACE_IDS:
128
+ gid_to_group[gid] = 2
129
+ for gid in _HOUSE_IDS - _POINT_IDS - _EDGE_IDS - _FACE_IDS:
130
+ gid_to_group[gid] = 3
131
+ for gid, onehot_idx in _STRUCTURAL_ONEHOT.items():
132
+ gid_to_class[gid] = onehot_idx
133
+ for gid in _HOUSE_IDS - set(_STRUCTURAL_ONEHOT.keys()):
134
+ gid_to_class[gid] = len(STRUCTURAL_CLASSES) # other_house
135
+
136
+ # Apply lookup for points with valid gestalt ids
137
+ has_gest = gest_id >= 0
138
+ group_id = np.full(P, 4, dtype=np.int8) # default: tier 4
139
+ class_id = np.full(P, NUM_SEMANTIC_CLASSES - 1, dtype=np.uint8) # default: non-house
140
+
141
+ group_id[has_gest] = gid_to_group[gest_id[has_gest]]
142
+ class_id[has_gest] = gid_to_class[gest_id[has_gest]]
143
+
144
+ # ADE house points (no gestalt) get tier 3 + class_id = other_house
145
+ ade_house_arr = np.array(sorted(_ADE_HOUSE_IDS), dtype=np.int32)
146
+ is_ade_house = ~has_gest & (vsrc == 2) & (vid >= 0) & np.isin(vid, ade_house_arr)
147
+ group_id[is_ade_house] = 3
148
+ class_id[is_ade_house] = len(STRUCTURAL_CLASSES) # other_house (index 11)
149
+
150
+ # Mark excluded points (unclassified) as -1
151
+ if _UNCLS_ID >= 0:
152
+ group_id[is_uncls] = -1
153
+ class_id[is_uncls] = NUM_SEMANTIC_CLASSES - 1
154
+
155
+ return group_id, class_id
156
+
157
+
158
+ def _compute_smart_center_scale(xyz, source, mad_k=2.5, percentile=95.0,
159
+ max_points=8000):
160
+ """Compute normalization center and scale from depth points with MAD filter."""
161
+ depth_mask = source == 1
162
+ ref = xyz[depth_mask] if depth_mask.any() else xyz
163
+ if ref.shape[0] == 0:
164
+ center = xyz.mean(axis=0)
165
+ scale = max(np.linalg.norm(xyz - center, axis=1).max(), 1e-6)
166
+ return center.astype(np.float32), np.float32(scale)
167
+
168
+ if ref.shape[0] > max_points:
169
+ idx = np.random.choice(ref.shape[0], max_points, replace=False)
170
+ ref = ref[idx]
171
+
172
+ center0 = np.median(ref, axis=0)
173
+ dist = np.linalg.norm(ref - center0, axis=1)
174
+ med = np.median(dist)
175
+ mad = max(np.median(np.abs(dist - med)), 1e-6)
176
+ inliers = dist <= (med + mad_k * mad)
177
+ if inliers.any():
178
+ ref = ref[inliers]
179
+
180
+ # Percentile bounding box
181
+ lo_f = (100.0 - percentile) * 0.5 / 100.0
182
+ sorted_v = np.sort(ref, axis=0)
183
+ n = sorted_v.shape[0]
184
+ lo_idx = max(0, min(n - 1, int(lo_f * (n - 1))))
185
+ hi_idx = max(0, min(n - 1, int((1.0 - lo_f) * (n - 1))))
186
+ low = sorted_v[lo_idx]
187
+ high = sorted_v[hi_idx]
188
+
189
+ center = 0.5 * (low + high)
190
+ scale = max(np.sqrt(((high - low) ** 2).sum()), 1e-6)
191
+ return center.astype(np.float32), np.float32(scale)
192
+
193
+
194
+ def _process_one(sample, cfg):
195
+ """Process a single HF sample into a cache dict. Returns (order_id, dict) or None."""
196
+ rng = np.random.RandomState() # worker-local rng
197
+
198
+ n_edges = len(sample.get("wf_edges", []))
199
+ if n_edges == 0 or n_edges > 64:
200
+ return None
201
+
202
+ scene = build_compact_scene(sample, cfg, rng=rng)
203
+ if scene is None:
204
+ return None
205
+
206
+ gt_v = scene.get("gt_vertices")
207
+ gt_e = scene.get("gt_edges")
208
+ if gt_v is None or gt_e is None or len(gt_e) == 0:
209
+ return None
210
+
211
+ xyz = scene["xyz"]
212
+ source = scene["source"]
213
+ visible_src = scene["visible_src"]
214
+ visible_id = scene["visible_id"]
215
+ behind_id = scene["behind_gest_id"]
216
+
217
+ group_id, class_id = _compute_group_and_class(
218
+ visible_src, visible_id, behind_id, source
219
+ )
220
+
221
+ center, scale = _compute_smart_center_scale(xyz, source)
222
+
223
+ order_id = sample.get("order_id", "unknown")
224
+
225
+ return order_id, {
226
+ "xyz": xyz.astype(np.float32),
227
+ "source": source.astype(np.uint8),
228
+ "group_id": group_id,
229
+ "class_id": class_id,
230
+ "behind_gest_id": behind_id.astype(np.int16),
231
+ "visible_src": visible_src.astype(np.uint8),
232
+ "visible_id": visible_id.astype(np.int16),
233
+ "n_views_voted": scene["n_views_voted"],
234
+ "vote_frac": scene["vote_frac"],
235
+ "center": center,
236
+ "scale": scale,
237
+ "gt_vertices": gt_v.astype(np.float32),
238
+ "gt_edges": gt_e.astype(np.int32),
239
+ }
240
+
241
+
242
+ def main():
243
+ p = argparse.ArgumentParser(description="Cache compact scenes from HoHo22k")
244
+ g = p.add_mutually_exclusive_group(required=True)
245
+ g.add_argument("--data-dir", help="Local dir with shards")
246
+ g.add_argument("--streaming", action="store_true", help="Stream from HuggingFace")
247
+ p.add_argument("--out-dir", required=True, help="Output directory for .pt files")
248
+ p.add_argument("--limit", type=int, default=0)
249
+ p.add_argument("--depth-per-view", type=int, default=8000)
250
+ p.add_argument("--workers", type=int, default=0,
251
+ help="Parallel workers (0=sequential)")
252
+ p.add_argument("--skip-existing", action="store_true",
253
+ help="Skip samples whose .pt already exists in out-dir")
254
+ p.add_argument("--shard-start", type=int, default=0,
255
+ help="First shard index (for parallel launches)")
256
+ p.add_argument("--shard-stride", type=int, default=1,
257
+ help="Stride between shards (e.g. 8 means take every 8th shard)")
258
+ args = p.parse_args()
259
+
260
+ out_dir = Path(args.out_dir)
261
+ out_dir.mkdir(parents=True, exist_ok=True)
262
+ existing_ids = set(p.stem for p in out_dir.glob("*.pt")) if args.skip_existing else set()
263
+
264
+ # Load dataset
265
+ from datasets import load_dataset
266
+ if args.streaming:
267
+ ds = load_dataset(
268
+ "usm3d/hoho22k_2026_trainval",
269
+ streaming=True, trust_remote_code=True, split="train",
270
+ )
271
+ else:
272
+ data_root = Path(args.data_dir).resolve()
273
+ tars = []
274
+ for candidate in [data_root / "data" / "train", data_root / "train", data_root]:
275
+ if candidate.exists():
276
+ tars = sorted(str(p) for p in candidate.glob("*.tar"))
277
+ if tars:
278
+ break
279
+ loader = None
280
+ for c in [data_root / "hoho22k_2026_trainval.py"]:
281
+ if c.exists():
282
+ loader = c
283
+ break
284
+ if loader is None:
285
+ found = list(data_root.rglob("hoho22k_2026_trainval.py"))
286
+ loader = found[0] if found else None
287
+ if loader is None:
288
+ raise FileNotFoundError("Cannot find loader script")
289
+ # Shard-level parallelism: each process handles a slice of tars
290
+ if args.shard_stride > 1:
291
+ tars = tars[args.shard_start::args.shard_stride]
292
+ print(f"Shard slice: start={args.shard_start} stride={args.shard_stride} -> {len(tars)} shards")
293
+ ds = load_dataset(str(loader), data_files={"train": tars},
294
+ streaming=True, trust_remote_code=True, split="train")
295
+
296
+ cfg = FuserConfig(depth_points_per_view=args.depth_per_view)
297
+
298
+ saved = 0
299
+ skipped = 0
300
+ t_start = time.perf_counter()
301
+
302
+ if args.workers > 0:
303
+ # Parallel: collect samples into batches, process in worker pool
304
+ # Note: HF streaming datasets can't be shared across workers, so we
305
+ # iterate in the main thread and dispatch processing to workers.
306
+ with ProcessPoolExecutor(max_workers=args.workers) as pool:
307
+ futures = {}
308
+ for i, sample in enumerate(ds):
309
+ if args.limit > 0 and i >= args.limit:
310
+ break
311
+ oid = sample.get("order_id", "unknown")
312
+ if oid in existing_ids:
313
+ skipped += 1
314
+ continue
315
+ future = pool.submit(_process_one, sample, cfg)
316
+ futures[future] = i
317
+
318
+ # Drain completed futures to bound memory
319
+ if len(futures) >= args.workers * 4:
320
+ done = [f for f in futures if f.done()]
321
+ for f in done:
322
+ result = f.result()
323
+ del futures[f]
324
+ if result is None:
325
+ skipped += 1
326
+ continue
327
+ order_id, data = result
328
+ torch.save(data, out_dir / f"{order_id}.pt")
329
+ saved += 1
330
+ if saved % 50 == 0:
331
+ elapsed = time.perf_counter() - t_start
332
+ print(f"Saved {saved} (skipped {skipped}) "
333
+ f"[{saved / elapsed:.1f} samples/s]")
334
+
335
+ # Drain remaining
336
+ for f in as_completed(futures):
337
+ result = f.result()
338
+ if result is None:
339
+ skipped += 1
340
+ continue
341
+ order_id, data = result
342
+ torch.save(data, out_dir / f"{order_id}.pt")
343
+ saved += 1
344
+ else:
345
+ # Sequential
346
+ for i, sample in enumerate(ds):
347
+ if args.limit > 0 and i >= args.limit:
348
+ break
349
+ oid = sample.get("order_id", "unknown")
350
+ if oid in existing_ids:
351
+ skipped += 1
352
+ continue
353
+
354
+ result = _process_one(sample, cfg)
355
+ if result is None:
356
+ skipped += 1
357
+ continue
358
+ order_id, data = result
359
+ torch.save(data, out_dir / f"{order_id}.pt")
360
+ saved += 1
361
+
362
+ if saved % 50 == 0:
363
+ elapsed = time.perf_counter() - t_start
364
+ print(f"Saved {saved} (skipped {skipped}) "
365
+ f"[{saved / elapsed:.1f} samples/s]")
366
+
367
+ elapsed = time.perf_counter() - t_start
368
+ print(f"Done. Saved {saved}, skipped {skipped} in {elapsed:.0f}s "
369
+ f"({saved / elapsed:.1f} samples/s)")
370
+
371
+
372
+ if __name__ == "__main__":
373
+ main()
s23dr_2026_example/color_mappings.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gestalt_color_mapping = {
2
+ "unclassified": (215, 62, 138),
3
+ "apex": (235, 88, 48),
4
+ "eave_end_point": (248, 130, 228),
5
+ "flashing_end_point": (71, 11, 161),
6
+ "ridge": (214, 251, 248),
7
+ "rake": (13, 94, 47),
8
+ "eave": (54, 243, 63),
9
+ "post": (187, 123, 236),
10
+ "ground_line": (136, 206, 14),
11
+ "flashing": (162, 162, 32),
12
+ "step_flashing": (169, 255, 219),
13
+ "hip": (8, 89, 52),
14
+ "valley": (85, 27, 65),
15
+ "roof": (215, 232, 179),
16
+ "door": (110, 52, 23),
17
+ "garage": (50, 233, 171),
18
+ "window": (230, 249, 40),
19
+ "shutter": (122, 4, 233),
20
+ "fascia": (95, 230, 240),
21
+ "soffit": (2, 102, 197),
22
+ "horizontal_siding": (131, 88, 59),
23
+ "vertical_siding": (110, 187, 198),
24
+ "brick": (171, 252, 7),
25
+ "concrete": (32, 47, 246),
26
+ "other_wall": (112, 61, 240),
27
+ "trim": (151, 206, 58),
28
+ "unknown": (127, 127, 127),
29
+ "transition_line": (0,0,0),
30
+ }
31
+
32
+ ade20k_color_mapping = {
33
+ 'wall': (120, 120, 120),
34
+ 'building;edifice': (180, 120, 120),
35
+ 'sky': (6, 230, 230),
36
+ 'floor;flooring': (80, 50, 50),
37
+ 'tree': (4, 200, 3),
38
+ 'ceiling': (120, 120, 80),
39
+ 'road;route': (140, 140, 140),
40
+ 'bed': (204, 5, 255),
41
+ 'windowpane;window': (230, 230, 230),
42
+ 'grass': (4, 250, 7),
43
+ 'cabinet': (224, 5, 255),
44
+ 'sidewalk;pavement': (235, 255, 7),
45
+ 'person;individual;someone;somebody;mortal;soul': (150, 5, 61),
46
+ 'earth;ground': (120, 120, 70),
47
+ 'door;double;door': (8, 255, 51),
48
+ 'table': (255, 6, 82),
49
+ 'mountain;mount': (143, 255, 140),
50
+ 'plant;flora;plant;life': (204, 255, 4),
51
+ 'curtain;drape;drapery;mantle;pall': (255, 51, 7),
52
+ 'chair': (204, 70, 3),
53
+ 'car;auto;automobile;machine;motorcar': (0, 102, 200),
54
+ 'water': (61, 230, 250),
55
+ 'painting;picture': (255, 6, 51),
56
+ 'sofa;couch;lounge': (11, 102, 255),
57
+ 'shelf': (255, 7, 71),
58
+ 'house': (255, 9, 224),
59
+ 'sea': (9, 7, 230),
60
+ 'mirror': (220, 220, 220),
61
+ 'rug;carpet;carpeting': (255, 9, 92),
62
+ 'field': (112, 9, 255),
63
+ 'armchair': (8, 255, 214),
64
+ 'seat': (7, 255, 224),
65
+ 'fence;fencing': (255, 184, 6),
66
+ 'desk': (10, 255, 71),
67
+ 'rock;stone': (255, 41, 10),
68
+ 'wardrobe;closet;press': (7, 255, 255),
69
+ 'lamp': (224, 255, 8),
70
+ 'bathtub;bathing;tub;bath;tub': (102, 8, 255),
71
+ 'railing;rail': (255, 61, 6),
72
+ 'cushion': (255, 194, 7),
73
+ 'base;pedestal;stand': (255, 122, 8),
74
+ 'box': (0, 255, 20),
75
+ 'column;pillar': (255, 8, 41),
76
+ 'signboard;sign': (255, 5, 153),
77
+ 'chest;of;drawers;chest;bureau;dresser': (6, 51, 255),
78
+ 'counter': (235, 12, 255),
79
+ 'sand': (160, 150, 20),
80
+ 'sink': (0, 163, 255),
81
+ 'skyscraper': (140, 140, 140),
82
+ 'fireplace;hearth;open;fireplace': (250, 10, 15),
83
+ 'refrigerator;icebox': (20, 255, 0),
84
+ 'grandstand;covered;stand': (31, 255, 0),
85
+ 'path': (255, 31, 0),
86
+ 'stairs;steps': (255, 224, 0),
87
+ 'runway': (153, 255, 0),
88
+ 'case;display;case;showcase;vitrine': (0, 0, 255),
89
+ 'pool;table;billiard;table;snooker;table': (255, 71, 0),
90
+ 'pillow': (0, 235, 255),
91
+ 'screen;door;screen': (0, 173, 255),
92
+ 'stairway;staircase': (31, 0, 255),
93
+ 'river': (11, 200, 200),
94
+ 'bridge;span': (255 ,82, 0),
95
+ 'bookcase': (0, 255, 245),
96
+ 'blind;screen': (0, 61, 255),
97
+ 'coffee;table;cocktail;table': (0, 255, 112),
98
+ 'toilet;can;commode;crapper;pot;potty;stool;throne': (0, 255, 133),
99
+ 'flower': (255, 0, 0),
100
+ 'book': (255, 163, 0),
101
+ 'hill': (255, 102, 0),
102
+ 'bench': (194, 255, 0),
103
+ 'countertop': (0, 143, 255),
104
+ 'stove;kitchen;stove;range;kitchen;range;cooking;stove': (51, 255, 0),
105
+ 'palm;palm;tree': (0, 82, 255),
106
+ 'kitchen;island': (0, 255, 41),
107
+ 'computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system': (0, 255, 173),
108
+ 'swivel;chair': (10, 0, 255),
109
+ 'boat': (173, 255, 0),
110
+ 'bar': (0, 255, 153),
111
+ 'arcade;machine': (255, 92, 0),
112
+ 'hovel;hut;hutch;shack;shanty': (255, 0, 255),
113
+ 'bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle': (255, 0, 245),
114
+ 'towel': (255, 0, 102),
115
+ 'light;light;source': (255, 173, 0),
116
+ 'truck;motortruck': (255, 0, 20),
117
+ 'tower': (255, 184, 184),
118
+ 'chandelier;pendant;pendent': (0, 31, 255),
119
+ 'awning;sunshade;sunblind': (0, 255, 61),
120
+ 'streetlight;street;lamp': (0, 71, 255),
121
+ 'booth;cubicle;stall;kiosk': (255, 0, 204),
122
+ 'television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box': (0, 255, 194),
123
+ 'airplane;aeroplane;plane': (0, 255, 82),
124
+ 'dirt;track': (0, 10, 255),
125
+ 'apparel;wearing;apparel;dress;clothes': (0, 112, 255),
126
+ 'pole': (51, 0, 255),
127
+ 'land;ground;soil': (0, 194, 255),
128
+ 'bannister;banister;balustrade;balusters;handrail': (0, 122, 255),
129
+ 'escalator;moving;staircase;moving;stairway': (0, 255, 163),
130
+ 'ottoman;pouf;pouffe;puff;hassock': (255, 153, 0),
131
+ 'bottle': (0, 255, 10),
132
+ 'buffet;counter;sideboard': (255, 112, 0),
133
+ 'poster;posting;placard;notice;bill;card': (143, 255, 0),
134
+ 'stage': (82, 0, 255),
135
+ 'van': (163, 255, 0),
136
+ 'ship': (255, 235, 0),
137
+ 'fountain': (8, 184, 170),
138
+ 'conveyer;belt;conveyor;belt;conveyer;conveyor;transporter': (133, 0, 255),
139
+ 'canopy': (0, 255, 92),
140
+ 'washer;automatic;washer;washing;machine': (184, 0, 255),
141
+ 'plaything;toy': (255, 0, 31),
142
+ 'swimming;pool;swimming;bath;natatorium': (0, 184, 255),
143
+ 'stool': (0, 214, 255),
144
+ 'barrel;cask': (255, 0, 112),
145
+ 'basket;handbasket': (92, 255, 0),
146
+ 'waterfall;falls': (0, 224, 255),
147
+ 'tent;collapsible;shelter': (112, 224, 255),
148
+ 'bag': (70, 184, 160),
149
+ 'minibike;motorbike': (163, 0, 255),
150
+ 'cradle': (153, 0, 255),
151
+ 'oven': (71, 255, 0),
152
+ 'ball': (255, 0, 163),
153
+ 'food;solid;food': (255, 204, 0),
154
+ 'step;stair': (255, 0, 143),
155
+ 'tank;storage;tank': (0, 255, 235),
156
+ 'trade;name;brand;name;brand;marque': (133, 255, 0),
157
+ 'microwave;microwave;oven': (255, 0, 235),
158
+ 'pot;flowerpot': (245, 0, 255),
159
+ 'animal;animate;being;beast;brute;creature;fauna': (255, 0, 122),
160
+ 'bicycle;bike;wheel;cycle': (255, 245, 0),
161
+ 'lake': (10, 190, 212),
162
+ 'dishwasher;dish;washer;dishwashing;machine': (214, 255, 0),
163
+ 'screen;silver;screen;projection;screen': (0, 204, 255),
164
+ 'blanket;cover': (20, 0, 255),
165
+ 'sculpture': (255, 255, 0),
166
+ 'hood;exhaust;hood': (0, 153, 255),
167
+ 'sconce': (0, 41, 255),
168
+ 'vase': (0, 255, 204),
169
+ 'traffic;light;traffic;signal;stoplight': (41, 0, 255),
170
+ 'tray': (41, 255, 0),
171
+ 'ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin': (173, 0, 255),
172
+ 'fan': (0, 245, 255),
173
+ 'pier;wharf;wharfage;dock': (71, 0, 255),
174
+ 'crt;screen': (122, 0, 255),
175
+ 'plate': (0, 255, 184),
176
+ 'monitor;monitoring;device': (0, 92, 255),
177
+ 'bulletin;board;notice;board': (184, 255, 0),
178
+ 'shower': (0, 133, 255),
179
+ 'radiator': (255, 214, 0),
180
+ 'glass;drinking;glass': (25, 194, 194),
181
+ 'clock': (102, 255, 0),
182
+ 'flag': (92, 0, 255),
183
+ }
184
+
185
+
186
+ EDGE_CLASSES = {'cornice_return': 0,
187
+ 'cornice_strip': 1,
188
+ 'eave': 2,
189
+ 'flashing': 3,
190
+ 'hip': 4,
191
+ 'rake': 5,
192
+ 'ridge': 6,
193
+ 'step_flashing': 7,
194
+ 'transition_line': 8,
195
+ 'valley': 9}
196
+ EDGE_CLASSES_BY_ID = {v: k for k, v in EDGE_CLASSES.items()}
197
+
198
+ edge_color_mapping = {
199
+ 'cornice_return': (215, 62, 138),
200
+ 'cornice_strip': (235, 88, 48),
201
+ 'eave': (54, 243, 63),
202
+ "flashing": (162, 162, 32),
203
+ 'hip': (8, 89, 52),
204
+ 'rake': (13, 94, 47),
205
+ 'ridge': (214, 251, 248),
206
+ "step_flashing": (169, 255, 219),
207
+ 'transition_line': (200,0,50),
208
+ 'valley': (85, 27, 65),
209
+ }
s23dr_2026_example/data.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loading for pre-sampled HF datasets.
2
+
3
+ Expects pre-sampled npz blobs with xyz_norm [2048, 3] (not full PCD).
4
+ Use make_sampled_cache.py to produce these from full point clouds.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from .tokenizer import EdgeDepthSequenceConfig
14
+
15
+ # Default token budget (must match make_sampled_cache.py)
16
+ SEQ_LEN = 2048
17
+ COLMAP_POINTS = 1536
18
+ DEPTH_POINTS = 512
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Datasets
23
+ # ---------------------------------------------------------------------------
24
+
25
+ def _load_bad_sample_ids():
26
+ """Load the set of known-bad sample IDs (misaligned GT, extreme scale)."""
27
+ bad_file = Path(__file__).parent / "bad_samples.txt"
28
+ if not bad_file.exists():
29
+ return set()
30
+ return set(line.strip() for line in bad_file.read_text().splitlines() if line.strip())
31
+
32
+
33
+ class HFCachedDataset(torch.utils.data.Dataset):
34
+ """Load pre-sampled HuggingFace dataset into memory."""
35
+
36
+ def __init__(self, hf_dataset, aug_rotate=False, aug_jitter=0.0,
37
+ aug_drop=0.0, aug_flip=False):
38
+ import io as _io
39
+ bad_ids = _load_bad_sample_ids()
40
+ print(f"Pre-decoding {len(hf_dataset)} samples into memory...")
41
+ self.samples = []
42
+ self.order_ids = []
43
+ n_skipped = 0
44
+ for i, sample in enumerate(hf_dataset):
45
+ if sample["order_id"] in bad_ids:
46
+ n_skipped += 1
47
+ continue
48
+ d = dict(np.load(_io.BytesIO(sample["data"])))
49
+ if "xyz_norm" not in d:
50
+ raise ValueError(
51
+ f"Sample {sample['order_id']} missing 'xyz_norm' -- this looks like "
52
+ f"a full PCD dataset, not pre-sampled. Use make_sampled_cache.py first.")
53
+ self.samples.append(d)
54
+ self.order_ids.append(sample["order_id"])
55
+ if (i + 1) % 2000 == 0:
56
+ print(f" {i+1}/{len(hf_dataset)}...")
57
+ print(f" Done. {len(self.samples)} samples in memory"
58
+ f" ({n_skipped} bad samples filtered).")
59
+ self.aug_rotate = aug_rotate
60
+ self.aug_jitter = aug_jitter
61
+ self.aug_drop = aug_drop
62
+ self.aug_flip = aug_flip
63
+
64
+ def __len__(self):
65
+ return len(self.samples)
66
+
67
+ def __getitem__(self, idx):
68
+ out = _process_sample(self.samples[idx], self.aug_rotate,
69
+ self.aug_jitter, self.aug_drop, self.aug_flip)
70
+ out["sample_id"] = self.order_ids[idx]
71
+ return out
72
+
73
+
74
+ def _process_sample(d, aug_rotate, aug_jitter=0.0, aug_drop=0.0, aug_flip=False):
75
+ """Process a pre-sampled npz dict into training tensors.
76
+
77
+ Args:
78
+ aug_rotate: random yaw rotation
79
+ aug_jitter: std of Gaussian noise added to point positions (0=disabled)
80
+ aug_drop: fraction of points to randomly drop (0=disabled)
81
+ aug_flip: random mirror along X axis (50% chance)
82
+ """
83
+ xyz_norm = d["xyz_norm"].copy()
84
+ gt_seg = d["gt_segments"].copy()
85
+ mask = d["mask"].copy()
86
+
87
+ if aug_rotate:
88
+ theta = np.random.rand() * 2 * np.pi
89
+ cos_t, sin_t = np.cos(theta), np.sin(theta)
90
+ x, z = xyz_norm[:, 0].copy(), xyz_norm[:, 2].copy()
91
+ xyz_norm[:, 0] = x * cos_t - z * sin_t
92
+ xyz_norm[:, 2] = x * sin_t + z * cos_t
93
+ for ep in range(2):
94
+ sx, sz = gt_seg[:, ep, 0].copy(), gt_seg[:, ep, 2].copy()
95
+ gt_seg[:, ep, 0] = sx * cos_t - sz * sin_t
96
+ gt_seg[:, ep, 2] = sx * sin_t + sz * cos_t
97
+
98
+ if aug_flip and np.random.rand() < 0.5:
99
+ xyz_norm[:, 0] = -xyz_norm[:, 0]
100
+ gt_seg[:, :, 0] = -gt_seg[:, :, 0]
101
+
102
+ if aug_jitter > 0:
103
+ valid = mask.astype(bool)
104
+ xyz_norm[valid] += np.random.randn(valid.sum(), 3).astype(np.float32) * aug_jitter
105
+
106
+ if aug_drop > 0:
107
+ valid_idx = np.where(mask)[0]
108
+ n_drop = int(len(valid_idx) * aug_drop)
109
+ if n_drop > 0:
110
+ drop_idx = np.random.choice(valid_idx, n_drop, replace=False)
111
+ mask[drop_idx] = False
112
+
113
+ result = {
114
+ "xyz_norm": torch.as_tensor(xyz_norm, dtype=torch.float32),
115
+ "class_id": torch.as_tensor(d["class_id"], dtype=torch.long),
116
+ "source": torch.as_tensor(d["source"], dtype=torch.long),
117
+ "mask": torch.as_tensor(mask),
118
+ "gt_segments": torch.as_tensor(gt_seg, dtype=torch.float32),
119
+ "scale": torch.tensor(float(d["scale"]), dtype=torch.float32),
120
+ "center": torch.as_tensor(d["center"], dtype=torch.float32),
121
+ "gt_vertices": d["gt_vertices"],
122
+ "gt_edges": d["gt_edges"],
123
+ "visible_src": torch.as_tensor(d["visible_src"], dtype=torch.long),
124
+ "visible_id": torch.as_tensor(d["visible_id"], dtype=torch.long),
125
+ }
126
+ if "behind" in d:
127
+ result["behind"] = torch.as_tensor(
128
+ np.clip(np.asarray(d["behind"], dtype=np.int16), 0, None), dtype=torch.long)
129
+ if "n_views_voted" in d:
130
+ result["n_views_voted"] = torch.as_tensor(d["n_views_voted"], dtype=torch.float32)
131
+ if "vote_frac" in d:
132
+ result["vote_frac"] = torch.as_tensor(d["vote_frac"], dtype=torch.float32)
133
+ if "gt_edge_classes" in d:
134
+ result["gt_edge_classes"] = torch.as_tensor(
135
+ np.asarray(d["gt_edge_classes"], dtype=np.int64), dtype=torch.long)
136
+ return result
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Collation + DataLoader
141
+ # ---------------------------------------------------------------------------
142
+
143
+ def collate(batch):
144
+ """Stack samples into batched tensors."""
145
+ out = {
146
+ "xyz_norm": torch.stack([d["xyz_norm"] for d in batch]),
147
+ "class_id": torch.stack([d["class_id"] for d in batch]),
148
+ "source": torch.stack([d["source"] for d in batch]),
149
+ "mask": torch.stack([d["mask"] for d in batch]),
150
+ "gt_segments": [d["gt_segments"] for d in batch],
151
+ "scales": torch.stack([d["scale"] for d in batch]),
152
+ "meta": batch,
153
+ }
154
+ # Optional fields: check ALL samples, not just batch[0].
155
+ # If any sample has it, all must have it (no mixed data versions).
156
+ for field in ("behind", "n_views_voted", "vote_frac"):
157
+ if any(field in d for d in batch):
158
+ missing = [i for i, d in enumerate(batch) if field not in d]
159
+ if missing:
160
+ raise KeyError(
161
+ f"Field '{field}' present in some batch samples but missing in "
162
+ f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
163
+ out[field] = torch.stack([d[field] for d in batch])
164
+ # gt_edge_classes: variable length per sample (like gt_segments), keep as list
165
+ if any("gt_edge_classes" in d for d in batch):
166
+ missing = [i for i, d in enumerate(batch) if "gt_edge_classes" not in d]
167
+ if missing:
168
+ raise KeyError(
169
+ f"Field 'gt_edge_classes' present in some batch samples but missing in "
170
+ f"{len(missing)}/{len(batch)}. Mixed data versions in cache?")
171
+ out["gt_edge_classes"] = [d["gt_edge_classes"] for d in batch]
172
+ return out
173
+
174
+
175
+ def build_loader(cache_dir, batch_size, aug_rotate=False, aug_jitter=0.0,
176
+ aug_drop=0.0, aug_flip=False):
177
+ """Create a DataLoader from HF dataset.
178
+
179
+ cache_dir should be 'hf://repo/name:split' format.
180
+ """
181
+ if not cache_dir.startswith("hf://"):
182
+ raise ValueError(
183
+ f"cache_dir must be 'hf://repo:split' format, got: {cache_dir}. "
184
+ f"Local .pt caches are no longer supported in the training path.")
185
+ parts = cache_dir[5:].split(":")
186
+ repo = parts[0]
187
+ split = parts[1] if len(parts) > 1 else "train"
188
+ from datasets import load_dataset
189
+ hf_ds = load_dataset(repo, split=split)
190
+ ds = HFCachedDataset(hf_ds, aug_rotate=aug_rotate, aug_jitter=aug_jitter,
191
+ aug_drop=aug_drop, aug_flip=aug_flip)
192
+ loader = torch.utils.data.DataLoader(
193
+ ds, batch_size=batch_size, shuffle=True,
194
+ num_workers=0, collate_fn=collate,
195
+ )
196
+ print(f"Dataset: {len(ds)} scenes, batch_size={batch_size}")
197
+ return loader
198
+
199
+
200
+ # ---------------------------------------------------------------------------
201
+ # Token building (GPU)
202
+ # ---------------------------------------------------------------------------
203
+
204
+ def build_tokens(batch, model, device):
205
+ """Apply Fourier features + learned embeddings on GPU."""
206
+ xyz = batch["xyz_norm"].to(device)
207
+ cid = batch["class_id"].to(device)
208
+ src = batch["source"].to(device)
209
+ masks = batch["mask"].to(device)
210
+ gt = [g.to(device) for g in batch["gt_segments"]]
211
+ scales = batch["scales"]
212
+
213
+ B, T, _ = xyz.shape
214
+ tok = model.tokenizer
215
+ fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \
216
+ if tok.pos_enc is not None else xyz.new_zeros(B, T, 0)
217
+ parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))]
218
+ if tok.behind_emb_dim > 0:
219
+ if "behind" in batch:
220
+ beh = batch["behind"].to(device)
221
+ else:
222
+ # Data doesn't have behind -- use zeros (embed index 0).
223
+ # This is intentional for eval on old data; for training,
224
+ # fail fast by requiring the field (checked in _process_sample).
225
+ beh = xyz.new_zeros(B, T, dtype=torch.long)
226
+ parts.append(tok.behind_emb(beh))
227
+ if tok.use_vote_features:
228
+ if "n_views_voted" not in batch or "vote_frac" not in batch:
229
+ raise KeyError(
230
+ "Model expects vote features (--vote-features) but data is missing "
231
+ "'n_views_voted'/'vote_frac'. Use v2 dataset or regenerate cache.")
232
+ # Normalize to ~zero mean, unit variance (dataset stats: nv~2.7+/-1.0, vf~0.5+/-0.25)
233
+ nv = ((batch["n_views_voted"].to(device).float() - 2.7) / 1.0).unsqueeze(-1)
234
+ vf = ((batch["vote_frac"].to(device).float() - 0.5) / 0.25).unsqueeze(-1)
235
+ parts.extend([nv, vf])
236
+ tokens = torch.cat(parts, dim=-1)
237
+ return tokens, masks, gt, scales, batch["meta"]
s23dr_2026_example/losses.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loss computation for wireframe prediction."""
2
+ from __future__ import annotations
3
+
4
+ import torch
5
+
6
+ from .varifold import varifold_loss_batch
7
+ from .sinkhorn import batched_sinkhorn_loss
8
+ from .soft_hss_loss import batched_sinkhorn_vertex_f1, batched_soft_hss_v2
9
+
10
+ # Varifold config
11
+ VARIANT = "simpson3"
12
+ SIGMAS = [0.5, 1.0, 2.0] # meters (divided by per-scene scale at runtime)
13
+ ALPHAS = [0.2, 0.6, 0.2]
14
+ LEN_POW = 1.0
15
+ VARIFOLD_CROSS_ONLY = False # Set to True to drop self-energy (avoids O(S^2) blowup)
16
+
17
+ # Sinkhorn config (note: near-zero gradients at eps=0.05, effectively disabled)
18
+ SINKHORN_EPS = 0.05
19
+ SINKHORN_ITERS = 10
20
+
21
+ # Distance thresholds in meters (divided by per-scene scale at runtime)
22
+ VERTEX_THRESH_M = 0.5 # vertex match threshold (mirrors real HSS)
23
+ TUBE_RADIUS_M = 0.5 # tube IoU radius (mirrors real HSS)
24
+
25
+ # Sinkhorn dustbin cost: controls the OT "not matching" penalty.
26
+ # Like tau, this is an OT behavior parameter, NOT a physical distance.
27
+ # Must be comparable to typical matching costs in normalized space (~0.1).
28
+ # Do NOT divide by scale.
29
+ SINKHORN_DUSTBIN = 0.1
30
+
31
+ # Sigmoid temperature: controls gradient smoothness, NOT a distance threshold.
32
+ # Must stay large enough in normalized space to provide useful gradients.
33
+ # Do NOT divide by scale (unlike the thresholds above).
34
+ SIGMOID_TAU = 0.05
35
+
36
+ MAX_GT = 64 # fixed pad size for compile-friendly shapes
37
+
38
+ # Precomputed constants (created once on first call)
39
+ _loss_constants = {}
40
+
41
+
42
+ def _get_loss_constants(device, dtype):
43
+ key = (device, dtype)
44
+ if key not in _loss_constants:
45
+ _loss_constants[key] = {
46
+ "sigmas": torch.tensor(SIGMAS, device=device, dtype=dtype),
47
+ "alphas": torch.tensor(ALPHAS, device=device, dtype=dtype),
48
+ }
49
+ return _loss_constants[key]
50
+
51
+
52
+ def pad_gt_fixed(gt_list, device, dtype):
53
+ """Pad GT segments to fixed MAX_GT for compile-friendly shapes."""
54
+ B = len(gt_list)
55
+ gt_pad = torch.zeros((B, MAX_GT, 2, 3), device=device, dtype=dtype)
56
+ gt_mask = torch.zeros((B, MAX_GT), device=device, dtype=torch.bool)
57
+ gt_lengths = torch.zeros(B, device=device, dtype=dtype)
58
+ for i, g in enumerate(gt_list):
59
+ n = g.shape[0]
60
+ if n > 0:
61
+ gt_pad[i, :n] = g
62
+ gt_mask[i, :n] = True
63
+ gt_lengths[i] = torch.linalg.norm(g[:, 1] - g[:, 0], dim=-1).sum()
64
+ return gt_pad, gt_mask, gt_lengths
65
+
66
+
67
+ def _loss_inner(pred_segments, gt_pad, gt_mask, gt_lengths, scales,
68
+ sigmas, alphas, varifold_w, vertex_f1_w):
69
+ """Pure tensor loss -- no Python control flow, no boolean indexing."""
70
+ has_gt = (gt_lengths > 0).float()
71
+
72
+ sigmas_eff = sigmas / scales[:, None]
73
+ loss_batch = varifold_loss_batch(
74
+ pred_segments, gt_pad, gt_mask=gt_mask,
75
+ variant=VARIANT, sigmas=sigmas_eff, alpha=alphas, len_pow=LEN_POW,
76
+ cross_only=VARIFOLD_CROSS_ONLY,
77
+ )
78
+ v = loss_batch / gt_lengths.clamp(min=1.0)
79
+ v = (v * has_gt).sum() / has_gt.sum().clamp(min=1.0)
80
+
81
+ thresh = VERTEX_THRESH_M / scales
82
+ f1 = batched_sinkhorn_vertex_f1(
83
+ pred_segments, gt_pad, gt_mask, thresh=thresh, tau=SIGMOID_TAU)
84
+ f1 = (f1 * has_gt).sum() / has_gt.sum().clamp(min=1.0)
85
+
86
+ total = varifold_w * v + vertex_f1_w * f1
87
+ return total, v, f1
88
+
89
+
90
+ # Will be replaced with compiled version on CUDA
91
+ _loss_fn = _loss_inner
92
+
93
+
94
+ def _conf_match_loss(pred_segments, gt_pad, gt_mask, conf_logits, scales):
95
+ """Auxiliary BCE loss: train conf to predict whether each segment matches GT.
96
+
97
+ Computes per-segment min-distance to GT, creates soft match target via
98
+ sigmoid thresholding, and returns BCE(sigmoid(conf), target).
99
+ """
100
+ B, S = pred_segments.shape[:2]
101
+ # Decoupled cost: midpoint + direction + length (same as sinkhorn)
102
+ p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
103
+ g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
104
+ mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
105
+ mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
106
+ d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1)
107
+ len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
108
+ len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
109
+ dir_p = half_p / len_p
110
+ dir_g = half_g / len_g
111
+ cos_angle = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1)
112
+ d_dir = 1.0 - cos_angle.abs()
113
+ d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs()
114
+ cost = d_mid + d_dir + d_len # [B, S, M]
115
+
116
+ # Mask invalid GT with high cost
117
+ cost = torch.where(gt_mask.unsqueeze(1), cost, cost.new_tensor(1e6))
118
+ min_dist = cost.min(dim=2).values # [B, S]
119
+
120
+ # Soft target: sigmoid((thresh - dist) / tau), in normalized space
121
+ thresh = VERTEX_THRESH_M / scales # [B]
122
+ target = torch.sigmoid((thresh[:, None] - min_dist) / SIGMOID_TAU)
123
+
124
+ return torch.nn.functional.binary_cross_entropy_with_logits(
125
+ conf_logits, target.detach(), reduction="mean")
126
+
127
+
128
+ def compute_loss(pred_segments, gt_list, scales, device,
129
+ varifold_w, sinkhorn_w, vertex_f1_w=0.0, soft_hss_w=0.0,
130
+ endpoint_w=0.0,
131
+ conf_logits=None, conf_weight=0.0, conf_mode="match",
132
+ sinkhorn_eps=None, sinkhorn_iters=None,
133
+ sinkhorn_dustbin=None, conf_clamp_min=None):
134
+ """Combined loss with fixed-size GT padding.
135
+
136
+ conf_mode: "match" = BCE matching supervision, "sinkhorn" = conf-weighted sinkhorn.
137
+ """
138
+ if conf_logits is not None and conf_clamp_min is not None:
139
+ conf_logits = conf_logits.clamp(min=conf_clamp_min)
140
+ gt_pad, gt_mask, gt_lengths = pad_gt_fixed(gt_list, device, pred_segments.dtype)
141
+ c = _get_loss_constants(device, pred_segments.dtype)
142
+
143
+ total, v, f1 = _loss_fn(
144
+ pred_segments, gt_pad, gt_mask, gt_lengths, scales,
145
+ c["sigmas"], c["alphas"], varifold_w, vertex_f1_w)
146
+
147
+ terms = {}
148
+ if varifold_w > 0:
149
+ terms["varifold"] = v.detach()
150
+ if vertex_f1_w > 0:
151
+ terms["vertex_f1"] = f1.detach()
152
+
153
+ if sinkhorn_w > 0:
154
+ has_gt = (gt_lengths > 0).float()
155
+ if conf_logits is not None and conf_mode == "sinkhorn":
156
+ pred_mass = torch.sigmoid(conf_logits)
157
+ elif conf_logits is not None and conf_mode == "sinkhorn_detach":
158
+ pred_mass = torch.sigmoid(conf_logits.detach())
159
+ else:
160
+ pred_mass = None
161
+ eps = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS
162
+ iters = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS
163
+ dustbin = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN
164
+ S = pred_segments.shape[1]
165
+ sink_per = batched_sinkhorn_loss(
166
+ pred_segments, gt_pad, gt_mask,
167
+ eps, iters, dustbin,
168
+ pred_mass=pred_mass,
169
+ ) / (gt_lengths.clamp(min=1.0) * S)
170
+ s = (sink_per * has_gt).sum() / has_gt.sum().clamp(min=1.0)
171
+ total = total + sinkhorn_w * s
172
+ terms["sinkhorn"] = s.detach()
173
+
174
+ if soft_hss_w > 0:
175
+ has_gt = (gt_lengths > 0).float()
176
+ vert_thresh = VERTEX_THRESH_M / scales
177
+ edge_thresh = TUBE_RADIUS_M / scales
178
+ hss_loss = batched_soft_hss_v2(
179
+ pred_segments, gt_pad, gt_mask,
180
+ vert_thresh=vert_thresh, edge_thresh=edge_thresh, tau=SIGMOID_TAU)
181
+ hs = (hss_loss * has_gt).sum() / has_gt.sum().clamp(min=1.0)
182
+ total = total + soft_hss_w * hs
183
+ terms["soft_hss"] = hs.detach()
184
+
185
+ if conf_logits is not None and conf_weight > 0:
186
+ if conf_mode == "match":
187
+ # Explicit BCE supervision from nearest-GT distances
188
+ cl = _conf_match_loss(pred_segments, gt_pad, gt_mask, conf_logits, scales)
189
+ total = total + conf_weight * cl
190
+ terms["conf"] = cl.detach()
191
+ elif conf_mode in ("sinkhorn", "sinkhorn_detach"):
192
+ # Conf trained through sinkhorn transport gradients (via pred_mass).
193
+ # sinkhorn_detach: pred_mass uses detached conf, so OT can't push conf negative.
194
+ # Add count regularizer to prevent all-zero conf collapse.
195
+ # Normalized by S so magnitude doesn't depend on segment count.
196
+ conf_w = torch.sigmoid(conf_logits)
197
+ S = conf_logits.shape[1]
198
+ gt_counts = gt_mask.sum(dim=1).float()
199
+ conf_sum = conf_w.sum(dim=1)
200
+ reg = (((conf_sum - gt_counts) / S) ** 2).mean()
201
+ total = total + conf_weight * reg
202
+ terms["conf_reg"] = reg.detach()
203
+ elif conf_mode == "varifold":
204
+ # Conf-weighted varifold: weight each pred segment's contribution
205
+ # by sigmoid(conf). Low-conf segments contribute less to the loss.
206
+ # Needs regularizer to prevent all-zero conf collapse.
207
+ has_gt = (gt_lengths > 0).float()
208
+ conf_w = torch.sigmoid(conf_logits) # [B, S]
209
+ sigmas_eff = c["sigmas"] / scales[:, None]
210
+ vf_conf = varifold_loss_batch(
211
+ pred_segments, gt_pad, gt_mask=gt_mask,
212
+ variant=VARIANT, sigmas=sigmas_eff, alpha=c["alphas"],
213
+ len_pow=LEN_POW, pred_weights=conf_w,
214
+ )
215
+ vc = (vf_conf / gt_lengths.clamp(min=1.0))
216
+ vc = (vc * has_gt).sum() / has_gt.sum().clamp(min=1.0)
217
+ # Regularizer: penalize total conf being far from n_gt
218
+ # Normalized by S so magnitude doesn't depend on segment count
219
+ S = conf_logits.shape[1]
220
+ gt_counts = gt_mask.sum(dim=1).float() # [B]
221
+ conf_sum = conf_w.sum(dim=1) # [B]
222
+ reg = (((conf_sum - gt_counts) / S) ** 2).mean()
223
+ total = total + conf_weight * vc + 0.01 * reg
224
+ terms["conf_vf"] = vc.detach()
225
+ terms["conf_reg"] = reg.detach()
226
+ else:
227
+ raise ValueError(f"Unknown conf_mode: {conf_mode}")
228
+
229
+ if endpoint_w > 0:
230
+ has_gt = (gt_lengths > 0).float()
231
+ eps_ep = sinkhorn_eps if sinkhorn_eps is not None else SINKHORN_EPS
232
+ iters_ep = sinkhorn_iters if sinkhorn_iters is not None else SINKHORN_ITERS
233
+ dustbin_ep = sinkhorn_dustbin if sinkhorn_dustbin is not None else SINKHORN_DUSTBIN
234
+ B, S = pred_segments.shape[:2]
235
+ M = gt_pad.shape[1]
236
+
237
+ # Compute hard assignment via sinkhorn (detached — matching is not trained)
238
+ with torch.no_grad():
239
+ pred_mass_ep = torch.sigmoid(conf_logits) if conf_logits is not None else None
240
+ sink_loss_for_assign = batched_sinkhorn_loss(
241
+ pred_segments, gt_pad, gt_mask, eps_ep, iters_ep, dustbin_ep,
242
+ pred_mass=pred_mass_ep)
243
+ # Re-run sinkhorn to get transport matrix for assignment
244
+ # (reuse the cost computation from batched_sinkhorn_loss internals)
245
+ p0, p1 = pred_segments[:, :, 0], pred_segments[:, :, 1]
246
+ g0, g1 = gt_pad[:, :, 0], gt_pad[:, :, 1]
247
+ mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
248
+ mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
249
+ d_mid = torch.linalg.norm(mid_p.unsqueeze(2) - mid_g.unsqueeze(1), dim=-1)
250
+ len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
251
+ len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
252
+ dir_p, dir_g = half_p / len_p, half_g / len_g
253
+ cos_a = (dir_p.unsqueeze(2) * dir_g.unsqueeze(1)).sum(dim=-1)
254
+ d_dir = 1.0 - cos_a.abs()
255
+ d_len = (len_p.unsqueeze(2) - len_g.unsqueeze(1)).squeeze(-1).abs()
256
+ cost = d_mid + d_dir + d_len
257
+ dc = torch.as_tensor(dustbin_ep, device=cost.device, dtype=cost.dtype)
258
+ cost = torch.where(gt_mask.unsqueeze(1), cost, dc * 10.0)
259
+ cost_pad = dc.expand(B, S + 1, M + 1).clone()
260
+ cost_pad[:, :S, :M] = cost
261
+ cost_pad[:, -1, -1] = 0.0
262
+ gt_counts = gt_mask.sum(dim=1).float()
263
+ if pred_mass_ep is not None:
264
+ pm = pred_mass_ep.clamp(min=0.0)
265
+ a = torch.cat([pm, (gt_counts - pm.sum(1)).clamp(min=0).unsqueeze(1)], dim=1)
266
+ b_val = torch.zeros(B, M + 1, device=cost.device, dtype=cost.dtype)
267
+ b_val[:, :M] = gt_mask.float()
268
+ b_val[:, -1] = (pm.sum(1) - gt_counts).clamp(min=0)
269
+ else:
270
+ n = float(S)
271
+ denom = n + gt_counts
272
+ a = (1.0 / denom).unsqueeze(1).expand(B, S + 1).clone()
273
+ a[:, -1] = gt_counts / denom
274
+ b_val = (1.0 / denom).unsqueeze(1).expand(B, M + 1).clone()
275
+ b_val[:, -1] = n / denom
276
+ b_val[:, :M] = b_val[:, :M] * gt_mask.float()
277
+ log_a = torch.log(a + 1e-9)
278
+ log_b = torch.log(b_val + 1e-9)
279
+ log_k = -cost_pad / eps_ep
280
+ log_u = torch.zeros_like(a)
281
+ log_v = torch.zeros_like(b_val)
282
+ for _ in range(iters_ep):
283
+ log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
284
+ log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
285
+ transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
286
+ assignment = transport[:, :S, :M+1].argmax(dim=2)
287
+ assignment[assignment >= M] = -1
288
+
289
+ # Everything below is WITH gradients (assignment is detached but pred_segments is live)
290
+ matched = (assignment >= 0) # [B, S]
291
+ n_matched = matched.float().sum().clamp(min=1.0)
292
+ assign_safe = assignment.clamp(min=0)
293
+ gt_matched = gt_pad[
294
+ torch.arange(B, device=device)[:, None].expand(B, S),
295
+ assign_safe] # [B, S, 2, 3]
296
+
297
+ # Symmetric endpoint distance
298
+ ref_ep1 = pred_segments[:, :, 0]
299
+ ref_ep2 = pred_segments[:, :, 1]
300
+ gt_ep1 = gt_matched[:, :, 0]
301
+ gt_ep2 = gt_matched[:, :, 1]
302
+ dist_fwd = (ref_ep1 - gt_ep1).norm(dim=-1) + (ref_ep2 - gt_ep2).norm(dim=-1)
303
+ dist_rev = (ref_ep1 - gt_ep2).norm(dim=-1) + (ref_ep2 - gt_ep1).norm(dim=-1)
304
+ ep_dist = torch.min(dist_fwd, dist_rev)
305
+
306
+ # Normalize by GT total length * S (same scale as sinkhorn)
307
+ ep_loss = (ep_dist * matched.float()).sum() / n_matched
308
+ total = total + endpoint_w * ep_loss
309
+ terms["endpoint"] = ep_loss.detach()
310
+
311
+ return total, terms
s23dr_2026_example/make_sampled_cache.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert full point cloud cache to pre-sampled 2048-point npz files.
3
+
4
+ Reads from either local .pt files or the HF dataset, priority-samples
5
+ 2048 points, normalizes, and saves as compact npz files (~50KB each).
6
+
7
+ Usage:
8
+ # From local cache:
9
+ python make_sampled_cache.py --in-dir /workspace/cache/v2 --out-dir /workspace/cache/sampled
10
+
11
+ # From HF dataset:
12
+ python make_sampled_cache.py --hf-repo usm3d/s23dr-2026-cached_full_pcd --out-dir /workspace/cache/sampled
13
+
14
+ # Specify split:
15
+ python make_sampled_cache.py --hf-repo usm3d/s23dr-2026-cached_full_pcd --split validation --out-dir /workspace/cache/sampled_val
16
+
17
+ # With edge classifications (from extract_edge_classes.py):
18
+ python make_sampled_cache.py --hf-repo usm3d/s23dr-2026-cached_full_pcd --out-dir /workspace/cache/sampled \
19
+ --edge-classes edge_classifications.npz
20
+
21
+ Note: uses a fixed seed so each scene gets one deterministic sample of 2048
22
+ points. This means no sampling augmentation across epochs -- every epoch sees
23
+ the same points. Fine for now; better augmentation can be added later.
24
+ """
25
+ from __future__ import annotations
26
+
27
+ import sys
28
+ from pathlib import Path as _Path
29
+ if __package__ is None or __package__ == "":
30
+ _here = _Path(__file__).resolve().parent
31
+ if str(_here.parent) not in sys.path:
32
+ sys.path.insert(0, str(_here.parent))
33
+ __package__ = _here.name
34
+
35
+ import argparse
36
+ import io
37
+ import time
38
+ from pathlib import Path
39
+
40
+ import numpy as np
41
+ import torch
42
+
43
+
44
+ # Priority sampling (same logic as train.py)
45
+ def _priority_sample(source, group_id, seq_len, colmap_quota, depth_quota):
46
+ def pick(src_id, quota):
47
+ base = source == src_id
48
+ picked, remaining = [], quota
49
+ for tier in range(5):
50
+ if remaining <= 0:
51
+ break
52
+ pool = np.where(base & (group_id == tier))[0]
53
+ if len(pool) == 0:
54
+ continue
55
+ np.random.shuffle(pool)
56
+ take = min(remaining, len(pool))
57
+ picked.append(pool[:take])
58
+ remaining -= take
59
+ if remaining > 0:
60
+ pool = np.where(base & (group_id >= 0))[0]
61
+ if len(pool) > 0:
62
+ np.random.shuffle(pool)
63
+ picked.append(pool[:min(remaining, len(pool))])
64
+ remaining -= min(remaining, len(pool))
65
+ return np.concatenate(picked) if picked else np.array([], dtype=np.int64), remaining
66
+
67
+ idx_c, rem_c = pick(0, colmap_quota)
68
+ idx_d, rem_d = pick(1, depth_quota)
69
+
70
+ if rem_c > 0:
71
+ extra = np.setdiff1d(np.where((source == 1) & (group_id >= 0))[0], idx_d)
72
+ np.random.shuffle(extra)
73
+ idx_d = np.concatenate([idx_d, extra[:rem_c]])
74
+ if rem_d > 0:
75
+ extra = np.setdiff1d(np.where((source == 0) & (group_id >= 0))[0], idx_c)
76
+ np.random.shuffle(extra)
77
+ idx_c = np.concatenate([idx_c, extra[:rem_d]])
78
+
79
+ indices = np.concatenate([idx_c, idx_d])
80
+ num_valid = len(indices)
81
+ if num_valid < seq_len:
82
+ if num_valid == 0:
83
+ return np.zeros(seq_len, dtype=np.int64), np.zeros(seq_len, dtype=bool)
84
+ indices = np.concatenate([indices, np.full(seq_len - num_valid, indices[-1])])
85
+ mask = np.zeros(seq_len, dtype=bool)
86
+ mask[:num_valid] = True
87
+ return indices[:seq_len], mask
88
+
89
+
90
+ def process_sample(xyz, source, group_id, class_id, vis_src, vis_id,
91
+ center, scale, gt_v, gt_e, behind=None,
92
+ n_views_voted=None, vote_frac=None,
93
+ gt_edge_classes=None,
94
+ seq_len=2048, colmap_q=1536, depth_q=512):
95
+ """Sample and normalize one scene. Returns dict of numpy arrays."""
96
+ indices, mask = _priority_sample(source, group_id, seq_len, colmap_q, depth_q)
97
+ xyz_norm = ((xyz[indices] - center) / scale).astype(np.float32)
98
+ gt_seg = np.stack([gt_v[gt_e[:, 0]], gt_v[gt_e[:, 1]]], axis=1)
99
+ gt_seg_norm = ((gt_seg - center) / scale).astype(np.float32)
100
+
101
+ result = {
102
+ "xyz_norm": xyz_norm,
103
+ "class_id": class_id[indices].astype(np.uint8),
104
+ "source": source[indices].astype(np.uint8),
105
+ "mask": mask,
106
+ "gt_segments": gt_seg_norm,
107
+ "scale": np.float32(scale),
108
+ "center": center.astype(np.float32),
109
+ "gt_vertices": gt_v.astype(np.float32),
110
+ "gt_edges": gt_e.astype(np.int32),
111
+ "visible_src": vis_src[indices].astype(np.uint8),
112
+ "visible_id": vis_id[indices].astype(np.int16),
113
+ }
114
+ if behind is not None:
115
+ result["behind"] = behind[indices].astype(np.int16)
116
+ if n_views_voted is not None:
117
+ result["n_views_voted"] = n_views_voted[indices].astype(np.uint8)
118
+ if vote_frac is not None:
119
+ result["vote_frac"] = vote_frac[indices].astype(np.float32)
120
+ if gt_edge_classes is not None:
121
+ if len(gt_edge_classes) != len(gt_e):
122
+ raise ValueError(
123
+ f"gt_edge_classes length {len(gt_edge_classes)} != "
124
+ f"gt_edges length {len(gt_e)}")
125
+ result["gt_edge_classes"] = gt_edge_classes.astype(np.int64)
126
+ return result
127
+
128
+
129
+ def _load_edge_classes(path):
130
+ """Load edge classifications lookup from npz file."""
131
+ if path is None:
132
+ return None
133
+ path = Path(path)
134
+ if not path.exists():
135
+ raise FileNotFoundError(f"Edge classifications file not found: {path}")
136
+ data = np.load(str(path), allow_pickle=False)
137
+ lookup = {k: data[k] for k in data.files}
138
+ print(f"Loaded edge classifications for {len(lookup)} orders from {path}")
139
+ return lookup
140
+
141
+
142
+ def main():
143
+ p = argparse.ArgumentParser()
144
+ g = p.add_mutually_exclusive_group(required=True)
145
+ g.add_argument("--in-dir", help="Local directory of .pt files")
146
+ g.add_argument("--hf-repo", help="HuggingFace dataset repo (e.g. usm3d/s23dr-2026-cached_full_pcd)")
147
+ p.add_argument("--split", default="train", help="HF dataset split")
148
+ p.add_argument("--out-dir", required=True)
149
+ p.add_argument("--edge-classes", default=None,
150
+ help="Path to edge_classifications.npz from extract_edge_classes.py")
151
+ p.add_argument("--seq-len", type=int, default=2048)
152
+ p.add_argument("--colmap-quota", type=int, default=1536)
153
+ p.add_argument("--depth-quota", type=int, default=512)
154
+ p.add_argument("--seed", type=int, default=7)
155
+ args = p.parse_args()
156
+
157
+ out_dir = Path(args.out_dir)
158
+ out_dir.mkdir(parents=True, exist_ok=True)
159
+ np.random.seed(args.seed)
160
+
161
+ edge_cls_lookup = _load_edge_classes(args.edge_classes)
162
+ n_edge_matched, n_edge_missing = 0, 0
163
+
164
+ t_start = time.perf_counter()
165
+ done = 0
166
+
167
+ if args.in_dir:
168
+ # Local .pt files
169
+ files = sorted(Path(args.in_dir).glob("*.pt"))
170
+ print(f"Converting {len(files)} local .pt files...")
171
+ for f in files:
172
+ out_f = out_dir / (f.stem + ".npz")
173
+ if out_f.exists():
174
+ done += 1
175
+ continue
176
+ d = torch.load(f, weights_only=False)
177
+ behind = np.asarray(d["behind_gest_id"], np.int16) if "behind_gest_id" in d else None
178
+ n_vv = np.asarray(d["n_views_voted"], np.uint8) if "n_views_voted" in d else None
179
+ vf = np.asarray(d["vote_frac"], np.float32) if "vote_frac" in d else None
180
+ gt_ec = None
181
+ if edge_cls_lookup is not None:
182
+ order_id = f.stem
183
+ if order_id in edge_cls_lookup:
184
+ gt_ec = edge_cls_lookup[order_id]
185
+ n_edge_matched += 1
186
+ else:
187
+ n_edge_missing += 1
188
+ result = process_sample(
189
+ np.asarray(d["xyz"], np.float32),
190
+ np.asarray(d["source"], np.uint8),
191
+ np.asarray(d["group_id"], np.int8),
192
+ np.asarray(d["class_id"], np.uint8),
193
+ np.asarray(d["visible_src"], np.uint8),
194
+ np.asarray(d["visible_id"], np.int16),
195
+ np.asarray(d["center"], np.float32),
196
+ float(d["scale"]),
197
+ np.asarray(d["gt_vertices"], np.float32),
198
+ np.asarray(d["gt_edges"], np.int32),
199
+ behind=behind, n_views_voted=n_vv, vote_frac=vf,
200
+ gt_edge_classes=gt_ec,
201
+ seq_len=args.seq_len, colmap_q=args.colmap_quota, depth_q=args.depth_quota,
202
+ )
203
+ np.savez(out_f, **result)
204
+ done += 1
205
+ if done % 2000 == 0:
206
+ print(f" {done}/{len(files)} [{done/(time.perf_counter()-t_start):.0f}/s]")
207
+ else:
208
+ # HF dataset
209
+ from datasets import load_dataset
210
+ print(f"Loading {args.hf_repo} split={args.split}...")
211
+ ds = load_dataset(args.hf_repo, split=args.split)
212
+ print(f"Converting {len(ds)} samples...")
213
+ for i, sample in enumerate(ds):
214
+ order_id = sample["order_id"]
215
+ out_f = out_dir / f"{order_id}.npz"
216
+ if out_f.exists():
217
+ done += 1
218
+ continue
219
+ arrays = np.load(io.BytesIO(sample["data"]))
220
+ behind = arrays["behind_gest_id"] if "behind_gest_id" in arrays else None
221
+ n_vv = arrays["n_views_voted"] if "n_views_voted" in arrays else None
222
+ vf = arrays["vote_frac"] if "vote_frac" in arrays else None
223
+ gt_ec = None
224
+ if edge_cls_lookup is not None:
225
+ if order_id in edge_cls_lookup:
226
+ gt_ec = edge_cls_lookup[order_id]
227
+ n_edge_matched += 1
228
+ else:
229
+ n_edge_missing += 1
230
+ result = process_sample(
231
+ arrays["xyz"], arrays["source"], arrays["group_id"],
232
+ arrays["class_id"], arrays["visible_src"], arrays["visible_id"],
233
+ arrays["center"], float(arrays["scale"]),
234
+ arrays["gt_vertices"], arrays["gt_edges"],
235
+ behind=behind, n_views_voted=n_vv, vote_frac=vf,
236
+ gt_edge_classes=gt_ec,
237
+ seq_len=args.seq_len, colmap_q=args.colmap_quota, depth_q=args.depth_quota,
238
+ )
239
+ np.savez(out_f, **result)
240
+ done += 1
241
+ if done % 2000 == 0:
242
+ print(f" {done}/{len(ds)} [{done/(time.perf_counter()-t_start):.0f}/s]")
243
+
244
+ elapsed = time.perf_counter() - t_start
245
+ print(f"Done: {done} files in {elapsed:.0f}s ({done/max(1,elapsed):.0f}/s)")
246
+
247
+ if edge_cls_lookup is not None:
248
+ print(f"Edge classifications: {n_edge_matched} matched, {n_edge_missing} missing")
249
+
250
+ # Report sizes
251
+ import os
252
+ npz_files = list(out_dir.glob("*.npz"))
253
+ if npz_files:
254
+ sizes = [os.path.getsize(f) for f in npz_files[:100]]
255
+ print(f"Avg file size: {np.mean(sizes)/1024:.0f}KB")
256
+ print(f"Est total: {np.mean(sizes)*len(npz_files)/1e9:.1f}GB")
257
+
258
+
259
+ if __name__ == "__main__":
260
+ main()
s23dr_2026_example/model.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Perceiver-based transformer for 3D roof wireframe prediction.
3
+
4
+ Architecture overview:
5
+
6
+ Input tokens [B, T, D]
7
+ |
8
+ v
9
+ input_proj: Linear -> GELU -> Linear -> LayerNorm => [B, T, hidden]
10
+ |
11
+ v
12
+ Perceiver latent bottleneck (N PerceiverLatentLayers):
13
+ Learnable latent embeddings [L, hidden] are broadcast to batch.
14
+ Each layer: cross-attn(latents <- tokens) -> self-attn(latents) -> FFN
15
+ Output: latents [B, L, hidden]
16
+ |
17
+ v
18
+ Segment decoder (M SegmentDecoderLayers):
19
+ Learnable query embeddings [S, hidden] are broadcast to batch.
20
+ Each layer: cross-attn(queries <- latents) -> self-attn(queries) -> FFN
21
+ Output: queries [B, S, hidden]
22
+ |
23
+ v
24
+ segment_head: Linear -> 6D -> (midpoint, half_vector)
25
+ + query_offsets (learnable per-query bias)
26
+ endpoints = midpoint +/- half_vector -> [B, S, 2, 3]
27
+ """
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+
32
+ from .attention import MultiHeadSDPA, FeedForward
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Building blocks
37
+ # ---------------------------------------------------------------------------
38
+
39
+ class AttnResidual(nn.Module):
40
+ """Pre-norm attention + residual + dropout."""
41
+
42
+ def __init__(
43
+ self,
44
+ d_model: int,
45
+ num_heads: int,
46
+ dropout: float = 0.0,
47
+ kv_heads: int | None = None,
48
+ norm_class=None,
49
+ qk_norm: bool = False,
50
+ qk_norm_type: str = "l2",
51
+ ):
52
+ super().__init__()
53
+ norm_class = norm_class or nn.LayerNorm
54
+ self.norm = norm_class(d_model)
55
+ self.attn = MultiHeadSDPA(d_model, num_heads, kv_heads=kv_heads, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
56
+ self.drop = nn.Dropout(dropout)
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ memory: torch.Tensor,
62
+ memory_key_padding_mask: torch.Tensor | None = None,
63
+ ) -> torch.Tensor:
64
+ res = x
65
+ x = self.norm(x)
66
+ x = self.attn(x, memory, key_padding_mask=memory_key_padding_mask)
67
+ return res + self.drop(x)
68
+
69
+
70
+ class FFNResidual(nn.Module):
71
+ """Pre-norm feed-forward + residual + dropout."""
72
+
73
+ def __init__(
74
+ self,
75
+ d_model: int,
76
+ dim_ff: int,
77
+ dropout: float = 0.0,
78
+ activation: str = "gelu",
79
+ norm_class=None,
80
+ ):
81
+ super().__init__()
82
+ norm_class = norm_class or nn.LayerNorm
83
+ self.norm = norm_class(d_model)
84
+ self.ffn = FeedForward(d_model, dim_ff, activation=activation)
85
+ self.drop = nn.Dropout(dropout)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ res = x
89
+ x = self.norm(x)
90
+ x = self.ffn(x)
91
+ return res + self.drop(x)
92
+
93
+
94
+ # ---------------------------------------------------------------------------
95
+ # Perceiver encoder layer
96
+ # ---------------------------------------------------------------------------
97
+
98
+ class PerceiverLatentLayer(nn.Module):
99
+ """Single Perceiver latent layer.
100
+
101
+ If use_cross=True: cross-attn(latents <- points) -> self-attn -> FFN
102
+ If use_cross=False: self-attn -> FFN (saves compute in deep stacks)
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ d_model: int,
108
+ num_heads: int,
109
+ dim_ff: int,
110
+ dropout: float = 0.0,
111
+ activation: str = "gelu",
112
+ kv_heads_cross: int | None = None,
113
+ kv_heads_self: int | None = None,
114
+ use_cross: bool = True,
115
+ norm_class=None,
116
+ qk_norm: bool = False,
117
+ qk_norm_type: str = "l2",
118
+ ):
119
+ super().__init__()
120
+ self.use_cross = use_cross
121
+ if use_cross:
122
+ self.cross = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
123
+ self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_self, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
124
+ self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class)
125
+
126
+ def forward(
127
+ self,
128
+ latents: torch.Tensor,
129
+ points: torch.Tensor,
130
+ points_key_padding_mask: torch.Tensor | None = None,
131
+ ) -> torch.Tensor:
132
+ if self.use_cross:
133
+ latents = self.cross(latents, points, memory_key_padding_mask=points_key_padding_mask)
134
+ latents = self.self_attn(latents, latents)
135
+ latents = self.ffn(latents)
136
+ return latents
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Segment decoder layer
141
+ # ---------------------------------------------------------------------------
142
+
143
+ class SegmentDecoderLayer(nn.Module):
144
+ """Single segment decoder layer.
145
+
146
+ cross-attn(queries <- latents) -> [cross-attn(queries <- inputs)] -> self-attn(queries) -> FFN
147
+
148
+ If input_xattn=True, adds a second cross-attention that attends directly
149
+ to the projected input tokens (bypassing the latent bottleneck). This gives
150
+ queries access to fine-grained point-level detail for vertex precision.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ d_model: int,
156
+ num_heads: int,
157
+ dim_ff: int,
158
+ dropout: float = 0.0,
159
+ activation: str = "gelu",
160
+ kv_heads_cross: int | None = None,
161
+ kv_heads_self: int | None = None,
162
+ norm_class=None,
163
+ input_xattn: bool = False,
164
+ qk_norm: bool = False,
165
+ qk_norm_type: str = "l2",
166
+ ):
167
+ super().__init__()
168
+ self.cross = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
169
+ self.input_xattn = input_xattn
170
+ if input_xattn:
171
+ self.cross_input = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
172
+ self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_self, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
173
+ self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class)
174
+
175
+ def forward(
176
+ self,
177
+ queries: torch.Tensor,
178
+ latents: torch.Tensor,
179
+ src: torch.Tensor | None = None,
180
+ src_key_padding_mask: torch.Tensor | None = None,
181
+ ) -> torch.Tensor:
182
+ queries = self.cross(queries, latents)
183
+ if self.input_xattn and src is not None:
184
+ queries = self.cross_input(queries, src, memory_key_padding_mask=src_key_padding_mask)
185
+ queries = self.self_attn(queries, queries)
186
+ queries = self.ffn(queries)
187
+ return queries
188
+
189
+
190
+ # ---------------------------------------------------------------------------
191
+ # Full model
192
+ # ---------------------------------------------------------------------------
193
+
194
+ class TokenTransformerSegments(nn.Module):
195
+ """Perceiver transformer that predicts 3D roof wireframe segments.
196
+
197
+ Takes point-cloud tokens and outputs segment endpoints as [B, S, 2, 3]
198
+ where S is the number of segments and each segment has two 3D endpoints.
199
+
200
+ Args:
201
+ segments: Number of predicted segments (S).
202
+ in_dim: Dimensionality of input tokens.
203
+ hidden: Internal hidden dimension throughout the model.
204
+ num_heads: Number of attention heads.
205
+ kv_heads_cross: Grouped-query heads for cross-attention (None = standard MHA).
206
+ kv_heads_self: Grouped-query heads for self-attention (None = standard MHA).
207
+ dim_feedforward: FFN intermediate dimension.
208
+ dropout: Dropout rate applied after attention and FFN.
209
+ latent_tokens: Number of learnable latent embeddings (L) in the bottleneck.
210
+ latent_layers: Number of PerceiverLatentLayers (N).
211
+ decoder_layers: Number of SegmentDecoderLayers (M).
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ segments: int = 32,
217
+ in_dim: int = 128,
218
+ hidden: int = 128,
219
+ num_heads: int = 4,
220
+ kv_heads_cross: int | None = 2,
221
+ kv_heads_self: int | None = 0,
222
+ dim_feedforward: int = 256,
223
+ dropout: float = 0.01,
224
+ latent_tokens: int = 64,
225
+ latent_layers: int = 2,
226
+ decoder_layers: int = 2,
227
+ cross_attn_interval: int = 1,
228
+ norm_class=None,
229
+ activation: str = "gelu",
230
+ segment_conf: bool = False,
231
+ pre_encoder_layers: int = 0,
232
+ segment_param: str = "midpoint_halfvec",
233
+ length_floor: float = 0.0,
234
+ decoder_input_xattn: bool = False,
235
+ qk_norm: bool = False,
236
+ qk_norm_type: str = "l2",
237
+ ):
238
+ super().__init__()
239
+ self.segments = segments
240
+ self.out_vertices = segments * 2
241
+ self.segment_param = segment_param
242
+ self.length_floor = length_floor
243
+ self.decoder_input_xattn = decoder_input_xattn
244
+ norm_class = norm_class or nn.LayerNorm
245
+
246
+ # Treat 0 as "use standard MHA"
247
+ if kv_heads_cross is not None and kv_heads_cross <= 0:
248
+ kv_heads_cross = None
249
+ if kv_heads_self is not None and kv_heads_self <= 0:
250
+ kv_heads_self = None
251
+
252
+ # -- Input projection --
253
+ self.input_proj = nn.Sequential(
254
+ nn.Linear(in_dim, dim_feedforward),
255
+ nn.GELU(),
256
+ nn.Linear(dim_feedforward, hidden),
257
+ norm_class(hidden),
258
+ )
259
+
260
+ # -- Optional pre-encoder: self-attention on full token sequence --
261
+ if pre_encoder_layers > 0:
262
+ self.pre_encoder = nn.ModuleList([
263
+ SelfAttentionEncoderLayer(
264
+ d_model=hidden,
265
+ num_heads=num_heads,
266
+ dim_ff=dim_feedforward,
267
+ dropout=dropout,
268
+ activation=activation,
269
+ kv_heads=kv_heads_self,
270
+ norm_class=norm_class,
271
+ qk_norm=qk_norm, qk_norm_type=qk_norm_type,
272
+ )
273
+ for _ in range(pre_encoder_layers)
274
+ ])
275
+ else:
276
+ self.pre_encoder = None
277
+
278
+ # -- Perceiver latent bottleneck --
279
+ self.latent_embed = nn.Embedding(latent_tokens, hidden)
280
+ N = latent_layers
281
+ self.latent_layers = nn.ModuleList([
282
+ PerceiverLatentLayer(
283
+ d_model=hidden,
284
+ num_heads=num_heads,
285
+ dim_ff=dim_feedforward,
286
+ dropout=dropout,
287
+ activation=activation,
288
+ kv_heads_cross=kv_heads_cross,
289
+ kv_heads_self=kv_heads_self,
290
+ use_cross=(i == 0) or (i == N - 1) or (i % cross_attn_interval == 0),
291
+ norm_class=norm_class,
292
+ qk_norm=qk_norm, qk_norm_type=qk_norm_type,
293
+ )
294
+ for i in range(N)
295
+ ])
296
+
297
+ # -- Segment decoder --
298
+ self.query_embed = nn.Embedding(segments, hidden)
299
+ self.decoder_layers = nn.ModuleList([
300
+ SegmentDecoderLayer(
301
+ d_model=hidden,
302
+ num_heads=num_heads,
303
+ dim_ff=dim_feedforward,
304
+ dropout=dropout,
305
+ activation=activation,
306
+ kv_heads_cross=kv_heads_cross,
307
+ kv_heads_self=kv_heads_self,
308
+ norm_class=norm_class,
309
+ input_xattn=decoder_input_xattn,
310
+ qk_norm=qk_norm, qk_norm_type=qk_norm_type,
311
+ )
312
+ for _ in range(decoder_layers)
313
+ ])
314
+
315
+ # -- Output head --
316
+ if segment_param == "midpoint_dir_len":
317
+ self.segment_head = nn.Linear(hidden, 7) # mid(3) + dir(3) + len(1)
318
+ else:
319
+ self.segment_head = nn.Linear(hidden, 6) # mid(3) + half(3)
320
+ self.query_offsets = nn.Parameter(torch.zeros(segments, 2, 3))
321
+
322
+ nn.init.trunc_normal_(self.segment_head.weight, mean=0.0, std=1e-3)
323
+ if self.segment_head.bias is not None:
324
+ nn.init.zeros_(self.segment_head.bias)
325
+ if segment_param == "midpoint_dir_len":
326
+ # softplus(0.5) * 0.1 ≈ 0.097 default length in normalized space
327
+ self.segment_head.bias.data[6] = 0.5
328
+ nn.init.normal_(self.query_offsets, mean=0.0, std=0.05)
329
+
330
+ # -- Optional confidence head --
331
+ self.segment_conf = segment_conf
332
+ if segment_conf:
333
+ self.conf_head = nn.Linear(hidden, 1)
334
+ nn.init.zeros_(self.conf_head.bias)
335
+
336
+ def forward(
337
+ self,
338
+ tokens: torch.Tensor,
339
+ mask: torch.Tensor | None = None,
340
+ ) -> dict[str, torch.Tensor | list]:
341
+ """
342
+ Args:
343
+ tokens: Input point-cloud tokens [B, T, in_dim].
344
+ mask: Boolean validity mask [B, T]. True = valid token.
345
+
346
+ Returns:
347
+ Dict with keys:
348
+ "vertices": [B, S*2, 3] flattened endpoints.
349
+ "segments": [B, S, 2, 3] segment endpoints.
350
+ "edges": Per-batch list of (start, end) index pairs into vertices.
351
+ "conf": [B, S] logits (only if segment_conf=True).
352
+ """
353
+ B = tokens.shape[0]
354
+
355
+ # Project input tokens
356
+ src = self.input_proj(tokens) # [B, T, hidden]
357
+
358
+ # Padding mask (True where padded) for cross-attention
359
+ pad_mask = ~mask.bool() if mask is not None else None
360
+
361
+ # Optional pre-encoder: self-attention on full token sequence
362
+ if self.pre_encoder is not None:
363
+ for layer in self.pre_encoder:
364
+ src = layer(src, key_padding_mask=pad_mask)
365
+
366
+ # Perceiver latent bottleneck
367
+ latents = self.latent_embed.weight.unsqueeze(0).expand(B, -1, -1)
368
+ for layer in self.latent_layers:
369
+ latents = layer(latents, src, points_key_padding_mask=pad_mask)
370
+
371
+ # Segment decoder
372
+ queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1)
373
+ for layer in self.decoder_layers:
374
+ queries = layer(queries, latents,
375
+ src=src if self.decoder_input_xattn else None,
376
+ src_key_padding_mask=pad_mask if self.decoder_input_xattn else None)
377
+
378
+ # Predict segments -> endpoints
379
+ if self.segment_param == "midpoint_dir_len":
380
+ raw = self.segment_head(queries) # [B, S, 7]
381
+ mid = raw[:, :, :3] + self.query_offsets[:, 0, :].unsqueeze(0)
382
+ direction = torch.nn.functional.normalize(raw[:, :, 3:6], dim=-1)
383
+ length = torch.nn.functional.softplus(raw[:, :, 6:7]) * 0.1
384
+ half = direction * length * 0.5
385
+ else:
386
+ raw = self.segment_head(queries).view(B, self.segments, 2, 3)
387
+ raw = raw + self.query_offsets.unsqueeze(0)
388
+ mid, half = raw[:, :, 0], raw[:, :, 1]
389
+ seg_params = torch.stack([mid - half, mid + half], dim=2)
390
+
391
+ vertices = seg_params.reshape(B, self.out_vertices, 3)
392
+ edges = [[(2 * i, 2 * i + 1) for i in range(self.segments)] for _ in range(B)]
393
+
394
+ out = {"vertices": vertices, "segments": seg_params, "edges": edges,
395
+ "src": src, "pad_mask": pad_mask, "queries": queries}
396
+ if self.segment_conf:
397
+ out["conf"] = self.conf_head(queries).squeeze(-1) # [B, S]
398
+ return out
399
+
400
+
401
+ # ---------------------------------------------------------------------------
402
+ # Encoder-only layer (self-attention on full token sequence)
403
+ # ---------------------------------------------------------------------------
404
+
405
+ class SelfAttentionEncoderLayer(nn.Module):
406
+ """Single self-attention layer: self-attn(tokens) -> FFN."""
407
+
408
+ def __init__(
409
+ self,
410
+ d_model: int,
411
+ num_heads: int,
412
+ dim_ff: int,
413
+ dropout: float = 0.0,
414
+ activation: str = "gelu",
415
+ kv_heads: int | None = None,
416
+ norm_class=None,
417
+ qk_norm: bool = False,
418
+ qk_norm_type: str = "l2",
419
+ ):
420
+ super().__init__()
421
+ self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type)
422
+ self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class)
423
+
424
+ def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor:
425
+ x = self.self_attn(x, x, memory_key_padding_mask=key_padding_mask)
426
+ x = self.ffn(x)
427
+ return x
428
+
429
+
430
+ # ---------------------------------------------------------------------------
431
+ # Vanilla transformer: self-attention encoder + segment query decoder
432
+ # ---------------------------------------------------------------------------
433
+
434
+ class TransformerSegments(nn.Module):
435
+ """Standard transformer encoder + cross-attention segment decoder.
436
+
437
+ Architecture:
438
+ Input tokens [B, T, D]
439
+ |
440
+ v
441
+ input_proj: Linear -> GELU -> Linear -> Norm => [B, T, hidden]
442
+ |
443
+ v
444
+ N SelfAttentionEncoderLayers (self-attn over all T tokens)
445
+ |
446
+ v
447
+ Segment decoder (same as Perceiver version):
448
+ M SegmentDecoderLayers (queries cross-attend to encoded tokens)
449
+ |
450
+ v
451
+ segment_head -> endpoints [B, S, 2, 3] (midpoint_halfvec or midpoint_dir_len)
452
+ """
453
+
454
+ def __init__(
455
+ self,
456
+ segments: int = 32,
457
+ in_dim: int = 128,
458
+ hidden: int = 128,
459
+ num_heads: int = 4,
460
+ kv_heads_cross: int | None = 2,
461
+ kv_heads_self: int | None = 0,
462
+ dim_feedforward: int = 256,
463
+ dropout: float = 0.01,
464
+ encoder_layers: int = 4,
465
+ decoder_layers: int = 2,
466
+ norm_class=None,
467
+ activation: str = "gelu",
468
+ segment_conf: bool = False,
469
+ segment_param: str = "midpoint_halfvec",
470
+ length_floor: float = 0.0,
471
+ decoder_input_xattn: bool = False,
472
+ qk_norm: bool = False,
473
+ qk_norm_type: str = "l2",
474
+ ):
475
+ super().__init__()
476
+ self.segments = segments
477
+ self.out_vertices = segments * 2
478
+ self.segment_param = segment_param
479
+ self.length_floor = length_floor
480
+ norm_class = norm_class or nn.LayerNorm
481
+
482
+ if kv_heads_cross is not None and kv_heads_cross <= 0:
483
+ kv_heads_cross = None
484
+ if kv_heads_self is not None and kv_heads_self <= 0:
485
+ kv_heads_self = None
486
+
487
+ # -- Input projection --
488
+ self.input_proj = nn.Sequential(
489
+ nn.Linear(in_dim, dim_feedforward),
490
+ nn.GELU(),
491
+ nn.Linear(dim_feedforward, hidden),
492
+ norm_class(hidden),
493
+ )
494
+
495
+ # -- Self-attention encoder --
496
+ self.encoder_layers = nn.ModuleList([
497
+ SelfAttentionEncoderLayer(
498
+ d_model=hidden,
499
+ num_heads=num_heads,
500
+ dim_ff=dim_feedforward,
501
+ dropout=dropout,
502
+ activation=activation,
503
+ kv_heads=kv_heads_self,
504
+ norm_class=norm_class,
505
+ qk_norm=qk_norm, qk_norm_type=qk_norm_type,
506
+ )
507
+ for _ in range(encoder_layers)
508
+ ])
509
+
510
+ # -- Segment decoder (same structure as Perceiver version) --
511
+ # Note: for transformer arch, decoder_input_xattn is ignored because
512
+ # the decoder already cross-attends to the full encoded token sequence.
513
+ self.query_embed = nn.Embedding(segments, hidden)
514
+ self.decoder_layers = nn.ModuleList([
515
+ SegmentDecoderLayer(
516
+ d_model=hidden,
517
+ num_heads=num_heads,
518
+ dim_ff=dim_feedforward,
519
+ dropout=dropout,
520
+ activation=activation,
521
+ kv_heads_cross=kv_heads_cross,
522
+ kv_heads_self=kv_heads_self,
523
+ norm_class=norm_class,
524
+ qk_norm=qk_norm, qk_norm_type=qk_norm_type,
525
+ )
526
+ for _ in range(decoder_layers)
527
+ ])
528
+
529
+ # -- Output head (shared logic with Perceiver version) --
530
+ if segment_param == "midpoint_dir_len":
531
+ self.segment_head = nn.Linear(hidden, 7) # mid(3) + dir(3) + len(1)
532
+ else:
533
+ self.segment_head = nn.Linear(hidden, 6) # mid(3) + half(3)
534
+ self.query_offsets = nn.Parameter(torch.zeros(segments, 2, 3))
535
+
536
+ nn.init.trunc_normal_(self.segment_head.weight, mean=0.0, std=1e-3)
537
+ if self.segment_head.bias is not None:
538
+ nn.init.zeros_(self.segment_head.bias)
539
+ if segment_param == "midpoint_dir_len":
540
+ # sigmoid(-2.2) ~ 0.1 default length in normalized space (~3m)
541
+ self.segment_head.bias.data[6] = -2.2
542
+ nn.init.normal_(self.query_offsets, mean=0.0, std=0.05)
543
+
544
+ self.segment_conf = segment_conf
545
+ if segment_conf:
546
+ self.conf_head = nn.Linear(hidden, 1)
547
+ nn.init.zeros_(self.conf_head.bias)
548
+
549
+ def forward(
550
+ self,
551
+ tokens: torch.Tensor,
552
+ mask: torch.Tensor | None = None,
553
+ ) -> dict[str, torch.Tensor | list]:
554
+ B = tokens.shape[0]
555
+
556
+ src = self.input_proj(tokens)
557
+ pad_mask = ~mask.bool() if mask is not None else None
558
+
559
+ # Encode: self-attention over all tokens
560
+ for layer in self.encoder_layers:
561
+ src = layer(src, key_padding_mask=pad_mask)
562
+
563
+ # Decode: segment queries cross-attend to encoded tokens
564
+ queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1)
565
+ for layer in self.decoder_layers:
566
+ queries = layer(queries, src)
567
+
568
+ # Predict segments -> endpoints
569
+ if self.segment_param == "midpoint_dir_len":
570
+ raw = self.segment_head(queries) # [B, S, 7]
571
+ mid = raw[:, :, :3] + self.query_offsets[:, 0, :].unsqueeze(0)
572
+ direction = torch.nn.functional.normalize(raw[:, :, 3:6], dim=-1)
573
+ length = torch.nn.functional.softplus(raw[:, :, 6:7]) * 0.1
574
+ half = direction * length * 0.5
575
+ else:
576
+ raw = self.segment_head(queries).view(B, self.segments, 2, 3)
577
+ raw = raw + self.query_offsets.unsqueeze(0)
578
+ mid, half = raw[:, :, 0], raw[:, :, 1]
579
+ seg_params = torch.stack([mid - half, mid + half], dim=2)
580
+
581
+ vertices = seg_params.reshape(B, self.out_vertices, 3)
582
+ edges = [[(2 * i, 2 * i + 1) for i in range(self.segments)] for _ in range(B)]
583
+
584
+ out = {"vertices": vertices, "segments": seg_params, "edges": edges}
585
+ if self.segment_conf:
586
+ out["conf"] = self.conf_head(queries).squeeze(-1)
587
+ return out
588
+
589
+
590
+ # ---------------------------------------------------------------------------
591
+ # End-to-end model: tokenizer embeddings + transformer/perceiver
592
+ # ---------------------------------------------------------------------------
593
+
594
+ class EdgeDepthSegmentsModel(nn.Module):
595
+ """Tokenizer embeddings + transformer for 3D roof wireframes.
596
+
597
+ Supports two architectures via the `arch` parameter:
598
+ - "perceiver": Perceiver latent bottleneck (default, O(L*T) attention)
599
+ - "transformer": Standard self-attention encoder (O(T^2) attention)
600
+
601
+ Both share the same decoder, output head, and tokenizer.
602
+ """
603
+
604
+ def __init__(
605
+ self,
606
+ seq_cfg,
607
+ segments: int = 32,
608
+ hidden: int = 128,
609
+ num_heads: int = 4,
610
+ kv_heads_cross: int | None = 2,
611
+ kv_heads_self: int | None = 0,
612
+ dim_feedforward: int = 256,
613
+ dropout: float = 0.1,
614
+ latent_tokens: int = 64,
615
+ latent_layers: int = 1,
616
+ decoder_layers: int = 2,
617
+ label_emb_dim: int = 16,
618
+ src_emb_dim: int = 2,
619
+ behind_emb_dim: int = 8,
620
+ fourier_seed: int = 0,
621
+ cross_attn_interval: int = 1,
622
+ norm_class=None,
623
+ activation: str = "gelu",
624
+ segment_conf: bool = False,
625
+ use_vote_features: bool = False,
626
+ arch: str = "perceiver",
627
+ encoder_layers: int = 4,
628
+ pre_encoder_layers: int = 0,
629
+ segment_param: str = "midpoint_halfvec",
630
+ length_floor: float = 0.0,
631
+ decoder_input_xattn: bool = False,
632
+ qk_norm: bool = False,
633
+ qk_norm_type: str = "l2",
634
+ learnable_fourier: bool = False,
635
+ ):
636
+ super().__init__()
637
+ self.seq_cfg = seq_cfg
638
+
639
+ from .tokenizer import EdgeDepthSequenceBuilder
640
+ self.tokenizer = EdgeDepthSequenceBuilder(
641
+ seq_cfg,
642
+ label_emb_dim=label_emb_dim,
643
+ src_emb_dim=src_emb_dim,
644
+ behind_emb_dim=behind_emb_dim,
645
+ fourier_seed=fourier_seed,
646
+ use_vote_features=use_vote_features,
647
+ learnable_fourier=learnable_fourier,
648
+ )
649
+
650
+ if arch == "transformer":
651
+ self.segmenter = TransformerSegments(
652
+ segments=segments,
653
+ in_dim=self.tokenizer.out_dim,
654
+ hidden=hidden,
655
+ num_heads=num_heads,
656
+ kv_heads_cross=kv_heads_cross,
657
+ kv_heads_self=kv_heads_self,
658
+ dim_feedforward=dim_feedforward,
659
+ dropout=dropout,
660
+ encoder_layers=encoder_layers,
661
+ decoder_layers=decoder_layers,
662
+ norm_class=norm_class,
663
+ activation=activation,
664
+ segment_conf=segment_conf,
665
+ segment_param=segment_param,
666
+ length_floor=length_floor,
667
+ decoder_input_xattn=decoder_input_xattn,
668
+ qk_norm=qk_norm, qk_norm_type=qk_norm_type,
669
+ )
670
+ else:
671
+ self.segmenter = TokenTransformerSegments(
672
+ segments=segments,
673
+ in_dim=self.tokenizer.out_dim,
674
+ hidden=hidden,
675
+ num_heads=num_heads,
676
+ kv_heads_cross=kv_heads_cross,
677
+ kv_heads_self=kv_heads_self,
678
+ dim_feedforward=dim_feedforward,
679
+ dropout=dropout,
680
+ latent_tokens=latent_tokens,
681
+ latent_layers=latent_layers,
682
+ decoder_layers=decoder_layers,
683
+ cross_attn_interval=cross_attn_interval,
684
+ norm_class=norm_class,
685
+ activation=activation,
686
+ segment_conf=segment_conf,
687
+ pre_encoder_layers=pre_encoder_layers,
688
+ segment_param=segment_param,
689
+ length_floor=length_floor,
690
+ decoder_input_xattn=decoder_input_xattn,
691
+ qk_norm=qk_norm, qk_norm_type=qk_norm_type,
692
+ )
693
+
694
+ def forward_tokens(self, tokens: torch.Tensor, mask: torch.Tensor):
695
+ """Run the segmenter on pre-built token tensors."""
696
+ return self.segmenter(tokens, mask)
s23dr_2026_example/point_fusion.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ point_fusion.py
3
+
4
+ Simplified semantic point fusion for the 2026 dataset format.
5
+
6
+ Takes per-view (ADE segmap, Gestalt segmap, depth) + sparse COLMAP point cloud
7
+ from the usm3d/hoho22k_2026_trainval dataset and builds a compact, house-centric
8
+ semantic point representation suitable for downstream wireframe prediction.
9
+
10
+ Key differences from the 2025 pipeline:
11
+ - COLMAP is a ZIP of text files (cameras.txt, images.txt, points3D.txt)
12
+ - Depth is millimeter I;16 PNG (depth_scale=0.001 converts to meters)
13
+ - Views flagged with pose_only_in_colmap=True have zeroed K/R/t and must be
14
+ skipped for depth unprojection and projection
15
+ - Images arrive as PIL Images, not byte arrays
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import zipfile
21
+ from dataclasses import dataclass
22
+ from io import BytesIO
23
+ from typing import Dict, List, Optional, Tuple
24
+
25
+ import cv2
26
+ import numpy as np
27
+ from scipy.stats import mode as scipy_mode
28
+
29
+ from .color_mappings import ade20k_color_mapping, gestalt_color_mapping
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Color packing helpers
33
+ # ---------------------------------------------------------------------------
34
+
35
+ def _pack_rgb_u32(rgb: np.ndarray) -> np.ndarray:
36
+ """Pack uint8 RGB (..., 3) into uint32 codes."""
37
+ rgb = rgb.astype(np.uint32, copy=False)
38
+ return (rgb[..., 0] << 16) | (rgb[..., 1] << 8) | rgb[..., 2]
39
+
40
+
41
+ def _build_rgbcode_maps(color_mapping):
42
+ """Return (rgbcode_to_id, id_to_name) for a color mapping dict."""
43
+ names = list(color_mapping.keys())
44
+ rgbs = np.array([color_mapping[n] for n in names], dtype=np.uint8)
45
+ codes = _pack_rgb_u32(rgbs.reshape(-1, 1, 3)).reshape(-1)
46
+ rgbcode_to_id = {int(c): i for i, c in enumerate(codes)}
47
+ return rgbcode_to_id, names
48
+
49
+
50
+ def _name_to_packed_rgb(name, mapping):
51
+ """Case-insensitive lookup returning a packed RGB code, or None."""
52
+ for key in mapping:
53
+ if key.lower() == name.lower():
54
+ rgb = np.array(mapping[key], np.uint8).reshape(1, 1, 3)
55
+ return int(_pack_rgb_u32(rgb).reshape(()))
56
+ return None
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Label mapping constants
60
+ # ---------------------------------------------------------------------------
61
+
62
+ ADE_RGBCODE_TO_ID, ADE_ID_TO_NAME = _build_rgbcode_maps(ade20k_color_mapping)
63
+ GEST_RGBCODE_TO_ID, GEST_ID_TO_NAME = _build_rgbcode_maps(gestalt_color_mapping)
64
+ NUM_ADE = len(ADE_ID_TO_NAME)
65
+ NUM_GEST = len(GEST_ID_TO_NAME)
66
+
67
+ GEST_INVALID_NAMES = ("unclassified", "unknown", "transition_line")
68
+ GEST_INVALID_CODES = set(
69
+ int(_pack_rgb_u32(np.array(gestalt_color_mapping[n], np.uint8).reshape(1, 1, 3)).reshape(()))
70
+ for n in GEST_INVALID_NAMES if n in gestalt_color_mapping
71
+ )
72
+
73
+ # ADE classes whose surfaces are "see-through" for label fusion: when a point
74
+ # projects onto one of these, we use the Gestalt label behind it instead.
75
+ ADE_TRANSPARENT_NAMES = (
76
+ "wall", "building;edifice", "floor;flooring", "ceiling",
77
+ "windowpane;window", "door;double;door", "house", "skyscraper",
78
+ "screen;door;screen", "blind;screen", "hovel;hut;hutch;shack;shanty",
79
+ "tower", "booth;cubicle;stall;kiosk",
80
+ )
81
+
82
+ # ADE classes kept as "occluders/add-ons" when overlapping the house silhouette.
83
+ ADE_OCCLUDER_ALLOWLIST_NAMES = (
84
+ "tree", "person;individual;someone;somebody;mortal;soul",
85
+ "car;auto;automobile;machine;motorcar", "truck;motortruck", "van",
86
+ "fence;fencing", "railing;rail",
87
+ "bannister;banister;balustrade;balusters;handrail",
88
+ "stairs;steps", "stairway;staircase", "step;stair", "pole",
89
+ "streetlight;street;lamp", "signboard;sign", "awning;sunshade;sunblind",
90
+ "plant;flora;plant;life", "pot;flowerpot",
91
+ )
92
+
93
+ # Precomputed arrays for the default name lists (avoids re-lookup every call).
94
+ _DEFAULT_ADE_TRANSPARENT_CODES = np.array(
95
+ [c for n in ADE_TRANSPARENT_NAMES
96
+ if (c := _name_to_packed_rgb(n, ade20k_color_mapping)) is not None],
97
+ dtype=np.uint32,
98
+ )
99
+ _DEFAULT_ADE_OCCLUDER_IDS = np.array(
100
+ sorted({ADE_RGBCODE_TO_ID[c]
101
+ for n in ADE_OCCLUDER_ALLOWLIST_NAMES
102
+ if (c := _name_to_packed_rgb(n, ade20k_color_mapping)) is not None
103
+ and c in ADE_RGBCODE_TO_ID}),
104
+ dtype=np.int32,
105
+ )
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # Config
109
+ # ---------------------------------------------------------------------------
110
+
111
+ @dataclass(frozen=True)
112
+ class FuserConfig:
113
+ """Simplified fusion configuration (no depth calibration fields)."""
114
+ depth_points_per_view: int = 20_000 # depth samples per view
115
+ depth_scale: float = 0.001 # mm -> meters
116
+ depth_clip_percentile: float = 99.5 # drop extreme outliers
117
+ house_mask_dilate_px: int = 5 # dilate gestalt mask
118
+ min_support_views: int = 1 # min views for a kept point
119
+ ade_transparent_classes: Tuple[str, ...] = ADE_TRANSPARENT_NAMES
120
+ ade_occluder_allowlist: Tuple[str, ...] = ADE_OCCLUDER_ALLOWLIST_NAMES
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Geometry: projection + depth unprojection
124
+ # ---------------------------------------------------------------------------
125
+
126
+ def project_world_points(points_world, K, R, t):
127
+ """Project (N,3) world points to pixel (u,v) with validity mask."""
128
+ pts = points_world.astype(np.float32, copy=False)
129
+ cam = (R @ pts.T + t).T # (N, 3)
130
+ z = cam[:, 2]
131
+ valid = z > 1e-6
132
+ inv_z = np.zeros_like(z)
133
+ inv_z[valid] = 1.0 / z[valid]
134
+ x = cam[:, 0] * inv_z
135
+ y = cam[:, 1] * inv_z
136
+ u = K[0, 0] * x + K[0, 2]
137
+ v = K[1, 1] * y + K[1, 2]
138
+ return u, v, valid
139
+
140
+
141
+ def unproject_depth_to_world(depth, K, R, t, num_points, sample_mask=None, rng=None):
142
+ """Convert a depth map + camera params to (M, 3) world points, M <= num_points."""
143
+ if rng is None:
144
+ rng = np.random.default_rng()
145
+ d = np.asarray(depth, dtype=np.float32)
146
+ if d.ndim != 2:
147
+ return np.zeros((0, 3), dtype=np.float32)
148
+
149
+ valid = np.isfinite(d) & (d > 1e-6)
150
+ if sample_mask is not None:
151
+ mask = np.asarray(sample_mask, dtype=bool)
152
+ if mask.shape != d.shape:
153
+ return np.zeros((0, 3), dtype=np.float32)
154
+ valid &= mask
155
+
156
+ ys, xs = np.where(valid)
157
+ if ys.size == 0:
158
+ return np.zeros((0, 3), dtype=np.float32)
159
+
160
+ idx = rng.choice(ys.size, size=min(num_points, ys.size), replace=False)
161
+ y = ys[idx].astype(np.float32)
162
+ x = xs[idx].astype(np.float32)
163
+ z = d[ys[idx], xs[idx]].astype(np.float32)
164
+
165
+ fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
166
+ cam_pts = np.stack([(x - cx) * z / fx, (y - cy) * z / fy, z], axis=0)
167
+ # cam = R * world + t => world = R^T * (cam - t)
168
+ world = (R.T @ (cam_pts - t)).T
169
+ return world.astype(np.float32, copy=False)
170
+
171
+
172
+ def clean_depth(depth, clip_percentile):
173
+ """Clip extreme depth values."""
174
+ d = np.asarray(depth, dtype=np.float32)
175
+ d = np.where(np.isfinite(d), d, 0.0)
176
+ d[d <= 0] = 0.0
177
+ if clip_percentile is not None and clip_percentile > 0 and np.any(d > 0):
178
+ hi = float(np.percentile(d[d > 0], clip_percentile))
179
+ d = np.clip(d, 0.0, hi)
180
+ return d
181
+
182
+
183
+ def dilate_mask(mask, radius_px):
184
+ """Binary dilation via cv2. mask: (H, W) bool."""
185
+ if radius_px <= 0:
186
+ return mask
187
+ k = 2 * radius_px + 1
188
+ kernel = np.ones((k, k), np.uint8)
189
+ return cv2.dilate(mask.astype(np.uint8), kernel) > 0
190
+
191
+ # ---------------------------------------------------------------------------
192
+ # COLMAP extraction (2026 format)
193
+ # ---------------------------------------------------------------------------
194
+
195
+ def extract_colmap_points_2026(sample):
196
+ """Extract (N, 3) float32 COLMAP world points from a 2026-format sample.
197
+
198
+ sample['colmap'] must be a ZIP archive containing points3D.txt.
199
+ Fails fast if that file is missing (it is always present in the 2026 format).
200
+ """
201
+ colmap_blob = sample.get("colmap")
202
+ if colmap_blob is None:
203
+ return np.zeros((0, 3), dtype=np.float32)
204
+ if not isinstance(colmap_blob, (bytes, bytearray, memoryview)):
205
+ return np.zeros((0, 3), dtype=np.float32)
206
+
207
+ try:
208
+ with zipfile.ZipFile(BytesIO(colmap_blob)) as zf:
209
+ if "points3D.txt" not in set(zf.namelist()):
210
+ raise FileNotFoundError(
211
+ "COLMAP ZIP is missing points3D.txt -- "
212
+ "this is required in the 2026 dataset format")
213
+ with zf.open("points3D.txt") as f:
214
+ text = f.read().decode("utf-8", errors="ignore")
215
+ # Format: POINT3D_ID X Y Z R G B ERROR TRACK[]
216
+ # Filter comment/blank lines, parse columns 1-3 (X,Y,Z)
217
+ from io import StringIO
218
+ clean = "\n".join(l for l in text.split("\n") if l and not l.startswith("#"))
219
+ if not clean:
220
+ return np.zeros((0, 3), dtype=np.float32)
221
+ return np.loadtxt(StringIO(clean), dtype=np.float32, usecols=(1, 2, 3))
222
+ except zipfile.BadZipFile:
223
+ pass
224
+ return np.zeros((0, 3), dtype=np.float32)
225
+
226
+ # ---------------------------------------------------------------------------
227
+ # Label helpers
228
+ # ---------------------------------------------------------------------------
229
+
230
+ def _codes_from_image(img):
231
+ """Convert a PIL Image or numpy array to a (H, W) uint32 packed-RGB map."""
232
+ arr = np.asarray(img)
233
+ if arr.ndim == 2:
234
+ arr = np.stack([arr, arr, arr], axis=-1)
235
+ arr = arr[..., :3]
236
+ if arr.dtype != np.uint8:
237
+ arr = np.clip(arr, 0, 255).astype(np.uint8)
238
+ return _pack_rgb_u32(arr)
239
+
240
+
241
+ def _row_majority(values):
242
+ """Row-wise majority vote on (P, V) int array; -1 means "no vote".
243
+ Returns (P,) with the most frequent non-negative value per row, or -1.
244
+
245
+ Masks -1 entries before voting so that abstentions don't outvote
246
+ actual labels (which happens when a point is visible in only 1-2 views).
247
+ """
248
+ P, V = values.shape
249
+ result = np.full(P, -1, dtype=values.dtype)
250
+
251
+ # For each row, find the most frequent non-negative value.
252
+ # Vectorized approach: flatten valid entries per row using argmax on counts.
253
+ # Since values are typically small non-negative ints (0-200), we can use
254
+ # a simple max-of-first-valid approach for speed when V is small.
255
+ for vi in range(V):
256
+ # For rows still unset, take the first valid vote
257
+ col = values[:, vi]
258
+ unset = result == -1
259
+ has_val = col >= 0
260
+ update = unset & has_val
261
+ result[update] = col[update]
262
+
263
+ # Now refine: if a row has multiple different valid votes, pick the mode.
264
+ # Check if any row has conflicting votes across views.
265
+ has_any = np.any(values >= 0, axis=1)
266
+ n_valid = np.sum(values >= 0, axis=1)
267
+ needs_vote = has_any & (n_valid > 1)
268
+
269
+ if np.any(needs_vote):
270
+ for i in np.where(needs_vote)[0]:
271
+ valid = values[i][values[i] >= 0]
272
+ # Use numpy bincount for speed (values are small non-neg ints)
273
+ counts = np.bincount(valid.astype(np.intp))
274
+ result[i] = counts.argmax()
275
+
276
+ return result
277
+
278
+ # ---------------------------------------------------------------------------
279
+ # Semantic fusion: house-centric, occluder-aware
280
+ # ---------------------------------------------------------------------------
281
+
282
+ def _fuse_labels_for_points(
283
+ points_world, Ks, Rs, ts, ade_images, gestalt_images,
284
+ ade_transparent_codes, ade_occluder_allowed_ids,
285
+ min_support_views, valid_view_mask=None,
286
+ ):
287
+ """Multi-view semantic label fusion with majority voting.
288
+
289
+ For each 3D point, project into every valid view:
290
+ - ADE "envelope" class -> use the Gestalt label behind it.
291
+ - ADE non-envelope -> keep if on the occluder allowlist.
292
+ Then majority-vote across views.
293
+
294
+ Returns dict: keep, visible_src, visible_id, behind_gest_id, support
295
+ """
296
+ P = points_world.shape[0]
297
+ V = min(len(Ks), len(Rs), len(ts), len(ade_images), len(gestalt_images))
298
+ empty = {
299
+ "keep": np.zeros(P, dtype=bool),
300
+ "visible_src": np.zeros(P, np.uint8),
301
+ "visible_id": np.full(P, -1, np.int16),
302
+ "behind_gest_id": np.full(P, -1, np.int16),
303
+ "support": np.zeros(P, np.uint8),
304
+ }
305
+ if P == 0 or V == 0:
306
+ return empty
307
+
308
+ # Per-view labels. src: 1=gestalt, 2=ade; -1 = no contribution.
309
+ visible_src_pv = np.full((P, V), -1, dtype=np.int8)
310
+ visible_id_pv = np.full((P, V), -1, dtype=np.int32)
311
+ behind_id_pv = np.full((P, V), -1, dtype=np.int32)
312
+ support = np.zeros(P, dtype=np.int32)
313
+
314
+ ade_allowed_set = set(ade_occluder_allowed_ids.tolist())
315
+ ade_transparent_u32 = ade_transparent_codes.astype(np.uint32, copy=False)
316
+ gest_invalid_arr = np.array(list(GEST_INVALID_CODES), dtype=np.uint32)
317
+
318
+ for vi in range(V):
319
+ if valid_view_mask is not None and not valid_view_mask[vi]:
320
+ continue
321
+
322
+ K = np.asarray(Ks[vi], np.float32)
323
+ R = np.asarray(Rs[vi], np.float32)
324
+ t = np.asarray(ts[vi], np.float32).reshape(3, 1)
325
+
326
+ ade_codes_img = _codes_from_image(ade_images[vi])
327
+ gest_codes_img = _codes_from_image(gestalt_images[vi])
328
+ H, W = ade_codes_img.shape
329
+
330
+ u, v, valid = project_world_points(points_world, K, R, t)
331
+ in_img = valid & (u >= 0) & (u < W) & (v >= 0) & (v < H)
332
+ if not np.any(in_img):
333
+ continue
334
+
335
+ ui = np.clip(np.round(u[in_img]).astype(np.int32), 0, W - 1)
336
+ vi_pix = np.clip(np.round(v[in_img]).astype(np.int32), 0, H - 1)
337
+ ade_codes = ade_codes_img[vi_pix, ui]
338
+ gest_codes = gest_codes_img[vi_pix, ui]
339
+
340
+ in_house = ~np.isin(gest_codes, gest_invalid_arr)
341
+ if not np.any(in_house):
342
+ continue
343
+
344
+ idx = np.where(in_img)[0][in_house]
345
+ ade_codes_h = ade_codes[in_house]
346
+ gest_codes_h = gest_codes[in_house]
347
+
348
+ behind_local = np.array(
349
+ [GEST_RGBCODE_TO_ID.get(int(c), -1) for c in gest_codes_h],
350
+ dtype=np.int32)
351
+ behind_id_pv[idx, vi] = behind_local
352
+
353
+ ade_is_transparent = np.isin(ade_codes_h, ade_transparent_u32)
354
+
355
+ # Case A: ADE is envelope -- use Gestalt label.
356
+ mask_a = ade_is_transparent & (behind_local >= 0)
357
+ if np.any(mask_a):
358
+ visible_src_pv[idx[mask_a], vi] = 1
359
+ visible_id_pv[idx[mask_a], vi] = behind_local[mask_a]
360
+
361
+ # Case B: ADE is non-envelope -- use ADE label (allowlist-filtered).
362
+ mask_b = ~ade_is_transparent
363
+ if np.any(mask_b):
364
+ ade_local = np.array(
365
+ [ADE_RGBCODE_TO_ID.get(int(c), -1) for c in ade_codes_h[mask_b]],
366
+ dtype=np.int32)
367
+ on_allowlist = np.array(
368
+ [int(a) in ade_allowed_set for a in ade_local], dtype=bool
369
+ ) & (ade_local >= 0)
370
+ if np.any(on_allowlist):
371
+ visible_src_pv[idx[mask_b][on_allowlist], vi] = 2
372
+ visible_id_pv[idx[mask_b][on_allowlist], vi] = ade_local[on_allowlist]
373
+
374
+ support[idx] += 1
375
+
376
+ # ---- Aggregate across views via majority vote ----
377
+ keep = (support >= min_support_views) & np.any(visible_src_pv >= 0, axis=1)
378
+
379
+ # Combine (src, id) into a single key for voting, then split back.
380
+ # src in {1,2} and id in [0, ~150], so stride=100k avoids collisions.
381
+ VIS_STRIDE = 100_000
382
+ vis_key = np.where(
383
+ visible_src_pv >= 0,
384
+ visible_src_pv.astype(np.int64) * VIS_STRIDE + visible_id_pv.astype(np.int64),
385
+ -1)
386
+ voted_key = _row_majority(vis_key)
387
+ voted_behind = _row_majority(behind_id_pv)
388
+
389
+ final_src = np.zeros(P, dtype=np.uint8)
390
+ final_id = np.full(P, -1, dtype=np.int16)
391
+ ok = voted_key >= 0
392
+ if np.any(ok):
393
+ final_src[ok] = (voted_key[ok] // VIS_STRIDE).astype(np.uint8)
394
+ final_id[ok] = (voted_key[ok] % VIS_STRIDE).astype(np.int16)
395
+
396
+ # ---- Vote confidence metadata ----
397
+ n_views_voted = np.sum(visible_src_pv >= 0, axis=1).astype(np.uint8)
398
+
399
+ # Fraction of voting views that agreed with the majority label
400
+ vote_frac = np.zeros(P, dtype=np.float32)
401
+ if np.any(ok):
402
+ for i in np.where(ok)[0]:
403
+ votes = vis_key[i][vis_key[i] >= 0]
404
+ if len(votes) > 0:
405
+ vote_frac[i] = (votes == voted_key[i]).sum() / len(votes)
406
+
407
+ return {
408
+ "keep": keep,
409
+ "visible_src": final_src,
410
+ "visible_id": final_id,
411
+ "behind_gest_id": voted_behind.astype(np.int16),
412
+ "support": support.astype(np.uint8),
413
+ "n_views_voted": n_views_voted,
414
+ "vote_frac": vote_frac,
415
+ }
416
+
417
+ # ---------------------------------------------------------------------------
418
+ # Compact scene builder (2026 dataset format)
419
+ # ---------------------------------------------------------------------------
420
+
421
+ def _resolve_ade_codes(cfg):
422
+ """Return (transparent_codes, occluder_ids) for the given config.
423
+ Uses precomputed module-level arrays when the config has default names.
424
+ """
425
+ if cfg.ade_transparent_classes == ADE_TRANSPARENT_NAMES:
426
+ transparent = _DEFAULT_ADE_TRANSPARENT_CODES
427
+ else:
428
+ transparent = np.array(
429
+ [c for n in cfg.ade_transparent_classes
430
+ if (c := _name_to_packed_rgb(n, ade20k_color_mapping)) is not None],
431
+ dtype=np.uint32)
432
+
433
+ if cfg.ade_occluder_allowlist == ADE_OCCLUDER_ALLOWLIST_NAMES:
434
+ occluder_ids = _DEFAULT_ADE_OCCLUDER_IDS
435
+ else:
436
+ occluder_ids = np.array(
437
+ sorted({ADE_RGBCODE_TO_ID[c]
438
+ for n in cfg.ade_occluder_allowlist
439
+ if (c := _name_to_packed_rgb(n, ade20k_color_mapping)) is not None
440
+ and c in ADE_RGBCODE_TO_ID}),
441
+ dtype=np.int32)
442
+ return transparent, occluder_ids
443
+
444
+
445
+ def _parse_gt_array(sample, key, dtype, expected_cols):
446
+ """Parse an optional ground-truth array from the sample dict."""
447
+ raw = sample.get(key)
448
+ if raw is None:
449
+ return None
450
+ arr = np.asarray(raw, dtype=dtype)
451
+ if arr.ndim == 2 and arr.shape[1] == expected_cols:
452
+ return arr
453
+ return None
454
+
455
+
456
+ def build_compact_scene(sample, cfg, rng):
457
+ """Build a compact semantic point representation from a HuggingFace sample.
458
+
459
+ Expected sample keys: K, R, t, ade, gestalt, depth, colmap,
460
+ pose_only_in_colmap, wf_vertices (opt), wf_edges (opt), __key__ (opt).
461
+
462
+ Returns dict (xyz, source, visible_src, visible_id, behind_gest_id,
463
+ gt_vertices, gt_edges, sample_id) or None if no points survive fusion.
464
+ """
465
+ Ks = sample.get("K") or []
466
+ Rs = sample.get("R") or []
467
+ ts = sample.get("t") or []
468
+ ade_imgs = sample.get("ade") or []
469
+ gest_imgs = sample.get("gestalt") or []
470
+ depths = sample.get("depth") or []
471
+ pose_flags = sample.get("pose_only_in_colmap") or []
472
+
473
+ V = min(len(Ks), len(Rs), len(ts), len(ade_imgs), len(gest_imgs))
474
+ if V == 0:
475
+ return None
476
+
477
+ valid_view = [not (vi < len(pose_flags) and pose_flags[vi]) for vi in range(V)]
478
+ if not any(valid_view):
479
+ return None
480
+
481
+ # ---- COLMAP points ----
482
+ colmap_pts = extract_colmap_points_2026(sample)
483
+
484
+ # ---- Precompute house masks (from Gestalt), optionally dilated ----
485
+ gest_invalid_arr = np.array(list(GEST_INVALID_CODES), dtype=np.uint32)
486
+ house_masks = []
487
+ for vi in range(V):
488
+ if not valid_view[vi]:
489
+ house_masks.append(None)
490
+ continue
491
+ mask = ~np.isin(_codes_from_image(gest_imgs[vi]), gest_invalid_arr)
492
+ if cfg.house_mask_dilate_px > 0:
493
+ mask = dilate_mask(mask, cfg.house_mask_dilate_px)
494
+ house_masks.append(mask)
495
+
496
+ # ---- Sample depth points per view ----
497
+ depth_points_all = []
498
+ for vi in range(min(V, len(depths))):
499
+ if not valid_view[vi] or depths[vi] is None:
500
+ continue
501
+ d = clean_depth(
502
+ np.asarray(depths[vi], dtype=np.float32) * cfg.depth_scale,
503
+ cfg.depth_clip_percentile)
504
+ pts = unproject_depth_to_world(
505
+ depth=d,
506
+ K=np.asarray(Ks[vi], np.float32),
507
+ R=np.asarray(Rs[vi], np.float32),
508
+ t=np.asarray(ts[vi], np.float32).reshape(3, 1),
509
+ num_points=cfg.depth_points_per_view,
510
+ sample_mask=house_masks[vi], rng=rng)
511
+ if pts.shape[0]:
512
+ depth_points_all.append(pts)
513
+
514
+ # ---- Combine COLMAP + depth points ----
515
+ pts_list, src_list = [], []
516
+ if colmap_pts.shape[0]:
517
+ pts_list.append(colmap_pts)
518
+ src_list.append(np.zeros(colmap_pts.shape[0], dtype=np.uint8)) # 0=colmap
519
+ if depth_points_all:
520
+ all_depth = np.concatenate(depth_points_all, axis=0)
521
+ pts_list.append(all_depth)
522
+ src_list.append(np.ones(all_depth.shape[0], dtype=np.uint8)) # 1=depth
523
+ if not pts_list:
524
+ return None
525
+
526
+ points_world = np.concatenate(pts_list, axis=0).astype(np.float32, copy=False)
527
+ point_source = np.concatenate(src_list, axis=0).astype(np.uint8, copy=False)
528
+
529
+ # ---- Fuse semantic labels ----
530
+ ade_transparent_arr, ade_allow_ids = _resolve_ade_codes(cfg)
531
+ fused = _fuse_labels_for_points(
532
+ points_world=points_world, Ks=Ks, Rs=Rs, ts=ts,
533
+ ade_images=ade_imgs, gestalt_images=gest_imgs,
534
+ ade_transparent_codes=ade_transparent_arr,
535
+ ade_occluder_allowed_ids=ade_allow_ids,
536
+ min_support_views=cfg.min_support_views,
537
+ valid_view_mask=valid_view)
538
+
539
+ keep = fused["keep"]
540
+ if not np.any(keep):
541
+ return None
542
+
543
+ return {
544
+ "xyz": points_world[keep],
545
+ "source": point_source[keep], # 0=colmap, 1=monodepth
546
+ "visible_src": fused["visible_src"][keep], # 1=gestalt, 2=ade
547
+ "visible_id": fused["visible_id"][keep],
548
+ "behind_gest_id": fused["behind_gest_id"][keep],
549
+ "n_views_voted": fused["n_views_voted"][keep],
550
+ "vote_frac": fused["vote_frac"][keep],
551
+ "gt_vertices": _parse_gt_array(sample, "wf_vertices", np.float32, 3),
552
+ "gt_edges": _parse_gt_array(sample, "wf_edges", np.int64, 2),
553
+ "sample_id": sample.get("__key__", None),
554
+ }
s23dr_2026_example/postprocess_v2.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Post-processing functions for segment predictions."""
2
+ import numpy as np
3
+
4
+
5
+ def snap_to_point_cloud(vertices, xyz, class_id, snap_radius=0.5,
6
+ target_classes=None):
7
+ """Snap vertices to nearby point cloud clusters of specific semantic classes."""
8
+ if target_classes is None:
9
+ target_classes = [1, 2] # apex, eave_end_point
10
+
11
+ snapped = vertices.copy()
12
+ mask = np.isin(class_id, target_classes)
13
+
14
+ if mask.sum() < 2:
15
+ return snapped
16
+
17
+ target_pts = xyz[mask]
18
+
19
+ for i, v in enumerate(vertices):
20
+ dists = np.linalg.norm(target_pts - v, axis=-1)
21
+ close = dists < snap_radius
22
+ if close.sum() >= 2:
23
+ snapped[i] = target_pts[close].mean(axis=0)
24
+
25
+ return snapped
26
+
27
+
28
+ def snap_horizontal(vertices, edges, max_slope=0.05):
29
+ """Snap near-horizontal edges to be exactly horizontal."""
30
+ verts = vertices.copy()
31
+ for a, b in edges:
32
+ a, b = int(a), int(b)
33
+ dy = abs(verts[a, 1] - verts[b, 1])
34
+ dxz = np.sqrt((verts[a, 0] - verts[b, 0])**2 + (verts[a, 2] - verts[b, 2])**2)
35
+ if dxz > 0.1 and dy / dxz < max_slope:
36
+ avg_y = 0.5 * (verts[a, 1] + verts[b, 1])
37
+ verts[a, 1] = avg_y
38
+ verts[b, 1] = avg_y
39
+ return verts
s23dr_2026_example/segment_postprocess.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+
6
+ def merge_vertices(vertices: np.ndarray, edges: np.ndarray, thresh: float):
7
+ verts = np.asarray(vertices, dtype=np.float32)
8
+ edges = np.asarray(edges, dtype=np.int64)
9
+ if verts.size == 0 or edges.size == 0:
10
+ return verts, edges
11
+
12
+ n = verts.shape[0]
13
+ parent = np.arange(n, dtype=np.int64)
14
+
15
+ def find(i):
16
+ while parent[i] != i:
17
+ parent[i] = parent[parent[i]]
18
+ i = parent[i]
19
+ return i
20
+
21
+ def union(i, j):
22
+ ri = find(i)
23
+ rj = find(j)
24
+ if ri != rj:
25
+ parent[rj] = ri
26
+
27
+ for i in range(n):
28
+ vi = verts[i]
29
+ for j in range(i + 1, n):
30
+ if np.linalg.norm(vi - verts[j]) <= thresh:
31
+ union(i, j)
32
+
33
+ clusters = {}
34
+ for i in range(n):
35
+ root = find(i)
36
+ clusters.setdefault(root, []).append(i)
37
+
38
+ new_vertices = []
39
+ mapping = {}
40
+ for new_idx, idxs in enumerate(clusters.values()):
41
+ pts = verts[idxs]
42
+ center = pts.mean(axis=0)
43
+ new_vertices.append(center)
44
+ for i in idxs:
45
+ mapping[i] = new_idx
46
+
47
+ new_edges = []
48
+ seen = set()
49
+ for a, b in edges:
50
+ na = mapping.get(int(a), int(a))
51
+ nb = mapping.get(int(b), int(b))
52
+ if na == nb:
53
+ continue
54
+ key = (na, nb) if na <= nb else (nb, na)
55
+ if key in seen:
56
+ continue
57
+ seen.add(key)
58
+ new_edges.append([na, nb])
59
+
60
+ return np.asarray(new_vertices, dtype=np.float32), np.asarray(new_edges, dtype=np.int64)
s23dr_2026_example/sinkhorn.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sinkhorn optimal transport loss for segment matching.
2
+
3
+ Note: at eps=0.05, sinkhorn gradients are near-zero (~1e-7 norm) for
4
+ typical matrix sizes. The loss value is tracked but does not meaningfully
5
+ train the model. Default sinkhorn_weight=0.0. See worklog.md for details.
6
+
7
+ Future: schedule eps from large (1.0) to small (0.05) during training
8
+ to get useful gradients early and precise matching late.
9
+ """
10
+ import torch
11
+
12
+
13
+ def segment_pair_cost(pred_segments: torch.Tensor, gt_segments: torch.Tensor) -> torch.Tensor:
14
+ """Cost between pred and GT segments: midpoint + direction + length (decoupled).
15
+ pred_segments: [N, 2, 3], gt_segments: [M, 2, 3] -> [N, M]
16
+ """
17
+ p0, p1 = pred_segments[:, 0], pred_segments[:, 1]
18
+ g0, g1 = gt_segments[:, 0], gt_segments[:, 1]
19
+ mid_p, half_p = 0.5 * (p0 + p1), 0.5 * (p1 - p0)
20
+ mid_g, half_g = 0.5 * (g0 + g1), 0.5 * (g1 - g0)
21
+ d_mid = torch.cdist(mid_p, mid_g)
22
+ len_p = torch.linalg.norm(half_p, dim=-1, keepdim=True).clamp(min=1e-6)
23
+ len_g = torch.linalg.norm(half_g, dim=-1, keepdim=True).clamp(min=1e-6)
24
+ dir_p = half_p / len_p
25
+ dir_g = half_g / len_g
26
+ cos_angle = (dir_p[:, None, :] * dir_g[None, :, :]).sum(dim=-1)
27
+ d_dir = 1.0 - cos_angle.abs()
28
+ d_len = (len_p[:, None, :] - len_g[None, :, :]).squeeze(-1).abs()
29
+ return d_mid + d_dir + d_len
30
+
31
+
32
+
33
+ def batched_sinkhorn_loss(
34
+ pred_segments: torch.Tensor,
35
+ gt_pad: torch.Tensor,
36
+ gt_mask: torch.Tensor,
37
+ eps: float,
38
+ iters: int,
39
+ dustbin_cost: float | torch.Tensor,
40
+ pred_mass: torch.Tensor | None = None,
41
+ ) -> torch.Tensor:
42
+ """Batched sinkhorn segment matching loss.
43
+
44
+ Args:
45
+ pred_segments: [B, S, 2, 3] predicted segments
46
+ gt_pad: [B, M, 2, 3] padded GT segments
47
+ gt_mask: [B, M] bool mask (True = valid GT segment)
48
+ eps: sinkhorn regularization
49
+ iters: sinkhorn iterations
50
+ dustbin_cost: cost for unmatched segments (scalar or [B])
51
+ pred_mass: [B, S] per-segment mass weights (e.g. sigmoid(conf)).
52
+ If None, uniform masses are used.
53
+
54
+ Returns:
55
+ [B] per-sample sinkhorn transport cost
56
+ """
57
+ B, S, _, _ = pred_segments.shape
58
+ M = gt_pad.shape[1]
59
+
60
+ # Allow per-sample dustbin cost
61
+ dc = torch.as_tensor(dustbin_cost, device=pred_segments.device, dtype=pred_segments.dtype)
62
+ if dc.dim() == 0:
63
+ dc = dc.expand(B)
64
+
65
+ # Compute cost matrices [B, S, M] in midpoint-halfvec space.
66
+ # Decouples position from direction: mid gradient is pure position,
67
+ # half gradient is pure direction/length. Sign-invariance on half
68
+ # handles segment direction ambiguity cleanly.
69
+ p0 = pred_segments[:, :, 0] # [B, S, 3]
70
+ p1 = pred_segments[:, :, 1] # [B, S, 3]
71
+ g0 = gt_pad[:, :, 0] # [B, M, 3]
72
+ g1 = gt_pad[:, :, 1] # [B, M, 3]
73
+
74
+ mid_pred = 0.5 * (p0 + p1) # [B, S, 3]
75
+ half_pred = 0.5 * (p1 - p0) # [B, S, 3]
76
+ mid_gt = 0.5 * (g0 + g1) # [B, M, 3]
77
+ half_gt = 0.5 * (g1 - g0) # [B, M, 3]
78
+
79
+ # Midpoint distance [B, S, M]
80
+ d_mid = torch.linalg.norm(
81
+ mid_pred.unsqueeze(2) - mid_gt.unsqueeze(1), dim=-1)
82
+
83
+ # Decoupled direction + length distance (sign-invariant for direction ambiguity)
84
+ len_pred = torch.linalg.norm(half_pred, dim=-1, keepdim=True).clamp(min=1e-6) # [B, S, 1]
85
+ len_gt = torch.linalg.norm(half_gt, dim=-1, keepdim=True).clamp(min=1e-6) # [B, M, 1]
86
+ dir_pred = half_pred / len_pred # [B, S, 3]
87
+ dir_gt = half_gt / len_gt # [B, M, 3]
88
+
89
+ # Direction distance: 1 - |cos(angle)|, sign-invariant [B, S, M]
90
+ cos_angle = (dir_pred.unsqueeze(2) * dir_gt.unsqueeze(1)).sum(dim=-1) # [B, S, M]
91
+ d_dir = 1.0 - cos_angle.abs()
92
+
93
+ # Length distance [B, S, M]
94
+ d_len = (len_pred.unsqueeze(2) - len_gt.unsqueeze(1)).squeeze(-1).abs()
95
+
96
+ cost = d_mid + d_dir + d_len # [B, S, M]
97
+
98
+ # Mask invalid GT segments with high cost so they go to dustbin
99
+ cost = torch.where(gt_mask.unsqueeze(1), cost, dc[:, None, None] * 10.0)
100
+
101
+ # Pad with dustbin row and column: [B, S+1, M+1]
102
+ cost_pad = dc[:, None, None].expand(B, S + 1, M + 1).clone()
103
+ cost_pad[:, :S, :M] = cost
104
+ cost_pad[:, -1, -1] = 0.0
105
+
106
+ # Masses
107
+ gt_counts = gt_mask.sum(dim=1).float() # [B]
108
+
109
+ if pred_mass is not None:
110
+ # Confidence-weighted masses (matches learned_v2 approach).
111
+ # sigmoid(conf) gives per-segment mass; dustbin masses balance the totals.
112
+ # No normalization -- sum(a) == sum(b) == max(sum_pred, sum_gt).
113
+ pm = pred_mass.clamp(min=0.0) # [B, S]
114
+ sum_pred = pm.sum(dim=1) # [B]
115
+ sum_gt = gt_counts # [B]
116
+ pred_dustbin = (sum_gt - sum_pred).clamp(min=0.0) # [B]
117
+ gt_dustbin = (sum_pred - sum_gt).clamp(min=0.0) # [B]
118
+ a = torch.cat([pm, pred_dustbin.unsqueeze(1)], dim=1) # [B, S+1]
119
+ b_val = torch.zeros(B, M + 1, device=cost.device, dtype=cost.dtype)
120
+ b_val[:, :M] = gt_mask.float() # 1.0 per valid GT segment
121
+ b_val[:, -1] = gt_dustbin
122
+ else:
123
+ # Uniform masses (normalized)
124
+ n = float(S)
125
+ denom = n + gt_counts # [B]
126
+ a = (1.0 / denom).unsqueeze(1).expand(B, S + 1).clone() # [B, S+1]
127
+ a[:, -1] = gt_counts / denom
128
+ b_val = (1.0 / denom).unsqueeze(1).expand(B, M + 1).clone() # [B, M+1]
129
+ b_val[:, -1] = n / denom
130
+ # Zero out mass for invalid GT
131
+ b_val[:, :M] = b_val[:, :M] * gt_mask.float()
132
+
133
+ # Log-domain sinkhorn
134
+ log_a = torch.log(a + 1e-9)
135
+ log_b = torch.log(b_val + 1e-9)
136
+ log_k = -cost_pad / eps
137
+
138
+ log_u = torch.zeros_like(a)
139
+ log_v = torch.zeros_like(b_val)
140
+
141
+ for _ in range(iters):
142
+ log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
143
+ log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
144
+
145
+ transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
146
+ return (transport * cost_pad).sum(dim=(1, 2)) # [B]
147
+
148
+
149
+ # Keep the per-sample version for compatibility
150
+ def sinkhorn_segment_loss(
151
+ pred_segments: torch.Tensor,
152
+ gt_segments: torch.Tensor,
153
+ eps: float,
154
+ iters: int,
155
+ dustbin_cost: float,
156
+ pred_mass: torch.Tensor | None = None,
157
+ ) -> torch.Tensor:
158
+ if pred_segments.numel() == 0 or gt_segments.numel() == 0:
159
+ return pred_segments.new_tensor(dustbin_cost)
160
+ cost = segment_pair_cost(pred_segments, gt_segments)
161
+ n, m = cost.shape
162
+ if n == 0 or m == 0:
163
+ return cost.new_tensor(dustbin_cost)
164
+ cost_pad = torch.full((n + 1, m + 1), dustbin_cost, device=cost.device, dtype=cost.dtype)
165
+ cost_pad[:n, :m] = cost
166
+ cost_pad[-1, -1] = 0.0
167
+ denom = float(n + m)
168
+ a = torch.full((n + 1,), 1.0 / denom, device=cost.device, dtype=cost.dtype)
169
+ b = torch.full((m + 1,), 1.0 / denom, device=cost.device, dtype=cost.dtype)
170
+ a[-1] = m / denom
171
+ b[-1] = n / denom
172
+ log_a = torch.log(a + 1e-9)
173
+ log_b = torch.log(b + 1e-9)
174
+ log_k = -cost_pad / eps
175
+ log_u = torch.zeros_like(a)
176
+ log_v = torch.zeros_like(b)
177
+ for _ in range(iters):
178
+ log_u = log_a - torch.logsumexp(log_k + log_v[None, :], dim=1)
179
+ log_v = log_b - torch.logsumexp(log_k + log_u[:, None], dim=0)
180
+ transport = torch.exp(log_u[:, None] + log_v[None, :] + log_k)
181
+ return torch.sum(transport * cost_pad)
s23dr_2026_example/soft_hss_loss.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def _softmin(values: torch.Tensor, dim: int, tau: float) -> torch.Tensor:
5
+ tau_t = torch.as_tensor(tau, device=values.device, dtype=values.dtype).clamp_min(1e-8)
6
+ return -tau_t * torch.logsumexp(-values / tau_t, dim=dim)
7
+
8
+
9
+ def point_segment_distance_squared(
10
+ points: torch.Tensor,
11
+ seg_a: torch.Tensor,
12
+ seg_b: torch.Tensor,
13
+ eps: float = 1e-9,
14
+ ) -> torch.Tensor:
15
+ """
16
+ points: (P,3)
17
+ seg_a/seg_b: (S,3)
18
+ returns dist2: (P,S)
19
+ """
20
+ ab = seg_b - seg_a # (S,3)
21
+ ab2 = (ab * ab).sum(dim=-1).clamp_min(eps) # (S,)
22
+ ap = points[:, None, :] - seg_a[None, :, :] # (P,S,3)
23
+ t = (ap * ab[None, :, :]).sum(dim=-1) / ab2[None, :] # (P,S)
24
+ t = t.clamp(0.0, 1.0)
25
+ closest = seg_a[None, :, :] + t[:, :, None] * ab[None, :, :]
26
+ diff = points[:, None, :] - closest
27
+ return (diff * diff).sum(dim=-1)
28
+
29
+
30
+ def distance_to_segments(
31
+ points: torch.Tensor,
32
+ segments: torch.Tensor,
33
+ eps: float = 1e-9,
34
+ ) -> torch.Tensor:
35
+ """
36
+ points: (P,3)
37
+ segments: (S,2,3)
38
+ returns min distance: (P,)
39
+ """
40
+ a = segments[:, 0]
41
+ b = segments[:, 1]
42
+ dist2 = point_segment_distance_squared(points, a, b, eps=eps)
43
+ return torch.sqrt(dist2.min(dim=1).values + eps)
44
+
45
+
46
+ def soft_vertex_f1(
47
+ pred_vertices: torch.Tensor,
48
+ gt_vertices: torch.Tensor,
49
+ thresh: float,
50
+ tau: float = 0.05,
51
+ softmin_tau: float = 0.05,
52
+ eps: float = 1e-8,
53
+ ) -> torch.Tensor:
54
+ """
55
+ Soft surrogate for the Hungarian-thresholded corner F1 used by HSS.
56
+
57
+ Uses (soft) nearest-neighbor distances and a sigmoid threshold.
58
+ """
59
+ if pred_vertices.numel() == 0 or gt_vertices.numel() == 0:
60
+ return torch.zeros((), device=pred_vertices.device, dtype=pred_vertices.dtype)
61
+
62
+ pred = pred_vertices
63
+ gt = gt_vertices
64
+
65
+ diff = pred[:, None, :] - gt[None, :, :]
66
+ dist = torch.sqrt((diff * diff).sum(dim=-1) + eps) # (P,G)
67
+
68
+ d_pred = _softmin(dist, dim=1, tau=softmin_tau) # (P,)
69
+ d_gt = _softmin(dist, dim=0, tau=softmin_tau) # (G,)
70
+
71
+ tau_t = torch.as_tensor(tau, device=dist.device, dtype=dist.dtype).clamp_min(1e-8)
72
+ thresh_t = torch.as_tensor(thresh, device=dist.device, dtype=dist.dtype)
73
+ p_match = torch.sigmoid((thresh_t - d_pred) / tau_t).mean()
74
+ r_match = torch.sigmoid((thresh_t - d_gt) / tau_t).mean()
75
+ return 2.0 * p_match * r_match / (p_match + r_match + eps)
76
+
77
+
78
+ def soft_tube_iou_mc(
79
+ pred_segments: torch.Tensor,
80
+ gt_segments: torch.Tensor,
81
+ radius: float,
82
+ n_samples: int = 4096,
83
+ tau: float = 0.05,
84
+ seed: int = 0,
85
+ eps: float = 1e-8,
86
+ ) -> torch.Tensor:
87
+ """
88
+ Soft surrogate for volumetric tube IoU (edge_thresh in HSS).
89
+
90
+ Samples points uniformly in a padded bbox around {pred,gt} endpoints.
91
+ Occupancy is sigmoid((radius - d(x, segments))/tau).
92
+ IoU is approximated by mean(min(occ_p, occ_g)) / mean(max(occ_p, occ_g)).
93
+ """
94
+ if pred_segments.numel() == 0 or gt_segments.numel() == 0:
95
+ return torch.zeros((), device=pred_segments.device, dtype=pred_segments.dtype)
96
+
97
+ pts_all = torch.cat([pred_segments.reshape(-1, 3), gt_segments.reshape(-1, 3)], dim=0)
98
+ pad = torch.as_tensor(radius, device=pts_all.device, dtype=pts_all.dtype)
99
+ lo = pts_all.min(dim=0).values - pad
100
+ hi = pts_all.max(dim=0).values + pad
101
+
102
+ gen = torch.Generator(device=pts_all.device)
103
+ gen.manual_seed(int(seed))
104
+ u = torch.rand((int(n_samples), 3), generator=gen, device=pts_all.device, dtype=pts_all.dtype)
105
+ x = lo[None, :] + u * (hi - lo)[None, :]
106
+
107
+ d_p = distance_to_segments(x, pred_segments, eps=eps)
108
+ d_g = distance_to_segments(x, gt_segments, eps=eps)
109
+
110
+ tau_t = torch.as_tensor(tau, device=pts_all.device, dtype=pts_all.dtype).clamp_min(1e-8)
111
+ rad_t = torch.as_tensor(radius, device=pts_all.device, dtype=pts_all.dtype)
112
+ occ_p = torch.sigmoid((rad_t - d_p) / tau_t)
113
+ occ_g = torch.sigmoid((rad_t - d_g) / tau_t)
114
+
115
+ inter = torch.minimum(occ_p, occ_g).mean()
116
+ union = torch.maximum(occ_p, occ_g).mean().clamp_min(eps)
117
+ return inter / union
118
+
119
+
120
+ def soft_hss(
121
+ pred_segments: torch.Tensor,
122
+ gt_segments: torch.Tensor,
123
+ gt_vertices: torch.Tensor,
124
+ vert_thresh: float = 0.5,
125
+ edge_thresh: float = 0.5,
126
+ tau: float = 0.05,
127
+ softmin_tau: float = 0.05,
128
+ n_samples: int = 4096,
129
+ seed: int = 0,
130
+ eps: float = 1e-8,
131
+ ):
132
+ """
133
+ Returns (soft_hss, soft_f1, soft_iou), all scalars in [0,1] (approximately).
134
+ """
135
+ pred_vertices = pred_segments.reshape(-1, 3)
136
+ f1 = soft_vertex_f1(pred_vertices, gt_vertices, thresh=vert_thresh, tau=tau, softmin_tau=softmin_tau, eps=eps)
137
+ iou = soft_tube_iou_mc(
138
+ pred_segments,
139
+ gt_segments,
140
+ radius=edge_thresh,
141
+ n_samples=n_samples,
142
+ tau=tau,
143
+ seed=seed,
144
+ eps=eps,
145
+ )
146
+ denom = (f1 + iou).clamp_min(eps)
147
+ hss = 2.0 * f1 * iou / denom
148
+ return hss, f1, iou
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Improved: Sinkhorn-matched vertex F1
153
+ # ---------------------------------------------------------------------------
154
+ #
155
+ # The original soft_vertex_f1 uses independent softmin nearest-neighbor
156
+ # distances, which allows multiple predicted vertices to claim the same GT
157
+ # vertex. This inflates precision and fails to penalize duplicate vertices --
158
+ # the exact failure mode that requires merge_vertices post-processing.
159
+ #
160
+ # This version uses Sinkhorn optimal transport to find a soft one-to-one
161
+ # assignment between predicted and GT vertices, then computes precision and
162
+ # recall from the matched distances. This is a better surrogate for the
163
+ # Hungarian matching used by the real HSS metric.
164
+
165
+
166
+ def sinkhorn_vertex_f1(
167
+ pred_vertices: torch.Tensor,
168
+ gt_vertices: torch.Tensor,
169
+ thresh: float = 0.5,
170
+ tau: float = 0.05,
171
+ eps_sinkhorn: float = 0.05,
172
+ iters: int = 20,
173
+ eps: float = 1e-8,
174
+ ) -> torch.Tensor:
175
+ """Soft vertex F1 using Sinkhorn matching (better aligned with real HSS).
176
+
177
+ Instead of independent nearest-neighbor distances (which allow double-
178
+ claiming), this uses optimal transport to find a soft one-to-one assignment
179
+ between predicted and GT vertices.
180
+
181
+ Returns a differentiable scalar in [0, 1].
182
+ """
183
+ if pred_vertices.numel() == 0 or gt_vertices.numel() == 0:
184
+ return torch.zeros((), device=pred_vertices.device, dtype=pred_vertices.dtype)
185
+
186
+ P = pred_vertices.shape[0]
187
+ G = gt_vertices.shape[0]
188
+
189
+ # Pairwise distance matrix (P, G)
190
+ dist = torch.cdist(pred_vertices, gt_vertices)
191
+
192
+ # Sinkhorn with dustbin: (P+1) x (G+1)
193
+ # Dustbin cost = thresh (unmatched vertices are "at threshold distance")
194
+ dustbin = thresh
195
+ cost_pad = torch.full((P + 1, G + 1), dustbin, device=dist.device, dtype=dist.dtype)
196
+ cost_pad[:P, :G] = dist
197
+ cost_pad[-1, -1] = 0.0
198
+
199
+ # Uniform masses with dustbin slack
200
+ denom = float(P + G)
201
+ a = torch.full((P + 1,), 1.0 / denom, device=dist.device, dtype=dist.dtype)
202
+ b = torch.full((G + 1,), 1.0 / denom, device=dist.device, dtype=dist.dtype)
203
+ a[-1] = G / denom # pred dustbin absorbs unmatched GT
204
+ b[-1] = P / denom # GT dustbin absorbs unmatched pred
205
+
206
+ # Log-domain Sinkhorn
207
+ log_a = torch.log(a + 1e-9)
208
+ log_b = torch.log(b + 1e-9)
209
+ log_k = -cost_pad / max(eps_sinkhorn, 1e-6)
210
+ log_u = torch.zeros_like(a)
211
+ log_v = torch.zeros_like(b)
212
+ for _ in range(iters):
213
+ log_u = log_a - torch.logsumexp(log_k + log_v[None, :], dim=1)
214
+ log_v = log_b - torch.logsumexp(log_k + log_u[:, None], dim=0)
215
+
216
+ # Transport plan (P+1, G+1)
217
+ transport = torch.exp(log_u[:, None] + log_v[None, :] + log_k)
218
+
219
+ # Extract the non-dustbin transport (P, G) -- these are the soft assignments
220
+ T = transport[:P, :G]
221
+
222
+ # For each predicted vertex, its matched distance is the transport-weighted
223
+ # average distance to GT vertices
224
+ # Normalize rows to sum to 1 (how much of this pred is matched vs dustbin)
225
+ row_sums = T.sum(dim=1).clamp_min(eps)
226
+ matched_dist_pred = (T * dist).sum(dim=1) / row_sums # (P,)
227
+ match_weight_pred = row_sums * denom # how much of this pred is matched (0-1 ish)
228
+
229
+ # Same for GT vertices (column perspective)
230
+ col_sums = T.sum(dim=0).clamp_min(eps)
231
+ matched_dist_gt = (T * dist).sum(dim=0) / col_sums # (G,)
232
+ match_weight_gt = col_sums * denom
233
+
234
+ # Soft precision: fraction of pred vertices that are matched AND within threshold
235
+ tau_t = torch.as_tensor(tau, device=dist.device, dtype=dist.dtype).clamp_min(1e-8)
236
+ thresh_t = torch.as_tensor(thresh, device=dist.device, dtype=dist.dtype)
237
+
238
+ prec_per = match_weight_pred * torch.sigmoid((thresh_t - matched_dist_pred) / tau_t)
239
+ precision = prec_per.mean()
240
+
241
+ # Soft recall: fraction of GT vertices that are matched AND within threshold
242
+ rec_per = match_weight_gt * torch.sigmoid((thresh_t - matched_dist_gt) / tau_t)
243
+ recall = rec_per.mean()
244
+
245
+ return 2.0 * precision * recall / (precision + recall + eps)
246
+
247
+
248
+ # ---------------------------------------------------------------------------
249
+ # Improved: Segment-sampled tube IoU
250
+ # ---------------------------------------------------------------------------
251
+ #
252
+ # The original soft_tube_iou_mc samples random points in the bounding box,
253
+ # wasting most samples in empty space. This version samples along the segments
254
+ # themselves, concentrating gradient signal where it matters.
255
+
256
+
257
+ def _sample_along_segments(segments: torch.Tensor, n_per_seg: int = 64) -> torch.Tensor:
258
+ """Sample n_per_seg points uniformly along each segment.
259
+
260
+ segments: (S, 2, 3)
261
+ returns: (S * n_per_seg, 3)
262
+ """
263
+ t = torch.linspace(0, 1, n_per_seg, device=segments.device, dtype=segments.dtype)
264
+ # (S, 1, 3) + (1, N, 1) * (S, 1, 3) -> (S, N, 3)
265
+ a = segments[:, 0:1, :]
266
+ b = segments[:, 1:2, :]
267
+ pts = a + t[None, :, None] * (b - a)
268
+ return pts.reshape(-1, 3)
269
+
270
+
271
+ def segment_sampled_tube_iou(
272
+ pred_segments: torch.Tensor,
273
+ gt_segments: torch.Tensor,
274
+ radius: float = 0.5,
275
+ n_per_seg: int = 64,
276
+ tau: float = 0.05,
277
+ eps: float = 1e-8,
278
+ ) -> torch.Tensor:
279
+ """Soft tube IoU by sampling along segments instead of in the bounding box.
280
+
281
+ Samples points along predicted and GT segments, then checks what fraction
282
+ of each set falls within radius of the other. More sample-efficient than
283
+ bbox Monte Carlo and gives better gradients.
284
+
285
+ Returns a differentiable scalar in [0, 1].
286
+ """
287
+ if pred_segments.numel() == 0 or gt_segments.numel() == 0:
288
+ return torch.zeros((), device=pred_segments.device, dtype=pred_segments.dtype)
289
+
290
+ pred_pts = _sample_along_segments(pred_segments, n_per_seg)
291
+ gt_pts = _sample_along_segments(gt_segments, n_per_seg)
292
+
293
+ tau_t = torch.as_tensor(tau, device=pred_pts.device, dtype=pred_pts.dtype).clamp_min(1e-8)
294
+ rad_t = torch.as_tensor(radius, device=pred_pts.device, dtype=pred_pts.dtype)
295
+
296
+ # Precision: fraction of pred points within radius of any GT segment
297
+ d_pred = distance_to_segments(pred_pts, gt_segments, eps=eps)
298
+ prec = torch.sigmoid((rad_t - d_pred) / tau_t).mean()
299
+
300
+ # Recall: fraction of GT points within radius of any pred segment
301
+ d_gt = distance_to_segments(gt_pts, pred_segments, eps=eps)
302
+ rec = torch.sigmoid((rad_t - d_gt) / tau_t).mean()
303
+
304
+ # Soft IoU from precision and recall:
305
+ # IoU = intersection/union = (P*R) / (P + R - P*R) for occupancy overlap
306
+ return prec * rec / (prec + rec - prec * rec + eps)
307
+
308
+
309
+ def soft_hss_v2(
310
+ pred_segments: torch.Tensor,
311
+ gt_segments: torch.Tensor,
312
+ gt_vertices: torch.Tensor,
313
+ vert_thresh: float = 0.5,
314
+ edge_thresh: float = 0.5,
315
+ tau: float = 0.05,
316
+ sinkhorn_eps: float = 0.05,
317
+ sinkhorn_iters: int = 20,
318
+ n_per_seg: int = 64,
319
+ eps: float = 1e-8,
320
+ ):
321
+ """Improved soft HSS using Sinkhorn vertex matching + segment-sampled IoU.
322
+
323
+ Returns (soft_hss, soft_f1, soft_iou).
324
+ """
325
+ pred_vertices = pred_segments.reshape(-1, 3)
326
+ f1 = sinkhorn_vertex_f1(
327
+ pred_vertices, gt_vertices,
328
+ thresh=vert_thresh, tau=tau,
329
+ eps_sinkhorn=sinkhorn_eps, iters=sinkhorn_iters, eps=eps,
330
+ )
331
+ iou = segment_sampled_tube_iou(
332
+ pred_segments, gt_segments,
333
+ radius=edge_thresh, n_per_seg=n_per_seg, tau=tau, eps=eps,
334
+ )
335
+ denom = (f1 + iou).clamp_min(eps)
336
+ hss = 2.0 * f1 * iou / denom
337
+ return hss, f1, iou
338
+
339
+
340
+
341
+ # ---------------------------------------------------------------------------
342
+ # Batched versions for training speed
343
+ # ---------------------------------------------------------------------------
344
+
345
+
346
+ def batched_sinkhorn_vertex_f1(
347
+ pred_segments: torch.Tensor,
348
+ gt_pad: torch.Tensor,
349
+ gt_mask: torch.Tensor,
350
+ thresh: float | torch.Tensor = 0.5,
351
+ tau: float | torch.Tensor = 0.05,
352
+ eps_sinkhorn: float = 0.05,
353
+ iters: int = 10,
354
+ eps: float = 1e-8,
355
+ ) -> torch.Tensor:
356
+ """Batched Sinkhorn vertex F1 loss.
357
+
358
+ Args:
359
+ pred_segments: [B, S, 2, 3] predicted segments
360
+ gt_pad: [B, M, 2, 3] padded GT segments
361
+ gt_mask: [B, M] bool mask (True = valid GT segment)
362
+ thresh: distance threshold for a vertex match (scalar or [B])
363
+ tau: sigmoid temperature (scalar or [B])
364
+ Returns:
365
+ [B] per-sample (1 - F1) loss
366
+ """
367
+ B, S = pred_segments.shape[:2]
368
+ M = gt_pad.shape[1]
369
+ P = S * 2 # pred vertices (both endpoints)
370
+
371
+ # Allow per-sample thresh and tau ([B] tensors or scalars)
372
+ thresh_t = torch.as_tensor(thresh, device=pred_segments.device, dtype=pred_segments.dtype)
373
+ if thresh_t.dim() == 0:
374
+ thresh_t = thresh_t.expand(B)
375
+ tau_t = torch.as_tensor(tau, device=pred_segments.device, dtype=pred_segments.dtype)
376
+ if tau_t.dim() == 0:
377
+ tau_t = tau_t.expand(B)
378
+ tau_t = tau_t.clamp_min(1e-8)
379
+
380
+ pred_verts = pred_segments.reshape(B, P, 3)
381
+ gt_verts = gt_pad.reshape(B, M * 2, 3) # will mask invalid ones
382
+
383
+ # Build GT vertex mask: each valid segment contributes 2 vertices
384
+ gt_vert_mask = gt_mask.unsqueeze(2).expand(B, M, 2).reshape(B, M * 2)
385
+ G = M * 2
386
+
387
+ # Pairwise distances [B, P, G]
388
+ dist = torch.linalg.norm(
389
+ pred_verts.unsqueeze(2) - gt_verts.unsqueeze(1), dim=-1)
390
+
391
+ # Mask invalid GT with high distance
392
+ dist = torch.where(gt_vert_mask.unsqueeze(1), dist, thresh_t[:, None, None] * 10.0)
393
+
394
+ # Sinkhorn matching: [B, P+1, G+1]
395
+ cost_pad = thresh_t[:, None, None].expand(B, P + 1, G + 1).clone()
396
+ cost_pad[:, :P, :G] = dist
397
+ cost_pad[:, -1, -1] = 0.0
398
+
399
+ gt_counts = gt_vert_mask.sum(dim=1).float() # [B]
400
+ n = float(P)
401
+ denom = n + gt_counts # [B]
402
+
403
+ a = (1.0 / denom).unsqueeze(1).expand(B, P + 1).clone()
404
+ a[:, -1] = gt_counts / denom
405
+ b = (1.0 / denom).unsqueeze(1).expand(B, G + 1).clone()
406
+ b[:, -1] = n / denom
407
+ b[:, :G] = b[:, :G] * gt_vert_mask.float()
408
+
409
+ log_a = torch.log(a + 1e-9)
410
+ log_b = torch.log(b + 1e-9)
411
+ log_k = -cost_pad / max(eps_sinkhorn, 1e-6)
412
+ log_u = torch.zeros_like(a)
413
+ log_v = torch.zeros_like(b)
414
+
415
+ for _ in range(iters):
416
+ log_u = log_a - torch.logsumexp(log_k + log_v.unsqueeze(1), dim=2)
417
+ log_v = log_b - torch.logsumexp(log_k + log_u.unsqueeze(2), dim=1)
418
+
419
+ transport = torch.exp(log_u.unsqueeze(2) + log_v.unsqueeze(1) + log_k)
420
+ T = transport[:, :P, :G] # [B, P, G]
421
+
422
+ # Matched distances
423
+ row_sums = T.sum(dim=2).clamp_min(eps)
424
+ matched_d_pred = (T * dist).sum(dim=2) / row_sums # [B, P]
425
+ w_pred = row_sums * denom.unsqueeze(1)
426
+
427
+ col_sums = T.sum(dim=1).clamp_min(eps)
428
+ matched_d_gt = (T * dist).sum(dim=1) / col_sums # [B, G]
429
+ w_gt = col_sums * denom.unsqueeze(1)
430
+
431
+ precision = (w_pred * torch.sigmoid((thresh_t[:, None] - matched_d_pred) / tau_t[:, None])).mean(dim=1)
432
+ recall_raw = w_gt * torch.sigmoid((thresh_t[:, None] - matched_d_gt) / tau_t[:, None])
433
+ # Mask invalid GT vertices in recall
434
+ recall = (recall_raw * gt_vert_mask.float()).sum(dim=1) / gt_counts.clamp_min(1.0)
435
+
436
+ f1 = 2.0 * precision * recall / (precision + recall + eps)
437
+ return 1.0 - f1 # return loss (1 - F1)
438
+
439
+
440
+
441
+ def batched_segment_sampled_iou(
442
+ pred_segments: torch.Tensor,
443
+ gt_pad: torch.Tensor,
444
+ gt_mask: torch.Tensor,
445
+ radius: float | torch.Tensor = 0.5,
446
+ n_per_seg: int = 32,
447
+ tau: float | torch.Tensor = 0.05,
448
+ eps: float = 1e-8,
449
+ ) -> torch.Tensor:
450
+ """Batched segment-sampled tube IoU loss.
451
+
452
+ Returns [B] per-sample (1 - IoU) loss.
453
+ """
454
+ B, S = pred_segments.shape[:2]
455
+ M = gt_pad.shape[1]
456
+
457
+ # Allow per-sample radius and tau ([B] tensors or scalars)
458
+ rad_t = torch.as_tensor(radius, device=pred_segments.device, dtype=pred_segments.dtype)
459
+ if rad_t.dim() == 0:
460
+ rad_t = rad_t.expand(B)
461
+ tau_t = torch.as_tensor(tau, device=pred_segments.device, dtype=pred_segments.dtype)
462
+ if tau_t.dim() == 0:
463
+ tau_t = tau_t.expand(B)
464
+ tau_t = tau_t.clamp_min(1e-8)
465
+
466
+ # Sample points along segments
467
+ t = torch.linspace(0, 1, n_per_seg, device=pred_segments.device, dtype=pred_segments.dtype)
468
+
469
+ # Pred points: [B, S*n_per_seg, 3]
470
+ pa = pred_segments[:, :, 0:1, :] # [B, S, 1, 3]
471
+ pb = pred_segments[:, :, 1:2, :]
472
+ pred_pts = (pa + t[None, None, :, None] * (pb - pa)).reshape(B, S * n_per_seg, 3)
473
+
474
+ # GT points: [B, M*n_per_seg, 3]
475
+ ga = gt_pad[:, :, 0:1, :]
476
+ gb = gt_pad[:, :, 1:2, :]
477
+ gt_pts = (ga + t[None, None, :, None] * (gb - ga)).reshape(B, M * n_per_seg, 3)
478
+
479
+ # For each pred point, min distance to any GT segment endpoint samples
480
+ d_pred_to_gt = torch.cdist(pred_pts, gt_pts) # [B, S*n, M*n]
481
+ d_pred = d_pred_to_gt.min(dim=2).values # [B, S*n]
482
+ prec = torch.sigmoid((rad_t[:, None] - d_pred) / tau_t[:, None]).mean(dim=1) # [B]
483
+
484
+ d_gt_to_pred = d_pred_to_gt.min(dim=1).values # [B, M*n]
485
+ # Mask invalid GT points
486
+ gt_pt_mask = gt_mask.unsqueeze(2).expand(B, M, n_per_seg).reshape(B, M * n_per_seg)
487
+ rec_raw = torch.sigmoid((rad_t[:, None] - d_gt_to_pred) / tau_t[:, None])
488
+ rec = (rec_raw * gt_pt_mask.float()).sum(dim=1) / gt_pt_mask.float().sum(dim=1).clamp_min(1.0)
489
+
490
+ iou = prec * rec / (prec + rec - prec * rec + eps)
491
+ return 1.0 - iou # return loss
492
+
493
+
494
+ def batched_soft_hss_v2(pred_segments, gt_pad, gt_mask,
495
+ vert_thresh=0.5, edge_thresh=0.5, tau=0.05,
496
+ sinkhorn_iters=10, n_per_seg=32):
497
+ """Batched soft HSS loss. Returns [B] per-sample (1 - HSS)."""
498
+ f1_loss = batched_sinkhorn_vertex_f1(
499
+ pred_segments, gt_pad, gt_mask,
500
+ thresh=vert_thresh, tau=tau, iters=sinkhorn_iters)
501
+ iou_loss = batched_segment_sampled_iou(
502
+ pred_segments, gt_pad, gt_mask,
503
+ radius=edge_thresh, n_per_seg=n_per_seg, tau=tau)
504
+ f1 = 1.0 - f1_loss
505
+ iou = 1.0 - iou_loss
506
+ hss = 2.0 * f1 * iou / (f1 + iou + 1e-8)
507
+ return 1.0 - hss
s23dr_2026_example/tokenizer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenizer: learned embeddings + Fourier features for the point cloud tokens.
2
+
3
+ The EdgeDepthSequenceBuilder holds the learned embedding tables (label, source,
4
+ behind) and the random Fourier positional encoding. At training time,
5
+ build_tokens() in data.py applies these to pre-sampled point indices on GPU.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from .point_fusion import NUM_ADE, NUM_GEST
17
+
18
+
19
+ # -- Config --
20
+
21
+ @dataclass(frozen=True)
22
+ class EdgeDepthSequenceConfig:
23
+ seq_len: int = 2048
24
+ colmap_points: int = 1280
25
+ depth_points: int = 768
26
+ use_fourier: bool = True
27
+ fourier_dim: int = 32
28
+ fourier_scale: float = 10.0
29
+
30
+
31
+ # -- Fourier positional encoding --
32
+
33
+ class FourierFeatures(nn.Module):
34
+ def __init__(self, in_dim: int = 3, fourier_dim: int = 64,
35
+ scale: float = 10.0, seed: int = 0,
36
+ learnable: bool = False):
37
+ super().__init__()
38
+ gen = torch.Generator()
39
+ gen.manual_seed(seed)
40
+ B = torch.randn(fourier_dim, in_dim, generator=gen) * scale
41
+ if learnable:
42
+ self.B = nn.Parameter(B)
43
+ else:
44
+ self.register_buffer("B", B, persistent=True)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ proj = (2.0 * np.pi) * (x @ self.B.t())
48
+ return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
49
+
50
+
51
+ # -- Sequence builder (holds embeddings) --
52
+
53
+ class EdgeDepthSequenceBuilder(nn.Module):
54
+ """Holds learned embeddings for point cloud tokenization.
55
+
56
+ Used by the model at training time: build_tokens() calls
57
+ self.label_emb(class_id), self.src_emb(source), etc.
58
+ """
59
+
60
+ def __init__(self, cfg: EdgeDepthSequenceConfig, label_emb_dim: int = 16,
61
+ src_emb_dim: int = 2, behind_emb_dim: int = 8,
62
+ fourier_seed: int = 0, use_vote_features: bool = False,
63
+ learnable_fourier: bool = False):
64
+ super().__init__()
65
+ self.cfg = cfg
66
+
67
+ self.num_labels = 13 # 11 structural + other_house + non_house
68
+ self.label_emb = nn.Embedding(self.num_labels, label_emb_dim)
69
+ self.src_emb = nn.Embedding(2, src_emb_dim)
70
+ self.behind_emb_dim = behind_emb_dim
71
+ if behind_emb_dim > 0:
72
+ self.behind_emb = nn.Embedding(NUM_GEST + 1, behind_emb_dim)
73
+
74
+ # Fourier positional encoding
75
+ if cfg.use_fourier:
76
+ self.pos_enc = FourierFeatures(
77
+ in_dim=3, fourier_dim=cfg.fourier_dim,
78
+ scale=cfg.fourier_scale, seed=fourier_seed,
79
+ learnable=learnable_fourier,
80
+ )
81
+ pos_dim = 3 + 2 * cfg.fourier_dim
82
+ else:
83
+ self.pos_enc = None
84
+ pos_dim = 3
85
+
86
+ vote_dim = 2 if use_vote_features else 0 # n_views_voted + vote_frac
87
+ self.use_vote_features = use_vote_features
88
+ self.out_dim = pos_dim + label_emb_dim + src_emb_dim + behind_emb_dim + vote_dim
s23dr_2026_example/varifold.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .wire_varifold_kernels import (
4
+ loss_semi_lobatto3,
5
+ loss_semi_lobatto3_mix,
6
+ loss_semi_lobatto3_mix_simple,
7
+ loss_simpson3,
8
+ loss_simpson3_batch,
9
+ loss_simpson3_mix,
10
+ loss_simpson3_mix_batch,
11
+ loss_simpson3_lenpow,
12
+ loss_simpson3_lenpow_mix,
13
+ loss_semi_legendre,
14
+ )
15
+
16
+
17
+ def edges_to_segments(vertices, edges) -> torch.Tensor:
18
+ verts = torch.as_tensor(vertices, dtype=torch.float32)
19
+ idx = torch.as_tensor(edges, dtype=torch.long)
20
+ return torch.stack([verts[idx[:, 0]], verts[idx[:, 1]]], dim=1)
21
+
22
+
23
+ def segments_to_vertices_edges(segments: torch.Tensor):
24
+ segs = torch.as_tensor(segments, dtype=torch.float32)
25
+ vertices = segs.reshape(-1, 3)
26
+ edges = [(2 * i, 2 * i + 1) for i in range(segs.shape[0])]
27
+ return vertices, edges
28
+
29
+
30
+ def varifold_loss(
31
+ pred_segments: torch.Tensor,
32
+ gt_segments: torch.Tensor,
33
+ sigma: float = 0.1,
34
+ variant: str = "semi_lobatto3",
35
+ t_nodes01: torch.Tensor | None = None,
36
+ t_w: torch.Tensor | None = None,
37
+ sigmas: torch.Tensor | None = None,
38
+ alpha: torch.Tensor | None = None,
39
+ normalize_alpha: bool = True,
40
+ len_pow: float | None = None,
41
+ ) -> torch.Tensor:
42
+ p_pred, q_pred = pred_segments[:, 0], pred_segments[:, 1]
43
+ p_gt, q_gt = gt_segments[:, 0], gt_segments[:, 1]
44
+
45
+ if variant == "semi_lobatto3":
46
+ return loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma)
47
+ if variant == "semi_lobatto3_mix":
48
+ if sigmas is None or alpha is None:
49
+ raise ValueError("sigmas and alpha are required for semi_lobatto3_mix")
50
+ return loss_semi_lobatto3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
51
+ if variant == "semi_lobatto3_mix_simple":
52
+ if sigmas is None or alpha is None:
53
+ raise ValueError("sigmas and alpha are required for semi_lobatto3_mix_simple")
54
+ return loss_semi_lobatto3_mix_simple(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
55
+ if variant == "simpson3":
56
+ if sigmas is not None or alpha is not None:
57
+ if sigmas is None or alpha is None:
58
+ raise ValueError("sigmas and alpha are required for simpson3 mix")
59
+ return loss_simpson3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
60
+ return loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
61
+ if variant == "simpson3_lenpow":
62
+ if len_pow is None:
63
+ len_pow = 1.0
64
+ if sigmas is not None or alpha is not None:
65
+ if sigmas is None or alpha is None:
66
+ raise ValueError("sigmas and alpha are required for simpson3_lenpow mix")
67
+ return loss_simpson3_lenpow_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, len_pow, normalize_alpha)
68
+ return loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma, len_pow)
69
+ if variant == "semi_legendre":
70
+ return loss_semi_legendre(p_pred, q_pred, p_gt, q_gt, sigma, t_nodes01, t_w)
71
+ if variant in ("centers", "segments_varifold", "semi_lobatto1"):
72
+ return varifold_loss_centers(pred_segments, gt_segments, sigma)
73
+ raise ValueError(f"Unknown varifold variant: {variant}")
74
+
75
+
76
+ def varifold_loss_batch(
77
+ pred_segments: torch.Tensor,
78
+ gt_segments: torch.Tensor,
79
+ *,
80
+ sigma: float = 0.1,
81
+ variant: str = "semi_lobatto3",
82
+ t_nodes01: torch.Tensor | None = None,
83
+ t_w: torch.Tensor | None = None,
84
+ sigmas: torch.Tensor | None = None,
85
+ alpha: torch.Tensor | None = None,
86
+ normalize_alpha: bool = True,
87
+ len_pow: float | None = None,
88
+ gt_mask: torch.Tensor | None = None,
89
+ pred_weights: torch.Tensor | None = None,
90
+ cross_only: bool = False,
91
+ ) -> torch.Tensor:
92
+ if pred_segments.dim() != 4 or gt_segments.dim() != 4:
93
+ raise ValueError("pred_segments and gt_segments must be (B, N, 2, 3)")
94
+ p_pred, q_pred = pred_segments[:, :, 0], pred_segments[:, :, 1]
95
+ p_gt, q_gt = gt_segments[:, :, 0], gt_segments[:, :, 1]
96
+
97
+ w_gt = None
98
+ if gt_mask is not None:
99
+ w_gt = gt_mask.to(device=pred_segments.device, dtype=pred_segments.dtype)
100
+
101
+ w_pred = None
102
+ if pred_weights is not None:
103
+ w_pred = pred_weights.to(device=pred_segments.device, dtype=pred_segments.dtype)
104
+
105
+ if variant == "simpson3":
106
+ if sigmas is not None or alpha is not None:
107
+ if sigmas is None or alpha is None:
108
+ raise ValueError("sigmas and alpha are required for simpson3 mix")
109
+ return loss_simpson3_mix_batch(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, w_gt=w_gt, w_pred=w_pred, normalize_alpha=normalize_alpha, cross_only=cross_only)
110
+ return loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigma, w_gt=w_gt, w_pred=w_pred)
111
+
112
+ # Fallback to per-sample loop for unsupported variants.
113
+ losses = []
114
+ sigmas_t = None
115
+ if sigmas is not None:
116
+ sigmas_t = torch.as_tensor(sigmas, device=pred_segments.device, dtype=pred_segments.dtype)
117
+ for idx in range(pred_segments.shape[0]):
118
+ gt_b = gt_segments[idx]
119
+ if gt_mask is not None:
120
+ gt_b = gt_b[gt_mask[idx]]
121
+ sigmas_i = sigmas
122
+ if sigmas_t is not None and sigmas_t.ndim == 2:
123
+ sigmas_i = sigmas_t[idx]
124
+ losses.append(
125
+ varifold_loss(
126
+ pred_segments[idx],
127
+ gt_b,
128
+ sigma=sigma,
129
+ variant=variant,
130
+ t_nodes01=t_nodes01,
131
+ t_w=t_w,
132
+ sigmas=sigmas_i,
133
+ alpha=alpha,
134
+ normalize_alpha=normalize_alpha,
135
+ len_pow=len_pow,
136
+ )
137
+ )
138
+ return torch.stack(losses, dim=0)
139
+
140
+
141
+ def varifold_loss_centers(
142
+ pred_segments: torch.Tensor,
143
+ gt_segments: torch.Tensor,
144
+ sigma: float = 0.1,
145
+ normalize_weights: bool = True,
146
+ ) -> torch.Tensor:
147
+ eps = 1e-8
148
+ a_p, b_p = pred_segments[:, 0], pred_segments[:, 1]
149
+ a_g, b_g = gt_segments[:, 0], gt_segments[:, 1]
150
+
151
+ v_p = b_p - a_p
152
+ v_g = b_g - a_g
153
+ len_p = torch.linalg.norm(v_p, dim=-1)
154
+ len_g = torch.linalg.norm(v_g, dim=-1)
155
+
156
+ x_p = 0.5 * (a_p + b_p)
157
+ x_g = 0.5 * (a_g + b_g)
158
+
159
+ u_p = v_p / (len_p[:, None] + eps)
160
+ u_g = v_g / (len_g[:, None] + eps)
161
+
162
+ w_p = len_p
163
+ w_g = len_g
164
+ if normalize_weights:
165
+ w_p = w_p / (w_p.sum() + eps)
166
+ w_g = w_g / (w_g.sum() + eps)
167
+
168
+ diff_pp = x_p[:, None, :] - x_p[None, :, :]
169
+ diff_gg = x_g[:, None, :] - x_g[None, :, :]
170
+ diff_pg = x_p[:, None, :] - x_g[None, :, :]
171
+ d_pp = (diff_pp * diff_pp).sum(dim=-1)
172
+ d_gg = (diff_gg * diff_gg).sum(dim=-1)
173
+ d_pg = (diff_pg * diff_pg).sum(dim=-1)
174
+
175
+ inv2s2 = 1.0 / (2.0 * sigma * sigma)
176
+ k_pp = torch.exp(-d_pp * inv2s2)
177
+ k_gg = torch.exp(-d_gg * inv2s2)
178
+ k_pg = torch.exp(-d_pg * inv2s2)
179
+
180
+ dot_pp = (u_p[:, None, :] * u_p[None, :, :]).sum(dim=-1)
181
+ dot_gg = (u_g[:, None, :] * u_g[None, :, :]).sum(dim=-1)
182
+ dot_pg = (u_p[:, None, :] * u_g[None, :, :]).sum(dim=-1)
183
+
184
+ k_pp = k_pp * (dot_pp * dot_pp)
185
+ k_gg = k_gg * (dot_gg * dot_gg)
186
+ k_pg = k_pg * (dot_pg * dot_pg)
187
+
188
+ wp_row = w_p[:, None]
189
+ wp_col = w_p[None, :]
190
+ wg_row = w_g[:, None]
191
+ wg_col = w_g[None, :]
192
+
193
+ a_pp = (wp_row * wp_col * k_pp).sum(dim=-1).sum(dim=-1)
194
+ a_gg = (wg_row * wg_col * k_gg).sum(dim=-1).sum(dim=-1)
195
+ a_pg = (w_p[:, None] * w_g[None, :] * k_pg).sum(dim=-1).sum(dim=-1)
196
+ return a_pp + a_gg - 2.0 * a_pg
s23dr_2026_example/wire_varifold_kernels.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ # -----------------------------
5
+ # Helpers
6
+ # -----------------------------
7
+ def segment_geom(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-9):
8
+ """
9
+ p,q: (...,3)
10
+ returns d, a, ell, u:
11
+ d = q - p
12
+ a = ||d||^2
13
+ ell = sqrt(a + eps^2)
14
+ u = d / ell
15
+ """
16
+ d = q - p
17
+ a = (d * d).sum(dim=-1)
18
+ eps_val = eps
19
+ if p.dtype in (torch.float16, torch.bfloat16):
20
+ eps_val = max(eps, float(torch.finfo(p.dtype).eps))
21
+ ell = torch.sqrt(a + eps_val * eps_val)
22
+ u = d / ell.unsqueeze(-1)
23
+ return d, a, ell, u
24
+
25
+ def sample_points(p: torch.Tensor, q: torch.Tensor, nodes01: torch.Tensor):
26
+ # (...,3) + (K,) -> (...,K,3)
27
+ d = q - p
28
+ nodes = nodes01.to(device=p.device, dtype=p.dtype)
29
+ shape = [1] * (p.dim() - 1) + [nodes.shape[0], 1]
30
+ nodes = nodes.view(*shape)
31
+ return p.unsqueeze(-2) + nodes * d.unsqueeze(-2)
32
+
33
+
34
+ # Fixed Lobatto-3 / Simpson nodes+weights on [0,1]
35
+ LOBATTO3_NODES = torch.tensor([0.0, 0.5, 1.0])
36
+ # LOBATTO3_W = torch.tensor([1.0/6.0, 4.0/6.0, 1.0/6.0])
37
+ LOBATTO3_W = torch.tensor([1/3, 1/3, 1/3])
38
+ LOBATTO3_W2 = LOBATTO3_W[:, None] * LOBATTO3_W[None, :] # (3,3)
39
+
40
+
41
+ def _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha: bool):
42
+ sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6)
43
+ alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype)
44
+ if normalize_alpha:
45
+ alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12)
46
+ return sigmas_t, alpha_t
47
+
48
+ # -----------------------------
49
+ # 1) Simpson-3 on both segments (3x3 product rule)
50
+ # -----------------------------
51
+ def _prep_weight(w, n: int, b: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
52
+ if w is None:
53
+ return None
54
+ w = torch.as_tensor(w, device=device, dtype=dtype)
55
+ if w.dim() == 1:
56
+ if w.shape[0] != n:
57
+ raise ValueError(f"weight length {w.shape[0]} != {n}")
58
+ w = w.unsqueeze(0).expand(b, -1)
59
+ elif w.dim() == 2:
60
+ if w.shape[0] != b or w.shape[1] != n:
61
+ raise ValueError(f"weight shape {tuple(w.shape)} != ({b}, {n})")
62
+ else:
63
+ raise ValueError("weights must be 1D or 2D")
64
+ return w
65
+
66
+
67
+ def cross_simpson3(
68
+ pA,
69
+ qA,
70
+ pB,
71
+ qB,
72
+ sigma: float | torch.Tensor,
73
+ wA: torch.Tensor | None = None,
74
+ wB: torch.Tensor | None = None,
75
+ ):
76
+ device, dtype = pA.device, pA.dtype
77
+ batched = pA.dim() == 3
78
+ if not batched:
79
+ pA = pA.unsqueeze(0)
80
+ qA = qA.unsqueeze(0)
81
+ pB = pB.unsqueeze(0)
82
+ qB = qB.unsqueeze(0)
83
+ nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
84
+ w2 = LOBATTO3_W2.to(device=device, dtype=dtype)
85
+
86
+ bsz, nA, _ = pA.shape
87
+ nB = pB.shape[1]
88
+ wA = _prep_weight(wA, nA, bsz, device, dtype)
89
+ wB = _prep_weight(wB, nB, bsz, device, dtype)
90
+
91
+ _, _, ellA, uA = segment_geom(pA, qA)
92
+ _, _, ellB, uB = segment_geom(pB, qB)
93
+
94
+ XA = sample_points(pA, qA, nodes) # (B,N,3,3)
95
+ YB = sample_points(pB, qB, nodes) # (B,M,3,3)
96
+
97
+ # angular + length factors: (N,M)
98
+ ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2)
99
+ lenfac = ellA[:, :, None] * ellB[:, None, :]
100
+ if wA is not None or wB is not None:
101
+ if wA is None:
102
+ wA = torch.ones((bsz, nA), device=device, dtype=dtype)
103
+ if wB is None:
104
+ wB = torch.ones((bsz, nB), device=device, dtype=dtype)
105
+ lenfac = lenfac * (wA[:, :, None] * wB[:, None, :])
106
+
107
+ # spatial: build (N,M,3,3) kernel via broadcasting
108
+ diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] # (B,N,M,3,3,3)
109
+ r2 = (diff * diff).sum(dim=-1) # (B,N,M,3,3)
110
+ sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype)
111
+ if sigma_t.ndim == 0:
112
+ inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t)
113
+ else:
114
+ if sigma_t.shape[0] != bsz:
115
+ raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}")
116
+ inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1)
117
+ K = torch.exp(-r2 * inv2s2) # (B,N,M,3,3)
118
+
119
+ spatial = (K * w2).sum(dim=-1).sum(dim=-1) # (B,N,M)
120
+ out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) # (B,)
121
+ return out[0] if not batched else out
122
+
123
+
124
+ def cross_simpson3_lenpow(
125
+ pA,
126
+ qA,
127
+ pB,
128
+ qB,
129
+ sigma: float | torch.Tensor,
130
+ len_pow: float,
131
+ wA: torch.Tensor | None = None,
132
+ wB: torch.Tensor | None = None,
133
+ ):
134
+ device, dtype = pA.device, pA.dtype
135
+ batched = pA.dim() == 3
136
+ if not batched:
137
+ pA = pA.unsqueeze(0)
138
+ qA = qA.unsqueeze(0)
139
+ pB = pB.unsqueeze(0)
140
+ qB = qB.unsqueeze(0)
141
+ nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
142
+ w2 = LOBATTO3_W2.to(device=device, dtype=dtype)
143
+
144
+ bsz, nA, _ = pA.shape
145
+ nB = pB.shape[1]
146
+ wA = _prep_weight(wA, nA, bsz, device, dtype)
147
+ wB = _prep_weight(wB, nB, bsz, device, dtype)
148
+
149
+ _, _, ellA, uA = segment_geom(pA, qA)
150
+ _, _, ellB, uB = segment_geom(pB, qB)
151
+
152
+ XA = sample_points(pA, qA, nodes) # (B,N,3,3)
153
+ YB = sample_points(pB, qB, nodes) # (B,M,3,3)
154
+
155
+ ang = torch.matmul(uA, uB.transpose(-1, -2)).pow(2)
156
+ lenfac = (ellA[:, :, None] * ellB[:, None, :]).pow(len_pow)
157
+ if wA is not None or wB is not None:
158
+ if wA is None:
159
+ wA = torch.ones((bsz, nA), device=device, dtype=dtype)
160
+ if wB is None:
161
+ wB = torch.ones((bsz, nB), device=device, dtype=dtype)
162
+ lenfac = lenfac * (wA[:, :, None] * wB[:, None, :])
163
+
164
+ diff = XA[:, :, None, :, None, :] - YB[:, None, :, None, :, :] # (B,N,M,3,3,3)
165
+ r2 = (diff * diff).sum(dim=-1) # (B,N,M,3,3)
166
+ sigma_t = torch.as_tensor(sigma, device=device, dtype=dtype)
167
+ if sigma_t.ndim == 0:
168
+ inv2s2 = 1.0 / (2.0 * sigma_t * sigma_t)
169
+ else:
170
+ if sigma_t.shape[0] != bsz:
171
+ raise ValueError(f"sigma batch {sigma_t.shape[0]} != {bsz}")
172
+ inv2s2 = (1.0 / (2.0 * sigma_t * sigma_t)).view(bsz, 1, 1, 1, 1)
173
+ K = torch.exp(-r2 * inv2s2) # (B,N,M,3,3)
174
+
175
+ spatial = (K * w2).sum(dim=-1).sum(dim=-1) # (B,N,M)
176
+ out = (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1) # (B,)
177
+ return out[0] if not batched else out
178
+
179
+
180
+ # -----------------------------
181
+ # 2/3) Semi-analytic in s, quadrature in t
182
+ # - Lobatto-3 (endpoints+midpoint)
183
+ # - Gauss-Legendre Q (nodes/weights passed in)
184
+ # -----------------------------
185
+ def cross_semi_analytic(pA, qA, pB, qB, sigma: float, t_nodes01: torch.Tensor, t_w: torch.Tensor):
186
+ """
187
+ Gaussian k_x. Integrate s exactly along A, integrate t numerically along B.
188
+ t_nodes01, t_w: (Q,) nodes/weights on [0,1] (constants you pass in)
189
+ """
190
+ device, dtype = pA.device, pA.dtype
191
+ t = t_nodes01.to(device=device, dtype=dtype) # (Q,)
192
+ w = t_w.to(device=device, dtype=dtype) # (Q,)
193
+
194
+ dA, aA, ellA, uA = segment_geom(pA, qA)
195
+ dB, _, ellB, uB = segment_geom(pB, qB)
196
+
197
+ # (N,M) factors
198
+ ang = (uA @ uB.t()).pow(2)
199
+ lenfac = ellA[:, None] * ellB[None, :]
200
+
201
+ # r0: (N,M,3)
202
+ r0 = pA[:, None, :] - pB[None, :, :]
203
+
204
+ # r(t): (N,M,Q,3)
205
+ r = r0[:, :, None, :] - t[None, None, :, None] * dB[None, :, None, :]
206
+
207
+ # beta, r2: (N,M,Q)
208
+ beta = (r * dA[:, None, None, :]).sum(dim=-1)
209
+ r2 = (r * r).sum(dim=-1)
210
+
211
+ # semi-analytic constants per A segment: shapes broadcast to (N,1,1)
212
+ a = aA.clamp_min(1e-12)
213
+ inv_a = (1.0 / a).view(-1, 1, 1)
214
+ denom = (torch.sqrt(2.0 * a) * sigma).view(-1, 1, 1)
215
+ pref = (math.sqrt(math.pi) * sigma / torch.sqrt(2.0 * a)).view(-1, 1, 1)
216
+
217
+ # J(t): (N,M,Q)
218
+ exp_term = torch.exp(-(r2 - (beta * beta) * inv_a) / (2.0 * sigma * sigma))
219
+ erf1 = torch.special.erf((a.view(-1, 1, 1) + beta) / denom)
220
+ erf0 = torch.special.erf(beta / denom)
221
+ J = pref * (erf1 - erf0) * exp_term
222
+
223
+ # integrate over t: (N,M)
224
+ spatial = (J * w.view(1, 1, -1)).sum(dim=-1)
225
+ return (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1)
226
+
227
+
228
+ def cross_semi_lobatto3(pA, qA, pB, qB, sigma: float):
229
+ device, dtype = pA.device, pA.dtype
230
+ t = LOBATTO3_NODES.to(device=device, dtype=dtype)
231
+ w = LOBATTO3_W.to(device=device, dtype=dtype)
232
+ return cross_semi_analytic(pA, qA, pB, qB, sigma, t, w)
233
+
234
+
235
+ def cross_semi_lobatto3_mix(
236
+ pA,
237
+ qA,
238
+ pB,
239
+ qB,
240
+ sigmas,
241
+ alpha,
242
+ normalize_alpha: bool = True,
243
+ ):
244
+ """
245
+ Semi-analytic in s (along A), Lobatto-3 in t (along B), with a sigma mixture.
246
+ """
247
+ device, dtype = pA.device, pA.dtype
248
+ t_nodes = LOBATTO3_NODES.to(device=device, dtype=dtype)
249
+ t_w = LOBATTO3_W.to(device=device, dtype=dtype)
250
+
251
+ sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
252
+
253
+ dA, aA, ellA, uA = segment_geom(pA, qA)
254
+ dB, _, ellB, uB = segment_geom(pB, qB)
255
+
256
+ ang = (uA @ uB.t()).pow(2)
257
+ lenfac = ellA[:, None] * ellB[None, :]
258
+
259
+ r0 = pA[:, None, :] - pB[None, :, :]
260
+
261
+ a = aA.clamp_min(1e-12)
262
+ inv_a = (1.0 / a).view(-1, 1)
263
+ sqrt_a = torch.sqrt(2.0 * a).clamp_min(1e-12)
264
+
265
+ denom = (sqrt_a[:, None] * sigmas_t[None, :]).clamp_min(1e-12)
266
+ pref = math.sqrt(math.pi) * sigmas_t[None, :] / sqrt_a[:, None]
267
+ inv2s2 = (1.0 / (2.0 * sigmas_t * sigmas_t)).view(1, 1, -1)
268
+
269
+ denom_nmS = denom[:, None, :]
270
+ pref_nmS = pref[:, None, :]
271
+ alpha_nmS = alpha_t.view(1, 1, -1)
272
+ a_nm1 = a[:, None, None]
273
+
274
+ spatial = torch.zeros((pA.shape[0], pB.shape[0]), device=device, dtype=dtype)
275
+ for tk, wk in zip(t_nodes, t_w):
276
+ r = r0 - tk * dB[None, :, :]
277
+ beta = (r * dA[:, None, :]).sum(dim=-1)
278
+ r2 = (r * r).sum(dim=-1)
279
+ core = r2 - (beta * beta) * inv_a
280
+
281
+ exp_term = torch.exp(-core[:, :, None] * inv2s2)
282
+ erf1 = torch.special.erf((a_nm1 + beta[:, :, None]) / denom_nmS)
283
+ erf0 = torch.special.erf(beta[:, :, None] / denom_nmS)
284
+ J = pref_nmS * (erf1 - erf0) * exp_term
285
+ spatial = spatial + wk * (J * alpha_nmS).sum(dim=-1)
286
+
287
+ return (ang * lenfac * spatial).sum(dim=-1).sum(dim=-1)
288
+
289
+
290
+ # -----------------------------
291
+ # Full losses (self + self - 2 cross)
292
+ # -----------------------------
293
+ # def loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma: float):
294
+ # s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
295
+ # # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma)
296
+ # cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
297
+ # # return s_pred + s_gt - 2.0 * cross
298
+ # return s_pred - 2.0 * cross
299
+
300
+
301
+ def loss_simpson3(p_pred, q_pred, p_gt, q_gt, sigma: float):
302
+ s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
303
+ # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma)
304
+ cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma)
305
+ # return s_pred + s_gt - 2.0 * cross
306
+ return s_pred - 2.0 * cross
307
+
308
+
309
+ def loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma: float, len_pow: float):
310
+ s_pred = cross_simpson3_lenpow(p_pred, q_pred, p_pred, q_pred, sigma, len_pow)
311
+ # s_gt = cross_simpson3_lenpow(p_gt, q_gt, p_gt, q_gt, sigma, len_pow)
312
+ cross = cross_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, sigma, len_pow)
313
+ # return s_pred + s_gt - 2.0 * cross
314
+ return s_pred - 2.0 * cross
315
+
316
+ def loss_simpson3_mix(
317
+ p_pred,
318
+ q_pred,
319
+ p_gt,
320
+ q_gt,
321
+ sigmas,
322
+ alpha,
323
+ normalize_alpha: bool = True,
324
+ ):
325
+ device, dtype = p_pred.device, p_pred.dtype
326
+ sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
327
+ losses = [loss_simpson3(p_pred, q_pred, p_gt, q_gt, s) for s in sigmas_t]
328
+ return (torch.stack(losses) * alpha_t).sum()
329
+
330
+
331
+ # def loss_simpson3_batch(
332
+ # p_pred: torch.Tensor,
333
+ # q_pred: torch.Tensor,
334
+ # p_gt: torch.Tensor,
335
+ # q_gt: torch.Tensor,
336
+ # sigma: float | torch.Tensor,
337
+ # w_gt: torch.Tensor | None = None,
338
+ # ) -> torch.Tensor:
339
+ # s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma)
340
+ # # s_gt = cross_simpson3(p_gt, q_gt, p_gt, q_gt, sigma, wA=w_gt, wB=w_gt)
341
+ # cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wB=w_gt)
342
+ # # return s_pred + s_gt - 2.0 * cross
343
+ # return s_pred - 2.0 * cross
344
+
345
+
346
+ def loss_simpson3_batch(
347
+ p_pred: torch.Tensor,
348
+ q_pred: torch.Tensor,
349
+ p_gt: torch.Tensor,
350
+ q_gt: torch.Tensor,
351
+ sigma: float | torch.Tensor,
352
+ w_gt: torch.Tensor | None = None,
353
+ w_pred: torch.Tensor | None = None,
354
+ cross_only: bool = False,
355
+ ) -> torch.Tensor:
356
+ cross = cross_simpson3(p_pred, q_pred, p_gt, q_gt, sigma, wA=w_pred, wB=w_gt)
357
+ if cross_only:
358
+ # No self-energy: avoids O(S^2) blowup, sinkhorn handles repulsion
359
+ return -2.0 * cross
360
+ s_pred = cross_simpson3(p_pred, q_pred, p_pred, q_pred, sigma, wA=w_pred, wB=w_pred)
361
+ return s_pred - 2.0 * cross
362
+
363
+
364
+ def loss_simpson3_mix_batch(
365
+ p_pred: torch.Tensor,
366
+ q_pred: torch.Tensor,
367
+ p_gt: torch.Tensor,
368
+ q_gt: torch.Tensor,
369
+ sigmas,
370
+ alpha,
371
+ w_gt: torch.Tensor | None = None,
372
+ w_pred: torch.Tensor | None = None,
373
+ normalize_alpha: bool = True,
374
+ cross_only: bool = False,
375
+ ) -> torch.Tensor:
376
+ device, dtype = p_pred.device, p_pred.dtype
377
+ sigmas_t = torch.as_tensor(sigmas, device=device, dtype=dtype).clamp_min(1e-6)
378
+ alpha_t = torch.as_tensor(alpha, device=device, dtype=dtype)
379
+ if normalize_alpha:
380
+ alpha_t = alpha_t / alpha_t.sum().clamp_min(1e-12)
381
+ if sigmas_t.ndim == 1:
382
+ losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, s, w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for s in sigmas_t]
383
+ return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
384
+ if sigmas_t.ndim == 2:
385
+ losses = [loss_simpson3_batch(p_pred, q_pred, p_gt, q_gt, sigmas_t[:, i], w_gt=w_gt, w_pred=w_pred, cross_only=cross_only) for i in range(sigmas_t.shape[1])]
386
+ return (torch.stack(losses, dim=0) * alpha_t[:, None]).sum(dim=0)
387
+ raise ValueError("sigmas must be 1D or 2D for batch loss")
388
+
389
+
390
+ def loss_simpson3_lenpow_mix(
391
+ p_pred,
392
+ q_pred,
393
+ p_gt,
394
+ q_gt,
395
+ sigmas,
396
+ alpha,
397
+ len_pow: float,
398
+ normalize_alpha: bool = True,
399
+ ):
400
+ device, dtype = p_pred.device, p_pred.dtype
401
+ sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
402
+ losses = [loss_simpson3_lenpow(p_pred, q_pred, p_gt, q_gt, s, len_pow) for s in sigmas_t]
403
+ return (torch.stack(losses) * alpha_t).sum()
404
+
405
+ def loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma: float):
406
+ s_pred = cross_semi_lobatto3(p_pred, q_pred, p_pred, q_pred, sigma)
407
+ # s_gt = cross_semi_lobatto3(p_gt, q_gt, p_gt, q_gt, sigma)
408
+ cross = cross_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, sigma)
409
+ # return s_pred + s_gt - 2.0 * cross
410
+ return s_pred - 2.0 * cross
411
+
412
+
413
+ def loss_semi_lobatto3_mix(
414
+ p_pred,
415
+ q_pred,
416
+ p_gt,
417
+ q_gt,
418
+ sigmas,
419
+ alpha,
420
+ normalize_alpha: bool = True,
421
+ ):
422
+ s_pred = cross_semi_lobatto3_mix(p_pred, q_pred, p_pred, q_pred, sigmas, alpha, normalize_alpha)
423
+ # s_gt = cross_semi_lobatto3_mix(p_gt, q_gt, p_gt, q_gt, sigmas, alpha, normalize_alpha)
424
+ cross = cross_semi_lobatto3_mix(p_pred, q_pred, p_gt, q_gt, sigmas, alpha, normalize_alpha)
425
+ # return s_pred + s_gt - 2.0 * cross
426
+ return s_pred - 2.0 * cross
427
+
428
+ def loss_semi_lobatto3_mix_simple(
429
+ p_pred,
430
+ q_pred,
431
+ p_gt,
432
+ q_gt,
433
+ sigmas,
434
+ alpha,
435
+ normalize_alpha: bool = True,
436
+ ):
437
+ device, dtype = p_pred.device, p_pred.dtype
438
+ sigmas_t, alpha_t = _prepare_mix_weights(sigmas, alpha, device, dtype, normalize_alpha)
439
+ losses = [loss_semi_lobatto3(p_pred, q_pred, p_gt, q_gt, s) for s in sigmas_t]
440
+ return (torch.stack(losses) * alpha_t).sum()
441
+
442
+ def loss_semi_legendre(p_pred, q_pred, p_gt, q_gt, sigma: float, t_nodes01, t_w):
443
+ s_pred = cross_semi_analytic(p_pred, q_pred, p_pred, q_pred, sigma, t_nodes01, t_w)
444
+ s_gt = cross_semi_analytic(p_gt, q_gt, p_gt, q_gt, sigma, t_nodes01, t_w)
445
+ cross = cross_semi_analytic(p_pred, q_pred, p_gt, q_gt, sigma, t_nodes01, t_w)
446
+ return s_pred + s_gt - 2.0 * cross
447
+
448
+
449
+ # -----------------------------
450
+ # torch.compile usage
451
+ # -----------------------------
452
+ # For Legendre: generate nodes/weights ONCE outside compile and pass them in.
453
+ # Example:
454
+ # import numpy as np
455
+ # x,w = np.polynomial.legendre.leggauss(Q)
456
+ # t_nodes = torch.tensor(0.5*(x+1.0), device=device, dtype=dtype)
457
+ # t_w = torch.tensor(0.5*w, device=device, dtype=dtype)
458
+ #
459
+ # compiled_loss = torch.compile(loss_semi_lobatto3, fullgraph=True)
460
+ # compiled_loss_leg = torch.compile(lambda pp,qp,pg,qg,s: loss_semi_legendre(pp,qp,pg,qg,s,t_nodes,t_w),
461
+ # fullgraph=True)
script.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """S23DR 2026 submission: learned wireframe prediction from fused point clouds.
2
+
3
+ Pipeline: raw sample -> point fusion -> priority sample 2048 -> model -> post-process -> wireframe
4
+ """
5
+ from pathlib import Path
6
+ from tqdm import tqdm
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ def empty_solution():
17
+ return np.zeros((2, 3)), [(0, 1)]
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Point fusion + sampling (from cache_scenes.py / make_sampled_cache.py)
22
+ # ---------------------------------------------------------------------------
23
+
24
+ # Add our package to path
25
+ SCRIPT_DIR = Path(__file__).resolve().parent
26
+ sys.path.insert(0, str(SCRIPT_DIR))
27
+
28
+ from s23dr_2026_example.point_fusion import build_compact_scene, FuserConfig
29
+ from s23dr_2026_example.cache_scenes import (
30
+ _compute_group_and_class, _compute_smart_center_scale,
31
+ )
32
+ from s23dr_2026_example.make_sampled_cache import _priority_sample
33
+
34
+ # Tokenizer / model imports
35
+ from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig
36
+ from s23dr_2026_example.model import EdgeDepthSegmentsModel
37
+ from s23dr_2026_example.segment_postprocess import merge_vertices
38
+ from s23dr_2026_example.varifold import segments_to_vertices_edges
39
+ from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
40
+
41
+ SEQ_LEN = 2048
42
+ COLMAP_QUOTA = 1536
43
+ DEPTH_QUOTA = 512
44
+ CONF_THRESH = 0.7
45
+ MERGE_THRESH = 0.4
46
+ SNAP_RADIUS = 0.5
47
+
48
+
49
+ def fuse_and_sample(sample, cfg, rng):
50
+ """Run point fusion + priority sampling on a raw dataset sample.
51
+
52
+ Returns a dict with xyz_norm, class_id, source, mask, center, scale, etc.
53
+ ready for model inference. Returns None if fusion fails.
54
+ """
55
+ try:
56
+ scene = build_compact_scene(sample, cfg, rng)
57
+ except Exception as e:
58
+ print(f" Fusion failed: {e}")
59
+ return None
60
+
61
+ xyz = scene["xyz"]
62
+ source = scene["source"]
63
+
64
+ if len(xyz) < 10:
65
+ return None
66
+
67
+ # Compute group_id and class_id (same as cache_scenes.py)
68
+ behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16))
69
+ group_id, class_id = _compute_group_and_class(
70
+ scene["visible_src"], scene["visible_id"], behind_id, source)
71
+
72
+ # Normalize
73
+ center, scale = _compute_smart_center_scale(xyz, source)
74
+
75
+ # Priority sample
76
+ indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA)
77
+
78
+ xyz_norm = (xyz[indices] - center) / scale
79
+
80
+ result = {
81
+ "xyz_norm": xyz_norm.astype(np.float32),
82
+ "class_id": class_id[indices].astype(np.int64),
83
+ "source": source[indices].astype(np.int64),
84
+ "mask": mask,
85
+ "center": center.astype(np.float32),
86
+ "scale": np.float32(scale),
87
+ }
88
+
89
+ # Optional fields
90
+ if "behind_gest_id" in scene:
91
+ behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None)
92
+ result["behind"] = behind.astype(np.int64)
93
+ if "n_views_voted" in scene:
94
+ result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32)
95
+ if "vote_frac" in scene:
96
+ result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32)
97
+
98
+ # Visible src/id for snap post-processing
99
+ result["visible_src"] = scene["visible_src"][indices].astype(np.int64)
100
+ result["visible_id"] = scene["visible_id"][indices].astype(np.int64)
101
+
102
+ return result
103
+
104
+
105
+ def load_model(checkpoint_path, device):
106
+ """Load model from checkpoint."""
107
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
108
+ args = ckpt.get("args", {})
109
+
110
+ norm_class = torch.nn.RMSNorm if args.get("rms_norm") else None
111
+ seq_cfg = EdgeDepthSequenceConfig(
112
+ seq_len=SEQ_LEN, colmap_points=COLMAP_QUOTA, depth_points=DEPTH_QUOTA)
113
+
114
+ model = EdgeDepthSegmentsModel(
115
+ seq_cfg=seq_cfg,
116
+ segments=args.get("segments", 64),
117
+ hidden=args.get("hidden", 256),
118
+ num_heads=args.get("num_heads", 4),
119
+ kv_heads_cross=args.get("kv_heads_cross", 2),
120
+ kv_heads_self=args.get("kv_heads_self", 2),
121
+ dim_feedforward=args.get("ff", 1024),
122
+ dropout=args.get("dropout", 0.1),
123
+ latent_tokens=args.get("latent_tokens", 256),
124
+ latent_layers=args.get("latent_layers", 7),
125
+ decoder_layers=args.get("decoder_layers", 3),
126
+ cross_attn_interval=args.get("cross_attn_interval", 4),
127
+ norm_class=norm_class,
128
+ activation=args.get("activation", "gelu"),
129
+ segment_conf=args.get("segment_conf", True),
130
+ behind_emb_dim=args.get("behind_emb_dim", 8),
131
+ use_vote_features=args.get("vote_features", True),
132
+ arch=args.get("arch", "perceiver"),
133
+ encoder_layers=args.get("encoder_layers", 4),
134
+ pre_encoder_layers=args.get("pre_encoder_layers", 0),
135
+ segment_param=args.get("segment_param", "midpoint_dir_len"),
136
+ qk_norm=args.get("qk_norm", True),
137
+ ).to(device)
138
+
139
+ # Handle torch.compile _orig_mod prefix
140
+ state = ckpt["model"]
141
+ fixed = {k.replace("segmenter._orig_mod.", "segmenter."): v
142
+ for k, v in state.items()}
143
+ model.load_state_dict(fixed, strict=True)
144
+ model.eval()
145
+ return model
146
+
147
+
148
+ def build_tokens_single(sample_dict, model, device):
149
+ """Build token tensor for a single sample (no DataLoader)."""
150
+ xyz = torch.as_tensor(sample_dict["xyz_norm"], dtype=torch.float32).unsqueeze(0).to(device)
151
+ cid = torch.as_tensor(sample_dict["class_id"], dtype=torch.long).unsqueeze(0).to(device)
152
+ src = torch.as_tensor(sample_dict["source"], dtype=torch.long).unsqueeze(0).to(device)
153
+ masks = torch.as_tensor(sample_dict["mask"], dtype=torch.bool).unsqueeze(0).to(device)
154
+
155
+ B, T, _ = xyz.shape
156
+ tok = model.tokenizer
157
+ fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \
158
+ if tok.pos_enc is not None else xyz.new_zeros(B, T, 0)
159
+ parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))]
160
+
161
+ if tok.behind_emb_dim > 0:
162
+ if "behind" in sample_dict:
163
+ beh = torch.as_tensor(sample_dict["behind"], dtype=torch.long).unsqueeze(0).to(device)
164
+ else:
165
+ beh = xyz.new_zeros(B, T, dtype=torch.long)
166
+ parts.append(tok.behind_emb(beh))
167
+
168
+ if tok.use_vote_features:
169
+ if "n_views_voted" in sample_dict and "vote_frac" in sample_dict:
170
+ nv = ((torch.as_tensor(sample_dict["n_views_voted"], dtype=torch.float32).unsqueeze(0).to(device) - 2.7) / 1.0).unsqueeze(-1)
171
+ vf = ((torch.as_tensor(sample_dict["vote_frac"], dtype=torch.float32).unsqueeze(0).to(device) - 0.5) / 0.25).unsqueeze(-1)
172
+ parts.extend([nv, vf])
173
+ else:
174
+ parts.extend([xyz.new_zeros(B, T, 1), xyz.new_zeros(B, T, 1)])
175
+
176
+ tokens = torch.cat(parts, dim=-1)
177
+ return tokens, masks
178
+
179
+
180
+ def predict_sample(sample_dict, model, device):
181
+ """Run model inference + post-processing on a fused sample.
182
+
183
+ Returns (vertices, edges) in world space.
184
+ """
185
+ tokens, masks = build_tokens_single(sample_dict, model, device)
186
+ scale = float(sample_dict["scale"])
187
+ center = sample_dict["center"]
188
+
189
+ with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16,
190
+ enabled=(device.type == 'cuda')):
191
+ out = model.forward_tokens(tokens, masks)
192
+
193
+ segs = out["segments"][0].float().cpu()
194
+ conf = torch.sigmoid(out["conf"][0].float()).cpu().numpy() if "conf" in out else None
195
+
196
+ # Confidence filter
197
+ if conf is not None:
198
+ keep = conf > CONF_THRESH
199
+ segs = segs[keep]
200
+ if len(segs) < 1:
201
+ return empty_solution()
202
+
203
+ # To world space
204
+ segs_world = segs.numpy() * scale + center
205
+
206
+ # Vertices + edges from segments
207
+ pv, pe = segments_to_vertices_edges(torch.tensor(segs_world))
208
+ pv, pe = pv.numpy(), np.array(pe, dtype=np.int32)
209
+
210
+ # Merge
211
+ pv, pe = merge_vertices(pv, pe, MERGE_THRESH)
212
+
213
+ # Snap to point cloud
214
+ xyz_norm = sample_dict["xyz_norm"]
215
+ mask = sample_dict["mask"]
216
+ cid = sample_dict["class_id"]
217
+ xyz_world = xyz_norm[mask] * scale + center
218
+ cid_valid = cid[mask]
219
+ pv = snap_to_point_cloud(pv, xyz_world, cid_valid, snap_radius=SNAP_RADIUS)
220
+
221
+ # Horizontal snap
222
+ pv = snap_horizontal(pv, pe)
223
+
224
+ if len(pv) < 2 or len(pe) < 1:
225
+ return empty_solution()
226
+
227
+ edges = [(int(a), int(b)) for a, b in pe]
228
+ return pv, edges
229
+
230
+
231
+ # ---------------------------------------------------------------------------
232
+ # Main
233
+ # ---------------------------------------------------------------------------
234
+
235
+ if __name__ == "__main__":
236
+ t_start = time.time()
237
+
238
+ # Load params
239
+ param_path = Path("params.json")
240
+ with param_path.open() as f:
241
+ params = json.load(f)
242
+ print(f"Competition: {params.get('competition_id', '?')}")
243
+ print(f"Dataset: {params.get('dataset', '?')}")
244
+
245
+ # Load test data
246
+ data_path = Path("/tmp/data")
247
+ if not data_path.exists():
248
+ from huggingface_hub import snapshot_download
249
+ snapshot_download(
250
+ repo_id=params["dataset"],
251
+ local_dir="/tmp/data",
252
+ repo_type="dataset",
253
+ )
254
+
255
+ from datasets import load_dataset
256
+ data_files = {
257
+ "validation": [str(p) for p in data_path.rglob("*public*/**/*.tar")],
258
+ "test": [str(p) for p in data_path.rglob("*private*/**/*.tar")],
259
+ }
260
+ print(f"Data files: {data_files}")
261
+ dataset = load_dataset(
262
+ str(data_path / "hoho22k_2026_test_x_anon.py"),
263
+ data_files=data_files,
264
+ trust_remote_code=True,
265
+ writer_batch_size=100,
266
+ )
267
+ print(f"Loaded: {dataset}")
268
+
269
+ # Load model
270
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
271
+ print(f"Device: {device}")
272
+ checkpoint_path = SCRIPT_DIR / "checkpoint.pt"
273
+ model = load_model(checkpoint_path, device)
274
+ print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
275
+
276
+ # Point fusion config
277
+ cfg = FuserConfig()
278
+ rng = np.random.RandomState(2718)
279
+
280
+ # Process all samples
281
+ solution = []
282
+ total_samples = sum(len(dataset[s]) for s in dataset)
283
+ processed = 0
284
+
285
+ for subset_name in dataset:
286
+ print(f"\nProcessing {subset_name} ({len(dataset[subset_name])} samples)...")
287
+
288
+ for sample in tqdm(dataset[subset_name], desc=subset_name):
289
+ order_id = sample["order_id"]
290
+
291
+ # Fuse + sample
292
+ fused = fuse_and_sample(sample, cfg, rng)
293
+ if fused is None:
294
+ pred_v, pred_e = empty_solution()
295
+ else:
296
+ try:
297
+ pred_v, pred_e = predict_sample(fused, model, device)
298
+ except Exception as e:
299
+ print(f" Predict failed for {order_id}: {e}")
300
+ pred_v, pred_e = empty_solution()
301
+
302
+ solution.append({
303
+ "order_id": order_id,
304
+ "wf_vertices": pred_v.tolist() if isinstance(pred_v, np.ndarray) else pred_v,
305
+ "wf_edges": [(int(a), int(b)) for a, b in pred_e],
306
+ })
307
+ processed += 1
308
+
309
+ if processed % 50 == 0:
310
+ elapsed = time.time() - t_start
311
+ rate = elapsed / processed
312
+ remaining = (total_samples - processed) * rate
313
+ print(f" [{processed}/{total_samples}] "
314
+ f"{elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining")
315
+
316
+ # Save
317
+ with open("submission.json", "w") as f:
318
+ json.dump(solution, f)
319
+
320
+ elapsed = time.time() - t_start
321
+ print(f"\nDone. {processed} samples in {elapsed:.0f}s ({elapsed/max(processed,1):.1f}s/sample)")
322
+ print(f"Saved submission.json ({len(solution)} entries)")