Other
PyTorch
3d-reconstruction
wireframe
building
point-cloud
s23dr
cvpr-2026
File size: 5,984 Bytes
f4487da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# custom_transformer.py
import torch
import torch.nn as nn
import torch.nn.functional as F

# =============================================================================
# Core Efficient Multihead Attention using Scaled Dot Product Attention (SDPA)
# =============================================================================

class MultiHeadSDPA(nn.Module):
    """
    Multi-head cross-attention using torch.nn.functional.scaled_dot_product_attention
    without causal masking. Suitable for set inputs and cross-attention.

    If qk_norm=True, L2-normalizes Q and K per-head before the dot product,
    then scales by a learned per-head temperature (log_scale). This caps logit
    magnitude to [-1, +1] * exp(log_scale), preventing attention entropy
    collapse at large head_dim.
    """
    def __init__(self, d_model: int, num_heads: int, kv_heads: int = None,
                 qk_norm: bool = False, qk_norm_type: str = "l2"):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.kv_heads = kv_heads or num_heads
        assert self.num_heads % self.kv_heads == 0, "kv_heads must divide num_heads"

        self.head_dim = d_model // num_heads
        self.qk_norm = qk_norm
        self.qk_norm_type = qk_norm_type

        # Input projection layers
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, self.kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, self.kv_heads * self.head_dim, bias=False)

        # Output projection
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        nn.init.zeros_(self.out_proj.weight)

        if qk_norm:
            import math
            if qk_norm_type == "rms":
                # Standard QK-norm (Qwen3/Gemma3 style): RMSNorm on Q and K,
                # no learned temperature. SDPA's 1/sqrt(d) scaling is sufficient
                # because RMSNorm preserves the expected logit variance.
                pass  # no extra parameters needed
            else:
                # L2 + learned temperature (nGPT/ViT-22B style):
                # L2 projects to unit sphere, needs learned scale to compensate.
                self.log_scale = nn.Parameter(
                    torch.full((num_heads,), math.log(math.sqrt(self.head_dim))))

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        key_padding_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # Project
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(key)

        B, Tq, _ = q.shape
        _, Tk, _ = k.shape

        q = q.view(B, Tq, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, Tk, self.kv_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, Tk, self.kv_heads, self.head_dim).transpose(1, 2)

        if self.kv_heads != self.num_heads:
            repeat = self.num_heads // self.kv_heads
            k = k.repeat_interleave(repeat, dim=1)
            v = v.repeat_interleave(repeat, dim=1)

        if self.qk_norm:
            if self.qk_norm_type == "rms":
                # RMSNorm (Qwen3/Gemma3 style): no learned temperature needed.
                # After RMSNorm, logit variance matches standard SDPA naturally.
                q = q * torch.rsqrt(q.square().mean(dim=-1, keepdim=True) + 1e-6)
                k = k * torch.rsqrt(k.square().mean(dim=-1, keepdim=True) + 1e-6)
                attn_mask = None
                if key_padding_mask is not None:
                    attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
                attn_out = F.scaled_dot_product_attention(
                    q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
                )
            else:
                # L2 + learned temperature (nGPT/ViT-22B style)
                q = F.normalize(q, dim=-1)
                k = F.normalize(k, dim=-1)
                scale = self.log_scale.exp().view(1, -1, 1, 1)
                q = q * scale
                attn_mask = None
                if key_padding_mask is not None:
                    attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
                attn_out = F.scaled_dot_product_attention(
                    q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False,
                    scale=1.0,
                )
        else:
            attn_mask = None
            if key_padding_mask is not None:
                attn_mask = ~key_padding_mask[:, None, None, :].to(dtype=torch.bool)
            attn_out = F.scaled_dot_product_attention(
                q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
            )

        attn_out = attn_out.transpose(1, 2).reshape(B, Tq, self.d_model)
        return self.out_proj(attn_out)


# =============================================================================
# Transformer Feed-Forward Block
# =============================================================================

def _get_activation(name: str):
    """Look up activation function by name. Supports 'relu_sq' for ReLU^2."""
    if name == "relu_sq":
        return lambda x: F.relu(x).square()
    return getattr(F, name)


class FeedForward(nn.Module):
    """
    Position-wise MLP block: linear -> activation -> linear.
    Supports 'gelu', 'relu', 'relu_sq', etc.
    """
    def __init__(self, d_model: int, dim_ff: int, activation: str = "gelu"):
        super().__init__()
        self.linear1 = nn.Linear(d_model, dim_ff)
        self.linear2 = nn.Linear(dim_ff, d_model)
        nn.init.zeros_(self.linear2.weight)
        nn.init.zeros_(self.linear2.bias)
        self.activation = _get_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        return self.linear2(self.activation(x))