cmpatino HF Staff commited on
Commit
a1b4ce8
·
1 Parent(s): cfe9277

feat: add risk calculation gradio app

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .env
2
+ .venv/
3
+ __pycache__/
4
+ *.py[cod]
5
+ runs/
README.md CHANGED
@@ -5,9 +5,17 @@ colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
+ python_version: '3.11'
9
  app_file: app.py
10
  pinned: false
11
  ---
12
 
13
+ Gradio app for computing optimal screening risk curves from form inputs.
14
+
15
+ The default form values mirror:
16
+
17
+ ```bash
18
+ uv run calculate-risk configs/example-risk.yaml
19
+ ```
20
+
21
+ The app accepts a Hugging Face dataset, an uploaded CSV, or pasted CSV rows. Each run writes a temporary JSON file and exposes it through the download button.
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from uuid import uuid4
8
+
9
+ import gradio as gr
10
+
11
+ from optimal_screening.cli.calculate_risk import calculate_from_config
12
+
13
+
14
+ ROOT = Path(__file__).parent
15
+ SOURCE_CSV_UPLOAD = "Upload CSV"
16
+ SOURCE_CSV_PASTE = "Paste CSV"
17
+ SOURCE_HF_DATASET = "Hugging Face dataset"
18
+
19
+ DEFAULT_DATASET = "cmpatino/landmine-detection"
20
+ DEFAULT_SPLIT = "train"
21
+ DEFAULT_OUTCOME = "mines_outcome"
22
+ DEFAULT_STRATA = "Municipio"
23
+ DEFAULT_BETA = 0.1
24
+ DEFAULT_ALPHA_VALUES = "0.0, 0.05, 0.1"
25
+
26
+
27
+ def _result_summary(result: dict[str, Any], output_path: Path) -> str:
28
+ alpha_count = len(result.get("alpha_values", []))
29
+ total_samples = result.get("total_samples", "unknown")
30
+ total_positive = result.get("total_positive", "unknown")
31
+ beta = result.get("beta", "unknown")
32
+
33
+ return (
34
+ f"Computed `{alpha_count}` alpha point(s) with beta `{beta}`.\n\n"
35
+ f"Total samples: `{total_samples}` \n"
36
+ f"Total positives: `{total_positive}` \n"
37
+ f"Output file: `{output_path.name}`"
38
+ )
39
+
40
+
41
+ def _uploaded_path(uploaded_csv: Any) -> str | None:
42
+ if uploaded_csv is None:
43
+ return None
44
+ if isinstance(uploaded_csv, list):
45
+ if not uploaded_csv:
46
+ return None
47
+ uploaded_csv = uploaded_csv[0]
48
+ if isinstance(uploaded_csv, str):
49
+ return uploaded_csv
50
+ if hasattr(uploaded_csv, "name"):
51
+ return str(uploaded_csv.name)
52
+ return str(uploaded_csv)
53
+
54
+
55
+ def _parse_list(value: str, field: str) -> list[str]:
56
+ values = [item.strip() for item in value.replace("\n", ",").split(",") if item.strip()]
57
+ if not values:
58
+ raise ValueError(f"{field} must include at least one value.")
59
+ return values
60
+
61
+
62
+ def _parse_optional_float_list(value: str) -> list[float] | None:
63
+ if not value.strip():
64
+ return None
65
+ return [float(item) for item in _parse_list(value, "alpha values")]
66
+
67
+
68
+ def _optional_text(value: str | None) -> str | None:
69
+ if value is None:
70
+ return None
71
+ value = value.strip()
72
+ return value or None
73
+
74
+
75
+ def _result_filename(value: str | None) -> str:
76
+ filename = Path(value.strip()).name if value and value.strip() else "risk-results.json"
77
+ if not filename.endswith(".json"):
78
+ filename = f"{filename}.json"
79
+ return filename
80
+
81
+
82
+ def _source_visibility(source: str) -> tuple[Any, Any, Any, Any, Any]:
83
+ return (
84
+ gr.update(visible=source == SOURCE_CSV_UPLOAD),
85
+ gr.update(visible=source == SOURCE_CSV_PASTE),
86
+ gr.update(visible=source == SOURCE_HF_DATASET),
87
+ gr.update(visible=source == SOURCE_HF_DATASET),
88
+ gr.update(visible=source == SOURCE_HF_DATASET),
89
+ )
90
+
91
+
92
+ def _build_config(
93
+ *,
94
+ data_source: str,
95
+ uploaded_csv: Any,
96
+ pasted_csv: str,
97
+ hf_dataset: str,
98
+ hf_split: str,
99
+ hf_revision: str,
100
+ outcome: str,
101
+ strata: str,
102
+ beta: float,
103
+ prediction_col: str,
104
+ risk_col: str,
105
+ alpha_values: str,
106
+ result_filename: str,
107
+ run_dir: Path,
108
+ ) -> dict[str, Any]:
109
+ config: dict[str, Any] = {
110
+ "outcome": outcome.strip(),
111
+ "strata": _parse_list(strata, "strata"),
112
+ "beta": float(beta),
113
+ "output": str(run_dir / _result_filename(result_filename)),
114
+ }
115
+
116
+ if data_source == SOURCE_CSV_UPLOAD:
117
+ csv_path = _uploaded_path(uploaded_csv)
118
+ if csv_path is None:
119
+ raise ValueError("Upload a CSV file before calculating.")
120
+ config["csv"] = csv_path
121
+ elif data_source == SOURCE_CSV_PASTE:
122
+ if not pasted_csv.strip():
123
+ raise ValueError("Paste CSV data before calculating.")
124
+ pasted_csv_path = run_dir / "input.csv"
125
+ pasted_csv_path.write_text(pasted_csv.strip() + "\n")
126
+ config["csv"] = str(pasted_csv_path)
127
+ elif data_source == SOURCE_HF_DATASET:
128
+ dataset = hf_dataset.strip()
129
+ if not dataset:
130
+ raise ValueError("Hugging Face dataset is required.")
131
+ config["hf_dataset"] = dataset
132
+ config["hf_split"] = hf_split.strip() or "train"
133
+ revision = _optional_text(hf_revision)
134
+ if revision is not None:
135
+ config["hf_revision"] = revision
136
+ else:
137
+ raise ValueError(f"Unknown data source: {data_source}")
138
+
139
+ prediction = _optional_text(prediction_col)
140
+ if prediction is not None:
141
+ config["prediction_col"] = prediction
142
+
143
+ risk = _optional_text(risk_col)
144
+ if risk is not None:
145
+ config["risk_col"] = risk
146
+
147
+ alpha_quantiles = _parse_optional_float_list(alpha_values)
148
+ if alpha_quantiles is not None:
149
+ config["alpha_quantiles"] = alpha_quantiles
150
+
151
+ return config
152
+
153
+
154
+ def calculate_risk(
155
+ data_source: str,
156
+ uploaded_csv: Any,
157
+ pasted_csv: str,
158
+ hf_dataset: str,
159
+ hf_split: str,
160
+ hf_revision: str,
161
+ outcome: str,
162
+ strata: str,
163
+ beta: float,
164
+ prediction_col: str,
165
+ risk_col: str,
166
+ alpha_values: str,
167
+ result_filename: str,
168
+ ) -> tuple[str, dict[str, Any] | None, Any]:
169
+ try:
170
+ run_dir = Path(tempfile.gettempdir()) / "risk-calculation" / uuid4().hex
171
+ run_dir.mkdir(parents=True, exist_ok=True)
172
+
173
+ config = _build_config(
174
+ data_source=data_source,
175
+ uploaded_csv=uploaded_csv,
176
+ pasted_csv=pasted_csv,
177
+ hf_dataset=hf_dataset,
178
+ hf_split=hf_split,
179
+ hf_revision=hf_revision,
180
+ outcome=outcome,
181
+ strata=strata,
182
+ beta=beta,
183
+ prediction_col=prediction_col,
184
+ risk_col=risk_col,
185
+ alpha_values=alpha_values,
186
+ result_filename=result_filename,
187
+ run_dir=run_dir,
188
+ )
189
+
190
+ config_path = run_dir / "risk-config.json"
191
+ config_path.write_text(json.dumps(config, indent=2))
192
+
193
+ calculated_output_path = calculate_from_config(config_path)
194
+ result = json.loads(calculated_output_path.read_text())
195
+ return _result_summary(result, calculated_output_path), result, gr.update(
196
+ value=str(calculated_output_path),
197
+ interactive=True,
198
+ )
199
+ except Exception as exc: # noqa: BLE001 - show validation/runtime errors in the interface.
200
+ return f"Calculation failed: `{exc}`", None, gr.update(value=None, interactive=False)
201
+
202
+
203
+ with gr.Blocks(title="Risk Calculation") as demo:
204
+ gr.Markdown("# Risk Calculation")
205
+
206
+ with gr.Row():
207
+ with gr.Column(scale=2):
208
+ data_source = gr.Radio(
209
+ choices=[SOURCE_HF_DATASET, SOURCE_CSV_UPLOAD, SOURCE_CSV_PASTE],
210
+ value=SOURCE_HF_DATASET,
211
+ label="Data source",
212
+ )
213
+ uploaded_csv = gr.File(
214
+ label="Upload CSV",
215
+ file_types=[".csv"],
216
+ type="filepath",
217
+ visible=False,
218
+ )
219
+ pasted_csv = gr.Textbox(
220
+ label="Paste CSV",
221
+ lines=8,
222
+ max_lines=16,
223
+ placeholder="risk,outcome,group\n0.9,1,a\n0.1,0,b",
224
+ visible=False,
225
+ )
226
+ hf_dataset = gr.Textbox(
227
+ value=DEFAULT_DATASET,
228
+ label="Hugging Face dataset",
229
+ )
230
+ with gr.Row():
231
+ hf_split = gr.Textbox(value=DEFAULT_SPLIT, label="Split")
232
+ hf_revision = gr.Textbox(value="", label="Revision")
233
+
234
+ outcome = gr.Textbox(value=DEFAULT_OUTCOME, label="Outcome column")
235
+ strata = gr.Textbox(value=DEFAULT_STRATA, label="Strata columns")
236
+ beta = gr.Number(
237
+ value=DEFAULT_BETA,
238
+ label="Treatment budget beta",
239
+ minimum=0,
240
+ maximum=1,
241
+ step=0.01,
242
+ )
243
+ prediction_col = gr.Textbox(value="probability", label="Prediction column")
244
+ risk_col = gr.Textbox(value="", label="Risk column")
245
+ alpha_values = gr.Textbox(
246
+ value=DEFAULT_ALPHA_VALUES,
247
+ label="Alpha values",
248
+ )
249
+ result_filename = gr.Textbox(value="risk-results.json", label="Result file name")
250
+ run_button = gr.Button("Calculate risk", variant="primary")
251
+
252
+ with gr.Column(scale=3):
253
+ status_output = gr.Markdown(label="Status")
254
+ download_output = gr.DownloadButton(
255
+ label="Download results JSON",
256
+ value=None,
257
+ interactive=False,
258
+ )
259
+ result_output = gr.JSON(label="Results")
260
+
261
+ data_source.change(
262
+ fn=_source_visibility,
263
+ inputs=data_source,
264
+ outputs=[uploaded_csv, pasted_csv, hf_dataset, hf_split, hf_revision],
265
+ show_progress="hidden",
266
+ )
267
+ run_button.click(
268
+ fn=calculate_risk,
269
+ inputs=[
270
+ data_source,
271
+ uploaded_csv,
272
+ pasted_csv,
273
+ hf_dataset,
274
+ hf_split,
275
+ hf_revision,
276
+ outcome,
277
+ strata,
278
+ beta,
279
+ prediction_col,
280
+ risk_col,
281
+ alpha_values,
282
+ result_filename,
283
+ ],
284
+ outputs=[status_output, result_output, download_output],
285
+ api_name="calculate_risk",
286
+ )
287
+
288
+
289
+ if __name__ == "__main__":
290
+ demo.queue().launch()
configs/example-risk.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ hf_dataset: cmpatino/landmine-detection
2
+ hf_split: train
3
+ outcome: mines_outcome
4
+ strata:
5
+ - Municipio
6
+ beta: 0.1
7
+ alpha_quantiles: [0.0, 0.05, 0.1]
8
+ output: runs/example-risk-output.json
optimal_screening/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Camera-ready optimal screening code for paper replication."""
2
+
3
+ __all__ = []
optimal_screening/analysis/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .stratified import (
2
+ RISK_PRESETS,
3
+ SIMULATION_SIZE,
4
+ compute_empirical_probabilities,
5
+ compute_intuitive_optimal_curve,
6
+ compute_optimal_screening_curve,
7
+ compute_random_screening_curve,
8
+ generate_simulation_data,
9
+ )
10
+
11
+
12
+ __all__ = [
13
+ "RISK_PRESETS",
14
+ "SIMULATION_SIZE",
15
+ "compute_empirical_probabilities",
16
+ "compute_intuitive_optimal_curve",
17
+ "compute_optimal_screening_curve",
18
+ "compute_random_screening_curve",
19
+ "generate_simulation_data",
20
+ ]
optimal_screening/analysis/stratified.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from collections.abc import Sequence
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+
10
+ SIMULATION_SIZE = 100_000
11
+
12
+ RISK_PRESETS: dict[str, tuple[float, float] | float] = {
13
+ "uniform": (1.0, 1.0),
14
+ "bimodal": (0.1, 0.1),
15
+ "unimodal": (10.0, 10.0),
16
+ "delta_half": 0.5,
17
+ }
18
+
19
+
20
+ def compute_empirical_probabilities(
21
+ rows: list[dict[str, Any]],
22
+ outcome_col: str,
23
+ strata_features: Sequence[str],
24
+ ) -> dict[tuple[Any, ...], dict[str, Any]]:
25
+ """Compute empirical P(Y=1|X) for each feature stratum from true outcomes.
26
+
27
+ Args:
28
+ rows: List of data rows (dicts) with features and outcome
29
+ outcome_col: Name of the outcome column (e.g., "PINCP > 50k")
30
+ strata_features: List of feature names to define strata (e.g., ['AGEP', 'SEX'])
31
+
32
+ Returns:
33
+ Dictionary mapping stratum key -> {
34
+ 'probability': empirical P(Y=1|X) = (# Y=1) / (# total),
35
+ 'count': number of samples in stratum,
36
+ 'positive_count': number of Y=1 samples,
37
+ 'features': dict of feature values for this stratum
38
+ }
39
+
40
+ Example:
41
+ >>> rows = [
42
+ ... {'AGEP': 35, 'SEX': 1, 'PINCP > 50k': True},
43
+ ... {'AGEP': 35, 'SEX': 1, 'PINCP > 50k': False},
44
+ ... {'AGEP': 35, 'SEX': 1, 'PINCP > 50k': True},
45
+ ... ]
46
+ >>> strata = compute_empirical_probabilities(rows, 'PINCP > 50k', ['AGEP', 'SEX'])
47
+ >>> strata[(35, 1)]['probability']
48
+ 0.6666666666666666
49
+ >>> strata[(35, 1)]['count']
50
+ 3
51
+ """
52
+ # Group by strata and count outcomes
53
+ strata_counts: dict[tuple[Any, ...], dict[str, int]] = defaultdict(lambda: {"total": 0, "positive": 0})
54
+ strata_features_map: dict[tuple[Any, ...], dict[str, Any]] = {}
55
+
56
+ for row in rows:
57
+ # Create stratum key from selected features
58
+ stratum_key = tuple(row.get(f) for f in strata_features)
59
+
60
+ # Count outcomes
61
+ outcome_value = row.get(outcome_col)
62
+ strata_counts[stratum_key]["total"] += 1
63
+
64
+ # Convert outcome to boolean (handle "True"/"False" strings, True/False, 1/0, etc.)
65
+ if _is_positive_outcome(outcome_value):
66
+ strata_counts[stratum_key]["positive"] += 1
67
+
68
+ # Store feature values for this stratum
69
+ if stratum_key not in strata_features_map:
70
+ strata_features_map[stratum_key] = {f: row.get(f) for f in strata_features}
71
+
72
+ # Compute empirical P(Y=1|X) for each stratum
73
+ result = {}
74
+ for stratum_key, counts in strata_counts.items():
75
+ total = counts["total"]
76
+ positive = counts["positive"]
77
+
78
+ result[stratum_key] = {
79
+ "probability": positive / total if total > 0 else 0.0,
80
+ "count": total,
81
+ "positive_count": positive,
82
+ "features": strata_features_map[stratum_key],
83
+ }
84
+
85
+ return result
86
+
87
+
88
+ def _is_positive_outcome(value: Any) -> bool:
89
+ """Helper to determine if outcome value represents Y=1."""
90
+ if value is None:
91
+ return False
92
+ if isinstance(value, bool):
93
+ return value
94
+ if isinstance(value, (int, float)):
95
+ return value > 0
96
+ if isinstance(value, str):
97
+ return value.lower() in ("true", "1", "yes", "t", "y")
98
+ return False
99
+
100
+
101
+ def generate_simulation_data(
102
+ a: float | None = None,
103
+ b: float | None = None,
104
+ size: int = SIMULATION_SIZE,
105
+ seed: int | None = None,
106
+ point_mass: float | None = None,
107
+ ) -> tuple[np.ndarray, np.ndarray]:
108
+ """Generate synthetic risk scores and binary outcomes.
109
+
110
+ Supports two modes:
111
+ - **Beta-Binomial**: risk_scores ~ Beta(a, b), outcomes ~ Binomial(1, risk_scores).
112
+ - **Point mass**: all risk_scores = *point_mass*, outcomes ~ Binomial(1, point_mass).
113
+
114
+ Args:
115
+ a: Alpha parameter of the Beta distribution (ignored when *point_mass* is set).
116
+ b: Beta parameter of the Beta distribution (ignored when *point_mass* is set).
117
+ size: Number of samples to generate.
118
+ seed: Random seed for reproducibility.
119
+ point_mass: If provided, every risk score is set to this constant value.
120
+
121
+ Returns:
122
+ Tuple of (risk_scores, outcomes).
123
+ """
124
+ rng = np.random.default_rng(seed)
125
+ if point_mass is not None:
126
+ risk_scores = np.full(size, point_mass)
127
+ else:
128
+ risk_scores = rng.beta(a, b, size=size)
129
+ outcomes = rng.binomial(1, risk_scores)
130
+ return risk_scores, outcomes
131
+
132
+
133
+ def compute_optimal_screening_curve(
134
+ rows: list[dict[str, Any]],
135
+ outcome_col: str,
136
+ strata_features: Sequence[str],
137
+ prediction_col: str = "probability",
138
+ beta: float = 0.5,
139
+ alpha_quantiles: Sequence[float] | None = None,
140
+ max_iterations: int = 20,
141
+ tolerance: float = 1e-6,
142
+ seed: int | None = None,
143
+ use_custom_risk_col: str | None = None,
144
+ simulation: str | tuple[float, float] | None = None,
145
+ ) -> dict[str, Any]:
146
+ """Compute optimal screening curve with treatment budget β and screening budget α.
147
+
148
+ Band structure (highest to lowest risk):
149
+ - Band 1: Top (β - α) - Treated, model predictions
150
+ - Band 2: Next (α - avg_risk(Band 3)) - Treated, model predictions
151
+ - Band 3: Next α - Screened (true outcomes)
152
+ - Band 4: Bottom (1 - β - α + avg_risk) - Untreated (predict 0)
153
+
154
+ Uses iterative method to resolve circular dependency between Band 2 and Band 3.
155
+
156
+ Args:
157
+ rows: List of data rows with features, outcome, and predictions
158
+ outcome_col: Name of outcome column
159
+ strata_features: Features defining strata for computing empirical P(Y=1|X)
160
+ prediction_col: Column name for model predictions
161
+ beta: Treatment budget (proportion who can be treated)
162
+ alpha_quantiles: Screening budget levels to evaluate
163
+ max_iterations: Maximum iterations for convergence
164
+ tolerance: Convergence tolerance for avg_risk
165
+ seed: Random seed for uniform distribution override (for debugging)
166
+ use_custom_risk_col: If provided, use this column for risk instead of computing
167
+ empirical probabilities from strata. Useful for comparing LLM predictions
168
+ with empirical baselines.
169
+ simulation: If provided, generate synthetic data from a Beta distribution instead
170
+ of using real data. Pass a preset name ('uniform', 'bimodal', 'unimodal') or
171
+ a tuple (a, b) of Beta distribution parameters. Uses SIMULATION_SIZE samples.
172
+
173
+ Returns:
174
+ Dictionary with screening curves and band information
175
+ """
176
+ if alpha_quantiles is None:
177
+ # Default: 10 equally spaced values from 0 to beta
178
+ alpha_quantiles = [beta * i / 49 for i in range(50)]
179
+
180
+ # Assign each row its risk (simulation, custom, or empirical)
181
+ rows_with_risk = []
182
+
183
+ if simulation is not None:
184
+ # Generate synthetic data from a Beta distribution
185
+ if isinstance(simulation, str):
186
+ if simulation not in RISK_PRESETS:
187
+ raise ValueError(f"Unknown simulation preset '{simulation}'. Choose from {list(RISK_PRESETS.keys())}.")
188
+ preset = RISK_PRESETS[simulation]
189
+ else:
190
+ preset = simulation
191
+
192
+ if isinstance(preset, (int, float)):
193
+ risk_scores, outcomes = generate_simulation_data(size=SIMULATION_SIZE, seed=seed, point_mass=float(preset))
194
+ else:
195
+ a, b = preset
196
+ risk_scores, outcomes = generate_simulation_data(a, b, size=SIMULATION_SIZE, seed=seed)
197
+
198
+ for i in range(SIMULATION_SIZE):
199
+ rows_with_risk.append(
200
+ {
201
+ "row": {"_sim_index": i, "_sim_feature": 0, outcome_col: bool(outcomes[i])},
202
+ "empirical_risk": float(risk_scores[i]),
203
+ "true_outcome": bool(outcomes[i]),
204
+ "model_prediction": float(risk_scores[i]),
205
+ }
206
+ )
207
+ elif use_custom_risk_col is not None:
208
+ # Use custom risk column directly
209
+ for row in rows:
210
+ risk = row.get(use_custom_risk_col, 0.5)
211
+ rows_with_risk.append(
212
+ {
213
+ "row": row,
214
+ "empirical_risk": risk,
215
+ "true_outcome": _is_positive_outcome(row.get(outcome_col)),
216
+ "model_prediction": row.get(prediction_col, 0.5),
217
+ }
218
+ )
219
+ else:
220
+ # Compute empirical P(Y=1|X) for each stratum
221
+ empirical_probs = compute_empirical_probabilities(rows, outcome_col, strata_features)
222
+
223
+ for row in rows:
224
+ stratum_key = tuple(row.get(f) for f in strata_features)
225
+ empirical_risk = empirical_probs.get(stratum_key, {}).get("probability", 0.5)
226
+
227
+ rows_with_risk.append(
228
+ {
229
+ "row": row,
230
+ "empirical_risk": empirical_risk,
231
+ "true_outcome": _is_positive_outcome(row.get(outcome_col)),
232
+ "model_prediction": row.get(prediction_col, 0.5),
233
+ }
234
+ )
235
+
236
+ # Sort by risk (highest to lowest)
237
+ rows_with_risk.sort(key=lambda x: x["empirical_risk"], reverse=True)
238
+
239
+ total_positive = sum(1 for r in rows_with_risk if r["true_outcome"])
240
+ n = len(rows_with_risk)
241
+
242
+ # Results storage
243
+ results = {
244
+ "beta": beta,
245
+ "alpha_values": [],
246
+ "true_positives": [],
247
+ "band_info": [],
248
+ "total_positive": total_positive,
249
+ "total_samples": n,
250
+ }
251
+
252
+ for alpha in alpha_quantiles:
253
+ assert alpha <= beta, f"Screening budget α={alpha} exceeds treatment budget β={beta}"
254
+
255
+ # Iteratively find Band 3 position.
256
+ prev_avg_risk = 0.0
257
+
258
+ for _iteration in range(max_iterations):
259
+ # Compute target mass: ∫ f(risk) d(risk) = target
260
+ # Where f(risk) is the density over risk values
261
+ # For discrete: sum of (count at each risk / total count) = proportion of population at that risk
262
+ band1_target_mass = beta - alpha
263
+ # Band 2 size: ∫ 1 × f(risk) d(risk) over Band 2 = ∫ (1 - risk) × f(risk) d(risk) over Band 3
264
+ # Since Band 3 has mass α and average risk prev_avg_risk:
265
+ # ∫ (1 - risk) × f(risk) d(risk) over Band 3 = α × (1 - prev_avg_risk)
266
+ band2_target_mass = alpha * (1 - prev_avg_risk)
267
+ band3_target_mass = alpha
268
+
269
+ # Band 1: Find index where cumulative proportion of population = band1_target_mass
270
+ # This is: ∫ f(risk) d(risk) from risk=1 down to some risk threshold
271
+ cumulative_mass = 0.0
272
+ band1_end_idx = 0
273
+ for i in range(n):
274
+ # Each person contributes 1/n to the density (proportion of population)
275
+ population_contribution = 1.0 / n
276
+ cumulative_mass += population_contribution
277
+ if cumulative_mass >= band1_target_mass:
278
+ band1_end_idx = i + 1
279
+ break
280
+ if band1_end_idx == 0 and band1_target_mass > 0:
281
+ band1_end_idx = 1 # At least one person
282
+
283
+ # Band 2: Continue from Band 1 end
284
+ target_mass_band1_plus_band2 = band1_target_mass + band2_target_mass
285
+ band2_end_idx = band1_end_idx
286
+ for i in range(band1_end_idx, n):
287
+ population_contribution = 1.0 / n
288
+ cumulative_mass += population_contribution
289
+ if cumulative_mass >= target_mass_band1_plus_band2:
290
+ band2_end_idx = i + 1
291
+ break
292
+
293
+ # Band 3: Continue from Band 2 end
294
+ target_mass_band1_plus_band2_plus_band3 = band1_target_mass + band2_target_mass + band3_target_mass
295
+ band3_end_idx = band2_end_idx
296
+ for i in range(band2_end_idx, n):
297
+ population_contribution = 1.0 / n
298
+ cumulative_mass += population_contribution
299
+ if cumulative_mass >= target_mass_band1_plus_band2_plus_band3:
300
+ band3_end_idx = i + 1
301
+ break
302
+
303
+ # Ensure indices are within bounds
304
+ band1_end_idx = min(band1_end_idx, n)
305
+ band2_end_idx = min(band2_end_idx, n)
306
+ band3_end_idx = min(band3_end_idx, n)
307
+
308
+ # Compute average risk of Band 3
309
+ if band3_end_idx > band2_end_idx:
310
+ band3_risks = [rows_with_risk[i]["empirical_risk"] for i in range(band2_end_idx, band3_end_idx)]
311
+ current_avg_risk = np.mean(band3_risks) if band3_risks else 0.0
312
+ else:
313
+ current_avg_risk = 0.0
314
+
315
+ # Check convergence
316
+ if abs(current_avg_risk - prev_avg_risk) < tolerance:
317
+ break
318
+
319
+ prev_avg_risk = current_avg_risk
320
+
321
+ # Final band sizes (keep the indices from the last iteration)
322
+ # The indices are already set from the converged iteration above
323
+ avg_risk_band3 = prev_avg_risk
324
+
325
+ # Compute integrals: ∫ risk × (1/n) dx for each band (for reporting purposes)
326
+ band1_integral = sum(rows_with_risk[i]["empirical_risk"] / n for i in range(0, band1_end_idx))
327
+ band2_integral = sum(rows_with_risk[i]["empirical_risk"] / n for i in range(band1_end_idx, band2_end_idx))
328
+ band3_integral = sum(rows_with_risk[i]["empirical_risk"] / n for i in range(band2_end_idx, band3_end_idx))
329
+ band4_integral = sum(rows_with_risk[i]["empirical_risk"] / n for i in range(band3_end_idx, n))
330
+
331
+ # Population proportions = ∫ f(risk) d(risk) for each band
332
+ # This is the "mass" used for band selection
333
+ band1_pop_prop = band1_end_idx / n
334
+ band2_pop_prop = (band2_end_idx - band1_end_idx) / n
335
+ band3_pop_prop = (band3_end_idx - band2_end_idx) / n
336
+ band4_pop_prop = (n - band3_end_idx) / n
337
+
338
+ # Expected negatives in Band 3: ∫ (1 - risk) × f(risk) d(risk) over Band 3
339
+ band3_expected_negatives = sum(
340
+ (1 - rows_with_risk[i]["empirical_risk"]) / n for i in range(band2_end_idx, band3_end_idx)
341
+ )
342
+
343
+ # Count true positives in each band
344
+ tp_count = 0
345
+
346
+ # Band 1: Treated, empirical predictions
347
+ for i in range(0, band1_end_idx):
348
+ item = rows_with_risk[i]
349
+ if item["true_outcome"]:
350
+ tp_count += 1
351
+
352
+ # Band 2: Treated, empirical predictions
353
+ for i in range(band1_end_idx, band2_end_idx):
354
+ item = rows_with_risk[i]
355
+ if item["true_outcome"]:
356
+ tp_count += 1
357
+
358
+ # Band 3: Screened, use true outcomes
359
+ for i in range(band2_end_idx, band3_end_idx):
360
+ item = rows_with_risk[i]
361
+ if item["true_outcome"]:
362
+ tp_count += 1
363
+
364
+ # Band 4: Untreated, predict 0 (no TPs)
365
+ # (no contribution to tp_count)
366
+
367
+ results["alpha_values"].append(alpha)
368
+ # Enforce monotonicity: TP can never decrease as screening budget grows
369
+ tp_count = max(tp_count, results["true_positives"][-1] if results["true_positives"] else 0)
370
+ results["true_positives"].append(tp_count)
371
+ results["band_info"].append(
372
+ {
373
+ "alpha": alpha,
374
+ "band1_integral": band1_integral,
375
+ "band2_integral": band2_integral,
376
+ "band3_integral": band3_integral,
377
+ "band4_integral": band4_integral,
378
+ "band1_pop_prop": band1_pop_prop,
379
+ "band2_pop_prop": band2_pop_prop,
380
+ "band3_pop_prop": band3_pop_prop,
381
+ "band4_pop_prop": band4_pop_prop,
382
+ "band3_expected_negatives": band3_expected_negatives,
383
+ "avg_risk_band3": avg_risk_band3,
384
+ "band1_end_idx": band1_end_idx,
385
+ "band2_end_idx": band2_end_idx,
386
+ "band3_end_idx": band3_end_idx,
387
+ }
388
+ )
389
+
390
+ return results
391
+
392
+
393
+ def compute_random_screening_curve(
394
+ rows: list[dict[str, Any]],
395
+ outcome_col: str,
396
+ strata_features: Sequence[str],
397
+ prediction_col: str = "probability",
398
+ beta: float = 0.5,
399
+ alpha_quantiles: Sequence[float] | None = None,
400
+ seed: int = 42,
401
+ use_custom_risk_col: str | None = None,
402
+ simulation: str | tuple[float, float] | None = None,
403
+ ) -> dict[str, Any]:
404
+ """Compute random screening baseline curve.
405
+
406
+ This baseline screens α proportion of the population at random (instead of targeting
407
+ low-risk individuals). It treats:
408
+ 1. All screened individuals with Y=1 (true positive outcome)
409
+ 2. From unscreened, treats top (β + prop_screened_negatives - prop_screened_positives) by risk
410
+
411
+ The intuition: by randomly screening, we identify some negatives and don't waste treatment
412
+ budget on them, allowing us to treat more high-risk unscreened individuals.
413
+
414
+ Args:
415
+ rows: List of data rows with features, outcome, and predictions
416
+ outcome_col: Name of outcome column
417
+ strata_features: Features defining strata (used for risk scoring)
418
+ prediction_col: Column name for model predictions
419
+ beta: Treatment budget (proportion who can be treated)
420
+ alpha_quantiles: Screening budget levels to evaluate
421
+ seed: Random seed for reproducible random screening
422
+ use_custom_risk_col: If provided, use this column for risk instead of empirical
423
+ simulation: If provided, generate synthetic data from a Beta distribution instead
424
+ of using real data. Pass a preset name ('uniform', 'bimodal', 'unimodal') or
425
+ a tuple (a, b) of Beta distribution parameters. Uses SIMULATION_SIZE samples.
426
+
427
+ Returns:
428
+ Dictionary with screening curves
429
+ """
430
+ if alpha_quantiles is None:
431
+ alpha_quantiles = [beta * i / 49 for i in range(50)]
432
+
433
+ # Assign each row its risk (simulation, custom, or empirical)
434
+ rows_with_risk = []
435
+
436
+ if simulation is not None:
437
+ # Generate synthetic data from a Beta distribution
438
+ if isinstance(simulation, str):
439
+ if simulation not in RISK_PRESETS:
440
+ raise ValueError(f"Unknown simulation preset '{simulation}'. Choose from {list(RISK_PRESETS.keys())}.")
441
+ preset = RISK_PRESETS[simulation]
442
+ else:
443
+ preset = simulation
444
+
445
+ if isinstance(preset, (int, float)):
446
+ risk_scores, outcomes = generate_simulation_data(size=SIMULATION_SIZE, seed=seed, point_mass=float(preset))
447
+ else:
448
+ a, b = preset
449
+ risk_scores, outcomes = generate_simulation_data(a, b, size=SIMULATION_SIZE, seed=seed)
450
+
451
+ for i in range(SIMULATION_SIZE):
452
+ rows_with_risk.append(
453
+ {
454
+ "row": {"_sim_index": i, "_sim_feature": 0, outcome_col: bool(outcomes[i])},
455
+ "empirical_risk": float(risk_scores[i]),
456
+ "true_outcome": bool(outcomes[i]),
457
+ "model_prediction": float(risk_scores[i]),
458
+ }
459
+ )
460
+ elif use_custom_risk_col is not None:
461
+ # Use custom risk column directly
462
+ for row in rows:
463
+ risk = row.get(use_custom_risk_col, 0.5)
464
+ rows_with_risk.append(
465
+ {
466
+ "row": row,
467
+ "empirical_risk": risk,
468
+ "true_outcome": _is_positive_outcome(row.get(outcome_col)),
469
+ "model_prediction": row.get(prediction_col, 0.5),
470
+ }
471
+ )
472
+ else:
473
+ # Compute empirical P(Y=1|X) for each stratum
474
+ empirical_probs = compute_empirical_probabilities(rows, outcome_col, strata_features)
475
+
476
+ for row in rows:
477
+ stratum_key = tuple(row.get(f) for f in strata_features)
478
+ empirical_risk = empirical_probs.get(stratum_key, {}).get("probability", 0.5)
479
+
480
+ rows_with_risk.append(
481
+ {
482
+ "row": row,
483
+ "empirical_risk": empirical_risk,
484
+ "true_outcome": _is_positive_outcome(row.get(outcome_col)),
485
+ "model_prediction": row.get(prediction_col, 0.5),
486
+ }
487
+ )
488
+
489
+ total_positive = sum(1 for r in rows_with_risk if r["true_outcome"])
490
+ n = len(rows_with_risk)
491
+
492
+ # Results storage
493
+ results = {
494
+ "beta": beta,
495
+ "alpha_values": [],
496
+ "true_positives": [],
497
+ "total_positive": total_positive,
498
+ "total_samples": n,
499
+ }
500
+
501
+ # Set random seed for reproducibility — use a single permutation so that
502
+ # screened sets are nested (larger α always includes the smaller α set).
503
+ rng = np.random.RandomState(seed)
504
+ random_order = rng.permutation(n)
505
+
506
+ for alpha in alpha_quantiles:
507
+ assert alpha <= beta, f"Screening budget α={alpha} exceeds treatment budget β={beta}"
508
+ # Screen α proportion uniformly at random
509
+ n_screen = min(int(alpha * n), n)
510
+ n_treat = int(beta * n)
511
+
512
+ screened_indices = set(random_order[:n_screen])
513
+
514
+ # Identify screened positives (gamma mass)
515
+ screened_positive_indices = {idx for idx in screened_indices if rows_with_risk[idx]["true_outcome"]}
516
+ gamma_count = len(screened_positive_indices)
517
+
518
+ # Treat screened positives up to budget
519
+ tp_from_screening = min(gamma_count, n_treat)
520
+ remaining_budget = max(0, n_treat - tp_from_screening)
521
+
522
+ # Pool for risk-based treatment: everyone except screened positives
523
+ pool = [(idx, rows_with_risk[idx]) for idx in range(n) if idx not in screened_positive_indices]
524
+ pool.sort(key=lambda x: x[1]["empirical_risk"], reverse=True)
525
+
526
+ # Treat top (β - γ) mass by risk score
527
+ n_treat_by_risk = min(remaining_budget, len(pool))
528
+ tp_from_risk = sum(1 for i in range(n_treat_by_risk) if pool[i][1]["true_outcome"])
529
+
530
+ tp_count = tp_from_screening + tp_from_risk
531
+ results["alpha_values"].append(alpha)
532
+ results["true_positives"].append(tp_count)
533
+
534
+ return results
535
+
536
+
537
+ def compute_intuitive_optimal_curve(
538
+ rows: list[dict[str, Any]],
539
+ outcome_col: str,
540
+ strata_features: Sequence[str],
541
+ prediction_col: str = "probability",
542
+ beta: float = 0.5,
543
+ alpha_quantiles: Sequence[float] | None = None,
544
+ seed: int | None = None,
545
+ use_custom_risk_col: str | None = None,
546
+ simulation: str | tuple[float, float] | None = None,
547
+ ) -> dict[str, Any]:
548
+ """Compute intuitive-optimal screening curve.
549
+
550
+ Algorithm (all bands are adjacent slices of the risk-sorted population):
551
+ 1. Band A: treat the top (β − α) mass by risk (highest risk, no screening).
552
+ 2. Band B: screen the next α mass. Let γ ≤ α be the mass of screened
553
+ individuals with Y=0. Screened Y=1 are treated; screened Y=0 are not.
554
+ 3. Band C: treat the next γ mass below the screened band (replaces the
555
+ screened negatives, preserving total treatment budget = β).
556
+
557
+ Args:
558
+ rows: List of data rows (ignored when *simulation* is set).
559
+ outcome_col: Name of outcome column.
560
+ strata_features: Features defining strata.
561
+ prediction_col: Column name for model predictions.
562
+ beta: Treatment budget (proportion who can be treated).
563
+ alpha_quantiles: Screening budget levels to evaluate.
564
+ seed: Random seed for simulation mode.
565
+ use_custom_risk_col: Use this column for risk instead of empirical.
566
+ simulation: Preset name or (a, b) Beta parameters for synthetic data.
567
+
568
+ Returns:
569
+ Dictionary with alpha_values, true_positives, total_positive, total_samples.
570
+ """
571
+ if alpha_quantiles is None:
572
+ alpha_quantiles = [beta * i / 49 for i in range(50)]
573
+
574
+ # --- Build rows_with_risk (same logic as compute_optimal_screening_curve) ---
575
+ rows_with_risk = []
576
+
577
+ if simulation is not None:
578
+ if isinstance(simulation, str):
579
+ if simulation not in RISK_PRESETS:
580
+ raise ValueError(f"Unknown simulation preset '{simulation}'. Choose from {list(RISK_PRESETS.keys())}.")
581
+ preset = RISK_PRESETS[simulation]
582
+ else:
583
+ preset = simulation
584
+
585
+ if isinstance(preset, (int, float)):
586
+ risk_scores, outcomes = generate_simulation_data(size=SIMULATION_SIZE, seed=seed, point_mass=float(preset))
587
+ else:
588
+ a, b = preset
589
+ risk_scores, outcomes = generate_simulation_data(a, b, size=SIMULATION_SIZE, seed=seed)
590
+
591
+ for i in range(SIMULATION_SIZE):
592
+ rows_with_risk.append(
593
+ {
594
+ "row": {"_sim_index": i, "_sim_feature": 0, outcome_col: bool(outcomes[i])},
595
+ "empirical_risk": float(risk_scores[i]),
596
+ "true_outcome": bool(outcomes[i]),
597
+ "model_prediction": float(risk_scores[i]),
598
+ }
599
+ )
600
+ elif use_custom_risk_col is not None:
601
+ for row in rows:
602
+ risk = row.get(use_custom_risk_col, 0.5)
603
+ rows_with_risk.append(
604
+ {
605
+ "row": row,
606
+ "empirical_risk": risk,
607
+ "true_outcome": _is_positive_outcome(row.get(outcome_col)),
608
+ "model_prediction": row.get(prediction_col, 0.5),
609
+ }
610
+ )
611
+ else:
612
+ empirical_probs = compute_empirical_probabilities(rows, outcome_col, strata_features)
613
+ for row in rows:
614
+ stratum_key = tuple(row.get(f) for f in strata_features)
615
+ empirical_risk = empirical_probs.get(stratum_key, {}).get("probability", 0.5)
616
+ rows_with_risk.append(
617
+ {
618
+ "row": row,
619
+ "empirical_risk": empirical_risk,
620
+ "true_outcome": _is_positive_outcome(row.get(outcome_col)),
621
+ "model_prediction": row.get(prediction_col, 0.5),
622
+ }
623
+ )
624
+
625
+ # Sort by risk (highest to lowest)
626
+ rows_with_risk.sort(key=lambda x: x["empirical_risk"], reverse=True)
627
+
628
+ total_positive = sum(1 for r in rows_with_risk if r["true_outcome"])
629
+ n = len(rows_with_risk)
630
+
631
+ results = {
632
+ "beta": beta,
633
+ "alpha_values": [],
634
+ "true_positives": [],
635
+ "total_positive": total_positive,
636
+ "total_samples": n,
637
+ }
638
+
639
+ for alpha in alpha_quantiles:
640
+ assert alpha <= beta, f"Screening budget α={alpha} exceeds treatment budget β={beta}"
641
+
642
+ band_a_end = int((beta - alpha) * n)
643
+ band_b_end = band_a_end + int(alpha * n)
644
+
645
+ # Band A: treated by risk
646
+ tp_band_a = 0
647
+ for i in range(band_a_end):
648
+ item = rows_with_risk[i]
649
+ if item["true_outcome"]:
650
+ tp_band_a += 1
651
+
652
+ # Band B: screened — Y=1 treated, Y=0 not treated
653
+ tp_band_b = 0
654
+ gamma_count = 0 # number of screened with Y=0
655
+ for i in range(band_a_end, band_b_end):
656
+ item = rows_with_risk[i]
657
+ if item["true_outcome"]:
658
+ tp_band_b += 1
659
+ else:
660
+ gamma_count += 1
661
+
662
+ # Band C: next gamma_count individuals treated by risk
663
+ band_c_end = min(band_b_end + gamma_count, n)
664
+ tp_band_c = 0
665
+ for i in range(band_b_end, band_c_end):
666
+ item = rows_with_risk[i]
667
+ if item["true_outcome"]:
668
+ tp_band_c += 1
669
+
670
+ tp_count = tp_band_a + tp_band_b + tp_band_c
671
+ results["alpha_values"].append(alpha)
672
+ results["true_positives"].append(tp_count)
673
+
674
+ return results
optimal_screening/cli/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Command-line entry points for the camera-ready paper code."""
2
+
3
+ __all__ = []
optimal_screening/cli/calculate_risk.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import yaml
9
+
10
+ from optimal_screening.analysis import compute_optimal_screening_curve
11
+ from optimal_screening.data_sources import load_dataframe
12
+
13
+
14
+ REQUIRED_FIELDS = {"outcome", "strata", "beta"}
15
+
16
+
17
+ def _read_config(path: Path) -> dict[str, Any]:
18
+ if not path.exists():
19
+ raise FileNotFoundError(f"Config file not found: {path}")
20
+
21
+ text = path.read_text()
22
+ if path.suffix.lower() == ".json":
23
+ data = json.loads(text)
24
+ elif path.suffix.lower() in {".yaml", ".yml"}:
25
+ data = yaml.safe_load(text)
26
+ else:
27
+ raise ValueError("Config file must be YAML or JSON")
28
+
29
+ if not isinstance(data, dict):
30
+ raise ValueError("Config must be a mapping")
31
+ return data
32
+
33
+
34
+ def _as_float_sequence(values: Any, field: str) -> list[float] | None:
35
+ if values is None:
36
+ return None
37
+ if not isinstance(values, list | tuple):
38
+ raise ValueError(f"{field} must be a list of numbers")
39
+ return [float(value) for value in values]
40
+
41
+
42
+ def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
43
+ missing = sorted(REQUIRED_FIELDS - set(config))
44
+ if missing:
45
+ raise ValueError(f"Missing required config fields: {missing}")
46
+
47
+ has_csv = config.get("csv") is not None
48
+ has_hf_dataset = config.get("hf_dataset") is not None
49
+ if has_csv == has_hf_dataset:
50
+ raise ValueError("Config must provide exactly one data source: csv or hf_dataset")
51
+
52
+ strata = config["strata"]
53
+ if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata):
54
+ raise ValueError("strata must be a non-empty list of column names")
55
+
56
+ beta = float(config["beta"])
57
+ if not 0 < beta <= 1:
58
+ raise ValueError("beta must be in the interval (0, 1]")
59
+
60
+ alpha_quantiles = _as_float_sequence(config.get("alpha_quantiles"), "alpha_quantiles")
61
+ if alpha_quantiles is not None:
62
+ invalid = [alpha for alpha in alpha_quantiles if alpha < 0 or alpha > beta]
63
+ if invalid:
64
+ raise ValueError(f"alpha_quantiles must be between 0 and beta={beta}; invalid values: {invalid}")
65
+
66
+ return {
67
+ "csv": str(config["csv"]) if has_csv else None,
68
+ "hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None,
69
+ "hf_split": str(config.get("hf_split", "train")),
70
+ "hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None,
71
+ "outcome": str(config["outcome"]),
72
+ "strata": strata,
73
+ "beta": beta,
74
+ "prediction_col": str(config.get("prediction_col", "probability")),
75
+ "risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None,
76
+ "alpha_quantiles": alpha_quantiles,
77
+ "output": str(config.get("output", "runs/optimal_screening_curve.json")),
78
+ }
79
+
80
+
81
+ def _json_safe(value: Any) -> Any:
82
+ if isinstance(value, dict):
83
+ return {key: _json_safe(item) for key, item in value.items()}
84
+ if isinstance(value, list | tuple):
85
+ return [_json_safe(item) for item in value]
86
+ if hasattr(value, "item"):
87
+ return value.item()
88
+ return value
89
+
90
+
91
+ def calculate_from_config(config_path: Path) -> Path:
92
+ config = _validate_config(_read_config(config_path))
93
+
94
+ df, dataset_label = load_dataframe(
95
+ csv_path=config["csv"],
96
+ hf_dataset=config["hf_dataset"],
97
+ hf_split=config["hf_split"],
98
+ hf_revision=config["hf_revision"],
99
+ )
100
+
101
+ required_cols = {config["outcome"], *config["strata"]}
102
+ if config["risk_col"]:
103
+ required_cols.add(config["risk_col"])
104
+ elif config["prediction_col"] in df.columns:
105
+ required_cols.add(config["prediction_col"])
106
+
107
+ missing_cols = sorted(required_cols - set(df.columns))
108
+ if missing_cols:
109
+ raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}")
110
+
111
+ result = compute_optimal_screening_curve(
112
+ rows=df.to_dict("records"),
113
+ outcome_col=config["outcome"],
114
+ strata_features=config["strata"],
115
+ prediction_col=config["prediction_col"],
116
+ beta=config["beta"],
117
+ alpha_quantiles=config["alpha_quantiles"],
118
+ use_custom_risk_col=config["risk_col"],
119
+ )
120
+
121
+ output_path = Path(config["output"])
122
+ output_path.parent.mkdir(parents=True, exist_ok=True)
123
+ output_path.write_text(json.dumps(_json_safe(result), indent=2))
124
+ return output_path
125
+
126
+
127
+ def main() -> None:
128
+ parser = argparse.ArgumentParser(description="Compute an optimal screening curve from a YAML or JSON config")
129
+ parser.add_argument("config", help="Path to a YAML or JSON config file")
130
+ args = parser.parse_args()
131
+
132
+ output_path = calculate_from_config(Path(args.config))
133
+ print(f"Wrote {output_path}")
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
optimal_screening/data_sources.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ from datasets import load_dataset
7
+
8
+
9
+ def load_hf_dataframe(dataset: str, split: str = "train", revision: str | None = None) -> pd.DataFrame:
10
+ """Load a tabular Hugging Face dataset split as a pandas DataFrame."""
11
+ kwargs = {"path": dataset, "split": split}
12
+ if revision is not None:
13
+ kwargs["revision"] = revision
14
+ return load_dataset(**kwargs).to_pandas()
15
+
16
+
17
+ def load_dataframe(
18
+ *,
19
+ csv_path: str | None = None,
20
+ hf_dataset: str | None = None,
21
+ hf_split: str = "train",
22
+ hf_revision: str | None = None,
23
+ ) -> tuple[pd.DataFrame, str]:
24
+ """Load a DataFrame from exactly one supported source and return a source label."""
25
+ sources = [source is not None for source in (csv_path, hf_dataset)]
26
+ if sum(sources) != 1:
27
+ raise ValueError("Provide exactly one data source: csv_path or hf_dataset")
28
+
29
+ if hf_dataset is not None:
30
+ return load_hf_dataframe(hf_dataset, split=hf_split, revision=hf_revision), hf_dataset
31
+
32
+ assert csv_path is not None
33
+ path = Path(csv_path)
34
+ if not path.exists():
35
+ raise FileNotFoundError(f"CSV file not found: {path}")
36
+ return pd.read_csv(path), str(path)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==6.14.0
2
+ PyYAML>=6.0
3
+ datasets>=2.18
4
+ numpy>=1.24
5
+ pandas>=2.0