File size: 2,242 Bytes
03815d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Render per-rubric ablation bar chart from `logs/ablation_study.json`.

Each child rubric is zeroed in turn (eval-time, not retrain) and the
delta in average composite reward is plotted. Bars left of 0 → rubric
matters (reward dropped without it); bars right of 0 → rubric helps the
model game the metric (reward rose without it).

Output: plots/chakravyuh_plots/ablation_per_rubric.png
"""

from __future__ import annotations

import json
from pathlib import Path

import matplotlib.pyplot as plt


def main() -> int:
    src = Path("logs/ablation_study.json")
    out = Path("plots/chakravyuh_plots/ablation_per_rubric.png")
    data = json.loads(src.read_text(encoding="utf-8"))
    rows = data["ablations"]
    full_avg = data["full_avg_reward"]

    rows_sorted = sorted(rows, key=lambda r: r["delta_reward"])
    names = [r["rubric_zeroed"] for r in rows_sorted]
    deltas = [r["delta_reward"] for r in rows_sorted]
    weights = [r["default_weight"] for r in rows_sorted]

    colors = ["#c62828" if d < 0 else "#558b2f" if d > 0 else "#9e9e9e" for d in deltas]
    fig, ax = plt.subplots(figsize=(8, 5))
    y_pos = list(range(len(names)))
    ax.barh(y_pos, deltas, color=colors, edgecolor="black", linewidth=0.5)
    ax.set_yticks(y_pos)
    ax.set_yticklabels([f"{n}\n(w={w:+.2f})" for n, w in zip(names, weights)], fontsize=9)
    ax.axvline(0, color="black", linewidth=0.8)
    for i, d in enumerate(deltas):
        ax.text(
            d + (0.01 if d >= 0 else -0.01), i,
            f"{d:+.4f}",
            va="center", ha="left" if d >= 0 else "right",
            fontsize=9, fontweight="bold",
        )
    ax.set_xlabel("Δ avg composite reward (zeroed − full)", fontsize=11)
    ax.set_title(
        f"Per-rubric ablation — {data['rubric_class']} on n={data['n_scenarios']}\n"
        f"full reward = {full_avg:.4f} · negative bar = rubric matters",
        fontsize=11, fontweight="bold",
    )
    ax.grid(True, alpha=0.3, axis="x")
    fig.tight_layout()
    out.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out, dpi=120, bbox_inches="tight")
    plt.close(fig)
    print(f"Wrote {out} ({out.stat().st_size:,} bytes)")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())