| import math |
| import sys |
| from contextlib import contextmanager |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from datasets import Dataset, DatasetDict |
| from tqdm import tqdm |
| from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, EsmModel |
|
|
| |
| |
| |
|
|
| ROOT = Path("<>") |
| PROJ_ROOT = ROOT / "PeptiVerse" |
|
|
| CSV_PATH = PROJ_ROOT / "training_data" / "c-binding.csv" |
|
|
| OUT_ROOT = PROJ_ROOT / "training_data_cleaned" / "binding_affinity" |
|
|
| |
| ESM_MODEL = "facebook/esm2_t33_650M_UR50D" |
| ESM_MAX_LEN = 1022 |
| ESM_BATCH = 32 |
|
|
| |
| sys.path.append(str(PROJ_ROOT)) |
| from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
| PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all" |
| TOKENIZER_VOCAB = str(PROJ_ROOT / "tokenizer" / "new_vocab.txt") |
| TOKENIZER_SPLITS = str(PROJ_ROOT / "tokenizer" / "new_splits.txt") |
| PEPTIDECLM_MAX_LEN = 768 |
| PEPTIDECLM_BATCH = 128 |
|
|
| |
| CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM" |
| CHEMBERTA_MAX_LEN = 512 |
| CHEMBERTA_BATCH = 128 |
|
|
| |
| RUN_PEPTIDECLM = True |
| RUN_CHEMBERTA = True |
|
|
| |
| COL_SEQ1 = "seq1" |
| COL_SEQ2 = "seq2" |
| COL_AFF = "affinity" |
| COL_F2S = "Fasta2SMILES" |
| COL_REACT = "REACT_SMILES" |
| COL_MERGE = "Merge_SMILES" |
| COL_WT_IPTM = "wt_iptm_score" |
| COL_SMI_IPTM = "smiles_iptm_score" |
|
|
| |
| TRAIN_FRAC = 0.80 |
| RANDOM_SEED = 1986 |
| AFFINITY_Q_BINS = 30 |
|
|
| |
| QUIET = True |
| USE_TQDM = False |
|
|
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| |
| |
| |
|
|
| def log(msg: str): |
| if not QUIET: |
| print(msg) |
|
|
| def pbar(it, **kwargs): |
| return tqdm(it, **kwargs) if USE_TQDM else it |
|
|
| @contextmanager |
| def section(title: str): |
| log(f"\n=== {title} ===") |
| yield |
| log(f"=== done: {title} ===") |
|
|
|
|
| |
| |
| |
|
|
| def has_uaa(seq: str) -> bool: |
| return "X" in str(seq).upper() |
|
|
| def pick_smiles(row) -> str | None: |
| """Column Priority: Fasta2SMILES > REACT_SMILES > Merge_SMILES.""" |
| for col in [COL_F2S, COL_REACT, COL_MERGE]: |
| val = row.get(col, None) |
| if val is not None and str(val).strip() not in ("", "nan", "None"): |
| return str(val).strip() |
| return None |
|
|
| def affinity_to_class(a: float) -> str: |
| if a >= 9.0: |
| return "High" |
| elif a >= 7.0: |
| return "Moderate" |
| else: |
| return "Low" |
|
|
| def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame: |
| df = df.copy() |
| df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") |
| df = df.dropna(subset=[COL_AFF]).reset_index(drop=True) |
| df["affinity_class"] = df[COL_AFF].apply(affinity_to_class) |
|
|
| try: |
| df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop") |
| strat_col = "aff_bin" |
| except Exception: |
| df["aff_bin"] = df["affinity_class"] |
| strat_col = "aff_bin" |
|
|
| rng = np.random.RandomState(RANDOM_SEED) |
| df["split"] = None |
| for _, g in df.groupby(strat_col, observed=True): |
| idx = g.index.to_numpy() |
| rng.shuffle(idx) |
| n_train = int(math.floor(len(idx) * TRAIN_FRAC)) |
| df.loc[idx[:n_train], "split"] = "train" |
| df.loc[idx[n_train:], "split"] = "val" |
| df["split"] = df["split"].fillna("train") |
| return df |
|
|
| def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame: |
| out = df_in.copy() |
| out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() |
| out["sequence"] = out[binder_seq_col].astype(str).str.strip() |
| out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce") |
| out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce") |
| out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce") |
| out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True) |
| return out[["target_sequence", "sequence", "label", "split", |
| iptm_col, COL_AFF, "affinity_class"]] |
|
|
|
|
| |
| |
| |
|
|
| def build_pooled_ds(view: pd.DataFrame, iptm_col: str, |
| tgt_embs: np.ndarray, bnd_embs: np.ndarray) -> Dataset: |
| """Both target and binder are (N, H) pooled float32 arrays.""" |
| return Dataset.from_dict({ |
| "target_sequence": view["target_sequence"].tolist(), |
| "sequence": view["sequence"].tolist(), |
| "label": view["label"].astype(float).tolist(), |
| "target_embedding": tgt_embs, |
| "binder_embedding": bnd_embs, |
| "affinity": view[COL_AFF].astype(float).tolist(), |
| "affinity_class": view["affinity_class"].tolist(), |
| iptm_col: view[iptm_col].astype(float).tolist(), |
| }) |
|
|
|
|
| def build_unpooled_ds(view: pd.DataFrame, iptm_col: str, |
| tgt_tok_embs, tgt_masks, tgt_lengths, |
| bnd_tok_embs, bnd_masks, bnd_lengths) -> Dataset: |
| """ |
| Per-token lists for both sides. |
| target_embedding[i] : (Lt_i, H_esm) float16 ndarray |
| binder_embedding[i] : (Lb_i, H_binder) float16 ndarray |
| """ |
| return Dataset.from_dict({ |
| "target_sequence": view["target_sequence"].tolist(), |
| "sequence": view["sequence"].tolist(), |
| "label": view["label"].astype(float).tolist(), |
|
|
| "target_embedding": tgt_tok_embs, |
| "target_attention_mask": tgt_masks, |
| "target_length": tgt_lengths, |
|
|
| "binder_embedding": bnd_tok_embs, |
| "binder_attention_mask": bnd_masks, |
| "binder_length": bnd_lengths, |
|
|
| "affinity": view[COL_AFF].astype(float).tolist(), |
| "affinity_class": view["affinity_class"].tolist(), |
| iptm_col: view[iptm_col].astype(float).tolist(), |
| }) |
|
|
|
|
| |
| |
| |
|
|
| def load_esm(): |
| print(f" Loading ESM2: {ESM_MODEL}") |
| tok = AutoTokenizer.from_pretrained(ESM_MODEL) |
| model = EsmModel.from_pretrained(ESM_MODEL).to(DEVICE).eval() |
| return tok, model |
|
|
|
|
| @torch.no_grad() |
| def embed_esm_pooled(seqs, tok, model) -> np.ndarray: |
| """Returns (N, H) float32 - mean-pooled over non-pad tokens.""" |
| all_embs = [] |
| for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 pooled"): |
| batch = seqs[i:i + ESM_BATCH] |
| enc = tok(batch, return_tensors="pt", padding=True, |
| truncation=True, max_length=ESM_MAX_LEN) |
| ids = enc["input_ids"].to(DEVICE) |
| mask = enc["attention_mask"].to(DEVICE) |
| h = model(input_ids=ids, attention_mask=mask).last_hidden_state |
| attn_f = mask.unsqueeze(-1).float() |
| pooled = ((h * attn_f).sum(dim=1) / |
| attn_f.sum(dim=1).clamp(min=1e-9)).cpu().numpy().astype(np.float32) |
| all_embs.append(pooled) |
| return np.vstack(all_embs) |
|
|
|
|
| @torch.no_grad() |
| def embed_esm_unpooled(seqs, tok, model): |
| """ |
| Returns per-token lists (CLS/EOS/pad excluded). |
| tok_embs : list of (Lt_i, H) float16 arrays |
| masks : list of (Lt_i,) int8 arrays (all-ones) |
| lengths : list of int |
| """ |
| cls_id = tok.cls_token_id |
| eos_id = tok.eos_token_id |
|
|
| tok_embs, masks, lengths = [], [], [] |
| for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 unpooled"): |
| batch = seqs[i:i + ESM_BATCH] |
| enc = tok(batch, return_tensors="pt", padding=True, |
| truncation=True, max_length=ESM_MAX_LEN) |
| ids = enc["input_ids"].to(DEVICE) |
| mask = enc["attention_mask"].to(DEVICE) |
| h = model(input_ids=ids, attention_mask=mask).last_hidden_state |
|
|
| for b in range(h.shape[0]): |
| keep = mask[b].bool() |
| if cls_id is not None: |
| keep = keep & (ids[b] != cls_id) |
| if eos_id is not None: |
| keep = keep & (ids[b] != eos_id) |
| emb = h[b, keep].cpu().to(torch.float16).numpy() |
| tok_embs.append(emb) |
| masks.append(np.ones(emb.shape[0], dtype=np.int8)) |
| lengths.append(emb.shape[0]) |
| return tok_embs, masks, lengths |
|
|
|
|
| |
| |
| |
|
|
| def _get_special_ids_t(tokenizer): |
| special_ids = sorted({ |
| x for x in [ |
| getattr(tokenizer, attr, None) |
| for attr in ("pad_token_id", "cls_token_id", "sep_token_id", |
| "bos_token_id", "eos_token_id", "mask_token_id") |
| ] if x is not None |
| }) |
| return (torch.tensor(special_ids, device=DEVICE, dtype=torch.long) |
| if special_ids else None) |
|
|
|
|
| def _pool_and_unpool(last_hidden, input_ids, attention_mask, special_ids_t): |
| """Mean-pool over non-special valid tokens; also return per-token arrays.""" |
| valid = attention_mask.bool() |
| if special_ids_t is not None: |
| valid = valid & (~torch.isin(input_ids, special_ids_t)) |
|
|
| valid_f = valid.unsqueeze(-1).float() |
| pooled = ( |
| torch.sum(last_hidden * valid_f, dim=1) / |
| torch.clamp(valid_f.sum(dim=1), min=1e-9) |
| ).cpu().numpy().astype(np.float32) |
|
|
| tok_embs, masks, lengths = [], [], [] |
| for b in range(last_hidden.shape[0]): |
| emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy() |
| tok_embs.append(emb) |
| masks.append(np.ones(emb.shape[0], dtype=np.int8)) |
| lengths.append(emb.shape[0]) |
| return pooled, tok_embs, masks, lengths |
|
|
|
|
| |
| |
| |
|
|
| def load_peptideclm(): |
| print(f" Loading PeptideCLM: {PEPTIDECLM_MODEL}") |
| tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS) |
| model = (AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL) |
| .roformer.to(DEVICE).eval()) |
| return tok, model, _get_special_ids_t(tok) |
|
|
|
|
| @torch.no_grad() |
| def embed_peptideclm(seqs, tok, model, sid_t): |
| pooled_all, tok_all, mask_all, len_all = [], [], [], [] |
| for i in pbar(range(0, len(seqs), PEPTIDECLM_BATCH), desc=" PeptideCLM binder"): |
| batch = seqs[i:i + PEPTIDECLM_BATCH] |
| enc = tok(batch, return_tensors="pt", padding=True, |
| truncation=True, max_length=PEPTIDECLM_MAX_LEN) |
| ids = enc["input_ids"].to(DEVICE) |
| mask = enc["attention_mask"].to(DEVICE) |
| h = model(input_ids=ids, attention_mask=mask).last_hidden_state |
| p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t) |
| pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l) |
| return np.vstack(pooled_all), tok_all, mask_all, len_all |
|
|
|
|
| |
| |
| |
|
|
| def load_chemberta(): |
| print(f" Loading ChemBERTa: {CHEMBERTA_MODEL}") |
| tok = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL) |
| model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(DEVICE).eval() |
| return tok, model, _get_special_ids_t(tok) |
|
|
|
|
| @torch.no_grad() |
| def embed_chemberta(seqs, tok, model, sid_t): |
| pooled_all, tok_all, mask_all, len_all = [], [], [], [] |
| for i in pbar(range(0, len(seqs), CHEMBERTA_BATCH), desc=" ChemBERTa binder"): |
| batch = seqs[i:i + CHEMBERTA_BATCH] |
| enc = tok(batch, return_tensors="pt", padding=True, |
| truncation=True, max_length=CHEMBERTA_MAX_LEN) |
| ids = enc["input_ids"].to(DEVICE) |
| mask = enc["attention_mask"].to(DEVICE) |
| h = model(input_ids=ids, attention_mask=mask).last_hidden_state |
| p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t) |
| pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l) |
| return np.vstack(pooled_all), tok_all, mask_all, len_all |
|
|
|
|
| |
| |
| |
|
|
| def run_wt_branch(wt_train: pd.DataFrame, wt_val: pd.DataFrame, |
| esm_tok, esm_model): |
| print("\n" + "="*55) |
| print(" Branch : WT (ESM2 target Γ ESM2 binder)") |
| print("="*55) |
|
|
| pooled_splits, unpooled_splits = {}, {} |
|
|
| for split_name, view in [("train", wt_train), ("val", wt_val)]: |
| print(f"\n [{split_name}] {len(view)} rows") |
| targets = view["target_sequence"].tolist() |
| binders = view["sequence"].tolist() |
|
|
| tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model) |
| bnd_pooled = embed_esm_pooled(binders, esm_tok, esm_model) |
|
|
| tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled(targets, esm_tok, esm_model) |
| bnd_tok_embs, bnd_masks, bnd_lengths = embed_esm_unpooled(binders, esm_tok, esm_model) |
|
|
| pooled_splits[split_name] = build_pooled_ds( |
| view, COL_WT_IPTM, tgt_pooled, bnd_pooled) |
| unpooled_splits[split_name] = build_unpooled_ds( |
| view, COL_WT_IPTM, |
| tgt_tok_embs, tgt_masks, tgt_lengths, |
| bnd_tok_embs, bnd_masks, bnd_lengths) |
|
|
| pooled_out = OUT_ROOT / "pair_wt_wt_pooled" |
| unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled" |
| DatasetDict(pooled_splits).save_to_disk(str(pooled_out)) |
| DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out)) |
| print(f"\n WT pooled to {pooled_out}") |
| print(f" WT unpooled to {unpooled_out}") |
|
|
|
|
| |
| |
| |
|
|
| def run_smiles_binder_model(name: str, |
| smi_train: pd.DataFrame, smi_val: pd.DataFrame, |
| esm_tok, esm_model, |
| load_fn, embed_fn): |
| print("\n" + "="*55) |
| print(f" Branch : SMILES (ESM2 target Γ {name} binder)") |
| print("="*55) |
|
|
| binder_tok, binder_model, sid_t = load_fn() |
| pooled_splits, unpooled_splits = {}, {} |
|
|
| for split_name, view in [("train", smi_train), ("val", smi_val)]: |
| print(f"\n [{split_name}] {len(view)} rows") |
| targets = view["target_sequence"].tolist() |
| binders = view["sequence"].tolist() |
|
|
| print(" ESM2 target - pooled ...") |
| tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model) |
|
|
| print(" ESM2 target - unpooled ...") |
| tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled( |
| targets, esm_tok, esm_model) |
|
|
| print(f" {name} binder - pooled + unpooled ...") |
| bnd_pooled, bnd_tok_embs, bnd_masks, bnd_lengths = embed_fn( |
| binders, binder_tok, binder_model, sid_t) |
|
|
| pooled_splits[split_name] = build_pooled_ds( |
| view, COL_SMI_IPTM, tgt_pooled, bnd_pooled) |
| unpooled_splits[split_name] = build_unpooled_ds( |
| view, COL_SMI_IPTM, |
| tgt_tok_embs, tgt_masks, tgt_lengths, |
| bnd_tok_embs, bnd_masks, bnd_lengths) |
|
|
| suffix = "" if name.lower() == "peptideclm" else f"_{name.lower()}" |
| pooled_out = OUT_ROOT / f"pair_wt_smiles_pooled{suffix}" |
| unpooled_out = OUT_ROOT / f"pair_wt_smiles_unpooled{suffix}" |
| DatasetDict(pooled_splits).save_to_disk(str(pooled_out)) |
| DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out)) |
| print(f"\n {name} pooled to {pooled_out}") |
| print(f" {name} unpooled to {unpooled_out}") |
|
|
| del binder_model |
| torch.cuda.empty_cache() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print(f"Device : {DEVICE}") |
| print(f"CSV : {CSV_PATH}") |
| print(f"Out : {OUT_ROOT}\n") |
| OUT_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
| |
| |
| |
| with section("load csv + dedup"): |
| df = pd.read_csv(CSV_PATH) |
| print(f"Raw rows: {len(df)}") |
| df["orig_idx"] = df.index |
|
|
| for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE]: |
| if c in df.columns: |
| df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x) |
|
|
| for col in [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]: |
| if col not in df.columns: |
| raise ValueError(f"Missing required column: '{col}'") |
|
|
| dedup_cols = [c for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE] |
| if c in df.columns] |
| before = len(df) |
| df = df.drop_duplicates(subset=dedup_cols, keep="first").reset_index(drop=True) |
| print(f"After dedup pass 1 (raw columns) : {len(df)} (-{before - len(df)})") |
|
|
| df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") |
|
|
| |
| |
| |
| with section("prepare WT / SMILES subsets"): |
| |
| |
| df_wt = df.copy() |
| df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip() |
| df_wt = df_wt.dropna(subset=[COL_AFF]) |
| df_wt = df_wt[~df_wt[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)] |
| df_wt = df_wt[df_wt["wt_sequence"] != ""] |
| df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)] |
| df_wt = df_wt.reset_index(drop=True) |
|
|
| |
| |
| |
| |
| df_smi = df.copy() |
| df_smi = df_smi.dropna(subset=[COL_AFF]) |
| df_smi = df_smi[ |
| pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna() |
| ] |
| df_smi = df_smi[~df_smi[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)] |
| df_smi = df_smi.reset_index(drop=True) |
|
|
| df_smi["smiles_sequence"] = df_smi.apply(pick_smiles, axis=1) |
| df_smi = df_smi[df_smi["smiles_sequence"].notna()].reset_index(drop=True) |
| print(f"After requiring β₯1 valid SMILES : {len(df_smi)}") |
|
|
| |
| before = len(df_smi) |
| df_smi = df_smi.drop_duplicates( |
| subset=[COL_SEQ1, "smiles_sequence"], keep="first" |
| ).reset_index(drop=True) |
| print(f"After dedup pass 2 (seq1, smiles_sequence): {len(df_smi)} (-{before - len(df_smi)})") |
|
|
| assert not df_smi.duplicated(subset=[COL_SEQ1, "smiles_sequence"]).any(), \ |
| "BUG: duplicate (seq1, smiles_sequence) pairs remain!" |
|
|
| print(f"\n[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)}") |
|
|
| |
| |
| |
| with section("split WT and SMILES separately"): |
| df_wt2 = make_distribution_matched_split(df_wt) |
| df_smi2 = make_distribution_matched_split(df_smi) |
|
|
| df_wt2.to_csv(OUT_ROOT / "binding_affinity_wt_meta_with_split.csv", index=False) |
| df_smi2.to_csv(OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv", index=False) |
|
|
| |
| |
| |
| wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM) |
| smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM) |
|
|
| wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True) |
| wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True) |
| smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True) |
| smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True) |
|
|
| print(f"\nSplit sizes - WT: train={len(wt_train)} val={len(wt_val)}") |
| print(f"Split sizes - SMILES: train={len(smi_train)} val={len(smi_val)}") |
|
|
| |
| |
| |
| print("\nLoading ESM2 (shared target encoder) ...") |
| esm_tok, esm_model = load_esm() |
|
|
| |
| |
| |
| run_wt_branch(wt_train, wt_val, esm_tok, esm_model) |
|
|
| |
| |
| |
| if RUN_PEPTIDECLM: |
| run_smiles_binder_model( |
| "peptideclm", smi_train, smi_val, |
| esm_tok, esm_model, |
| load_fn=load_peptideclm, |
| embed_fn=embed_peptideclm, |
| ) |
|
|
| if RUN_CHEMBERTA: |
| run_smiles_binder_model( |
| "chemberta", smi_train, smi_val, |
| esm_tok, esm_model, |
| load_fn=load_chemberta, |
| embed_fn=embed_chemberta, |
| ) |
|
|
| print(f"\n All done. Datasets saved under: {OUT_ROOT}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |