File size: 6,607 Bytes
80ef3b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | #!/usr/bin/env python3
"""Build the compact RQ2 selection-regret interval figure.
The output is TikZ rather than a raster/PDF plot so the manuscript does not
depend on a local matplotlib installation.
"""
from __future__ import annotations
import csv
import math
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
IN_CSV = ROOT / "artifacts" / "results" / "selection_regret_scope_sweep_20260505.csv"
OUT_TIKZ = ROOT / "paper_outputs" / "figures" / "fig_selection_regret_rq2.tikz"
OUT_VALUES = ROOT / "artifacts" / "results" / "selection_regret_rq2_figure_values.csv"
ROW_ORDER = [
"FireWx-FM ref.",
"Prithvi-WxC",
"Aurora",
"ClimaX",
"StormCast",
"DLWP",
"FCN",
"FengWu",
"FuXi",
"Pangu-Weather",
"AlphaEarth",
]
SCOPES = {
"global": "global",
"fire_prone_top20": "top20",
}
def read_rows() -> dict[tuple[str, str], dict[str, str]]:
with IN_CSV.open("r", encoding="utf-8", newline="") as fh:
rows = list(csv.DictReader(fh))
by_key = {(row["label"], row["scope"]): row for row in rows}
missing = [
(label, scope)
for label in ROW_ORDER
for scope in SCOPES.values()
if (label, scope) not in by_key
]
if missing:
raise SystemExit(f"Missing selection-regret rows: {missing}")
bad_labels = sorted({row["label"] for row in rows if "Pangu24" in row["label"]})
if bad_labels:
raise SystemExit(f"Stale Pangu24 labels found in final CSV: {bad_labels}")
return by_key
def pct(row: dict[str, str], field: str) -> float:
return float(row[field]) * 100.0
def escape_tex(text: str) -> str:
return text.replace("&", r"\&").replace("%", r"\%")
def nice_floor(value: float, step: float) -> float:
return math.floor(value / step) * step
def nice_ceil(value: float, step: float) -> float:
return math.ceil(value / step) * step
def write_tikz(records: list[dict[str, float | str]], x_min: float, x_max: float) -> None:
width = 5.55
left = 2.45
y_step = 0.41
top = 4.35
bottom = -0.35
offset = 0.13
x_span = x_max - x_min
def x(value: float) -> float:
return left + (value - x_min) / x_span * width
lines: list[str] = [
r"% Auto-generated by scripts/build_selection_regret_rq2_figure.py.",
r"\begin{tikzpicture}[x=1cm,y=1cm]",
r"\footnotesize",
]
ticks = [tick for tick in range(int(x_min), int(x_max) + 1, 10) if tick % 10 == 0]
if 0 not in ticks:
ticks.append(0)
ticks = sorted(set(ticks))
for tick in ticks:
xt = x(float(tick))
color = "wfgray" if tick == 0 else "black!12"
lw = "0.55pt" if tick == 0 else "0.35pt"
lines.append(
rf"\draw[{color}, line width={lw}] ({xt:.3f},{bottom:.3f}) -- ({xt:.3f},{top + 0.18:.3f});"
)
lines.append(rf"\node[anchor=north, font=\scriptsize, text=black!70] at ({xt:.3f},{bottom - 0.06:.3f}) {{{tick}}};")
axis_y = bottom
lines.append(
rf"\draw[black!45, line width=0.4pt] ({x(x_min):.3f},{axis_y:.3f}) -- ({x(x_max):.3f},{axis_y:.3f});"
)
for idx, record in enumerate(records):
y_base = top - idx * y_step
label = str(record["feature_source"])
label_tex = escape_tex(label)
if label == "FireWx-FM ref.":
label_tex = rf"\textcolor{{wfblue}}{{\textbf{{{label_tex}}}}}"
lines.append(rf"\node[anchor=east, font=\scriptsize, text=black!82] at ({left - 0.13:.3f},{y_base:.3f}) {{{label_tex}}};")
for scope_key, color, y_offset, marker in [
("global", "wfslate", -offset, "circle"),
("top20", "wforange", offset, "square"),
]:
mean = float(record[f"{scope_key}_mean_pp"])
std = float(record[f"{scope_key}_std_pp"])
y_val = y_base + y_offset
x_lo = x(mean - std)
x_hi = x(mean + std)
x_mid = x(mean)
lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val:.3f}) -- ({x_hi:.3f},{y_val:.3f});")
lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val - 0.035:.3f}) -- ({x_lo:.3f},{y_val + 0.035:.3f});")
lines.append(rf"\draw[{color}, line width=0.72pt] ({x_hi:.3f},{y_val - 0.035:.3f}) -- ({x_hi:.3f},{y_val + 0.035:.3f});")
if marker == "circle":
lines.append(rf"\filldraw[{color}] ({x_mid:.3f},{y_val:.3f}) circle[radius=0.045];")
else:
lines.append(rf"\filldraw[{color}] ({x_mid - 0.045:.3f},{y_val - 0.045:.3f}) rectangle ({x_mid + 0.045:.3f},{y_val + 0.045:.3f});")
lines.append(r"\end{tikzpicture}")
OUT_TIKZ.parent.mkdir(parents=True, exist_ok=True)
OUT_TIKZ.write_text("\n".join(lines) + "\n", encoding="utf-8")
def main() -> None:
by_key = read_rows()
records: list[dict[str, float | str]] = []
interval_edges: list[float] = []
for label in ROW_ORDER:
g = by_key[(label, SCOPES["global"])]
f = by_key[(label, SCOPES["fire_prone_top20"])]
g_mean = pct(g, "union_regret_mean")
g_std = pct(g, "union_regret_std")
f_mean = pct(f, "union_regret_mean")
f_std = pct(f, "union_regret_std")
interval_edges.extend([g_mean - g_std, g_mean + g_std, f_mean - f_std, f_mean + f_std])
records.append(
{
"feature_source": label,
"global_mean_pp": g_mean,
"global_std_pp": g_std,
"top20_mean_pp": f_mean,
"top20_std_pp": f_std,
}
)
x_min = min(-5.0, nice_floor(min(interval_edges), 10.0))
x_max = max(30.0, nice_ceil(max(interval_edges), 10.0))
write_tikz(records, x_min, x_max)
OUT_VALUES.parent.mkdir(parents=True, exist_ok=True)
with OUT_VALUES.open("w", encoding="utf-8", newline="") as fh:
fieldnames = ["feature_source", "global_mean_pp", "global_std_pp", "top20_mean_pp", "top20_std_pp"]
writer = csv.DictWriter(fh, fieldnames=fieldnames)
writer.writeheader()
for record in records:
writer.writerow(
{
"feature_source": record["feature_source"],
"global_mean_pp": f"{float(record['global_mean_pp']):.4f}",
"global_std_pp": f"{float(record['global_std_pp']):.4f}",
"top20_mean_pp": f"{float(record['top20_mean_pp']):.4f}",
"top20_std_pp": f"{float(record['top20_std_pp']):.4f}",
}
)
if __name__ == "__main__":
main()
|