""" Carbon with bp_probs generation support. generate_bp() reuses the full HF generate() pipeline (parameter preparation, cache management, stopping criteria, logits processing, etc.) and only replaces the token selection step with bp-level independent base selection. """ import os from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import LlamaForCausalLM BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1} IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"} class CarbonForCausalLM(LlamaForCausalLM): """LlamaForCausalLM with bp-level autoregressive generation. Inherits all standard functionality (forward, generate, etc.) and adds generate_bp() for base-pair independent generation. """ def setup_tokenizer(self, tokenizer): """Cache tokenizer and precompute lookup tables for bp generation.""" self.tokenizer = tokenizer k = tokenizer.k self.k = k num_special = len(tokenizer.special_tokens) num_kmers = 4 ** k self._kmer_ids = tokenizer.get_kmer_ids() self._kmers = tokenizer.get_kmers() bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long) for j in range(k): bp_base_index[j] = torch.arange(num_kmers) >> ((k - 1 - j) * 2) & 3 device = next(self.parameters()).device self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False) self._bp_powers = torch.tensor( [4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device ) flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device) for kmer, tid in zip(self._kmers, self._kmer_ids): idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer)) flat_to_tid[idx] = tid self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False) def compute_bp_probs(self, logits): """Compute per-base marginal probabilities from token logits (vectorized). Args: logits: [B, V] or [B, L, V] token logits Returns: bp_probs: [B, k, 4] or [B, L, k, 4] """ squeeze = False if logits.dim() == 2: logits = logits.unsqueeze(1) # [B, 1, V] squeeze = True kmer_logits = logits[:, :, self._kmer_ids] # [B, L, num_kmers] kmer_probs = F.softmax(kmer_logits.float(), dim=-1) B, L, _ = kmer_probs.shape bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype) for pos in range(self.k): idx = self._bp_base_index[pos] # [num_kmers] -> 0~3 for nt in range(4): bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1) if squeeze: bp_probs = bp_probs.squeeze(1) # [B, k, 4] return bp_probs # ------------------------------------------------------------------------- # generate_bp: sets a flag then delegates to the standard generate() # ------------------------------------------------------------------------- @torch.no_grad() def generate_bp(self, inputs=None, generation_config=None, **kwargs): """Same interface as generate(), but with bp-level independent base selection. Token logits are marginalized to per-base probabilities [k, 4], and each base position is selected independently. All standard generate() parameters (temperature, top_k, top_p, do_sample, attention_mask, etc.) are fully supported — they are processed by the HF generate pipeline as usual. Returns: Same as generate() — token ids tensor or GenerateOutput. """ assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer() first" self._bp_generation = True try: return super().generate( inputs=inputs, generation_config=generation_config, **kwargs ) finally: self._bp_generation = False # ------------------------------------------------------------------------- # Override _sample: when _bp_generation is set, use bp-level token selection # ------------------------------------------------------------------------- def _sample( self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs, ): if not getattr(self, "_bp_generation", False): return super()._sample( input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs, ) # ================================================================== # BP generation mode — copied from transformers 4.56.0 _sample(), # with ONLY the token selection block replaced by bp marginalization. # ================================================================== from transformers.generation.utils import ( GenerateDecoderOnlyOutput, ) # init values pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any( hasattr(criteria, "eos_token_id") for criteria in stopping_criteria ) do_sample = generation_config.do_sample # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = ( () if (return_dict_in_generate and output_attentions) else None ) decoder_hidden_states = ( () if (return_dict_in_generate and output_hidden_states) else None ) # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape[:2] this_peer_finished = False unfinished_sequences = torch.ones( batch_size, dtype=torch.long, device=input_ids.device ) model_kwargs = self._get_initial_cache_position( cur_len, input_ids.device, model_kwargs ) model_forward = self.__call__ compile_forward = self._valid_auto_compile_criteria( model_kwargs, generation_config ) if compile_forward: os.environ["TOKENIZERS_PARALLELISM"] = "0" if self.config._attn_implementation == "flash_attention_2": if ( generation_config.compile_config is not None and generation_config.compile_config.fullgraph ): generation_config.compile_config.fullgraph = False model_forward = self.get_compiled_call(generation_config.compile_config) if generation_config.prefill_chunk_size is not None: model_kwargs = self._prefill_chunking( input_ids, generation_config, **model_kwargs ) is_prefill = False else: is_prefill = True while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls model_inputs.update( {"output_attentions": output_attentions} if output_attentions else {} ) model_inputs.update( {"output_hidden_states": output_hidden_states} if output_hidden_states else {} ) if is_prefill: outputs = self(**model_inputs, return_dict=True) is_prefill = False else: outputs = model_forward(**model_inputs, return_dict=True) # update model kwargs for next step (handles cache, attention_mask, etc.) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) if synced_gpus and this_peer_finished: continue next_token_logits = outputs.logits[:, -1, :].to( copy=True, dtype=torch.float32, device=input_ids.device ) # pre-process distribution (temperature, top_k, top_p, repetition_penalty, etc.) next_token_scores = logits_processor(input_ids, next_token_logits) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ((outputs.attentions,),) if output_hidden_states: decoder_hidden_states += ((outputs.hidden_states,),) # ============================================================= # BP-LEVEL TOKEN SELECTION (vectorized, the ONLY change) # ============================================================= # [B, V] -> [B, k, 4] marginal bp probabilities bp_probs = self.compute_bp_probs(next_token_scores) # [B, k, 4] if do_sample: # [B*k, 4] -> multinomial -> [B, k] base_indices = torch.multinomial( bp_probs.view(-1, 4), 1 ).view(batch_size, self.k) else: base_indices = bp_probs.argmax(dim=-1) # [B, k] # base_indices [B, k] -> flat kmer index -> token_id [B] flat_idx = (base_indices * self._bp_powers).sum(dim=-1) # [B] next_tokens = self._flat_idx_to_token_id[flat_idx] # [B] # ============================================================= # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( 1 - unfinished_sequences ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) unfinished_sequences = unfinished_sequences & ~stopping_criteria( input_ids, scores ) this_peer_finished = unfinished_sequences.max() == 0 cur_len += 1 del outputs if streamer is not None: streamer.end() if return_dict_in_generate: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids @torch.no_grad() def score_sequence(self, sequences: Union[str, list[str]]): """Score DNA sequence(s) and return per-base conditional probabilities. Each sequence is manually prepended with BOS token ("") and padded with 'A' if length is not a multiple of k. Returns probabilities for the original sequences only (excluding padding). Args: sequences: Single DNA sequence string or list of sequences Returns: Tuple of (bp_probs, actual_probs): - bp_probs: Full probability distribution * Single sequence: [seq_len, 4] tensor * Batch: list of [seq_len_i, 4] tensors - actual_probs: Probability of the actual base at each position * Single sequence: [seq_len] tensor * Batch: list of [seq_len_i] tensors bp_probs[i, j] = P(base at position i is nucleotide j | context) actual_probs[i] = P(actual base at position i | context) where j: 0=A, 1=T, 2=C, 3=G Example: # Single sequence bp_probs, actual_probs = model.score_sequence("ACGT") # Batch of sequences bp_probs_list, actual_probs_list = model.score_sequence([ "ACGT" * 150, "ACGT" * 149 + "AC", ]) """ assert hasattr(self, "tokenizer"), "Call setup_tokenizer() first" # Handle single sequence case is_single = isinstance(sequences, str) if is_single: sequences = [sequences] # Store original info original_lens = [len(seq) for seq in sequences] original_sequences = sequences.copy() # Pad each sequence to multiple of k with 'A' padded_sequences = [] for seq in sequences: if len(seq) % self.k != 0: padding_len = self.k - (len(seq) % self.k) seq = seq + 'A' * padding_len padded_sequences.append(seq) # Manually prepend BOS token "" to each sequence sequences_with_bos = ["" + seq for seq in padded_sequences] # Tokenize batch (without add_special_tokens since we added manually) inputs = self.tokenizer( sequences_with_bos, return_tensors="pt", padding=True, add_special_tokens=False ) input_ids = inputs["input_ids"].to(self.device) attention_mask = inputs["attention_mask"].to(self.device) # Forward pass to get logits for all positions outputs = self(input_ids, attention_mask=attention_mask, return_dict=True) logits = outputs.logits # [B, max_seq_len, vocab_size] # Compute bp probabilities for all token positions bp_probs = self.compute_bp_probs(logits) # [B, max_seq_len, k, 4] # Process each sequence in the batch bp_probs_results = [] actual_probs_results = [] for i, (original_seq, original_len, padded_seq) in enumerate( zip(original_sequences, original_lens, padded_sequences) ): # Calculate number of actual sequence tokens (excluding BOS) num_seq_tokens = len(padded_seq) // self.k # Extract bp_probs for this sequence # logits[0] predicts token after BOS (first sequence token) # logits[i] predicts token[i+1] # So logits[0:num_seq_tokens] predict the sequence tokens seq_bp_probs = bp_probs[i, :num_seq_tokens] # [num_seq_tokens, k, 4] # Reshape: [num_seq_tokens, k, 4] -> [num_seq_tokens * k, 4] seq_result = seq_bp_probs.reshape(-1, 4) # Trim to original sequence length (remove padding) seq_result = seq_result[:original_len] # Extract actual base probabilities actual_probs = self._extract_actual_probs(seq_result, original_seq) bp_probs_results.append(seq_result) actual_probs_results.append(actual_probs) # Return single tensors if input was single sequence if is_single: return bp_probs_results[0], actual_probs_results[0] return bp_probs_results, actual_probs_results def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str): """Extract probabilities of actual bases in the sequence. For each position i in the sequence, returns the probability that the model assigned to the actual base at that position. For 'N' bases (unknown), returns the maximum probability across all 4 bases. Args: bp_probs: [seq_len, 4] probability distribution from logits bp_probs[i] = P(position i | context before i) sequence: DNA sequence string (may contain 'N') Returns: actual_probs: [seq_len] probabilities of actual bases actual_probs[i] = bp_probs[i, sequence[i]] for A/T/C/G actual_probs[i] = max(bp_probs[i]) for 'N' """ seq_len = len(sequence) actual_probs = torch.zeros(seq_len, device=bp_probs.device, dtype=bp_probs.dtype) for i, base in enumerate(sequence): if base == 'N': # For N, take the maximum probability across all 4 bases actual_probs[i] = bp_probs[i].max() else: base_idx = BASE_TO_IDX[base] actual_probs[i] = bp_probs[i, base_idx] return actual_probs