Joblib
PeptiVerse / training_classifiers /ml_uncertainty_reg.py
ynuozhang
major update
04c2975
import os
import json
import argparse
import numpy as np
import pandas as pd
import xgboost as xgb
from scipy.stats import spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from datasets import load_from_disk, DatasetDict
def safe_spearmanr(y_true, y_pred):
rho = spearmanr(y_true, y_pred).correlation
return 0.0 if (rho is None or np.isnan(rho)) else float(rho)
def eval_regression(y_true, y_pred):
try:
from sklearn.metrics import root_mean_squared_error
rmse = float(root_mean_squared_error(y_true, y_pred))
except Exception:
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
return {
"spearman_rho": safe_spearmanr(y_true, y_pred),
"rmse": rmse,
"mae": float(mean_absolute_error(y_true, y_pred)),
"r2": float(r2_score(y_true, y_pred)),
}
# ======================== Bootstrap CI =========================================
def bootstrap_ci_reg(
y_true: np.ndarray,
y_pred: np.ndarray,
n_bootstrap: int = 2000,
ci: float = 0.95,
seed: int = 1986,
) -> dict:
"""
Percentile bootstrap CI for regression metrics.
Uses percentile method (not t-CI) because:
- Spearman rho is bounded [-1, 1] - t-CI can produce impossible values near extremes
- RMSE is strictly positive - symmetric t-CI is inappropriate near 0
- Percentile bootstrap makes no distributional assumptions
Fisher z-transform CI for rho is also computed as a cross-check.
"""
rng = np.random.default_rng(seed=seed)
n = len(y_true)
alpha = 1 - ci
lo, hi = alpha / 2, 1 - alpha / 2
boot_metrics = {k: [] for k in ["spearman_rho", "rmse", "mae", "r2"]}
for _ in range(n_bootstrap):
idx = rng.integers(0, n, size=n)
yt, yp = y_true[idx], y_pred[idx]
if len(np.unique(yt)) < 2:
continue
m = eval_regression(yt, yp)
for k in boot_metrics:
boot_metrics[k].append(m[k])
results = {}
for name, arr in boot_metrics.items():
arr = np.array(arr)
results[name] = {
"mean": float(arr.mean()),
"std": float(arr.std()),
"ci_low": float(np.quantile(arr, lo)),
"ci_high": float(np.quantile(arr, hi)),
"report": f"{arr.mean():.4f} [{np.quantile(arr, lo):.4f}, {np.quantile(arr, hi):.4f}]",
"n_bootstrap": len(arr),
}
# Fisher z-transform CI for Spearman rho (cross-check, more accurate near ±1)
rho_vals = np.array(boot_metrics["spearman_rho"])
rho_obs = safe_spearmanr(y_true, y_pred)
# z-transform: arctanh(rho), SE = 1/sqrt(n-3)
z = np.arctanh(np.clip(rho_obs, -0.9999, 0.9999))
se_z = 1.0 / np.sqrt(max(n - 3, 1))
z_lo = z - 1.96 * se_z
z_hi = z + 1.96 * se_z
results["spearman_rho"]["fisher_z_ci"] = {
"ci_low": float(np.tanh(z_lo)),
"ci_high": float(np.tanh(z_hi)),
"report": f"[{np.tanh(z_lo):.4f}, {np.tanh(z_hi):.4f}]",
"note": "Fisher z-transform CI - more accurate when rho > 0.9",
}
results["n_samples"] = int(n)
return results
def residual_uncertainty(val_preds_df: pd.DataFrame, coverage: float = 0.95) -> pd.DataFrame:
"""
- Assume residuals ~ N(0, sigma) where sigma = std(residuals)
- 95% prediction interval for molecule i: y_pred_i ± z * sigma
- Uncertainty score = sigma (constant across all molecules for linear models)
- Dataset-level uncertainty
"""
df = val_preds_df.copy()
residuals = df["y_true"] - df["y_pred"]
sigma = float(residuals.std(ddof=1))
z = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}.get(coverage, 1.960)
half_width = z * sigma
df["pred_interval_low"] = df["y_pred"] - half_width
df["pred_interval_high"] = df["y_pred"] + half_width
df["pred_interval_width"] = 2 * half_width # constant for linear models
df["abs_error"] = residuals.abs()
# what fraction of y_true actually falls inside the interval
empirical_coverage = float(
((df["y_true"] >= df["pred_interval_low"]) &
(df["y_true"] <= df["pred_interval_high"])).mean()
)
meta = {
"residual_std": round(sigma, 6),
"interval_halfwidth": round(half_width, 6),
f"nominal_coverage": coverage,
"empirical_coverage": round(empirical_coverage, 4),
"note": (
"Prediction interval assumes N(0, sigma) residuals."
"Interval width is constant across molecules for linear models. "
),
}
return df, meta
def save_ci_report(ci_results: dict, out_dir: str, model_name: str = ""):
os.makedirs(out_dir, exist_ok=True)
path = os.path.join(out_dir, "bootstrap_ci_reg.json")
with open(path, "w") as f:
json.dump(ci_results, f, indent=2)
print(f"\n=== Bootstrap 95% CI - Regression ({model_name}) ===")
for metric in ["spearman_rho", "rmse", "mae", "r2"]:
r = ci_results[metric]
print(f" {metric:15s}: {r['report']}")
if metric == "spearman_rho" and "fisher_z_ci" in r:
fz = r["fisher_z_ci"]
print(f" Fisher z CI : {fz['report']} ← use this if rho > 0.9")
print(f" n_val={ci_results['n_samples']}, n_bootstrap={ci_results['spearman_rho']['n_bootstrap']}")
print(f"Saved to {path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", required=True,
choices=["ci", "uncertainty_residual"],
help=(
"ci : bootstrap CI from val_predictions.csv\n"
"uncertainty_residual: residual interval for ElasticNet/SVR"
))
parser.add_argument("--val_preds", type=str, help="Path to val_predictions.csv")
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--model_name", type=str, default="")
parser.add_argument("--n_bootstrap", type=int, default=2000)
args = parser.parse_args()
if args.mode == "ci":
assert args.val_preds, "--val_preds required"
df = pd.read_csv(args.val_preds)
ci = bootstrap_ci_reg(df["y_true"].values, df["y_pred"].values,
n_bootstrap=args.n_bootstrap)
save_ci_report(ci, args.out_dir, args.model_name)
elif args.mode == "uncertainty_residual":
assert args.val_preds
df_preds = pd.read_csv(args.val_preds)
ci = bootstrap_ci_reg(df_preds["y_true"].values, df_preds["y_pred"].values,
n_bootstrap=args.n_bootstrap)
save_ci_report(ci, args.out_dir, args.model_name)
df_unc, meta = residual_uncertainty(df_preds)
path = os.path.join(args.out_dir, "val_uncertainty_residual.csv")
df_unc.to_csv(path, index=False)
meta_path = os.path.join(args.out_dir, "residual_interval_meta.json")
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
print(f"\nResidual interval summary:")
print(f" Residual std : {meta['residual_std']:.4f}")
print(f" 95% interval ± {meta['interval_halfwidth']:.4f}")
print(f" Empirical coverage : {meta['empirical_coverage']:.4f} (nominal={meta['nominal_coverage']})")
print(f" Saved to {path}")