cmpatino's picture
cmpatino HF Staff
feat: add risk calculation gradio app
a1b4ce8
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import yaml
from optimal_screening.analysis import compute_optimal_screening_curve
from optimal_screening.data_sources import load_dataframe
REQUIRED_FIELDS = {"outcome", "strata", "beta"}
def _read_config(path: Path) -> dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")
text = path.read_text()
if path.suffix.lower() == ".json":
data = json.loads(text)
elif path.suffix.lower() in {".yaml", ".yml"}:
data = yaml.safe_load(text)
else:
raise ValueError("Config file must be YAML or JSON")
if not isinstance(data, dict):
raise ValueError("Config must be a mapping")
return data
def _as_float_sequence(values: Any, field: str) -> list[float] | None:
if values is None:
return None
if not isinstance(values, list | tuple):
raise ValueError(f"{field} must be a list of numbers")
return [float(value) for value in values]
def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
missing = sorted(REQUIRED_FIELDS - set(config))
if missing:
raise ValueError(f"Missing required config fields: {missing}")
has_csv = config.get("csv") is not None
has_hf_dataset = config.get("hf_dataset") is not None
if has_csv == has_hf_dataset:
raise ValueError("Config must provide exactly one data source: csv or hf_dataset")
strata = config["strata"]
if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata):
raise ValueError("strata must be a non-empty list of column names")
beta = float(config["beta"])
if not 0 < beta <= 1:
raise ValueError("beta must be in the interval (0, 1]")
alpha_quantiles = _as_float_sequence(config.get("alpha_quantiles"), "alpha_quantiles")
if alpha_quantiles is not None:
invalid = [alpha for alpha in alpha_quantiles if alpha < 0 or alpha > beta]
if invalid:
raise ValueError(f"alpha_quantiles must be between 0 and beta={beta}; invalid values: {invalid}")
return {
"csv": str(config["csv"]) if has_csv else None,
"hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None,
"hf_split": str(config.get("hf_split", "train")),
"hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None,
"outcome": str(config["outcome"]),
"strata": strata,
"beta": beta,
"prediction_col": str(config.get("prediction_col", "probability")),
"risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None,
"alpha_quantiles": alpha_quantiles,
"output": str(config.get("output", "runs/optimal_screening_curve.json")),
}
def _json_safe(value: Any) -> Any:
if isinstance(value, dict):
return {key: _json_safe(item) for key, item in value.items()}
if isinstance(value, list | tuple):
return [_json_safe(item) for item in value]
if hasattr(value, "item"):
return value.item()
return value
def calculate_from_config(config_path: Path) -> Path:
config = _validate_config(_read_config(config_path))
df, dataset_label = load_dataframe(
csv_path=config["csv"],
hf_dataset=config["hf_dataset"],
hf_split=config["hf_split"],
hf_revision=config["hf_revision"],
)
required_cols = {config["outcome"], *config["strata"]}
if config["risk_col"]:
required_cols.add(config["risk_col"])
elif config["prediction_col"] in df.columns:
required_cols.add(config["prediction_col"])
missing_cols = sorted(required_cols - set(df.columns))
if missing_cols:
raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}")
result = compute_optimal_screening_curve(
rows=df.to_dict("records"),
outcome_col=config["outcome"],
strata_features=config["strata"],
prediction_col=config["prediction_col"],
beta=config["beta"],
alpha_quantiles=config["alpha_quantiles"],
use_custom_risk_col=config["risk_col"],
)
output_path = Path(config["output"])
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(_json_safe(result), indent=2))
return output_path
def main() -> None:
parser = argparse.ArgumentParser(description="Compute an optimal screening curve from a YAML or JSON config")
parser.add_argument("config", help="Path to a YAML or JSON config file")
args = parser.parse_args()
output_path = calculate_from_config(Path(args.config))
print(f"Wrote {output_path}")
if __name__ == "__main__":
main()