TheJackBright's picture
Deploy GitHub root master to Space
c296d62
"""Evaluation and training plot generation."""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def _load_json(path: Path) -> dict:
if not path.exists():
return {}
return json.loads(path.read_text(encoding="utf-8"))
def _policy_stack_label(label: str) -> str:
labels = {
"bandit-only": "Bandits only",
"bandit_only": "Bandits only",
"llm-only": "Baseline LLM only",
"llm_only": "Baseline LLM only",
"llm+bandit": "LLM + Bandits",
"llm_bandit": "LLM + Bandits",
}
return labels.get(label, label)
def generate_training_plots(report_dir: Path, plot_dir: Path) -> list[str]:
plot_dir.mkdir(parents=True, exist_ok=True)
planner = _load_json(report_dir / "planner_grpo.json")
supervisor = _load_json(report_dir / "supervisor_grpo.json")
dosing = _load_json(report_dir / "dosing_grpo.json")
series_names = ["avg_reward", "legality_rate", "success_rate", "avg_process_fidelity"]
labels = ["supervisor", "planner", "dosing"]
payloads = [supervisor, planner, dosing]
output_paths: list[str] = []
for metric in series_names:
values = [float(item.get(metric, 0.0)) for item in payloads]
fig, ax = plt.subplots(figsize=(6.2, 3.6))
ax.bar(labels, values, color=["#2f855a", "#2b6cb0", "#d69e2e"])
ax.set_ylim(0.0, 1.0)
ax.set_title(metric)
ax.grid(alpha=0.2, axis="y")
path = plot_dir / f"{metric}.png"
fig.tight_layout()
fig.savefig(path)
plt.close(fig)
output_paths.append(str(path))
baselines = _load_json(report_dir / "baselines.json")
ablations = baselines.get("policy_stack_ablations", {}) if isinstance(baselines, dict) else {}
if isinstance(ablations, dict) and ablations:
keys = list(ablations.keys())
labels = [_policy_stack_label(label) for label in keys]
values = [float((ablations.get(label) or {}).get("avg_reward", 0.0)) for label in keys]
fig, ax = plt.subplots(figsize=(7.0, 3.8))
ax.bar(labels, values, color=["#805ad5", "#2c5282", "#2f855a"][: len(labels)])
ax.set_ylim(0.0, 1.0)
ax.set_title("Without Bandits vs With Bandits average reward")
ax.grid(alpha=0.2, axis="y")
path = plot_dir / "policy_stack_avg_reward.png"
fig.tight_layout()
fig.savefig(path)
plt.close(fig)
output_paths.append(str(path))
# Primary reward channel comparison from planner summary when present.
planner_channels = ((planner or {}).get("primary_reward_channels", {}) if isinstance(planner, dict) else {}) or {}
if planner_channels:
labels = list(planner_channels.keys())
values = [float(planner_channels[key]) for key in labels]
fig, ax = plt.subplots(figsize=(7.0, 3.8))
ax.bar(labels, values, color=["#276749", "#2b6cb0", "#dd6b20", "#4a5568"][: len(labels)])
ax.set_ylim(0.0, 1.0)
ax.set_title("planner_primary_reward_channels")
ax.grid(alpha=0.2, axis="y")
path = plot_dir / "planner_primary_reward_channels.png"
fig.tight_layout()
fig.savefig(path)
plt.close(fig)
output_paths.append(str(path))
return output_paths