import os import json import argparse import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader from datasets import load_from_disk, DatasetDict from scipy.stats import spearmanr from scipy import stats as scipy_stats from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score from lightning.pytorch import seed_everything import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) from binding_training import ( CrossAttnPooled, CrossAttnUnpooled, collate_pair_pooled, collate_pair_unpooled, eval_spearman_pooled, eval_spearman_unpooled, train_one_epoch_pooled, train_one_epoch_unpooled, affinity_to_class_tensor, safe_spearmanr, ) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def load_split_paired(path: str): dd = load_from_disk(path) if not isinstance(dd, DatasetDict): raise ValueError(f"Expected DatasetDict at {path}") return dd["train"], dd["val"] def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> dict: try: from sklearn.metrics import root_mean_squared_error rmse = float(root_mean_squared_error(y_true, y_pred)) except Exception: rmse = float(np.sqrt(mean_squared_error(y_true, y_pred))) return { "spearman_rho": safe_spearmanr(y_true, y_pred), "rmse": rmse, "mae": float(mean_absolute_error(y_true, y_pred)), "r2": float(r2_score(y_true, y_pred)), } @torch.no_grad() def predict_all_pooled(model, loader): model.eval() ys, ps = [], [] for t, b, y in loader: t = t.to(DEVICE, non_blocking=True) b = b.to(DEVICE, non_blocking=True) pred, _ = model(t, b) ys.append(y.numpy()) ps.append(pred.detach().cpu().numpy()) return np.concatenate(ys), np.concatenate(ps) @torch.no_grad() def predict_all_unpooled(model, loader): model.eval() ys, ps = [], [] for T, Mt, B, Mb, y in loader: T = T.to(DEVICE, non_blocking=True) Mt = Mt.to(DEVICE, non_blocking=True) B = B.to(DEVICE, non_blocking=True) Mb = Mb.to(DEVICE, non_blocking=True) pred, _ = model(T, Mt, B, Mb) ys.append(y.numpy()) ps.append(pred.detach().cpu().numpy()) return np.concatenate(ys), np.concatenate(ps) def build_model(mode: str, params: dict, train_ds) -> nn.Module: hidden = int(params["hidden_dim"]) n_heads = int(params["n_heads"]) n_layers = int(params["n_layers"]) dropout = float(params["dropout"]) binder_key = "embedding" if "binder_embedding" not in train_ds.column_names else "binder_embedding" if mode == "pooled": Ht = len(train_ds[0]["target_embedding"]) Hb = len(train_ds[0][binder_key]) return CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) else: Ht = len(train_ds[0]["target_embedding"][0]) Hb = len(train_ds[0]["binder_embedding"][0]) return CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE) # Refit def refit_with_seed(dataset_path: str, base_out_dir: str, mode: str, seed: int, patience: int = 20) -> dict: model_path = os.path.join(base_out_dir, "best_model.pt") if not os.path.exists(model_path): raise FileNotFoundError( f"No best_model.pt found at {model_path}. Run Optuna (binding_training.py) first." ) checkpoint = torch.load(model_path, map_location="cpu") best_params = checkpoint["best_params"] print(f"Loaded best_params from {model_path}") print(json.dumps(best_params, indent=2)) seed_everything(seed) out_dir = os.path.join(base_out_dir, f"seed_{seed}") os.makedirs(out_dir, exist_ok=True) train_ds, val_ds = load_split_paired(dataset_path) print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} mode={mode}") batch = int(best_params["batch_size"]) cls_w = float(best_params["cls_weight"]) if mode == "pooled": collate = collate_pair_pooled eval_fn = eval_spearman_pooled train_fn = train_one_epoch_pooled predict = predict_all_pooled else: collate = collate_pair_unpooled eval_fn = eval_spearman_unpooled train_fn = train_one_epoch_unpooled predict = predict_all_unpooled train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate) val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate) model = build_model(mode, best_params, train_ds) opt = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]), weight_decay=float(best_params["weight_decay"])) loss_reg = nn.MSELoss() loss_cls = nn.CrossEntropyLoss() best_rho, bad, best_state = -1e9, 0, None for epoch in range(1, 201): train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w) rho = eval_fn(model, val_loader) if rho > best_rho + 1e-6: best_rho = rho bad = 0 best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} else: bad += 1 if bad >= patience: print(f" Early stopping at epoch {epoch} (best rho={best_rho:.4f})") break if best_state: model.load_state_dict(best_state) y_true, y_pred = predict(model, val_loader) metrics = eval_regression(y_true, y_pred) # Save predictions df_val = pd.DataFrame({ "y_true": y_true.astype(float), "y_pred": y_pred.astype(float), "residual": (y_true - y_pred).astype(float), "abs_error": np.abs(y_true - y_pred).astype(float), }) for col in ("target_sequence", "sequence", "affinity_class"): if col in val_ds.column_names: df_val.insert(0, col, np.asarray(val_ds[col])) df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False) torch.save({"state_dict": model.state_dict(), "best_params": best_params, "mode": mode, "seed": seed}, os.path.join(out_dir, "model.pt")) summary = {"mode": mode, "seed": seed, **{k: round(v, 6) for k, v in metrics.items()}} with open(os.path.join(out_dir, "metrics.json"), "w") as f: json.dump(summary, f, indent=2) print(f"\n[Seed {seed}] rho={metrics['spearman_rho']:.4f} " f"RMSE={metrics['rmse']:.4f} R2={metrics['r2']:.4f}") return summary # CI aggregation def aggregate_seed_results(base_out_dir: str, seeds: list) -> pd.DataFrame: records = [] for seed in seeds: p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json") if os.path.exists(p): records.append(json.load(open(p))) else: print(f"[WARN] Missing seed {seed} at {p}") if not records: raise ValueError("No seed results found — did the refit jobs complete?") df = pd.DataFrame(records) print("\nPer-seed results:") print(df.to_string(index=False)) summary_rows = [] for metric in ["spearman_rho", "rmse", "mae", "r2"]: vals = df[metric].values n = len(vals) mean = vals.mean() std = vals.std(ddof=1) se = std / np.sqrt(n) t_crit = scipy_stats.t.ppf(0.975, df=n - 1) ci = t_crit * se row = { "metric": metric, "mean": round(mean, 4), "std": round(std, 4), "ci_95": round(ci, 4), "report": f"{mean:.4f} ± {ci:.4f}", "n_seeds": n, } if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95): row["note"] = "rho near boundary — consider Fisher z-transform CI" summary_rows.append(row) summary_df = pd.DataFrame(summary_rows) out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv") summary_df.to_csv(out_path, index=False) print("\n=== Aggregated Metrics (95% CI, t-distribution) ===") for _, row in summary_df.iterrows(): note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else "" print(f" {row['metric']:15s}: {row['report']}{note}") print(f"\nSaved → {out_path}") return summary_df if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path") parser.add_argument("--base_out_dir", type=str, required=True, help="Directory containing best_model.pt from the Optuna run") parser.add_argument("--mode", type=str, required=True) parser.add_argument("--seed", type=int, required=True) parser.add_argument("--patience", type=int, default=20) parser.add_argument("--aggregate", action="store_true", help="Aggregate across seed runs instead of training") parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345]) args = parser.parse_args() if args.aggregate: aggregate_seed_results(args.base_out_dir, args.all_seeds) else: refit_with_seed( dataset_path=args.dataset_path, base_out_dir=args.base_out_dir, mode=args.mode, seed=args.seed, patience=args.patience, )