cmpatino HF Staff commited on
Commit
09f8c96
·
1 Parent(s): 2ebff59

Update app for optimal screening CSV output

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Optimal Screening Calculation
3
  emoji: 🐠
4
  colorFrom: yellow
5
  colorTo: indigo
@@ -10,12 +10,12 @@ app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- Gradio app for computing optimal screening risk curves from form inputs.
14
 
15
  The default form values mirror:
16
 
17
  ```bash
18
- uv run calculate-risk configs/example-risk.yaml
19
  ```
20
 
21
- The app accepts a Hugging Face dataset, an uploaded CSV, or pasted CSV rows. Each run writes a temporary JSON file and exposes it through the download button.
 
1
  ---
2
+ title: Optimal Screening Decisions
3
  emoji: 🐠
4
  colorFrom: yellow
5
  colorTo: indigo
 
10
  pinned: false
11
  ---
12
 
13
+ Gradio app for writing optimal screening decisions from form inputs.
14
 
15
  The default form values mirror:
16
 
17
  ```bash
18
+ uv run get-optimal-screening configs/example-risk.yaml
19
  ```
20
 
21
+ The app accepts a Hugging Face dataset, an uploaded CSV, or pasted CSV rows. Each run writes a temporary CSV file with an added screening decision column and exposes it through the download button.
app.py CHANGED
@@ -7,11 +7,11 @@ from typing import Any
7
  from uuid import uuid4
8
 
9
  import gradio as gr
 
10
 
11
- from optimal_screening.cli.calculate_risk import calculate_from_config
12
 
13
 
14
- ROOT = Path(__file__).parent
15
  SOURCE_CSV_UPLOAD = "Upload CSV"
16
  SOURCE_CSV_PASTE = "Paste CSV"
17
  SOURCE_HF_DATASET = "Hugging Face dataset"
@@ -21,21 +21,8 @@ DEFAULT_SPLIT = "train"
21
  DEFAULT_OUTCOME = "mines_outcome"
22
  DEFAULT_STRATA = "Municipio"
23
  DEFAULT_BETA = 0.1
24
- DEFAULT_ALPHA_VALUES = "0.0, 0.05, 0.1"
25
-
26
-
27
- def _result_summary(result: dict[str, Any], output_path: Path) -> str:
28
- alpha_count = len(result.get("alpha_values", []))
29
- total_samples = result.get("total_samples", "unknown")
30
- total_positive = result.get("total_positive", "unknown")
31
- beta = result.get("beta", "unknown")
32
-
33
- return (
34
- f"Computed `{alpha_count}` alpha point(s) with beta `{beta}`.\n\n"
35
- f"Total samples: `{total_samples}` \n"
36
- f"Total positives: `{total_positive}` \n"
37
- f"Output file: `{output_path.name}`"
38
- )
39
 
40
 
41
  def _uploaded_path(uploaded_csv: Any) -> str | None:
@@ -59,12 +46,6 @@ def _parse_list(value: str, field: str) -> list[str]:
59
  return values
60
 
61
 
62
- def _parse_optional_float_list(value: str) -> list[float] | None:
63
- if not value.strip():
64
- return None
65
- return [float(item) for item in _parse_list(value, "alpha values")]
66
-
67
-
68
  def _optional_text(value: str | None) -> str | None:
69
  if value is None:
70
  return None
@@ -72,10 +53,10 @@ def _optional_text(value: str | None) -> str | None:
72
  return value or None
73
 
74
 
75
- def _result_filename(value: str | None) -> str:
76
- filename = Path(value.strip()).name if value and value.strip() else "risk-results.json"
77
- if not filename.endswith(".json"):
78
- filename = f"{filename}.json"
79
  return filename
80
 
81
 
@@ -100,27 +81,29 @@ def _build_config(
100
  outcome: str,
101
  strata: str,
102
  beta: float,
 
103
  prediction_col: str,
104
  risk_col: str,
105
- alpha_values: str,
106
- result_filename: str,
107
  run_dir: Path,
108
  ) -> dict[str, Any]:
109
  config: dict[str, Any] = {
110
  "outcome": outcome.strip(),
111
  "strata": _parse_list(strata, "strata"),
112
  "beta": float(beta),
113
- "output": str(run_dir / _result_filename(result_filename)),
 
114
  }
115
 
116
  if data_source == SOURCE_CSV_UPLOAD:
117
  csv_path = _uploaded_path(uploaded_csv)
118
  if csv_path is None:
119
- raise ValueError("Upload a CSV file before calculating.")
120
  config["csv"] = csv_path
121
  elif data_source == SOURCE_CSV_PASTE:
122
  if not pasted_csv.strip():
123
- raise ValueError("Paste CSV data before calculating.")
124
  pasted_csv_path = run_dir / "input.csv"
125
  pasted_csv_path.write_text(pasted_csv.strip() + "\n")
126
  config["csv"] = str(pasted_csv_path)
@@ -144,14 +127,30 @@ def _build_config(
144
  if risk is not None:
145
  config["risk_col"] = risk
146
 
147
- alpha_quantiles = _parse_optional_float_list(alpha_values)
148
- if alpha_quantiles is not None:
149
- config["alpha_quantiles"] = alpha_quantiles
150
 
151
  return config
152
 
153
 
154
- def calculate_risk(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  data_source: str,
156
  uploaded_csv: Any,
157
  pasted_csv: str,
@@ -161,13 +160,14 @@ def calculate_risk(
161
  outcome: str,
162
  strata: str,
163
  beta: float,
 
164
  prediction_col: str,
165
  risk_col: str,
166
- alpha_values: str,
167
- result_filename: str,
168
- ) -> tuple[str, dict[str, Any] | None, Any]:
169
  try:
170
- run_dir = Path(tempfile.gettempdir()) / "risk-calculation" / uuid4().hex
171
  run_dir.mkdir(parents=True, exist_ok=True)
172
 
173
  config = _build_config(
@@ -180,28 +180,26 @@ def calculate_risk(
180
  outcome=outcome,
181
  strata=strata,
182
  beta=beta,
 
183
  prediction_col=prediction_col,
184
  risk_col=risk_col,
185
- alpha_values=alpha_values,
186
- result_filename=result_filename,
187
  run_dir=run_dir,
188
  )
189
 
190
- config_path = run_dir / "risk-config.json"
191
  config_path.write_text(json.dumps(config, indent=2))
192
 
193
- calculated_output_path = calculate_from_config(config_path)
194
- result = json.loads(calculated_output_path.read_text())
195
- return _result_summary(result, calculated_output_path), result, gr.update(
196
- value=str(calculated_output_path),
197
- interactive=True,
198
- )
199
  except Exception as exc: # noqa: BLE001 - show validation/runtime errors in the interface.
200
- return f"Calculation failed: `{exc}`", None, gr.update(value=None, interactive=False)
201
 
202
 
203
- with gr.Blocks(title="Risk Calculation") as demo:
204
- gr.Markdown("# Risk Calculation")
205
 
206
  with gr.Row():
207
  with gr.Column(scale=2):
@@ -233,30 +231,35 @@ with gr.Blocks(title="Risk Calculation") as demo:
233
 
234
  outcome = gr.Textbox(value=DEFAULT_OUTCOME, label="Outcome column")
235
  strata = gr.Textbox(value=DEFAULT_STRATA, label="Strata columns")
236
- beta = gr.Number(
237
- value=DEFAULT_BETA,
238
- label="Treatment budget beta",
239
- minimum=0,
240
- maximum=1,
241
- step=0.01,
242
- )
 
 
 
 
 
 
 
 
243
  prediction_col = gr.Textbox(value="probability", label="Prediction column")
244
  risk_col = gr.Textbox(value="", label="Risk column")
245
- alpha_values = gr.Textbox(
246
- value=DEFAULT_ALPHA_VALUES,
247
- label="Alpha values",
248
- )
249
- result_filename = gr.Textbox(value="risk-results.json", label="Result file name")
250
- run_button = gr.Button("Calculate risk", variant="primary")
251
 
252
  with gr.Column(scale=3):
253
  status_output = gr.Markdown(label="Status")
254
  download_output = gr.DownloadButton(
255
- label="Download results JSON",
256
  value=None,
257
  interactive=False,
258
  )
259
- result_output = gr.JSON(label="Results")
260
 
261
  data_source.change(
262
  fn=_source_visibility,
@@ -265,7 +268,7 @@ with gr.Blocks(title="Risk Calculation") as demo:
265
  show_progress="hidden",
266
  )
267
  run_button.click(
268
- fn=calculate_risk,
269
  inputs=[
270
  data_source,
271
  uploaded_csv,
@@ -276,13 +279,14 @@ with gr.Blocks(title="Risk Calculation") as demo:
276
  outcome,
277
  strata,
278
  beta,
 
279
  prediction_col,
280
  risk_col,
281
- alpha_values,
282
- result_filename,
283
  ],
284
- outputs=[status_output, result_output, download_output],
285
- api_name="calculate_risk",
286
  )
287
 
288
 
 
7
  from uuid import uuid4
8
 
9
  import gradio as gr
10
+ import pandas as pd
11
 
12
+ from optimal_screening.cli.get_optimal_screening import get_optimal_screening_from_config
13
 
14
 
 
15
  SOURCE_CSV_UPLOAD = "Upload CSV"
16
  SOURCE_CSV_PASTE = "Paste CSV"
17
  SOURCE_HF_DATASET = "Hugging Face dataset"
 
21
  DEFAULT_OUTCOME = "mines_outcome"
22
  DEFAULT_STRATA = "Municipio"
23
  DEFAULT_BETA = 0.1
24
+ DEFAULT_ALPHA = 0.05
25
+ DEFAULT_ACTION_COL = "screening_decision"
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def _uploaded_path(uploaded_csv: Any) -> str | None:
 
46
  return values
47
 
48
 
 
 
 
 
 
 
49
  def _optional_text(value: str | None) -> str | None:
50
  if value is None:
51
  return None
 
53
  return value or None
54
 
55
 
56
+ def _csv_filename(value: str | None) -> str:
57
+ filename = Path(value.strip()).name if value and value.strip() else "optimal-screening.csv"
58
+ if not filename.endswith(".csv"):
59
+ filename = f"{filename}.csv"
60
  return filename
61
 
62
 
 
81
  outcome: str,
82
  strata: str,
83
  beta: float,
84
+ alpha: float,
85
  prediction_col: str,
86
  risk_col: str,
87
+ action_col: str,
88
+ output_filename: str,
89
  run_dir: Path,
90
  ) -> dict[str, Any]:
91
  config: dict[str, Any] = {
92
  "outcome": outcome.strip(),
93
  "strata": _parse_list(strata, "strata"),
94
  "beta": float(beta),
95
+ "alpha": float(alpha),
96
+ "output": str(run_dir / _csv_filename(output_filename)),
97
  }
98
 
99
  if data_source == SOURCE_CSV_UPLOAD:
100
  csv_path = _uploaded_path(uploaded_csv)
101
  if csv_path is None:
102
+ raise ValueError("Upload a CSV file before running.")
103
  config["csv"] = csv_path
104
  elif data_source == SOURCE_CSV_PASTE:
105
  if not pasted_csv.strip():
106
+ raise ValueError("Paste CSV data before running.")
107
  pasted_csv_path = run_dir / "input.csv"
108
  pasted_csv_path.write_text(pasted_csv.strip() + "\n")
109
  config["csv"] = str(pasted_csv_path)
 
127
  if risk is not None:
128
  config["risk_col"] = risk
129
 
130
+ action = _optional_text(action_col)
131
+ if action is not None:
132
+ config["action_col"] = action
133
 
134
  return config
135
 
136
 
137
+ def _result_summary(output_path: Path, action_col: str) -> tuple[str, pd.DataFrame]:
138
+ df = pd.read_csv(output_path)
139
+ summary_lines = [
140
+ f"Wrote `{output_path.name}`.",
141
+ "",
142
+ f"Rows: `{len(df)}`",
143
+ ]
144
+
145
+ if action_col in df.columns:
146
+ counts = df[action_col].value_counts().sort_index()
147
+ count_text = ", ".join(f"{int(action)}: {int(count)}" for action, count in counts.items())
148
+ summary_lines.append(f"{action_col}: `{count_text}`")
149
+
150
+ return "\n".join(summary_lines), df.head(100)
151
+
152
+
153
+ def get_optimal_screening(
154
  data_source: str,
155
  uploaded_csv: Any,
156
  pasted_csv: str,
 
160
  outcome: str,
161
  strata: str,
162
  beta: float,
163
+ alpha: float,
164
  prediction_col: str,
165
  risk_col: str,
166
+ action_col: str,
167
+ output_filename: str,
168
+ ) -> tuple[str, pd.DataFrame | None, Any]:
169
  try:
170
+ run_dir = Path(tempfile.gettempdir()) / "optimal-screening" / uuid4().hex
171
  run_dir.mkdir(parents=True, exist_ok=True)
172
 
173
  config = _build_config(
 
180
  outcome=outcome,
181
  strata=strata,
182
  beta=beta,
183
+ alpha=alpha,
184
  prediction_col=prediction_col,
185
  risk_col=risk_col,
186
+ action_col=action_col,
187
+ output_filename=output_filename,
188
  run_dir=run_dir,
189
  )
190
 
191
+ config_path = run_dir / "optimal-screening-config.json"
192
  config_path.write_text(json.dumps(config, indent=2))
193
 
194
+ output_path = get_optimal_screening_from_config(config_path)
195
+ summary, preview = _result_summary(output_path, config.get("action_col", DEFAULT_ACTION_COL))
196
+ return summary, preview, gr.update(value=str(output_path), interactive=True)
 
 
 
197
  except Exception as exc: # noqa: BLE001 - show validation/runtime errors in the interface.
198
+ return f"Run failed: `{exc}`", None, gr.update(value=None, interactive=False)
199
 
200
 
201
+ with gr.Blocks(title="Optimal Screening Decisions") as demo:
202
+ gr.Markdown("# Optimal Screening Decisions")
203
 
204
  with gr.Row():
205
  with gr.Column(scale=2):
 
231
 
232
  outcome = gr.Textbox(value=DEFAULT_OUTCOME, label="Outcome column")
233
  strata = gr.Textbox(value=DEFAULT_STRATA, label="Strata columns")
234
+ with gr.Row():
235
+ beta = gr.Number(
236
+ value=DEFAULT_BETA,
237
+ label="Treatment budget beta",
238
+ minimum=0,
239
+ maximum=1,
240
+ step=0.01,
241
+ )
242
+ alpha = gr.Number(
243
+ value=DEFAULT_ALPHA,
244
+ label="Screening budget alpha",
245
+ minimum=0,
246
+ maximum=1,
247
+ step=0.01,
248
+ )
249
  prediction_col = gr.Textbox(value="probability", label="Prediction column")
250
  risk_col = gr.Textbox(value="", label="Risk column")
251
+ action_col = gr.Textbox(value=DEFAULT_ACTION_COL, label="Action column")
252
+ output_filename = gr.Textbox(value="optimal-screening.csv", label="Output file name")
253
+ run_button = gr.Button("Run", variant="primary")
 
 
 
254
 
255
  with gr.Column(scale=3):
256
  status_output = gr.Markdown(label="Status")
257
  download_output = gr.DownloadButton(
258
+ label="Download CSV",
259
  value=None,
260
  interactive=False,
261
  )
262
+ preview_output = gr.Dataframe(label="CSV preview", interactive=False)
263
 
264
  data_source.change(
265
  fn=_source_visibility,
 
268
  show_progress="hidden",
269
  )
270
  run_button.click(
271
+ fn=get_optimal_screening,
272
  inputs=[
273
  data_source,
274
  uploaded_csv,
 
279
  outcome,
280
  strata,
281
  beta,
282
+ alpha,
283
  prediction_col,
284
  risk_col,
285
+ action_col,
286
+ output_filename,
287
  ],
288
+ outputs=[status_output, preview_output, download_output],
289
+ api_name="get_optimal_screening",
290
  )
291
 
292
 
configs/example-risk.yaml CHANGED
@@ -4,5 +4,5 @@ outcome: mines_outcome
4
  strata:
5
  - Municipio
6
  beta: 0.1
7
- alpha_quantiles: [0.0, 0.05, 0.1]
8
- output: runs/example-risk-output.json
 
4
  strata:
5
  - Municipio
6
  beta: 0.1
7
+ alpha: 0.05
8
+ output: runs/example-risk-output.csv
optimal_screening/analysis/__init__.py CHANGED
@@ -3,6 +3,7 @@ from .stratified import (
3
  SIMULATION_SIZE,
4
  compute_empirical_probabilities,
5
  compute_intuitive_optimal_curve,
 
6
  compute_optimal_screening_curve,
7
  compute_random_screening_curve,
8
  generate_simulation_data,
@@ -14,6 +15,7 @@ __all__ = [
14
  "SIMULATION_SIZE",
15
  "compute_empirical_probabilities",
16
  "compute_intuitive_optimal_curve",
 
17
  "compute_optimal_screening_curve",
18
  "compute_random_screening_curve",
19
  "generate_simulation_data",
 
3
  SIMULATION_SIZE,
4
  compute_empirical_probabilities,
5
  compute_intuitive_optimal_curve,
6
+ compute_optimal_screening_actions,
7
  compute_optimal_screening_curve,
8
  compute_random_screening_curve,
9
  generate_simulation_data,
 
15
  "SIMULATION_SIZE",
16
  "compute_empirical_probabilities",
17
  "compute_intuitive_optimal_curve",
18
+ "compute_optimal_screening_actions",
19
  "compute_optimal_screening_curve",
20
  "compute_random_screening_curve",
21
  "generate_simulation_data",
optimal_screening/analysis/stratified.py CHANGED
@@ -130,54 +130,15 @@ def generate_simulation_data(
130
  return risk_scores, outcomes
131
 
132
 
133
- def compute_optimal_screening_curve(
134
  rows: list[dict[str, Any]],
135
  outcome_col: str,
136
  strata_features: Sequence[str],
137
- prediction_col: str = "probability",
138
- beta: float = 0.5,
139
- alpha_quantiles: Sequence[float] | None = None,
140
- max_iterations: int = 20,
141
- tolerance: float = 1e-6,
142
- seed: int | None = None,
143
- use_custom_risk_col: str | None = None,
144
- simulation: str | tuple[float, float] | None = None,
145
- ) -> dict[str, Any]:
146
- """Compute optimal screening curve with treatment budget β and screening budget α.
147
-
148
- Band structure (highest to lowest risk):
149
- - Band 1: Top (β - α) - Treated, model predictions
150
- - Band 2: Next (α - avg_risk(Band 3)) - Treated, model predictions
151
- - Band 3: Next α - Screened (true outcomes)
152
- - Band 4: Bottom (1 - β - α + avg_risk) - Untreated (predict 0)
153
-
154
- Uses iterative method to resolve circular dependency between Band 2 and Band 3.
155
-
156
- Args:
157
- rows: List of data rows with features, outcome, and predictions
158
- outcome_col: Name of outcome column
159
- strata_features: Features defining strata for computing empirical P(Y=1|X)
160
- prediction_col: Column name for model predictions
161
- beta: Treatment budget (proportion who can be treated)
162
- alpha_quantiles: Screening budget levels to evaluate
163
- max_iterations: Maximum iterations for convergence
164
- tolerance: Convergence tolerance for avg_risk
165
- seed: Random seed for uniform distribution override (for debugging)
166
- use_custom_risk_col: If provided, use this column for risk instead of computing
167
- empirical probabilities from strata. Useful for comparing LLM predictions
168
- with empirical baselines.
169
- simulation: If provided, generate synthetic data from a Beta distribution instead
170
- of using real data. Pass a preset name ('uniform', 'bimodal', 'unimodal') or
171
- a tuple (a, b) of Beta distribution parameters. Uses SIMULATION_SIZE samples.
172
-
173
- Returns:
174
- Dictionary with screening curves and band information
175
- """
176
- if alpha_quantiles is None:
177
- # Default: 10 equally spaced values from 0 to beta
178
- alpha_quantiles = [beta * i / 49 for i in range(50)]
179
-
180
- # Assign each row its risk (simulation, custom, or empirical)
181
  rows_with_risk = []
182
 
183
  if simulation is not None:
@@ -202,11 +163,12 @@ def compute_optimal_screening_curve(
202
  "empirical_risk": float(risk_scores[i]),
203
  "true_outcome": bool(outcomes[i]),
204
  "model_prediction": float(risk_scores[i]),
 
205
  }
206
  )
207
  elif use_custom_risk_col is not None:
208
  # Use custom risk column directly
209
- for row in rows:
210
  risk = row.get(use_custom_risk_col, 0.5)
211
  rows_with_risk.append(
212
  {
@@ -214,13 +176,14 @@ def compute_optimal_screening_curve(
214
  "empirical_risk": risk,
215
  "true_outcome": _is_positive_outcome(row.get(outcome_col)),
216
  "model_prediction": row.get(prediction_col, 0.5),
 
217
  }
218
  )
219
  else:
220
  # Compute empirical P(Y=1|X) for each stratum
221
  empirical_probs = compute_empirical_probabilities(rows, outcome_col, strata_features)
222
 
223
- for row in rows:
224
  stratum_key = tuple(row.get(f) for f in strata_features)
225
  empirical_risk = empirical_probs.get(stratum_key, {}).get("probability", 0.5)
226
 
@@ -230,9 +193,192 @@ def compute_optimal_screening_curve(
230
  "empirical_risk": empirical_risk,
231
  "true_outcome": _is_positive_outcome(row.get(outcome_col)),
232
  "model_prediction": row.get(prediction_col, 0.5),
 
233
  }
234
  )
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  # Sort by risk (highest to lowest)
237
  rows_with_risk.sort(key=lambda x: x["empirical_risk"], reverse=True)
238
 
@@ -250,77 +396,13 @@ def compute_optimal_screening_curve(
250
  }
251
 
252
  for alpha in alpha_quantiles:
253
- assert alpha <= beta, f"Screening budget α={alpha} exceeds treatment budget β={beta}"
254
-
255
- # Iteratively find Band 3 position.
256
- prev_avg_risk = 0.0
257
-
258
- for _iteration in range(max_iterations):
259
- # Compute target mass: ∫ f(risk) d(risk) = target
260
- # Where f(risk) is the density over risk values
261
- # For discrete: sum of (count at each risk / total count) = proportion of population at that risk
262
- band1_target_mass = beta - alpha
263
- # Band 2 size: ∫ 1 × f(risk) d(risk) over Band 2 = ∫ (1 - risk) × f(risk) d(risk) over Band 3
264
- # Since Band 3 has mass α and average risk prev_avg_risk:
265
- # ∫ (1 - risk) × f(risk) d(risk) over Band 3 = α × (1 - prev_avg_risk)
266
- band2_target_mass = alpha * (1 - prev_avg_risk)
267
- band3_target_mass = alpha
268
-
269
- # Band 1: Find index where cumulative proportion of population = band1_target_mass
270
- # This is: ∫ f(risk) d(risk) from risk=1 down to some risk threshold
271
- cumulative_mass = 0.0
272
- band1_end_idx = 0
273
- for i in range(n):
274
- # Each person contributes 1/n to the density (proportion of population)
275
- population_contribution = 1.0 / n
276
- cumulative_mass += population_contribution
277
- if cumulative_mass >= band1_target_mass:
278
- band1_end_idx = i + 1
279
- break
280
- if band1_end_idx == 0 and band1_target_mass > 0:
281
- band1_end_idx = 1 # At least one person
282
-
283
- # Band 2: Continue from Band 1 end
284
- target_mass_band1_plus_band2 = band1_target_mass + band2_target_mass
285
- band2_end_idx = band1_end_idx
286
- for i in range(band1_end_idx, n):
287
- population_contribution = 1.0 / n
288
- cumulative_mass += population_contribution
289
- if cumulative_mass >= target_mass_band1_plus_band2:
290
- band2_end_idx = i + 1
291
- break
292
-
293
- # Band 3: Continue from Band 2 end
294
- target_mass_band1_plus_band2_plus_band3 = band1_target_mass + band2_target_mass + band3_target_mass
295
- band3_end_idx = band2_end_idx
296
- for i in range(band2_end_idx, n):
297
- population_contribution = 1.0 / n
298
- cumulative_mass += population_contribution
299
- if cumulative_mass >= target_mass_band1_plus_band2_plus_band3:
300
- band3_end_idx = i + 1
301
- break
302
-
303
- # Ensure indices are within bounds
304
- band1_end_idx = min(band1_end_idx, n)
305
- band2_end_idx = min(band2_end_idx, n)
306
- band3_end_idx = min(band3_end_idx, n)
307
-
308
- # Compute average risk of Band 3
309
- if band3_end_idx > band2_end_idx:
310
- band3_risks = [rows_with_risk[i]["empirical_risk"] for i in range(band2_end_idx, band3_end_idx)]
311
- current_avg_risk = np.mean(band3_risks) if band3_risks else 0.0
312
- else:
313
- current_avg_risk = 0.0
314
-
315
- # Check convergence
316
- if abs(current_avg_risk - prev_avg_risk) < tolerance:
317
- break
318
-
319
- prev_avg_risk = current_avg_risk
320
-
321
- # Final band sizes (keep the indices from the last iteration)
322
- # The indices are already set from the converged iteration above
323
- avg_risk_band3 = prev_avg_risk
324
 
325
  # Compute integrals: ∫ risk × (1/n) dx for each band (for reporting purposes)
326
  band1_integral = sum(rows_with_risk[i]["empirical_risk"] / n for i in range(0, band1_end_idx))
 
130
  return risk_scores, outcomes
131
 
132
 
133
+ def _build_rows_with_risk(
134
  rows: list[dict[str, Any]],
135
  outcome_col: str,
136
  strata_features: Sequence[str],
137
+ prediction_col: str,
138
+ seed: int | None,
139
+ use_custom_risk_col: str | None,
140
+ simulation: str | tuple[float, float] | None,
141
+ ) -> list[dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  rows_with_risk = []
143
 
144
  if simulation is not None:
 
163
  "empirical_risk": float(risk_scores[i]),
164
  "true_outcome": bool(outcomes[i]),
165
  "model_prediction": float(risk_scores[i]),
166
+ "_input_index": i,
167
  }
168
  )
169
  elif use_custom_risk_col is not None:
170
  # Use custom risk column directly
171
+ for input_index, row in enumerate(rows):
172
  risk = row.get(use_custom_risk_col, 0.5)
173
  rows_with_risk.append(
174
  {
 
176
  "empirical_risk": risk,
177
  "true_outcome": _is_positive_outcome(row.get(outcome_col)),
178
  "model_prediction": row.get(prediction_col, 0.5),
179
+ "_input_index": input_index,
180
  }
181
  )
182
  else:
183
  # Compute empirical P(Y=1|X) for each stratum
184
  empirical_probs = compute_empirical_probabilities(rows, outcome_col, strata_features)
185
 
186
+ for input_index, row in enumerate(rows):
187
  stratum_key = tuple(row.get(f) for f in strata_features)
188
  empirical_risk = empirical_probs.get(stratum_key, {}).get("probability", 0.5)
189
 
 
193
  "empirical_risk": empirical_risk,
194
  "true_outcome": _is_positive_outcome(row.get(outcome_col)),
195
  "model_prediction": row.get(prediction_col, 0.5),
196
+ "_input_index": input_index,
197
  }
198
  )
199
 
200
+ return rows_with_risk
201
+
202
+
203
+ def _end_index_for_target_mass(n: int, target_mass: float) -> int:
204
+ if n <= 0 or target_mass <= 0:
205
+ return 0
206
+
207
+ cumulative_mass = 0.0
208
+ for i in range(n):
209
+ cumulative_mass += 1.0 / n
210
+ if cumulative_mass >= target_mass:
211
+ return i + 1
212
+ return n
213
+
214
+
215
+ def _find_optimal_band_indices(
216
+ rows_with_risk: list[dict[str, Any]],
217
+ beta: float,
218
+ alpha: float,
219
+ max_iterations: int,
220
+ tolerance: float,
221
+ ) -> tuple[int, int, int, float]:
222
+ assert alpha <= beta, f"Screening budget α={alpha} exceeds treatment budget β={beta}"
223
+
224
+ n = len(rows_with_risk)
225
+ if n == 0:
226
+ return 0, 0, 0, 0.0
227
+
228
+ prev_avg_risk = 0.0
229
+ band1_end_idx = 0
230
+ band2_end_idx = 0
231
+ band3_end_idx = 0
232
+ avg_risk_band3 = 0.0
233
+
234
+ for _iteration in range(max_iterations):
235
+ # Compute target mass: ∫ f(risk) d(risk) = target
236
+ # Where f(risk) is the density over risk values
237
+ # For discrete: sum of (count at each risk / total count) = proportion of population at that risk
238
+ band1_target_mass = beta - alpha
239
+ # Band 2 size: ∫ 1 × f(risk) d(risk) over Band 2 = ∫ (1 - risk) × f(risk) d(risk) over Band 3
240
+ # Since Band 3 has mass α and average risk prev_avg_risk:
241
+ # ∫ (1 - risk) × f(risk) d(risk) over Band 3 = α × (1 - prev_avg_risk)
242
+ band2_target_mass = alpha * (1 - prev_avg_risk)
243
+ band3_target_mass = alpha
244
+
245
+ band1_end_idx = _end_index_for_target_mass(n, band1_target_mass)
246
+ band2_end_idx = _end_index_for_target_mass(n, band1_target_mass + band2_target_mass)
247
+ band3_end_idx = _end_index_for_target_mass(n, band1_target_mass + band2_target_mass + band3_target_mass)
248
+
249
+ # Ensure indices are ordered and within bounds
250
+ band1_end_idx = min(band1_end_idx, n)
251
+ band2_end_idx = max(band1_end_idx, min(band2_end_idx, n))
252
+ band3_end_idx = max(band2_end_idx, min(band3_end_idx, n))
253
+
254
+ # Compute average risk of Band 3
255
+ if band3_end_idx > band2_end_idx:
256
+ band3_risks = [rows_with_risk[i]["empirical_risk"] for i in range(band2_end_idx, band3_end_idx)]
257
+ current_avg_risk = np.mean(band3_risks) if band3_risks else 0.0
258
+ else:
259
+ current_avg_risk = 0.0
260
+
261
+ avg_risk_band3 = current_avg_risk
262
+
263
+ # Check convergence
264
+ if abs(current_avg_risk - prev_avg_risk) < tolerance:
265
+ break
266
+
267
+ prev_avg_risk = current_avg_risk
268
+
269
+ return band1_end_idx, band2_end_idx, band3_end_idx, avg_risk_band3
270
+
271
+
272
+ def compute_optimal_screening_actions(
273
+ rows: list[dict[str, Any]],
274
+ outcome_col: str,
275
+ strata_features: Sequence[str],
276
+ prediction_col: str = "probability",
277
+ beta: float = 0.5,
278
+ alpha: float = 0.0,
279
+ max_iterations: int = 20,
280
+ tolerance: float = 1e-6,
281
+ seed: int | None = None,
282
+ use_custom_risk_col: str | None = None,
283
+ simulation: str | tuple[float, float] | None = None,
284
+ ) -> list[int]:
285
+ """Compute one optimal screening allocation.
286
+
287
+ Returns one action per input row, preserving input order:
288
+ - 0: ignore
289
+ - 1: treat directly
290
+ - 2: screen
291
+ """
292
+ rows_with_risk = _build_rows_with_risk(
293
+ rows=rows,
294
+ outcome_col=outcome_col,
295
+ strata_features=strata_features,
296
+ prediction_col=prediction_col,
297
+ seed=seed,
298
+ use_custom_risk_col=use_custom_risk_col,
299
+ simulation=simulation,
300
+ )
301
+ rows_with_risk.sort(key=lambda x: x["empirical_risk"], reverse=True)
302
+
303
+ _band1_end_idx, band2_end_idx, band3_end_idx, _avg_risk_band3 = _find_optimal_band_indices(
304
+ rows_with_risk=rows_with_risk,
305
+ beta=beta,
306
+ alpha=alpha,
307
+ max_iterations=max_iterations,
308
+ tolerance=tolerance,
309
+ )
310
+
311
+ actions_by_input_index: dict[int, int] = {}
312
+ for sorted_index, item in enumerate(rows_with_risk):
313
+ if sorted_index < band2_end_idx:
314
+ action = 1
315
+ elif sorted_index < band3_end_idx:
316
+ action = 2
317
+ else:
318
+ action = 0
319
+ actions_by_input_index[item["_input_index"]] = action
320
+
321
+ return [actions_by_input_index[i] for i in range(len(rows_with_risk))]
322
+
323
+
324
+ def compute_optimal_screening_curve(
325
+ rows: list[dict[str, Any]],
326
+ outcome_col: str,
327
+ strata_features: Sequence[str],
328
+ prediction_col: str = "probability",
329
+ beta: float = 0.5,
330
+ alpha_quantiles: Sequence[float] | None = None,
331
+ max_iterations: int = 20,
332
+ tolerance: float = 1e-6,
333
+ seed: int | None = None,
334
+ use_custom_risk_col: str | None = None,
335
+ simulation: str | tuple[float, float] | None = None,
336
+ ) -> dict[str, Any]:
337
+ """Compute optimal screening curve with treatment budget β and screening budget α.
338
+
339
+ Band structure (highest to lowest risk):
340
+ - Band 1: Top (β - α) - Treated, model predictions
341
+ - Band 2: Next (α - avg_risk(Band 3)) - Treated, model predictions
342
+ - Band 3: Next α - Screened (true outcomes)
343
+ - Band 4: Bottom (1 - β - α + avg_risk) - Untreated (predict 0)
344
+
345
+ Uses iterative method to resolve circular dependency between Band 2 and Band 3.
346
+
347
+ Args:
348
+ rows: List of data rows with features, outcome, and predictions
349
+ outcome_col: Name of outcome column
350
+ strata_features: Features defining strata for computing empirical P(Y=1|X)
351
+ prediction_col: Column name for model predictions
352
+ beta: Treatment budget (proportion who can be treated)
353
+ alpha_quantiles: Screening budget levels to evaluate
354
+ max_iterations: Maximum iterations for convergence
355
+ tolerance: Convergence tolerance for avg_risk
356
+ seed: Random seed for uniform distribution override (for debugging)
357
+ use_custom_risk_col: If provided, use this column for risk instead of computing
358
+ empirical probabilities from strata. Useful for comparing LLM predictions
359
+ with empirical baselines.
360
+ simulation: If provided, generate synthetic data from a Beta distribution instead
361
+ of using real data. Pass a preset name ('uniform', 'bimodal', 'unimodal') or
362
+ a tuple (a, b) of Beta distribution parameters. Uses SIMULATION_SIZE samples.
363
+
364
+ Returns:
365
+ Dictionary with screening curves and band information
366
+ """
367
+ if alpha_quantiles is None:
368
+ # Default: 10 equally spaced values from 0 to beta
369
+ alpha_quantiles = [beta * i / 49 for i in range(50)]
370
+
371
+ # Assign each row its risk (simulation, custom, or empirical)
372
+ rows_with_risk = _build_rows_with_risk(
373
+ rows=rows,
374
+ outcome_col=outcome_col,
375
+ strata_features=strata_features,
376
+ prediction_col=prediction_col,
377
+ seed=seed,
378
+ use_custom_risk_col=use_custom_risk_col,
379
+ simulation=simulation,
380
+ )
381
+
382
  # Sort by risk (highest to lowest)
383
  rows_with_risk.sort(key=lambda x: x["empirical_risk"], reverse=True)
384
 
 
396
  }
397
 
398
  for alpha in alpha_quantiles:
399
+ band1_end_idx, band2_end_idx, band3_end_idx, avg_risk_band3 = _find_optimal_band_indices(
400
+ rows_with_risk=rows_with_risk,
401
+ beta=beta,
402
+ alpha=alpha,
403
+ max_iterations=max_iterations,
404
+ tolerance=tolerance,
405
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  # Compute integrals: ∫ risk × (1/n) dx for each band (for reporting purposes)
408
  band1_integral = sum(rows_with_risk[i]["empirical_risk"] / n for i in range(0, band1_end_idx))
optimal_screening/cli/get_optimal_screening.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import yaml
9
+
10
+ from optimal_screening.analysis import compute_optimal_screening_actions
11
+ from optimal_screening.data_sources import load_dataframe
12
+
13
+
14
+ REQUIRED_FIELDS = {"alpha", "beta", "outcome", "strata"}
15
+ DEFAULT_ACTION_COL = "screening_decision"
16
+
17
+
18
+ def _read_config(path: Path) -> dict[str, Any]:
19
+ if not path.exists():
20
+ raise FileNotFoundError(f"Config file not found: {path}")
21
+
22
+ text = path.read_text()
23
+ if path.suffix.lower() == ".json":
24
+ data = json.loads(text)
25
+ elif path.suffix.lower() in {".yaml", ".yml"}:
26
+ data = yaml.safe_load(text)
27
+ else:
28
+ raise ValueError("Config file must be YAML or JSON")
29
+
30
+ if not isinstance(data, dict):
31
+ raise ValueError("Config must be a mapping")
32
+ return data
33
+
34
+
35
+ def _validate_config(config: dict[str, Any]) -> dict[str, Any]:
36
+ if "alpha_quantiles" in config:
37
+ raise ValueError("Use alpha for one screening budget; alpha_quantiles is only for curve outputs")
38
+
39
+ missing = sorted(REQUIRED_FIELDS - set(config))
40
+ if missing:
41
+ raise ValueError(f"Missing required config fields: {missing}")
42
+
43
+ has_csv = config.get("csv") is not None
44
+ has_hf_dataset = config.get("hf_dataset") is not None
45
+ if has_csv == has_hf_dataset:
46
+ raise ValueError("Config must provide exactly one data source: csv or hf_dataset")
47
+
48
+ strata = config["strata"]
49
+ if not isinstance(strata, list) or not strata or not all(isinstance(item, str) for item in strata):
50
+ raise ValueError("strata must be a non-empty list of column names")
51
+
52
+ beta = float(config["beta"])
53
+ if not 0 < beta <= 1:
54
+ raise ValueError("beta must be in the interval (0, 1]")
55
+
56
+ alpha = float(config["alpha"])
57
+ if not 0 <= alpha <= beta:
58
+ raise ValueError(f"alpha must be between 0 and beta={beta}")
59
+
60
+ action_col = str(config.get("action_col", DEFAULT_ACTION_COL))
61
+ if not action_col:
62
+ raise ValueError("action_col must not be empty")
63
+
64
+ return {
65
+ "csv": str(config["csv"]) if has_csv else None,
66
+ "hf_dataset": str(config["hf_dataset"]) if has_hf_dataset else None,
67
+ "hf_split": str(config.get("hf_split", "train")),
68
+ "hf_revision": str(config["hf_revision"]) if config.get("hf_revision") is not None else None,
69
+ "outcome": str(config["outcome"]),
70
+ "strata": strata,
71
+ "beta": beta,
72
+ "alpha": alpha,
73
+ "prediction_col": str(config.get("prediction_col", "probability")),
74
+ "risk_col": str(config["risk_col"]) if config.get("risk_col") is not None else None,
75
+ "action_col": action_col,
76
+ "output": str(config.get("output", "runs/optimal_screening.csv")),
77
+ }
78
+
79
+
80
+ def get_optimal_screening_from_config(config_path: Path) -> Path:
81
+ config = _validate_config(_read_config(config_path))
82
+
83
+ df, dataset_label = load_dataframe(
84
+ csv_path=config["csv"],
85
+ hf_dataset=config["hf_dataset"],
86
+ hf_split=config["hf_split"],
87
+ hf_revision=config["hf_revision"],
88
+ )
89
+
90
+ required_cols = {config["outcome"], *config["strata"]}
91
+ if config["risk_col"]:
92
+ required_cols.add(config["risk_col"])
93
+ elif config["prediction_col"] in df.columns:
94
+ required_cols.add(config["prediction_col"])
95
+
96
+ missing_cols = sorted(required_cols - set(df.columns))
97
+ if missing_cols:
98
+ raise ValueError(f"Missing required columns in {dataset_label}: {missing_cols}")
99
+
100
+ if config["action_col"] in df.columns:
101
+ raise ValueError(f"Output action column already exists in {dataset_label}: {config['action_col']}")
102
+
103
+ df[config["action_col"]] = compute_optimal_screening_actions(
104
+ rows=df.to_dict("records"),
105
+ outcome_col=config["outcome"],
106
+ strata_features=config["strata"],
107
+ prediction_col=config["prediction_col"],
108
+ beta=config["beta"],
109
+ alpha=config["alpha"],
110
+ use_custom_risk_col=config["risk_col"],
111
+ )
112
+
113
+ output_path = Path(config["output"])
114
+ output_path.parent.mkdir(parents=True, exist_ok=True)
115
+ df.to_csv(output_path, index=False)
116
+ return output_path
117
+
118
+
119
+ def main() -> None:
120
+ parser = argparse.ArgumentParser(description="Write optimal screening actions from a YAML or JSON config")
121
+ parser.add_argument("config", help="Path to a YAML or JSON config file")
122
+ args = parser.parse_args()
123
+
124
+ output_path = get_optimal_screening_from_config(Path(args.config))
125
+ print(f"Wrote {output_path}")
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()