| |
| """Build the compact RQ2 selection-regret interval figure. |
| |
| The output is TikZ rather than a raster/PDF plot so the manuscript does not |
| depend on a local matplotlib installation. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import csv |
| import math |
| from pathlib import Path |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| IN_CSV = ROOT / "artifacts" / "results" / "selection_regret_scope_sweep_20260505.csv" |
| OUT_TIKZ = ROOT / "paper_outputs" / "figures" / "fig_selection_regret_rq2.tikz" |
| OUT_VALUES = ROOT / "artifacts" / "results" / "selection_regret_rq2_figure_values.csv" |
|
|
| ROW_ORDER = [ |
| "FireWx-FM ref.", |
| "Prithvi-WxC", |
| "Aurora", |
| "ClimaX", |
| "StormCast", |
| "DLWP", |
| "FCN", |
| "FengWu", |
| "FuXi", |
| "Pangu-Weather", |
| "AlphaEarth", |
| ] |
|
|
| SCOPES = { |
| "global": "global", |
| "fire_prone_top20": "top20", |
| } |
|
|
|
|
| def read_rows() -> dict[tuple[str, str], dict[str, str]]: |
| with IN_CSV.open("r", encoding="utf-8", newline="") as fh: |
| rows = list(csv.DictReader(fh)) |
| by_key = {(row["label"], row["scope"]): row for row in rows} |
| missing = [ |
| (label, scope) |
| for label in ROW_ORDER |
| for scope in SCOPES.values() |
| if (label, scope) not in by_key |
| ] |
| if missing: |
| raise SystemExit(f"Missing selection-regret rows: {missing}") |
| bad_labels = sorted({row["label"] for row in rows if "Pangu24" in row["label"]}) |
| if bad_labels: |
| raise SystemExit(f"Stale Pangu24 labels found in final CSV: {bad_labels}") |
| return by_key |
|
|
|
|
| def pct(row: dict[str, str], field: str) -> float: |
| return float(row[field]) * 100.0 |
|
|
|
|
| def escape_tex(text: str) -> str: |
| return text.replace("&", r"\&").replace("%", r"\%") |
|
|
|
|
| def nice_floor(value: float, step: float) -> float: |
| return math.floor(value / step) * step |
|
|
|
|
| def nice_ceil(value: float, step: float) -> float: |
| return math.ceil(value / step) * step |
|
|
|
|
| def write_tikz(records: list[dict[str, float | str]], x_min: float, x_max: float) -> None: |
| width = 5.55 |
| left = 2.45 |
| y_step = 0.41 |
| top = 4.35 |
| bottom = -0.35 |
| offset = 0.13 |
| x_span = x_max - x_min |
|
|
| def x(value: float) -> float: |
| return left + (value - x_min) / x_span * width |
|
|
| lines: list[str] = [ |
| r"% Auto-generated by scripts/build_selection_regret_rq2_figure.py.", |
| r"\begin{tikzpicture}[x=1cm,y=1cm]", |
| r"\footnotesize", |
| ] |
|
|
| ticks = [tick for tick in range(int(x_min), int(x_max) + 1, 10) if tick % 10 == 0] |
| if 0 not in ticks: |
| ticks.append(0) |
| ticks = sorted(set(ticks)) |
|
|
| for tick in ticks: |
| xt = x(float(tick)) |
| color = "wfgray" if tick == 0 else "black!12" |
| lw = "0.55pt" if tick == 0 else "0.35pt" |
| lines.append( |
| rf"\draw[{color}, line width={lw}] ({xt:.3f},{bottom:.3f}) -- ({xt:.3f},{top + 0.18:.3f});" |
| ) |
| lines.append(rf"\node[anchor=north, font=\scriptsize, text=black!70] at ({xt:.3f},{bottom - 0.06:.3f}) {{{tick}}};") |
|
|
| axis_y = bottom |
| lines.append( |
| rf"\draw[black!45, line width=0.4pt] ({x(x_min):.3f},{axis_y:.3f}) -- ({x(x_max):.3f},{axis_y:.3f});" |
| ) |
|
|
| for idx, record in enumerate(records): |
| y_base = top - idx * y_step |
| label = str(record["feature_source"]) |
| label_tex = escape_tex(label) |
| if label == "FireWx-FM ref.": |
| label_tex = rf"\textcolor{{wfblue}}{{\textbf{{{label_tex}}}}}" |
| lines.append(rf"\node[anchor=east, font=\scriptsize, text=black!82] at ({left - 0.13:.3f},{y_base:.3f}) {{{label_tex}}};") |
|
|
| for scope_key, color, y_offset, marker in [ |
| ("global", "wfslate", -offset, "circle"), |
| ("top20", "wforange", offset, "square"), |
| ]: |
| mean = float(record[f"{scope_key}_mean_pp"]) |
| std = float(record[f"{scope_key}_std_pp"]) |
| y_val = y_base + y_offset |
| x_lo = x(mean - std) |
| x_hi = x(mean + std) |
| x_mid = x(mean) |
| lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val:.3f}) -- ({x_hi:.3f},{y_val:.3f});") |
| lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val - 0.035:.3f}) -- ({x_lo:.3f},{y_val + 0.035:.3f});") |
| lines.append(rf"\draw[{color}, line width=0.72pt] ({x_hi:.3f},{y_val - 0.035:.3f}) -- ({x_hi:.3f},{y_val + 0.035:.3f});") |
| if marker == "circle": |
| lines.append(rf"\filldraw[{color}] ({x_mid:.3f},{y_val:.3f}) circle[radius=0.045];") |
| else: |
| lines.append(rf"\filldraw[{color}] ({x_mid - 0.045:.3f},{y_val - 0.045:.3f}) rectangle ({x_mid + 0.045:.3f},{y_val + 0.045:.3f});") |
|
|
| lines.append(r"\end{tikzpicture}") |
| OUT_TIKZ.parent.mkdir(parents=True, exist_ok=True) |
| OUT_TIKZ.write_text("\n".join(lines) + "\n", encoding="utf-8") |
|
|
|
|
| def main() -> None: |
| by_key = read_rows() |
| records: list[dict[str, float | str]] = [] |
| interval_edges: list[float] = [] |
|
|
| for label in ROW_ORDER: |
| g = by_key[(label, SCOPES["global"])] |
| f = by_key[(label, SCOPES["fire_prone_top20"])] |
| g_mean = pct(g, "union_regret_mean") |
| g_std = pct(g, "union_regret_std") |
| f_mean = pct(f, "union_regret_mean") |
| f_std = pct(f, "union_regret_std") |
| interval_edges.extend([g_mean - g_std, g_mean + g_std, f_mean - f_std, f_mean + f_std]) |
| records.append( |
| { |
| "feature_source": label, |
| "global_mean_pp": g_mean, |
| "global_std_pp": g_std, |
| "top20_mean_pp": f_mean, |
| "top20_std_pp": f_std, |
| } |
| ) |
|
|
| x_min = min(-5.0, nice_floor(min(interval_edges), 10.0)) |
| x_max = max(30.0, nice_ceil(max(interval_edges), 10.0)) |
| write_tikz(records, x_min, x_max) |
|
|
| OUT_VALUES.parent.mkdir(parents=True, exist_ok=True) |
| with OUT_VALUES.open("w", encoding="utf-8", newline="") as fh: |
| fieldnames = ["feature_source", "global_mean_pp", "global_std_pp", "top20_mean_pp", "top20_std_pp"] |
| writer = csv.DictWriter(fh, fieldnames=fieldnames) |
| writer.writeheader() |
| for record in records: |
| writer.writerow( |
| { |
| "feature_source": record["feature_source"], |
| "global_mean_pp": f"{float(record['global_mean_pp']):.4f}", |
| "global_std_pp": f"{float(record['global_std_pp']):.4f}", |
| "top20_mean_pp": f"{float(record['top20_mean_pp']):.4f}", |
| "top20_std_pp": f"{float(record['top20_std_pp']):.4f}", |
| } |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|