import math import sys from contextlib import contextmanager from pathlib import Path import numpy as np import pandas as pd import torch from datasets import Dataset, DatasetDict from tqdm import tqdm from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, EsmModel # ====================== # CONFIG # ====================== ROOT = Path("<>") # CHANGE HERE PROJ_ROOT = ROOT / "PeptiVerse" CSV_PATH = PROJ_ROOT / "training_data" / "c-binding.csv" OUT_ROOT = PROJ_ROOT / "training_data_cleaned" / "binding_affinity" # ESM2 - target encoder (shared across all branches) ESM_MODEL = "facebook/esm2_t33_650M_UR50D" ESM_MAX_LEN = 1022 ESM_BATCH = 32 # PeptideCLM - SMILES binder encoder sys.path.append(str(PROJ_ROOT)) from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all" TOKENIZER_VOCAB = str(PROJ_ROOT / "tokenizer" / "new_vocab.txt") TOKENIZER_SPLITS = str(PROJ_ROOT / "tokenizer" / "new_splits.txt") PEPTIDECLM_MAX_LEN = 768 PEPTIDECLM_BATCH = 128 # ChemBERTa - SMILES binder encoder CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM" CHEMBERTA_MAX_LEN = 512 CHEMBERTA_BATCH = 128 # Which SMILES binder models to run RUN_PEPTIDECLM = True RUN_CHEMBERTA = True # CSV column names COL_SEQ1 = "seq1" COL_SEQ2 = "seq2" COL_AFF = "affinity" COL_F2S = "Fasta2SMILES" COL_REACT = "REACT_SMILES" COL_MERGE = "Merge_SMILES" COL_WT_IPTM = "wt_iptm_score" COL_SMI_IPTM = "smiles_iptm_score" # Split config TRAIN_FRAC = 0.80 RANDOM_SEED = 1986 AFFINITY_Q_BINS = 30 # Logging QUIET = True USE_TQDM = False DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ====================== # Logging / progress # ====================== def log(msg: str): if not QUIET: print(msg) def pbar(it, **kwargs): return tqdm(it, **kwargs) if USE_TQDM else it @contextmanager def section(title: str): log(f"\n=== {title} ===") yield log(f"=== done: {title} ===") # ====================== # Data Handling # ====================== def has_uaa(seq: str) -> bool: return "X" in str(seq).upper() def pick_smiles(row) -> str | None: """Column Priority: Fasta2SMILES > REACT_SMILES > Merge_SMILES.""" for col in [COL_F2S, COL_REACT, COL_MERGE]: val = row.get(col, None) if val is not None and str(val).strip() not in ("", "nan", "None"): return str(val).strip() return None def affinity_to_class(a: float) -> str: if a >= 9.0: return "High" elif a >= 7.0: return "Moderate" else: return "Low" def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame: df = df.copy() df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") df = df.dropna(subset=[COL_AFF]).reset_index(drop=True) df["affinity_class"] = df[COL_AFF].apply(affinity_to_class) try: df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop") strat_col = "aff_bin" except Exception: df["aff_bin"] = df["affinity_class"] strat_col = "aff_bin" rng = np.random.RandomState(RANDOM_SEED) df["split"] = None for _, g in df.groupby(strat_col, observed=True): idx = g.index.to_numpy() rng.shuffle(idx) n_train = int(math.floor(len(idx) * TRAIN_FRAC)) df.loc[idx[:n_train], "split"] = "train" df.loc[idx[n_train:], "split"] = "val" df["split"] = df["split"].fillna("train") return df def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame: out = df_in.copy() out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() out["sequence"] = out[binder_seq_col].astype(str).str.strip() out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce") out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce") out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce") out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True) return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]] # ====================== # Dataset builders # ====================== def build_pooled_ds(view: pd.DataFrame, iptm_col: str, tgt_embs: np.ndarray, bnd_embs: np.ndarray) -> Dataset: """Both target and binder are (N, H) pooled float32 arrays.""" return Dataset.from_dict({ "target_sequence": view["target_sequence"].tolist(), "sequence": view["sequence"].tolist(), "label": view["label"].astype(float).tolist(), "target_embedding": tgt_embs, # (N, H_esm) float32 "binder_embedding": bnd_embs, # (N, H_binder) float32 "affinity": view[COL_AFF].astype(float).tolist(), "affinity_class": view["affinity_class"].tolist(), iptm_col: view[iptm_col].astype(float).tolist(), }) def build_unpooled_ds(view: pd.DataFrame, iptm_col: str, tgt_tok_embs, tgt_masks, tgt_lengths, bnd_tok_embs, bnd_masks, bnd_lengths) -> Dataset: """ Per-token lists for both sides. target_embedding[i] : (Lt_i, H_esm) float16 ndarray binder_embedding[i] : (Lb_i, H_binder) float16 ndarray """ return Dataset.from_dict({ "target_sequence": view["target_sequence"].tolist(), "sequence": view["sequence"].tolist(), "label": view["label"].astype(float).tolist(), "target_embedding": tgt_tok_embs, "target_attention_mask": tgt_masks, "target_length": tgt_lengths, "binder_embedding": bnd_tok_embs, "binder_attention_mask": bnd_masks, "binder_length": bnd_lengths, "affinity": view[COL_AFF].astype(float).tolist(), "affinity_class": view["affinity_class"].tolist(), iptm_col: view[iptm_col].astype(float).tolist(), }) # ====================== # ESM2 - shared target encoder # ====================== def load_esm(): print(f" Loading ESM2: {ESM_MODEL}") tok = AutoTokenizer.from_pretrained(ESM_MODEL) model = EsmModel.from_pretrained(ESM_MODEL).to(DEVICE).eval() return tok, model @torch.no_grad() def embed_esm_pooled(seqs, tok, model) -> np.ndarray: """Returns (N, H) float32 - mean-pooled over non-pad tokens.""" all_embs = [] for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 pooled"): batch = seqs[i:i + ESM_BATCH] enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=ESM_MAX_LEN) ids = enc["input_ids"].to(DEVICE) mask = enc["attention_mask"].to(DEVICE) h = model(input_ids=ids, attention_mask=mask).last_hidden_state attn_f = mask.unsqueeze(-1).float() pooled = ((h * attn_f).sum(dim=1) / attn_f.sum(dim=1).clamp(min=1e-9)).cpu().numpy().astype(np.float32) all_embs.append(pooled) return np.vstack(all_embs) @torch.no_grad() def embed_esm_unpooled(seqs, tok, model): """ Returns per-token lists (CLS/EOS/pad excluded). tok_embs : list of (Lt_i, H) float16 arrays masks : list of (Lt_i,) int8 arrays (all-ones) lengths : list of int """ cls_id = tok.cls_token_id eos_id = tok.eos_token_id tok_embs, masks, lengths = [], [], [] for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 unpooled"): batch = seqs[i:i + ESM_BATCH] enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=ESM_MAX_LEN) ids = enc["input_ids"].to(DEVICE) mask = enc["attention_mask"].to(DEVICE) h = model(input_ids=ids, attention_mask=mask).last_hidden_state for b in range(h.shape[0]): keep = mask[b].bool() if cls_id is not None: keep = keep & (ids[b] != cls_id) if eos_id is not None: keep = keep & (ids[b] != eos_id) emb = h[b, keep].cpu().to(torch.float16).numpy() tok_embs.append(emb) masks.append(np.ones(emb.shape[0], dtype=np.int8)) lengths.append(emb.shape[0]) return tok_embs, masks, lengths # ====================== # Generic binder embedding helpers # ====================== def _get_special_ids_t(tokenizer): special_ids = sorted({ x for x in [ getattr(tokenizer, attr, None) for attr in ("pad_token_id", "cls_token_id", "sep_token_id", "bos_token_id", "eos_token_id", "mask_token_id") ] if x is not None }) return (torch.tensor(special_ids, device=DEVICE, dtype=torch.long) if special_ids else None) def _pool_and_unpool(last_hidden, input_ids, attention_mask, special_ids_t): """Mean-pool over non-special valid tokens; also return per-token arrays.""" valid = attention_mask.bool() if special_ids_t is not None: valid = valid & (~torch.isin(input_ids, special_ids_t)) valid_f = valid.unsqueeze(-1).float() pooled = ( torch.sum(last_hidden * valid_f, dim=1) / torch.clamp(valid_f.sum(dim=1), min=1e-9) ).cpu().numpy().astype(np.float32) tok_embs, masks, lengths = [], [], [] for b in range(last_hidden.shape[0]): emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy() tok_embs.append(emb) masks.append(np.ones(emb.shape[0], dtype=np.int8)) lengths.append(emb.shape[0]) return pooled, tok_embs, masks, lengths # ====================== # PeptideCLM - SMILES binder encoder # ====================== def load_peptideclm(): print(f" Loading PeptideCLM: {PEPTIDECLM_MODEL}") tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS) model = (AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL) .roformer.to(DEVICE).eval()) return tok, model, _get_special_ids_t(tok) @torch.no_grad() def embed_peptideclm(seqs, tok, model, sid_t): pooled_all, tok_all, mask_all, len_all = [], [], [], [] for i in pbar(range(0, len(seqs), PEPTIDECLM_BATCH), desc=" PeptideCLM binder"): batch = seqs[i:i + PEPTIDECLM_BATCH] enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=PEPTIDECLM_MAX_LEN) ids = enc["input_ids"].to(DEVICE) mask = enc["attention_mask"].to(DEVICE) h = model(input_ids=ids, attention_mask=mask).last_hidden_state p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t) pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l) return np.vstack(pooled_all), tok_all, mask_all, len_all # ====================== # ChemBERTa - SMILES binder encoder # ====================== def load_chemberta(): print(f" Loading ChemBERTa: {CHEMBERTA_MODEL}") tok = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL) model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(DEVICE).eval() return tok, model, _get_special_ids_t(tok) @torch.no_grad() def embed_chemberta(seqs, tok, model, sid_t): pooled_all, tok_all, mask_all, len_all = [], [], [], [] for i in pbar(range(0, len(seqs), CHEMBERTA_BATCH), desc=" ChemBERTa binder"): batch = seqs[i:i + CHEMBERTA_BATCH] enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=CHEMBERTA_MAX_LEN) ids = enc["input_ids"].to(DEVICE) mask = enc["attention_mask"].to(DEVICE) h = model(input_ids=ids, attention_mask=mask).last_hidden_state p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t) pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l) return np.vstack(pooled_all), tok_all, mask_all, len_all # ====================== # WT branch (ESM2 × ESM2) # ====================== def run_wt_branch(wt_train: pd.DataFrame, wt_val: pd.DataFrame, esm_tok, esm_model): print("\n" + "="*55) print(" Branch : WT (ESM2 target × ESM2 binder)") print("="*55) pooled_splits, unpooled_splits = {}, {} for split_name, view in [("train", wt_train), ("val", wt_val)]: print(f"\n [{split_name}] {len(view)} rows") targets = view["target_sequence"].tolist() binders = view["sequence"].tolist() tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model) bnd_pooled = embed_esm_pooled(binders, esm_tok, esm_model) tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled(targets, esm_tok, esm_model) bnd_tok_embs, bnd_masks, bnd_lengths = embed_esm_unpooled(binders, esm_tok, esm_model) pooled_splits[split_name] = build_pooled_ds( view, COL_WT_IPTM, tgt_pooled, bnd_pooled) unpooled_splits[split_name] = build_unpooled_ds( view, COL_WT_IPTM, tgt_tok_embs, tgt_masks, tgt_lengths, bnd_tok_embs, bnd_masks, bnd_lengths) pooled_out = OUT_ROOT / "pair_wt_wt_pooled" unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled" DatasetDict(pooled_splits).save_to_disk(str(pooled_out)) DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out)) print(f"\n WT pooled to {pooled_out}") print(f" WT unpooled to {unpooled_out}") # ====================== # SMILES branch (ESM2 × {PeptideCLM | ChemBERTa}) # ====================== def run_smiles_binder_model(name: str, smi_train: pd.DataFrame, smi_val: pd.DataFrame, esm_tok, esm_model, load_fn, embed_fn): print("\n" + "="*55) print(f" Branch : SMILES (ESM2 target × {name} binder)") print("="*55) binder_tok, binder_model, sid_t = load_fn() pooled_splits, unpooled_splits = {}, {} for split_name, view in [("train", smi_train), ("val", smi_val)]: print(f"\n [{split_name}] {len(view)} rows") targets = view["target_sequence"].tolist() binders = view["sequence"].tolist() print(" ESM2 target - pooled ...") tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model) print(" ESM2 target - unpooled ...") tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled( targets, esm_tok, esm_model) print(f" {name} binder - pooled + unpooled ...") bnd_pooled, bnd_tok_embs, bnd_masks, bnd_lengths = embed_fn( binders, binder_tok, binder_model, sid_t) pooled_splits[split_name] = build_pooled_ds( view, COL_SMI_IPTM, tgt_pooled, bnd_pooled) unpooled_splits[split_name] = build_unpooled_ds( view, COL_SMI_IPTM, tgt_tok_embs, tgt_masks, tgt_lengths, bnd_tok_embs, bnd_masks, bnd_lengths) suffix = "" if name.lower() == "peptideclm" else f"_{name.lower()}" pooled_out = OUT_ROOT / f"pair_wt_smiles_pooled{suffix}" unpooled_out = OUT_ROOT / f"pair_wt_smiles_unpooled{suffix}" DatasetDict(pooled_splits).save_to_disk(str(pooled_out)) DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out)) print(f"\n {name} pooled to {pooled_out}") print(f" {name} unpooled to {unpooled_out}") del binder_model torch.cuda.empty_cache() # ====================== # Main # ====================== def main(): print(f"Device : {DEVICE}") print(f"CSV : {CSV_PATH}") print(f"Out : {OUT_ROOT}\n") OUT_ROOT.mkdir(parents=True, exist_ok=True) # ------------------------------------------------------------------ # 1. Load + dedup # ------------------------------------------------------------------ with section("load csv + dedup"): df = pd.read_csv(CSV_PATH) print(f"Raw rows: {len(df)}") df["orig_idx"] = df.index # traceability only for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE]: if c in df.columns: df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x) for col in [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]: if col not in df.columns: raise ValueError(f"Missing required column: '{col}'") dedup_cols = [c for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE] if c in df.columns] before = len(df) df = df.drop_duplicates(subset=dedup_cols, keep="first").reset_index(drop=True) print(f"After dedup pass 1 (raw columns) : {len(df)} (-{before - len(df)})") df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") # ------------------------------------------------------------------ # 2. Prepare per-branch subsets # ------------------------------------------------------------------ with section("prepare WT / SMILES subsets"): # ── WT branch ────────────────────────────────────────────────── # Both seq1 and seq2 must be canonical (no X) for ESM2 df_wt = df.copy() df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip() df_wt = df_wt.dropna(subset=[COL_AFF]) df_wt = df_wt[~df_wt[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)] df_wt = df_wt[df_wt["wt_sequence"] != ""] df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)] df_wt = df_wt.reset_index(drop=True) # ── SMILES branch ────────────────────────────────────────────── # seq1 must be canonical (no X) for ESM2; binder SMILES picked # by priority (Fasta2SMILES > REACT_SMILES > Merge_SMILES), then # dedup pass 2 on (seq1, picked smiles_sequence) df_smi = df.copy() df_smi = df_smi.dropna(subset=[COL_AFF]) df_smi = df_smi[ pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna() ] df_smi = df_smi[~df_smi[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)] df_smi = df_smi.reset_index(drop=True) df_smi["smiles_sequence"] = df_smi.apply(pick_smiles, axis=1) df_smi = df_smi[df_smi["smiles_sequence"].notna()].reset_index(drop=True) print(f"After requiring ≥1 valid SMILES : {len(df_smi)}") # Dedup pass 2: (seq1, picked smiles_sequence) before = len(df_smi) df_smi = df_smi.drop_duplicates( subset=[COL_SEQ1, "smiles_sequence"], keep="first" ).reset_index(drop=True) print(f"After dedup pass 2 (seq1, smiles_sequence): {len(df_smi)} (-{before - len(df_smi)})") assert not df_smi.duplicated(subset=[COL_SEQ1, "smiles_sequence"]).any(), \ "BUG: duplicate (seq1, smiles_sequence) pairs remain!" print(f"\n[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)}") # ------------------------------------------------------------------ # 3. Split # ------------------------------------------------------------------ with section("split WT and SMILES separately"): df_wt2 = make_distribution_matched_split(df_wt) df_smi2 = make_distribution_matched_split(df_smi) df_wt2.to_csv(OUT_ROOT / "binding_affinity_wt_meta_with_split.csv", index=False) df_smi2.to_csv(OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv", index=False) # ------------------------------------------------------------------ # 4. Build split views # ------------------------------------------------------------------ wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM) smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM) wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True) wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True) smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True) smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True) print(f"\nSplit sizes - WT: train={len(wt_train)} val={len(wt_val)}") print(f"Split sizes - SMILES: train={len(smi_train)} val={len(smi_val)}") # ------------------------------------------------------------------ # 5. Load ESM2 once - shared across all branches # ------------------------------------------------------------------ print("\nLoading ESM2 (shared target encoder) ...") esm_tok, esm_model = load_esm() # ------------------------------------------------------------------ # 6. WT branch # ------------------------------------------------------------------ run_wt_branch(wt_train, wt_val, esm_tok, esm_model) # ------------------------------------------------------------------ # 7. SMILES branches # ------------------------------------------------------------------ if RUN_PEPTIDECLM: run_smiles_binder_model( "peptideclm", smi_train, smi_val, esm_tok, esm_model, load_fn=load_peptideclm, embed_fn=embed_peptideclm, ) if RUN_CHEMBERTA: run_smiles_binder_model( "chemberta", smi_train, smi_val, esm_tok, esm_model, load_fn=load_chemberta, embed_fn=embed_chemberta, ) print(f"\n All done. Datasets saved under: {OUT_ROOT}") if __name__ == "__main__": main()