import numpy as np import torch from torch.utils.data import DataLoader from datasets import load_from_disk, DatasetDict from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score import torch.nn as nn import os import json import pandas as pd import argparse from typing import Optional from lightning.pytorch import seed_everything def infer_in_dim_from_unpooled_ds(ds) -> int: ex = ds[0] return int(len(ex["embedding"][0])) def load_split(dataset_path): ds = load_from_disk(dataset_path) if isinstance(ds, DatasetDict): return ds["train"], ds["val"] raise ValueError("Expected DatasetDict with 'train' and 'val' splits") def collate_unpooled(batch): lengths = [int(x["length"]) for x in batch] Lmax = max(lengths) H = len(batch[0]["embedding"][0]) X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32) M = torch.zeros(len(batch), Lmax, dtype=torch.bool) y = torch.tensor([x["label"] for x in batch], dtype=torch.float32) for i, x in enumerate(batch): emb = torch.tensor(x["embedding"], dtype=torch.float32) L = emb.shape[0] X[i, :L] = emb if "attention_mask" in x: m = torch.tensor(x["attention_mask"], dtype=torch.bool) M[i, :L] = m[:L] else: M[i, :L] = True return X, M, y # ======================== Models ========================================= class MaskedMeanPool(nn.Module): def forward(self, X, M): Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom class MLPClassifier(nn.Module): def __init__(self, in_dim, hidden=512, dropout=0.1): super().__init__() self.pool = MaskedMeanPool() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1), ) def forward(self, X, M): return self.net(self.pool(X, M)).squeeze(-1) class CNNClassifier(nn.Module): def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): super().__init__() blocks, ch = [], in_ch for _ in range(layers): blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)] ch = c self.conv = nn.Sequential(*blocks) self.head = nn.Linear(c, 1) def forward(self, X, M): Y = self.conv(X.transpose(1, 2)).transpose(1, 2) Mf = M.unsqueeze(-1).float() pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) return self.head(pooled).squeeze(-1) class TransformerClassifier(nn.Module): def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): super().__init__() self.proj = nn.Linear(in_dim, d_model) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=ff, dropout=dropout, batch_first=True, activation="gelu" ) self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) self.head = nn.Linear(d_model, 1) def forward(self, X, M): Z = self.enc(self.proj(X), src_key_padding_mask=~M) Mf = M.unsqueeze(-1).float() pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) return self.head(pooled).squeeze(-1) # ======================== Training utils ========================================= def best_f1_threshold(y_true, y_prob): p, r, thr = precision_recall_curve(y_true, y_prob) f1s = (2 * p[:-1] * r[:-1]) / (p[:-1] + r[:-1] + 1e-12) i = int(np.nanargmax(f1s)) return float(thr[i]), float(f1s[i]) @torch.no_grad() def eval_probs(model, loader, device): model.eval() ys, ps = [], [] for X, M, y in loader: X, M = X.to(device), M.to(device) ps.append(torch.sigmoid(model(X, M)).cpu().numpy()) ys.append(y.numpy()) return np.concatenate(ys), np.concatenate(ps) def train_one_epoch(model, loader, optim, criterion, device): model.train() for X, M, y in loader: X, M, y = X.to(device), M.to(device), y.to(device) optim.zero_grad(set_to_none=True) criterion(model(X, M), y).backward() optim.step() def build_model(model_name, in_dim, params): dropout = float(params.get("dropout", 0.1)) if model_name == "mlp": return MLPClassifier(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) elif model_name == "cnn": return CNNClassifier(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), layers=int(params["layers"]), dropout=dropout) elif model_name == "transformer": return TransformerClassifier(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]), layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout) raise ValueError(model_name) # ======================== Main refit ========================================= def refit_with_seed(dataset_path, base_out_dir, model_name, seed, device="cuda:0"): """ Loads best_params from base_out_dir/best_model.pt (saved by original Optuna run), retrains with the given seed, saves results to base_out_dir/seed_{seed}/. """ # Load best params from completed Optuna run model_path = os.path.join(base_out_dir, "best_model.pt") if not os.path.exists(model_path): raise FileNotFoundError(f"No best_model.pt found at {model_path}. Run Optuna first.") checkpoint = torch.load(model_path, map_location="cpu") best_params = checkpoint["best_params"] print(f"Loaded best_params from {model_path}") print(json.dumps(best_params, indent=2)) # Seed seed_everything(seed) out_dir = os.path.join(base_out_dir, f"seed_{seed}") os.makedirs(out_dir, exist_ok=True) # Data import train_ds, val_ds = load_split(dataset_path) print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}") batch_size = int(best_params.get("batch_size", 32)) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_unpooled, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=collate_unpooled, num_workers=4, pin_memory=True) in_dim = infer_in_dim_from_unpooled_ds(train_ds) model = build_model(model_name, in_dim, best_params).to(device) # Loss ytr = np.asarray(train_ds["label"], dtype=np.int64) pos, neg = ytr.sum(), len(ytr) - ytr.sum() pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32) criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) optim = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]), weight_decay=float(best_params["weight_decay"])) # Training loop with early stopping best_f1, best_thr, bad, patience = -1.0, 0.5, 0, 12 best_state = None for epoch in range(1, 151): train_one_epoch(model, train_loader, optim, criterion, device) y_true, y_prob = eval_probs(model, val_loader, device) thr, f1 = best_f1_threshold(y_true, y_prob) if f1 > best_f1 + 1e-4: best_f1 = f1 best_thr = thr bad = 0 best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} else: bad += 1 if bad >= patience: print(f"Early stopping at epoch {epoch}") break if best_state is not None: model.load_state_dict(best_state) # Final eval y_true_val, y_prob_val = eval_probs(model, val_loader, device) best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val) auc_final = roc_auc_score(y_true_val, y_prob_val) # Save df_val = pd.DataFrame({ "y_true": y_true_val.astype(int), "y_prob": y_prob_val.astype(float), "y_pred": (y_prob_val >= best_thr_final).astype(int), }) if "sequence" in val_ds.column_names: df_val.insert(0, "sequence", np.asarray(val_ds["sequence"])) df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False) torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed}, os.path.join(out_dir, "model.pt")) summary = { "model": model_name, "seed": seed, "val_f1": round(best_f1_final, 6), "val_auc": round(auc_final, 6), "val_thr": round(best_thr_final, 6), } with open(os.path.join(out_dir, "metrics.json"), "w") as f: json.dump(summary, f, indent=2) print(f"\n[Seed {seed}] F1={best_f1_final:.4f} AUC={auc_final:.4f} thr={best_thr_final:.4f}") print(f"Saved to {out_dir}") return summary # ======================== CI aggregation ========================================= def aggregate_seed_results(base_out_dir, seeds): """ Call after all seed runs finish to compute mean ± 95% CI across seeds. Saves a summary CSV to base_out_dir/seed_aggregated_metrics.csv """ from scipy import stats records = [] for seed in seeds: p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json") if os.path.exists(p): records.append(json.load(open(p))) else: print(f"Warning: missing seed {seed} at {p}") if not records: raise ValueError("No seed results found.") df = pd.DataFrame(records) print("\nPer-seed results:") print(df.to_string(index=False)) summary_rows = [] for metric in ["val_f1", "val_auc"]: vals = df[metric].values n = len(vals) mean = vals.mean() std = vals.std(ddof=1) se = std / np.sqrt(n) t_crit = stats.t.ppf(0.975, df=n - 1) ci = t_crit * se summary_rows.append({ "metric": metric, "mean": round(mean, 4), "std": round(std, 4), "ci_95": round(ci, 4), "report": f"{mean:.4f} ± {ci:.4f}", "n_seeds": n, }) summary_df = pd.DataFrame(summary_rows) out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv") summary_df.to_csv(out_path, index=False) print("\n=== Aggregated Metrics (95% CI) ===") for _, row in summary_df.iterrows(): print(f" {row['metric']:12s}: {row['report']} (n={row['n_seeds']})") print(f"\nSaved to {out_path}") return summary_df if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dataset_path", type=str, required=True) parser.add_argument("--base_out_dir", type=str, required=True, help="Directory containing best_model.pt from Optuna run") parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True) parser.add_argument("--seed", type=int, required=True, help="Training seed for this run (1986, 42, 0, 123, 12345)") parser.add_argument("--aggregate", action="store_true", help="After all seeds done: aggregate results into CI summary") parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345], help="All seeds to aggregate (used with --aggregate)") args = parser.parse_args() if args.aggregate: aggregate_seed_results(args.base_out_dir, args.all_seeds) else: refit_with_seed( dataset_path=args.dataset_path, base_out_dir=args.base_out_dir, model_name=args.model, seed=args.seed, )