| import os |
| import json |
| import argparse |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from datasets import load_from_disk, DatasetDict |
| from scipy.stats import spearmanr |
| from scipy import stats as scipy_stats |
| from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score |
| from lightning.pytorch import seed_everything |
| import sys |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| from binding_training import ( |
| CrossAttnPooled, |
| CrossAttnUnpooled, |
| collate_pair_pooled, |
| collate_pair_unpooled, |
| eval_spearman_pooled, |
| eval_spearman_unpooled, |
| train_one_epoch_pooled, |
| train_one_epoch_unpooled, |
| affinity_to_class_tensor, |
| safe_spearmanr, |
| ) |
|
|
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
| def load_split_paired(path: str): |
| dd = load_from_disk(path) |
| if not isinstance(dd, DatasetDict): |
| raise ValueError(f"Expected DatasetDict at {path}") |
| return dd["train"], dd["val"] |
|
|
| |
| def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> dict: |
| try: |
| from sklearn.metrics import root_mean_squared_error |
| rmse = float(root_mean_squared_error(y_true, y_pred)) |
| except Exception: |
| rmse = float(np.sqrt(mean_squared_error(y_true, y_pred))) |
| return { |
| "spearman_rho": safe_spearmanr(y_true, y_pred), |
| "rmse": rmse, |
| "mae": float(mean_absolute_error(y_true, y_pred)), |
| "r2": float(r2_score(y_true, y_pred)), |
| } |
|
|
| @torch.no_grad() |
| def predict_all_pooled(model, loader): |
| model.eval() |
| ys, ps = [], [] |
| for t, b, y in loader: |
| t = t.to(DEVICE, non_blocking=True) |
| b = b.to(DEVICE, non_blocking=True) |
| pred, _ = model(t, b) |
| ys.append(y.numpy()) |
| ps.append(pred.detach().cpu().numpy()) |
| return np.concatenate(ys), np.concatenate(ps) |
|
|
|
|
| @torch.no_grad() |
| def predict_all_unpooled(model, loader): |
| model.eval() |
| ys, ps = [], [] |
| for T, Mt, B, Mb, y in loader: |
| T = T.to(DEVICE, non_blocking=True) |
| Mt = Mt.to(DEVICE, non_blocking=True) |
| B = B.to(DEVICE, non_blocking=True) |
| Mb = Mb.to(DEVICE, non_blocking=True) |
| pred, _ = model(T, Mt, B, Mb) |
| ys.append(y.numpy()) |
| ps.append(pred.detach().cpu().numpy()) |
| return np.concatenate(ys), np.concatenate(ps) |
|
|
|
|
| def build_model(mode: str, params: dict, train_ds) -> nn.Module: |
| hidden = int(params["hidden_dim"]) |
| n_heads = int(params["n_heads"]) |
| n_layers = int(params["n_layers"]) |
| dropout = float(params["dropout"]) |
|
|
| binder_key = "embedding" if "binder_embedding" not in train_ds.column_names else "binder_embedding" |
|
|
| if mode == "pooled": |
| Ht = len(train_ds[0]["target_embedding"]) |
| Hb = len(train_ds[0][binder_key]) |
| return CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, |
| n_layers=n_layers, dropout=dropout).to(DEVICE) |
| else: |
| Ht = len(train_ds[0]["target_embedding"][0]) |
| Hb = len(train_ds[0]["binder_embedding"][0]) |
| return CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, |
| n_layers=n_layers, dropout=dropout).to(DEVICE) |
|
|
|
|
| |
| def refit_with_seed(dataset_path: str, base_out_dir: str, mode: str, |
| seed: int, patience: int = 20) -> dict: |
| model_path = os.path.join(base_out_dir, "best_model.pt") |
| if not os.path.exists(model_path): |
| raise FileNotFoundError( |
| f"No best_model.pt found at {model_path}. Run Optuna (binding_training.py) first." |
| ) |
|
|
| checkpoint = torch.load(model_path, map_location="cpu") |
| best_params = checkpoint["best_params"] |
| print(f"Loaded best_params from {model_path}") |
| print(json.dumps(best_params, indent=2)) |
|
|
| seed_everything(seed) |
| out_dir = os.path.join(base_out_dir, f"seed_{seed}") |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| train_ds, val_ds = load_split_paired(dataset_path) |
| print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} mode={mode}") |
|
|
| batch = int(best_params["batch_size"]) |
| cls_w = float(best_params["cls_weight"]) |
|
|
| if mode == "pooled": |
| collate = collate_pair_pooled |
| eval_fn = eval_spearman_pooled |
| train_fn = train_one_epoch_pooled |
| predict = predict_all_pooled |
| else: |
| collate = collate_pair_unpooled |
| eval_fn = eval_spearman_unpooled |
| train_fn = train_one_epoch_unpooled |
| predict = predict_all_unpooled |
|
|
| train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, |
| num_workers=4, pin_memory=True, collate_fn=collate) |
| val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, |
| num_workers=4, pin_memory=True, collate_fn=collate) |
|
|
| model = build_model(mode, best_params, train_ds) |
| opt = torch.optim.AdamW(model.parameters(), |
| lr=float(best_params["lr"]), |
| weight_decay=float(best_params["weight_decay"])) |
| loss_reg = nn.MSELoss() |
| loss_cls = nn.CrossEntropyLoss() |
|
|
| best_rho, bad, best_state = -1e9, 0, None |
|
|
| for epoch in range(1, 201): |
| train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) |
| rho = eval_fn(model, val_loader) |
|
|
| if rho > best_rho + 1e-6: |
| best_rho = rho |
| bad = 0 |
| best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} |
| else: |
| bad += 1 |
| if bad >= patience: |
| print(f" Early stopping at epoch {epoch} (best rho={best_rho:.4f})") |
| break |
|
|
| if best_state: |
| model.load_state_dict(best_state) |
|
|
| y_true, y_pred = predict(model, val_loader) |
| metrics = eval_regression(y_true, y_pred) |
|
|
| |
| df_val = pd.DataFrame({ |
| "y_true": y_true.astype(float), |
| "y_pred": y_pred.astype(float), |
| "residual": (y_true - y_pred).astype(float), |
| "abs_error": np.abs(y_true - y_pred).astype(float), |
| }) |
| for col in ("target_sequence", "sequence", "affinity_class"): |
| if col in val_ds.column_names: |
| df_val.insert(0, col, np.asarray(val_ds[col])) |
| df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False) |
|
|
| torch.save({"state_dict": model.state_dict(), |
| "best_params": best_params, |
| "mode": mode, |
| "seed": seed}, |
| os.path.join(out_dir, "model.pt")) |
|
|
| summary = {"mode": mode, "seed": seed, |
| **{k: round(v, 6) for k, v in metrics.items()}} |
| with open(os.path.join(out_dir, "metrics.json"), "w") as f: |
| json.dump(summary, f, indent=2) |
|
|
| print(f"\n[Seed {seed}] rho={metrics['spearman_rho']:.4f} " |
| f"RMSE={metrics['rmse']:.4f} R2={metrics['r2']:.4f}") |
| return summary |
|
|
|
|
| |
|
|
| def aggregate_seed_results(base_out_dir: str, seeds: list) -> pd.DataFrame: |
| records = [] |
| for seed in seeds: |
| p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json") |
| if os.path.exists(p): |
| records.append(json.load(open(p))) |
| else: |
| print(f"[WARN] Missing seed {seed} at {p}") |
|
|
| if not records: |
| raise ValueError("No seed results found — did the refit jobs complete?") |
|
|
| df = pd.DataFrame(records) |
| print("\nPer-seed results:") |
| print(df.to_string(index=False)) |
|
|
| summary_rows = [] |
| for metric in ["spearman_rho", "rmse", "mae", "r2"]: |
| vals = df[metric].values |
| n = len(vals) |
| mean = vals.mean() |
| std = vals.std(ddof=1) |
| se = std / np.sqrt(n) |
| t_crit = scipy_stats.t.ppf(0.975, df=n - 1) |
| ci = t_crit * se |
| row = { |
| "metric": metric, |
| "mean": round(mean, 4), |
| "std": round(std, 4), |
| "ci_95": round(ci, 4), |
| "report": f"{mean:.4f} ± {ci:.4f}", |
| "n_seeds": n, |
| } |
| if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95): |
| row["note"] = "rho near boundary — consider Fisher z-transform CI" |
| summary_rows.append(row) |
|
|
| summary_df = pd.DataFrame(summary_rows) |
| out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv") |
| summary_df.to_csv(out_path, index=False) |
|
|
| print("\n=== Aggregated Metrics (95% CI, t-distribution) ===") |
| for _, row in summary_df.iterrows(): |
| note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else "" |
| print(f" {row['metric']:15s}: {row['report']}{note}") |
| print(f"\nSaved → {out_path}") |
| return summary_df |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--dataset_path", type=str, required=True, |
| help="Paired DatasetDict path") |
| parser.add_argument("--base_out_dir", type=str, required=True, |
| help="Directory containing best_model.pt from the Optuna run") |
| parser.add_argument("--mode", type=str, required=True) |
| parser.add_argument("--seed", type=int, required=True) |
| parser.add_argument("--patience", type=int, default=20) |
| parser.add_argument("--aggregate", action="store_true", |
| help="Aggregate across seed runs instead of training") |
| parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345]) |
| args = parser.parse_args() |
|
|
| if args.aggregate: |
| aggregate_seed_results(args.base_out_dir, args.all_seeds) |
| else: |
| refit_with_seed( |
| dataset_path=args.dataset_path, |
| base_out_dir=args.base_out_dir, |
| mode=args.mode, |
| seed=args.seed, |
| patience=args.patience, |
| ) |