dapatil211 commited on
Commit
8ba58fb
·
1 Parent(s): 021848e

Re-add modeling_amplify after history purge

Browse files
Files changed (1) hide show
  1. modeling_amplify.py +291 -0
modeling_amplify.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone AMPLIFY model for HuggingFace Hub (trust_remote_code=True).
2
+
3
+ This is a self-contained file that can be shipped in a HuggingFace repo so that
4
+ ``AutoModel.from_pretrained(..., trust_remote_code=True)`` works without
5
+ installing the ``amplify`` package.
6
+
7
+ Based on: https://github.com/chandar-lab/AMPLIFY
8
+ """
9
+
10
+ from typing import Tuple
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn.functional import scaled_dot_product_attention
15
+ from transformers import PreTrainedModel, PretrainedConfig
16
+ from transformers.modeling_outputs import MaskedLMOutput
17
+
18
+ # Optional: flash attention for packed-sequence training. Not required for
19
+ # standard inference.
20
+ try:
21
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func # type: ignore
22
+ except ImportError:
23
+ flash_attn_varlen_func = None
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Rotary positional embeddings (inlined from amplify.model.rotary)
28
+ # ---------------------------------------------------------------------------
29
+
30
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
31
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
32
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
33
+ freqs = torch.outer(t, freqs)
34
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
35
+ return freqs_cis
36
+
37
+
38
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
39
+ assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1])
40
+ return freqs_cis.contiguous().unsqueeze(2)
41
+
42
+
43
+ def apply_rotary_emb(
44
+ xq: torch.Tensor,
45
+ xk: torch.Tensor,
46
+ freqs_cis: torch.Tensor,
47
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
48
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
49
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
50
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
51
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
52
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
53
+ return xq_out.type_as(xq), xk_out.type_as(xk)
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # Config
58
+ # ---------------------------------------------------------------------------
59
+
60
+ class AMPLIFYConfig(PretrainedConfig):
61
+ model_type = "AMPLIFY"
62
+
63
+ def __init__(
64
+ self,
65
+ hidden_size: int = 960,
66
+ num_hidden_layers: int = 32,
67
+ num_attention_heads: int = 15,
68
+ intermediate_size: int = 3840,
69
+ embedding_init_range: float = 0.02,
70
+ decoder_init_range: float = 0.02,
71
+ norm_eps: float = 1e-05,
72
+ vocab_size: int = 32,
73
+ pad_token_id: int = 0,
74
+ max_length: int = 2048,
75
+ max_protein_length: int = 50000,
76
+ **kwargs,
77
+ ):
78
+ super().__init__(**kwargs)
79
+ self.hidden_size = hidden_size
80
+ self.num_hidden_layers = num_hidden_layers
81
+ self.num_attention_heads = num_attention_heads
82
+ self.intermediate_size = intermediate_size
83
+ self.embedding_init_range = embedding_init_range
84
+ self.decoder_init_range = decoder_init_range
85
+ self.norm_eps = norm_eps
86
+ self.vocab_size = vocab_size
87
+ self.pad_token_id = pad_token_id
88
+ self.max_length = max_length
89
+ self.max_protein_length = max_protein_length
90
+
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # Encoder blocks
94
+ # ---------------------------------------------------------------------------
95
+
96
+ class EncoderBlock(nn.Module):
97
+ """Standard transformer encoder block with SwiGLU FFN and RoPE."""
98
+
99
+ def __init__(self, config: AMPLIFYConfig):
100
+ super().__init__()
101
+ self.config = config
102
+ self.d_head = config.hidden_size // config.num_attention_heads
103
+
104
+ # Attention
105
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=False)
106
+ self.wo = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
107
+
108
+ # SwiGLU FFN
109
+ multiple_of = 8
110
+ intermediate_size = multiple_of * (
111
+ (int(2 * config.intermediate_size / 3) + multiple_of - 1) // multiple_of
112
+ )
113
+ self.c_fc = nn.Linear(config.hidden_size, 2 * intermediate_size, bias=False)
114
+ self.silu = nn.SiLU()
115
+ self.mlp_c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
116
+
117
+ self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
118
+ self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
119
+
120
+ def forward(
121
+ self,
122
+ x: torch.Tensor,
123
+ attention_mask: torch.Tensor,
124
+ freqs_cis: torch.Tensor,
125
+ output_attentions: bool,
126
+ max_seqlen: int = None,
127
+ cu_seqlens: torch.Tensor = None,
128
+ ):
129
+ batch_size, seq_len, _ = x.shape
130
+
131
+ xq, xk, xv = (
132
+ self.qkv(self.attention_norm(x))
133
+ .reshape(batch_size, seq_len, self.config.num_attention_heads, self.d_head * 3)
134
+ .chunk(3, axis=-1)
135
+ )
136
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
137
+
138
+ attn_weights = None
139
+
140
+ if cu_seqlens is not None:
141
+ assert flash_attn_varlen_func is not None, (
142
+ "flash_attn is required for packed-sequence attention. "
143
+ "Install with: pip install flash-attn"
144
+ )
145
+ attn = flash_attn_varlen_func(
146
+ q=xq.squeeze(0),
147
+ k=xk.squeeze(0),
148
+ v=xv.squeeze(0),
149
+ cu_seqlens_q=cu_seqlens.squeeze(),
150
+ cu_seqlens_k=cu_seqlens.squeeze(),
151
+ max_seqlen_q=max_seqlen,
152
+ max_seqlen_k=max_seqlen,
153
+ dropout_p=0.0,
154
+ causal=False,
155
+ )
156
+ elif output_attentions:
157
+ attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
158
+ if attention_mask is not None:
159
+ attn_weights = attn_weights * attention_mask
160
+ attn_weights = attn_weights.softmax(-1)
161
+ attn = attn_weights @ xv.permute(0, 2, 1, 3)
162
+ attn = attn.transpose(1, 2)
163
+ else:
164
+ attn = scaled_dot_product_attention(
165
+ query=xq.transpose(1, 2),
166
+ key=xk.transpose(1, 2),
167
+ value=xv.transpose(1, 2),
168
+ attn_mask=attention_mask.bool() if attention_mask is not None else None,
169
+ dropout_p=0,
170
+ ).transpose(1, 2)
171
+
172
+ attn = self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.d_head))
173
+
174
+ x = x + attn
175
+
176
+ uv = self.c_fc(self.ffn_norm(x))
177
+ u, v = torch.chunk(uv, 2, dim=-1)
178
+ x_mlp = u * self.silu(v)
179
+ h_mlp = self.mlp_c_proj(x_mlp)
180
+
181
+ x = x + h_mlp
182
+ return x, attn_weights
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # Model
187
+ # ---------------------------------------------------------------------------
188
+
189
+ class AMPLIFYPreTrainedModel(PreTrainedModel):
190
+ config_class = AMPLIFYConfig
191
+
192
+ def _init_weights(self, module):
193
+ if isinstance(module, nn.Linear):
194
+ module.weight.data.uniform_(
195
+ -self.config.decoder_init_range, self.config.decoder_init_range
196
+ )
197
+ elif isinstance(module, nn.Embedding):
198
+ module.weight.data.uniform_(
199
+ -self.config.embedding_init_range, self.config.embedding_init_range
200
+ )
201
+
202
+
203
+ class AMPLIFY(AMPLIFYPreTrainedModel):
204
+ """AMPLIFY protein language model.
205
+
206
+ A transformer encoder for protein sequences using RoPE and SwiGLU,
207
+ trained with masked language modelling.
208
+ """
209
+
210
+ def __init__(self, config: AMPLIFYConfig, **kwargs):
211
+ super().__init__(config)
212
+ self.config = config
213
+
214
+ self.encoder = nn.Embedding(
215
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
216
+ )
217
+
218
+ self.transformer_encoder = nn.ModuleList()
219
+ for _ in range(config.num_hidden_layers):
220
+ self.transformer_encoder.append(EncoderBlock(config))
221
+
222
+ self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
223
+
224
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
225
+
226
+ freqs_cis = precompute_freqs_cis(
227
+ config.hidden_size // config.num_attention_heads,
228
+ config.max_protein_length * 2,
229
+ )
230
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
231
+
232
+ self.post_init()
233
+
234
+ def forward(
235
+ self,
236
+ input_ids: torch.Tensor,
237
+ position_ids: torch.Tensor = None,
238
+ max_seqlen: int = None,
239
+ cu_seqlens: torch.Tensor = None,
240
+ attention_mask: torch.Tensor = None,
241
+ output_hidden_states: bool = False,
242
+ output_attentions: bool = False,
243
+ ):
244
+ hidden_states, attentions = [], []
245
+
246
+ if isinstance(output_hidden_states, bool) and not output_hidden_states:
247
+ output_hidden_index = self.config.num_hidden_layers + 1
248
+ elif isinstance(output_hidden_states, int):
249
+ output_hidden_index = output_hidden_states
250
+ else:
251
+ output_hidden_index = 0
252
+
253
+ if attention_mask is not None:
254
+ attention_mask = (
255
+ attention_mask.unsqueeze(1)
256
+ .unsqueeze(1)
257
+ .repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
258
+ )
259
+
260
+ if cu_seqlens is not None:
261
+ assert not output_attentions, "Output attentions is not supported when sequences are packed."
262
+ assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
263
+ assert input_ids.shape[0] == 1, "Cumulative sequence lengths are provided but input_ids are not packed."
264
+ assert input_ids.is_cuda, "Packing uses flash-attention and is only supported on GPU."
265
+
266
+ # RoPE
267
+ if position_ids is not None:
268
+ freqs_cis = self.freqs_cis[position_ids]
269
+ else:
270
+ freqs_cis = (
271
+ self.freqs_cis[: input_ids.shape[1]]
272
+ .unsqueeze(0)
273
+ .repeat(input_ids.shape[0], 1, 1)
274
+ )
275
+
276
+ x = self.encoder(input_ids)
277
+
278
+ for idx, layer in enumerate(self.transformer_encoder):
279
+ x, attn = layer(x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
280
+ if idx >= output_hidden_index:
281
+ hidden_states.append(x)
282
+ if output_attentions:
283
+ attentions.append(attn)
284
+
285
+ logits = self.decoder(self.layer_norm(x))
286
+
287
+ return MaskedLMOutput(
288
+ logits=logits,
289
+ hidden_states=hidden_states,
290
+ attentions=attentions,
291
+ )