Carbon-3B / modeling_carbon.py
GenerTeam's picture
Upload modeling_carbon.py
01fd9be verified
raw
history blame
17.3 kB
"""
Carbon with bp_probs generation support.
generate_bp() reuses the full HF generate() pipeline (parameter preparation,
cache management, stopping criteria, logits processing, etc.) and only replaces
the token selection step with bp-level independent base selection.
"""
import os
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaForCausalLM
BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
class CarbonForCausalLM(LlamaForCausalLM):
"""LlamaForCausalLM with bp-level autoregressive generation.
Inherits all standard functionality (forward, generate, etc.)
and adds generate_bp() for base-pair independent generation.
"""
def setup_tokenizer(self, tokenizer):
"""Cache tokenizer and precompute lookup tables for bp generation."""
self.tokenizer = tokenizer
k = tokenizer.k
self.k = k
num_special = len(tokenizer.special_tokens)
num_kmers = 4 ** k
self._kmer_ids = tokenizer.get_kmer_ids()
self._kmers = tokenizer.get_kmers()
bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
for j in range(k):
bp_base_index[j] = torch.arange(num_kmers) >> ((k - 1 - j) * 2) & 3
device = next(self.parameters()).device
self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
self._bp_powers = torch.tensor(
[4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
)
flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
for kmer, tid in zip(self._kmers, self._kmer_ids):
idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer))
flat_to_tid[idx] = tid
self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
def compute_bp_probs(self, logits):
"""Compute per-base marginal probabilities from token logits (vectorized).
Args:
logits: [B, V] or [B, L, V] token logits
Returns:
bp_probs: [B, k, 4] or [B, L, k, 4]
"""
squeeze = False
if logits.dim() == 2:
logits = logits.unsqueeze(1) # [B, 1, V]
squeeze = True
kmer_logits = logits[:, :, self._kmer_ids] # [B, L, num_kmers]
kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
B, L, _ = kmer_probs.shape
bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
for pos in range(self.k):
idx = self._bp_base_index[pos] # [num_kmers] -> 0~3
for nt in range(4):
bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
if squeeze:
bp_probs = bp_probs.squeeze(1) # [B, k, 4]
return bp_probs
# -------------------------------------------------------------------------
# generate_bp: sets a flag then delegates to the standard generate()
# -------------------------------------------------------------------------
@torch.no_grad()
def generate_bp(self, inputs=None, generation_config=None, **kwargs):
"""Same interface as generate(), but with bp-level independent base selection.
Token logits are marginalized to per-base probabilities [k, 4], and each
base position is selected independently. All standard generate() parameters
(temperature, top_k, top_p, do_sample, attention_mask, etc.) are fully
supported — they are processed by the HF generate pipeline as usual.
Returns:
Same as generate() — token ids tensor or GenerateOutput.
"""
assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer() first"
self._bp_generation = True
try:
return super().generate(
inputs=inputs, generation_config=generation_config, **kwargs
)
finally:
self._bp_generation = False
# -------------------------------------------------------------------------
# Override _sample: when _bp_generation is set, use bp-level token selection
# -------------------------------------------------------------------------
def _sample(
self,
input_ids,
logits_processor,
stopping_criteria,
generation_config,
synced_gpus,
streamer,
**model_kwargs,
):
if not getattr(self, "_bp_generation", False):
return super()._sample(
input_ids,
logits_processor,
stopping_criteria,
generation_config,
synced_gpus,
streamer,
**model_kwargs,
)
# ==================================================================
# BP generation mode — copied from transformers 4.56.0 _sample(),
# with ONLY the token selection block replaced by bp marginalization.
# ==================================================================
from transformers.generation.utils import (
GenerateDecoderOnlyOutput,
)
# init values
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
)
do_sample = generation_config.do_sample
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(
batch_size, dtype=torch.long, device=input_ids.device
)
model_kwargs = self._get_initial_cache_position(
cur_len, input_ids.device, model_kwargs
)
model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(
model_kwargs, generation_config
)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
if self.config._attn_implementation == "flash_attention_2":
if (
generation_config.compile_config is not None
and generation_config.compile_config.fullgraph
):
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config)
if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(
input_ids, generation_config, **model_kwargs
)
is_prefill = False
else:
is_prefill = True
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls
model_inputs.update(
{"output_attentions": output_attentions} if output_attentions else {}
)
model_inputs.update(
{"output_hidden_states": output_hidden_states}
if output_hidden_states
else {}
)
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True)
# update model kwargs for next step (handles cache, attention_mask, etc.)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue
next_token_logits = outputs.logits[:, -1, :].to(
copy=True, dtype=torch.float32, device=input_ids.device
)
# pre-process distribution (temperature, top_k, top_p, repetition_penalty, etc.)
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += ((outputs.attentions,),)
if output_hidden_states:
decoder_hidden_states += ((outputs.hidden_states,),)
# =============================================================
# BP-LEVEL TOKEN SELECTION (vectorized, the ONLY change)
# =============================================================
# [B, V] -> [B, k, 4] marginal bp probabilities
bp_probs = self.compute_bp_probs(next_token_scores) # [B, k, 4]
if do_sample:
# [B*k, 4] -> multinomial -> [B, k]
base_indices = torch.multinomial(
bp_probs.view(-1, 4), 1
).view(batch_size, self.k)
else:
base_indices = bp_probs.argmax(dim=-1) # [B, k]
# base_indices [B, k] -> flat kmer index -> token_id [B]
flat_idx = (base_indices * self._bp_powers).sum(dim=-1) # [B]
next_tokens = self._flat_idx_to_token_id[flat_idx] # [B]
# =============================================================
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
input_ids, scores
)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
del outputs
if streamer is not None:
streamer.end()
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
@torch.no_grad()
def score_sequence(self, sequences: Union[str, list[str]]):
"""Score DNA sequence(s) and return per-base conditional probabilities.
Each sequence is manually prepended with BOS token ("<dna>") and padded
with 'A' if length is not a multiple of k. Returns probabilities for the
original sequences only (excluding padding).
Args:
sequences: Single DNA sequence string or list of sequences
Returns:
Tuple of (bp_probs, actual_probs):
- bp_probs: Full probability distribution
* Single sequence: [seq_len, 4] tensor
* Batch: list of [seq_len_i, 4] tensors
- actual_probs: Probability of the actual base at each position
* Single sequence: [seq_len] tensor
* Batch: list of [seq_len_i] tensors
bp_probs[i, j] = P(base at position i is nucleotide j | context)
actual_probs[i] = P(actual base at position i | context)
where j: 0=A, 1=T, 2=C, 3=G
Example:
# Single sequence
bp_probs, actual_probs = model.score_sequence("ACGT")
# Batch of sequences
bp_probs_list, actual_probs_list = model.score_sequence([
"ACGT" * 150,
"ACGT" * 149 + "AC",
])
"""
assert hasattr(self, "tokenizer"), "Call setup_tokenizer() first"
# Handle single sequence case
is_single = isinstance(sequences, str)
if is_single:
sequences = [sequences]
# Store original info
original_lens = [len(seq) for seq in sequences]
original_sequences = sequences.copy()
# Pad each sequence to multiple of k with 'A'
padded_sequences = []
for seq in sequences:
if len(seq) % self.k != 0:
padding_len = self.k - (len(seq) % self.k)
seq = seq + 'A' * padding_len
padded_sequences.append(seq)
# Manually prepend BOS token "<dna>" to each sequence
sequences_with_bos = ["<dna>" + seq for seq in padded_sequences]
# Tokenize batch (without add_special_tokens since we added manually)
inputs = self.tokenizer(
sequences_with_bos,
return_tensors="pt",
padding=True,
add_special_tokens=False
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# Forward pass to get logits for all positions
outputs = self(input_ids, attention_mask=attention_mask, return_dict=True)
logits = outputs.logits # [B, max_seq_len, vocab_size]
# Compute bp probabilities for all token positions
bp_probs = self.compute_bp_probs(logits) # [B, max_seq_len, k, 4]
# Process each sequence in the batch
bp_probs_results = []
actual_probs_results = []
for i, (original_seq, original_len, padded_seq) in enumerate(
zip(original_sequences, original_lens, padded_sequences)
):
# Calculate number of actual sequence tokens (excluding BOS)
num_seq_tokens = len(padded_seq) // self.k
# Extract bp_probs for this sequence
# logits[0] predicts token after BOS (first sequence token)
# logits[i] predicts token[i+1]
# So logits[0:num_seq_tokens] predict the sequence tokens
seq_bp_probs = bp_probs[i, :num_seq_tokens] # [num_seq_tokens, k, 4]
# Reshape: [num_seq_tokens, k, 4] -> [num_seq_tokens * k, 4]
seq_result = seq_bp_probs.reshape(-1, 4)
# Trim to original sequence length (remove padding)
seq_result = seq_result[:original_len]
# Extract actual base probabilities
actual_probs = self._extract_actual_probs(seq_result, original_seq)
bp_probs_results.append(seq_result)
actual_probs_results.append(actual_probs)
# Return single tensors if input was single sequence
if is_single:
return bp_probs_results[0], actual_probs_results[0]
return bp_probs_results, actual_probs_results
def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str):
"""Extract probabilities of actual bases in the sequence.
For each position i in the sequence, returns the probability that the model
assigned to the actual base at that position.
For 'N' bases (unknown), returns the maximum probability across all 4 bases.
Args:
bp_probs: [seq_len, 4] probability distribution from logits
bp_probs[i] = P(position i | context before i)
sequence: DNA sequence string (may contain 'N')
Returns:
actual_probs: [seq_len] probabilities of actual bases
actual_probs[i] = bp_probs[i, sequence[i]] for A/T/C/G
actual_probs[i] = max(bp_probs[i]) for 'N'
"""
seq_len = len(sequence)
actual_probs = torch.zeros(seq_len, device=bp_probs.device, dtype=bp_probs.dtype)
for i, base in enumerate(sequence):
if base == 'N':
# For N, take the maximum probability across all 4 bases
actual_probs[i] = bp_probs[i].max()
else:
base_idx = BASE_TO_IDX[base]
actual_probs[i] = bp_probs[i, base_idx]
return actual_probs