import os import json import argparse import numpy as np import pandas as pd import xgboost as xgb from scipy.stats import spearmanr from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score from datasets import load_from_disk, DatasetDict def safe_spearmanr(y_true, y_pred): rho = spearmanr(y_true, y_pred).correlation return 0.0 if (rho is None or np.isnan(rho)) else float(rho) def eval_regression(y_true, y_pred): try: from sklearn.metrics import root_mean_squared_error rmse = float(root_mean_squared_error(y_true, y_pred)) except Exception: rmse = float(np.sqrt(mean_squared_error(y_true, y_pred))) return { "spearman_rho": safe_spearmanr(y_true, y_pred), "rmse": rmse, "mae": float(mean_absolute_error(y_true, y_pred)), "r2": float(r2_score(y_true, y_pred)), } # ======================== Bootstrap CI ========================================= def bootstrap_ci_reg( y_true: np.ndarray, y_pred: np.ndarray, n_bootstrap: int = 2000, ci: float = 0.95, seed: int = 1986, ) -> dict: """ Percentile bootstrap CI for regression metrics. Uses percentile method (not t-CI) because: - Spearman rho is bounded [-1, 1] - t-CI can produce impossible values near extremes - RMSE is strictly positive - symmetric t-CI is inappropriate near 0 - Percentile bootstrap makes no distributional assumptions Fisher z-transform CI for rho is also computed as a cross-check. """ rng = np.random.default_rng(seed=seed) n = len(y_true) alpha = 1 - ci lo, hi = alpha / 2, 1 - alpha / 2 boot_metrics = {k: [] for k in ["spearman_rho", "rmse", "mae", "r2"]} for _ in range(n_bootstrap): idx = rng.integers(0, n, size=n) yt, yp = y_true[idx], y_pred[idx] if len(np.unique(yt)) < 2: continue m = eval_regression(yt, yp) for k in boot_metrics: boot_metrics[k].append(m[k]) results = {} for name, arr in boot_metrics.items(): arr = np.array(arr) results[name] = { "mean": float(arr.mean()), "std": float(arr.std()), "ci_low": float(np.quantile(arr, lo)), "ci_high": float(np.quantile(arr, hi)), "report": f"{arr.mean():.4f} [{np.quantile(arr, lo):.4f}, {np.quantile(arr, hi):.4f}]", "n_bootstrap": len(arr), } # Fisher z-transform CI for Spearman rho (cross-check, more accurate near ±1) rho_vals = np.array(boot_metrics["spearman_rho"]) rho_obs = safe_spearmanr(y_true, y_pred) # z-transform: arctanh(rho), SE = 1/sqrt(n-3) z = np.arctanh(np.clip(rho_obs, -0.9999, 0.9999)) se_z = 1.0 / np.sqrt(max(n - 3, 1)) z_lo = z - 1.96 * se_z z_hi = z + 1.96 * se_z results["spearman_rho"]["fisher_z_ci"] = { "ci_low": float(np.tanh(z_lo)), "ci_high": float(np.tanh(z_hi)), "report": f"[{np.tanh(z_lo):.4f}, {np.tanh(z_hi):.4f}]", "note": "Fisher z-transform CI - more accurate when rho > 0.9", } results["n_samples"] = int(n) return results def residual_uncertainty(val_preds_df: pd.DataFrame, coverage: float = 0.95) -> pd.DataFrame: """ - Assume residuals ~ N(0, sigma) where sigma = std(residuals) - 95% prediction interval for molecule i: y_pred_i ± z * sigma - Uncertainty score = sigma (constant across all molecules for linear models) - Dataset-level uncertainty """ df = val_preds_df.copy() residuals = df["y_true"] - df["y_pred"] sigma = float(residuals.std(ddof=1)) z = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}.get(coverage, 1.960) half_width = z * sigma df["pred_interval_low"] = df["y_pred"] - half_width df["pred_interval_high"] = df["y_pred"] + half_width df["pred_interval_width"] = 2 * half_width # constant for linear models df["abs_error"] = residuals.abs() # what fraction of y_true actually falls inside the interval empirical_coverage = float( ((df["y_true"] >= df["pred_interval_low"]) & (df["y_true"] <= df["pred_interval_high"])).mean() ) meta = { "residual_std": round(sigma, 6), "interval_halfwidth": round(half_width, 6), f"nominal_coverage": coverage, "empirical_coverage": round(empirical_coverage, 4), "note": ( "Prediction interval assumes N(0, sigma) residuals." "Interval width is constant across molecules for linear models. " ), } return df, meta def save_ci_report(ci_results: dict, out_dir: str, model_name: str = ""): os.makedirs(out_dir, exist_ok=True) path = os.path.join(out_dir, "bootstrap_ci_reg.json") with open(path, "w") as f: json.dump(ci_results, f, indent=2) print(f"\n=== Bootstrap 95% CI - Regression ({model_name}) ===") for metric in ["spearman_rho", "rmse", "mae", "r2"]: r = ci_results[metric] print(f" {metric:15s}: {r['report']}") if metric == "spearman_rho" and "fisher_z_ci" in r: fz = r["fisher_z_ci"] print(f" Fisher z CI : {fz['report']} ← use this if rho > 0.9") print(f" n_val={ci_results['n_samples']}, n_bootstrap={ci_results['spearman_rho']['n_bootstrap']}") print(f"Saved to {path}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mode", required=True, choices=["ci", "uncertainty_residual"], help=( "ci : bootstrap CI from val_predictions.csv\n" "uncertainty_residual: residual interval for ElasticNet/SVR" )) parser.add_argument("--val_preds", type=str, help="Path to val_predictions.csv") parser.add_argument("--out_dir", type=str, required=True) parser.add_argument("--model_name", type=str, default="") parser.add_argument("--n_bootstrap", type=int, default=2000) args = parser.parse_args() if args.mode == "ci": assert args.val_preds, "--val_preds required" df = pd.read_csv(args.val_preds) ci = bootstrap_ci_reg(df["y_true"].values, df["y_pred"].values, n_bootstrap=args.n_bootstrap) save_ci_report(ci, args.out_dir, args.model_name) elif args.mode == "uncertainty_residual": assert args.val_preds df_preds = pd.read_csv(args.val_preds) ci = bootstrap_ci_reg(df_preds["y_true"].values, df_preds["y_pred"].values, n_bootstrap=args.n_bootstrap) save_ci_report(ci, args.out_dir, args.model_name) df_unc, meta = residual_uncertainty(df_preds) path = os.path.join(args.out_dir, "val_uncertainty_residual.csv") df_unc.to_csv(path, index=False) meta_path = os.path.join(args.out_dir, "residual_interval_meta.json") with open(meta_path, "w") as f: json.dump(meta, f, indent=2) print(f"\nResidual interval summary:") print(f" Residual std : {meta['residual_std']:.4f}") print(f" 95% interval ± {meta['interval_halfwidth']:.4f}") print(f" Empirical coverage : {meta['empirical_coverage']:.4f} (nominal={meta['nominal_coverage']})") print(f" Saved to {path}")