| """ |
| Perceiver-based transformer for 3D roof wireframe prediction. |
| |
| Architecture overview: |
| |
| Input tokens [B, T, D] |
| | |
| v |
| input_proj: Linear -> GELU -> Linear -> LayerNorm => [B, T, hidden] |
| | |
| v |
| Perceiver latent bottleneck (N PerceiverLatentLayers): |
| Learnable latent embeddings [L, hidden] are broadcast to batch. |
| Each layer: cross-attn(latents <- tokens) -> self-attn(latents) -> FFN |
| Output: latents [B, L, hidden] |
| | |
| v |
| Segment decoder (M SegmentDecoderLayers): |
| Learnable query embeddings [S, hidden] are broadcast to batch. |
| Each layer: cross-attn(queries <- latents) -> self-attn(queries) -> FFN |
| Output: queries [B, S, hidden] |
| | |
| v |
| segment_head: Linear -> 6D -> (midpoint, half_vector) |
| + query_offsets (learnable per-query bias) |
| endpoints = midpoint +/- half_vector -> [B, S, 2, 3] |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .attention import MultiHeadSDPA, FeedForward |
|
|
|
|
| |
| |
| |
|
|
| class AttnResidual(nn.Module): |
| """Pre-norm attention + residual + dropout.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| num_heads: int, |
| dropout: float = 0.0, |
| kv_heads: int | None = None, |
| norm_class=None, |
| qk_norm: bool = False, |
| qk_norm_type: str = "l2", |
| ): |
| super().__init__() |
| norm_class = norm_class or nn.LayerNorm |
| self.norm = norm_class(d_model) |
| self.attn = MultiHeadSDPA(d_model, num_heads, kv_heads=kv_heads, qk_norm=qk_norm, qk_norm_type=qk_norm_type) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| memory: torch.Tensor, |
| memory_key_padding_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| res = x |
| x = self.norm(x) |
| x = self.attn(x, memory, key_padding_mask=memory_key_padding_mask) |
| return res + self.drop(x) |
|
|
|
|
| class FFNResidual(nn.Module): |
| """Pre-norm feed-forward + residual + dropout.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| dim_ff: int, |
| dropout: float = 0.0, |
| activation: str = "gelu", |
| norm_class=None, |
| ): |
| super().__init__() |
| norm_class = norm_class or nn.LayerNorm |
| self.norm = norm_class(d_model) |
| self.ffn = FeedForward(d_model, dim_ff, activation=activation) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| res = x |
| x = self.norm(x) |
| x = self.ffn(x) |
| return res + self.drop(x) |
|
|
|
|
| |
| |
| |
|
|
| class PerceiverLatentLayer(nn.Module): |
| """Single Perceiver latent layer. |
| |
| If use_cross=True: cross-attn(latents <- points) -> self-attn -> FFN |
| If use_cross=False: self-attn -> FFN (saves compute in deep stacks) |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| num_heads: int, |
| dim_ff: int, |
| dropout: float = 0.0, |
| activation: str = "gelu", |
| kv_heads_cross: int | None = None, |
| kv_heads_self: int | None = None, |
| use_cross: bool = True, |
| norm_class=None, |
| qk_norm: bool = False, |
| qk_norm_type: str = "l2", |
| ): |
| super().__init__() |
| self.use_cross = use_cross |
| if use_cross: |
| self.cross = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type) |
| self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_self, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type) |
| self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class) |
|
|
| def forward( |
| self, |
| latents: torch.Tensor, |
| points: torch.Tensor, |
| points_key_padding_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| if self.use_cross: |
| latents = self.cross(latents, points, memory_key_padding_mask=points_key_padding_mask) |
| latents = self.self_attn(latents, latents) |
| latents = self.ffn(latents) |
| return latents |
|
|
|
|
| |
| |
| |
|
|
| class SegmentDecoderLayer(nn.Module): |
| """Single segment decoder layer. |
| |
| cross-attn(queries <- latents) -> [cross-attn(queries <- inputs)] -> self-attn(queries) -> FFN |
| |
| If input_xattn=True, adds a second cross-attention that attends directly |
| to the projected input tokens (bypassing the latent bottleneck). This gives |
| queries access to fine-grained point-level detail for vertex precision. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| num_heads: int, |
| dim_ff: int, |
| dropout: float = 0.0, |
| activation: str = "gelu", |
| kv_heads_cross: int | None = None, |
| kv_heads_self: int | None = None, |
| norm_class=None, |
| input_xattn: bool = False, |
| qk_norm: bool = False, |
| qk_norm_type: str = "l2", |
| ): |
| super().__init__() |
| self.cross = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type) |
| self.input_xattn = input_xattn |
| if input_xattn: |
| self.cross_input = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_cross, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type) |
| self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads_self, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type) |
| self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class) |
|
|
| def forward( |
| self, |
| queries: torch.Tensor, |
| latents: torch.Tensor, |
| src: torch.Tensor | None = None, |
| src_key_padding_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| queries = self.cross(queries, latents) |
| if self.input_xattn and src is not None: |
| queries = self.cross_input(queries, src, memory_key_padding_mask=src_key_padding_mask) |
| queries = self.self_attn(queries, queries) |
| queries = self.ffn(queries) |
| return queries |
|
|
|
|
| |
| |
| |
|
|
| class TokenTransformerSegments(nn.Module): |
| """Perceiver transformer that predicts 3D roof wireframe segments. |
| |
| Takes point-cloud tokens and outputs segment endpoints as [B, S, 2, 3] |
| where S is the number of segments and each segment has two 3D endpoints. |
| |
| Args: |
| segments: Number of predicted segments (S). |
| in_dim: Dimensionality of input tokens. |
| hidden: Internal hidden dimension throughout the model. |
| num_heads: Number of attention heads. |
| kv_heads_cross: Grouped-query heads for cross-attention (None = standard MHA). |
| kv_heads_self: Grouped-query heads for self-attention (None = standard MHA). |
| dim_feedforward: FFN intermediate dimension. |
| dropout: Dropout rate applied after attention and FFN. |
| latent_tokens: Number of learnable latent embeddings (L) in the bottleneck. |
| latent_layers: Number of PerceiverLatentLayers (N). |
| decoder_layers: Number of SegmentDecoderLayers (M). |
| """ |
|
|
| def __init__( |
| self, |
| segments: int = 32, |
| in_dim: int = 128, |
| hidden: int = 128, |
| num_heads: int = 4, |
| kv_heads_cross: int | None = 2, |
| kv_heads_self: int | None = 0, |
| dim_feedforward: int = 256, |
| dropout: float = 0.01, |
| latent_tokens: int = 64, |
| latent_layers: int = 2, |
| decoder_layers: int = 2, |
| cross_attn_interval: int = 1, |
| norm_class=None, |
| activation: str = "gelu", |
| segment_conf: bool = False, |
| pre_encoder_layers: int = 0, |
| segment_param: str = "midpoint_halfvec", |
| length_floor: float = 0.0, |
| decoder_input_xattn: bool = False, |
| qk_norm: bool = False, |
| qk_norm_type: str = "l2", |
| ): |
| super().__init__() |
| self.segments = segments |
| self.out_vertices = segments * 2 |
| self.segment_param = segment_param |
| self.decoder_input_xattn = decoder_input_xattn |
| norm_class = norm_class or nn.LayerNorm |
|
|
| |
| if kv_heads_cross is not None and kv_heads_cross <= 0: |
| kv_heads_cross = None |
| if kv_heads_self is not None and kv_heads_self <= 0: |
| kv_heads_self = None |
|
|
| |
| self.input_proj = nn.Sequential( |
| nn.Linear(in_dim, dim_feedforward), |
| nn.GELU(), |
| nn.Linear(dim_feedforward, hidden), |
| norm_class(hidden), |
| ) |
|
|
| |
| if pre_encoder_layers > 0: |
| self.pre_encoder = nn.ModuleList([ |
| SelfAttentionEncoderLayer( |
| d_model=hidden, |
| num_heads=num_heads, |
| dim_ff=dim_feedforward, |
| dropout=dropout, |
| activation=activation, |
| kv_heads=kv_heads_self, |
| norm_class=norm_class, |
| qk_norm=qk_norm, qk_norm_type=qk_norm_type, |
| ) |
| for _ in range(pre_encoder_layers) |
| ]) |
| else: |
| self.pre_encoder = None |
|
|
| |
| self.latent_embed = nn.Embedding(latent_tokens, hidden) |
| N = latent_layers |
| self.latent_layers = nn.ModuleList([ |
| PerceiverLatentLayer( |
| d_model=hidden, |
| num_heads=num_heads, |
| dim_ff=dim_feedforward, |
| dropout=dropout, |
| activation=activation, |
| kv_heads_cross=kv_heads_cross, |
| kv_heads_self=kv_heads_self, |
| use_cross=(i == 0) or (i == N - 1) or (i % cross_attn_interval == 0), |
| norm_class=norm_class, |
| qk_norm=qk_norm, qk_norm_type=qk_norm_type, |
| ) |
| for i in range(N) |
| ]) |
|
|
| |
| self.query_embed = nn.Embedding(segments, hidden) |
| self.decoder_layers = nn.ModuleList([ |
| SegmentDecoderLayer( |
| d_model=hidden, |
| num_heads=num_heads, |
| dim_ff=dim_feedforward, |
| dropout=dropout, |
| activation=activation, |
| kv_heads_cross=kv_heads_cross, |
| kv_heads_self=kv_heads_self, |
| norm_class=norm_class, |
| input_xattn=decoder_input_xattn, |
| qk_norm=qk_norm, qk_norm_type=qk_norm_type, |
| ) |
| for _ in range(decoder_layers) |
| ]) |
|
|
| |
| if segment_param == "midpoint_dir_len": |
| self.segment_head = nn.Linear(hidden, 7) |
| else: |
| self.segment_head = nn.Linear(hidden, 6) |
| self.query_offsets = nn.Parameter(torch.zeros(segments, 2, 3)) |
|
|
| nn.init.trunc_normal_(self.segment_head.weight, mean=0.0, std=1e-3) |
| if self.segment_head.bias is not None: |
| nn.init.zeros_(self.segment_head.bias) |
| if segment_param == "midpoint_dir_len": |
| |
| self.segment_head.bias.data[6] = 0.5 |
| nn.init.normal_(self.query_offsets, mean=0.0, std=0.05) |
|
|
| |
| self.segment_conf = segment_conf |
| if segment_conf: |
| self.conf_head = nn.Linear(hidden, 1) |
| nn.init.zeros_(self.conf_head.bias) |
|
|
| def forward( |
| self, |
| tokens: torch.Tensor, |
| mask: torch.Tensor | None = None, |
| ) -> dict[str, torch.Tensor | list]: |
| """ |
| Args: |
| tokens: Input point-cloud tokens [B, T, in_dim]. |
| mask: Boolean validity mask [B, T]. True = valid token. |
| |
| Returns: |
| Dict with keys: |
| "vertices": [B, S*2, 3] flattened endpoints. |
| "segments": [B, S, 2, 3] segment endpoints. |
| "edges": Per-batch list of (start, end) index pairs into vertices. |
| "conf": [B, S] logits (only if segment_conf=True). |
| """ |
| B = tokens.shape[0] |
|
|
| |
| src = self.input_proj(tokens) |
|
|
| |
| pad_mask = ~mask.bool() if mask is not None else None |
|
|
| |
| if self.pre_encoder is not None: |
| for layer in self.pre_encoder: |
| src = layer(src, key_padding_mask=pad_mask) |
|
|
| |
| latents = self.latent_embed.weight.unsqueeze(0).expand(B, -1, -1) |
| for layer in self.latent_layers: |
| latents = layer(latents, src, points_key_padding_mask=pad_mask) |
|
|
| |
| queries = self.query_embed.weight.unsqueeze(0).expand(B, -1, -1) |
| for layer in self.decoder_layers: |
| queries = layer(queries, latents, |
| src=src if self.decoder_input_xattn else None, |
| src_key_padding_mask=pad_mask if self.decoder_input_xattn else None) |
|
|
| |
| if self.segment_param == "midpoint_dir_len": |
| raw = self.segment_head(queries) |
| mid = raw[:, :, :3] + self.query_offsets[:, 0, :].unsqueeze(0) |
| direction = torch.nn.functional.normalize(raw[:, :, 3:6], dim=-1) |
| length = torch.nn.functional.softplus(raw[:, :, 6:7]) * 0.1 |
| half = direction * length * 0.5 |
| else: |
| raw = self.segment_head(queries).view(B, self.segments, 2, 3) |
| raw = raw + self.query_offsets.unsqueeze(0) |
| mid, half = raw[:, :, 0], raw[:, :, 1] |
| seg_params = torch.stack([mid - half, mid + half], dim=2) |
|
|
| vertices = seg_params.reshape(B, self.out_vertices, 3) |
| edges = [[(2 * i, 2 * i + 1) for i in range(self.segments)] for _ in range(B)] |
|
|
| out = {"vertices": vertices, "segments": seg_params, "edges": edges, |
| "src": src, "pad_mask": pad_mask, "queries": queries} |
| if self.segment_conf: |
| out["conf"] = self.conf_head(queries).squeeze(-1) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class SelfAttentionEncoderLayer(nn.Module): |
| """Single self-attention layer: self-attn(tokens) -> FFN.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| num_heads: int, |
| dim_ff: int, |
| dropout: float = 0.0, |
| activation: str = "gelu", |
| kv_heads: int | None = None, |
| norm_class=None, |
| qk_norm: bool = False, |
| qk_norm_type: str = "l2", |
| ): |
| super().__init__() |
| self.self_attn = AttnResidual(d_model, num_heads, dropout, kv_heads=kv_heads, norm_class=norm_class, qk_norm=qk_norm, qk_norm_type=qk_norm_type) |
| self.ffn = FFNResidual(d_model, dim_ff, dropout, activation=activation, norm_class=norm_class) |
|
|
| def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor: |
| x = self.self_attn(x, x, memory_key_padding_mask=key_padding_mask) |
| x = self.ffn(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class EdgeDepthSegmentsModel(nn.Module): |
| """Tokenizer embeddings + transformer for 3D roof wireframes. |
| |
| Supports two architectures via the `arch` parameter: |
| - "perceiver": Perceiver latent bottleneck (default, O(L*T) attention) |
| - "transformer": Standard self-attention encoder (O(T^2) attention) |
| |
| Both share the same decoder, output head, and tokenizer. |
| """ |
|
|
| def __init__( |
| self, |
| seq_cfg, |
| segments: int = 32, |
| hidden: int = 128, |
| num_heads: int = 4, |
| kv_heads_cross: int | None = 2, |
| kv_heads_self: int | None = 0, |
| dim_feedforward: int = 256, |
| dropout: float = 0.1, |
| latent_tokens: int = 64, |
| latent_layers: int = 1, |
| decoder_layers: int = 2, |
| label_emb_dim: int = 16, |
| src_emb_dim: int = 2, |
| behind_emb_dim: int = 8, |
| fourier_seed: int = 0, |
| cross_attn_interval: int = 1, |
| norm_class=None, |
| activation: str = "gelu", |
| segment_conf: bool = False, |
| use_vote_features: bool = False, |
| arch: str = "perceiver", |
| encoder_layers: int = 4, |
| pre_encoder_layers: int = 0, |
| segment_param: str = "midpoint_halfvec", |
| length_floor: float = 0.0, |
| decoder_input_xattn: bool = False, |
| qk_norm: bool = False, |
| qk_norm_type: str = "l2", |
| learnable_fourier: bool = False, |
| ): |
| super().__init__() |
| self.seq_cfg = seq_cfg |
|
|
| from .tokenizer import EdgeDepthSequenceBuilder |
| self.tokenizer = EdgeDepthSequenceBuilder( |
| seq_cfg, |
| label_emb_dim=label_emb_dim, |
| src_emb_dim=src_emb_dim, |
| behind_emb_dim=behind_emb_dim, |
| fourier_seed=fourier_seed, |
| use_vote_features=use_vote_features, |
| learnable_fourier=learnable_fourier, |
| ) |
|
|
| if arch == "transformer": |
| raise ValueError( |
| "arch='transformer' is no longer supported. " |
| "TransformerSegments has been removed; use arch='perceiver'.") |
| else: |
| self.segmenter = TokenTransformerSegments( |
| segments=segments, |
| in_dim=self.tokenizer.out_dim, |
| hidden=hidden, |
| num_heads=num_heads, |
| kv_heads_cross=kv_heads_cross, |
| kv_heads_self=kv_heads_self, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, |
| latent_tokens=latent_tokens, |
| latent_layers=latent_layers, |
| decoder_layers=decoder_layers, |
| cross_attn_interval=cross_attn_interval, |
| norm_class=norm_class, |
| activation=activation, |
| segment_conf=segment_conf, |
| pre_encoder_layers=pre_encoder_layers, |
| segment_param=segment_param, |
| length_floor=length_floor, |
| decoder_input_xattn=decoder_input_xattn, |
| qk_norm=qk_norm, qk_norm_type=qk_norm_type, |
| ) |
|
|
| def forward_tokens(self, tokens: torch.Tensor, mask: torch.Tensor): |
| """Run the segmenter on pre-built token tensors.""" |
| return self.segmenter(tokens, mask) |
|
|