Jonttup commited on
Commit
e27c6bd
·
verified ·
1 Parent(s): 3ff7322

Upload models/vit.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/vit.py +386 -0
models/vit.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vision Transformer (ViT) for Palette Feature Extraction
3
+
4
+ Implements a standard ViT with Samsung TRM best practices:
5
+ - RMS Normalization
6
+ - SwiGLU activation
7
+ - Truncated normal initialization
8
+ - Spatial feature preservation
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import math
15
+ from typing import Tuple
16
+
17
+
18
+ # ============================================================================
19
+ # Helper functions (local copies)
20
+ # NOTE: These are intentionally local copies, NOT imported from transformer_layers.py.
21
+ # transformer_layers.py uses different parameter names (variance_epsilon vs eps,
22
+ # lower/upper vs a/b), CastedLinear instead of nn.Linear, and different SwiGLU
23
+ # expansion defaults. Callers here rely on the local signatures.
24
+ # ============================================================================
25
+
26
+ def rms_norm(hidden_states: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
27
+ """
28
+ RMS Normalization (more stable than LayerNorm)
29
+
30
+ Args:
31
+ hidden_states: Input tensor
32
+ eps: Epsilon for numerical stability
33
+
34
+ Returns:
35
+ Normalized tensor
36
+ """
37
+ input_dtype = hidden_states.dtype
38
+ hidden_states = hidden_states.to(torch.float32)
39
+
40
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
41
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
42
+
43
+ return hidden_states.to(input_dtype)
44
+
45
+
46
+ def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, a: float = -2, b: float = 2):
47
+ """
48
+ Truncated normal initialization (better than uniform)
49
+
50
+ Args:
51
+ tensor: Tensor to initialize
52
+ std: Standard deviation
53
+ a: Lower truncation bound (in std units)
54
+ b: Upper truncation bound (in std units)
55
+
56
+ Returns:
57
+ Initialized tensor
58
+ """
59
+ with torch.no_grad():
60
+ tensor.normal_(0, std)
61
+ tensor.clamp_(min=a*std, max=b*std)
62
+ return tensor
63
+
64
+
65
+ # ============================================================================
66
+ # SwiGLU Activation
67
+ # ============================================================================
68
+
69
+ class SwiGLU(nn.Module):
70
+ """
71
+ SwiGLU activation (Gated Linear Unit with Swish/SiLU)
72
+
73
+ Superior to ReLU for expressiveness.
74
+ Used in modern LLMs (LLaMA, PaLM, etc.)
75
+ """
76
+
77
+ def __init__(self, hidden_size: int, expansion: float = 2.0):
78
+ super().__init__()
79
+
80
+ # Compute intermediate dimension (round to multiple of 256 for efficiency)
81
+ inter = int(expansion * hidden_size * 2 / 3)
82
+ inter = ((inter + 255) // 256) * 256
83
+
84
+ self.gate_up_proj = nn.Linear(hidden_size, inter * 2, bias=False)
85
+ self.down_proj = nn.Linear(inter, hidden_size, bias=False)
86
+
87
+ def forward(self, x):
88
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
89
+ return self.down_proj(F.silu(gate) * up)
90
+
91
+
92
+ # ============================================================================
93
+ # Multi-Head Self-Attention
94
+ # ============================================================================
95
+
96
+ class MultiHeadSelfAttention(nn.Module):
97
+ """Multi-head self-attention for ViT"""
98
+
99
+ def __init__(self, hidden_dim: int, num_heads: int = 8, dropout: float = 0.1, rms_eps: float = 1e-5):
100
+ super().__init__()
101
+ assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
102
+
103
+ self.hidden_dim = hidden_dim
104
+ self.num_heads = num_heads
105
+ self.head_dim = hidden_dim // num_heads
106
+ self.rms_eps = rms_eps
107
+
108
+ # Projections
109
+ self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=False)
110
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
111
+
112
+ self.dropout = nn.Dropout(dropout)
113
+ self.scale = self.head_dim ** -0.5
114
+
115
+ # Initialize with truncated normal
116
+ self._init_weights()
117
+
118
+ def _init_weights(self):
119
+ """Initialize weights with truncated normal"""
120
+ for module in [self.qkv_proj, self.out_proj]:
121
+ std = 1.0 / math.sqrt(module.in_features)
122
+ trunc_normal_init_(module.weight, std=std)
123
+
124
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Args:
127
+ x: (B, N, D) input sequence
128
+
129
+ Returns:
130
+ (B, N, D) output sequence
131
+ """
132
+ B, N, D = x.shape
133
+
134
+ # Project to Q, K, V
135
+ qkv = self.qkv_proj(x) # (B, N, 3*D)
136
+ qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
137
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, d)
138
+ Q, K, V = qkv[0], qkv[1], qkv[2]
139
+
140
+ # Attention
141
+ scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
142
+ attn_weights = F.softmax(scores, dim=-1)
143
+ attn_weights = self.dropout(attn_weights)
144
+
145
+ context = torch.matmul(attn_weights, V)
146
+
147
+ # Merge heads
148
+ context = context.transpose(1, 2).contiguous().view(B, N, D)
149
+ output = self.out_proj(context)
150
+
151
+ return output
152
+
153
+
154
+ # ============================================================================
155
+ # Transformer Block
156
+ # ============================================================================
157
+
158
+ class TransformerBlock(nn.Module):
159
+ """
160
+ Standard transformer block with RMS norm and SwiGLU
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ hidden_dim: int,
166
+ num_heads: int = 8,
167
+ dropout: float = 0.1,
168
+ swiglu_expansion: float = 2.0,
169
+ rms_eps: float = 1e-5
170
+ ):
171
+ super().__init__()
172
+
173
+ self.hidden_dim = hidden_dim
174
+ self.rms_eps = rms_eps
175
+
176
+ # Self-attention
177
+ self.attention = MultiHeadSelfAttention(hidden_dim, num_heads, dropout, rms_eps)
178
+
179
+ # Feed-forward with SwiGLU
180
+ self.ffn = SwiGLU(hidden_dim, swiglu_expansion)
181
+
182
+ self.dropout = nn.Dropout(dropout)
183
+
184
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
185
+ """
186
+ Args:
187
+ x: (B, N, D) input sequence
188
+
189
+ Returns:
190
+ (B, N, D) output sequence
191
+ """
192
+ # Attention with residual + RMS norm
193
+ x_norm = rms_norm(x, eps=self.rms_eps)
194
+ attn_out = self.attention(x_norm)
195
+ x = x + self.dropout(attn_out)
196
+
197
+ # FFN with residual + RMS norm
198
+ x_norm = rms_norm(x, eps=self.rms_eps)
199
+ ffn_out = self.ffn(x_norm)
200
+ x = x + self.dropout(ffn_out)
201
+
202
+ return x
203
+
204
+
205
+ # ============================================================================
206
+ # Vision Transformer
207
+ # ============================================================================
208
+
209
+ class VisionTransformer(nn.Module):
210
+ """
211
+ Vision Transformer for palette feature extraction
212
+
213
+ Takes embedded palettes (B, H, W, D) and outputs spatial features (B, H, W, D)
214
+
215
+ Architecture:
216
+ - Patchify input (reduce spatial dimensions)
217
+ - Apply transformer layers
218
+ - Unpatchify back to original spatial dimensions
219
+
220
+ Best practices from Samsung TRM:
221
+ - RMS normalization
222
+ - SwiGLU activation
223
+ - Truncated normal initialization
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ hidden_dim: int = 768,
229
+ num_layers: int = 6,
230
+ num_heads: int = 8,
231
+ patch_size: int = 4,
232
+ dropout: float = 0.1,
233
+ rms_eps: float = 1e-5
234
+ ):
235
+ super().__init__()
236
+
237
+ self.hidden_dim = hidden_dim
238
+ self.num_layers = num_layers
239
+ self.num_heads = num_heads
240
+ self.patch_size = patch_size
241
+ self.rms_eps = rms_eps
242
+
243
+ # Patch embedding (reduce spatial dimensions)
244
+ self.patch_embed = nn.Conv2d(
245
+ hidden_dim, hidden_dim,
246
+ kernel_size=patch_size,
247
+ stride=patch_size,
248
+ bias=False
249
+ )
250
+
251
+ # Transformer blocks
252
+ self.blocks = nn.ModuleList([
253
+ TransformerBlock(hidden_dim, num_heads, dropout, rms_eps=rms_eps)
254
+ for _ in range(num_layers)
255
+ ])
256
+
257
+ # Unpatch (restore spatial dimensions)
258
+ self.unpatch = nn.ConvTranspose2d(
259
+ hidden_dim, hidden_dim,
260
+ kernel_size=patch_size,
261
+ stride=patch_size,
262
+ bias=False
263
+ )
264
+
265
+ # Final normalization
266
+ self.final_norm = lambda x: rms_norm(x, eps=rms_eps)
267
+
268
+ # Initialize weights
269
+ self._init_weights()
270
+
271
+ def _init_weights(self):
272
+ """Initialize all weights with truncated normal"""
273
+ for module in self.modules():
274
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
275
+ std = 1.0 / math.sqrt(module.weight.shape[1] if len(module.weight.shape) > 1 else module.weight.shape[0])
276
+ trunc_normal_init_(module.weight, std=std)
277
+ if module.bias is not None:
278
+ module.bias.data.zero_()
279
+
280
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
281
+ """
282
+ Extract spatial features from embedded palettes
283
+
284
+ Args:
285
+ x: (B, H, W, D) embedded palette
286
+
287
+ Returns:
288
+ (B, H, W, D) spatial features
289
+ """
290
+ B, H, W, D = x.shape
291
+
292
+ # Rearrange for Conv2d: (B, H, W, D) → (B, D, H, W)
293
+ x = x.permute(0, 3, 1, 2)
294
+
295
+ # 1. Patchify: (B, D, H, W) → (B, D, H/P, W/P)
296
+ x_patches = self.patch_embed(x)
297
+ B, D, H_p, W_p = x_patches.shape
298
+
299
+ # 2. Flatten patches: (B, D, H_p, W_p) → (B, N, D) where N = H_p * W_p
300
+ x_seq = x_patches.flatten(2).transpose(1, 2) # (B, N, D)
301
+
302
+ # 3. Apply transformer blocks
303
+ for block in self.blocks:
304
+ x_seq = block(x_seq)
305
+
306
+ # 4. Reshape back to patches: (B, N, D) → (B, D, H_p, W_p)
307
+ x_patches = x_seq.transpose(1, 2).reshape(B, D, H_p, W_p)
308
+
309
+ # 5. Unpatchify: (B, D, H_p, W_p) → (B, D, H, W)
310
+ x_out = self.unpatch(x_patches)
311
+
312
+ # 6. Final normalization
313
+ # Normalize along feature dimension (D)
314
+ x_out_norm = x_out.permute(0, 2, 3, 1) # (B, H, W, D)
315
+ x_out_norm = self.final_norm(x_out_norm)
316
+
317
+ return x_out_norm
318
+
319
+
320
+ # ============================================================================
321
+ # Palette Embedding + ViT Pipeline
322
+ # ============================================================================
323
+
324
+ class PaletteFeatureExtractor(nn.Module):
325
+ """
326
+ Complete pipeline: Palette embedding → ViT → Features
327
+
328
+ Combines:
329
+ 1. Token embedding (palette indices → continuous vectors)
330
+ 2. ViT feature extraction (spatial transformations)
331
+
332
+ Input: (B, H, W) LongTensor palette indices
333
+ Output: (B, H, W, D) FloatTensor features
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ palette_size: int = 4096,
339
+ hidden_dim: int = 768,
340
+ num_layers: int = 6,
341
+ num_heads: int = 8,
342
+ patch_size: int = 4,
343
+ dropout: float = 0.1
344
+ ):
345
+ super().__init__()
346
+
347
+ self.palette_size = palette_size
348
+ self.hidden_dim = hidden_dim
349
+
350
+ # Token embedding
351
+ self.palette_embed = nn.Embedding(palette_size, hidden_dim)
352
+
353
+ # ViT
354
+ self.vit = VisionTransformer(
355
+ hidden_dim=hidden_dim,
356
+ num_layers=num_layers,
357
+ num_heads=num_heads,
358
+ patch_size=patch_size,
359
+ dropout=dropout
360
+ )
361
+
362
+ # Initialize embeddings
363
+ self._init_embeddings()
364
+
365
+ def _init_embeddings(self):
366
+ """Initialize embedding with truncated normal"""
367
+ std = 1.0 / math.sqrt(self.hidden_dim)
368
+ trunc_normal_init_(self.palette_embed.weight, std=std)
369
+
370
+ def forward(self, palette: torch.Tensor) -> torch.Tensor:
371
+ """
372
+ Extract features from palette
373
+
374
+ Args:
375
+ palette: (B, H, W) LongTensor palette indices
376
+
377
+ Returns:
378
+ (B, H, W, D) FloatTensor features
379
+ """
380
+ # Embed palette tokens
381
+ x = self.palette_embed(palette) # (B, H, W, D)
382
+
383
+ # Extract features with ViT
384
+ features = self.vit(x) # (B, H, W, D)
385
+
386
+ return features