Wildfire-FM / scripts /summarize_selection_regret_scope_sweep_20260505.py
yx21e's picture
Initial FireWx-FM artifact release
80ef3b2 verified
#!/usr/bin/env python3
"""Summarize fixed-feature selection-regret scope sweep."""
from __future__ import annotations
import argparse
import csv
import json
import math
import re
import statistics
from collections import defaultdict
from pathlib import Path
from typing import Any, Iterable
ROOT = Path(__file__).resolve().parents[1]
ROW_ORDER = [
("reference", "FireWx-FM ref."),
("prithvi_wxc", "Prithvi-WxC"),
("aurora", "Aurora"),
("climax", "ClimaX"),
("stormcast", "StormCast"),
("dlwp", "DLWP"),
("fcn", "FCN"),
("fengwu", "FengWu"),
("fuxi", "FuXi"),
("pangu6", "Pangu-Weather"),
("alphaearth", "AlphaEarth"),
]
SCOPE_ORDER = [
("global", r"\(\Omega=\)global"),
("top5", r"\(\Omega=\)top 5\%"),
("top10", r"\(\Omega=\)top 10\%"),
("top20", r"\(\Omega=\)top 20\%"),
]
METRICS = [
("exact_regret", "Exact"),
("tolerated_regret", "Tol."),
("union_regret", "Union"),
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Summarize selection-regret scope sweep.")
parser.add_argument("--run-root", type=Path, required=True)
parser.add_argument("--out-json", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.json")
parser.add_argument("--out-csv", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.csv")
parser.add_argument("--out-tex", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505.tex")
parser.add_argument("--out-audit", type=Path, default=ROOT / "generated" / "selection_regret_scope_sweep_20260505_audit.md")
parser.add_argument("--min-seeds", type=int, default=5)
return parser.parse_args()
def read_csv(path: Path) -> list[dict[str, str]]:
with path.open("r", encoding="utf-8", newline="") as fh:
return list(csv.DictReader(fh))
def finite(values: Iterable[Any]) -> list[float]:
out: list[float] = []
for value in values:
try:
number = float(value)
except Exception:
continue
if math.isfinite(number):
out.append(number)
return out
def stat(values: Iterable[Any]) -> dict[str, Any]:
vals = finite(values)
if not vals:
return {"n": 0, "mean": math.nan, "std": math.nan, "min": math.nan, "max": math.nan}
return {
"n": len(vals),
"mean": float(statistics.fmean(vals)),
"std": float(statistics.stdev(vals)) if len(vals) > 1 else 0.0,
"min": float(min(vals)),
"max": float(max(vals)),
}
def ms(summary: dict[str, Any], scale: float = 100.0) -> str:
n = int(summary.get("n", 0))
if n < 2:
return "not_bundled"
mean = float(summary["mean"]) * scale
std = float(summary["std"]) * scale
if abs(mean) < 0.00005 and abs(std) < 0.00005:
return "0.0000"
return rf"\ms{{{mean:.4f}}}{{{std:.4f}}}"
def collect_rows(run_root: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for path in sorted(run_root.glob("*/*/selection_rows.csv")):
for row in read_csv(path):
enriched: dict[str, Any] = dict(row)
enriched["path"] = str(path)
enriched["seed"] = int(float(row["seed"]))
for metric, _ in METRICS:
enriched[metric] = float(row[metric])
rows.append(enriched)
return rows
def summarize(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
by_key: dict[tuple[str, str], list[dict[str, Any]]] = defaultdict(list)
for row in rows:
by_key[(str(row["model_tag"]), str(row["scope"]))].append(row)
summary: list[dict[str, Any]] = []
for model_tag, label in ROW_ORDER:
for scope, scope_label in SCOPE_ORDER:
selected = by_key[(model_tag, scope)]
seeds = sorted({int(row["seed"]) for row in selected})
item: dict[str, Any] = {
"model_tag": model_tag,
"label": label,
"scope": scope,
"scope_label": scope_label,
"n": len(selected),
"seeds": seeds,
"paths": sorted({str(row["path"]) for row in selected}),
}
for metric, _ in METRICS:
item[metric] = stat(row[metric] for row in selected)
summary.append(item)
return summary
def write_csv(rows: list[dict[str, Any]], path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
flat_rows: list[dict[str, Any]] = []
for row in rows:
out = {
"model_tag": row["model_tag"],
"label": row["label"],
"scope": row["scope"],
"scope_label": row["scope_label"],
"n": row["n"],
"seeds": " ".join(str(seed) for seed in row["seeds"]),
}
for metric, _ in METRICS:
out[f"{metric}_mean"] = row[metric]["mean"]
out[f"{metric}_std"] = row[metric]["std"]
out[f"{metric}_min"] = row[metric]["min"]
out[f"{metric}_max"] = row[metric]["max"]
flat_rows.append(out)
fieldnames = list(flat_rows[0]) if flat_rows else []
with path.open("w", newline="", encoding="utf-8") as fh:
writer = csv.DictWriter(fh, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(flat_rows)
def write_tex(rows: list[dict[str, Any]], path: Path) -> str:
by_key = {(row["model_tag"], row["scope"]): row for row in rows}
lines = [
r"\begin{table*}[!t]",
r" \centering",
r" \small",
r" \setlength{\tabcolsep}{4pt}",
r" \caption{Fixed-feature selection-regret sweep across evaluation scopes. Values are percentage-point regret \(\delta = D(h_D)-D(h_R)\) under union-\(F_1\). Top-\(k\) scopes are train-defined fire-prone masks. Rows report mean with small std over five seeds.}",
r" \label{tab:selection_regret_scope_sweep}",
r" \begin{tabular}{lcccc}",
r" \toprule",
r" \textbf{Feature source} & \textbf{\(\Omega=\)global} & \textbf{\(\Omega=\)top 5\%} & \textbf{\(\Omega=\)top 10\%} & \textbf{\(\Omega=\)top 20\%} \\",
r" \midrule",
]
for model_tag, label in ROW_ORDER:
cells = [ms(by_key[(model_tag, scope)]["union_regret"]) for scope, _ in SCOPE_ORDER]
if model_tag == "reference":
label = r"\textcolor{blue}{FireWx-FM ref.}"
lines.append(" " + label + " & " + " & ".join(cells) + r" \\")
lines.extend(
[
r" \bottomrule",
r" \end{tabular}",
r"\end{table*}",
"",
]
)
tex = "\n".join(lines)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(tex, encoding="utf-8")
return tex
def audit(summary: list[dict[str, Any]], tex: str, min_seeds: int) -> tuple[list[str], list[str]]:
issues: list[str] = []
notes: list[str] = []
for row in summary:
if int(row["n"]) < int(min_seeds):
issues.append(f"{row['label']} {row['scope']} has n={row['n']}, expected >= {min_seeds}")
for metric, _ in METRICS:
values = row[metric]
if not math.isfinite(float(values["mean"])) or not math.isfinite(float(values["std"])):
issues.append(f"{row['label']} {row['scope']} {metric} is not finite")
if f"{float(values['std']) * 100.0:.4f}" == "0.0000":
notes.append(f"{row['label']} {row['scope']} {metric} has true zero displayed std")
cells = re.findall(r"\\ms\{([^}]*)\}\{([^}]*)\}", tex)
by_cell: dict[tuple[str, str], list[int]] = defaultdict(list)
for idx, (mean, std) in enumerate(cells, start=1):
by_cell[(mean, std)].append(idx)
if std == "0.0000":
issues.append(f"zero displayed std in cell {idx}: {mean} +/- {std}")
for cell, idxs in by_cell.items():
if len(idxs) > 1 and cell != ("0.0000", "0.0000"):
issues.append(f"duplicate displayed cell {cell} at positions {idxs}")
forbidden = ["not_bundled", "not_applicable", "--", "nan", "NaN", "tied"]
for token in forbidden:
if token in tex:
issues.append(f"forbidden token in tex: {token}")
return issues, notes
def main() -> None:
args = parse_args()
rows = collect_rows(args.run_root)
summary = summarize(rows)
tex = write_tex(summary, args.out_tex)
write_csv(summary, args.out_csv)
issues, notes = audit(summary, tex, int(args.min_seeds))
args.out_json.parent.mkdir(parents=True, exist_ok=True)
args.out_json.write_text(
json.dumps({"rows": rows, "summary": summary, "issues": issues, "notes": notes}, indent=2),
encoding="utf-8",
)
if issues:
args.out_audit.write_text("FAIL\n" + "\n".join(f"- {issue}" for issue in issues) + "\n", encoding="utf-8")
raise SystemExit("Selection-regret scope-sweep audit failed; see " + str(args.out_audit))
args.out_audit.write_text("PASS\n" + "\n".join(f"- {note}" for note in notes) + ("\n" if notes else ""), encoding="utf-8")
print(json.dumps({"rows": len(rows), "summary_rows": len(summary), "tex": str(args.out_tex), "audit": "PASS"}, indent=2))
if __name__ == "__main__":
main()