Spaces:
Sleeping
Sleeping
File size: 4,830 Bytes
a1b4ce8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import yaml
from optimal_screening.analysis import compute_optimal_screening_curve
from optimal_screening.data_sources import load_dataframe
REQUIRED_FIELDS = {"outcome", "strata", "beta"}
def _read_config(path: Path) -> dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")
text = path.read_text()
if path.suffix.lower() == ".json":
data = json.loads(text)
elif path.suffix.lower() in {".yaml", ".yml"}:
data = yaml.safe_load(text)
else:
raise ValueError("Config file must be YAML or JSON")
if not isinstance(data, dict):
raise ValueError("Config must be a mapping")
return data
def _as_float_sequence(values: Any, field: str) -> list[float] | None:
if values is None:
return None
if not isinstance(values, list | tuple):
raise ValueError(f"{field} must be a list of numbers")
return [float(value) for value in values]
def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
missing = sorted(REQUIRED_FIELDS - set(config))
if missing:
raise ValueError(f"Missing required config fields: {missing}")
has_csv = config.get("csv") is not None
has_hf_dataset = config.get("hf_dataset") is not None
if has_csv == has_hf_dataset:
raise ValueError("Config must provide exactly one data source: csv or hf_dataset")
strata = config["strata"]
if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata):
raise ValueError("strata must be a non-empty list of column names")
beta = float(config["beta"])
if not 0 < beta <= 1:
raise ValueError("beta must be in the interval (0, 1]")
alpha_quantiles = _as_float_sequence(config.get("alpha_quantiles"), "alpha_quantiles")
if alpha_quantiles is not None:
invalid = [alpha for alpha in alpha_quantiles if alpha < 0 or alpha > beta]
if invalid:
raise ValueError(f"alpha_quantiles must be between 0 and beta={beta}; invalid values: {invalid}")
return {
"csv": str(config["csv"]) if has_csv else None,
"hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None,
"hf_split": str(config.get("hf_split", "train")),
"hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None,
"outcome": str(config["outcome"]),
"strata": strata,
"beta": beta,
"prediction_col": str(config.get("prediction_col", "probability")),
"risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None,
"alpha_quantiles": alpha_quantiles,
"output": str(config.get("output", "runs/optimal_screening_curve.json")),
}
def _json_safe(value: Any) -> Any:
if isinstance(value, dict):
return {key: _json_safe(item) for key, item in value.items()}
if isinstance(value, list | tuple):
return [_json_safe(item) for item in value]
if hasattr(value, "item"):
return value.item()
return value
def calculate_from_config(config_path: Path) -> Path:
config = _validate_config(_read_config(config_path))
df, dataset_label = load_dataframe(
csv_path=config["csv"],
hf_dataset=config["hf_dataset"],
hf_split=config["hf_split"],
hf_revision=config["hf_revision"],
)
required_cols = {config["outcome"], *config["strata"]}
if config["risk_col"]:
required_cols.add(config["risk_col"])
elif config["prediction_col"] in df.columns:
required_cols.add(config["prediction_col"])
missing_cols = sorted(required_cols - set(df.columns))
if missing_cols:
raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}")
result = compute_optimal_screening_curve(
rows=df.to_dict("records"),
outcome_col=config["outcome"],
strata_features=config["strata"],
prediction_col=config["prediction_col"],
beta=config["beta"],
alpha_quantiles=config["alpha_quantiles"],
use_custom_risk_col=config["risk_col"],
)
output_path = Path(config["output"])
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(_json_safe(result), indent=2))
return output_path
def main() -> None:
parser = argparse.ArgumentParser(description="Compute an optimal screening curve from a YAML or JSON config")
parser.add_argument("config", help="Path to a YAML or JSON config file")
args = parser.parse_args()
output_path = calculate_from_config(Path(args.config))
print(f"Wrote {output_path}")
if __name__ == "__main__":
main()
|