optimal-screening-app / optimal_screening /cli /get_optimal_screening.py
cmpatino's picture
cmpatino HF Staff
Update app for optimal screening CSV output
09f8c96
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import yaml
from optimal_screening.analysis import compute_optimal_screening_actions
from optimal_screening.data_sources import load_dataframe
REQUIRED_FIELDS = {"alpha", "beta", "outcome", "strata"}
DEFAULT_ACTION_COL = "screening_decision"
def _read_config(path: Path) -> dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Config file not found: {path}")
text = path.read_text()
if path.suffix.lower() == ".json":
data = json.loads(text)
elif path.suffix.lower() in {".yaml", ".yml"}:
data = yaml.safe_load(text)
else:
raise ValueError("Config file must be YAML or JSON")
if not isinstance(data, dict):
raise ValueError("Config must be a mapping")
return data
def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
if "alpha_quantiles" in config:
raise ValueError("Use alpha for one screening budget; alpha_quantiles is only for curve outputs")
missing = sorted(REQUIRED_FIELDS - set(config))
if missing:
raise ValueError(f"Missing required config fields: {missing}")
has_csv = config.get("csv") is not None
has_hf_dataset = config.get("hf_dataset") is not None
if has_csv == has_hf_dataset:
raise ValueError("Config must provide exactly one data source: csv or hf_dataset")
strata = config["strata"]
if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata):
raise ValueError("strata must be a non-empty list of column names")
beta = float(config["beta"])
if not 0 < beta <= 1:
raise ValueError("beta must be in the interval (0, 1]")
alpha = float(config["alpha"])
if not 0 <= alpha <= beta:
raise ValueError(f"alpha must be between 0 and beta={beta}")
action_col = str(config.get("action_col", DEFAULT_ACTION_COL))
if not action_col:
raise ValueError("action_col must not be empty")
return {
"csv": str(config["csv"]) if has_csv else None,
"hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None,
"hf_split": str(config.get("hf_split", "train")),
"hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None,
"outcome": str(config["outcome"]),
"strata": strata,
"beta": beta,
"alpha": alpha,
"prediction_col": str(config.get("prediction_col", "probability")),
"risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None,
"action_col": action_col,
"output": str(config.get("output", "runs/optimal_screening.csv")),
}
def get_optimal_screening_from_config(config_path: Path) -> Path:
config = _validate_config(_read_config(config_path))
df, dataset_label = load_dataframe(
csv_path=config["csv"],
hf_dataset=config["hf_dataset"],
hf_split=config["hf_split"],
hf_revision=config["hf_revision"],
)
required_cols = {config["outcome"], *config["strata"]}
if config["risk_col"]:
required_cols.add(config["risk_col"])
elif config["prediction_col"] in df.columns:
required_cols.add(config["prediction_col"])
missing_cols = sorted(required_cols - set(df.columns))
if missing_cols:
raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}")
if config["action_col"] in df.columns:
raise ValueError(f"Output action column already exists in {dataset_label}: {config['action_col']}")
df[config["action_col"]] = compute_optimal_screening_actions(
rows=df.to_dict("records"),
outcome_col=config["outcome"],
strata_features=config["strata"],
prediction_col=config["prediction_col"],
beta=config["beta"],
alpha=config["alpha"],
use_custom_risk_col=config["risk_col"],
)
output_path = Path(config["output"])
output_path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(output_path, index=False)
return output_path
def main() -> None:
parser = argparse.ArgumentParser(description="Write optimal screening actions from a YAML or JSON config")
parser.add_argument("config", help="Path to a YAML or JSON config file")
args = parser.parse_args()
output_path = get_optimal_screening_from_config(Path(args.config))
print(f"Wrote {output_path}")
if __name__ == "__main__":
main()