| |
| |
|
|
| import argparse |
| import os |
| import numpy as np |
| import pandas as pd |
| import torch |
| import tabm |
| from sklearn.metrics import precision_recall_curve, auc |
|
|
| def normalize_rt(s: pd.Series) -> pd.Series: |
| return s.astype(str).str.strip().str.upper() |
|
|
| def compute_patient_metrics(df_p: pd.DataFrame, y_prob: np.ndarray) -> tuple: |
| X_r = df_p.copy() |
| X_r['ML_pred'] = y_prob |
| X_r['response'] = (normalize_rt(X_r['response_type']) == 'CD8').astype(int) |
|
|
| X_r = X_r.sort_values(by=['ML_pred'], ascending=False).reset_index(drop=True) |
|
|
| idx_pos = np.where(X_r['response'].to_numpy() == 1)[0] |
| idx_tested = np.where(normalize_rt(X_r['response_type']) == 'NEGATIVE')[0] |
|
|
| def topk_counts(k: int): |
| k_eff = min(k, len(X_r)) |
| nr_correct = int(np.sum(idx_pos < k_eff)) |
| nr_tested = nr_correct + int(np.sum(idx_tested < k_eff)) |
| return nr_correct, nr_tested |
|
|
| nr_correct20, nr_tested20 = topk_counts(20) |
| nr_correct50, nr_tested50 = topk_counts(50) |
| nr_correct100, nr_tested100 = topk_counts(100) |
|
|
| nr_immuno = int(np.sum(X_r['response'] == 1)) |
| y_true = X_r['response'].to_numpy() |
| y_pred = X_r['ML_pred'].to_numpy() |
|
|
| alpha = 0.005 |
| score = float(np.sum(np.exp(-alpha * idx_pos))) |
|
|
| if nr_immuno > 0: |
| sort_idx = np.argsort(idx_pos) |
| ranks_str = ",".join([f"{int(r+1)}" for r in idx_pos[sort_idx]]) |
| mut_seqs = X_r.loc[X_r['response'] == 1, 'mutant_seq'].to_numpy() |
| mut_seqs_str = ",".join([str(s) for s in mut_seqs[sort_idx]]) |
| genes = X_r.loc[X_r['response'] == 1, 'gene'].to_numpy() |
| genes_str = ",".join([str(g) for g in genes[sort_idx]]) |
| else: |
| ranks_str = "" |
| mut_seqs_str = "" |
| genes_str = "" |
|
|
| return (X_r['ML_pred'].to_numpy(), X_r, |
| nr_correct20, nr_tested20, |
| nr_correct50, nr_tested50, |
| nr_correct100, nr_tested100, |
| nr_immuno, idx_pos, score, |
| ranks_str, mut_seqs_str, genes_str) |
|
|
|
|
| def predict_in_batches(model, X_all, device, batch_size=1024): |
| model.eval() |
| y_prob_all = [] |
| |
| with torch.inference_mode(): |
| for i in range(0, len(X_all), batch_size): |
| batch_end = min(i + batch_size, len(X_all)) |
| batch_X = X_all[i:batch_end].to(device) |
| |
| batch_pred = model(batch_X).mean(1) |
| batch_pred = torch.softmax(batch_pred, dim=1)[:, 1] |
| |
| y_prob_all.append(batch_pred.cpu()) |
| |
| del batch_X, batch_pred |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| return torch.cat(y_prob_all, dim=0).numpy() |
|
|
| def main(): |
|
|
| ap = argparse.ArgumentParser(description="TabM model evaluation, output format consistent with TestVotingClassifier") |
| ap.add_argument("--model_file", type=str, required=False, help="TabM model file, e.g. tabm_results/tabm_model.pth (mutually exclusive with --model_files/--model_glob, choose one of three)") |
| ap.add_argument("--model_files", type=str, nargs='*', default=None, help="Multiple model files for equal-weighted average prediction") |
| ap.add_argument("--model_glob", type=str, default=None, help="Use wildcards to match multiple model files (e.g. tabm_results/tabm_hyperopt_best_rep*.pth)") |
| ap.add_argument("--data_file", type=str, required=True, help="Input TSV: TestVoting_selection_neopep.tsv") |
| ap.add_argument("--output_file", type=str, required=True, help="Main result output file (header consistent with original)") |
| ap.add_argument("--tesla_file", type=str, default=None, help="TESLA score output file (for neopep task)") |
| ap.add_argument("--output_xlsx", type=str, default=None, help="Main result Excel output path (optional)") |
| ap.add_argument("--tesla_xlsx", type=str, default=None, help="TESLA result Excel output path (optional)") |
| ap.add_argument("--dataset_name", type=str, default=None, help="If no dataset column exists, use this value as Dataset column in TESLA") |
| ap.add_argument("--skip_no_cd8", action="store_true", help="Skip patients without CD8") |
| ap.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], |
| help="Device selection: auto/cuda/cpu") |
| ap.add_argument("--batch_size", type=int, default=1024, |
| help="Batch size to avoid GPU memory overflow (default 1024)") |
| args = ap.parse_args() |
|
|
| |
| if args.device == "auto": |
| if torch.cuda.is_available(): |
| device = torch.device('cuda:0') |
| print(f"🚀 Auto-selected GPU: {torch.cuda.get_device_name(0)}") |
| print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") |
| else: |
| device = torch.device('cpu') |
| print("⚠️ No GPU detected, using CPU") |
| elif args.device == "cuda": |
| if torch.cuda.is_available(): |
| device = torch.device('cuda:0') |
| print(f"🚀 Force using GPU: {torch.cuda.get_device_name(0)}") |
| else: |
| raise RuntimeError("CUDA specified but no GPU detected") |
| else: |
| device = torch.device('cpu') |
| print("️ Using CPU") |
|
|
| print(f" Batch size: {args.batch_size}") |
|
|
| |
| df = pd.read_csv(args.data_file, sep="\t", header=0, low_memory=False) |
| print(f"📈 Data shape: {df.shape}") |
|
|
| |
| required_cols = ["patient", "response_type", "gene", "mutant_seq"] |
| for c in required_cols: |
| if c not in df.columns: |
| raise KeyError(f"Missing required column: {c}") |
|
|
| |
| feature_cols = [c for c in df.columns if c not in required_cols] |
| |
| X_all = df[feature_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0).to_numpy() |
| print(f" Number of features: {X_all.shape[1]}") |
|
|
| |
| import glob as _glob |
| model_paths: list[str] = [] |
| if args.model_files: |
| model_paths.extend(list(args.model_files)) |
| if args.model_glob: |
| model_paths.extend(sorted(_glob.glob(args.model_glob))) |
| if not model_paths and args.model_file: |
| model_paths = [args.model_file] |
| if not model_paths: |
| raise FileNotFoundError("No model files found, please check!") |
|
|
| first_ckpt = torch.load(model_paths[0], map_location='cpu', weights_only=False) |
| model_args = first_ckpt['args'] |
|
|
| def _predict_with_model(model_path: str, X_all_np: np.ndarray) -> np.ndarray: |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not existed: {model_path}") |
| ckpt = torch.load(model_path, map_location='cpu', weights_only=False) |
| m_args = ckpt['args'] |
| X_np = X_all_np |
| if ckpt.get("used_feature_idx") is not None: |
| try: |
| ufi = ckpt["used_feature_idx"] |
| import numpy as _np |
| ufi_arr = _np.array(ufi, dtype=int) |
| max_idx = X_np.shape[1] - 1 |
| ufi_arr = ufi_arr[(ufi_arr >= 0) & (ufi_arr <= max_idx)] |
| if len(ufi_arr) > 0: |
| X_np = X_np[:, ufi_arr] |
| except Exception: |
| pass |
| X_tensor_cpu = torch.as_tensor(X_np, dtype=torch.float32) |
| num_embeddings = None |
| if getattr(m_args, 'use_embeddings', False): |
| if m_args.embedding_type == 'linear': |
| import rtdl_num_embeddings |
| num_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(X_tensor_cpu.shape[1]) |
| elif m_args.embedding_type == 'periodic': |
| import rtdl_num_embeddings |
| num_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(X_tensor_cpu.shape[1], lite=False) |
| elif m_args.embedding_type == 'piecewise': |
| import rtdl_num_embeddings |
| num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings( |
| rtdl_num_embeddings.compute_bins(X_tensor_cpu, n_bins=48), |
| d_embedding=16, |
| activation=False, |
| version='B', |
| ) |
| model = tabm.TabM.make( |
| n_num_features=X_tensor_cpu.shape[1], |
| cat_cardinalities=[], |
| d_out=2, |
| k=m_args.k, |
| n_blocks=m_args.n_blocks, |
| d_block=m_args.d_block, |
| num_embeddings=num_embeddings, |
| arch_type=getattr(m_args, 'arch_type', 'tabm'), |
| ) |
| model.load_state_dict(ckpt['model_state_dict']) |
| model.to(device) |
| model.eval() |
| bs = max(256, args.batch_size) |
| probs_list = [] |
| n = len(X_tensor_cpu) |
| with torch.inference_mode(): |
| for i in range(0, n, bs): |
| j = min(i + bs, n) |
| xb = X_tensor_cpu[i:j].to(device) |
| logits = model(xb).mean(1) |
| probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() |
| probs_list.append(probs) |
| del xb, logits |
| if torch.cuda.is_available() and device.type == 'cuda': |
| torch.cuda.empty_cache() |
| if (i // bs) % 50 == 0: |
| print(f" batch {i//bs}/{(n+bs-1)//bs}") |
| return np.concatenate(probs_list, axis=0) |
|
|
| def _stringify(v): |
| try: |
| return repr(v) |
| except Exception: |
| try: |
| return str(v) |
| except Exception: |
| return "<unprintable>" |
|
|
| print("===== Saved Hyperparameters from checkpoint['args'] =====") |
| if hasattr(model_args, "__dict__"): |
| hp_items = sorted(vars(model_args).items()) |
| elif isinstance(model_args, dict): |
| hp_items = sorted(model_args.items()) |
| else: |
| try: |
| hp_items = sorted(model_args.__dict__.items()) |
| except Exception: |
| hp_items = [] |
| print("⚠️ Unable to enumerate contents of model_args") |
| for key, val in hp_items: |
| print(f"- {key}: {_stringify(val)}") |
| print("=========================================================") |
|
|
| def _p_dict(title, d): |
| try: |
| print(title) |
| for k in sorted(d.keys()): |
| try: |
| print(f"- {k}: {repr(d[k])}") |
| except Exception: |
| print(f"- {k}: <unprintable>") |
| print("=" * len(title)) |
| except Exception: |
| pass |
|
|
| if isinstance(first_ckpt.get("training_args"), dict): |
| _p_dict("===== checkpoint['training_args'] =====", first_ckpt["training_args"]) |
|
|
| if isinstance(first_ckpt.get("best_params"), dict): |
| _p_dict("===== checkpoint['best_params'] =====", first_ckpt["best_params"]) |
|
|
| if isinstance(first_ckpt.get("full_args"), dict): |
| _p_dict("===== checkpoint['full_args'] =====", first_ckpt["full_args"]) |
|
|
| if first_ckpt.get("used_feature_idx") is not None: |
| try: |
| ufi = first_ckpt["used_feature_idx"] |
| print("===== used_feature_idx =====") |
| print(f"- length: {len(ufi)}") |
| print(f"- head: {list(ufi[:10])}") |
| print("=" * 25) |
| except Exception: |
| print("===== used_feature_idx =====\n<unprintable>\n============================") |
|
|
| try: |
| print("===== Environment =====") |
| print(f"- torch: {torch.__version__}") |
| print(f"- cuda available: {torch.cuda.is_available()}") |
| if torch.cuda.is_available(): |
| print(f"- device: {torch.cuda.get_device_name(0)}") |
| print(f"- cuda version: {torch.version.cuda}") |
| import tabm as _tabm_mod |
| print(f"- tabm: {getattr(_tabm_mod, '__version__', 'unknown')}") |
| print("========================") |
| except Exception: |
| pass |
|
|
| n_models = len(model_paths) |
| print(f"🔗 Loading {n_models} models for equal-weighted average prediction...") |
| y_prob_all = None |
| for mp in model_paths: |
| print(f" -> {mp}") |
| probs = _predict_with_model(mp, X_all) |
| if y_prob_all is None: |
| y_prob_all = probs.astype(np.float64) |
| else: |
| y_prob_all += probs |
| y_prob_all = (y_prob_all / float(n_models)).astype(np.float64) |
|
|
| print(f"✅ Prediction completed, total {len(y_prob_all)} samples; number of models={n_models}") |
|
|
| rows_main = [] |
| rows_tesla = [] |
|
|
| need_header = (not os.path.exists(args.output_file)) or (os.path.getsize(args.output_file) == 0) |
| with open(args.output_file, "a", encoding="utf-8") as f: |
| if need_header: |
| f.write("Patient\tNr_correct_top20\tNr_tested_top20\tNr_correct_top50\tNr_tested_top50\t" |
| "Nr_correct_top100\tNr_tested_top100\tNr_immunogenic\tNr_peptides\tClf_score\t" |
| "CD8_ranks\tCD8_mut_seqs\tCD8_genes\n") |
|
|
| for patient, df_p in df.groupby("patient", sort=False): |
| has_cd8 = (normalize_rt(df_p["response_type"]) == "CD8").any() |
| if args.skip_no_cd8 and not has_cd8: |
| continue |
|
|
| idx = df_p.index.to_numpy() |
| y_prob = y_prob_all[idx] |
|
|
| (y_pred_sorted, X_sorted, |
| nr_correct20, nr_tested20, |
| nr_correct50, nr_tested50, |
| nr_correct100, nr_tested100, |
| nr_immuno, r, score, |
| ranks_str, mut_seqs_str, genes_str) = compute_patient_metrics(df_p, y_prob) |
|
|
| f.write(f"{patient}\t{nr_correct20}\t{nr_tested20}\t{nr_correct50}\t{nr_tested50}\t" |
| f"{nr_correct100}\t{nr_tested100}\t{nr_immuno}\t{len(df_p)}\t{score:.6f}\t" |
| f"{ranks_str}\t{mut_seqs_str}\t{genes_str}\n") |
|
|
| rows_main.append({ |
| "Patient": patient, |
| "Nr_correct_top20": nr_correct20, |
| "Nr_tested_top20": nr_tested20, |
| "Nr_correct_top50": nr_correct50, |
| "Nr_tested_top50": nr_tested50, |
| "Nr_correct_top100": nr_correct100, |
| "Nr_tested_top100": nr_tested100, |
| "Nr_immunogenic": nr_immuno, |
| "Nr_peptides": len(df_p), |
| "Clf_score": score, |
| "CD8_ranks": ranks_str, |
| "CD8_mut_seqs": mut_seqs_str, |
| "CD8_genes": genes_str, |
| }) |
|
|
| if args.tesla_file or args.tesla_xlsx: |
| if "dataset" in df_p.columns: |
| dataset_val = str(df_p["dataset"].iloc[0]) |
| else: |
| dataset_val = args.dataset_name if args.dataset_name is not None else "" |
| idx_nt = X_sorted['response_type'].astype(str) != 'not_tested' |
| y_pred_tesla = pd.Series(y_pred_sorted)[idx_nt].to_numpy() |
| y_tesla = X_sorted.loc[idx_nt, 'response'].to_numpy() |
| ttif = (nr_correct20 / nr_tested20) if nr_tested20 > 0 else 0.0 |
| fr = (nr_correct100 / nr_immuno) if nr_immuno > 0 else 0.0 |
| precision, recall, _ = precision_recall_curve(y_tesla, y_pred_tesla) |
| auprc = auc(recall, precision) |
|
|
| if args.tesla_file: |
| new_tesla = (not os.path.exists(args.tesla_file)) or (os.path.getsize(args.tesla_file) == 0) |
| with open(args.tesla_file, "a", encoding="utf-8") as tf: |
| if new_tesla: |
| tf.write("Dataset\tPatient\tTTIF\tFR\tAUPRC\n") |
| tf.write(f"{dataset_val}\t{patient}\t{ttif:.3f}\t{fr:.3f}\t{auprc:.3f}\n") |
|
|
| rows_tesla.append({ |
| "Dataset": dataset_val, |
| "Patient": patient, |
| "TTIF": ttif, |
| "FR": fr, |
| "AUPRC": auprc, |
| }) |
|
|
| if args.output_xlsx and rows_main: |
| os.makedirs(os.path.dirname(args.output_xlsx) or '.', exist_ok=True) |
| pd.DataFrame(rows_main).to_excel(args.output_xlsx, index=False) |
| if args.tesla_xlsx and rows_tesla: |
| os.makedirs(os.path.dirname(args.tesla_xlsx) or '.', exist_ok=True) |
| pd.DataFrame(rows_tesla).to_excel(args.tesla_xlsx, index=False) |
|
|
| print(f" Evaluation completed! Processed {len(rows_main)} patients") |
|
|
| if __name__ == "__main__": |
| main() |