Spaces:
Sleeping
Sleeping
File size: 4,542 Bytes
09f8c96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import yaml
from optimal_screening.analysis import compute_optimal_screening_actions
from optimal_screening.data_sources import load_dataframe
REQUIRED_FIELDS = {"alpha", "beta", "outcome", "strata"}
DEFAULT_ACTION_COL = "screening_decision"
def _read_config(path: Path) -> dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")
text = path.read_text()
if path.suffix.lower() == ".json":
data = json.loads(text)
elif path.suffix.lower() in {".yaml", ".yml"}:
data = yaml.safe_load(text)
else:
raise ValueError("Config file must be YAML or JSON")
if not isinstance(data, dict):
raise ValueError("Config must be a mapping")
return data
def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
if "alpha_quantiles" in config:
raise ValueError("Use alpha for one screening budget; alpha_quantiles is only for curve outputs")
missing = sorted(REQUIRED_FIELDS - set(config))
if missing:
raise ValueError(f"Missing required config fields: {missing}")
has_csv = config.get("csv") is not None
has_hf_dataset = config.get("hf_dataset") is not None
if has_csv == has_hf_dataset:
raise ValueError("Config must provide exactly one data source: csv or hf_dataset")
strata = config["strata"]
if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata):
raise ValueError("strata must be a non-empty list of column names")
beta = float(config["beta"])
if not 0 < beta <= 1:
raise ValueError("beta must be in the interval (0, 1]")
alpha = float(config["alpha"])
if not 0 <= alpha <= beta:
raise ValueError(f"alpha must be between 0 and beta={beta}")
action_col = str(config.get("action_col", DEFAULT_ACTION_COL))
if not action_col:
raise ValueError("action_col must not be empty")
return {
"csv": str(config["csv"]) if has_csv else None,
"hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None,
"hf_split": str(config.get("hf_split", "train")),
"hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None,
"outcome": str(config["outcome"]),
"strata": strata,
"beta": beta,
"alpha": alpha,
"prediction_col": str(config.get("prediction_col", "probability")),
"risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None,
"action_col": action_col,
"output": str(config.get("output", "runs/optimal_screening.csv")),
}
def get_optimal_screening_from_config(config_path: Path) -> Path:
config = _validate_config(_read_config(config_path))
df, dataset_label = load_dataframe(
csv_path=config["csv"],
hf_dataset=config["hf_dataset"],
hf_split=config["hf_split"],
hf_revision=config["hf_revision"],
)
required_cols = {config["outcome"], *config["strata"]}
if config["risk_col"]:
required_cols.add(config["risk_col"])
elif config["prediction_col"] in df.columns:
required_cols.add(config["prediction_col"])
missing_cols = sorted(required_cols - set(df.columns))
if missing_cols:
raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}")
if config["action_col"] in df.columns:
raise ValueError(f"Output action column already exists in {dataset_label}: {config['action_col']}")
df[config["action_col"]] = compute_optimal_screening_actions(
rows=df.to_dict("records"),
outcome_col=config["outcome"],
strata_features=config["strata"],
prediction_col=config["prediction_col"],
beta=config["beta"],
alpha=config["alpha"],
use_custom_risk_col=config["risk_col"],
)
output_path = Path(config["output"])
output_path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(output_path, index=False)
return output_path
def main() -> None:
parser = argparse.ArgumentParser(description="Write optimal screening actions from a YAML or JSON config")
parser.add_argument("config", help="Path to a YAML or JSON config file")
args = parser.parse_args()
output_path = get_optimal_screening_from_config(Path(args.config))
print(f"Wrote {output_path}")
if __name__ == "__main__":
main()
|