#!/usr/bin/env python3 """Build the RQ1 fire-prone contract-progression figure from summary JSON.""" from __future__ import annotations import json from pathlib import Path from simple_pdf import PdfCanvas, draw_axes ROOT = Path(__file__).resolve().parents[1] IN_JSON = ROOT / "artifacts" / "results" / "fireprone_contract_progression_summary.json" OUT = ROOT / "paper_outputs" / "figures" / "fig_fireprone_contract_progression_compact.pdf" MODEL_ORDER = [ ("reference", "FireWx-FM ref.", "Ref."), ("prithvi_wxc", "Prithvi-WxC", "WxC"), ("aurora", "Aurora", "Aurora"), ("climax", "ClimaX", "ClimaX"), ("stormcast", "StormCast", "Storm"), ("dlwp", "DLWP", "DLWP"), ("fcn", "FCN", "FCN"), ("fengwu", "FengWu", "FengWu"), ("fuxi", "FuXi", "FuXi"), ("pangu6", "Pangu-Weather", "Pangu-W"), ("alphaearth", "AlphaEarth", "Alpha"), ] SCOPE_ORDER = [ ("full_domain", "global"), ("train_fire_top05pct", "top 5%"), ("train_fire_top10pct", "top 10%"), ("train_fire_top20pct", "top 20%"), ] def dashed_vline(c: PdfCanvas, x: float, y_start: float, y_end: float) -> None: dash, gap = 7.0, 5.0 y = y_start while y < y_end: c.line([(x, y), (x, min(y + dash, y_end))], color=(0.42, 0.44, 0.46), lw=0.75) y += dash + gap def main() -> None: data = json.loads(IN_JSON.read_text(encoding="utf-8")) by_key = {(row["model_tag"], row["scope"]): row for row in data["summary"]} c = PdfCanvas(width=1320, height=470) c.rect(0, 0, c.width, c.height, fill=(1, 1, 1)) x0, y0, plot_w, plot_h, ymax = 72, 132, 1194, 268, 80.0 draw_axes(c, x0, y0, plot_w, plot_h, ymax, [0, 20, 40, 60, 80]) c.text(x0 - 38, y0 + plot_h + 8, "F1 (%)", size=8, color=(0.15, 0.15, 0.15), bold=True) colors = { "strict": (0.09, 0.22, 0.37), "tolerance": (0.31, 0.55, 0.80), "union": (0.75, 0.84, 0.94), } block_gap = 10.0 fire_gap = 28.0 block_w = (plot_w - fire_gap - 2 * block_gap) / len(SCOPE_ORDER) bar_step = block_w / len(MODEL_ORDER) bar_w = min(18.0, bar_step * 0.80) scope_lefts: dict[str, float] = {} current_x = x0 + 8.0 for scope_idx, (scope, scope_label) in enumerate(SCOPE_ORDER): if scope_idx == 1: dashed_vline(c, current_x - fire_gap / 2.0, y0 - 6, y0 + plot_h + 16) scope_lefts[scope] = current_x c.text(current_x + block_w / 2.0, y0 + plot_h + 17, scope_label, size=15.0, align="center", bold=True) if scope_idx < len(SCOPE_ORDER) - 1: current_x += block_w + (fire_gap if scope_idx == 0 else block_gap) for scope, _scope_label in SCOPE_ORDER: block_x = scope_lefts[scope] for idx, (model_tag, _label, short) in enumerate(MODEL_ORDER): row = by_key[(model_tag, scope)] strict = row["strict_f1"]["mean"] * 100.0 tolerance = row["tolerance_f1"]["mean"] * 100.0 union = row["union_f1"]["mean"] * 100.0 bx = block_x + idx * bar_step + (bar_step - bar_w) / 2.0 base = y0 for segment, value in [ ("strict", max(0.0, strict)), ("tolerance", max(0.0, tolerance - strict)), ("union", max(0.0, union - tolerance)), ]: height = plot_h * value / ymax if height <= 0: continue c.rect(bx, base, bar_w, height, fill=colors[segment], stroke=(1, 1, 1), lw=0.35) base += height c.text_rotated(bx + bar_w / 2.0 - 3.0, y0 - 76, short, angle_deg=-45.0, size=10.0, align="right") legend_x, legend_y = x0 + 18, y0 + plot_h - 26 c.rect(legend_x - 13, legend_y - 12, 304, 23, fill=(0.98, 0.98, 0.96), stroke=(0.78, 0.80, 0.78), lw=0.45) for idx, (label, color) in enumerate([("Strict", colors["strict"]), ("Tolerance", colors["tolerance"]), ("Union", colors["union"])]): x = legend_x + idx * 98 c.rect(x, legend_y - 3, 24, 9, fill=color, stroke=(1, 1, 1), lw=0.35) c.text(x + 31, legend_y - 1, label, size=8.0) OUT.parent.mkdir(parents=True, exist_ok=True) c.save(OUT) if __name__ == "__main__": main()