Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| from optimal_screening.analysis import compute_optimal_screening_actions | |
| from optimal_screening.data_sources import load_dataframe | |
| REQUIRED_FIELDS = {"alpha", "beta", "outcome", "strata"} | |
| DEFAULT_ACTION_COL = "screening_decision" | |
| def _read_config(path: Path) -> dict[str, Any]: | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Config file not found: {path}") | |
| text = path.read_text() | |
| if path.suffix.lower() == ".json": | |
| data = json.loads(text) | |
| elif path.suffix.lower() in {".yaml", ".yml"}: | |
| data = yaml.safe_load(text) | |
| else: | |
| raise ValueError("Config file must be YAML or JSON") | |
| if not isinstance(data, dict): | |
| raise ValueError("Config must be a mapping") | |
| return data | |
| def _validate_config(config: dict[str, Any]) -> dict[str, Any]: | |
| if "alpha_quantiles" in config: | |
| raise ValueError("Use alpha for one screening budget; alpha_quantiles is only for curve outputs") | |
| missing = sorted(REQUIRED_FIELDS - set(config)) | |
| if missing: | |
| raise ValueError(f"Missing required config fields: {missing}") | |
| has_csv = config.get("csv") is not None | |
| has_hf_dataset = config.get("hf_dataset") is not None | |
| if has_csv == has_hf_dataset: | |
| raise ValueError("Config must provide exactly one data source: csv or hf_dataset") | |
| strata = config["strata"] | |
| if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata): | |
| raise ValueError("strata must be a non-empty list of column names") | |
| beta = float(config["beta"]) | |
| if not 0 < beta <= 1: | |
| raise ValueError("beta must be in the interval (0, 1]") | |
| alpha = float(config["alpha"]) | |
| if not 0 <= alpha <= beta: | |
| raise ValueError(f"alpha must be between 0 and beta={beta}") | |
| action_col = str(config.get("action_col", DEFAULT_ACTION_COL)) | |
| if not action_col: | |
| raise ValueError("action_col must not be empty") | |
| return { | |
| "csv": str(config["csv"]) if has_csv else None, | |
| "hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None, | |
| "hf_split": str(config.get("hf_split", "train")), | |
| "hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None, | |
| "outcome": str(config["outcome"]), | |
| "strata": strata, | |
| "beta": beta, | |
| "alpha": alpha, | |
| "prediction_col": str(config.get("prediction_col", "probability")), | |
| "risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None, | |
| "action_col": action_col, | |
| "output": str(config.get("output", "runs/optimal_screening.csv")), | |
| } | |
| def get_optimal_screening_from_config(config_path: Path) -> Path: | |
| config = _validate_config(_read_config(config_path)) | |
| df, dataset_label = load_dataframe( | |
| csv_path=config["csv"], | |
| hf_dataset=config["hf_dataset"], | |
| hf_split=config["hf_split"], | |
| hf_revision=config["hf_revision"], | |
| ) | |
| required_cols = {config["outcome"], *config["strata"]} | |
| if config["risk_col"]: | |
| required_cols.add(config["risk_col"]) | |
| elif config["prediction_col"] in df.columns: | |
| required_cols.add(config["prediction_col"]) | |
| missing_cols = sorted(required_cols - set(df.columns)) | |
| if missing_cols: | |
| raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}") | |
| if config["action_col"] in df.columns: | |
| raise ValueError(f"Output action column already exists in {dataset_label}: {config['action_col']}") | |
| df[config["action_col"]] = compute_optimal_screening_actions( | |
| rows=df.to_dict("records"), | |
| outcome_col=config["outcome"], | |
| strata_features=config["strata"], | |
| prediction_col=config["prediction_col"], | |
| beta=config["beta"], | |
| alpha=config["alpha"], | |
| use_custom_risk_col=config["risk_col"], | |
| ) | |
| output_path = Path(config["output"]) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| df.to_csv(output_path, index=False) | |
| return output_path | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Write optimal screening actions from a YAML or JSON config") | |
| parser.add_argument("config", help="Path to a YAML or JSON config file") | |
| args = parser.parse_args() | |
| output_path = get_optimal_screening_from_config(Path(args.config)) | |
| print(f"Wrote {output_path}") | |
| if __name__ == "__main__": | |
| main() | |