Spaces:
Running
Running
| """Evaluation and training plot generation.""" | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| def _load_json(path: Path) -> dict: | |
| if not path.exists(): | |
| return {} | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| def generate_training_plots(report_dir: Path, plot_dir: Path) -> list[str]: | |
| plot_dir.mkdir(parents=True, exist_ok=True) | |
| planner = _load_json(report_dir / "planner_grpo.json") | |
| supervisor = _load_json(report_dir / "supervisor_grpo.json") | |
| dosing = _load_json(report_dir / "dosing_grpo.json") | |
| series_names = ["avg_reward", "legality_rate", "success_rate", "avg_process_fidelity"] | |
| labels = ["supervisor", "planner", "dosing"] | |
| payloads = [supervisor, planner, dosing] | |
| output_paths: list[str] = [] | |
| for metric in series_names: | |
| values = [float(item.get(metric, 0.0)) for item in payloads] | |
| fig, ax = plt.subplots(figsize=(6.2, 3.6)) | |
| ax.bar(labels, values, color=["#2f855a", "#2b6cb0", "#d69e2e"]) | |
| ax.set_ylim(0.0, 1.0) | |
| ax.set_title(metric) | |
| ax.grid(alpha=0.2, axis="y") | |
| path = plot_dir / f"{metric}.png" | |
| fig.tight_layout() | |
| fig.savefig(path) | |
| plt.close(fig) | |
| output_paths.append(str(path)) | |
| baselines = _load_json(report_dir / "baselines.json") | |
| ablations = baselines.get("policy_stack_ablations", {}) if isinstance(baselines, dict) else {} | |
| if isinstance(ablations, dict) and ablations: | |
| labels = list(ablations.keys()) | |
| values = [float((ablations.get(label) or {}).get("avg_reward", 0.0)) for label in labels] | |
| fig, ax = plt.subplots(figsize=(7.0, 3.8)) | |
| ax.bar(labels, values, color=["#805ad5", "#2c5282", "#2f855a"][: len(labels)]) | |
| ax.set_ylim(0.0, 1.0) | |
| ax.set_title("policy_stack_avg_reward") | |
| ax.grid(alpha=0.2, axis="y") | |
| path = plot_dir / "policy_stack_avg_reward.png" | |
| fig.tight_layout() | |
| fig.savefig(path) | |
| plt.close(fig) | |
| output_paths.append(str(path)) | |
| # Primary reward channel comparison from planner summary when present. | |
| planner_channels = ((planner or {}).get("primary_reward_channels", {}) if isinstance(planner, dict) else {}) or {} | |
| if planner_channels: | |
| labels = list(planner_channels.keys()) | |
| values = [float(planner_channels[key]) for key in labels] | |
| fig, ax = plt.subplots(figsize=(7.0, 3.8)) | |
| ax.bar(labels, values, color=["#276749", "#2b6cb0", "#dd6b20", "#4a5568"][: len(labels)]) | |
| ax.set_ylim(0.0, 1.0) | |
| ax.set_title("planner_primary_reward_channels") | |
| ax.grid(alpha=0.2, axis="y") | |
| path = plot_dir / "planner_primary_reward_channels.png" | |
| fig.tight_layout() | |
| fig.savefig(path) | |
| plt.close(fig) | |
| output_paths.append(str(path)) | |
| return output_paths | |