File size: 4,542 Bytes
09f8c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any

import yaml

from optimal_screening.analysis import compute_optimal_screening_actions
from optimal_screening.data_sources import load_dataframe


REQUIRED_FIELDS = {"alpha", "beta", "outcome", "strata"}
DEFAULT_ACTION_COL = "screening_decision"


def _read_config(path: Path) -> dict[str, Any]:
    if not path.exists():
        raise FileNotFoundError(f"Config file not found: {path}")

    text = path.read_text()
    if path.suffix.lower() == ".json":
        data = json.loads(text)
    elif path.suffix.lower() in {".yaml", ".yml"}:
        data = yaml.safe_load(text)
    else:
        raise ValueError("Config file must be YAML or JSON")

    if not isinstance(data, dict):
        raise ValueError("Config must be a mapping")
    return data


def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
    if "alpha_quantiles" in config:
        raise ValueError("Use alpha for one screening budget; alpha_quantiles is only for curve outputs")

    missing = sorted(REQUIRED_FIELDS - set(config))
    if missing:
        raise ValueError(f"Missing required config fields: {missing}")

    has_csv = config.get("csv") is not None
    has_hf_dataset = config.get("hf_dataset") is not None
    if has_csv == has_hf_dataset:
        raise ValueError("Config must provide exactly one data source: csv or hf_dataset")

    strata = config["strata"]
    if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata):
        raise ValueError("strata must be a non-empty list of column names")

    beta = float(config["beta"])
    if not 0 < beta <= 1:
        raise ValueError("beta must be in the interval (0, 1]")

    alpha = float(config["alpha"])
    if not 0 <= alpha <= beta:
        raise ValueError(f"alpha must be between 0 and beta={beta}")

    action_col = str(config.get("action_col", DEFAULT_ACTION_COL))
    if not action_col:
        raise ValueError("action_col must not be empty")

    return {
        "csv": str(config["csv"]) if has_csv else None,
        "hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None,
        "hf_split": str(config.get("hf_split", "train")),
        "hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None,
        "outcome": str(config["outcome"]),
        "strata": strata,
        "beta": beta,
        "alpha": alpha,
        "prediction_col": str(config.get("prediction_col", "probability")),
        "risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None,
        "action_col": action_col,
        "output": str(config.get("output", "runs/optimal_screening.csv")),
    }


def get_optimal_screening_from_config(config_path: Path) -> Path:
    config = _validate_config(_read_config(config_path))

    df, dataset_label = load_dataframe(
        csv_path=config["csv"],
        hf_dataset=config["hf_dataset"],
        hf_split=config["hf_split"],
        hf_revision=config["hf_revision"],
    )

    required_cols = {config["outcome"], *config["strata"]}
    if config["risk_col"]:
        required_cols.add(config["risk_col"])
    elif config["prediction_col"] in df.columns:
        required_cols.add(config["prediction_col"])

    missing_cols = sorted(required_cols - set(df.columns))
    if missing_cols:
        raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}")

    if config["action_col"] in df.columns:
        raise ValueError(f"Output action column already exists in {dataset_label}: {config['action_col']}")

    df[config["action_col"]] = compute_optimal_screening_actions(
        rows=df.to_dict("records"),
        outcome_col=config["outcome"],
        strata_features=config["strata"],
        prediction_col=config["prediction_col"],
        beta=config["beta"],
        alpha=config["alpha"],
        use_custom_risk_col=config["risk_col"],
    )

    output_path = Path(config["output"])
    output_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(output_path, index=False)
    return output_path


def main() -> None:
    parser = argparse.ArgumentParser(description="Write optimal screening actions from a YAML or JSON config")
    parser.add_argument("config", help="Path to a YAML or JSON config file")
    args = parser.parse_args()

    output_path = get_optimal_screening_from_config(Path(args.config))
    print(f"Wrote {output_path}")


if __name__ == "__main__":
    main()