kashif HF Staff commited on
Commit
6300b57
·
verified ·
1 Parent(s): e1cc2f4

modeling_carbon: replace _sample() override with stable LogitsProcessor API

Browse files
Files changed (1) hide show
  1. modeling_carbon.py +135 -353
modeling_carbon.py CHANGED
@@ -1,427 +1,209 @@
1
  """
2
- Carbon with bp_probs generation support.
3
 
4
- generate_bp() reuses the full HF generate() pipeline (parameter preparation,
5
- cache management, stopping criteria, logits processing, etc.) and only replaces
6
- the token selection step with bp-level independent base selection.
7
  """
8
- import os
9
- from typing import Optional, Union
10
-
11
  import torch
12
- import torch.nn as nn
13
  import torch.nn.functional as F
14
- from transformers import LlamaForCausalLM
 
15
 
16
  BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
17
  IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
18
 
19
 
20
- class CarbonForCausalLM(LlamaForCausalLM):
21
- """LlamaForCausalLM with bp-level autoregressive generation.
22
 
23
- Inherits all standard functionality (forward, generate, etc.)
24
- and adds generate_bp() for base-pair independent generation.
25
  """
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def setup_tokenizer(self, tokenizer):
28
- """Cache tokenizer and precompute lookup tables for bp generation."""
29
  self.tokenizer = tokenizer
30
  k = tokenizer.k
31
  self.k = k
32
- num_special = len(tokenizer.special_tokens)
33
- num_kmers = 4 ** k
34
 
35
- self._kmer_ids = tokenizer.get_kmer_ids()
36
- self._kmers = tokenizer.get_kmers()
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
38
  bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
39
- for j in range(k):
40
- bp_base_index[j] = torch.arange(num_kmers) >> ((k - 1 - j) * 2) & 3
41
- device = next(self.parameters()).device
42
  self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
43
 
44
  self._bp_powers = torch.tensor(
45
  [4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
46
  )
 
 
47
  flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
48
- for kmer, tid in zip(self._kmers, self._kmer_ids):
49
- idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer))
50
- flat_to_tid[idx] = tid
51
  self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
52
 
53
  def compute_bp_probs(self, logits):
54
- """Compute per-base marginal probabilities from token logits (vectorized).
55
 
56
  Args:
57
- logits: [B, V] or [B, L, V] token logits
58
  Returns:
59
  bp_probs: [B, k, 4] or [B, L, k, 4]
60
  """
61
- squeeze = False
62
- if logits.dim() == 2:
63
- logits = logits.unsqueeze(1) # [B, 1, V]
64
- squeeze = True
65
 
66
- kmer_logits = logits[:, :, self._kmer_ids] # [B, L, num_kmers]
67
  kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
68
  B, L, _ = kmer_probs.shape
69
  bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
70
  for pos in range(self.k):
71
- idx = self._bp_base_index[pos] # [num_kmers] -> 0~3
72
  for nt in range(4):
73
  bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
74
 
75
- if squeeze:
76
- bp_probs = bp_probs.squeeze(1) # [B, k, 4]
77
- return bp_probs
78
 
79
- # -------------------------------------------------------------------------
80
- # generate_bp: sets a flag then delegates to the standard generate()
81
- # -------------------------------------------------------------------------
82
- @torch.no_grad()
83
  def generate_bp(self, inputs=None, generation_config=None, **kwargs):
84
- """Same interface as generate(), but with bp-level independent base selection.
85
 
86
- Token logits are marginalized to per-base probabilities [k, 4], and each
87
- base position is selected independently. All standard generate() parameters
88
- (temperature, top_k, top_p, do_sample, attention_mask, etc.) are fully
89
- supported — they are processed by the HF generate pipeline as usual.
90
-
91
- Returns:
92
- Same as generate() — token ids tensor or GenerateOutput.
93
  """
94
- assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer() first"
95
- self._bp_generation = True
96
- try:
97
- return super().generate(
98
- inputs=inputs, generation_config=generation_config, **kwargs
99
- )
100
- finally:
101
- self._bp_generation = False
102
-
103
- # -------------------------------------------------------------------------
104
- # Override _sample: when _bp_generation is set, use bp-level token selection
105
- # -------------------------------------------------------------------------
106
- def _sample(
107
- self,
108
- input_ids,
109
- logits_processor,
110
- stopping_criteria,
111
- generation_config,
112
- synced_gpus,
113
- streamer,
114
- **model_kwargs,
115
- ):
116
- if not getattr(self, "_bp_generation", False):
117
- return super()._sample(
118
- input_ids,
119
- logits_processor,
120
- stopping_criteria,
121
- generation_config,
122
- synced_gpus,
123
- streamer,
124
- **model_kwargs,
125
- )
126
-
127
- # ==================================================================
128
- # BP generation mode — copied from transformers 4.56.0 _sample(),
129
- # with ONLY the token selection block replaced by bp marginalization.
130
- # ==================================================================
131
- from transformers.generation.utils import (
132
- GenerateDecoderOnlyOutput,
133
  )
 
 
134
 
135
- # init values
136
- pad_token_id = generation_config._pad_token_tensor
137
- output_attentions = generation_config.output_attentions
138
- output_hidden_states = generation_config.output_hidden_states
139
- output_scores = generation_config.output_scores
140
- output_logits = generation_config.output_logits
141
- return_dict_in_generate = generation_config.return_dict_in_generate
142
- has_eos_stopping_criteria = any(
143
- hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
144
- )
145
- do_sample = generation_config.do_sample
146
-
147
- # init attention / hidden states / scores tuples
148
- scores = () if (return_dict_in_generate and output_scores) else None
149
- raw_logits = () if (return_dict_in_generate and output_logits) else None
150
- decoder_attentions = (
151
- () if (return_dict_in_generate and output_attentions) else None
152
- )
153
- decoder_hidden_states = (
154
- () if (return_dict_in_generate and output_hidden_states) else None
155
- )
156
-
157
- # keep track of which sequences are already finished
158
- batch_size, cur_len = input_ids.shape[:2]
159
- this_peer_finished = False
160
- unfinished_sequences = torch.ones(
161
- batch_size, dtype=torch.long, device=input_ids.device
162
- )
163
- model_kwargs = self._get_initial_cache_position(
164
- cur_len, input_ids.device, model_kwargs
165
- )
166
-
167
- model_forward = self.__call__
168
- compile_forward = self._valid_auto_compile_criteria(
169
- model_kwargs, generation_config
170
- )
171
- if compile_forward:
172
- os.environ["TOKENIZERS_PARALLELISM"] = "0"
173
- if self.config._attn_implementation == "flash_attention_2":
174
- if (
175
- generation_config.compile_config is not None
176
- and generation_config.compile_config.fullgraph
177
- ):
178
- generation_config.compile_config.fullgraph = False
179
- model_forward = self.get_compiled_call(generation_config.compile_config)
180
-
181
- if generation_config.prefill_chunk_size is not None:
182
- model_kwargs = self._prefill_chunking(
183
- input_ids, generation_config, **model_kwargs
184
- )
185
- is_prefill = False
186
- else:
187
- is_prefill = True
188
-
189
- while self._has_unfinished_sequences(
190
- this_peer_finished, synced_gpus, device=input_ids.device
191
- ):
192
- # prepare model inputs
193
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
194
-
195
- # prepare variable output controls
196
- model_inputs.update(
197
- {"output_attentions": output_attentions} if output_attentions else {}
198
- )
199
- model_inputs.update(
200
- {"output_hidden_states": output_hidden_states}
201
- if output_hidden_states
202
- else {}
203
- )
204
-
205
- if is_prefill:
206
- outputs = self(**model_inputs, return_dict=True)
207
- is_prefill = False
208
- else:
209
- outputs = model_forward(**model_inputs, return_dict=True)
210
-
211
- # update model kwargs for next step (handles cache, attention_mask, etc.)
212
- model_kwargs = self._update_model_kwargs_for_generation(
213
- outputs,
214
- model_kwargs,
215
- is_encoder_decoder=self.config.is_encoder_decoder,
216
- )
217
- if synced_gpus and this_peer_finished:
218
- continue
219
-
220
- next_token_logits = outputs.logits[:, -1, :].to(
221
- copy=True, dtype=torch.float32, device=input_ids.device
222
- )
223
-
224
- # pre-process distribution (temperature, top_k, top_p, repetition_penalty, etc.)
225
- next_token_scores = logits_processor(input_ids, next_token_logits)
226
-
227
- # Store scores, attentions and hidden_states when required
228
- if return_dict_in_generate:
229
- if output_scores:
230
- scores += (next_token_scores,)
231
- if output_logits:
232
- raw_logits += (next_token_logits,)
233
- if output_attentions:
234
- decoder_attentions += ((outputs.attentions,),)
235
- if output_hidden_states:
236
- decoder_hidden_states += ((outputs.hidden_states,),)
237
-
238
- # =============================================================
239
- # BP-LEVEL TOKEN SELECTION (vectorized, the ONLY change)
240
- # =============================================================
241
- # [B, V] -> [B, k, 4] marginal bp probabilities
242
- bp_probs = self.compute_bp_probs(next_token_scores) # [B, k, 4]
243
-
244
- if do_sample:
245
- # [B*k, 4] -> multinomial -> [B, k]
246
- base_indices = torch.multinomial(
247
- bp_probs.view(-1, 4), 1
248
- ).view(batch_size, self.k)
249
- else:
250
- base_indices = bp_probs.argmax(dim=-1) # [B, k]
251
-
252
- # base_indices [B, k] -> flat kmer index -> token_id [B]
253
- flat_idx = (base_indices * self._bp_powers).sum(dim=-1) # [B]
254
- next_tokens = self._flat_idx_to_token_id[flat_idx] # [B]
255
- # =============================================================
256
-
257
- # finished sentences should have their next token be a padding token
258
- if has_eos_stopping_criteria:
259
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
260
- 1 - unfinished_sequences
261
- )
262
-
263
- # update generated ids, model inputs, and length for next step
264
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
265
- if streamer is not None:
266
- streamer.put(next_tokens.cpu())
267
-
268
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(
269
- input_ids, scores
270
- )
271
- this_peer_finished = unfinished_sequences.max() == 0
272
- cur_len += 1
273
-
274
- del outputs
275
-
276
- if streamer is not None:
277
- streamer.end()
278
-
279
- if return_dict_in_generate:
280
- return GenerateDecoderOnlyOutput(
281
- sequences=input_ids,
282
- scores=scores,
283
- logits=raw_logits,
284
- attentions=decoder_attentions,
285
- hidden_states=decoder_hidden_states,
286
- past_key_values=model_kwargs.get("past_key_values"),
287
- )
288
- else:
289
- return input_ids
290
 
291
  @torch.no_grad()
292
- def score_sequence(self, sequences: Union[str, list[str]]):
293
- """Score DNA sequence(s) and return per-base conditional probabilities.
294
 
295
- Each sequence is manually prepended with BOS token ("<dna>") and padded
296
- with 'A' if length is not a multiple of k. Returns probabilities for the
297
- original sequences only (excluding padding).
298
 
299
  Args:
300
- sequences: Single DNA sequence string or list of sequences
301
 
302
  Returns:
303
- Tuple of (bp_probs, actual_probs):
304
- - bp_probs: Full probability distribution
305
- * Single sequence: [seq_len, 4] tensor
306
- * Batch: list of [seq_len_i, 4] tensors
307
- - actual_probs: Probability of the actual base at each position
308
- * Single sequence: [seq_len] tensor
309
- * Batch: list of [seq_len_i] tensors
310
-
311
- bp_probs[i, j] = P(base at position i is nucleotide j | context)
312
- actual_probs[i] = P(actual base at position i | context)
313
- where j: 0=A, 1=T, 2=C, 3=G
314
-
315
- Example:
316
- # Single sequence
317
- bp_probs, actual_probs = model.score_sequence("ACGT")
318
-
319
- # Batch of sequences
320
- bp_probs_list, actual_probs_list = model.score_sequence([
321
- "ACGT" * 150,
322
- "ACGT" * 149 + "AC",
323
- ])
324
  """
325
- assert hasattr(self, "tokenizer"), "Call setup_tokenizer() first"
326
 
327
- # Handle single sequence case
328
  is_single = isinstance(sequences, str)
329
  if is_single:
330
  sequences = [sequences]
331
 
332
- # Store original info
333
- original_lens = [len(seq) for seq in sequences]
334
- original_sequences = sequences.copy()
335
 
336
- # Pad each sequence to multiple of k with 'A'
337
- padded_sequences = []
338
- for seq in sequences:
339
- if len(seq) % self.k != 0:
340
- padding_len = self.k - (len(seq) % self.k)
341
- seq = seq + 'A' * padding_len
342
- padded_sequences.append(seq)
343
 
344
- # Manually prepend BOS token "<dna>" to each sequence
345
- sequences_with_bos = ["<dna>" + seq for seq in padded_sequences]
346
 
347
- # Tokenize batch (without add_special_tokens since we added manually)
348
  inputs = self.tokenizer(
349
- sequences_with_bos,
350
- return_tensors="pt",
351
- padding=True,
352
- add_special_tokens=False
353
  )
354
  input_ids = inputs["input_ids"].to(self.device)
355
  attention_mask = inputs["attention_mask"].to(self.device)
356
 
357
- # Forward pass to get logits for all positions
358
- outputs = self(input_ids, attention_mask=attention_mask, return_dict=True)
359
- logits = outputs.logits # [B, max_seq_len, vocab_size]
360
-
361
- # Compute bp probabilities for all token positions
362
- bp_probs = self.compute_bp_probs(logits) # [B, max_seq_len, k, 4]
363
-
364
- # Process each sequence in the batch
365
- bp_probs_results = []
366
- actual_probs_results = []
367
-
368
- for i, (original_seq, original_len, padded_seq) in enumerate(
369
- zip(original_sequences, original_lens, padded_sequences)
370
- ):
371
- # Calculate number of actual sequence tokens (excluding BOS)
372
- num_seq_tokens = len(padded_seq) // self.k
373
 
374
- # Extract bp_probs for this sequence
375
- # logits[0] predicts token after BOS (first sequence token)
376
- # logits[i] predicts token[i+1]
377
- # So logits[0:num_seq_tokens] predict the sequence tokens
378
- seq_bp_probs = bp_probs[i, :num_seq_tokens] # [num_seq_tokens, k, 4]
 
 
 
 
379
 
380
- # Reshape: [num_seq_tokens, k, 4] -> [num_seq_tokens * k, 4]
381
- seq_result = seq_bp_probs.reshape(-1, 4)
382
-
383
- # Trim to original sequence length (remove padding)
384
- seq_result = seq_result[:original_len]
385
-
386
- # Extract actual base probabilities
387
- actual_probs = self._extract_actual_probs(seq_result, original_seq)
388
-
389
- bp_probs_results.append(seq_result)
390
- actual_probs_results.append(actual_probs)
391
-
392
- # Return single tensors if input was single sequence
393
  if is_single:
394
- return bp_probs_results[0], actual_probs_results[0]
395
-
396
- return bp_probs_results, actual_probs_results
397
-
398
- def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str):
399
- """Extract probabilities of actual bases in the sequence.
400
-
401
- For each position i in the sequence, returns the probability that the model
402
- assigned to the actual base at that position.
403
-
404
- For 'N' bases (unknown), returns the maximum probability across all 4 bases.
405
-
406
- Args:
407
- bp_probs: [seq_len, 4] probability distribution from logits
408
- bp_probs[i] = P(position i | context before i)
409
- sequence: DNA sequence string (may contain 'N')
410
-
411
- Returns:
412
- actual_probs: [seq_len] probabilities of actual bases
413
- actual_probs[i] = bp_probs[i, sequence[i]] for A/T/C/G
414
- actual_probs[i] = max(bp_probs[i]) for 'N'
415
- """
416
- seq_len = len(sequence)
417
- actual_probs = torch.zeros(seq_len, device=bp_probs.device, dtype=bp_probs.dtype)
418
 
 
 
419
  for i, base in enumerate(sequence):
420
- if base == 'N':
421
- # For N, take the maximum probability across all 4 bases
422
- actual_probs[i] = bp_probs[i].max()
423
- else:
424
- base_idx = BASE_TO_IDX[base]
425
- actual_probs[i] = bp_probs[i, base_idx]
426
-
427
- return actual_probs
 
1
  """
2
+ Carbon with bp-level generation and scoring.
3
 
4
+ generate_bp() plugs into the standard HF generate() pipeline via a
5
+ LogitsProcessor no internal methods are overridden, so it is compatible
6
+ with any transformers version.
7
  """
 
 
 
8
  import torch
 
9
  import torch.nn.functional as F
10
+ from transformers import LlamaForCausalLM, LogitsProcessor, LogitsProcessorList
11
+ from typing import Union
12
 
13
  BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
14
  IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
15
 
16
 
17
+ class _BPLogitsProcessor(LogitsProcessor):
18
+ """Forces token selection to use per-base marginal probabilities.
19
 
20
+ Runs LAST in the logits-processor chain so that temperature / top-k /
21
+ top-p etc. influence the marginal distributions before base selection.
22
  """
23
 
24
+ def __init__(self, kmer_ids, bp_base_index, flat_idx_to_token_id, bp_powers, k, do_sample):
25
+ self.kmer_ids = kmer_ids
26
+ self.bp_base_index = bp_base_index
27
+ self.flat_idx_to_token_id = flat_idx_to_token_id
28
+ self.bp_powers = bp_powers
29
+ self.k = k
30
+ self.do_sample = do_sample
31
+
32
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
33
+ B = scores.shape[0]
34
+ kmer_probs = F.softmax(scores[:, self.kmer_ids].float(), dim=-1) # [B, num_kmers]
35
+
36
+ # Marginalise to per-base probabilities [B, k, 4]
37
+ bp_probs = torch.zeros(B, self.k, 4, device=scores.device, dtype=kmer_probs.dtype)
38
+ for pos in range(self.k):
39
+ idx = self.bp_base_index[pos] # [num_kmers] in {0,1,2,3}
40
+ for nt in range(4):
41
+ bp_probs[:, pos, nt] = kmer_probs[:, idx == nt].sum(dim=-1)
42
+
43
+ if self.do_sample:
44
+ base_indices = torch.multinomial(bp_probs.view(-1, 4), 1).view(B, self.k)
45
+ else:
46
+ base_indices = bp_probs.argmax(dim=-1) # [B, k]
47
+
48
+ flat_idx = (base_indices * self.bp_powers).sum(dim=-1) # [B]
49
+ selected = self.flat_idx_to_token_id[flat_idx] # [B]
50
+
51
+ # One-hot: both argmax and multinomial land on the bp-selected token
52
+ new_scores = torch.full_like(scores, float("-inf"))
53
+ new_scores.scatter_(1, selected.unsqueeze(1), 0.0)
54
+ return new_scores
55
+
56
+
57
+ class CarbonForCausalLM(LlamaForCausalLM):
58
+ """LlamaForCausalLM with bp-level generation and sequence scoring."""
59
+
60
  def setup_tokenizer(self, tokenizer):
61
+ """Cache tokenizer and precompute lookup tables for bp-level operations."""
62
  self.tokenizer = tokenizer
63
  k = tokenizer.k
64
  self.k = k
 
 
65
 
66
+ device = next(self.parameters()).device
67
+
68
+ # Build ordered kmer list from the tokenizer's DNA vocab
69
+ kmer_items = sorted(
70
+ [
71
+ (kmer, tid)
72
+ for kmer, tid in tokenizer.dna_token_to_id.items()
73
+ if len(kmer) == k and all(b in "ATCG" for b in kmer)
74
+ ],
75
+ key=lambda x: x[1],
76
+ )
77
+ kmers = [item[0] for item in kmer_items]
78
+ kmer_ids = [item[1] for item in kmer_items]
79
+ num_kmers = len(kmer_ids)
80
 
81
+ self._kmer_ids = torch.tensor(kmer_ids, dtype=torch.long, device=device)
82
+
83
+ # bp_base_index[pos, j] = base index (0-3) of kmer j at position pos
84
  bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
85
+ for j, kmer in enumerate(kmers):
86
+ for pos, base in enumerate(kmer):
87
+ bp_base_index[pos, j] = BASE_TO_IDX[base]
88
  self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
89
 
90
  self._bp_powers = torch.tensor(
91
  [4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
92
  )
93
+
94
+ # flat kmer index -> token id (flat index = sum base_idx[i] * 4^(k-1-i))
95
  flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
96
+ for j, (kmer, tid) in enumerate(kmer_items):
97
+ flat_idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer))
98
+ flat_to_tid[flat_idx] = tid
99
  self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
100
 
101
  def compute_bp_probs(self, logits):
102
+ """Compute per-base marginal probabilities from token logits.
103
 
104
  Args:
105
+ logits: [B, V] or [B, L, V]
106
  Returns:
107
  bp_probs: [B, k, 4] or [B, L, k, 4]
108
  """
109
+ squeeze = logits.dim() == 2
110
+ if squeeze:
111
+ logits = logits.unsqueeze(1)
 
112
 
113
+ kmer_logits = logits[:, :, self._kmer_ids]
114
  kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
115
  B, L, _ = kmer_probs.shape
116
  bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
117
  for pos in range(self.k):
118
+ idx = self._bp_base_index[pos]
119
  for nt in range(4):
120
  bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
121
 
122
+ return bp_probs.squeeze(1) if squeeze else bp_probs
 
 
123
 
 
 
 
 
124
  def generate_bp(self, inputs=None, generation_config=None, **kwargs):
125
+ """Like generate(), but each token is selected base-by-base from marginal distributions.
126
 
127
+ Temperature, top_k, top_p, repetition_penalty etc. all apply as usual
128
+ they run before the bp processor and shift the marginal distributions.
129
+ Output shape and type are identical to generate().
 
 
 
 
130
  """
131
+ assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer(tokenizer) first"
132
+
133
+ gc = generation_config or self.generation_config
134
+ do_sample = kwargs.get("do_sample", getattr(gc, "do_sample", False))
135
+
136
+ bp_proc = _BPLogitsProcessor(
137
+ kmer_ids=self._kmer_ids,
138
+ bp_base_index=self._bp_base_index,
139
+ flat_idx_to_token_id=self._flat_idx_to_token_id,
140
+ bp_powers=self._bp_powers,
141
+ k=self.k,
142
+ do_sample=do_sample,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  )
144
+ existing = list(kwargs.pop("logits_processor", None) or [])
145
+ kwargs["logits_processor"] = LogitsProcessorList(existing + [bp_proc])
146
 
147
+ return super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  @torch.no_grad()
150
+ def score_sequence(self, sequences: Union[str, list]):
151
+ """Score DNA sequence(s) at base resolution.
152
 
153
+ Returns per-base probability distributions and the probability of the
154
+ actual base at each position, given all preceding context.
 
155
 
156
  Args:
157
+ sequences: single DNA string or list of DNA strings (ACGT only)
158
 
159
  Returns:
160
+ (bp_probs, actual_probs) for a single sequence, or
161
+ (list of bp_probs, list of actual_probs) for a batch.
162
+ bp_probs[i]: [seq_len_i, 4] — P(base | context) at each position
163
+ actual_probs[i]: [seq_len_i] — P(actual base | context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  """
165
+ assert hasattr(self, "tokenizer"), "Call setup_tokenizer(tokenizer) first"
166
 
 
167
  is_single = isinstance(sequences, str)
168
  if is_single:
169
  sequences = [sequences]
170
 
171
+ original_lens = [len(s) for s in sequences]
 
 
172
 
173
+ # Right-pad to multiple of k with 'A' (matches tokenizer convention)
174
+ padded = []
175
+ for s in sequences:
176
+ r = len(s) % self.k
177
+ padded.append(s + "A" * (self.k - r) if r else s)
 
 
178
 
179
+ # Prepend <dna> tag manually (training format)
180
+ tagged = ["<dna>" + s for s in padded]
181
 
 
182
  inputs = self.tokenizer(
183
+ tagged, return_tensors="pt", padding=True, add_special_tokens=False
 
 
 
184
  )
185
  input_ids = inputs["input_ids"].to(self.device)
186
  attention_mask = inputs["attention_mask"].to(self.device)
187
 
188
+ logits = self(input_ids, attention_mask=attention_mask, return_dict=True).logits
189
+ bp_probs_all = self.compute_bp_probs(logits) # [B, L, k, 4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ bp_results, actual_results = [], []
192
+ for i, (seq, orig_len, pad_seq) in enumerate(zip(sequences, original_lens, padded)):
193
+ num_tokens = len(pad_seq) // self.k
194
+ # logits[t] predicts token t+1; logits[0] (from <dna>) predicts token 1
195
+ seq_bp = bp_probs_all[i, :num_tokens] # [num_tokens, k, 4]
196
+ seq_bp = seq_bp.reshape(-1, 4)[:orig_len] # [orig_len, 4]
197
+ actual = self._extract_actual_probs(seq_bp, seq)
198
+ bp_results.append(seq_bp)
199
+ actual_results.append(actual)
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  if is_single:
202
+ return bp_results[0], actual_results[0]
203
+ return bp_results, actual_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str) -> torch.Tensor:
206
+ actual = torch.zeros(len(sequence), device=bp_probs.device, dtype=bp_probs.dtype)
207
  for i, base in enumerate(sequence):
208
+ actual[i] = bp_probs[i].max() if base == "N" else bp_probs[i, BASE_TO_IDX[base]]
209
+ return actual