| |
| """Build the RQ1 fire-prone contract-progression figure from summary JSON.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| from simple_pdf import PdfCanvas, draw_axes |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| IN_JSON = ROOT / "artifacts" / "results" / "fireprone_contract_progression_summary.json" |
| OUT = ROOT / "paper_outputs" / "figures" / "fig_fireprone_contract_progression_compact.pdf" |
|
|
| MODEL_ORDER = [ |
| ("reference", "FireWx-FM ref.", "Ref."), |
| ("prithvi_wxc", "Prithvi-WxC", "WxC"), |
| ("aurora", "Aurora", "Aurora"), |
| ("climax", "ClimaX", "ClimaX"), |
| ("stormcast", "StormCast", "Storm"), |
| ("dlwp", "DLWP", "DLWP"), |
| ("fcn", "FCN", "FCN"), |
| ("fengwu", "FengWu", "FengWu"), |
| ("fuxi", "FuXi", "FuXi"), |
| ("pangu6", "Pangu-Weather", "Pangu-W"), |
| ("alphaearth", "AlphaEarth", "Alpha"), |
| ] |
|
|
| SCOPE_ORDER = [ |
| ("full_domain", "global"), |
| ("train_fire_top05pct", "top 5%"), |
| ("train_fire_top10pct", "top 10%"), |
| ("train_fire_top20pct", "top 20%"), |
| ] |
|
|
|
|
| def dashed_vline(c: PdfCanvas, x: float, y_start: float, y_end: float) -> None: |
| dash, gap = 7.0, 5.0 |
| y = y_start |
| while y < y_end: |
| c.line([(x, y), (x, min(y + dash, y_end))], color=(0.42, 0.44, 0.46), lw=0.75) |
| y += dash + gap |
|
|
|
|
| def main() -> None: |
| data = json.loads(IN_JSON.read_text(encoding="utf-8")) |
| by_key = {(row["model_tag"], row["scope"]): row for row in data["summary"]} |
|
|
| c = PdfCanvas(width=1320, height=470) |
| c.rect(0, 0, c.width, c.height, fill=(1, 1, 1)) |
|
|
| x0, y0, plot_w, plot_h, ymax = 72, 132, 1194, 268, 80.0 |
| draw_axes(c, x0, y0, plot_w, plot_h, ymax, [0, 20, 40, 60, 80]) |
| c.text(x0 - 38, y0 + plot_h + 8, "F1 (%)", size=8, color=(0.15, 0.15, 0.15), bold=True) |
|
|
| colors = { |
| "strict": (0.09, 0.22, 0.37), |
| "tolerance": (0.31, 0.55, 0.80), |
| "union": (0.75, 0.84, 0.94), |
| } |
| block_gap = 10.0 |
| fire_gap = 28.0 |
| block_w = (plot_w - fire_gap - 2 * block_gap) / len(SCOPE_ORDER) |
| bar_step = block_w / len(MODEL_ORDER) |
| bar_w = min(18.0, bar_step * 0.80) |
|
|
| scope_lefts: dict[str, float] = {} |
| current_x = x0 + 8.0 |
| for scope_idx, (scope, scope_label) in enumerate(SCOPE_ORDER): |
| if scope_idx == 1: |
| dashed_vline(c, current_x - fire_gap / 2.0, y0 - 6, y0 + plot_h + 16) |
| scope_lefts[scope] = current_x |
| c.text(current_x + block_w / 2.0, y0 + plot_h + 17, scope_label, size=15.0, align="center", bold=True) |
| if scope_idx < len(SCOPE_ORDER) - 1: |
| current_x += block_w + (fire_gap if scope_idx == 0 else block_gap) |
|
|
| for scope, _scope_label in SCOPE_ORDER: |
| block_x = scope_lefts[scope] |
| for idx, (model_tag, _label, short) in enumerate(MODEL_ORDER): |
| row = by_key[(model_tag, scope)] |
| strict = row["strict_f1"]["mean"] * 100.0 |
| tolerance = row["tolerance_f1"]["mean"] * 100.0 |
| union = row["union_f1"]["mean"] * 100.0 |
| bx = block_x + idx * bar_step + (bar_step - bar_w) / 2.0 |
| base = y0 |
| for segment, value in [ |
| ("strict", max(0.0, strict)), |
| ("tolerance", max(0.0, tolerance - strict)), |
| ("union", max(0.0, union - tolerance)), |
| ]: |
| height = plot_h * value / ymax |
| if height <= 0: |
| continue |
| c.rect(bx, base, bar_w, height, fill=colors[segment], stroke=(1, 1, 1), lw=0.35) |
| base += height |
| c.text_rotated(bx + bar_w / 2.0 - 3.0, y0 - 76, short, angle_deg=-45.0, size=10.0, align="right") |
|
|
| legend_x, legend_y = x0 + 18, y0 + plot_h - 26 |
| c.rect(legend_x - 13, legend_y - 12, 304, 23, fill=(0.98, 0.98, 0.96), stroke=(0.78, 0.80, 0.78), lw=0.45) |
| for idx, (label, color) in enumerate([("Strict", colors["strict"]), ("Tolerance", colors["tolerance"]), ("Union", colors["union"])]): |
| x = legend_x + idx * 98 |
| c.rect(x, legend_y - 3, 24, 9, fill=color, stroke=(1, 1, 1), lw=0.35) |
| c.text(x + 31, legend_y - 1, label, size=8.0) |
|
|
| OUT.parent.mkdir(parents=True, exist_ok=True) |
| c.save(OUT) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|