File size: 6,607 Bytes
80ef3b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#!/usr/bin/env python3
"""Build the compact RQ2 selection-regret interval figure.

The output is TikZ rather than a raster/PDF plot so the manuscript does not
depend on a local matplotlib installation.
"""

from __future__ import annotations

import csv
import math
from pathlib import Path


ROOT = Path(__file__).resolve().parents[1]
IN_CSV = ROOT / "artifacts" / "results" / "selection_regret_scope_sweep_20260505.csv"
OUT_TIKZ = ROOT / "paper_outputs" / "figures" / "fig_selection_regret_rq2.tikz"
OUT_VALUES = ROOT / "artifacts" / "results" / "selection_regret_rq2_figure_values.csv"

ROW_ORDER = [
    "FireWx-FM ref.",
    "Prithvi-WxC",
    "Aurora",
    "ClimaX",
    "StormCast",
    "DLWP",
    "FCN",
    "FengWu",
    "FuXi",
    "Pangu-Weather",
    "AlphaEarth",
]

SCOPES = {
    "global": "global",
    "fire_prone_top20": "top20",
}


def read_rows() -> dict[tuple[str, str], dict[str, str]]:
    with IN_CSV.open("r", encoding="utf-8", newline="") as fh:
        rows = list(csv.DictReader(fh))
    by_key = {(row["label"], row["scope"]): row for row in rows}
    missing = [
        (label, scope)
        for label in ROW_ORDER
        for scope in SCOPES.values()
        if (label, scope) not in by_key
    ]
    if missing:
        raise SystemExit(f"Missing selection-regret rows: {missing}")
    bad_labels = sorted({row["label"] for row in rows if "Pangu24" in row["label"]})
    if bad_labels:
        raise SystemExit(f"Stale Pangu24 labels found in final CSV: {bad_labels}")
    return by_key


def pct(row: dict[str, str], field: str) -> float:
    return float(row[field]) * 100.0


def escape_tex(text: str) -> str:
    return text.replace("&", r"\&").replace("%", r"\%")


def nice_floor(value: float, step: float) -> float:
    return math.floor(value / step) * step


def nice_ceil(value: float, step: float) -> float:
    return math.ceil(value / step) * step


def write_tikz(records: list[dict[str, float | str]], x_min: float, x_max: float) -> None:
    width = 5.55
    left = 2.45
    y_step = 0.41
    top = 4.35
    bottom = -0.35
    offset = 0.13
    x_span = x_max - x_min

    def x(value: float) -> float:
        return left + (value - x_min) / x_span * width

    lines: list[str] = [
        r"% Auto-generated by scripts/build_selection_regret_rq2_figure.py.",
        r"\begin{tikzpicture}[x=1cm,y=1cm]",
        r"\footnotesize",
    ]

    ticks = [tick for tick in range(int(x_min), int(x_max) + 1, 10) if tick % 10 == 0]
    if 0 not in ticks:
        ticks.append(0)
    ticks = sorted(set(ticks))

    for tick in ticks:
        xt = x(float(tick))
        color = "wfgray" if tick == 0 else "black!12"
        lw = "0.55pt" if tick == 0 else "0.35pt"
        lines.append(
            rf"\draw[{color}, line width={lw}] ({xt:.3f},{bottom:.3f}) -- ({xt:.3f},{top + 0.18:.3f});"
        )
        lines.append(rf"\node[anchor=north, font=\scriptsize, text=black!70] at ({xt:.3f},{bottom - 0.06:.3f}) {{{tick}}};")

    axis_y = bottom
    lines.append(
        rf"\draw[black!45, line width=0.4pt] ({x(x_min):.3f},{axis_y:.3f}) -- ({x(x_max):.3f},{axis_y:.3f});"
    )

    for idx, record in enumerate(records):
        y_base = top - idx * y_step
        label = str(record["feature_source"])
        label_tex = escape_tex(label)
        if label == "FireWx-FM ref.":
            label_tex = rf"\textcolor{{wfblue}}{{\textbf{{{label_tex}}}}}"
        lines.append(rf"\node[anchor=east, font=\scriptsize, text=black!82] at ({left - 0.13:.3f},{y_base:.3f}) {{{label_tex}}};")

        for scope_key, color, y_offset, marker in [
            ("global", "wfslate", -offset, "circle"),
            ("top20", "wforange", offset, "square"),
        ]:
            mean = float(record[f"{scope_key}_mean_pp"])
            std = float(record[f"{scope_key}_std_pp"])
            y_val = y_base + y_offset
            x_lo = x(mean - std)
            x_hi = x(mean + std)
            x_mid = x(mean)
            lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val:.3f}) -- ({x_hi:.3f},{y_val:.3f});")
            lines.append(rf"\draw[{color}, line width=0.72pt] ({x_lo:.3f},{y_val - 0.035:.3f}) -- ({x_lo:.3f},{y_val + 0.035:.3f});")
            lines.append(rf"\draw[{color}, line width=0.72pt] ({x_hi:.3f},{y_val - 0.035:.3f}) -- ({x_hi:.3f},{y_val + 0.035:.3f});")
            if marker == "circle":
                lines.append(rf"\filldraw[{color}] ({x_mid:.3f},{y_val:.3f}) circle[radius=0.045];")
            else:
                lines.append(rf"\filldraw[{color}] ({x_mid - 0.045:.3f},{y_val - 0.045:.3f}) rectangle ({x_mid + 0.045:.3f},{y_val + 0.045:.3f});")

    lines.append(r"\end{tikzpicture}")
    OUT_TIKZ.parent.mkdir(parents=True, exist_ok=True)
    OUT_TIKZ.write_text("\n".join(lines) + "\n", encoding="utf-8")


def main() -> None:
    by_key = read_rows()
    records: list[dict[str, float | str]] = []
    interval_edges: list[float] = []

    for label in ROW_ORDER:
        g = by_key[(label, SCOPES["global"])]
        f = by_key[(label, SCOPES["fire_prone_top20"])]
        g_mean = pct(g, "union_regret_mean")
        g_std = pct(g, "union_regret_std")
        f_mean = pct(f, "union_regret_mean")
        f_std = pct(f, "union_regret_std")
        interval_edges.extend([g_mean - g_std, g_mean + g_std, f_mean - f_std, f_mean + f_std])
        records.append(
            {
                "feature_source": label,
                "global_mean_pp": g_mean,
                "global_std_pp": g_std,
                "top20_mean_pp": f_mean,
                "top20_std_pp": f_std,
            }
        )

    x_min = min(-5.0, nice_floor(min(interval_edges), 10.0))
    x_max = max(30.0, nice_ceil(max(interval_edges), 10.0))
    write_tikz(records, x_min, x_max)

    OUT_VALUES.parent.mkdir(parents=True, exist_ok=True)
    with OUT_VALUES.open("w", encoding="utf-8", newline="") as fh:
        fieldnames = ["feature_source", "global_mean_pp", "global_std_pp", "top20_mean_pp", "top20_std_pp"]
        writer = csv.DictWriter(fh, fieldnames=fieldnames)
        writer.writeheader()
        for record in records:
            writer.writerow(
                {
                    "feature_source": record["feature_source"],
                    "global_mean_pp": f"{float(record['global_mean_pp']):.4f}",
                    "global_std_pp": f"{float(record['global_std_pp']):.4f}",
                    "top20_mean_pp": f"{float(record['top20_mean_pp']):.4f}",
                    "top20_std_pp": f"{float(record['top20_std_pp']):.4f}",
                }
            )


if __name__ == "__main__":
    main()