""" Pipeline: 1. Read *_meta_with_split.csv (sequence, label, id, split) 2. Convert wt sequences to SMILES via: fasta2smi -i peptides.fasta -o peptides.p2smi 3. Parse .p2smi format: "{seq}-linear: {SMILES}" 4. Embed SMILES with ChemBERTa to save pooled + unpooled DatasetDicts 5. Embed SMILES with PeptideCLM to save pooled + unpooled DatasetDicts """ import os import subprocess import tempfile import sys import numpy as np import torch import pandas as pd from tqdm import tqdm from datasets import Dataset, DatasetDict from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM PROJECT_ROOT = "<>" # change here # using permeability as example META_CSV = ( f"{PROJECT_ROOT}/training_data_cleaned/" "permeability_penetrance/permeability_meta_with_split.csv" ) BASE_OUT = f"{PROJECT_ROOT}/alternative_embeddings" # ChemBERTa CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM" CHEMBERTA_OUT = f"{BASE_OUT}/permeability_chemberta/perm_smiles_with_embeddings" # PeptideCLM sys.path.append(PROJECT_ROOT) from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all" PEPTIDECLM_TOKENIZER = f"{PROJECT_ROOT}/tokenizer/new_vocab.txt" PEPTIDECLM_SPLITS = f"{PROJECT_ROOT}/tokenizer/new_splits.txt" PEPTIDECLM_OUT = f"{BASE_OUT}/permeability_peptideclm/perm_smiles_with_embeddings" # Column names in the CSV SEQ_COL = "sequence" LABEL_COL = "label" SPLIT_COL = "split" ID_COL = "id" # used as FASTA header; must be unique # fasta2smi settings FASTA2SMI_BIN = "fasta2smi" # install via github # Embedding settings MAX_LENGTH_CHEMBERTA = 512 MAX_LENGTH_PEPTIDECLM = 768 BATCH_SIZE = 128 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # =========================================================================================== # Step 1 — fasta2smi conversion, do not apply to properties that only have SMILES sequences # =========================================================================================== def sequences_to_smiles(sequences: list[str], ids: list[str]) -> dict[str, str]: """ .p2smi format produced by fasta2smi: MIIFAIAASHKK-linear: N[C@@H](CCSC)C(=O)... KIAKLKAKIQ...-linear: N[C@@H](CCCCN)C(=O)... """ with tempfile.TemporaryDirectory() as tmpdir: fasta_path = os.path.join(tmpdir, "peptides.fasta") p2smi_path = os.path.join(tmpdir, "peptides.p2smi") with open(fasta_path, "w") as fh: for sid, seq in zip(ids, sequences): fh.write(f">{sid}\n{seq}\n") cmd = [FASTA2SMI_BIN, "-i", fasta_path, "-o", p2smi_path] print(f" Running: {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError( f"fasta2smi failed (exit {result.returncode}):\n" f" stdout: {result.stdout}\n stderr: {result.stderr}" ) seq2smi = _parse_p2smi(p2smi_path) n_ok = len(seq2smi) n_fail = len(sequences) - n_ok print(f" fasta2smi: {n_ok}/{len(sequences)} converted ({n_fail} failed/skipped)") return seq2smi def _parse_p2smi(path: str) -> dict[str, str]: seq2smi: dict[str, str] = {} with open(path) as fh: for line in fh: line = line.strip() if not line or line.startswith("#"): continue # Split on "-linear: " — the separator fasta2smi uses if "-linear: " not in line: print(f" [WARN] Unexpected p2smi line, skipping: {line[:80]}") continue aa_seq, smi = line.split("-linear: ", maxsplit=1) smi = smi.strip() if smi and smi.lower() not in ("none", "null", "n/a"): seq2smi[aa_seq] = smi return seq2smi # ============================================================ # Setups # ============================================================ def _get_special_ids_tensor(tokenizer): attrs = [ "pad_token_id", "cls_token_id", "sep_token_id", "bos_token_id", "eos_token_id", "mask_token_id", ] ids = sorted({getattr(tokenizer, a, None) for a in attrs} - {None}) return torch.tensor(ids, device=device, dtype=torch.long) if ids else None @torch.no_grad() def _embed_batch(tokenizer, model, special_ids_t, sequences, max_length): tok = tokenizer( sequences, return_tensors="pt", padding=True, max_length=max_length, truncation=True, ) input_ids = tok["input_ids"].to(device) attention_mask = tok["attention_mask"].to(device) out = model(input_ids=input_ids, attention_mask=attention_mask) last_hidden = out.last_hidden_state # (B, L, H) valid = attention_mask.bool() if special_ids_t is not None: valid = valid & (~torch.isin(input_ids, special_ids_t)) valid_f = valid.unsqueeze(-1).float() pooled = ( torch.sum(last_hidden * valid_f, dim=1) / torch.clamp(valid_f.sum(dim=1), min=1e-9) ).cpu().numpy() # (B, H) float32 token_embs, masks, lengths = [], [], [] for b in range(last_hidden.shape[0]): emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy() token_embs.append(emb) masks.append(np.ones(emb.shape[0], dtype=np.int8)) lengths.append(emb.shape[0]) return pooled, token_embs, masks, lengths def _embed_all(tokenizer, model, special_ids_t, sequences, max_length): pooled_all, token_all, mask_all, len_all = [], [], [], [] for i in tqdm(range(0, len(sequences), BATCH_SIZE), desc=" batches"): p, t, m, l = _embed_batch( tokenizer, model, special_ids_t, sequences[i:i+BATCH_SIZE], max_length, ) pooled_all.append(p) token_all.extend(t) mask_all.extend(m) len_all.extend(l) return np.vstack(pooled_all), token_all, mask_all, len_all def _build_datasets(wt_seqs, smiles, labels, tokenizer, model, special_ids_t, max_length): pooled, tok_embs, masks, lengths = _embed_all( tokenizer, model, special_ids_t, smiles, max_length ) pooled_ds = Dataset.from_dict({ "sequence": wt_seqs, "smiles": smiles, "label": labels, "embedding": pooled, }) full_ds = Dataset.from_dict({ "sequence": wt_seqs, "smiles": smiles, "label": labels, "embedding": tok_embs, "attention_mask": masks, "length": lengths, }) return pooled_ds, full_ds def _save(splits: dict, out_path: str): os.makedirs(os.path.dirname(out_path), exist_ok=True) DatasetDict({k: v[0] for k, v in splits.items()}).save_to_disk(out_path) DatasetDict({k: v[1] for k, v in splits.items()}).save_to_disk(out_path + "_unpooled") print(f" Saved pooled to {out_path}") print(f" Saved unpooled to {out_path}_unpooled") # ============================================================ # ChemBERTa # ============================================================ def run_chemberta(meta: pd.DataFrame): print(f"\n{'='*60}") print(" Encoder: ChemBERTa") print(f"{'='*60}") print(f" Loading {CHEMBERTA_MODEL} ...") tokenizer = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL) model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(device).eval() special_ids_t = _get_special_ids_tensor(tokenizer) splits: dict[str, tuple] = {} for split_name in ["train", "val"]: df = meta[meta[SPLIT_COL] == split_name].reset_index(drop=True) print(f"\n [{split_name}] {len(df)} rows") if df.empty: print(" [WARN] Empty split, skipping.") continue pooled_ds, full_ds = _build_datasets( df[SEQ_COL].tolist(), df["smiles"].tolist(), df[LABEL_COL].tolist(), tokenizer, model, special_ids_t, MAX_LENGTH_CHEMBERTA, ) splits[split_name] = (pooled_ds, full_ds) _save(splits, CHEMBERTA_OUT) # free GPU memory before loading next model del model torch.cuda.empty_cache() # ============================================================ # PeptideCLM # ============================================================ def run_peptideclm(meta: pd.DataFrame): print(f"\n{'='*60}") print(" Encoder: PeptideCLM") print(f"{'='*60}") print(f" Loading tokenizer from {PEPTIDECLM_TOKENIZER} ...") tokenizer = SMILES_SPE_Tokenizer(PEPTIDECLM_TOKENIZER, PEPTIDECLM_SPLITS) print(f" Loading {PEPTIDECLM_MODEL} ...") full_model = AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL) model = full_model.roformer.to(device).eval() special_ids_t = _get_special_ids_tensor(tokenizer) splits: dict[str, tuple] = {} for split_name in ["train", "val"]: df = meta[meta[SPLIT_COL] == split_name].reset_index(drop=True) print(f"\n [{split_name}] {len(df)} rows") if df.empty: print(" [WARN] Empty split, skipping.") continue pooled_ds, full_ds = _build_datasets( df[SEQ_COL].tolist(), df["smiles"].tolist(), df[LABEL_COL].tolist(), tokenizer, model, special_ids_t, MAX_LENGTH_PEPTIDECLM, ) splits[split_name] = (pooled_ds, full_ds) _save(splits, PEPTIDECLM_OUT) del model torch.cuda.empty_cache() # ============================================================ # Main # ============================================================ def main(): print(f"\nDevice : {device}") print(f"Meta : {META_CSV}") # Load metadata meta = pd.read_csv(META_CSV, sep=None, engine="python") print(f"Loaded {len(meta)} rows. Columns: {meta.columns.tolist()}") for col in [SEQ_COL, LABEL_COL, SPLIT_COL]: if col not in meta.columns: raise ValueError(f"Expected column '{col}' not found. Available: {meta.columns.tolist()}") # Ensure numeric labels meta[LABEL_COL] = pd.to_numeric(meta[LABEL_COL], errors="coerce") meta = meta.dropna(subset=[SEQ_COL, LABEL_COL]).reset_index(drop=True) # Build id list for FASTA headers if ID_COL in meta.columns: ids = meta[ID_COL].astype(str).tolist() else: ids = [f"seq_{i}" for i in range(len(meta))] # Note that for properties start with SMILES sequences, fasta2smi is not needed # Convert wt to SMILES (single fasta2smi call for the whole dataset) print("\nConverting peptide sequences to SMILES ...") seqs = meta[SEQ_COL].astype(str).tolist() seq2smi = sequences_to_smiles(seqs, ids) meta["smiles"] = meta[SEQ_COL].astype(str).map(seq2smi) n_missing = meta["smiles"].isna().sum() if n_missing: print(f" [WARN] {n_missing} sequences had no SMILES — dropping.") meta = meta.dropna(subset=["smiles"]).reset_index(drop=True) print(f" Retained {len(meta)} rows with valid SMILES.") # Save SMILES-enriched meta CSV smiles_meta_path = os.path.join(BASE_OUT, "permeability_smiles_meta_with_split.csv") os.makedirs(BASE_OUT, exist_ok=True) meta.to_csv(smiles_meta_path, index=False) print(f" Saved SMILES meta to {smiles_meta_path}") # Run both encoders sequentially (share the same converted SMILES) #run_chemberta(meta) #run_peptideclm(meta) print("\nAll done.") if __name__ == "__main__": main()