Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| from optimal_screening.analysis import compute_optimal_screening_curve | |
| from optimal_screening.data_sources import load_dataframe | |
| REQUIRED_FIELDS = {"outcome", "strata", "beta"} | |
| def _read_config(path: Path) -> dict[str, Any]: | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Config file not found: {path}") | |
| text = path.read_text() | |
| if path.suffix.lower() == ".json": | |
| data = json.loads(text) | |
| elif path.suffix.lower() in {".yaml", ".yml"}: | |
| data = yaml.safe_load(text) | |
| else: | |
| raise ValueError("Config file must be YAML or JSON") | |
| if not isinstance(data, dict): | |
| raise ValueError("Config must be a mapping") | |
| return data | |
| def _as_float_sequence(values: Any, field: str) -> list[float] | None: | |
| if values is None: | |
| return None | |
| if not isinstance(values, list | tuple): | |
| raise ValueError(f"{field} must be a list of numbers") | |
| return [float(value) for value in values] | |
| def _validate_config(config: dict[str, Any]) -> dict[str, Any]: | |
| missing = sorted(REQUIRED_FIELDS - set(config)) | |
| if missing: | |
| raise ValueError(f"Missing required config fields: {missing}") | |
| has_csv = config.get("csv") is not None | |
| has_hf_dataset = config.get("hf_dataset") is not None | |
| if has_csv == has_hf_dataset: | |
| raise ValueError("Config must provide exactly one data source: csv or hf_dataset") | |
| strata = config["strata"] | |
| if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata): | |
| raise ValueError("strata must be a non-empty list of column names") | |
| beta = float(config["beta"]) | |
| if not 0 < beta <= 1: | |
| raise ValueError("beta must be in the interval (0, 1]") | |
| alpha_quantiles = _as_float_sequence(config.get("alpha_quantiles"), "alpha_quantiles") | |
| if alpha_quantiles is not None: | |
| invalid = [alpha for alpha in alpha_quantiles if alpha < 0 or alpha > beta] | |
| if invalid: | |
| raise ValueError(f"alpha_quantiles must be between 0 and beta={beta}; invalid values: {invalid}") | |
| return { | |
| "csv": str(config["csv"]) if has_csv else None, | |
| "hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None, | |
| "hf_split": str(config.get("hf_split", "train")), | |
| "hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None, | |
| "outcome": str(config["outcome"]), | |
| "strata": strata, | |
| "beta": beta, | |
| "prediction_col": str(config.get("prediction_col", "probability")), | |
| "risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None, | |
| "alpha_quantiles": alpha_quantiles, | |
| "output": str(config.get("output", "runs/optimal_screening_curve.json")), | |
| } | |
| def _json_safe(value: Any) -> Any: | |
| if isinstance(value, dict): | |
| return {key: _json_safe(item) for key, item in value.items()} | |
| if isinstance(value, list | tuple): | |
| return [_json_safe(item) for item in value] | |
| if hasattr(value, "item"): | |
| return value.item() | |
| return value | |
| def calculate_from_config(config_path: Path) -> Path: | |
| config = _validate_config(_read_config(config_path)) | |
| df, dataset_label = load_dataframe( | |
| csv_path=config["csv"], | |
| hf_dataset=config["hf_dataset"], | |
| hf_split=config["hf_split"], | |
| hf_revision=config["hf_revision"], | |
| ) | |
| required_cols = {config["outcome"], *config["strata"]} | |
| if config["risk_col"]: | |
| required_cols.add(config["risk_col"]) | |
| elif config["prediction_col"] in df.columns: | |
| required_cols.add(config["prediction_col"]) | |
| missing_cols = sorted(required_cols - set(df.columns)) | |
| if missing_cols: | |
| raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}") | |
| result = compute_optimal_screening_curve( | |
| rows=df.to_dict("records"), | |
| outcome_col=config["outcome"], | |
| strata_features=config["strata"], | |
| prediction_col=config["prediction_col"], | |
| beta=config["beta"], | |
| alpha_quantiles=config["alpha_quantiles"], | |
| use_custom_risk_col=config["risk_col"], | |
| ) | |
| output_path = Path(config["output"]) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text(json.dumps(_json_safe(result), indent=2)) | |
| return output_path | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Compute an optimal screening curve from a YAML or JSON config") | |
| parser.add_argument("config", help="Path to a YAML or JSON config file") | |
| args = parser.parse_args() | |
| output_path = calculate_from_config(Path(args.config)) | |
| print(f"Wrote {output_path}") | |
| if __name__ == "__main__": | |
| main() | |