| import os, json |
| from pathlib import Path |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| import optuna |
| from datasets import load_from_disk, DatasetDict |
| from scipy.stats import spearmanr |
| from lightning.pytorch import seed_everything |
| seed_everything(1986) |
|
|
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float: |
| rho = spearmanr(y_true, y_pred).correlation |
| if rho is None or np.isnan(rho): |
| return 0.0 |
| return float(rho) |
|
|
|
|
| |
| |
| |
| |
| |
| def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor: |
| high = y >= 9.0 |
| low = y < 7.0 |
| mid = ~(high | low) |
| cls = torch.zeros_like(y, dtype=torch.long) |
| cls[mid] = 1 |
| cls[low] = 2 |
| return cls |
|
|
|
|
| |
| |
| |
| def load_split_paired(path: str): |
| dd = load_from_disk(path) |
| if not isinstance(dd, DatasetDict): |
| raise ValueError(f"Expected DatasetDict at {path}") |
| if "train" not in dd or "val" not in dd: |
| raise ValueError(f"DatasetDict missing train/val at {path}") |
| return dd["train"], dd["val"] |
|
|
|
|
| |
| |
| |
| def collate_pair_pooled(batch): |
| binder_key = "binder_embedding" if "binder_embedding" in batch[0] else "embedding" |
| Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) |
| Pb = torch.tensor([x[binder_key] for x in batch], dtype=torch.float32) |
| y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
| return Pt, Pb, y |
|
|
|
|
| |
| |
| |
| def collate_pair_unpooled(batch): |
| B = len(batch) |
| Ht = len(batch[0]["target_embedding"][0]) |
| Hb = len(batch[0]["binder_embedding"][0]) |
| Lt_max = max(int(x["target_length"]) for x in batch) |
| Lb_max = max(int(x["binder_length"]) for x in batch) |
|
|
| Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32) |
| Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32) |
| Mt = torch.zeros(B, Lt_max, dtype=torch.bool) |
| Mb = torch.zeros(B, Lb_max, dtype=torch.bool) |
| y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
|
|
| for i, x in enumerate(batch): |
| t = torch.tensor(x["target_embedding"], dtype=torch.float32) |
| b = torch.tensor(x["binder_embedding"], dtype=torch.float32) |
| lt, lb = t.shape[0], b.shape[0] |
| Pt[i, :lt] = t |
| Pb[i, :lb] = b |
| Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool) |
| Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool) |
|
|
| return Pt, Mt, Pb, Mb, y |
|
|
|
|
| |
| |
| |
| class CrossAttnPooled(nn.Module): |
| """ |
| pooled vectors -> treat as single-token sequences for cross attention |
| """ |
| def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
| super().__init__() |
| self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
| self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
| self.layers = nn.ModuleList([]) |
| for _ in range(n_layers): |
| self.layers.append(nn.ModuleDict({ |
| "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
| "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
| "n1t": nn.LayerNorm(hidden), |
| "n2t": nn.LayerNorm(hidden), |
| "n1b": nn.LayerNorm(hidden), |
| "n2b": nn.LayerNorm(hidden), |
| "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| })) |
|
|
| self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
| self.reg = nn.Linear(hidden, 1) |
| self.cls = nn.Linear(hidden, 3) |
|
|
| def forward(self, t_vec, b_vec): |
| |
| t = self.t_proj(t_vec).unsqueeze(0) |
| b = self.b_proj(b_vec).unsqueeze(0) |
|
|
| for L in self.layers: |
| t_attn, _ = L["attn_tb"](t, b, b) |
| t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) |
| t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) |
|
|
| b_attn, _ = L["attn_bt"](b, t, t) |
| b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) |
| b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) |
|
|
| t0 = t[0] |
| b0 = b[0] |
| z = torch.cat([t0, b0], dim=-1) |
| h = self.shared(z) |
| return self.reg(h).squeeze(-1), self.cls(h) |
|
|
|
|
| class CrossAttnUnpooled(nn.Module): |
| """ |
| token sequences with masks; alternating cross attention. |
| """ |
| def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
| super().__init__() |
| self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
| self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
| self.layers = nn.ModuleList([]) |
| for _ in range(n_layers): |
| self.layers.append(nn.ModuleDict({ |
| "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
| "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
| "n1t": nn.LayerNorm(hidden), |
| "n2t": nn.LayerNorm(hidden), |
| "n1b": nn.LayerNorm(hidden), |
| "n2b": nn.LayerNorm(hidden), |
| "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| })) |
|
|
| self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
| self.reg = nn.Linear(hidden, 1) |
| self.cls = nn.Linear(hidden, 3) |
|
|
| def masked_mean(self, X, M): |
| Mf = M.unsqueeze(-1).float() |
| denom = Mf.sum(dim=1).clamp(min=1.0) |
| return (X * Mf).sum(dim=1) / denom |
|
|
| def forward(self, T, Mt, B, Mb): |
| |
| T = self.t_proj(T) |
| Bx = self.b_proj(B) |
|
|
| kp_t = ~Mt |
| kp_b = ~Mb |
|
|
| for L in self.layers: |
| |
| T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) |
| T = L["n1t"](T + T_attn) |
| T = L["n2t"](T + L["fft"](T)) |
|
|
| |
| B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) |
| Bx = L["n1b"](Bx + B_attn) |
| Bx = L["n2b"](Bx + L["ffb"](Bx)) |
|
|
| t_pool = self.masked_mean(T, Mt) |
| b_pool = self.masked_mean(Bx, Mb) |
| z = torch.cat([t_pool, b_pool], dim=-1) |
| h = self.shared(z) |
| return self.reg(h).squeeze(-1), self.cls(h) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def eval_spearman_pooled(model, loader): |
| model.eval() |
| ys, ps = [], [] |
| for t, b, y in loader: |
| t = t.to(DEVICE, non_blocking=True) |
| b = b.to(DEVICE, non_blocking=True) |
| pred, _ = model(t, b) |
| ys.append(y.numpy()) |
| ps.append(pred.detach().cpu().numpy()) |
| return safe_spearmanr(np.concatenate(ys), np.concatenate(ps)) |
|
|
| @torch.no_grad() |
| def eval_spearman_unpooled(model, loader): |
| model.eval() |
| ys, ps = [], [] |
| for T, Mt, B, Mb, y in loader: |
| T = T.to(DEVICE, non_blocking=True) |
| Mt = Mt.to(DEVICE, non_blocking=True) |
| B = B.to(DEVICE, non_blocking=True) |
| Mb = Mb.to(DEVICE, non_blocking=True) |
| pred, _ = model(T, Mt, B, Mb) |
| ys.append(y.numpy()) |
| ps.append(pred.detach().cpu().numpy()) |
| return safe_spearmanr(np.concatenate(ys), np.concatenate(ps)) |
|
|
| def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0): |
| model.train() |
| for t, b, y in loader: |
| t = t.to(DEVICE, non_blocking=True) |
| b = b.to(DEVICE, non_blocking=True) |
| y = y.to(DEVICE, non_blocking=True) |
| y_cls = affinity_to_class_tensor(y) |
|
|
| opt.zero_grad(set_to_none=True) |
| pred, logits = model(t, b) |
| L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls) |
| L.backward() |
| if clip is not None: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), clip) |
| opt.step() |
|
|
| def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0): |
| model.train() |
| for T, Mt, B, Mb, y in loader: |
| T = T.to(DEVICE, non_blocking=True) |
| Mt = Mt.to(DEVICE, non_blocking=True) |
| B = B.to(DEVICE, non_blocking=True) |
| Mb = Mb.to(DEVICE, non_blocking=True) |
| y = y.to(DEVICE, non_blocking=True) |
| y_cls = affinity_to_class_tensor(y) |
|
|
| opt.zero_grad(set_to_none=True) |
| pred, logits = model(T, Mt, B, Mb) |
| L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls) |
| L.backward() |
| if clip is not None: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), clip) |
| opt.step() |
|
|
|
|
| |
| |
| |
| def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float: |
| lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True) |
| wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True) |
| dropout = trial.suggest_float("dropout", 0.0, 0.4) |
| hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768]) |
| n_heads = trial.suggest_categorical("n_heads", [4, 8]) |
| n_layers = trial.suggest_int("n_layers", 1, 4) |
| cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True) |
| batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128]) |
|
|
| |
| if mode == "pooled": |
| Ht = len(train_ds[0]["target_embedding"]) |
| binder_key = "binder_embedding" if "binder_embedding" in train_ds.column_names else "embedding" |
| Hb = len(train_ds[0][binder_key]) |
| collate = collate_pair_pooled |
| model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
| train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
| val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
| eval_fn = eval_spearman_pooled |
| train_fn = train_one_epoch_pooled |
|
|
| else: |
| Ht = len(train_ds[0]["target_embedding"][0]) |
| Hb = len(train_ds[0]["binder_embedding"][0]) |
| collate = collate_pair_unpooled |
| model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
| train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
| val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
| eval_fn = eval_spearman_unpooled |
| train_fn = train_one_epoch_unpooled |
|
|
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
| loss_reg = nn.MSELoss() |
| loss_cls = nn.CrossEntropyLoss() |
|
|
| best = -1e9 |
| bad = 0 |
| patience = 10 |
|
|
| for ep in range(1, 61): |
| train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) |
| rho = eval_fn(model, val_loader) |
|
|
| trial.report(rho, ep) |
| if trial.should_prune(): |
| raise optuna.TrialPruned() |
|
|
| if rho > best + 1e-6: |
| best = rho |
| bad = 0 |
| else: |
| bad += 1 |
| if bad >= patience: |
| break |
|
|
| return float(best) |
|
|
|
|
| |
| |
| |
| def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50): |
| out_dir = Path(out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| train_ds, val_ds = load_split_paired(dataset_path) |
| print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}") |
|
|
| study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) |
| study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials) |
|
|
| study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False) |
| best = study.best_trial |
| best_params = dict(best.params) |
|
|
| |
| lr = float(best_params["lr"]) |
| wd = float(best_params["weight_decay"]) |
| dropout = float(best_params["dropout"]) |
| hidden = int(best_params["hidden_dim"]) |
| n_heads = int(best_params["n_heads"]) |
| n_layers = int(best_params["n_layers"]) |
| cls_w = float(best_params["cls_weight"]) |
| batch = int(best_params["batch_size"]) |
|
|
| loss_reg = nn.MSELoss() |
| loss_cls = nn.CrossEntropyLoss() |
|
|
| if mode == "pooled": |
| Ht = len(train_ds[0]["target_embedding"]) |
| binder_key = "binder_embedding" if "binder_embedding" in train_ds.column_names else "embedding" |
| Hb = len(train_ds[0][binder_key]) |
| model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
| collate = collate_pair_pooled |
| train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
| val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
| eval_fn = eval_spearman_pooled |
| train_fn = train_one_epoch_pooled |
| else: |
| Ht = len(train_ds[0]["target_embedding"][0]) |
| Hb = len(train_ds[0]["binder_embedding"][0]) |
| model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) |
| collate = collate_pair_unpooled |
| train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) |
| val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) |
| eval_fn = eval_spearman_unpooled |
| train_fn = train_one_epoch_unpooled |
|
|
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) |
|
|
| best_rho = -1e9 |
| bad = 0 |
| patience = 20 |
| best_state = None |
|
|
| for ep in range(1, 201): |
| train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) |
| rho = eval_fn(model, val_loader) |
|
|
| if rho > best_rho + 1e-6: |
| best_rho = rho |
| bad = 0 |
| best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} |
| else: |
| bad += 1 |
| if bad >= patience: |
| break |
|
|
| if best_state is not None: |
| model.load_state_dict(best_state) |
|
|
| |
| torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt") |
| with open(out_dir / "best_params.json", "w") as f: |
| json.dump(best_params, f, indent=2) |
|
|
| print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)") |
| ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True) |
| ap.add_argument("--out_dir", type=str, required=True) |
| ap.add_argument("--n_trials", type=int, default=50) |
| args = ap.parse_args() |
|
|
| run( |
| dataset_path=args.dataset_path, |
| out_dir=args.out_dir, |
| mode=args.mode, |
| n_trials=args.n_trials, |
| ) |
|
|