| """Evaluation and training plot generation.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
|
|
| def _load_json(path: Path) -> dict: |
| if not path.exists(): |
| return {} |
| return json.loads(path.read_text(encoding="utf-8")) |
|
|
|
|
| def _policy_stack_label(label: str) -> str: |
| labels = { |
| "bandit-only": "Bandits only", |
| "bandit_only": "Bandits only", |
| "llm-only": "Baseline LLM only", |
| "llm_only": "Baseline LLM only", |
| "llm+bandit": "LLM + Bandits", |
| "llm_bandit": "LLM + Bandits", |
| } |
| return labels.get(label, label) |
|
|
|
|
| def generate_training_plots(report_dir: Path, plot_dir: Path) -> list[str]: |
| plot_dir.mkdir(parents=True, exist_ok=True) |
| planner = _load_json(report_dir / "planner_grpo.json") |
| supervisor = _load_json(report_dir / "supervisor_grpo.json") |
| dosing = _load_json(report_dir / "dosing_grpo.json") |
|
|
| series_names = ["avg_reward", "legality_rate", "success_rate", "avg_process_fidelity"] |
| labels = ["supervisor", "planner", "dosing"] |
| payloads = [supervisor, planner, dosing] |
| output_paths: list[str] = [] |
|
|
| for metric in series_names: |
| values = [float(item.get(metric, 0.0)) for item in payloads] |
| fig, ax = plt.subplots(figsize=(6.2, 3.6)) |
| ax.bar(labels, values, color=["#2f855a", "#2b6cb0", "#d69e2e"]) |
| ax.set_ylim(0.0, 1.0) |
| ax.set_title(metric) |
| ax.grid(alpha=0.2, axis="y") |
| path = plot_dir / f"{metric}.png" |
| fig.tight_layout() |
| fig.savefig(path) |
| plt.close(fig) |
| output_paths.append(str(path)) |
|
|
| baselines = _load_json(report_dir / "baselines.json") |
| ablations = baselines.get("policy_stack_ablations", {}) if isinstance(baselines, dict) else {} |
| if isinstance(ablations, dict) and ablations: |
| keys = list(ablations.keys()) |
| labels = [_policy_stack_label(label) for label in keys] |
| values = [float((ablations.get(label) or {}).get("avg_reward", 0.0)) for label in keys] |
| fig, ax = plt.subplots(figsize=(7.0, 3.8)) |
| ax.bar(labels, values, color=["#805ad5", "#2c5282", "#2f855a"][: len(labels)]) |
| ax.set_ylim(0.0, 1.0) |
| ax.set_title("Without Bandits vs With Bandits average reward") |
| ax.grid(alpha=0.2, axis="y") |
| path = plot_dir / "policy_stack_avg_reward.png" |
| fig.tight_layout() |
| fig.savefig(path) |
| plt.close(fig) |
| output_paths.append(str(path)) |
|
|
| |
| planner_channels = ((planner or {}).get("primary_reward_channels", {}) if isinstance(planner, dict) else {}) or {} |
| if planner_channels: |
| labels = list(planner_channels.keys()) |
| values = [float(planner_channels[key]) for key in labels] |
| fig, ax = plt.subplots(figsize=(7.0, 3.8)) |
| ax.bar(labels, values, color=["#276749", "#2b6cb0", "#dd6b20", "#4a5568"][: len(labels)]) |
| ax.set_ylim(0.0, 1.0) |
| ax.set_title("planner_primary_reward_channels") |
| ax.grid(alpha=0.2, axis="y") |
| path = plot_dir / "planner_primary_reward_channels.png" |
| fig.tight_layout() |
| fig.savefig(path) |
| plt.close(fig) |
| output_paths.append(str(path)) |
| return output_paths |
|
|