| |
| |
|
|
| import argparse |
| import os |
| import random |
| from copy import deepcopy |
| from typing import Any, Dict |
|
|
| import numpy as np |
| import pandas as pd |
| from hyperopt import fmin, tpe, hp, Trials, STATUS_OK |
| from hyperopt.pyll.base import scope |
| from sklearn.model_selection import StratifiedKFold |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim |
| from torch import Tensor |
|
|
| import tabm |
| import rtdl_num_embeddings |
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed + 1) |
| torch.manual_seed(seed + 2) |
|
|
| def _dump_model_info_sidecar(model_path: str) -> None: |
| try: |
| if not os.path.exists(model_path): |
| return |
| ckpt = torch.load(model_path, map_location='cpu', weights_only=False) |
| sidecar = os.path.splitext(model_path)[0] + ".info.txt" |
| with open(sidecar, "w", encoding="utf-8") as f: |
| def _p(title: str, d): |
| try: |
| f.write(title + "\n") |
| if hasattr(d, "__dict__"): |
| items = sorted(vars(d).items()) |
| elif isinstance(d, dict): |
| items = sorted(d.items()) |
| else: |
| try: |
| items = sorted(d.__dict__.items()) |
| except Exception: |
| items = [] |
| for k, v in items: |
| try: |
| f.write(f"- {k}: {repr(v)}\n") |
| except Exception: |
| f.write(f"- {k}: <unprintable>\n") |
| f.write("=" * len(title) + "\n") |
| except Exception: |
| pass |
|
|
| _p("===== checkpoint['args'] =====", ckpt.get('args')) |
| _p("===== checkpoint['training_args'] =====", ckpt.get('training_args', {})) |
| _p("===== checkpoint['best_params'] =====", ckpt.get('best_params', {})) |
| _p("===== checkpoint['full_args'] =====", ckpt.get('full_args', {})) |
|
|
| if ckpt.get("used_feature_idx") is not None: |
| ufi = ckpt["used_feature_idx"] |
| f.write("===== used_feature_idx =====\n") |
| try: |
| f.write(f"- length: {len(ufi)}\n") |
| f.write(f"- head: {list(ufi[:10])}\n") |
| except Exception: |
| f.write("<unprintable>\n") |
| f.write("=" * 25 + "\n") |
|
|
| |
| try: |
| f.write("===== Environment =====\n") |
| f.write(f"- torch: {torch.__version__}\n") |
| f.write(f"- cuda available: {torch.cuda.is_available()}\n") |
| if torch.cuda.is_available(): |
| f.write(f"- device: {torch.cuda.get_device_name(0)}\n") |
| f.write(f"- cuda version: {torch.version.cuda}\n") |
| import tabm as _tabm_mod |
| f.write(f"- tabm: {getattr(_tabm_mod, '__version__', 'unknown')}\n") |
| f.write("========================\n") |
| except Exception: |
| pass |
| except Exception: |
| pass |
| def load_training_data(data_file: str) -> tuple[np.ndarray, np.ndarray]: |
| |
| |
| df = pd.read_csv( |
| data_file, |
| sep='\t', |
| header=0, |
| dtype=str, |
| keep_default_na=False, |
| na_filter=False, |
| engine='python', |
| ) |
|
|
| if df.shape[0] == 0 or df.shape[1] < 2: |
| raise ValueError( |
| f"Incorrect training data format: {data_file}, requires at least 1 label column + 1 feature column, actual shape={df.shape}" |
| ) |
|
|
| |
| label_col = 'label' if 'label' in df.columns else df.columns[0] |
|
|
| |
| y = pd.to_numeric(df[label_col], errors='coerce').fillna(0).astype(np.int64).to_numpy() |
|
|
| |
| feature_cols = [c for c in df.columns if c != label_col] |
| if len(feature_cols) == 0: |
| raise ValueError("No feature columns found") |
|
|
| X_df = df[feature_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0) |
| X = X_df.to_numpy(dtype=np.float32) |
|
|
| return X, y |
|
|
| def build_num_embeddings(embedding_type: str, X_fold: np.ndarray) -> tuple[Any, np.ndarray]: |
| used_idx = np.arange(X_fold.shape[1]) |
| if embedding_type == 'piecewise': |
| var = X_fold.var(axis=0) |
| used_idx = np.where(var > 0.0)[0] |
| X_fold = X_fold[:, used_idx] |
| if len(used_idx) < 1: |
| return None, used_idx |
| try: |
| X_tensor = torch.as_tensor(X_fold, dtype=torch.float32) |
| num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings( |
| rtdl_num_embeddings.compute_bins(X_tensor, n_bins=48), |
| d_embedding=16, |
| activation=False, |
| version='B', |
| ) |
| return num_embeddings, used_idx |
| except Exception: |
| return None, used_idx |
| elif embedding_type == 'linear': |
| return rtdl_num_embeddings.LinearReLUEmbeddings(X_fold.shape[1]), used_idx |
| elif embedding_type == 'periodic': |
| return rtdl_num_embeddings.PeriodicEmbeddings(X_fold.shape[1], lite=False), used_idx |
| else: |
| return None, used_idx |
|
|
| def make_model(n_features: int, |
| k: int, |
| n_blocks: int, |
| d_block: int, |
| num_embeddings: Any, |
| arch_type: str = 'tabm') -> nn.Module: |
| return tabm.TabM.make( |
| n_num_features=n_features, |
| cat_cardinalities=[], |
| d_out=2, |
| k=k, |
| n_blocks=n_blocks, |
| d_block=d_block, |
| num_embeddings=num_embeddings, |
| arch_type=arch_type, |
| ) |
|
|
| def train_one_epoch(model: nn.Module, |
| X: torch.Tensor, |
| y: torch.Tensor, |
| optimizer: torch.optim.Optimizer, |
| batch_size: int, |
| device: torch.device) -> float: |
| model.train() |
| indices = torch.randperm(len(X), device=device) |
| batches = indices.split(batch_size) |
| total_loss = 0.0 |
| share_training_batches = True |
|
|
| def loss_fn(y_pred: Tensor, y_true: Tensor) -> Tensor: |
| |
| y_pred = y_pred.flatten(0, 1) |
| if share_training_batches: |
| y_true = y_true.repeat_interleave(model.backbone.k) |
| else: |
| y_true = y_true.flatten(0, 1) |
| return nn.functional.cross_entropy(y_pred, y_true) |
|
|
| for batch_idx in batches: |
| optimizer.zero_grad() |
| logits = model(X[batch_idx]) |
| loss = loss_fn(logits, y[batch_idx]) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| total_loss += float(loss.detach().cpu()) |
| return total_loss / max(1, len(batches)) |
|
|
| def sum_rank_correct_numpy(y_true: np.ndarray, y_prob: np.ndarray, alpha: float = 0.005) -> float: |
| idx = np.argsort(-y_prob) |
| y_sorted = y_true[idx] |
| r = np.where(y_sorted == 1)[0] |
| return float(np.sum(np.exp(-alpha * r))) |
|
|
| @torch.inference_mode() |
| def evaluate_sum_exp_rank(model: nn.Module, X: torch.Tensor, y: torch.Tensor, device: torch.device, alpha: float = 0.005) -> float: |
| model.eval() |
| eval_bs = 8096 |
| logits = torch.cat([ |
| model(X[idx]).mean(1) |
| for idx in torch.arange(len(X), device=device).split(eval_bs) |
| ]) |
| probs_pos = torch.softmax(logits, dim=1)[:, 1].cpu().numpy() |
| y_true = y.cpu().numpy() |
| return sum_rank_correct_numpy(y_true, probs_pos, alpha) |
|
|
|
|
| def objective(params: Dict[str, Any], |
| X: np.ndarray, |
| y: np.ndarray, |
| device: torch.device, |
| seed: int, |
| cv_folds: int, |
| epochs: int, |
| batch_size: int, |
| alpha: float = 0.005) -> Dict[str, Any]: |
|
|
| k = int(params.get('k', 32)) |
| n_blocks = int(params['n_blocks']) |
| d_block = int(params['d_block']) |
| lr = float(params['lr']) |
| wd_choice = params['weight_decay_choice'] |
| weight_decay = 0.0 if wd_choice == 0 else float(params['weight_decay_val']) |
| embedding_type = params['embedding_type'] |
| arch_type = params['arch_type'] |
|
|
| cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=seed) |
| ap_scores: list[float] = [] |
|
|
| for train_idx, val_idx in cv.split(X, y): |
| X_tr = X[train_idx] |
| y_tr = y[train_idx] |
| X_va = X[val_idx] |
| y_va = y[val_idx] |
|
|
| num_embeddings, used_idx = build_num_embeddings(embedding_type, X_tr) |
| X_tr_used = X_tr[:, used_idx] if len(used_idx) != X_tr.shape[1] else (X_tr if embedding_type != 'piecewise' else X_tr[:, used_idx]) |
| X_va_used = X_va[:, used_idx] if embedding_type == 'piecewise' else X_va |
|
|
| n_features = X_tr_used.shape[1] |
| model = make_model(n_features, k, n_blocks, d_block, num_embeddings, arch_type).to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
| X_tr_t = torch.as_tensor(X_tr_used, device=device) |
| y_tr_t = torch.as_tensor(y_tr, device=device) |
| X_va_t = torch.as_tensor(X_va_used, device=device) |
| y_va_t = torch.as_tensor(y_va, device=device) |
|
|
| for _ in range(epochs): |
| train_one_epoch(model, X_tr_t, y_tr_t, optimizer, batch_size, device) |
|
|
| score = evaluate_sum_exp_rank(model, X_va_t, y_va_t, device, alpha) |
| ap_scores.append(score) |
|
|
| mean_score = float(np.mean(ap_scores)) |
| return {"loss": -mean_score, "status": STATUS_OK, "score": mean_score} |
|
|
| def train_final(X: np.ndarray, |
| y: np.ndarray, |
| best_params: Dict[str, Any], |
| device: torch.device, |
| final_epochs: int, |
| batch_size: int, |
| output_path: str, |
| seed: int, |
| alpha: float = 0.005) -> None: |
| k = int(best_params.get('k', 32)) |
| n_blocks = int(best_params['n_blocks']) |
| d_block = int(best_params['d_block']) |
| lr = float(best_params['lr']) |
| wd_choice = best_params['weight_decay_choice'] |
| weight_decay = 0.0 if wd_choice == 0 else float(best_params['weight_decay_val']) |
| embedding_type = best_params['embedding_type'] |
| arch_type = best_params['arch_type'] |
|
|
| num_embeddings, used_idx = build_num_embeddings(embedding_type, X) |
| X_used = X[:, used_idx] if embedding_type == 'piecewise' else X |
| n_features = X_used.shape[1] |
|
|
| model = make_model(n_features, k, n_blocks, d_block, num_embeddings, arch_type).to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
| X_t = torch.as_tensor(X_used, device=device) |
| y_t = torch.as_tensor(y, device=device) |
|
|
| for _ in range(final_epochs): |
| train_one_epoch(model, X_t, y_t, optimizer, batch_size, device) |
|
|
| os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "args": argparse.Namespace( |
| k=k, |
| n_blocks=n_blocks, |
| d_block=d_block, |
| use_embeddings=True if embedding_type in ("linear", "periodic", "piecewise") else False, |
| embedding_type=embedding_type, |
| arch_type=arch_type, |
| ), |
| "best_params": deepcopy(best_params), |
| "training_args": { |
| "lr": lr, |
| "weight_decay_choice": wd_choice, |
| "weight_decay_val": weight_decay, |
| "batch_size": batch_size, |
| "final_epochs": final_epochs, |
| "seed": seed, |
| "alpha": alpha, |
| "device": str(device), |
| }, |
| "used_feature_idx": used_idx, |
| "full_args": dict( |
| best_params=deepcopy(best_params), |
| final_epochs=final_epochs, batch_size=batch_size, |
| seed=seed, alpha=alpha, device=str(device), |
| ), |
| "search_space": "hyperopt space v1", |
| }, output_path) |
| print(f"Final models saved into: {output_path}") |
| _dump_model_info_sidecar(output_path) |
|
|
| def hyperopt_search(X: np.ndarray, |
| y: np.ndarray, |
| device: torch.device, |
| seed: int, |
| cv_folds: int, |
| epochs: int, |
| batch_size: int, |
| alpha: float, |
| tune_k: bool, |
| max_evals: int) -> tuple[dict, float]: |
| space = { |
| "n_blocks": scope.int(hp.quniform("n_blocks", 2, 5, 1)), |
| "d_block": scope.int(hp.quniform("d_block", 64, 1024, 16)), |
| "lr": hp.loguniform("lr", np.log(1e-4), np.log(5e-3)), |
| "weight_decay_choice": hp.choice("weight_decay_choice", [0, 1]), |
| "weight_decay_val": hp.loguniform("weight_decay_val", np.log(1e-4), np.log(1e-1)), |
| "embedding_type": hp.choice("embedding_type", ["none", "linear", "periodic", "piecewise"]), |
| "arch_type": hp.choice("arch_type", ["tabm", "tabm-mini"]), |
| } |
| if tune_k: |
| space["k"] = scope.int(hp.quniform("k", 16, 32, 8)) |
| else: |
| space["k"] = 32 |
|
|
| def obj_fn(hparams): |
| return objective(hparams, X, y, device, seed, cv_folds, epochs, batch_size, alpha) |
|
|
| trials = Trials() |
| best = fmin(fn=obj_fn, space=space, algo=tpe.suggest, max_evals=max_evals, trials=trials) |
| best_trial = min(trials.trials, key=lambda t: t["result"]["loss"]) |
| best_ap = -best_trial["result"]["loss"] |
| best_params = best_trial["misc"]["vals"].copy() |
|
|
| emb_choices = ["none", "linear", "periodic", "piecewise"] |
| best_params["embedding_type"] = emb_choices[int(best_params["embedding_type"][0])] if isinstance(best_params["embedding_type"], list) else best_params["embedding_type"] |
| arch_choices = ["tabm", "tabm-mini"] |
| best_params["arch_type"] = arch_choices[int(best_params["arch_type"][0])] if isinstance(best_params["arch_type"], list) else best_params["arch_type"] |
| if isinstance(best_params.get("k", 32), list): |
| best_params["k"] = int(best_params["k"][0]) |
| for k_ in ["n_blocks", "d_block", "weight_decay_choice"]: |
| if isinstance(best_params[k_], list): |
| best_params[k_] = int(best_params[k_][0]) |
| for k_ in ["lr", "weight_decay_val"]: |
| if isinstance(best_params[k_], list): |
| best_params[k_] = float(best_params[k_][0]) |
|
|
| return best_params, float(best_ap) |
|
|
| def run_one_pipeline(rep_idx: int, |
| X: np.ndarray, |
| y: np.ndarray, |
| device_str: str, |
| args_dict: dict, |
| out_dir: str, |
| base: str, |
| ext: str) -> str: |
| device = torch.device(device_str) |
| rep_seed = int(args_dict["seed"]) + 997 * int(rep_idx) |
| set_seed(rep_seed) |
|
|
| print(f"[rep {rep_idx}] ๐ Starting hyperparameter search (max_evals={args_dict['max_evals']}) ...") |
| best_params, best_ap = hyperopt_search( |
| X, y, device, |
| seed=rep_seed, |
| cv_folds=args_dict["cv_folds"], |
| epochs=args_dict["epochs"], |
| batch_size=args_dict["batch_size"], |
| alpha=args_dict["alpha"], |
| tune_k=args_dict["tune_k"], |
| max_evals=args_dict["max_evals"], |
| ) |
| print(f"[rep {rep_idx}] ๐ฏ Best sum_exp_rank={best_ap:.6f}") |
| print(f"[rep {rep_idx}] ๐ฏ Best parameters={best_params}") |
|
|
| out_path = os.path.join(out_dir, f"{base}_rep{rep_idx}{ext}") |
| print(f"[rep {rep_idx}] ๐๏ธ Starting final training and saving to: {out_path}") |
| train_final( |
| X, y, best_params, device, |
| final_epochs=args_dict["final_epochs"], |
| batch_size=args_dict["batch_size"], |
| output_path=out_path, |
| seed=rep_seed, |
| alpha=args_dict["alpha"], |
| ) |
| return out_path |
|
|
| def main(): |
|
|
| ap = argparse.ArgumentParser(description="TabM hyperparameter search (Hyperopt) with internal cross-validation, target=AUPRC; training set only, no external validation/test") |
| ap.add_argument("--data_file", type=str, default="Neopep_ml_with_labels.txt", help="Training data TSV") |
| ap.add_argument("--model_out", type=str, default="tabm_results/tabm_hyperopt_best.pth", help="Final model save path (or base name within directory)") |
| ap.add_argument("--max_evals", type=int, default=30, help="Number of Hyperopt evaluations per parallel repetition") |
| ap.add_argument("--cv_folds", type=int, default=5, help="Number of cross-validation folds") |
| ap.add_argument("--epochs", type=int, default=40, help="Training epochs per fold") |
| ap.add_argument("--final_epochs", type=int, default=120, help="Final model training epochs") |
| ap.add_argument("--batch_size", type=int, default=256, help="Batch size") |
| ap.add_argument("--seed", type=int, default=42, help="Random seed (each repetition will be offset when running in parallel)") |
| ap.add_argument("--alpha", type=float, default=0.005, help="Alpha for sum_exp_rank") |
| ap.add_argument("--tune_k", action="store_true", help="Whether to search for k together (default fixed at 32)") |
| ap.add_argument("--device", type=str, default="auto", help="Device selection: auto/cuda/cpu") |
| ap.add_argument("--nr_hyperopt_rep", type=int, default=1, help="Parallel repetition count: each independent hyperparameter search + final training") |
| args = ap.parse_args() |
|
|
| set_seed(args.seed) |
| |
| |
| if args.device == "auto": |
| if torch.cuda.is_available(): |
| device = torch.device('cuda:0') |
| print(f"๐ Detected GPU: {torch.cuda.get_device_name(0)}") |
| print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") |
| print(f" CUDA Version: {torch.version.cuda}") |
| else: |
| device = torch.device('cpu') |
| print("โ ๏ธ No GPU detected, using CPU") |
| elif args.device == "cuda": |
| if torch.cuda.is_available(): |
| device = torch.device('cuda:0') |
| print(f"๐ Forcing GPU usage: {torch.cuda.get_device_name(0)}") |
| else: |
| raise RuntimeError("CUDA specified but no GPU detected") |
| else: |
| device = torch.device('cpu') |
| print("๐ฅ๏ธ Using CPU") |
|
|
| X, y = load_training_data(args.data_file) |
| print(f"Training data: {X.shape}, Positive sample ratio: {np.mean(y):.5f}") |
|
|
| out_dir = os.path.dirname(args.model_out) or '.' |
| os.makedirs(out_dir, exist_ok=True) |
| base = os.path.splitext(os.path.basename(args.model_out))[0] |
| ext = os.path.splitext(args.model_out)[1] or '.pth' |
|
|
| args_dict = { |
| "seed": int(args.seed), |
| "cv_folds": int(args.cv_folds), |
| "epochs": int(args.epochs), |
| "final_epochs": int(args.final_epochs), |
| "batch_size": int(args.batch_size), |
| "alpha": float(args.alpha), |
| "tune_k": bool(args.tune_k), |
| "max_evals": int(args.max_evals), |
| } |
|
|
| from multiprocessing import get_context |
| ctx = get_context('spawn') |
| repeats = int(args.nr_hyperopt_rep) |
| print(f"๐งต Parallel repetitions: {repeats} (each independent hyperparameter search + final training)") |
|
|
| with ctx.Pool(processes=repeats) as pool: |
| paths = pool.starmap( |
| run_one_pipeline, |
| [(i, X, y, str(device), args_dict, out_dir, base, ext) for i in range(repeats)] |
| ) |
| print("Saved model files:") |
| for p in sorted(paths): |
| print("-", p) |
|
|
| if __name__ == "__main__": |
| main() |