Wildfire-FM / scripts /build_selection_regret_rq2_figure.py
yx21e's picture
Initial FireWx-FM artifact release
80ef3b2 verified
#!/usr/bin/env python3
"""Build the compact RQ2 selection-regret interval figure.
The output is TikZ rather than a raster/PDF plot so the manuscript does not
depend on a local matplotlib installation.
"""
from __future__ import annotations
import csv
import math
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
IN_CSV = ROOT / "artifacts" / "results" / "selection_regret_scope_sweep_20260505.csv"
OUT_TIKZ = ROOT / "paper_outputs" / "figures" / "fig_selection_regret_rq2.tikz"
OUT_VALUES = ROOT / "artifacts" / "results" / "selection_regret_rq2_figure_values.csv"
ROW_ORDER = [
"FireWx-FM ref.",
"Prithvi-WxC",
"Aurora",
"ClimaX",
"StormCast",
"DLWP",
"FCN",
"FengWu",
"FuXi",
"Pangu-Weather",
"AlphaEarth",
]
SCOPES = {
"global": "global",
"fire_prone_top20": "top20",
}
def read_rows() -> dict[tuple[str, str], dict[str, str]]:
with IN_CSV.open("r", encoding="utf-8", newline="") as fh:
rows = list(csv.DictReader(fh))
by_key = {(row["label"], row["scope"]): row for row in rows}
missing = [
(label, scope)
for label in ROW_ORDER
for scope in SCOPES.values()
if (label, scope) not in by_key
]
if missing:
raise SystemExit(f"Missing selection-regret rows: {missing}")
bad_labels = sorted({row["label"] for row in rows if "Pangu24" in row["label"]})
if bad_labels:
raise SystemExit(f"Stale Pangu24 labels found in final CSV: {bad_labels}")
return by_key
def pct(row: dict[str, str], field: str) -> float:
return float(row[field]) * 100.0
def escape_tex(text: str) -> str:
return text.replace("&", r"\&").replace("%", r"\%")
def nice_floor(value: float, step: float) -> float:
return math.floor(value / step) * step
def nice_ceil(value: float, step: float) -> float:
return math.ceil(value / step) * step
def write_tikz(records: list[dict[str, float | str]], x_min: float, x_max: float) -> None:
width = 5.55
left = 2.45
y_step = 0.41
top = 4.35
bottom = -0.35
offset = 0.13
x_span = x_max - x_min
def x(value: float) -> float:
return left + (value - x_min) / x_span * width
lines: list[str] = [
r"% Auto-generated by scripts/build_selection_regret_rq2_figure.py.",
r"\begin{tikzpicture}[x=1cm,y=1cm]",
r"\footnotesize",
]
ticks = [tick for tick in range(int(x_min), int(x_max) + 1, 10) if tick % 10 == 0]
if 0 not in ticks:
ticks.append(0)
ticks = sorted(set(ticks))
for tick in ticks:
xt = x(float(tick))
color = "wfgray" if tick == 0 else "black!12"
lw = "0.55pt" if tick == 0 else "0.35pt"
lines.append(
rf"\draw[{color}, line width={lw}] ({xt:.3f},{bottom:.3f}) -- ({xt:.3f},{top + 0.18:.3f});"
)
lines.append(rf"\node[anchor=north, font=\scriptsize, text=black!70] at ({xt:.3f},{bottom - 0.06:.3f}) {{{tick}}};")
axis_y = bottom
lines.append(
rf"\draw[black!45, line width=0.4pt] ({x(x_min):.3f},{axis_y:.3f}) -- ({x(x_max):.3f},{axis_y:.3f});"
)
for idx, record in enumerate(records):
y_base = top - idx * y_step
label = str(record["feature_source"])
label_tex = escape_tex(label)
if label == "FireWx-FM ref.":
label_tex = rf"\textcolor{{wfblue}}{{\textbf{{{label_tex}}}}}"
lines.append(rf"\node[anchor=east, font=\scriptsize, text=black!82] at ({left - 0.13:.3f},{y_base:.3f}) {{{label_tex}}};")
for scope_key, color, y_offset, marker in [
("global", "wfslate", -offset, "circle"),
("top20", "wforange", offset, "square"),
]:
mean = float(record[f"{scope_key}_mean_pp"])
std = float(record[f"{scope_key}_std_pp"])
y_val = y_base + y_offset
x_lo = x(mean - std)
x_hi = x(mean + std)
x_mid = x(mean)
lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val:.3f}) -- ({x_hi:.3f},{y_val:.3f});")
lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val - 0.035:.3f}) -- ({x_lo:.3f},{y_val + 0.035:.3f});")
lines.append(rf"\draw[{color}, line width=0.72pt] ({x_hi:.3f},{y_val - 0.035:.3f}) -- ({x_hi:.3f},{y_val + 0.035:.3f});")
if marker == "circle":
lines.append(rf"\filldraw[{color}] ({x_mid:.3f},{y_val:.3f}) circle[radius=0.045];")
else:
lines.append(rf"\filldraw[{color}] ({x_mid - 0.045:.3f},{y_val - 0.045:.3f}) rectangle ({x_mid + 0.045:.3f},{y_val + 0.045:.3f});")
lines.append(r"\end{tikzpicture}")
OUT_TIKZ.parent.mkdir(parents=True, exist_ok=True)
OUT_TIKZ.write_text("\n".join(lines) + "\n", encoding="utf-8")
def main() -> None:
by_key = read_rows()
records: list[dict[str, float | str]] = []
interval_edges: list[float] = []
for label in ROW_ORDER:
g = by_key[(label, SCOPES["global"])]
f = by_key[(label, SCOPES["fire_prone_top20"])]
g_mean = pct(g, "union_regret_mean")
g_std = pct(g, "union_regret_std")
f_mean = pct(f, "union_regret_mean")
f_std = pct(f, "union_regret_std")
interval_edges.extend([g_mean - g_std, g_mean + g_std, f_mean - f_std, f_mean + f_std])
records.append(
{
"feature_source": label,
"global_mean_pp": g_mean,
"global_std_pp": g_std,
"top20_mean_pp": f_mean,
"top20_std_pp": f_std,
}
)
x_min = min(-5.0, nice_floor(min(interval_edges), 10.0))
x_max = max(30.0, nice_ceil(max(interval_edges), 10.0))
write_tikz(records, x_min, x_max)
OUT_VALUES.parent.mkdir(parents=True, exist_ok=True)
with OUT_VALUES.open("w", encoding="utf-8", newline="") as fh:
fieldnames = ["feature_source", "global_mean_pp", "global_std_pp", "top20_mean_pp", "top20_std_pp"]
writer = csv.DictWriter(fh, fieldnames=fieldnames)
writer.writeheader()
for record in records:
writer.writerow(
{
"feature_source": record["feature_source"],
"global_mean_pp": f"{float(record['global_mean_pp']):.4f}",
"global_std_pp": f"{float(record['global_std_pp']):.4f}",
"top20_mean_pp": f"{float(record['top20_mean_pp']):.4f}",
"top20_std_pp": f"{float(record['top20_std_pp']):.4f}",
}
)
if __name__ == "__main__":
main()