| import json |
| import os |
| import torch |
| import torch.nn.functional as F |
| import esm |
| from tqdm import tqdm |
| import numpy as np |
|
|
| |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| BATCH_SIZE = 32 |
| MAX_SEQ_LEN = 50 |
| MIN_SEQ_LEN = 2 |
| CANONICAL_AA = set('ACDEFGHIKLMNPQRSTVWY') |
|
|
| print(f"Using device: {DEVICE}") |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name()}") |
| print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
| |
| def read_peptides_json(json_file): |
| """ |
| Read and filter sequences from the all_peptides_data.json file. |
| Extracts sequences from both main peptides and their monomers. |
| Filters: |
| - Only canonical 20 AAs |
| - Sequence length between MIN_SEQ_LEN and MAX_SEQ_LEN |
| - Non-empty sequences |
| Returns: |
| List of (seq_id, sequence) tuples. |
| """ |
| print(f"Loading peptides from {json_file}...") |
| with open(json_file, 'r') as f: |
| data = json.load(f) |
| |
| seqs = [] |
| processed_ids = set() |
| |
| for item in tqdm(data, desc="Processing peptides"): |
| |
| if 'sequence' in item and item['sequence']: |
| seq = item['sequence'].upper().strip() |
| if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and |
| all(aa in CANONICAL_AA for aa in seq)): |
| seq_id = f"main_{item.get('id', 'unk')}" |
| if seq_id not in processed_ids: |
| seqs.append((seq_id, seq)) |
| processed_ids.add(seq_id) |
| |
| |
| if 'monomers' in item and item['monomers']: |
| for monomer in item['monomers']: |
| if 'sequence' in monomer and monomer['sequence']: |
| seq = monomer['sequence'].upper().strip() |
| if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and |
| all(aa in CANONICAL_AA for aa in seq)): |
| seq_id = f"monomer_{monomer.get('id', 'unk')}" |
| if seq_id not in processed_ids: |
| seqs.append((seq_id, seq)) |
| processed_ids.add(seq_id) |
| |
| print(f"Found {len(seqs)} valid sequences") |
| return seqs |
|
|
| @torch.no_grad() |
| def get_per_residue_embeddings(model, alphabet, sequences, batch_size=BATCH_SIZE): |
| """ |
| Compute per-residue ESM-2 embeddings for a list of (id, seq). |
| Pads or truncates each embedding to shape [MAX_SEQ_LEN, D]. |
| Returns a dict {seq_id: tensor[MAX_SEQ_LEN, D]} on CPU. |
| """ |
| model.eval() |
| converter = alphabet.get_batch_converter() |
| embeddings = {} |
| |
| print(f"Computing embeddings for {len(sequences)} sequences...") |
| for i in tqdm(range(0, len(sequences), batch_size), desc="Computing embeddings"): |
| batch = sequences[i:i+batch_size] |
| labels, seqs = zip(*batch) |
| _, _, tokens = converter(batch) |
| tokens = tokens.to(DEVICE) |
| |
| out = model(tokens, repr_layers=[33], return_contacts=False) |
| reps = out['representations'][33] |
| |
| for idx, sid in enumerate(labels): |
| seq = seqs[idx] |
| L = len(seq) |
| |
| emb = reps[idx, 1:1+L, :] |
| if L < MAX_SEQ_LEN: |
| pad_len = MAX_SEQ_LEN - L |
| emb = F.pad(emb, (0, 0, 0, pad_len)) |
| elif L > MAX_SEQ_LEN: |
| emb = emb[:MAX_SEQ_LEN, :] |
| embeddings[sid] = emb.cpu() |
| |
| return embeddings |
|
|
| def save_embeddings_for_compressor(embeddings, output_dir="/data2/edwardsun/flow_project/peptide_embeddings"): |
| """ |
| Save embeddings in a format compatible with the compressor. |
| Creates both individual files and a combined tensor. |
| """ |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| print(f"Saving individual embeddings to {output_dir}/...") |
| for seq_id, emb in tqdm(embeddings.items(), desc="Saving individual files"): |
| torch.save(emb, os.path.join(output_dir, f"{seq_id}.pt")) |
| |
| |
| print("Creating combined tensor...") |
| all_embeddings = [] |
| seq_ids = [] |
| |
| for seq_id, emb in embeddings.items(): |
| all_embeddings.append(emb) |
| seq_ids.append(seq_id) |
| |
| |
| combined_embeddings = torch.stack(all_embeddings) |
| |
| |
| combined_path = os.path.join(output_dir, "all_peptide_embeddings.pt") |
| torch.save(combined_embeddings, combined_path) |
| |
| |
| seq_ids_path = os.path.join(output_dir, "sequence_ids.json") |
| with open(seq_ids_path, 'w') as f: |
| json.dump(seq_ids, f, indent=2) |
| |
| |
| metadata = { |
| "num_sequences": len(embeddings), |
| "embedding_dim": combined_embeddings.shape[-1], |
| "max_seq_len": MAX_SEQ_LEN, |
| "device_used": str(DEVICE), |
| "model_name": "esm2_t33_650M_UR50D" |
| } |
| metadata_path = os.path.join(output_dir, "metadata.json") |
| with open(metadata_path, 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| print(f"Saved combined embeddings: {combined_path}") |
| print(f"Combined tensor shape: {combined_embeddings.shape}") |
| print(f"Memory usage: {combined_embeddings.element_size() * combined_embeddings.nelement() / 1e6:.1f} MB") |
| |
| return combined_path |
|
|
| def create_compressor_dataset(embeddings, output_dir="/data2/edwardsun/flow_project/compressor_dataset"): |
| """ |
| Create a dataset format specifically for the compressor training. |
| """ |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| all_embeddings = torch.stack(list(embeddings.values())) |
| |
| |
| np_path = os.path.join(output_dir, "peptide_embeddings.npy") |
| np.save(np_path, all_embeddings.numpy()) |
| |
| |
| torch_path = os.path.join(output_dir, "peptide_embeddings.pt") |
| torch.save(all_embeddings, torch_path) |
| |
| print(f"Created compressor dataset:") |
| print(f" Shape: {all_embeddings.shape}") |
| print(f" Numpy: {np_path}") |
| print(f" Torch: {torch_path}") |
| |
| return torch_path |
|
|
| |
| if __name__ == '__main__': |
| |
| print("Loading ESM-2 model...") |
| model_name = 'esm2_t33_650M_UR50D' |
| model, alphabet = esm.pretrained.load_model_and_alphabet(model_name) |
| model = model.to(DEVICE) |
| print(f"Loaded {model_name}") |
|
|
| |
| json_file = 'all_peptides_data.json' |
| sequences = read_peptides_json(json_file) |
| print(f"Loaded {len(sequences)} valid sequences from {json_file}") |
|
|
| if len(sequences) == 0: |
| print("No valid sequences found. Exiting.") |
| exit(1) |
|
|
| |
| embeddings = get_per_residue_embeddings(model, alphabet, sequences) |
|
|
| |
| print("\nSaving embeddings...") |
| |
| |
| combined_path = save_embeddings_for_compressor(embeddings) |
| |
| |
| compressor_path = create_compressor_dataset(embeddings) |
| |
| print(f"\n✓ Successfully processed {len(embeddings)} peptide sequences") |
| print(f"✓ Embeddings saved and ready for compressor training") |
| print(f"✓ Use '{compressor_path}' in your compressor.py file") |
| |
| |
| sample_emb = next(iter(embeddings.values())) |
| print(f"\nEmbedding statistics:") |
| print(f" Individual embedding shape: {sample_emb.shape}") |
| print(f" Embedding dimension: {sample_emb.shape[-1]}") |
| print(f" Data type: {sample_emb.dtype}") |
|
|