#!/usr/bin/env python3 """Summarize fixed-feature selection-regret scope sweep.""" from __future__ import annotations import argparse import csv import json import math import re import statistics from collections import defaultdict from pathlib import Path from typing import Any, Iterable ROOT = Path(__file__).resolve().parents[1] ROW_ORDER = [ ("reference", "FireWx-FM ref."), ("prithvi_wxc", "Prithvi-WxC"), ("aurora", "Aurora"), ("climax", "ClimaX"), ("stormcast", "StormCast"), ("dlwp", "DLWP"), ("fcn", "FCN"), ("fengwu", "FengWu"), ("fuxi", "FuXi"), ("pangu6", "Pangu-Weather"), ("alphaearth", "AlphaEarth"), ] SCOPE_ORDER = [ ("global", r"\(\Omega=\)global"), ("top5", r"\(\Omega=\)top 5\%"), ("top10", r"\(\Omega=\)top 10\%"), ("top20", r"\(\Omega=\)top 20\%"), ] METRICS = [ ("exact_regret", "Exact"), ("tolerated_regret", "Tol."), ("union_regret", "Union"), ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Summarize selection-regret scope sweep.") parser.add_argument("--run-root", type=Path, required=True) parser.add_argument("--out-json", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.json") parser.add_argument("--out-csv", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.csv") parser.add_argument("--out-tex", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.tex") parser.add_argument("--out-audit", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505_audit.md") parser.add_argument("--min-seeds", type=int, default=5) return parser.parse_args() def read_csv(path: Path) -> list[dict[str, str]]: with path.open("r", encoding="utf-8", newline="") as fh: return list(csv.DictReader(fh)) def finite(values: Iterable[Any]) -> list[float]: out: list[float] = [] for value in values: try: number = float(value) except Exception: continue if math.isfinite(number): out.append(number) return out def stat(values: Iterable[Any]) -> dict[str, Any]: vals = finite(values) if not vals: return {"n": 0, "mean": math.nan, "std": math.nan, "min": math.nan, "max": math.nan} return { "n": len(vals), "mean": float(statistics.fmean(vals)), "std": float(statistics.stdev(vals)) if len(vals) > 1 else 0.0, "min": float(min(vals)), "max": float(max(vals)), } def ms(summary: dict[str, Any], scale: float = 100.0) -> str: n = int(summary.get("n", 0)) if n < 2: return "not_bundled" mean = float(summary["mean"]) * scale std = float(summary["std"]) * scale if abs(mean) < 0.00005 and abs(std) < 0.00005: return "0.0000" return rf"\ms{{{mean:.4f}}}{{{std:.4f}}}" def collect_rows(run_root: Path) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for path in sorted(run_root.glob("*/*/selection_rows.csv")): for row in read_csv(path): enriched: dict[str, Any] = dict(row) enriched["path"] = str(path) enriched["seed"] = int(float(row["seed"])) for metric, _ in METRICS: enriched[metric] = float(row[metric]) rows.append(enriched) return rows def summarize(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: by_key: dict[tuple[str, str], list[dict[str, Any]]] = defaultdict(list) for row in rows: by_key[(str(row["model_tag"]), str(row["scope"]))].append(row) summary: list[dict[str, Any]] = [] for model_tag, label in ROW_ORDER: for scope, scope_label in SCOPE_ORDER: selected = by_key[(model_tag, scope)] seeds = sorted({int(row["seed"]) for row in selected}) item: dict[str, Any] = { "model_tag": model_tag, "label": label, "scope": scope, "scope_label": scope_label, "n": len(selected), "seeds": seeds, "paths": sorted({str(row["path"]) for row in selected}), } for metric, _ in METRICS: item[metric] = stat(row[metric] for row in selected) summary.append(item) return summary def write_csv(rows: list[dict[str, Any]], path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) flat_rows: list[dict[str, Any]] = [] for row in rows: out = { "model_tag": row["model_tag"], "label": row["label"], "scope": row["scope"], "scope_label": row["scope_label"], "n": row["n"], "seeds": " ".join(str(seed) for seed in row["seeds"]), } for metric, _ in METRICS: out[f"{metric}_mean"] = row[metric]["mean"] out[f"{metric}_std"] = row[metric]["std"] out[f"{metric}_min"] = row[metric]["min"] out[f"{metric}_max"] = row[metric]["max"] flat_rows.append(out) fieldnames = list(flat_rows[0]) if flat_rows else [] with path.open("w", newline="", encoding="utf-8") as fh: writer = csv.DictWriter(fh, fieldnames=fieldnames) writer.writeheader() writer.writerows(flat_rows) def write_tex(rows: list[dict[str, Any]], path: Path) -> str: by_key = {(row["model_tag"], row["scope"]): row for row in rows} lines = [ r"\begin{table*}[!t]", r" \centering", r" \small", r" \setlength{\tabcolsep}{4pt}", r" \caption{Fixed-feature selection-regret sweep across evaluation scopes. Values are percentage-point regret \(\delta = D(h_D)-D(h_R)\) under union-\(F_1\). Top-\(k\) scopes are train-defined fire-prone masks. Rows report mean with small std over five seeds.}", r" \label{tab:selection_regret_scope_sweep}", r" \begin{tabular}{lcccc}", r" \toprule", r" \textbf{Feature source} & \textbf{\(\Omega=\)global} & \textbf{\(\Omega=\)top 5\%} & \textbf{\(\Omega=\)top 10\%} & \textbf{\(\Omega=\)top 20\%} \\", r" \midrule", ] for model_tag, label in ROW_ORDER: cells = [ms(by_key[(model_tag, scope)]["union_regret"]) for scope, _ in SCOPE_ORDER] if model_tag == "reference": label = r"\textcolor{blue}{FireWx-FM ref.}" lines.append(" " + label + " & " + " & ".join(cells) + r" \\") lines.extend( [ r" \bottomrule", r" \end{tabular}", r"\end{table*}", "", ] ) tex = "\n".join(lines) path.parent.mkdir(parents=True, exist_ok=True) path.write_text(tex, encoding="utf-8") return tex def audit(summary: list[dict[str, Any]], tex: str, min_seeds: int) -> tuple[list[str], list[str]]: issues: list[str] = [] notes: list[str] = [] for row in summary: if int(row["n"]) < int(min_seeds): issues.append(f"{row['label']} {row['scope']} has n={row['n']}, expected >= {min_seeds}") for metric, _ in METRICS: values = row[metric] if not math.isfinite(float(values["mean"])) or not math.isfinite(float(values["std"])): issues.append(f"{row['label']} {row['scope']} {metric} is not finite") if f"{float(values['std']) * 100.0:.4f}" == "0.0000": notes.append(f"{row['label']} {row['scope']} {metric} has true zero displayed std") cells = re.findall(r"\\ms\{([^}]*)\}\{([^}]*)\}", tex) by_cell: dict[tuple[str, str], list[int]] = defaultdict(list) for idx, (mean, std) in enumerate(cells, start=1): by_cell[(mean, std)].append(idx) if std == "0.0000": issues.append(f"zero displayed std in cell {idx}: {mean} +/- {std}") for cell, idxs in by_cell.items(): if len(idxs) > 1 and cell != ("0.0000", "0.0000"): issues.append(f"duplicate displayed cell {cell} at positions {idxs}") forbidden = ["not_bundled", "not_applicable", "--", "nan", "NaN", "tied"] for token in forbidden: if token in tex: issues.append(f"forbidden token in tex: {token}") return issues, notes def main() -> None: args = parse_args() rows = collect_rows(args.run_root) summary = summarize(rows) tex = write_tex(summary, args.out_tex) write_csv(summary, args.out_csv) issues, notes = audit(summary, tex, int(args.min_seeds)) args.out_json.parent.mkdir(parents=True, exist_ok=True) args.out_json.write_text( json.dumps({"rows": rows, "summary": summary, "issues": issues, "notes": notes}, indent=2), encoding="utf-8", ) if issues: args.out_audit.write_text("FAIL\n" + "\n".join(f"- {issue}" for issue in issues) + "\n", encoding="utf-8") raise SystemExit("Selection-regret scope-sweep audit failed; see " + str(args.out_audit)) args.out_audit.write_text("PASS\n" + "\n".join(f"- {note}" for note in notes) + ("\n" if notes else ""), encoding="utf-8") print(json.dumps({"rows": len(rows), "summary_rows": len(summary), "tex": str(args.out_tex), "audit": "PASS"}, indent=2)) if __name__ == "__main__": main()