import os import json import argparse import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.cuda.amp import autocast, GradScaler from datasets import load_from_disk, DatasetDict from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score from scipy.stats import spearmanr from lightning.pytorch import seed_everything from typing import Dict, Optional scaler_amp = GradScaler(enabled=torch.cuda.is_available()) def load_split(dataset_path): ds = load_from_disk(dataset_path) if isinstance(ds, DatasetDict): return ds["train"], ds["val"] raise ValueError("Expected DatasetDict with 'train' and 'val' splits") def infer_in_dim(ds) -> int: return int(len(ds[0]["embedding"][0])) def collate_unpooled_reg(batch): lengths = [int(x["length"]) for x in batch] Lmax = max(lengths) H = len(batch[0]["embedding"][0]) X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32) M = torch.zeros(len(batch), Lmax, dtype=torch.bool) y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) for i, x in enumerate(batch): emb = torch.tensor(x["embedding"], dtype=torch.float32) L = emb.shape[0] X[i, :L] = emb if "attention_mask" in x: m = torch.tensor(x["attention_mask"], dtype=torch.bool) M[i, :L] = m[:L] else: M[i, :L] = True return X, M, y # ======================== Models ========================================= class MaskedMeanPool(nn.Module): def forward(self, X, M): Mf = M.unsqueeze(-1).float() return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) class MLPRegressor(nn.Module): def __init__(self, in_dim, hidden=512, dropout=0.1): super().__init__() self.pool = MaskedMeanPool() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1) ) def forward(self, X, M): return self.net(self.pool(X, M)).squeeze(-1) class CNNRegressor(nn.Module): def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): super().__init__() blocks, ch = [], in_ch for _ in range(layers): blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)] ch = c self.conv = nn.Sequential(*blocks) self.head = nn.Linear(c, 1) def forward(self, X, M): Y = self.conv(X.transpose(1, 2)).transpose(1, 2) Mf = M.unsqueeze(-1).float() return self.head((Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1) class TransformerRegressor(nn.Module): def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): super().__init__() self.proj = nn.Linear(in_dim, d_model) self.enc = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff, dropout=dropout, batch_first=True, activation="gelu"), num_layers=layers ) self.head = nn.Linear(d_model, 1) def forward(self, X, M): Z = self.enc(self.proj(X), src_key_padding_mask=~M) Mf = M.unsqueeze(-1).float() return self.head((Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1) # ======================== utils ========================================= def safe_spearmanr(y_true, y_pred): rho = spearmanr(y_true, y_pred).correlation return 0.0 if (rho is None or np.isnan(rho)) else float(rho) def eval_regression(y_true, y_pred) -> Dict[str, float]: try: from sklearn.metrics import root_mean_squared_error rmse = float(root_mean_squared_error(y_true, y_pred)) except Exception: rmse = float(np.sqrt(mean_squared_error(y_true, y_pred))) return { "spearman_rho": safe_spearmanr(y_true, y_pred), "rmse": rmse, "mae": float(mean_absolute_error(y_true, y_pred)), "r2": float(r2_score(y_true, y_pred)), } def score_from_metrics(metrics, objective): return {"spearman": metrics["spearman_rho"], "neg_rmse": -metrics["rmse"], "r2": metrics["r2"]}[objective] @torch.no_grad() def eval_preds(model, loader, device): model.eval() ys, ps = [], [] for X, M, y in loader: X, M = X.to(device), M.to(device) ps.append(model(X, M).cpu().numpy()) ys.append(y.numpy()) return np.concatenate(ys), np.concatenate(ps) def train_one_epoch(model, loader, optim, criterion, device): model.train() for X, M, y in loader: X, M, y = X.to(device), M.to(device), y.to(device) optim.zero_grad(set_to_none=True) with autocast(enabled=torch.cuda.is_available()): loss = criterion(model(X, M), y) scaler_amp.scale(loss).backward() scaler_amp.step(optim) scaler_amp.update() def build_model(model_name, in_dim, params): dropout = float(params.get("dropout", 0.1)) if model_name == "mlp": return MLPRegressor(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) elif model_name == "cnn": return CNNRegressor(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), layers=int(params["layers"]), dropout=dropout) elif model_name == "transformer": return TransformerRegressor(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]), layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout) raise ValueError(model_name) # ======================== Refit Loop ========================================= def refit_with_seed(dataset_path, base_out_dir, model_name, seed, objective="spearman", device="cuda:0"): model_path = os.path.join(base_out_dir, "best_model.pt") if not os.path.exists(model_path): raise FileNotFoundError(f"No best_model.pt at {model_path}. Run Optuna first.") checkpoint = torch.load(model_path, map_location="cpu") best_params = checkpoint["best_params"] print(f"Loaded best_params from {model_path}") print(json.dumps(best_params, indent=2)) seed_everything(seed) out_dir = os.path.join(base_out_dir, f"seed_{seed}") os.makedirs(out_dir, exist_ok=True) train_ds, val_ds = load_split(dataset_path) print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}") batch_size = int(best_params.get("batch_size", 32)) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True) in_dim = infer_in_dim(train_ds) model = build_model(model_name, in_dim, best_params).to(device) # Loss loss_name = best_params.get("loss", "mse") if loss_name == "mse": criterion = nn.MSELoss() else: criterion = nn.HuberLoss(delta=float(best_params.get("huber_delta", 1.0))) optim = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]), weight_decay=float(best_params["weight_decay"])) best_score, bad, patience = -1e18, 0, 15 best_state, best_metrics = None, {} for epoch in range(1, 201): train_one_epoch(model, train_loader, optim, criterion, device) y_true, y_pred = eval_preds(model, val_loader, device) metrics = eval_regression(y_true, y_pred) score = score_from_metrics(metrics, objective) if score > best_score + 1e-6: best_score = score best_metrics = metrics bad = 0 best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} else: bad += 1 if bad >= patience: print(f"Early stopping at epoch {epoch}") break if best_state: model.load_state_dict(best_state) y_true_val, y_pred_val = eval_preds(model, val_loader, device) final_metrics = eval_regression(y_true_val, y_pred_val) df_val = pd.DataFrame({ "y_true": y_true_val.astype(float), "y_pred": y_pred_val.astype(float), "residual": (y_true_val - y_pred_val).astype(float), "abs_error": np.abs(y_true_val - y_pred_val).astype(float), }) if "sequence" in val_ds.column_names: df_val.insert(0, "sequence", np.asarray(val_ds["sequence"])) df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False) torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed}, os.path.join(out_dir, "model.pt")) summary = {"model": model_name, "seed": seed, **{k: round(v, 6) for k, v in final_metrics.items()}} with open(os.path.join(out_dir, "metrics.json"), "w") as f: json.dump(summary, f, indent=2) print(f"\n[Seed {seed}] rho={final_metrics['spearman_rho']:.4f} " f"RMSE={final_metrics['rmse']:.4f} R2={final_metrics['r2']:.4f}") return summary # ======================== CI aggregation ========================================= def aggregate_seed_results(base_out_dir, seeds): """ Aggregates across seed runs using: - t-distribution 95% CI for Spearman rho, RMSE, R2, MAE For rho > 0.9, use Fisher z-transform CI instead. """ from scipy import stats records = [] for seed in seeds: p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json") if os.path.exists(p): records.append(json.load(open(p))) else: print(f"Warning: missing seed {seed}") if not records: raise ValueError("No seed results found.") df = pd.DataFrame(records) print("\nPer-seed results:") print(df.to_string(index=False)) summary_rows = [] for metric in ["spearman_rho", "rmse", "mae", "r2"]: vals = df[metric].values n = len(vals) mean = vals.mean() std = vals.std(ddof=1) se = std / np.sqrt(n) t_crit = stats.t.ppf(0.975, df=n - 1) ci = t_crit * se row = { "metric": metric, "mean": round(mean, 4), "std": round(std, 4), "ci_95": round(ci, 4), "report": f"{mean:.4f} ± {ci:.4f}", "n_seeds": n, } # Flag if rho is high enough that the t-CI boundary might exceed 1.0 if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95): row["note"] = "rho near boundary — consider Fisher z-transform CI" summary_rows.append(row) summary_df = pd.DataFrame(summary_rows) out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv") summary_df.to_csv(out_path, index=False) print("\n=== Aggregated Metrics (95% CI, t-distribution) ===") for _, row in summary_df.iterrows(): note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else "" print(f" {row['metric']:15s}: {row['report']}{note}") print(f"\nSaved to {out_path}") return summary_df if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dataset_path", type=str, required=True) parser.add_argument("--base_out_dir", type=str, required=True) parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True) parser.add_argument("--seed", type=int, required=True) parser.add_argument("--objective", type=str, default="spearman", choices=["spearman", "neg_rmse", "r2"]) parser.add_argument("--aggregate", action="store_true") parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345]) args = parser.parse_args() if args.aggregate: aggregate_seed_results(args.base_out_dir, args.all_seeds) else: refit_with_seed( dataset_path=args.dataset_path, base_out_dir=args.base_out_dir, model_name=args.model, seed=args.seed, objective=args.objective, )