Wildfire-FM / scripts /build_fireprone_contract_progression_figure.py
yx21e's picture
Initial FireWx-FM artifact release
80ef3b2 verified
#!/usr/bin/env python3
"""Build the RQ1 fire-prone contract-progression figure from summary JSON."""
from __future__ import annotations
import json
from pathlib import Path
from simple_pdf import PdfCanvas, draw_axes
ROOT = Path(__file__).resolve().parents[1]
IN_JSON = ROOT / "artifacts" / "results" / "fireprone_contract_progression_summary.json"
OUT = ROOT / "paper_outputs" / "figures" / "fig_fireprone_contract_progression_compact.pdf"
MODEL_ORDER = [
("reference", "FireWx-FM ref.", "Ref."),
("prithvi_wxc", "Prithvi-WxC", "WxC"),
("aurora", "Aurora", "Aurora"),
("climax", "ClimaX", "ClimaX"),
("stormcast", "StormCast", "Storm"),
("dlwp", "DLWP", "DLWP"),
("fcn", "FCN", "FCN"),
("fengwu", "FengWu", "FengWu"),
("fuxi", "FuXi", "FuXi"),
("pangu6", "Pangu-Weather", "Pangu-W"),
("alphaearth", "AlphaEarth", "Alpha"),
]
SCOPE_ORDER = [
("full_domain", "global"),
("train_fire_top05pct", "top 5%"),
("train_fire_top10pct", "top 10%"),
("train_fire_top20pct", "top 20%"),
]
def dashed_vline(c: PdfCanvas, x: float, y_start: float, y_end: float) -> None:
dash, gap = 7.0, 5.0
y = y_start
while y < y_end:
c.line([(x, y), (x, min(y + dash, y_end))], color=(0.42, 0.44, 0.46), lw=0.75)
y += dash + gap
def main() -> None:
data = json.loads(IN_JSON.read_text(encoding="utf-8"))
by_key = {(row["model_tag"], row["scope"]): row for row in data["summary"]}
c = PdfCanvas(width=1320, height=470)
c.rect(0, 0, c.width, c.height, fill=(1, 1, 1))
x0, y0, plot_w, plot_h, ymax = 72, 132, 1194, 268, 80.0
draw_axes(c, x0, y0, plot_w, plot_h, ymax, [0, 20, 40, 60, 80])
c.text(x0 - 38, y0 + plot_h + 8, "F1 (%)", size=8, color=(0.15, 0.15, 0.15), bold=True)
colors = {
"strict": (0.09, 0.22, 0.37),
"tolerance": (0.31, 0.55, 0.80),
"union": (0.75, 0.84, 0.94),
}
block_gap = 10.0
fire_gap = 28.0
block_w = (plot_w - fire_gap - 2 * block_gap) / len(SCOPE_ORDER)
bar_step = block_w / len(MODEL_ORDER)
bar_w = min(18.0, bar_step * 0.80)
scope_lefts: dict[str, float] = {}
current_x = x0 + 8.0
for scope_idx, (scope, scope_label) in enumerate(SCOPE_ORDER):
if scope_idx == 1:
dashed_vline(c, current_x - fire_gap / 2.0, y0 - 6, y0 + plot_h + 16)
scope_lefts[scope] = current_x
c.text(current_x + block_w / 2.0, y0 + plot_h + 17, scope_label, size=15.0, align="center", bold=True)
if scope_idx < len(SCOPE_ORDER) - 1:
current_x += block_w + (fire_gap if scope_idx == 0 else block_gap)
for scope, _scope_label in SCOPE_ORDER:
block_x = scope_lefts[scope]
for idx, (model_tag, _label, short) in enumerate(MODEL_ORDER):
row = by_key[(model_tag, scope)]
strict = row["strict_f1"]["mean"] * 100.0
tolerance = row["tolerance_f1"]["mean"] * 100.0
union = row["union_f1"]["mean"] * 100.0
bx = block_x + idx * bar_step + (bar_step - bar_w) / 2.0
base = y0
for segment, value in [
("strict", max(0.0, strict)),
("tolerance", max(0.0, tolerance - strict)),
("union", max(0.0, union - tolerance)),
]:
height = plot_h * value / ymax
if height <= 0:
continue
c.rect(bx, base, bar_w, height, fill=colors[segment], stroke=(1, 1, 1), lw=0.35)
base += height
c.text_rotated(bx + bar_w / 2.0 - 3.0, y0 - 76, short, angle_deg=-45.0, size=10.0, align="right")
legend_x, legend_y = x0 + 18, y0 + plot_h - 26
c.rect(legend_x - 13, legend_y - 12, 304, 23, fill=(0.98, 0.98, 0.96), stroke=(0.78, 0.80, 0.78), lw=0.45)
for idx, (label, color) in enumerate([("Strict", colors["strict"]), ("Tolerance", colors["tolerance"]), ("Union", colors["union"])]):
x = legend_x + idx * 98
c.rect(x, legend_y - 3, 24, 9, fill=color, stroke=(1, 1, 1), lw=0.35)
c.text(x + 31, legend_y - 1, label, size=8.0)
OUT.parent.mkdir(parents=True, exist_ok=True)
c.save(OUT)
if __name__ == "__main__":
main()