| |
| """ |
| Extract reverse complement embeddings for TBX5 motif data using Evo2 40B model. |
| - Extract embeddings from block 20 pre-normalization layer |
| - Use 8192bp window around motif site |
| - Average embeddings for 61bp sequences (reverse complement) |
| - Create 4096 dimensional feature vector for each motif |
| """ |
|
|
| import pandas as pd |
| import numpy as np |
| import torch |
| import gzip |
| from Bio import SeqIO |
| from Bio.Seq import Seq |
| from evo2 import Evo2 |
| import pickle |
| from tqdm import tqdm |
| import os |
| import sys |
| import argparse |
|
|
| |
| tqdm.pandas() |
|
|
| |
| WINDOW_SIZE = 8192 |
| LAYER_NAME = "blocks.26.mlp.l3" |
| SEQUENCE_LENGTH = 61 |
| BATCH_SIZE = 8 |
|
|
| def get_reverse_complement(sequence): |
| """Get reverse complement of DNA sequence.""" |
| return str(Seq(sequence).reverse_complement()) |
|
|
| def load_fasta(fasta_path, chromosome): |
| """Load chromosome FASTA file.""" |
| print(f"Loading chromosome {chromosome} FASTA file...") |
| with gzip.open(fasta_path, "rt") as handle: |
| for record in SeqIO.parse(handle, "fasta"): |
| seq = str(record.seq).upper() |
| print(f"Loaded chromosome {chromosome}, length: {len(seq):,} bp") |
| return seq |
| return None |
|
|
| def normalize_sequence_length(df): |
| """Normalize all sequences to 61bp length.""" |
| print("Normalizing sequence lengths to 61bp...") |
| |
| df_normalized = df.copy() |
| |
| for idx, row in df_normalized.iterrows(): |
| start = row['start'] |
| end = row['end'] |
| current_length = end - start + 1 |
| |
| if current_length != SEQUENCE_LENGTH: |
| if current_length < SEQUENCE_LENGTH: |
| |
| extension = SEQUENCE_LENGTH - current_length |
| new_start = max(0, start - extension // 2) |
| new_end = new_start + SEQUENCE_LENGTH - 1 |
| else: |
| |
| excess = current_length - SEQUENCE_LENGTH |
| new_start = start + excess // 2 |
| new_end = new_start + SEQUENCE_LENGTH - 1 |
| |
| df_normalized.at[idx, 'start'] = new_start |
| df_normalized.at[idx, 'end'] = new_end |
| df_normalized.at[idx, 'length'] = SEQUENCE_LENGTH |
| |
| print(f"Normalized {len(df_normalized)} sequences to {SEQUENCE_LENGTH}bp") |
| return df_normalized |
|
|
| def get_sequence_window(chr_seq, start, end, window_size=WINDOW_SIZE): |
| """ |
| Extract sequence window around motif site. |
| |
| Args: |
| chr_seq: Full chromosome sequence |
| start: Start position of motif (1-based) |
| end: End position of motif (1-based) |
| window_size: Size of window around motif (default 8192bp) |
| |
| Returns: |
| seq_window: Sequence window around motif |
| motif_pos: Position of motif in the window |
| """ |
| |
| start_0 = start - 1 |
| end_0 = end - 1 |
| |
| |
| motif_center = (start_0 + end_0) // 2 |
| |
| |
| half_window = window_size // 2 |
| window_start = max(0, motif_center - half_window) |
| window_end = min(len(chr_seq), motif_center + half_window) |
| |
| |
| seq_window = chr_seq[window_start:window_end] |
| |
| |
| motif_start_in_window = start_0 - window_start |
| motif_end_in_window = end_0 - window_start |
| |
| return seq_window, motif_start_in_window, motif_end_in_window |
|
|
| def extract_embeddings_batch(model, sequences, layer_name=LAYER_NAME): |
| """ |
| Extract embeddings for a batch of sequences. |
| |
| Args: |
| model: Evo2 model |
| sequences: List of DNA sequences |
| layer_name: Name of layer to extract embeddings from |
| |
| Returns: |
| embeddings: Averaged embeddings for each sequence |
| """ |
| all_embeddings = [] |
| |
| for seq in sequences: |
| |
| input_ids = ( |
| torch.tensor( |
| model.tokenizer.tokenize(seq), |
| dtype=torch.int, |
| ) |
| .unsqueeze(0) |
| .to("cuda:0") |
| ) |
| |
| |
| with torch.no_grad(): |
| _, embeddings = model( |
| input_ids, return_embeddings=True, layer_names=[layer_name] |
| ) |
| |
| |
| |
| |
| avg_embedding = embeddings[layer_name].mean(dim=1).float().cpu().numpy() |
| all_embeddings.append(avg_embedding) |
| |
| return np.vstack(all_embeddings) |
|
|
| def process_motifs(model, chr_seq, motif_df, chromosome): |
| """ |
| Process all motifs and extract reverse complement embeddings. |
| |
| Args: |
| model: Evo2 model |
| chr_seq: Chromosome sequence |
| motif_df: DataFrame with motif information |
| chromosome: Chromosome identifier |
| |
| Returns: |
| embeddings_dict: Dictionary with motif indices as keys and embeddings as values |
| """ |
| embeddings_dict = {} |
| failed_motifs = [] |
| |
| print(f"Processing {len(motif_df)} motifs on chromosome {chromosome} (reverse complement)...") |
| |
| for idx, row in tqdm( |
| motif_df.iterrows(), |
| total=len(motif_df), |
| desc=f"Chr{chromosome} RC embeddings", |
| ncols=100, |
| leave=True, |
| position=0 |
| ): |
| try: |
| |
| start = int(row['start']) |
| end = int(row['end']) |
| |
| |
| seq_window, motif_start, motif_end = get_sequence_window( |
| chr_seq, start, end |
| ) |
| |
| if seq_window is None: |
| failed_motifs.append(idx) |
| continue |
| |
| |
| motif_seq = seq_window[motif_start:motif_end+1] |
| |
| |
| if len(motif_seq) != SEQUENCE_LENGTH: |
| print(f"Warning: Motif length {len(motif_seq)} != {SEQUENCE_LENGTH} at position {start}-{end}") |
| failed_motifs.append(idx) |
| continue |
| |
| |
| motif_seq_rc = get_reverse_complement(motif_seq) |
| |
| |
| embeddings = extract_embeddings_batch(model, [motif_seq_rc]) |
| |
| |
| motif_embedding = embeddings[0] |
| |
| embeddings_dict[idx] = { |
| "start": start, |
| "end": end, |
| "embedding": motif_embedding, |
| "tbx5_score": row.get("tbx5_score", 0), |
| "label": row.get("label", 0), |
| "chromosome": chromosome, |
| "sequence_type": "reverse_complement", |
| } |
| |
| except Exception as e: |
| print(f"Error processing motif at index {idx}: {e}") |
| failed_motifs.append(idx) |
| continue |
| |
| print(f"Successfully processed {len(embeddings_dict)} motifs (reverse complement)") |
| if failed_motifs: |
| print(f"Failed to process {len(failed_motifs)} motifs: {failed_motifs[:10]}...") |
| |
| return embeddings_dict |
|
|
| def save_embeddings(embeddings_dict, output_path, chromosome): |
| """Save embeddings to file.""" |
| print(f"Saving reverse complement embeddings to {output_path}") |
| |
| |
| save_data = { |
| "embeddings": {}, |
| "metadata": { |
| "chromosome": chromosome, |
| "window_size": WINDOW_SIZE, |
| "sequence_length": SEQUENCE_LENGTH, |
| "layer_name": LAYER_NAME, |
| "embedding_dim": 4096, |
| "num_motifs": len(embeddings_dict), |
| "sequence_type": "reverse_complement", |
| }, |
| } |
| |
| for idx, data in embeddings_dict.items(): |
| save_data["embeddings"][idx] = data |
| |
| |
| with open(output_path, "wb") as f: |
| pickle.dump(save_data, f) |
| |
| |
| np_output = output_path.replace(".pkl", "_arrays.npz") |
| |
| |
| indices = [] |
| starts = [] |
| ends = [] |
| embeddings = [] |
| tbx5_scores = [] |
| labels = [] |
| |
| for idx, data in embeddings_dict.items(): |
| indices.append(idx) |
| starts.append(data["start"]) |
| ends.append(data["end"]) |
| embeddings.append(data["embedding"]) |
| tbx5_scores.append(data["tbx5_score"]) |
| labels.append(data["label"]) |
| |
| if len(embeddings) > 0: |
| np.savez_compressed( |
| np_output, |
| indices=np.array(indices), |
| starts=np.array(starts), |
| ends=np.array(ends), |
| embeddings=np.vstack(embeddings), |
| tbx5_scores=np.array(tbx5_scores), |
| labels=np.array(labels), |
| metadata=save_data["metadata"], |
| ) |
| print(f"Saved numpy arrays to {np_output}") |
| else: |
| print("No embeddings to save in numpy format") |
|
|
| def main(): |
| |
| parser = argparse.ArgumentParser( |
| description="Extract reverse complement embeddings for TBX5 motif data" |
| ) |
| parser.add_argument( |
| "chromosome", type=str, help="Chromosome to process (e.g., 1, 2, X, Y)" |
| ) |
| parser.add_argument( |
| "--fasta-dir", |
| type=str, |
| default="fasta", |
| help="Directory containing FASTA files (default: fasta)", |
| ) |
| parser.add_argument( |
| "--csv-file", |
| type=str, |
| default="processed_data/all_tbx5_data.csv", |
| help="TBX5 CSV file (default: processed_data/all_tbx5_data.csv)", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="tbx5_embeddings_reverse_complement", |
| help="Output directory for reverse complement embeddings (default: tbx5_embeddings_reverse_complement)", |
| ) |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="evo2_40b", |
| help="Evo2 model to use (default: evo2_40b)", |
| ) |
| |
| args = parser.parse_args() |
| chromosome = args.chromosome |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| fasta_path = os.path.join( |
| args.fasta_dir, f"Homo_sapiens.GRCh38.dna.chromosome.{chromosome}.fa.gz" |
| ) |
| csv_path = args.csv_file |
| output_path = os.path.join(args.output_dir, f"chr{chromosome}_tbx5_embeddings_rc.pkl") |
| |
| |
| if not os.path.exists(fasta_path): |
| print(f"Error: FASTA file not found at {fasta_path}") |
| return 1 |
| |
| if not os.path.exists(csv_path): |
| print(f"Error: CSV file not found at {csv_path}") |
| return 1 |
| |
| |
| chr_seq = load_fasta(fasta_path, chromosome) |
| if chr_seq is None: |
| print(f"Error: Failed to load chromosome {chromosome} sequence") |
| return 1 |
| |
| |
| print(f"Loading TBX5 data for chromosome {chromosome}...") |
| motif_df = pd.read_csv(csv_path) |
| |
| |
| chr_motif_df = motif_df[motif_df['chromosome'] == chromosome].copy() |
| |
| if len(chr_motif_df) == 0: |
| print(f"Warning: No chromosome {chromosome} motifs found in TBX5 data") |
| |
| save_data = { |
| "embeddings": {}, |
| "metadata": { |
| "chromosome": chromosome, |
| "window_size": WINDOW_SIZE, |
| "sequence_length": SEQUENCE_LENGTH, |
| "layer_name": LAYER_NAME, |
| "embedding_dim": 4096, |
| "num_motifs": 0, |
| "sequence_type": "reverse_complement", |
| }, |
| } |
| with open(output_path, "wb") as f: |
| pickle.dump(save_data, f) |
| print(f"Created empty reverse complement embeddings file for chromosome {chromosome}") |
| return 0 |
| |
| print(f"Found {len(chr_motif_df)} motifs on chromosome {chromosome}") |
| |
| |
| chr_motif_df = normalize_sequence_length(chr_motif_df) |
| |
| |
| print(f"Loading {args.model} model...") |
| model = Evo2(args.model) |
| model.model.eval() |
| |
| |
| embeddings_dict = process_motifs(model, chr_seq, chr_motif_df, chromosome) |
| |
| |
| save_embeddings(embeddings_dict, output_path, chromosome) |
| |
| print(f"Done processing chromosome {chromosome} (reverse complement)!") |
| |
| |
| print(f"\n=== Summary for Chromosome {chromosome} (Reverse Complement) ===") |
| print(f"Total motifs processed: {len(embeddings_dict)}") |
| print(f"Embedding dimension: 4096") |
| print(f"Sequence length: {SEQUENCE_LENGTH}bp") |
| print(f"Window size: {WINDOW_SIZE}bp") |
| print(f"Sequence type: Reverse complement") |
| print(f"Output files:") |
| print(f" - {output_path}") |
| print(f" - {output_path.replace('.pkl', '_arrays.npz')}") |
| |
| return 0 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|