#!/usr/bin/env python3 """Build the compact RQ2 selection-regret interval figure. The output is TikZ rather than a raster/PDF plot so the manuscript does not depend on a local matplotlib installation. """ from __future__ import annotations import csv import math from pathlib import Path ROOT = Path(__file__).resolve().parents[1] IN_CSV = ROOT / "artifacts" / "results" / "selection_regret_scope_sweep_20260505.csv" OUT_TIKZ = ROOT / "paper_outputs" / "figures" / "fig_selection_regret_rq2.tikz" OUT_VALUES = ROOT / "artifacts" / "results" / "selection_regret_rq2_figure_values.csv" ROW_ORDER = [ "FireWx-FM ref.", "Prithvi-WxC", "Aurora", "ClimaX", "StormCast", "DLWP", "FCN", "FengWu", "FuXi", "Pangu-Weather", "AlphaEarth", ] SCOPES = { "global": "global", "fire_prone_top20": "top20", } def read_rows() -> dict[tuple[str, str], dict[str, str]]: with IN_CSV.open("r", encoding="utf-8", newline="") as fh: rows = list(csv.DictReader(fh)) by_key = {(row["label"], row["scope"]): row for row in rows} missing = [ (label, scope) for label in ROW_ORDER for scope in SCOPES.values() if (label, scope) not in by_key ] if missing: raise SystemExit(f"Missing selection-regret rows: {missing}") bad_labels = sorted({row["label"] for row in rows if "Pangu24" in row["label"]}) if bad_labels: raise SystemExit(f"Stale Pangu24 labels found in final CSV: {bad_labels}") return by_key def pct(row: dict[str, str], field: str) -> float: return float(row[field]) * 100.0 def escape_tex(text: str) -> str: return text.replace("&", r"\&").replace("%", r"\%") def nice_floor(value: float, step: float) -> float: return math.floor(value / step) * step def nice_ceil(value: float, step: float) -> float: return math.ceil(value / step) * step def write_tikz(records: list[dict[str, float | str]], x_min: float, x_max: float) -> None: width = 5.55 left = 2.45 y_step = 0.41 top = 4.35 bottom = -0.35 offset = 0.13 x_span = x_max - x_min def x(value: float) -> float: return left + (value - x_min) / x_span * width lines: list[str] = [ r"% Auto-generated by scripts/build_selection_regret_rq2_figure.py.", r"\begin{tikzpicture}[x=1cm,y=1cm]", r"\footnotesize", ] ticks = [tick for tick in range(int(x_min), int(x_max) + 1, 10) if tick % 10 == 0] if 0 not in ticks: ticks.append(0) ticks = sorted(set(ticks)) for tick in ticks: xt = x(float(tick)) color = "wfgray" if tick == 0 else "black!12" lw = "0.55pt" if tick == 0 else "0.35pt" lines.append( rf"\draw[{color}, line width={lw}] ({xt:.3f},{bottom:.3f}) -- ({xt:.3f},{top + 0.18:.3f});" ) lines.append(rf"\node[anchor=north, font=\scriptsize, text=black!70] at ({xt:.3f},{bottom - 0.06:.3f}) {{{tick}}};") axis_y = bottom lines.append( rf"\draw[black!45, line width=0.4pt] ({x(x_min):.3f},{axis_y:.3f}) -- ({x(x_max):.3f},{axis_y:.3f});" ) for idx, record in enumerate(records): y_base = top - idx * y_step label = str(record["feature_source"]) label_tex = escape_tex(label) if label == "FireWx-FM ref.": label_tex = rf"\textcolor{{wfblue}}{{\textbf{{{label_tex}}}}}" lines.append(rf"\node[anchor=east, font=\scriptsize, text=black!82] at ({left - 0.13:.3f},{y_base:.3f}) {{{label_tex}}};") for scope_key, color, y_offset, marker in [ ("global", "wfslate", -offset, "circle"), ("top20", "wforange", offset, "square"), ]: mean = float(record[f"{scope_key}_mean_pp"]) std = float(record[f"{scope_key}_std_pp"]) y_val = y_base + y_offset x_lo = x(mean - std) x_hi = x(mean + std) x_mid = x(mean) lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val:.3f}) -- ({x_hi:.3f},{y_val:.3f});") lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val - 0.035:.3f}) -- ({x_lo:.3f},{y_val + 0.035:.3f});") lines.append(rf"\draw[{color}, line width=0.72pt] ({x_hi:.3f},{y_val - 0.035:.3f}) -- ({x_hi:.3f},{y_val + 0.035:.3f});") if marker == "circle": lines.append(rf"\filldraw[{color}] ({x_mid:.3f},{y_val:.3f}) circle[radius=0.045];") else: lines.append(rf"\filldraw[{color}] ({x_mid - 0.045:.3f},{y_val - 0.045:.3f}) rectangle ({x_mid + 0.045:.3f},{y_val + 0.045:.3f});") lines.append(r"\end{tikzpicture}") OUT_TIKZ.parent.mkdir(parents=True, exist_ok=True) OUT_TIKZ.write_text("\n".join(lines) + "\n", encoding="utf-8") def main() -> None: by_key = read_rows() records: list[dict[str, float | str]] = [] interval_edges: list[float] = [] for label in ROW_ORDER: g = by_key[(label, SCOPES["global"])] f = by_key[(label, SCOPES["fire_prone_top20"])] g_mean = pct(g, "union_regret_mean") g_std = pct(g, "union_regret_std") f_mean = pct(f, "union_regret_mean") f_std = pct(f, "union_regret_std") interval_edges.extend([g_mean - g_std, g_mean + g_std, f_mean - f_std, f_mean + f_std]) records.append( { "feature_source": label, "global_mean_pp": g_mean, "global_std_pp": g_std, "top20_mean_pp": f_mean, "top20_std_pp": f_std, } ) x_min = min(-5.0, nice_floor(min(interval_edges), 10.0)) x_max = max(30.0, nice_ceil(max(interval_edges), 10.0)) write_tikz(records, x_min, x_max) OUT_VALUES.parent.mkdir(parents=True, exist_ok=True) with OUT_VALUES.open("w", encoding="utf-8", newline="") as fh: fieldnames = ["feature_source", "global_mean_pp", "global_std_pp", "top20_mean_pp", "top20_std_pp"] writer = csv.DictWriter(fh, fieldnames=fieldnames) writer.writeheader() for record in records: writer.writerow( { "feature_source": record["feature_source"], "global_mean_pp": f"{float(record['global_mean_pp']):.4f}", "global_std_pp": f"{float(record['global_std_pp']):.4f}", "top20_mean_pp": f"{float(record['top20_mean_pp']):.4f}", "top20_std_pp": f"{float(record['top20_std_pp']):.4f}", } ) if __name__ == "__main__": main()