Joblib
PeptiVerse / training_classifiers /refit_nn_seed.py
ynuozhang
major update
04c2975
import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_from_disk, DatasetDict
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
import torch.nn as nn
import os
import json
import pandas as pd
import argparse
from typing import Optional
from lightning.pytorch import seed_everything
def infer_in_dim_from_unpooled_ds(ds) -> int:
ex = ds[0]
return int(len(ex["embedding"][0]))
def load_split(dataset_path):
ds = load_from_disk(dataset_path)
if isinstance(ds, DatasetDict):
return ds["train"], ds["val"]
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
def collate_unpooled(batch):
lengths = [int(x["length"]) for x in batch]
Lmax = max(lengths)
H = len(batch[0]["embedding"][0])
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
y = torch.tensor([x["label"] for x in batch], dtype=torch.float32)
for i, x in enumerate(batch):
emb = torch.tensor(x["embedding"], dtype=torch.float32)
L = emb.shape[0]
X[i, :L] = emb
if "attention_mask" in x:
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
M[i, :L] = m[:L]
else:
M[i, :L] = True
return X, M, y
# ======================== Models =========================================
class MaskedMeanPool(nn.Module):
def forward(self, X, M):
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
return (X * Mf).sum(dim=1) / denom
class MLPClassifier(nn.Module):
def __init__(self, in_dim, hidden=512, dropout=0.1):
super().__init__()
self.pool = MaskedMeanPool()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, 1),
)
def forward(self, X, M):
return self.net(self.pool(X, M)).squeeze(-1)
class CNNClassifier(nn.Module):
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
super().__init__()
blocks, ch = [], in_ch
for _ in range(layers):
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
ch = c
self.conv = nn.Sequential(*blocks)
self.head = nn.Linear(c, 1)
def forward(self, X, M):
Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
Mf = M.unsqueeze(-1).float()
pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
return self.head(pooled).squeeze(-1)
class TransformerClassifier(nn.Module):
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
super().__init__()
self.proj = nn.Linear(in_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=ff,
dropout=dropout, batch_first=True, activation="gelu"
)
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
self.head = nn.Linear(d_model, 1)
def forward(self, X, M):
Z = self.enc(self.proj(X), src_key_padding_mask=~M)
Mf = M.unsqueeze(-1).float()
pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
return self.head(pooled).squeeze(-1)
# ======================== Training utils =========================================
def best_f1_threshold(y_true, y_prob):
p, r, thr = precision_recall_curve(y_true, y_prob)
f1s = (2 * p[:-1] * r[:-1]) / (p[:-1] + r[:-1] + 1e-12)
i = int(np.nanargmax(f1s))
return float(thr[i]), float(f1s[i])
@torch.no_grad()
def eval_probs(model, loader, device):
model.eval()
ys, ps = [], []
for X, M, y in loader:
X, M = X.to(device), M.to(device)
ps.append(torch.sigmoid(model(X, M)).cpu().numpy())
ys.append(y.numpy())
return np.concatenate(ys), np.concatenate(ps)
def train_one_epoch(model, loader, optim, criterion, device):
model.train()
for X, M, y in loader:
X, M, y = X.to(device), M.to(device), y.to(device)
optim.zero_grad(set_to_none=True)
criterion(model(X, M), y).backward()
optim.step()
def build_model(model_name, in_dim, params):
dropout = float(params.get("dropout", 0.1))
if model_name == "mlp":
return MLPClassifier(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
elif model_name == "cnn":
return CNNClassifier(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
layers=int(params["layers"]), dropout=dropout)
elif model_name == "transformer":
return TransformerClassifier(in_dim=in_dim, d_model=int(params["d_model"]),
nhead=int(params["nhead"]), layers=int(params["layers"]),
ff=int(params["ff"]), dropout=dropout)
raise ValueError(model_name)
# ======================== Main refit =========================================
def refit_with_seed(dataset_path, base_out_dir, model_name, seed, device="cuda:0"):
"""
Loads best_params from base_out_dir/best_model.pt (saved by original Optuna run),
retrains with the given seed, saves results to base_out_dir/seed_{seed}/.
"""
# Load best params from completed Optuna run
model_path = os.path.join(base_out_dir, "best_model.pt")
if not os.path.exists(model_path):
raise FileNotFoundError(f"No best_model.pt found at {model_path}. Run Optuna first.")
checkpoint = torch.load(model_path, map_location="cpu")
best_params = checkpoint["best_params"]
print(f"Loaded best_params from {model_path}")
print(json.dumps(best_params, indent=2))
# Seed
seed_everything(seed)
out_dir = os.path.join(base_out_dir, f"seed_{seed}")
os.makedirs(out_dir, exist_ok=True)
# Data import
train_ds, val_ds = load_split(dataset_path)
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
batch_size = int(best_params.get("batch_size", 32))
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
in_dim = infer_in_dim_from_unpooled_ds(train_ds)
model = build_model(model_name, in_dim, best_params).to(device)
# Loss
ytr = np.asarray(train_ds["label"], dtype=np.int64)
pos, neg = ytr.sum(), len(ytr) - ytr.sum()
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optim = torch.optim.AdamW(model.parameters(),
lr=float(best_params["lr"]),
weight_decay=float(best_params["weight_decay"]))
# Training loop with early stopping
best_f1, best_thr, bad, patience = -1.0, 0.5, 0, 12
best_state = None
for epoch in range(1, 151):
train_one_epoch(model, train_loader, optim, criterion, device)
y_true, y_prob = eval_probs(model, val_loader, device)
thr, f1 = best_f1_threshold(y_true, y_prob)
if f1 > best_f1 + 1e-4:
best_f1 = f1
best_thr = thr
bad = 0
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
else:
bad += 1
if bad >= patience:
print(f"Early stopping at epoch {epoch}")
break
if best_state is not None:
model.load_state_dict(best_state)
# Final eval
y_true_val, y_prob_val = eval_probs(model, val_loader, device)
best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val)
auc_final = roc_auc_score(y_true_val, y_prob_val)
# Save
df_val = pd.DataFrame({
"y_true": y_true_val.astype(int),
"y_prob": y_prob_val.astype(float),
"y_pred": (y_prob_val >= best_thr_final).astype(int),
})
if "sequence" in val_ds.column_names:
df_val.insert(0, "sequence", np.asarray(val_ds["sequence"]))
df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False)
torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed},
os.path.join(out_dir, "model.pt"))
summary = {
"model": model_name,
"seed": seed,
"val_f1": round(best_f1_final, 6),
"val_auc": round(auc_final, 6),
"val_thr": round(best_thr_final, 6),
}
with open(os.path.join(out_dir, "metrics.json"), "w") as f:
json.dump(summary, f, indent=2)
print(f"\n[Seed {seed}] F1={best_f1_final:.4f} AUC={auc_final:.4f} thr={best_thr_final:.4f}")
print(f"Saved to {out_dir}")
return summary
# ======================== CI aggregation =========================================
def aggregate_seed_results(base_out_dir, seeds):
"""
Call after all seed runs finish to compute mean ± 95% CI across seeds.
Saves a summary CSV to base_out_dir/seed_aggregated_metrics.csv
"""
from scipy import stats
records = []
for seed in seeds:
p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json")
if os.path.exists(p):
records.append(json.load(open(p)))
else:
print(f"Warning: missing seed {seed} at {p}")
if not records:
raise ValueError("No seed results found.")
df = pd.DataFrame(records)
print("\nPer-seed results:")
print(df.to_string(index=False))
summary_rows = []
for metric in ["val_f1", "val_auc"]:
vals = df[metric].values
n = len(vals)
mean = vals.mean()
std = vals.std(ddof=1)
se = std / np.sqrt(n)
t_crit = stats.t.ppf(0.975, df=n - 1)
ci = t_crit * se
summary_rows.append({
"metric": metric,
"mean": round(mean, 4),
"std": round(std, 4),
"ci_95": round(ci, 4),
"report": f"{mean:.4f} ± {ci:.4f}",
"n_seeds": n,
})
summary_df = pd.DataFrame(summary_rows)
out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv")
summary_df.to_csv(out_path, index=False)
print("\n=== Aggregated Metrics (95% CI) ===")
for _, row in summary_df.iterrows():
print(f" {row['metric']:12s}: {row['report']} (n={row['n_seeds']})")
print(f"\nSaved to {out_path}")
return summary_df
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True)
parser.add_argument("--base_out_dir", type=str, required=True,
help="Directory containing best_model.pt from Optuna run")
parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
parser.add_argument("--seed", type=int, required=True,
help="Training seed for this run (1986, 42, 0, 123, 12345)")
parser.add_argument("--aggregate", action="store_true",
help="After all seeds done: aggregate results into CI summary")
parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345],
help="All seeds to aggregate (used with --aggregate)")
args = parser.parse_args()
if args.aggregate:
aggregate_seed_results(args.base_out_dir, args.all_seeds)
else:
refit_with_seed(
dataset_path=args.dataset_path,
base_out_dir=args.base_out_dir,
model_name=args.model,
seed=args.seed,
)