| |
| """Summarize fixed-feature selection-regret scope sweep.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| import math |
| import re |
| import statistics |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Any, Iterable |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| ROW_ORDER = [ |
| ("reference", "FireWx-FM ref."), |
| ("prithvi_wxc", "Prithvi-WxC"), |
| ("aurora", "Aurora"), |
| ("climax", "ClimaX"), |
| ("stormcast", "StormCast"), |
| ("dlwp", "DLWP"), |
| ("fcn", "FCN"), |
| ("fengwu", "FengWu"), |
| ("fuxi", "FuXi"), |
| ("pangu6", "Pangu-Weather"), |
| ("alphaearth", "AlphaEarth"), |
| ] |
| SCOPE_ORDER = [ |
| ("global", r"\(\Omega=\)global"), |
| ("top5", r"\(\Omega=\)top 5\%"), |
| ("top10", r"\(\Omega=\)top 10\%"), |
| ("top20", r"\(\Omega=\)top 20\%"), |
| ] |
| METRICS = [ |
| ("exact_regret", "Exact"), |
| ("tolerated_regret", "Tol."), |
| ("union_regret", "Union"), |
| ] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Summarize selection-regret scope sweep.") |
| parser.add_argument("--run-root", type=Path, required=True) |
| parser.add_argument("--out-json", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.json") |
| parser.add_argument("--out-csv", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.csv") |
| parser.add_argument("--out-tex", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.tex") |
| parser.add_argument("--out-audit", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505_audit.md") |
| parser.add_argument("--min-seeds", type=int, default=5) |
| return parser.parse_args() |
|
|
|
|
| def read_csv(path: Path) -> list[dict[str, str]]: |
| with path.open("r", encoding="utf-8", newline="") as fh: |
| return list(csv.DictReader(fh)) |
|
|
|
|
| def finite(values: Iterable[Any]) -> list[float]: |
| out: list[float] = [] |
| for value in values: |
| try: |
| number = float(value) |
| except Exception: |
| continue |
| if math.isfinite(number): |
| out.append(number) |
| return out |
|
|
|
|
| def stat(values: Iterable[Any]) -> dict[str, Any]: |
| vals = finite(values) |
| if not vals: |
| return {"n": 0, "mean": math.nan, "std": math.nan, "min": math.nan, "max": math.nan} |
| return { |
| "n": len(vals), |
| "mean": float(statistics.fmean(vals)), |
| "std": float(statistics.stdev(vals)) if len(vals) > 1 else 0.0, |
| "min": float(min(vals)), |
| "max": float(max(vals)), |
| } |
|
|
|
|
| def ms(summary: dict[str, Any], scale: float = 100.0) -> str: |
| n = int(summary.get("n", 0)) |
| if n < 2: |
| return "not_bundled" |
| mean = float(summary["mean"]) * scale |
| std = float(summary["std"]) * scale |
| if abs(mean) < 0.00005 and abs(std) < 0.00005: |
| return "0.0000" |
| return rf"\ms{{{mean:.4f}}}{{{std:.4f}}}" |
|
|
|
|
| def collect_rows(run_root: Path) -> list[dict[str, Any]]: |
| rows: list[dict[str, Any]] = [] |
| for path in sorted(run_root.glob("*/*/selection_rows.csv")): |
| for row in read_csv(path): |
| enriched: dict[str, Any] = dict(row) |
| enriched["path"] = str(path) |
| enriched["seed"] = int(float(row["seed"])) |
| for metric, _ in METRICS: |
| enriched[metric] = float(row[metric]) |
| rows.append(enriched) |
| return rows |
|
|
|
|
| def summarize(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: |
| by_key: dict[tuple[str, str], list[dict[str, Any]]] = defaultdict(list) |
| for row in rows: |
| by_key[(str(row["model_tag"]), str(row["scope"]))].append(row) |
|
|
| summary: list[dict[str, Any]] = [] |
| for model_tag, label in ROW_ORDER: |
| for scope, scope_label in SCOPE_ORDER: |
| selected = by_key[(model_tag, scope)] |
| seeds = sorted({int(row["seed"]) for row in selected}) |
| item: dict[str, Any] = { |
| "model_tag": model_tag, |
| "label": label, |
| "scope": scope, |
| "scope_label": scope_label, |
| "n": len(selected), |
| "seeds": seeds, |
| "paths": sorted({str(row["path"]) for row in selected}), |
| } |
| for metric, _ in METRICS: |
| item[metric] = stat(row[metric] for row in selected) |
| summary.append(item) |
| return summary |
|
|
|
|
| def write_csv(rows: list[dict[str, Any]], path: Path) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| flat_rows: list[dict[str, Any]] = [] |
| for row in rows: |
| out = { |
| "model_tag": row["model_tag"], |
| "label": row["label"], |
| "scope": row["scope"], |
| "scope_label": row["scope_label"], |
| "n": row["n"], |
| "seeds": " ".join(str(seed) for seed in row["seeds"]), |
| } |
| for metric, _ in METRICS: |
| out[f"{metric}_mean"] = row[metric]["mean"] |
| out[f"{metric}_std"] = row[metric]["std"] |
| out[f"{metric}_min"] = row[metric]["min"] |
| out[f"{metric}_max"] = row[metric]["max"] |
| flat_rows.append(out) |
| fieldnames = list(flat_rows[0]) if flat_rows else [] |
| with path.open("w", newline="", encoding="utf-8") as fh: |
| writer = csv.DictWriter(fh, fieldnames=fieldnames) |
| writer.writeheader() |
| writer.writerows(flat_rows) |
|
|
|
|
| def write_tex(rows: list[dict[str, Any]], path: Path) -> str: |
| by_key = {(row["model_tag"], row["scope"]): row for row in rows} |
| lines = [ |
| r"\begin{table*}[!t]", |
| r" \centering", |
| r" \small", |
| r" \setlength{\tabcolsep}{4pt}", |
| r" \caption{Fixed-feature selection-regret sweep across evaluation scopes. Values are percentage-point regret \(\delta = D(h_D)-D(h_R)\) under union-\(F_1\). Top-\(k\) scopes are train-defined fire-prone masks. Rows report mean with small std over five seeds.}", |
| r" \label{tab:selection_regret_scope_sweep}", |
| r" \begin{tabular}{lcccc}", |
| r" \toprule", |
| r" \textbf{Feature source} & \textbf{\(\Omega=\)global} & \textbf{\(\Omega=\)top 5\%} & \textbf{\(\Omega=\)top 10\%} & \textbf{\(\Omega=\)top 20\%} \\", |
| r" \midrule", |
| ] |
| for model_tag, label in ROW_ORDER: |
| cells = [ms(by_key[(model_tag, scope)]["union_regret"]) for scope, _ in SCOPE_ORDER] |
| if model_tag == "reference": |
| label = r"\textcolor{blue}{FireWx-FM ref.}" |
| lines.append(" " + label + " & " + " & ".join(cells) + r" \\") |
| lines.extend( |
| [ |
| r" \bottomrule", |
| r" \end{tabular}", |
| r"\end{table*}", |
| "", |
| ] |
| ) |
| tex = "\n".join(lines) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(tex, encoding="utf-8") |
| return tex |
|
|
|
|
| def audit(summary: list[dict[str, Any]], tex: str, min_seeds: int) -> tuple[list[str], list[str]]: |
| issues: list[str] = [] |
| notes: list[str] = [] |
| for row in summary: |
| if int(row["n"]) < int(min_seeds): |
| issues.append(f"{row['label']} {row['scope']} has n={row['n']}, expected >= {min_seeds}") |
| for metric, _ in METRICS: |
| values = row[metric] |
| if not math.isfinite(float(values["mean"])) or not math.isfinite(float(values["std"])): |
| issues.append(f"{row['label']} {row['scope']} {metric} is not finite") |
| if f"{float(values['std']) * 100.0:.4f}" == "0.0000": |
| notes.append(f"{row['label']} {row['scope']} {metric} has true zero displayed std") |
| cells = re.findall(r"\\ms\{([^}]*)\}\{([^}]*)\}", tex) |
| by_cell: dict[tuple[str, str], list[int]] = defaultdict(list) |
| for idx, (mean, std) in enumerate(cells, start=1): |
| by_cell[(mean, std)].append(idx) |
| if std == "0.0000": |
| issues.append(f"zero displayed std in cell {idx}: {mean} +/- {std}") |
| for cell, idxs in by_cell.items(): |
| if len(idxs) > 1 and cell != ("0.0000", "0.0000"): |
| issues.append(f"duplicate displayed cell {cell} at positions {idxs}") |
| forbidden = ["not_bundled", "not_applicable", "--", "nan", "NaN", "tied"] |
| for token in forbidden: |
| if token in tex: |
| issues.append(f"forbidden token in tex: {token}") |
| return issues, notes |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| rows = collect_rows(args.run_root) |
| summary = summarize(rows) |
| tex = write_tex(summary, args.out_tex) |
| write_csv(summary, args.out_csv) |
| issues, notes = audit(summary, tex, int(args.min_seeds)) |
| args.out_json.parent.mkdir(parents=True, exist_ok=True) |
| args.out_json.write_text( |
| json.dumps({"rows": rows, "summary": summary, "issues": issues, "notes": notes}, indent=2), |
| encoding="utf-8", |
| ) |
| if issues: |
| args.out_audit.write_text("FAIL\n" + "\n".join(f"- {issue}" for issue in issues) + "\n", encoding="utf-8") |
| raise SystemExit("Selection-regret scope-sweep audit failed; see " + str(args.out_audit)) |
| args.out_audit.write_text("PASS\n" + "\n".join(f"- {note}" for note in notes) + ("\n" if notes else ""), encoding="utf-8") |
| print(json.dumps({"rows": len(rows), "summary_rows": len(summary), "tex": str(args.out_tex), "audit": "PASS"}, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|