Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any | |
| from uuid import uuid4 | |
| import gradio as gr | |
| import pandas as pd | |
| from optimal_screening.cli.get_optimal_screening import get_optimal_screening_from_config | |
| SOURCE_CSV_UPLOAD = "Upload CSV" | |
| SOURCE_CSV_PASTE = "Paste CSV" | |
| SOURCE_HF_DATASET = "Hugging Face dataset" | |
| DEFAULT_DATASET = "cmpatino/landmine-detection" | |
| DEFAULT_SPLIT = "train" | |
| DEFAULT_OUTCOME = "mines_outcome" | |
| DEFAULT_STRATA = "Municipio" | |
| DEFAULT_BETA = 0.1 | |
| DEFAULT_ALPHA = 0.05 | |
| DEFAULT_ACTION_COL = "screening_decision" | |
| def _uploaded_path(uploaded_csv: Any) -> str | None: | |
| if uploaded_csv is None: | |
| return None | |
| if isinstance(uploaded_csv, list): | |
| if not uploaded_csv: | |
| return None | |
| uploaded_csv = uploaded_csv[0] | |
| if isinstance(uploaded_csv, str): | |
| return uploaded_csv | |
| if hasattr(uploaded_csv, "name"): | |
| return str(uploaded_csv.name) | |
| return str(uploaded_csv) | |
| def _parse_list(value: str, field: str) -> list[str]: | |
| values = [item.strip() for item in value.replace("\n", ",").split(",") if item.strip()] | |
| if not values: | |
| raise ValueError(f"{field} must include at least one value.") | |
| return values | |
| def _optional_text(value: str | None) -> str | None: | |
| if value is None: | |
| return None | |
| value = value.strip() | |
| return value or None | |
| def _csv_filename(value: str | None) -> str: | |
| filename = Path(value.strip()).name if value and value.strip() else "optimal-screening.csv" | |
| if not filename.endswith(".csv"): | |
| filename = f"{filename}.csv" | |
| return filename | |
| def _source_visibility(source: str) -> tuple[Any, Any, Any, Any, Any]: | |
| return ( | |
| gr.update(visible=source == SOURCE_CSV_UPLOAD), | |
| gr.update(visible=source == SOURCE_CSV_PASTE), | |
| gr.update(visible=source == SOURCE_HF_DATASET), | |
| gr.update(visible=source == SOURCE_HF_DATASET), | |
| gr.update(visible=source == SOURCE_HF_DATASET), | |
| ) | |
| def _build_config( | |
| *, | |
| data_source: str, | |
| uploaded_csv: Any, | |
| pasted_csv: str, | |
| hf_dataset: str, | |
| hf_split: str, | |
| hf_revision: str, | |
| outcome: str, | |
| strata: str, | |
| beta: float, | |
| alpha: float, | |
| prediction_col: str, | |
| risk_col: str, | |
| action_col: str, | |
| output_filename: str, | |
| run_dir: Path, | |
| ) -> dict[str, Any]: | |
| config: dict[str, Any] = { | |
| "outcome": outcome.strip(), | |
| "strata": _parse_list(strata, "strata"), | |
| "beta": float(beta), | |
| "alpha": float(alpha), | |
| "output": str(run_dir / _csv_filename(output_filename)), | |
| } | |
| if data_source == SOURCE_CSV_UPLOAD: | |
| csv_path = _uploaded_path(uploaded_csv) | |
| if csv_path is None: | |
| raise ValueError("Upload a CSV file before running.") | |
| config["csv"] = csv_path | |
| elif data_source == SOURCE_CSV_PASTE: | |
| if not pasted_csv.strip(): | |
| raise ValueError("Paste CSV data before running.") | |
| pasted_csv_path = run_dir / "input.csv" | |
| pasted_csv_path.write_text(pasted_csv.strip() + "\n") | |
| config["csv"] = str(pasted_csv_path) | |
| elif data_source == SOURCE_HF_DATASET: | |
| dataset = hf_dataset.strip() | |
| if not dataset: | |
| raise ValueError("Hugging Face dataset is required.") | |
| config["hf_dataset"] = dataset | |
| config["hf_split"] = hf_split.strip() or "train" | |
| revision = _optional_text(hf_revision) | |
| if revision is not None: | |
| config["hf_revision"] = revision | |
| else: | |
| raise ValueError(f"Unknown data source: {data_source}") | |
| prediction = _optional_text(prediction_col) | |
| if prediction is not None: | |
| config["prediction_col"] = prediction | |
| risk = _optional_text(risk_col) | |
| if risk is not None: | |
| config["risk_col"] = risk | |
| action = _optional_text(action_col) | |
| if action is not None: | |
| config["action_col"] = action | |
| return config | |
| def _result_summary(output_path: Path, action_col: str) -> tuple[str, pd.DataFrame]: | |
| df = pd.read_csv(output_path) | |
| summary_lines = [ | |
| f"Wrote `{output_path.name}`.", | |
| "", | |
| f"Rows: `{len(df)}`", | |
| ] | |
| if action_col in df.columns: | |
| counts = df[action_col].value_counts().sort_index() | |
| count_text = ", ".join(f"{int(action)}: {int(count)}" for action, count in counts.items()) | |
| summary_lines.append(f"{action_col}: `{count_text}`") | |
| return "\n".join(summary_lines), df.head(100) | |
| def get_optimal_screening( | |
| data_source: str, | |
| uploaded_csv: Any, | |
| pasted_csv: str, | |
| hf_dataset: str, | |
| hf_split: str, | |
| hf_revision: str, | |
| outcome: str, | |
| strata: str, | |
| beta: float, | |
| alpha: float, | |
| prediction_col: str, | |
| risk_col: str, | |
| action_col: str, | |
| output_filename: str, | |
| ) -> tuple[str, pd.DataFrame | None, Any]: | |
| try: | |
| run_dir = Path(tempfile.gettempdir()) / "optimal-screening" / uuid4().hex | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| config = _build_config( | |
| data_source=data_source, | |
| uploaded_csv=uploaded_csv, | |
| pasted_csv=pasted_csv, | |
| hf_dataset=hf_dataset, | |
| hf_split=hf_split, | |
| hf_revision=hf_revision, | |
| outcome=outcome, | |
| strata=strata, | |
| beta=beta, | |
| alpha=alpha, | |
| prediction_col=prediction_col, | |
| risk_col=risk_col, | |
| action_col=action_col, | |
| output_filename=output_filename, | |
| run_dir=run_dir, | |
| ) | |
| config_path = run_dir / "optimal-screening-config.json" | |
| config_path.write_text(json.dumps(config, indent=2)) | |
| output_path = get_optimal_screening_from_config(config_path) | |
| summary, preview = _result_summary(output_path, config.get("action_col", DEFAULT_ACTION_COL)) | |
| return summary, preview, gr.update(value=str(output_path), interactive=True) | |
| except Exception as exc: # noqa: BLE001 - show validation/runtime errors in the interface. | |
| return f"Run failed: `{exc}`", None, gr.update(value=None, interactive=False) | |
| with gr.Blocks(title="Optimal Screening Decisions") as demo: | |
| gr.Markdown("# Optimal Screening Decisions") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| data_source = gr.Radio( | |
| choices=[SOURCE_HF_DATASET, SOURCE_CSV_UPLOAD, SOURCE_CSV_PASTE], | |
| value=SOURCE_HF_DATASET, | |
| label="Data source", | |
| ) | |
| uploaded_csv = gr.File( | |
| label="Upload CSV", | |
| file_types=[".csv"], | |
| type="filepath", | |
| visible=False, | |
| ) | |
| pasted_csv = gr.Textbox( | |
| label="Paste CSV", | |
| lines=8, | |
| max_lines=16, | |
| placeholder="risk,outcome,group\n0.9,1,a\n0.1,0,b", | |
| visible=False, | |
| ) | |
| hf_dataset = gr.Textbox( | |
| value=DEFAULT_DATASET, | |
| label="Hugging Face dataset", | |
| ) | |
| with gr.Row(): | |
| hf_split = gr.Textbox(value=DEFAULT_SPLIT, label="Split") | |
| hf_revision = gr.Textbox(value="", label="Revision") | |
| outcome = gr.Textbox(value=DEFAULT_OUTCOME, label="Outcome column") | |
| strata = gr.Textbox(value=DEFAULT_STRATA, label="Strata columns") | |
| with gr.Row(): | |
| beta = gr.Number( | |
| value=DEFAULT_BETA, | |
| label="Treatment budget beta", | |
| minimum=0, | |
| maximum=1, | |
| step=0.01, | |
| ) | |
| alpha = gr.Number( | |
| value=DEFAULT_ALPHA, | |
| label="Screening budget alpha", | |
| minimum=0, | |
| maximum=1, | |
| step=0.01, | |
| ) | |
| prediction_col = gr.Textbox(value="probability", label="Prediction column") | |
| risk_col = gr.Textbox(value="", label="Risk column") | |
| action_col = gr.Textbox(value=DEFAULT_ACTION_COL, label="Action column") | |
| output_filename = gr.Textbox(value="optimal-screening.csv", label="Output file name") | |
| run_button = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=3): | |
| status_output = gr.Markdown(label="Status") | |
| download_output = gr.DownloadButton( | |
| label="Download CSV", | |
| value=None, | |
| interactive=False, | |
| ) | |
| preview_output = gr.Dataframe(label="CSV preview", interactive=False) | |
| data_source.change( | |
| fn=_source_visibility, | |
| inputs=data_source, | |
| outputs=[uploaded_csv, pasted_csv, hf_dataset, hf_split, hf_revision], | |
| show_progress="hidden", | |
| ) | |
| run_button.click( | |
| fn=get_optimal_screening, | |
| inputs=[ | |
| data_source, | |
| uploaded_csv, | |
| pasted_csv, | |
| hf_dataset, | |
| hf_split, | |
| hf_revision, | |
| outcome, | |
| strata, | |
| beta, | |
| alpha, | |
| prediction_col, | |
| risk_col, | |
| action_col, | |
| output_filename, | |
| ], | |
| outputs=[status_output, preview_output, download_output], | |
| api_name="get_optimal_screening", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |