from __future__ import annotations import json import tempfile from pathlib import Path from typing import Any from uuid import uuid4 import gradio as gr import pandas as pd from optimal_screening.cli.get_optimal_screening import get_optimal_screening_from_config SOURCE_CSV_UPLOAD = "Upload CSV" SOURCE_CSV_PASTE = "Paste CSV" SOURCE_HF_DATASET = "Hugging Face dataset" DEFAULT_DATASET = "cmpatino/landmine-detection" DEFAULT_SPLIT = "train" DEFAULT_OUTCOME = "mines_outcome" DEFAULT_STRATA = "Municipio" DEFAULT_BETA = 0.1 DEFAULT_ALPHA = 0.05 DEFAULT_ACTION_COL = "screening_decision" def _uploaded_path(uploaded_csv: Any) -> str | None: if uploaded_csv is None: return None if isinstance(uploaded_csv, list): if not uploaded_csv: return None uploaded_csv = uploaded_csv[0] if isinstance(uploaded_csv, str): return uploaded_csv if hasattr(uploaded_csv, "name"): return str(uploaded_csv.name) return str(uploaded_csv) def _parse_list(value: str, field: str) -> list[str]: values = [item.strip() for item in value.replace("\n", ",").split(",") if item.strip()] if not values: raise ValueError(f"{field} must include at least one value.") return values def _optional_text(value: str | None) -> str | None: if value is None: return None value = value.strip() return value or None def _csv_filename(value: str | None) -> str: filename = Path(value.strip()).name if value and value.strip() else "optimal-screening.csv" if not filename.endswith(".csv"): filename = f"{filename}.csv" return filename def _source_visibility(source: str) -> tuple[Any, Any, Any, Any, Any]: return ( gr.update(visible=source == SOURCE_CSV_UPLOAD), gr.update(visible=source == SOURCE_CSV_PASTE), gr.update(visible=source == SOURCE_HF_DATASET), gr.update(visible=source == SOURCE_HF_DATASET), gr.update(visible=source == SOURCE_HF_DATASET), ) def _build_config( *, data_source: str, uploaded_csv: Any, pasted_csv: str, hf_dataset: str, hf_split: str, hf_revision: str, outcome: str, strata: str, beta: float, alpha: float, prediction_col: str, risk_col: str, action_col: str, output_filename: str, run_dir: Path, ) -> dict[str, Any]: config: dict[str, Any] = { "outcome": outcome.strip(), "strata": _parse_list(strata, "strata"), "beta": float(beta), "alpha": float(alpha), "output": str(run_dir / _csv_filename(output_filename)), } if data_source == SOURCE_CSV_UPLOAD: csv_path = _uploaded_path(uploaded_csv) if csv_path is None: raise ValueError("Upload a CSV file before running.") config["csv"] = csv_path elif data_source == SOURCE_CSV_PASTE: if not pasted_csv.strip(): raise ValueError("Paste CSV data before running.") pasted_csv_path = run_dir / "input.csv" pasted_csv_path.write_text(pasted_csv.strip() + "\n") config["csv"] = str(pasted_csv_path) elif data_source == SOURCE_HF_DATASET: dataset = hf_dataset.strip() if not dataset: raise ValueError("Hugging Face dataset is required.") config["hf_dataset"] = dataset config["hf_split"] = hf_split.strip() or "train" revision = _optional_text(hf_revision) if revision is not None: config["hf_revision"] = revision else: raise ValueError(f"Unknown data source: {data_source}") prediction = _optional_text(prediction_col) if prediction is not None: config["prediction_col"] = prediction risk = _optional_text(risk_col) if risk is not None: config["risk_col"] = risk action = _optional_text(action_col) if action is not None: config["action_col"] = action return config def _result_summary(output_path: Path, action_col: str) -> tuple[str, pd.DataFrame]: df = pd.read_csv(output_path) summary_lines = [ f"Wrote `{output_path.name}`.", "", f"Rows: `{len(df)}`", ] if action_col in df.columns: counts = df[action_col].value_counts().sort_index() count_text = ", ".join(f"{int(action)}: {int(count)}" for action, count in counts.items()) summary_lines.append(f"{action_col}: `{count_text}`") return "\n".join(summary_lines), df.head(100) def get_optimal_screening( data_source: str, uploaded_csv: Any, pasted_csv: str, hf_dataset: str, hf_split: str, hf_revision: str, outcome: str, strata: str, beta: float, alpha: float, prediction_col: str, risk_col: str, action_col: str, output_filename: str, ) -> tuple[str, pd.DataFrame | None, Any]: try: run_dir = Path(tempfile.gettempdir()) / "optimal-screening" / uuid4().hex run_dir.mkdir(parents=True, exist_ok=True) config = _build_config( data_source=data_source, uploaded_csv=uploaded_csv, pasted_csv=pasted_csv, hf_dataset=hf_dataset, hf_split=hf_split, hf_revision=hf_revision, outcome=outcome, strata=strata, beta=beta, alpha=alpha, prediction_col=prediction_col, risk_col=risk_col, action_col=action_col, output_filename=output_filename, run_dir=run_dir, ) config_path = run_dir / "optimal-screening-config.json" config_path.write_text(json.dumps(config, indent=2)) output_path = get_optimal_screening_from_config(config_path) summary, preview = _result_summary(output_path, config.get("action_col", DEFAULT_ACTION_COL)) return summary, preview, gr.update(value=str(output_path), interactive=True) except Exception as exc: # noqa: BLE001 - show validation/runtime errors in the interface. return f"Run failed: `{exc}`", None, gr.update(value=None, interactive=False) with gr.Blocks(title="Optimal Screening Decisions") as demo: gr.Markdown("# Optimal Screening Decisions") with gr.Row(): with gr.Column(scale=2): data_source = gr.Radio( choices=[SOURCE_HF_DATASET, SOURCE_CSV_UPLOAD, SOURCE_CSV_PASTE], value=SOURCE_HF_DATASET, label="Data source", ) uploaded_csv = gr.File( label="Upload CSV", file_types=[".csv"], type="filepath", visible=False, ) pasted_csv = gr.Textbox( label="Paste CSV", lines=8, max_lines=16, placeholder="risk,outcome,group\n0.9,1,a\n0.1,0,b", visible=False, ) hf_dataset = gr.Textbox( value=DEFAULT_DATASET, label="Hugging Face dataset", ) with gr.Row(): hf_split = gr.Textbox(value=DEFAULT_SPLIT, label="Split") hf_revision = gr.Textbox(value="", label="Revision") outcome = gr.Textbox(value=DEFAULT_OUTCOME, label="Outcome column") strata = gr.Textbox(value=DEFAULT_STRATA, label="Strata columns") with gr.Row(): beta = gr.Number( value=DEFAULT_BETA, label="Treatment budget beta", minimum=0, maximum=1, step=0.01, ) alpha = gr.Number( value=DEFAULT_ALPHA, label="Screening budget alpha", minimum=0, maximum=1, step=0.01, ) prediction_col = gr.Textbox(value="probability", label="Prediction column") risk_col = gr.Textbox(value="", label="Risk column") action_col = gr.Textbox(value=DEFAULT_ACTION_COL, label="Action column") output_filename = gr.Textbox(value="optimal-screening.csv", label="Output file name") run_button = gr.Button("Run", variant="primary") with gr.Column(scale=3): status_output = gr.Markdown(label="Status") download_output = gr.DownloadButton( label="Download CSV", value=None, interactive=False, ) preview_output = gr.Dataframe(label="CSV preview", interactive=False) data_source.change( fn=_source_visibility, inputs=data_source, outputs=[uploaded_csv, pasted_csv, hf_dataset, hf_split, hf_revision], show_progress="hidden", ) run_button.click( fn=get_optimal_screening, inputs=[ data_source, uploaded_csv, pasted_csv, hf_dataset, hf_split, hf_revision, outcome, strata, beta, alpha, prediction_col, risk_col, action_col, output_filename, ], outputs=[status_output, preview_output, download_output], api_name="get_optimal_screening", ) if __name__ == "__main__": demo.queue().launch()