GenerTeam commited on
Commit
01fd9be
·
verified ·
1 Parent(s): ceed708

Upload modeling_carbon.py

Browse files
Files changed (1) hide show
  1. modeling_carbon.py +427 -0
modeling_carbon.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Carbon with bp_probs generation support.
3
+
4
+ generate_bp() reuses the full HF generate() pipeline (parameter preparation,
5
+ cache management, stopping criteria, logits processing, etc.) and only replaces
6
+ the token selection step with bp-level independent base selection.
7
+ """
8
+ import os
9
+ from typing import Optional, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from transformers import LlamaForCausalLM
15
+
16
+ BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
17
+ IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
18
+
19
+
20
+ class CarbonForCausalLM(LlamaForCausalLM):
21
+ """LlamaForCausalLM with bp-level autoregressive generation.
22
+
23
+ Inherits all standard functionality (forward, generate, etc.)
24
+ and adds generate_bp() for base-pair independent generation.
25
+ """
26
+
27
+ def setup_tokenizer(self, tokenizer):
28
+ """Cache tokenizer and precompute lookup tables for bp generation."""
29
+ self.tokenizer = tokenizer
30
+ k = tokenizer.k
31
+ self.k = k
32
+ num_special = len(tokenizer.special_tokens)
33
+ num_kmers = 4 ** k
34
+
35
+ self._kmer_ids = tokenizer.get_kmer_ids()
36
+ self._kmers = tokenizer.get_kmers()
37
+
38
+ bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
39
+ for j in range(k):
40
+ bp_base_index[j] = torch.arange(num_kmers) >> ((k - 1 - j) * 2) & 3
41
+ device = next(self.parameters()).device
42
+ self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
43
+
44
+ self._bp_powers = torch.tensor(
45
+ [4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
46
+ )
47
+ flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
48
+ for kmer, tid in zip(self._kmers, self._kmer_ids):
49
+ idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer))
50
+ flat_to_tid[idx] = tid
51
+ self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
52
+
53
+ def compute_bp_probs(self, logits):
54
+ """Compute per-base marginal probabilities from token logits (vectorized).
55
+
56
+ Args:
57
+ logits: [B, V] or [B, L, V] token logits
58
+ Returns:
59
+ bp_probs: [B, k, 4] or [B, L, k, 4]
60
+ """
61
+ squeeze = False
62
+ if logits.dim() == 2:
63
+ logits = logits.unsqueeze(1) # [B, 1, V]
64
+ squeeze = True
65
+
66
+ kmer_logits = logits[:, :, self._kmer_ids] # [B, L, num_kmers]
67
+ kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
68
+ B, L, _ = kmer_probs.shape
69
+ bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
70
+ for pos in range(self.k):
71
+ idx = self._bp_base_index[pos] # [num_kmers] -> 0~3
72
+ for nt in range(4):
73
+ bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
74
+
75
+ if squeeze:
76
+ bp_probs = bp_probs.squeeze(1) # [B, k, 4]
77
+ return bp_probs
78
+
79
+ # -------------------------------------------------------------------------
80
+ # generate_bp: sets a flag then delegates to the standard generate()
81
+ # -------------------------------------------------------------------------
82
+ @torch.no_grad()
83
+ def generate_bp(self, inputs=None, generation_config=None, **kwargs):
84
+ """Same interface as generate(), but with bp-level independent base selection.
85
+
86
+ Token logits are marginalized to per-base probabilities [k, 4], and each
87
+ base position is selected independently. All standard generate() parameters
88
+ (temperature, top_k, top_p, do_sample, attention_mask, etc.) are fully
89
+ supported — they are processed by the HF generate pipeline as usual.
90
+
91
+ Returns:
92
+ Same as generate() — token ids tensor or GenerateOutput.
93
+ """
94
+ assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer() first"
95
+ self._bp_generation = True
96
+ try:
97
+ return super().generate(
98
+ inputs=inputs, generation_config=generation_config, **kwargs
99
+ )
100
+ finally:
101
+ self._bp_generation = False
102
+
103
+ # -------------------------------------------------------------------------
104
+ # Override _sample: when _bp_generation is set, use bp-level token selection
105
+ # -------------------------------------------------------------------------
106
+ def _sample(
107
+ self,
108
+ input_ids,
109
+ logits_processor,
110
+ stopping_criteria,
111
+ generation_config,
112
+ synced_gpus,
113
+ streamer,
114
+ **model_kwargs,
115
+ ):
116
+ if not getattr(self, "_bp_generation", False):
117
+ return super()._sample(
118
+ input_ids,
119
+ logits_processor,
120
+ stopping_criteria,
121
+ generation_config,
122
+ synced_gpus,
123
+ streamer,
124
+ **model_kwargs,
125
+ )
126
+
127
+ # ==================================================================
128
+ # BP generation mode — copied from transformers 4.56.0 _sample(),
129
+ # with ONLY the token selection block replaced by bp marginalization.
130
+ # ==================================================================
131
+ from transformers.generation.utils import (
132
+ GenerateDecoderOnlyOutput,
133
+ )
134
+
135
+ # init values
136
+ pad_token_id = generation_config._pad_token_tensor
137
+ output_attentions = generation_config.output_attentions
138
+ output_hidden_states = generation_config.output_hidden_states
139
+ output_scores = generation_config.output_scores
140
+ output_logits = generation_config.output_logits
141
+ return_dict_in_generate = generation_config.return_dict_in_generate
142
+ has_eos_stopping_criteria = any(
143
+ hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
144
+ )
145
+ do_sample = generation_config.do_sample
146
+
147
+ # init attention / hidden states / scores tuples
148
+ scores = () if (return_dict_in_generate and output_scores) else None
149
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
150
+ decoder_attentions = (
151
+ () if (return_dict_in_generate and output_attentions) else None
152
+ )
153
+ decoder_hidden_states = (
154
+ () if (return_dict_in_generate and output_hidden_states) else None
155
+ )
156
+
157
+ # keep track of which sequences are already finished
158
+ batch_size, cur_len = input_ids.shape[:2]
159
+ this_peer_finished = False
160
+ unfinished_sequences = torch.ones(
161
+ batch_size, dtype=torch.long, device=input_ids.device
162
+ )
163
+ model_kwargs = self._get_initial_cache_position(
164
+ cur_len, input_ids.device, model_kwargs
165
+ )
166
+
167
+ model_forward = self.__call__
168
+ compile_forward = self._valid_auto_compile_criteria(
169
+ model_kwargs, generation_config
170
+ )
171
+ if compile_forward:
172
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
173
+ if self.config._attn_implementation == "flash_attention_2":
174
+ if (
175
+ generation_config.compile_config is not None
176
+ and generation_config.compile_config.fullgraph
177
+ ):
178
+ generation_config.compile_config.fullgraph = False
179
+ model_forward = self.get_compiled_call(generation_config.compile_config)
180
+
181
+ if generation_config.prefill_chunk_size is not None:
182
+ model_kwargs = self._prefill_chunking(
183
+ input_ids, generation_config, **model_kwargs
184
+ )
185
+ is_prefill = False
186
+ else:
187
+ is_prefill = True
188
+
189
+ while self._has_unfinished_sequences(
190
+ this_peer_finished, synced_gpus, device=input_ids.device
191
+ ):
192
+ # prepare model inputs
193
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
194
+
195
+ # prepare variable output controls
196
+ model_inputs.update(
197
+ {"output_attentions": output_attentions} if output_attentions else {}
198
+ )
199
+ model_inputs.update(
200
+ {"output_hidden_states": output_hidden_states}
201
+ if output_hidden_states
202
+ else {}
203
+ )
204
+
205
+ if is_prefill:
206
+ outputs = self(**model_inputs, return_dict=True)
207
+ is_prefill = False
208
+ else:
209
+ outputs = model_forward(**model_inputs, return_dict=True)
210
+
211
+ # update model kwargs for next step (handles cache, attention_mask, etc.)
212
+ model_kwargs = self._update_model_kwargs_for_generation(
213
+ outputs,
214
+ model_kwargs,
215
+ is_encoder_decoder=self.config.is_encoder_decoder,
216
+ )
217
+ if synced_gpus and this_peer_finished:
218
+ continue
219
+
220
+ next_token_logits = outputs.logits[:, -1, :].to(
221
+ copy=True, dtype=torch.float32, device=input_ids.device
222
+ )
223
+
224
+ # pre-process distribution (temperature, top_k, top_p, repetition_penalty, etc.)
225
+ next_token_scores = logits_processor(input_ids, next_token_logits)
226
+
227
+ # Store scores, attentions and hidden_states when required
228
+ if return_dict_in_generate:
229
+ if output_scores:
230
+ scores += (next_token_scores,)
231
+ if output_logits:
232
+ raw_logits += (next_token_logits,)
233
+ if output_attentions:
234
+ decoder_attentions += ((outputs.attentions,),)
235
+ if output_hidden_states:
236
+ decoder_hidden_states += ((outputs.hidden_states,),)
237
+
238
+ # =============================================================
239
+ # BP-LEVEL TOKEN SELECTION (vectorized, the ONLY change)
240
+ # =============================================================
241
+ # [B, V] -> [B, k, 4] marginal bp probabilities
242
+ bp_probs = self.compute_bp_probs(next_token_scores) # [B, k, 4]
243
+
244
+ if do_sample:
245
+ # [B*k, 4] -> multinomial -> [B, k]
246
+ base_indices = torch.multinomial(
247
+ bp_probs.view(-1, 4), 1
248
+ ).view(batch_size, self.k)
249
+ else:
250
+ base_indices = bp_probs.argmax(dim=-1) # [B, k]
251
+
252
+ # base_indices [B, k] -> flat kmer index -> token_id [B]
253
+ flat_idx = (base_indices * self._bp_powers).sum(dim=-1) # [B]
254
+ next_tokens = self._flat_idx_to_token_id[flat_idx] # [B]
255
+ # =============================================================
256
+
257
+ # finished sentences should have their next token be a padding token
258
+ if has_eos_stopping_criteria:
259
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
260
+ 1 - unfinished_sequences
261
+ )
262
+
263
+ # update generated ids, model inputs, and length for next step
264
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
265
+ if streamer is not None:
266
+ streamer.put(next_tokens.cpu())
267
+
268
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
269
+ input_ids, scores
270
+ )
271
+ this_peer_finished = unfinished_sequences.max() == 0
272
+ cur_len += 1
273
+
274
+ del outputs
275
+
276
+ if streamer is not None:
277
+ streamer.end()
278
+
279
+ if return_dict_in_generate:
280
+ return GenerateDecoderOnlyOutput(
281
+ sequences=input_ids,
282
+ scores=scores,
283
+ logits=raw_logits,
284
+ attentions=decoder_attentions,
285
+ hidden_states=decoder_hidden_states,
286
+ past_key_values=model_kwargs.get("past_key_values"),
287
+ )
288
+ else:
289
+ return input_ids
290
+
291
+ @torch.no_grad()
292
+ def score_sequence(self, sequences: Union[str, list[str]]):
293
+ """Score DNA sequence(s) and return per-base conditional probabilities.
294
+
295
+ Each sequence is manually prepended with BOS token ("<dna>") and padded
296
+ with 'A' if length is not a multiple of k. Returns probabilities for the
297
+ original sequences only (excluding padding).
298
+
299
+ Args:
300
+ sequences: Single DNA sequence string or list of sequences
301
+
302
+ Returns:
303
+ Tuple of (bp_probs, actual_probs):
304
+ - bp_probs: Full probability distribution
305
+ * Single sequence: [seq_len, 4] tensor
306
+ * Batch: list of [seq_len_i, 4] tensors
307
+ - actual_probs: Probability of the actual base at each position
308
+ * Single sequence: [seq_len] tensor
309
+ * Batch: list of [seq_len_i] tensors
310
+
311
+ bp_probs[i, j] = P(base at position i is nucleotide j | context)
312
+ actual_probs[i] = P(actual base at position i | context)
313
+ where j: 0=A, 1=T, 2=C, 3=G
314
+
315
+ Example:
316
+ # Single sequence
317
+ bp_probs, actual_probs = model.score_sequence("ACGT")
318
+
319
+ # Batch of sequences
320
+ bp_probs_list, actual_probs_list = model.score_sequence([
321
+ "ACGT" * 150,
322
+ "ACGT" * 149 + "AC",
323
+ ])
324
+ """
325
+ assert hasattr(self, "tokenizer"), "Call setup_tokenizer() first"
326
+
327
+ # Handle single sequence case
328
+ is_single = isinstance(sequences, str)
329
+ if is_single:
330
+ sequences = [sequences]
331
+
332
+ # Store original info
333
+ original_lens = [len(seq) for seq in sequences]
334
+ original_sequences = sequences.copy()
335
+
336
+ # Pad each sequence to multiple of k with 'A'
337
+ padded_sequences = []
338
+ for seq in sequences:
339
+ if len(seq) % self.k != 0:
340
+ padding_len = self.k - (len(seq) % self.k)
341
+ seq = seq + 'A' * padding_len
342
+ padded_sequences.append(seq)
343
+
344
+ # Manually prepend BOS token "<dna>" to each sequence
345
+ sequences_with_bos = ["<dna>" + seq for seq in padded_sequences]
346
+
347
+ # Tokenize batch (without add_special_tokens since we added manually)
348
+ inputs = self.tokenizer(
349
+ sequences_with_bos,
350
+ return_tensors="pt",
351
+ padding=True,
352
+ add_special_tokens=False
353
+ )
354
+ input_ids = inputs["input_ids"].to(self.device)
355
+ attention_mask = inputs["attention_mask"].to(self.device)
356
+
357
+ # Forward pass to get logits for all positions
358
+ outputs = self(input_ids, attention_mask=attention_mask, return_dict=True)
359
+ logits = outputs.logits # [B, max_seq_len, vocab_size]
360
+
361
+ # Compute bp probabilities for all token positions
362
+ bp_probs = self.compute_bp_probs(logits) # [B, max_seq_len, k, 4]
363
+
364
+ # Process each sequence in the batch
365
+ bp_probs_results = []
366
+ actual_probs_results = []
367
+
368
+ for i, (original_seq, original_len, padded_seq) in enumerate(
369
+ zip(original_sequences, original_lens, padded_sequences)
370
+ ):
371
+ # Calculate number of actual sequence tokens (excluding BOS)
372
+ num_seq_tokens = len(padded_seq) // self.k
373
+
374
+ # Extract bp_probs for this sequence
375
+ # logits[0] predicts token after BOS (first sequence token)
376
+ # logits[i] predicts token[i+1]
377
+ # So logits[0:num_seq_tokens] predict the sequence tokens
378
+ seq_bp_probs = bp_probs[i, :num_seq_tokens] # [num_seq_tokens, k, 4]
379
+
380
+ # Reshape: [num_seq_tokens, k, 4] -> [num_seq_tokens * k, 4]
381
+ seq_result = seq_bp_probs.reshape(-1, 4)
382
+
383
+ # Trim to original sequence length (remove padding)
384
+ seq_result = seq_result[:original_len]
385
+
386
+ # Extract actual base probabilities
387
+ actual_probs = self._extract_actual_probs(seq_result, original_seq)
388
+
389
+ bp_probs_results.append(seq_result)
390
+ actual_probs_results.append(actual_probs)
391
+
392
+ # Return single tensors if input was single sequence
393
+ if is_single:
394
+ return bp_probs_results[0], actual_probs_results[0]
395
+
396
+ return bp_probs_results, actual_probs_results
397
+
398
+ def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str):
399
+ """Extract probabilities of actual bases in the sequence.
400
+
401
+ For each position i in the sequence, returns the probability that the model
402
+ assigned to the actual base at that position.
403
+
404
+ For 'N' bases (unknown), returns the maximum probability across all 4 bases.
405
+
406
+ Args:
407
+ bp_probs: [seq_len, 4] probability distribution from logits
408
+ bp_probs[i] = P(position i | context before i)
409
+ sequence: DNA sequence string (may contain 'N')
410
+
411
+ Returns:
412
+ actual_probs: [seq_len] probabilities of actual bases
413
+ actual_probs[i] = bp_probs[i, sequence[i]] for A/T/C/G
414
+ actual_probs[i] = max(bp_probs[i]) for 'N'
415
+ """
416
+ seq_len = len(sequence)
417
+ actual_probs = torch.zeros(seq_len, device=bp_probs.device, dtype=bp_probs.dtype)
418
+
419
+ for i, base in enumerate(sequence):
420
+ if base == 'N':
421
+ # For N, take the maximum probability across all 4 bases
422
+ actual_probs[i] = bp_probs[i].max()
423
+ else:
424
+ base_idx = BASE_TO_IDX[base]
425
+ actual_probs[i] = bp_probs[i, base_idx]
426
+
427
+ return actual_probs