File size: 4,208 Bytes
80ef3b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python3
"""Build the RQ1 fire-prone contract-progression figure from summary JSON."""

from __future__ import annotations

import json
from pathlib import Path

from simple_pdf import PdfCanvas, draw_axes


ROOT = Path(__file__).resolve().parents[1]
IN_JSON = ROOT / "artifacts" / "results" / "fireprone_contract_progression_summary.json"
OUT = ROOT / "paper_outputs" / "figures" / "fig_fireprone_contract_progression_compact.pdf"

MODEL_ORDER = [
    ("reference", "FireWx-FM ref.", "Ref."),
    ("prithvi_wxc", "Prithvi-WxC", "WxC"),
    ("aurora", "Aurora", "Aurora"),
    ("climax", "ClimaX", "ClimaX"),
    ("stormcast", "StormCast", "Storm"),
    ("dlwp", "DLWP", "DLWP"),
    ("fcn", "FCN", "FCN"),
    ("fengwu", "FengWu", "FengWu"),
    ("fuxi", "FuXi", "FuXi"),
    ("pangu6", "Pangu-Weather", "Pangu-W"),
    ("alphaearth", "AlphaEarth", "Alpha"),
]

SCOPE_ORDER = [
    ("full_domain", "global"),
    ("train_fire_top05pct", "top 5%"),
    ("train_fire_top10pct", "top 10%"),
    ("train_fire_top20pct", "top 20%"),
]


def dashed_vline(c: PdfCanvas, x: float, y_start: float, y_end: float) -> None:
    dash, gap = 7.0, 5.0
    y = y_start
    while y < y_end:
        c.line([(x, y), (x, min(y + dash, y_end))], color=(0.42, 0.44, 0.46), lw=0.75)
        y += dash + gap


def main() -> None:
    data = json.loads(IN_JSON.read_text(encoding="utf-8"))
    by_key = {(row["model_tag"], row["scope"]): row for row in data["summary"]}

    c = PdfCanvas(width=1320, height=470)
    c.rect(0, 0, c.width, c.height, fill=(1, 1, 1))

    x0, y0, plot_w, plot_h, ymax = 72, 132, 1194, 268, 80.0
    draw_axes(c, x0, y0, plot_w, plot_h, ymax, [0, 20, 40, 60, 80])
    c.text(x0 - 38, y0 + plot_h + 8, "F1 (%)", size=8, color=(0.15, 0.15, 0.15), bold=True)

    colors = {
        "strict": (0.09, 0.22, 0.37),
        "tolerance": (0.31, 0.55, 0.80),
        "union": (0.75, 0.84, 0.94),
    }
    block_gap = 10.0
    fire_gap = 28.0
    block_w = (plot_w - fire_gap - 2 * block_gap) / len(SCOPE_ORDER)
    bar_step = block_w / len(MODEL_ORDER)
    bar_w = min(18.0, bar_step * 0.80)

    scope_lefts: dict[str, float] = {}
    current_x = x0 + 8.0
    for scope_idx, (scope, scope_label) in enumerate(SCOPE_ORDER):
        if scope_idx == 1:
            dashed_vline(c, current_x - fire_gap / 2.0, y0 - 6, y0 + plot_h + 16)
        scope_lefts[scope] = current_x
        c.text(current_x + block_w / 2.0, y0 + plot_h + 17, scope_label, size=15.0, align="center", bold=True)
        if scope_idx < len(SCOPE_ORDER) - 1:
            current_x += block_w + (fire_gap if scope_idx == 0 else block_gap)

    for scope, _scope_label in SCOPE_ORDER:
        block_x = scope_lefts[scope]
        for idx, (model_tag, _label, short) in enumerate(MODEL_ORDER):
            row = by_key[(model_tag, scope)]
            strict = row["strict_f1"]["mean"] * 100.0
            tolerance = row["tolerance_f1"]["mean"] * 100.0
            union = row["union_f1"]["mean"] * 100.0
            bx = block_x + idx * bar_step + (bar_step - bar_w) / 2.0
            base = y0
            for segment, value in [
                ("strict", max(0.0, strict)),
                ("tolerance", max(0.0, tolerance - strict)),
                ("union", max(0.0, union - tolerance)),
            ]:
                height = plot_h * value / ymax
                if height <= 0:
                    continue
                c.rect(bx, base, bar_w, height, fill=colors[segment], stroke=(1, 1, 1), lw=0.35)
                base += height
            c.text_rotated(bx + bar_w / 2.0 - 3.0, y0 - 76, short, angle_deg=-45.0, size=10.0, align="right")

    legend_x, legend_y = x0 + 18, y0 + plot_h - 26
    c.rect(legend_x - 13, legend_y - 12, 304, 23, fill=(0.98, 0.98, 0.96), stroke=(0.78, 0.80, 0.78), lw=0.45)
    for idx, (label, color) in enumerate([("Strict", colors["strict"]), ("Tolerance", colors["tolerance"]), ("Union", colors["union"])]):
        x = legend_x + idx * 98
        c.rect(x, legend_y - 3, 24, 9, fill=color, stroke=(1, 1, 1), lw=0.35)
        c.text(x + 31, legend_y - 1, label, size=8.0)

    OUT.parent.mkdir(parents=True, exist_ok=True)
    c.save(OUT)


if __name__ == "__main__":
    main()