cmpatino's picture
cmpatino HF Staff
Update app for optimal screening CSV output
09f8c96
from __future__ import annotations
import json
import tempfile
from pathlib import Path
from typing import Any
from uuid import uuid4
import gradio as gr
import pandas as pd
from optimal_screening.cli.get_optimal_screening import get_optimal_screening_from_config
SOURCE_CSV_UPLOAD = "Upload CSV"
SOURCE_CSV_PASTE = "Paste CSV"
SOURCE_HF_DATASET = "Hugging Face dataset"
DEFAULT_DATASET = "cmpatino/landmine-detection"
DEFAULT_SPLIT = "train"
DEFAULT_OUTCOME = "mines_outcome"
DEFAULT_STRATA = "Municipio"
DEFAULT_BETA = 0.1
DEFAULT_ALPHA = 0.05
DEFAULT_ACTION_COL = "screening_decision"
def _uploaded_path(uploaded_csv: Any) -> str | None:
if uploaded_csv is None:
return None
if isinstance(uploaded_csv, list):
if not uploaded_csv:
return None
uploaded_csv = uploaded_csv[0]
if isinstance(uploaded_csv, str):
return uploaded_csv
if hasattr(uploaded_csv, "name"):
return str(uploaded_csv.name)
return str(uploaded_csv)
def _parse_list(value: str, field: str) -> list[str]:
values = [item.strip() for item in value.replace("\n", ",").split(",") if item.strip()]
if not values:
raise ValueError(f"{field} must include at least one value.")
return values
def _optional_text(value: str | None) -> str | None:
if value is None:
return None
value = value.strip()
return value or None
def _csv_filename(value: str | None) -> str:
filename = Path(value.strip()).name if value and value.strip() else "optimal-screening.csv"
if not filename.endswith(".csv"):
filename = f"{filename}.csv"
return filename
def _source_visibility(source: str) -> tuple[Any, Any, Any, Any, Any]:
return (
gr.update(visible=source == SOURCE_CSV_UPLOAD),
gr.update(visible=source == SOURCE_CSV_PASTE),
gr.update(visible=source == SOURCE_HF_DATASET),
gr.update(visible=source == SOURCE_HF_DATASET),
gr.update(visible=source == SOURCE_HF_DATASET),
)
def _build_config(
*,
data_source: str,
uploaded_csv: Any,
pasted_csv: str,
hf_dataset: str,
hf_split: str,
hf_revision: str,
outcome: str,
strata: str,
beta: float,
alpha: float,
prediction_col: str,
risk_col: str,
action_col: str,
output_filename: str,
run_dir: Path,
) -> dict[str, Any]:
config: dict[str, Any] = {
"outcome": outcome.strip(),
"strata": _parse_list(strata, "strata"),
"beta": float(beta),
"alpha": float(alpha),
"output": str(run_dir / _csv_filename(output_filename)),
}
if data_source == SOURCE_CSV_UPLOAD:
csv_path = _uploaded_path(uploaded_csv)
if csv_path is None:
raise ValueError("Upload a CSV file before running.")
config["csv"] = csv_path
elif data_source == SOURCE_CSV_PASTE:
if not pasted_csv.strip():
raise ValueError("Paste CSV data before running.")
pasted_csv_path = run_dir / "input.csv"
pasted_csv_path.write_text(pasted_csv.strip() + "\n")
config["csv"] = str(pasted_csv_path)
elif data_source == SOURCE_HF_DATASET:
dataset = hf_dataset.strip()
if not dataset:
raise ValueError("Hugging Face dataset is required.")
config["hf_dataset"] = dataset
config["hf_split"] = hf_split.strip() or "train"
revision = _optional_text(hf_revision)
if revision is not None:
config["hf_revision"] = revision
else:
raise ValueError(f"Unknown data source: {data_source}")
prediction = _optional_text(prediction_col)
if prediction is not None:
config["prediction_col"] = prediction
risk = _optional_text(risk_col)
if risk is not None:
config["risk_col"] = risk
action = _optional_text(action_col)
if action is not None:
config["action_col"] = action
return config
def _result_summary(output_path: Path, action_col: str) -> tuple[str, pd.DataFrame]:
df = pd.read_csv(output_path)
summary_lines = [
f"Wrote `{output_path.name}`.",
"",
f"Rows: `{len(df)}`",
]
if action_col in df.columns:
counts = df[action_col].value_counts().sort_index()
count_text = ", ".join(f"{int(action)}: {int(count)}" for action, count in counts.items())
summary_lines.append(f"{action_col}: `{count_text}`")
return "\n".join(summary_lines), df.head(100)
def get_optimal_screening(
data_source: str,
uploaded_csv: Any,
pasted_csv: str,
hf_dataset: str,
hf_split: str,
hf_revision: str,
outcome: str,
strata: str,
beta: float,
alpha: float,
prediction_col: str,
risk_col: str,
action_col: str,
output_filename: str,
) -> tuple[str, pd.DataFrame | None, Any]:
try:
run_dir = Path(tempfile.gettempdir()) / "optimal-screening" / uuid4().hex
run_dir.mkdir(parents=True, exist_ok=True)
config = _build_config(
data_source=data_source,
uploaded_csv=uploaded_csv,
pasted_csv=pasted_csv,
hf_dataset=hf_dataset,
hf_split=hf_split,
hf_revision=hf_revision,
outcome=outcome,
strata=strata,
beta=beta,
alpha=alpha,
prediction_col=prediction_col,
risk_col=risk_col,
action_col=action_col,
output_filename=output_filename,
run_dir=run_dir,
)
config_path = run_dir / "optimal-screening-config.json"
config_path.write_text(json.dumps(config, indent=2))
output_path = get_optimal_screening_from_config(config_path)
summary, preview = _result_summary(output_path, config.get("action_col", DEFAULT_ACTION_COL))
return summary, preview, gr.update(value=str(output_path), interactive=True)
except Exception as exc: # noqa: BLE001 - show validation/runtime errors in the interface.
return f"Run failed: `{exc}`", None, gr.update(value=None, interactive=False)
with gr.Blocks(title="Optimal Screening Decisions") as demo:
gr.Markdown("# Optimal Screening Decisions")
with gr.Row():
with gr.Column(scale=2):
data_source = gr.Radio(
choices=[SOURCE_HF_DATASET, SOURCE_CSV_UPLOAD, SOURCE_CSV_PASTE],
value=SOURCE_HF_DATASET,
label="Data source",
)
uploaded_csv = gr.File(
label="Upload CSV",
file_types=[".csv"],
type="filepath",
visible=False,
)
pasted_csv = gr.Textbox(
label="Paste CSV",
lines=8,
max_lines=16,
placeholder="risk,outcome,group\n0.9,1,a\n0.1,0,b",
visible=False,
)
hf_dataset = gr.Textbox(
value=DEFAULT_DATASET,
label="Hugging Face dataset",
)
with gr.Row():
hf_split = gr.Textbox(value=DEFAULT_SPLIT, label="Split")
hf_revision = gr.Textbox(value="", label="Revision")
outcome = gr.Textbox(value=DEFAULT_OUTCOME, label="Outcome column")
strata = gr.Textbox(value=DEFAULT_STRATA, label="Strata columns")
with gr.Row():
beta = gr.Number(
value=DEFAULT_BETA,
label="Treatment budget beta",
minimum=0,
maximum=1,
step=0.01,
)
alpha = gr.Number(
value=DEFAULT_ALPHA,
label="Screening budget alpha",
minimum=0,
maximum=1,
step=0.01,
)
prediction_col = gr.Textbox(value="probability", label="Prediction column")
risk_col = gr.Textbox(value="", label="Risk column")
action_col = gr.Textbox(value=DEFAULT_ACTION_COL, label="Action column")
output_filename = gr.Textbox(value="optimal-screening.csv", label="Output file name")
run_button = gr.Button("Run", variant="primary")
with gr.Column(scale=3):
status_output = gr.Markdown(label="Status")
download_output = gr.DownloadButton(
label="Download CSV",
value=None,
interactive=False,
)
preview_output = gr.Dataframe(label="CSV preview", interactive=False)
data_source.change(
fn=_source_visibility,
inputs=data_source,
outputs=[uploaded_csv, pasted_csv, hf_dataset, hf_split, hf_revision],
show_progress="hidden",
)
run_button.click(
fn=get_optimal_screening,
inputs=[
data_source,
uploaded_csv,
pasted_csv,
hf_dataset,
hf_split,
hf_revision,
outcome,
strata,
beta,
alpha,
prediction_col,
risk_col,
action_col,
output_filename,
],
outputs=[status_output, preview_output, download_output],
api_name="get_optimal_screening",
)
if __name__ == "__main__":
demo.queue().launch()