File size: 3,121 Bytes
93e68bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Evaluation metric plots for the CI-Triage-Env baseline comparison.

Module-level try/except lets the module import without matplotlib; tests patch
`ci_triage_env.training.plotting.plt` and `.sns` directly.
"""

from __future__ import annotations

from pathlib import Path

try:
    import matplotlib.pyplot as plt  # type: ignore[import]
    import seaborn as sns  # type: ignore[import]
except ImportError:
    plt = None  # type: ignore[assignment]
    sns = None  # type: ignore[assignment]


def plot_all_eval_metrics(df, output_dir) -> None:
    """Write ≥ 5 PNG plots summarising the evaluation results.

    Uses module-level plt/sns so tests can inject mocks without sys.modules tricks.
    Raises ImportError if matplotlib is not installed.
    """
    if plt is None:
        raise ImportError(
            "matplotlib and seaborn are required for plotting. "
            "Install with: pip install matplotlib seaborn"
        )

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # 1. Diagnosis accuracy by baseline
    fig, ax = plt.subplots(figsize=(8, 5))
    acc = df.groupby("baseline")["diagnosis_correct"].mean()
    ax.bar(list(acc.index), list(acc.values))
    ax.set_ylabel("Diagnosis Accuracy")
    ax.set_xlabel("Baseline")
    ax.set_title("Diagnosis Accuracy by Baseline")
    fig.tight_layout()
    fig.savefig(output_dir / "diagnosis_accuracy.png", dpi=120)
    plt.close(fig)

    # 2. Mean total reward ± std
    fig, ax = plt.subplots(figsize=(8, 5))
    agg = df.groupby("baseline")["total_reward"].agg(["mean", "std"])
    ax.bar(list(agg.index), list(agg["mean"]), yerr=list(agg["std"].fillna(0)))
    ax.set_ylabel("Mean Total Reward")
    ax.set_xlabel("Baseline")
    fig.tight_layout()
    fig.savefig(output_dir / "total_reward.png", dpi=120)
    plt.close(fig)

    # 3. Per-family accuracy heatmap
    pivot = df.pivot_table(
        index="baseline", columns="family",
        values="diagnosis_correct", aggfunc="mean",
    )
    fig, ax = plt.subplots(figsize=(10, 5))
    sns.heatmap(pivot, annot=True, fmt=".2f", cmap="Greens", ax=ax)
    fig.tight_layout()
    fig.savefig(output_dir / "per_family_accuracy.png", dpi=120)
    plt.close(fig)

    # 4. Reliability (calibration) on ambiguous scenarios
    amb = df[df["is_ambiguous_scenario"]]
    fig, ax = plt.subplots(figsize=(8, 5))
    for baseline in amb["baseline"].unique():
        sub = amb[amb["baseline"] == baseline]
        brier_vals = sub["brier_on_ambiguous"].fillna(0)
        ax.scatter(list(sub["confidence"]), list(1 - brier_vals), label=baseline, alpha=0.5)
    ax.set_xlabel("Reported confidence")
    ax.set_ylabel("Calibration score (1 − Brier)")
    ax.legend()
    fig.tight_layout()
    fig.savefig(output_dir / "calibration_ambiguous.png", dpi=120)
    plt.close(fig)

    # 5. Tool-call cost distribution
    fig, ax = plt.subplots(figsize=(8, 5))
    sns.boxplot(data=df, x="baseline", y="total_cost", ax=ax)
    ax.set_ylabel("Total cost ($)")
    fig.tight_layout()
    fig.savefig(output_dir / "cost_distribution.png", dpi=120)
    plt.close(fig)