GrimSqueaker commited on
Commit
df284ec
·
verified ·
1 Parent(s): 7031eae

Upload modeling_modern_protein.py

Browse files
Files changed (1) hide show
  1. modeling_modern_protein.py +396 -0
modeling_modern_protein.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ModernProteinLM: A next-generation protein encoder combining:
3
+ - ModernBERT architectural improvements (RoPE, Pre-LN, GeGLU, FlashAttention-compatible)
4
+ - ELECTRA-style discriminative pre-training
5
+ - Deep & narrow design optimal for protein sequences
6
+ - Curriculum masking (30% -> 5%)
7
+ - Span masking for protein structural motifs
8
+
9
+ Architecture goals (~150M params):
10
+ - 28 layers, hidden 576, heads 9, intermediate 2304 (GeGLU)
11
+ - RoPE position embeddings (no absolute PE)
12
+ - Pre-LayerNorm with extra LN after embedding
13
+ - No dropout (following ESM-2)
14
+ - Tied input/output embeddings
15
+ """
16
+
17
+ import math
18
+ from typing import Optional, Tuple, Union
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from transformers import PreTrainedModel, PretrainedConfig
23
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
24
+
25
+
26
+ class ModernProteinLMConfig(PretrainedConfig):
27
+ model_type = "modern_protein_lm"
28
+
29
+ def __init__(
30
+ self,
31
+ vocab_size=33,
32
+ hidden_size=576,
33
+ num_hidden_layers=28,
34
+ num_attention_heads=9,
35
+ intermediate_size=2304,
36
+ hidden_act="gelu",
37
+ hidden_dropout_prob=0.0,
38
+ attention_probs_dropout_prob=0.0,
39
+ max_position_embeddings=1026,
40
+ initializer_range=0.02,
41
+ layer_norm_eps=1e-12,
42
+ position_embedding_type="rotary",
43
+ rope_theta=10000.0,
44
+ use_geglu=True,
45
+ tie_word_embeddings=True,
46
+ pad_token_id=1,
47
+ mask_token_id=32,
48
+ cls_token_id=0,
49
+ eos_token_id=2,
50
+ **kwargs,
51
+ ):
52
+ super().__init__(
53
+ pad_token_id=pad_token_id,
54
+ mask_token_id=mask_token_id,
55
+ cls_token_id=cls_token_id,
56
+ eos_token_id=eos_token_id,
57
+ tie_word_embeddings=tie_word_embeddings,
58
+ **kwargs,
59
+ )
60
+ self.vocab_size = vocab_size
61
+ self.hidden_size = hidden_size
62
+ self.num_hidden_layers = num_hidden_layers
63
+ self.num_attention_heads = num_attention_heads
64
+ self.intermediate_size = intermediate_size
65
+ self.hidden_act = hidden_act
66
+ self.hidden_dropout_prob = hidden_dropout_prob
67
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
68
+ self.max_position_embeddings = max_position_embeddings
69
+ self.initializer_range = initializer_range
70
+ self.layer_norm_eps = layer_norm_eps
71
+ self.position_embedding_type = position_embedding_type
72
+ self.rope_theta = rope_theta
73
+ self.use_geglu = use_geglu
74
+
75
+
76
+ class RotaryEmbedding(nn.Module):
77
+ """RoPE (Rotary Position Embedding) for protein sequences."""
78
+
79
+ def __init__(self, dim, max_seq_len=1026, base=10000.0, device=None):
80
+ super().__init__()
81
+ self.dim = dim
82
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
83
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
84
+ self.max_seq_len = max_seq_len
85
+
86
+ def forward(self, seq_len, device):
87
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
88
+ freqs = torch.outer(t, self.inv_freq)
89
+ emb = torch.cat((freqs, freqs), dim=-1)
90
+ return emb.cos().to(torch.float32), emb.sin().to(torch.float32)
91
+
92
+
93
+ def rotate_half(x):
94
+ x1, x2 = x.chunk(2, dim=-1)
95
+ return torch.cat((-x2, x1), dim=-1)
96
+
97
+
98
+ def apply_rotary_pos_emb(q, k, cos, sin):
99
+ q_embed = (q * cos) + (rotate_half(q) * sin)
100
+ k_embed = (k * cos) + (rotate_half(k) * sin)
101
+ return q_embed, k_embed
102
+
103
+
104
+ class ModernProteinAttention(nn.Module):
105
+ """Multi-head attention with RoPE and optional FlashAttention."""
106
+
107
+ def __init__(self, config: ModernProteinLMConfig):
108
+ super().__init__()
109
+ self.num_heads = config.num_attention_heads
110
+ self.head_dim = config.hidden_size // config.num_attention_heads
111
+ self.scale = self.head_dim ** -0.5
112
+
113
+ self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
114
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
115
+
116
+ self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len=config.max_position_embeddings, base=config.rope_theta)
117
+
118
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob) if config.attention_probs_dropout_prob > 0 else None
119
+
120
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
121
+ batch_size, seq_len, _ = hidden_states.shape
122
+
123
+ qkv = self.qkv_proj(hidden_states)
124
+ qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
125
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, T, D)
126
+ q, k, v = qkv[0], qkv[1], qkv[2]
127
+
128
+ # Apply RoPE
129
+ cos, sin = self.rotary_emb(seq_len, device=hidden_states.device)
130
+ cos = cos[None, None, :, :] # (1, 1, T, D)
131
+ sin = sin[None, None, :, :]
132
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
133
+
134
+ # Try FlashAttention if available
135
+ try:
136
+ from flash_attn import flash_attn_func
137
+ if attention_mask is None and q.dtype in [torch.float16, torch.bfloat16]:
138
+ attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
139
+ dropout_p=self.dropout.p if self.dropout else 0.0,
140
+ causal=False)
141
+ attn_output = attn_output.transpose(1, 2)
142
+ else:
143
+ raise ImportError("Fallback to standard attention")
144
+ except (ImportError, AttributeError):
145
+ # Standard scaled dot-product attention
146
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
147
+
148
+ if attention_mask is not None:
149
+ attn_scores = attn_scores + attention_mask
150
+
151
+ attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
152
+ if self.dropout is not None:
153
+ attn_probs = self.dropout(attn_probs)
154
+
155
+ attn_output = torch.matmul(attn_probs, v)
156
+
157
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
158
+ attn_output = self.out_proj(attn_output)
159
+
160
+ if output_attentions:
161
+ return attn_output, attn_probs
162
+ return attn_output, None
163
+
164
+
165
+ class GeGLU(nn.Module):
166
+ """GeGLU activation: GELU(gate) * value. More expressive than GELU alone."""
167
+
168
+ def __init__(self, config: ModernProteinLMConfig):
169
+ super().__init__()
170
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
171
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
172
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
173
+ self.act = nn.GELU()
174
+
175
+ def forward(self, x):
176
+ gate = self.act(self.gate_proj(x))
177
+ up = self.up_proj(x)
178
+ return self.down_proj(gate * up)
179
+
180
+
181
+ class ModernProteinMLP(nn.Module):
182
+ def __init__(self, config: ModernProteinLMConfig):
183
+ super().__init__()
184
+ if config.use_geglu:
185
+ self.mlp = GeGLU(config)
186
+ else:
187
+ self.mlp = nn.Sequential(
188
+ nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
189
+ nn.GELU(),
190
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
191
+ )
192
+
193
+ def forward(self, x):
194
+ return self.mlp(x)
195
+
196
+
197
+ class ModernProteinLayer(nn.Module):
198
+ """Pre-LN transformer layer with optional parallel formulation."""
199
+
200
+ def __init__(self, config: ModernProteinLMConfig):
201
+ super().__init__()
202
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
203
+ self.attn = ModernProteinAttention(config)
204
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
205
+ self.mlp = ModernProteinMLP(config)
206
+
207
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
208
+ # Pre-LN: LN -> Attn -> Residual
209
+ attn_out, attn_weights = self.attn(self.ln1(hidden_states), attention_mask, output_attentions)
210
+ hidden_states = hidden_states + attn_out
211
+
212
+ # Pre-LN: LN -> MLP -> Residual
213
+ mlp_out = self.mlp(self.ln2(hidden_states))
214
+ hidden_states = hidden_states + mlp_out
215
+
216
+ return hidden_states, attn_weights
217
+
218
+
219
+ class ModernProteinLM(PreTrainedModel):
220
+ config_class = ModernProteinLMConfig
221
+
222
+ def __init__(self, config: ModernProteinLMConfig):
223
+ super().__init__(config)
224
+ self.config = config
225
+
226
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
227
+ self.embed_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
228
+
229
+ self.layers = nn.ModuleList([
230
+ ModernProteinLayer(config) for _ in range(config.num_hidden_layers)
231
+ ])
232
+
233
+ self.final_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
234
+
235
+ # Initialize weights
236
+ self._init_weights()
237
+
238
+ # Tie embeddings if requested
239
+ if config.tie_word_embeddings:
240
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
241
+ self.lm_head.weight = self.embeddings.weight
242
+ else:
243
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
244
+
245
+ def _init_weights(self):
246
+ for module in self.modules():
247
+ if isinstance(module, nn.Linear):
248
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
249
+ if module.bias is not None:
250
+ nn.init.zeros_(module.bias)
251
+ elif isinstance(module, nn.Embedding):
252
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
253
+ elif isinstance(module, nn.LayerNorm):
254
+ nn.init.ones_(module.weight)
255
+ nn.init.zeros_(module.bias)
256
+
257
+ def get_input_embeddings(self):
258
+ return self.embeddings
259
+
260
+ def set_input_embeddings(self, value):
261
+ self.embeddings = value
262
+
263
+ def forward(
264
+ self,
265
+ input_ids,
266
+ attention_mask=None,
267
+ position_ids=None,
268
+ labels=None,
269
+ output_attentions=False,
270
+ output_hidden_states=False,
271
+ return_dict=True,
272
+ ):
273
+ batch_size, seq_len = input_ids.shape
274
+
275
+ # Embedding
276
+ hidden_states = self.embeddings(input_ids)
277
+ hidden_states = self.embed_ln(hidden_states)
278
+
279
+ # Attention mask for padding
280
+ if attention_mask is not None:
281
+ # (B, T) -> (B, 1, 1, T) for broadcasting
282
+ attention_mask = (1.0 - attention_mask[:, None, None, :]) * -10000.0
283
+
284
+ all_hidden_states = () if output_hidden_states else None
285
+ all_attentions = () if output_attentions else None
286
+
287
+ # Transformer layers
288
+ for layer in self.layers:
289
+ if output_hidden_states:
290
+ all_hidden_states += (hidden_states,)
291
+
292
+ hidden_states, attn_weights = layer(hidden_states, attention_mask, output_attentions)
293
+
294
+ if output_attentions:
295
+ all_attentions += (attn_weights,)
296
+
297
+ hidden_states = self.final_ln(hidden_states)
298
+
299
+ if output_hidden_states:
300
+ all_hidden_states += (hidden_states,)
301
+
302
+ # LM head
303
+ logits = self.lm_head(hidden_states)
304
+
305
+ loss = None
306
+ if labels is not None:
307
+ loss_fct = nn.CrossEntropyLoss()
308
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
309
+
310
+ if not return_dict:
311
+ output = (logits,)
312
+ if output_hidden_states:
313
+ output += (all_hidden_states,)
314
+ if output_attentions:
315
+ output += (all_attentions,)
316
+ return ((loss,) + output) if loss is not None else output
317
+
318
+ return MaskedLMOutput(
319
+ loss=loss,
320
+ logits=logits,
321
+ hidden_states=all_hidden_states,
322
+ attentions=all_attentions,
323
+ )
324
+
325
+ def get_sequence_embedding(self, input_ids, attention_mask=None):
326
+ """Extract CLS or mean-pooled embedding for downstream tasks."""
327
+ outputs = self.forward(
328
+ input_ids=input_ids,
329
+ attention_mask=attention_mask,
330
+ output_hidden_states=True,
331
+ return_dict=True,
332
+ )
333
+ hidden = outputs.hidden_states[-1]
334
+
335
+ if attention_mask is not None:
336
+ # Mean pool over non-padded positions
337
+ mask_expanded = attention_mask.unsqueeze(-1).float()
338
+ sum_hidden = (hidden * mask_expanded).sum(dim=1)
339
+ pooled = sum_hidden / mask_expanded.sum(dim=1).clamp(min=1e-9)
340
+ else:
341
+ pooled = hidden[:, 0] # CLS token
342
+
343
+ return pooled
344
+
345
+
346
+ class ModernProteinLMForMaskedLM(ModernProteinLM):
347
+ """Masked Language Model wrapper."""
348
+ pass
349
+
350
+
351
+ class ModernProteinLMForSequenceClassification(PreTrainedModel):
352
+ config_class = ModernProteinLMConfig
353
+
354
+ def __init__(self, config: ModernProteinLMConfig):
355
+ super().__init__(config)
356
+ self.modern_protein = ModernProteinLM(config)
357
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
358
+
359
+ def forward(self, input_ids, attention_mask=None, labels=None):
360
+ pooled = self.modern_protein.get_sequence_embedding(input_ids, attention_mask)
361
+ logits = self.classifier(pooled)
362
+
363
+ loss = None
364
+ if labels is not None:
365
+ if self.config.num_labels == 1:
366
+ loss_fct = nn.MSELoss()
367
+ loss = loss_fct(logits.squeeze(-1), labels.float())
368
+ else:
369
+ loss_fct = nn.CrossEntropyLoss()
370
+ loss = loss_fct(logits, labels)
371
+
372
+ return SequenceClassifierOutput(loss=loss, logits=logits)
373
+
374
+
375
+ class ModernProteinLMForTokenClassification(PreTrainedModel):
376
+ config_class = ModernProteinLMConfig
377
+
378
+ def __init__(self, config: ModernProteinLMConfig):
379
+ super().__init__(config)
380
+ self.modern_protein = ModernProteinLM(config)
381
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
382
+
383
+ def forward(self, input_ids, attention_mask=None, labels=None):
384
+ outputs = self.modern_protein(
385
+ input_ids=input_ids,
386
+ attention_mask=attention_mask,
387
+ return_dict=True,
388
+ )
389
+ logits = self.classifier(outputs.hidden_states[-1])
390
+
391
+ loss = None
392
+ if labels is not None:
393
+ loss_fct = nn.CrossEntropyLoss()
394
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
395
+
396
+ return TokenClassifierOutput(loss=loss, logits=logits)