"""Evaluation and training plot generation.""" from __future__ import annotations import json from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt def _load_json(path: Path) -> dict: if not path.exists(): return {} return json.loads(path.read_text(encoding="utf-8")) def _policy_stack_label(label: str) -> str: labels = { "bandit-only": "Bandits only", "bandit_only": "Bandits only", "llm-only": "Baseline LLM only", "llm_only": "Baseline LLM only", "llm+bandit": "LLM + Bandits", "llm_bandit": "LLM + Bandits", } return labels.get(label, label) def generate_training_plots(report_dir: Path, plot_dir: Path) -> list[str]: plot_dir.mkdir(parents=True, exist_ok=True) planner = _load_json(report_dir / "planner_grpo.json") supervisor = _load_json(report_dir / "supervisor_grpo.json") dosing = _load_json(report_dir / "dosing_grpo.json") series_names = ["avg_reward", "legality_rate", "success_rate", "avg_process_fidelity"] labels = ["supervisor", "planner", "dosing"] payloads = [supervisor, planner, dosing] output_paths: list[str] = [] for metric in series_names: values = [float(item.get(metric, 0.0)) for item in payloads] fig, ax = plt.subplots(figsize=(6.2, 3.6)) ax.bar(labels, values, color=["#2f855a", "#2b6cb0", "#d69e2e"]) ax.set_ylim(0.0, 1.0) ax.set_title(metric) ax.grid(alpha=0.2, axis="y") path = plot_dir / f"{metric}.png" fig.tight_layout() fig.savefig(path) plt.close(fig) output_paths.append(str(path)) baselines = _load_json(report_dir / "baselines.json") ablations = baselines.get("policy_stack_ablations", {}) if isinstance(baselines, dict) else {} if isinstance(ablations, dict) and ablations: keys = list(ablations.keys()) labels = [_policy_stack_label(label) for label in keys] values = [float((ablations.get(label) or {}).get("avg_reward", 0.0)) for label in keys] fig, ax = plt.subplots(figsize=(7.0, 3.8)) ax.bar(labels, values, color=["#805ad5", "#2c5282", "#2f855a"][: len(labels)]) ax.set_ylim(0.0, 1.0) ax.set_title("Without Bandits vs With Bandits average reward") ax.grid(alpha=0.2, axis="y") path = plot_dir / "policy_stack_avg_reward.png" fig.tight_layout() fig.savefig(path) plt.close(fig) output_paths.append(str(path)) # Primary reward channel comparison from planner summary when present. planner_channels = ((planner or {}).get("primary_reward_channels", {}) if isinstance(planner, dict) else {}) or {} if planner_channels: labels = list(planner_channels.keys()) values = [float(planner_channels[key]) for key in labels] fig, ax = plt.subplots(figsize=(7.0, 3.8)) ax.bar(labels, values, color=["#276749", "#2b6cb0", "#dd6b20", "#4a5568"][: len(labels)]) ax.set_ylim(0.0, 1.0) ax.set_title("planner_primary_reward_channels") ax.grid(alpha=0.2, axis="y") path = plot_dir / "planner_primary_reward_channels.png" fig.tight_layout() fig.savefig(path) plt.close(fig) output_paths.append(str(path)) return output_paths