Joblib
PeptiVerse / training_data_cleaned /binding_affinity_split.py
ynuozhang
major update
04c2975
import math
import sys
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from tqdm import tqdm
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, EsmModel
# ======================
# CONFIG
# ======================
ROOT = Path("<>") # CHANGE HERE
PROJ_ROOT = ROOT / "PeptiVerse"
CSV_PATH = PROJ_ROOT / "training_data" / "c-binding.csv"
OUT_ROOT = PROJ_ROOT / "training_data_cleaned" / "binding_affinity"
# ESM2 - target encoder (shared across all branches)
ESM_MODEL = "facebook/esm2_t33_650M_UR50D"
ESM_MAX_LEN = 1022
ESM_BATCH = 32
# PeptideCLM - SMILES binder encoder
sys.path.append(str(PROJ_ROOT))
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all"
TOKENIZER_VOCAB = str(PROJ_ROOT / "tokenizer" / "new_vocab.txt")
TOKENIZER_SPLITS = str(PROJ_ROOT / "tokenizer" / "new_splits.txt")
PEPTIDECLM_MAX_LEN = 768
PEPTIDECLM_BATCH = 128
# ChemBERTa - SMILES binder encoder
CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM"
CHEMBERTA_MAX_LEN = 512
CHEMBERTA_BATCH = 128
# Which SMILES binder models to run
RUN_PEPTIDECLM = True
RUN_CHEMBERTA = True
# CSV column names
COL_SEQ1 = "seq1"
COL_SEQ2 = "seq2"
COL_AFF = "affinity"
COL_F2S = "Fasta2SMILES"
COL_REACT = "REACT_SMILES"
COL_MERGE = "Merge_SMILES"
COL_WT_IPTM = "wt_iptm_score"
COL_SMI_IPTM = "smiles_iptm_score"
# Split config
TRAIN_FRAC = 0.80
RANDOM_SEED = 1986
AFFINITY_Q_BINS = 30
# Logging
QUIET = True
USE_TQDM = False
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# ======================
# Logging / progress
# ======================
def log(msg: str):
if not QUIET:
print(msg)
def pbar(it, **kwargs):
return tqdm(it, **kwargs) if USE_TQDM else it
@contextmanager
def section(title: str):
log(f"\n=== {title} ===")
yield
log(f"=== done: {title} ===")
# ======================
# Data Handling
# ======================
def has_uaa(seq: str) -> bool:
return "X" in str(seq).upper()
def pick_smiles(row) -> str | None:
"""Column Priority: Fasta2SMILES > REACT_SMILES > Merge_SMILES."""
for col in [COL_F2S, COL_REACT, COL_MERGE]:
val = row.get(col, None)
if val is not None and str(val).strip() not in ("", "nan", "None"):
return str(val).strip()
return None
def affinity_to_class(a: float) -> str:
if a >= 9.0:
return "High"
elif a >= 7.0:
return "Moderate"
else:
return "Low"
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
df = df.copy()
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
try:
df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
strat_col = "aff_bin"
except Exception:
df["aff_bin"] = df["affinity_class"]
strat_col = "aff_bin"
rng = np.random.RandomState(RANDOM_SEED)
df["split"] = None
for _, g in df.groupby(strat_col, observed=True):
idx = g.index.to_numpy()
rng.shuffle(idx)
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
df.loc[idx[:n_train], "split"] = "train"
df.loc[idx[n_train:], "split"] = "val"
df["split"] = df["split"].fillna("train")
return df
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
out = df_in.copy()
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip()
out["sequence"] = out[binder_seq_col].astype(str).str.strip()
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
return out[["target_sequence", "sequence", "label", "split",
iptm_col, COL_AFF, "affinity_class"]]
# ======================
# Dataset builders
# ======================
def build_pooled_ds(view: pd.DataFrame, iptm_col: str,
tgt_embs: np.ndarray, bnd_embs: np.ndarray) -> Dataset:
"""Both target and binder are (N, H) pooled float32 arrays."""
return Dataset.from_dict({
"target_sequence": view["target_sequence"].tolist(),
"sequence": view["sequence"].tolist(),
"label": view["label"].astype(float).tolist(),
"target_embedding": tgt_embs, # (N, H_esm) float32
"binder_embedding": bnd_embs, # (N, H_binder) float32
"affinity": view[COL_AFF].astype(float).tolist(),
"affinity_class": view["affinity_class"].tolist(),
iptm_col: view[iptm_col].astype(float).tolist(),
})
def build_unpooled_ds(view: pd.DataFrame, iptm_col: str,
tgt_tok_embs, tgt_masks, tgt_lengths,
bnd_tok_embs, bnd_masks, bnd_lengths) -> Dataset:
"""
Per-token lists for both sides.
target_embedding[i] : (Lt_i, H_esm) float16 ndarray
binder_embedding[i] : (Lb_i, H_binder) float16 ndarray
"""
return Dataset.from_dict({
"target_sequence": view["target_sequence"].tolist(),
"sequence": view["sequence"].tolist(),
"label": view["label"].astype(float).tolist(),
"target_embedding": tgt_tok_embs,
"target_attention_mask": tgt_masks,
"target_length": tgt_lengths,
"binder_embedding": bnd_tok_embs,
"binder_attention_mask": bnd_masks,
"binder_length": bnd_lengths,
"affinity": view[COL_AFF].astype(float).tolist(),
"affinity_class": view["affinity_class"].tolist(),
iptm_col: view[iptm_col].astype(float).tolist(),
})
# ======================
# ESM2 - shared target encoder
# ======================
def load_esm():
print(f" Loading ESM2: {ESM_MODEL}")
tok = AutoTokenizer.from_pretrained(ESM_MODEL)
model = EsmModel.from_pretrained(ESM_MODEL).to(DEVICE).eval()
return tok, model
@torch.no_grad()
def embed_esm_pooled(seqs, tok, model) -> np.ndarray:
"""Returns (N, H) float32 - mean-pooled over non-pad tokens."""
all_embs = []
for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 pooled"):
batch = seqs[i:i + ESM_BATCH]
enc = tok(batch, return_tensors="pt", padding=True,
truncation=True, max_length=ESM_MAX_LEN)
ids = enc["input_ids"].to(DEVICE)
mask = enc["attention_mask"].to(DEVICE)
h = model(input_ids=ids, attention_mask=mask).last_hidden_state
attn_f = mask.unsqueeze(-1).float()
pooled = ((h * attn_f).sum(dim=1) /
attn_f.sum(dim=1).clamp(min=1e-9)).cpu().numpy().astype(np.float32)
all_embs.append(pooled)
return np.vstack(all_embs)
@torch.no_grad()
def embed_esm_unpooled(seqs, tok, model):
"""
Returns per-token lists (CLS/EOS/pad excluded).
tok_embs : list of (Lt_i, H) float16 arrays
masks : list of (Lt_i,) int8 arrays (all-ones)
lengths : list of int
"""
cls_id = tok.cls_token_id
eos_id = tok.eos_token_id
tok_embs, masks, lengths = [], [], []
for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 unpooled"):
batch = seqs[i:i + ESM_BATCH]
enc = tok(batch, return_tensors="pt", padding=True,
truncation=True, max_length=ESM_MAX_LEN)
ids = enc["input_ids"].to(DEVICE)
mask = enc["attention_mask"].to(DEVICE)
h = model(input_ids=ids, attention_mask=mask).last_hidden_state
for b in range(h.shape[0]):
keep = mask[b].bool()
if cls_id is not None:
keep = keep & (ids[b] != cls_id)
if eos_id is not None:
keep = keep & (ids[b] != eos_id)
emb = h[b, keep].cpu().to(torch.float16).numpy()
tok_embs.append(emb)
masks.append(np.ones(emb.shape[0], dtype=np.int8))
lengths.append(emb.shape[0])
return tok_embs, masks, lengths
# ======================
# Generic binder embedding helpers
# ======================
def _get_special_ids_t(tokenizer):
special_ids = sorted({
x for x in [
getattr(tokenizer, attr, None)
for attr in ("pad_token_id", "cls_token_id", "sep_token_id",
"bos_token_id", "eos_token_id", "mask_token_id")
] if x is not None
})
return (torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
if special_ids else None)
def _pool_and_unpool(last_hidden, input_ids, attention_mask, special_ids_t):
"""Mean-pool over non-special valid tokens; also return per-token arrays."""
valid = attention_mask.bool()
if special_ids_t is not None:
valid = valid & (~torch.isin(input_ids, special_ids_t))
valid_f = valid.unsqueeze(-1).float()
pooled = (
torch.sum(last_hidden * valid_f, dim=1) /
torch.clamp(valid_f.sum(dim=1), min=1e-9)
).cpu().numpy().astype(np.float32)
tok_embs, masks, lengths = [], [], []
for b in range(last_hidden.shape[0]):
emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy()
tok_embs.append(emb)
masks.append(np.ones(emb.shape[0], dtype=np.int8))
lengths.append(emb.shape[0])
return pooled, tok_embs, masks, lengths
# ======================
# PeptideCLM - SMILES binder encoder
# ======================
def load_peptideclm():
print(f" Loading PeptideCLM: {PEPTIDECLM_MODEL}")
tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
model = (AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL)
.roformer.to(DEVICE).eval())
return tok, model, _get_special_ids_t(tok)
@torch.no_grad()
def embed_peptideclm(seqs, tok, model, sid_t):
pooled_all, tok_all, mask_all, len_all = [], [], [], []
for i in pbar(range(0, len(seqs), PEPTIDECLM_BATCH), desc=" PeptideCLM binder"):
batch = seqs[i:i + PEPTIDECLM_BATCH]
enc = tok(batch, return_tensors="pt", padding=True,
truncation=True, max_length=PEPTIDECLM_MAX_LEN)
ids = enc["input_ids"].to(DEVICE)
mask = enc["attention_mask"].to(DEVICE)
h = model(input_ids=ids, attention_mask=mask).last_hidden_state
p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t)
pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l)
return np.vstack(pooled_all), tok_all, mask_all, len_all
# ======================
# ChemBERTa - SMILES binder encoder
# ======================
def load_chemberta():
print(f" Loading ChemBERTa: {CHEMBERTA_MODEL}")
tok = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL)
model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(DEVICE).eval()
return tok, model, _get_special_ids_t(tok)
@torch.no_grad()
def embed_chemberta(seqs, tok, model, sid_t):
pooled_all, tok_all, mask_all, len_all = [], [], [], []
for i in pbar(range(0, len(seqs), CHEMBERTA_BATCH), desc=" ChemBERTa binder"):
batch = seqs[i:i + CHEMBERTA_BATCH]
enc = tok(batch, return_tensors="pt", padding=True,
truncation=True, max_length=CHEMBERTA_MAX_LEN)
ids = enc["input_ids"].to(DEVICE)
mask = enc["attention_mask"].to(DEVICE)
h = model(input_ids=ids, attention_mask=mask).last_hidden_state
p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t)
pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l)
return np.vstack(pooled_all), tok_all, mask_all, len_all
# ======================
# WT branch (ESM2 Γ— ESM2)
# ======================
def run_wt_branch(wt_train: pd.DataFrame, wt_val: pd.DataFrame,
esm_tok, esm_model):
print("\n" + "="*55)
print(" Branch : WT (ESM2 target Γ— ESM2 binder)")
print("="*55)
pooled_splits, unpooled_splits = {}, {}
for split_name, view in [("train", wt_train), ("val", wt_val)]:
print(f"\n [{split_name}] {len(view)} rows")
targets = view["target_sequence"].tolist()
binders = view["sequence"].tolist()
tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model)
bnd_pooled = embed_esm_pooled(binders, esm_tok, esm_model)
tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled(targets, esm_tok, esm_model)
bnd_tok_embs, bnd_masks, bnd_lengths = embed_esm_unpooled(binders, esm_tok, esm_model)
pooled_splits[split_name] = build_pooled_ds(
view, COL_WT_IPTM, tgt_pooled, bnd_pooled)
unpooled_splits[split_name] = build_unpooled_ds(
view, COL_WT_IPTM,
tgt_tok_embs, tgt_masks, tgt_lengths,
bnd_tok_embs, bnd_masks, bnd_lengths)
pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
DatasetDict(pooled_splits).save_to_disk(str(pooled_out))
DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out))
print(f"\n WT pooled to {pooled_out}")
print(f" WT unpooled to {unpooled_out}")
# ======================
# SMILES branch (ESM2 Γ— {PeptideCLM | ChemBERTa})
# ======================
def run_smiles_binder_model(name: str,
smi_train: pd.DataFrame, smi_val: pd.DataFrame,
esm_tok, esm_model,
load_fn, embed_fn):
print("\n" + "="*55)
print(f" Branch : SMILES (ESM2 target Γ— {name} binder)")
print("="*55)
binder_tok, binder_model, sid_t = load_fn()
pooled_splits, unpooled_splits = {}, {}
for split_name, view in [("train", smi_train), ("val", smi_val)]:
print(f"\n [{split_name}] {len(view)} rows")
targets = view["target_sequence"].tolist()
binders = view["sequence"].tolist()
print(" ESM2 target - pooled ...")
tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model)
print(" ESM2 target - unpooled ...")
tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled(
targets, esm_tok, esm_model)
print(f" {name} binder - pooled + unpooled ...")
bnd_pooled, bnd_tok_embs, bnd_masks, bnd_lengths = embed_fn(
binders, binder_tok, binder_model, sid_t)
pooled_splits[split_name] = build_pooled_ds(
view, COL_SMI_IPTM, tgt_pooled, bnd_pooled)
unpooled_splits[split_name] = build_unpooled_ds(
view, COL_SMI_IPTM,
tgt_tok_embs, tgt_masks, tgt_lengths,
bnd_tok_embs, bnd_masks, bnd_lengths)
suffix = "" if name.lower() == "peptideclm" else f"_{name.lower()}"
pooled_out = OUT_ROOT / f"pair_wt_smiles_pooled{suffix}"
unpooled_out = OUT_ROOT / f"pair_wt_smiles_unpooled{suffix}"
DatasetDict(pooled_splits).save_to_disk(str(pooled_out))
DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out))
print(f"\n {name} pooled to {pooled_out}")
print(f" {name} unpooled to {unpooled_out}")
del binder_model
torch.cuda.empty_cache()
# ======================
# Main
# ======================
def main():
print(f"Device : {DEVICE}")
print(f"CSV : {CSV_PATH}")
print(f"Out : {OUT_ROOT}\n")
OUT_ROOT.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------------------------
# 1. Load + dedup
# ------------------------------------------------------------------
with section("load csv + dedup"):
df = pd.read_csv(CSV_PATH)
print(f"Raw rows: {len(df)}")
df["orig_idx"] = df.index # traceability only
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE]:
if c in df.columns:
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
for col in [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]:
if col not in df.columns:
raise ValueError(f"Missing required column: '{col}'")
dedup_cols = [c for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE]
if c in df.columns]
before = len(df)
df = df.drop_duplicates(subset=dedup_cols, keep="first").reset_index(drop=True)
print(f"After dedup pass 1 (raw columns) : {len(df)} (-{before - len(df)})")
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
# ------------------------------------------------------------------
# 2. Prepare per-branch subsets
# ------------------------------------------------------------------
with section("prepare WT / SMILES subsets"):
# ── WT branch ──────────────────────────────────────────────────
# Both seq1 and seq2 must be canonical (no X) for ESM2
df_wt = df.copy()
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
df_wt = df_wt.dropna(subset=[COL_AFF])
df_wt = df_wt[~df_wt[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)]
df_wt = df_wt[df_wt["wt_sequence"] != ""]
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)]
df_wt = df_wt.reset_index(drop=True)
# ── SMILES branch ──────────────────────────────────────────────
# seq1 must be canonical (no X) for ESM2; binder SMILES picked
# by priority (Fasta2SMILES > REACT_SMILES > Merge_SMILES), then
# dedup pass 2 on (seq1, picked smiles_sequence)
df_smi = df.copy()
df_smi = df_smi.dropna(subset=[COL_AFF])
df_smi = df_smi[
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
]
df_smi = df_smi[~df_smi[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)]
df_smi = df_smi.reset_index(drop=True)
df_smi["smiles_sequence"] = df_smi.apply(pick_smiles, axis=1)
df_smi = df_smi[df_smi["smiles_sequence"].notna()].reset_index(drop=True)
print(f"After requiring β‰₯1 valid SMILES : {len(df_smi)}")
# Dedup pass 2: (seq1, picked smiles_sequence)
before = len(df_smi)
df_smi = df_smi.drop_duplicates(
subset=[COL_SEQ1, "smiles_sequence"], keep="first"
).reset_index(drop=True)
print(f"After dedup pass 2 (seq1, smiles_sequence): {len(df_smi)} (-{before - len(df_smi)})")
assert not df_smi.duplicated(subset=[COL_SEQ1, "smiles_sequence"]).any(), \
"BUG: duplicate (seq1, smiles_sequence) pairs remain!"
print(f"\n[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)}")
# ------------------------------------------------------------------
# 3. Split
# ------------------------------------------------------------------
with section("split WT and SMILES separately"):
df_wt2 = make_distribution_matched_split(df_wt)
df_smi2 = make_distribution_matched_split(df_smi)
df_wt2.to_csv(OUT_ROOT / "binding_affinity_wt_meta_with_split.csv", index=False)
df_smi2.to_csv(OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv", index=False)
# ------------------------------------------------------------------
# 4. Build split views
# ------------------------------------------------------------------
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
print(f"\nSplit sizes - WT: train={len(wt_train)} val={len(wt_val)}")
print(f"Split sizes - SMILES: train={len(smi_train)} val={len(smi_val)}")
# ------------------------------------------------------------------
# 5. Load ESM2 once - shared across all branches
# ------------------------------------------------------------------
print("\nLoading ESM2 (shared target encoder) ...")
esm_tok, esm_model = load_esm()
# ------------------------------------------------------------------
# 6. WT branch
# ------------------------------------------------------------------
run_wt_branch(wt_train, wt_val, esm_tok, esm_model)
# ------------------------------------------------------------------
# 7. SMILES branches
# ------------------------------------------------------------------
if RUN_PEPTIDECLM:
run_smiles_binder_model(
"peptideclm", smi_train, smi_val,
esm_tok, esm_model,
load_fn=load_peptideclm,
embed_fn=embed_peptideclm,
)
if RUN_CHEMBERTA:
run_smiles_binder_model(
"chemberta", smi_train, smi_val,
esm_tok, esm_model,
load_fn=load_chemberta,
embed_fn=embed_chemberta,
)
print(f"\n All done. Datasets saved under: {OUT_ROOT}")
if __name__ == "__main__":
main()