| |
| """Generate sweep summaries, charts, and anti-hacking checks for HF training.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from collections import Counter |
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| REWARD_MIN = 0.001 |
| REWARD_MAX = 0.999 |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Summarize PolyGuard HF training sweeps.") |
| parser.add_argument("--sweep-dir", default="outputs/reports/sweeps") |
| parser.add_argument("--plot-dir", default="outputs/plots") |
| parser.add_argument("--output", default="outputs/reports/hf_sweep_summary.json") |
| parser.add_argument("--anti-hacking-output", default="outputs/reports/anti_hacking_overfit_report.json") |
| parser.add_argument( |
| "--mode", |
| choices=["full", "sft-baseline"], |
| default="full", |
| help="Report mode. SFT baseline mode treats GRPO artifacts as optional.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def _read_json(path: Path) -> dict[str, Any]: |
| if not path.exists(): |
| return {} |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except json.JSONDecodeError: |
| return {} |
| return payload if isinstance(payload, dict) else {} |
|
|
|
|
| def _read_history(path: Path) -> list[dict[str, Any]]: |
| if not path.exists(): |
| return [] |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except json.JSONDecodeError: |
| return [] |
| return [row for row in payload if isinstance(row, dict)] if isinstance(payload, list) else [] |
|
|
|
|
| def _read_jsonl(path: Path) -> list[dict[str, Any]]: |
| if not path.exists(): |
| return [] |
| rows: list[dict[str, Any]] = [] |
| with path.open("r", encoding="utf-8") as handle: |
| for line in handle: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| payload = json.loads(line) |
| except json.JSONDecodeError: |
| continue |
| if isinstance(payload, dict): |
| rows.append(payload) |
| return rows |
|
|
|
|
| def _as_float(value: Any, default: float = 0.0) -> float: |
| try: |
| return float(value) |
| except (TypeError, ValueError): |
| return default |
|
|
|
|
| def _is_reward_value(value: Any) -> bool: |
| if isinstance(value, bool) or not isinstance(value, int | float): |
| return False |
| number = float(value) |
| return REWARD_MIN <= number <= REWARD_MAX and round(number, 3) == number |
|
|
|
|
| def _scan_reward_payload(payload: Any, failures: list[str], path: str) -> None: |
| if isinstance(payload, dict): |
| for key, value in payload.items(): |
| next_path = f"{path}.{key}" if path else str(key) |
| if key in {"reward", "env_reward", "avg_reward", "avg_env_reward"} or key.endswith("_score"): |
| if not _is_reward_value(value): |
| failures.append(f"{next_path}={value!r}") |
| elif key in {"reward_breakdown", "primary_reward_channels", "avg_reward_components", "avg_primary_reward_channels"}: |
| if isinstance(value, dict): |
| for sub_key, sub_value in value.items(): |
| if not _is_reward_value(sub_value): |
| failures.append(f"{next_path}.{sub_key}={sub_value!r}") |
| else: |
| failures.append(f"{next_path}=not_dict") |
| else: |
| _scan_reward_payload(value, failures, next_path) |
| elif isinstance(payload, list): |
| for idx, item in enumerate(payload): |
| _scan_reward_payload(item, failures, f"{path}[{idx}]") |
|
|
|
|
| def _history_series(history: list[dict[str, Any]], names: tuple[str, ...]) -> tuple[list[int], list[float]]: |
| xs: list[int] = [] |
| ys: list[float] = [] |
| for idx, row in enumerate(history, start=1): |
| for name in names: |
| if name in row: |
| xs.append(idx) |
| ys.append(_as_float(row.get(name))) |
| break |
| return xs, ys |
|
|
|
|
| def _plot_placeholder(path: Path, title: str) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| fig, ax = plt.subplots(figsize=(8, 4.5)) |
| ax.text(0.5, 0.5, "No completed sweep data yet", ha="center", va="center", fontsize=12) |
| ax.set_axis_off() |
| ax.set_title(title) |
| fig.tight_layout() |
| fig.savefig(path, dpi=160) |
| plt.close(fig) |
|
|
|
|
| def _bar_chart(path: Path, title: str, labels: list[str], series: dict[str, list[float]], ylabel: str = "Reward") -> None: |
| if not labels or not series: |
| _plot_placeholder(path, title) |
| return |
| path.parent.mkdir(parents=True, exist_ok=True) |
| fig, ax = plt.subplots(figsize=(10, 5.2)) |
| width = 0.8 / max(1, len(series)) |
| x_positions = list(range(len(labels))) |
| offsets = [(-0.4 + (idx + 0.5) * width) for idx in range(len(series))] |
| for offset, (name, values) in zip(offsets, series.items(), strict=False): |
| ax.bar([x + offset for x in x_positions], values, width=width, label=name) |
| ax.set_title(title) |
| ax.set_ylabel(ylabel) |
| ax.set_xticks(x_positions) |
| ax.set_xticklabels(labels, rotation=20, ha="right") |
| ax.set_ylim(0, max(1.0, max((max(vals) for vals in series.values() if vals), default=1.0) * 1.15)) |
| ax.legend() |
| ax.grid(axis="y", alpha=0.24) |
| fig.tight_layout() |
| fig.savefig(path, dpi=160) |
| plt.close(fig) |
|
|
|
|
| def _line_chart( |
| path: Path, |
| title: str, |
| curves: dict[str, tuple[list[int], list[float]]], |
| ylabel: str, |
| ) -> None: |
| curves = {key: value for key, value in curves.items() if value[0] and value[1]} |
| if not curves: |
| _plot_placeholder(path, title) |
| return |
| path.parent.mkdir(parents=True, exist_ok=True) |
| fig, ax = plt.subplots(figsize=(10, 5.2)) |
| for label, (xs, ys) in curves.items(): |
| ax.plot(xs, ys, marker="o", linewidth=1.6, markersize=3.5, label=label) |
| ax.set_title(title) |
| ax.set_xlabel("Logged step") |
| ax.set_ylabel(ylabel) |
| ax.grid(alpha=0.24) |
| ax.legend() |
| fig.tight_layout() |
| fig.savefig(path, dpi=160) |
| plt.close(fig) |
|
|
|
|
| def _safe_model_label(model_id: str, fallback: str) -> str: |
| if model_id: |
| return model_id.split("/")[-1].replace("-Instruct", "") |
| return fallback |
|
|
|
|
| def _bandit_chart_label(label: str) -> str: |
| if "bandit" in label.lower(): |
| return label |
| if "qwen" in label.lower(): |
| return f"{label} + Bandits" |
| return label |
|
|
|
|
| def _summarize_run(run_dir: Path, *, mode: str) -> dict[str, Any]: |
| sft = _read_json(run_dir / "sft_trl_run.json") |
| grpo = _read_json(run_dir / "grpo_trl_run.json") |
| sft_inference = _read_json(run_dir / "postsave_inference_sft.json") |
| grpo_inference = _read_json(run_dir / "postsave_inference_grpo.json") |
| error = _read_json(run_dir / "error.json") |
| sft_history = _read_history(run_dir / "sft_history.json") |
| grpo_history = _read_history(run_dir / "grpo_history.json") |
| reward_rows = _read_jsonl(run_dir / "grpo_reward_components.jsonl") |
|
|
| sft_only = mode == "sft-baseline" |
| model_id = str(grpo.get("model_id") or sft.get("model_id") or error.get("model_id") or run_dir.name) |
| fallback_detected = any( |
| "fallback" in str(payload.get("backend", "")).lower() |
| or str(payload.get("model_source", "")).lower() == "fallback_policy" |
| for payload in ([sft, sft_inference] if sft_only else [sft, grpo, sft_inference, grpo_inference]) |
| ) |
|
|
| reward_failures: list[str] = [] |
| if not sft_only: |
| _scan_reward_payload(grpo, reward_failures, "grpo") |
| _scan_reward_payload(sft_inference, reward_failures, "sft_inference") |
| if not sft_only: |
| _scan_reward_payload(grpo_inference, reward_failures, "grpo_inference") |
| for idx, row in enumerate(reward_rows): |
| _scan_reward_payload(row, reward_failures, f"reward_log[{idx}]") |
|
|
| legal_count = sum(1 for row in reward_rows if row.get("legal") is True) |
| reward_count = len(reward_rows) |
| exploit_count = sum( |
| 1 |
| for row in reward_rows |
| if any( |
| marker in str(row.get("termination_reason", "")).lower() |
| for marker in ["cheat", "exploit", "abuse", "timeout", "invalid"] |
| ) |
| ) |
| selected = [str(row.get("selected_candidate_id") or row.get("generated_candidate_id") or "") for row in reward_rows] |
| selected = [item for item in selected if item] |
| counts = Counter(selected) |
| top_candidate_rate = (max(counts.values()) / len(selected)) if selected else 0.0 |
| candidate_diversity = (len(counts) / len(selected)) if selected else 0.0 |
|
|
| train_reward = _as_float((grpo.get("reward_summary") or {}).get("avg_reward")) |
| if sft_only: |
| holdout_reward = _as_float(sft_inference.get("avg_env_reward")) |
| train_reward = holdout_reward |
| else: |
| holdout_reward = _as_float(grpo_inference.get("avg_env_reward"), train_reward) |
| train_holdout_gap = round(train_reward - holdout_reward, 3) |
| validity = _as_float(sft_inference.get("valid_rate") if sft_only else grpo_inference.get("valid_rate"), 0.0) |
|
|
| completed = sft.get("status") == "ok" if sft_only else sft.get("status") == "ok" and grpo.get("status") == "ok" |
| return { |
| "run_id": run_dir.name, |
| "training_mode": mode, |
| "model_id": model_id, |
| "label": _safe_model_label(model_id, run_dir.name), |
| "status": "failed" if error else ("completed" if completed else "incomplete"), |
| "error": error.get("error", ""), |
| "sft_backend": sft.get("backend", ""), |
| "sft_examples": int(sft.get("examples_used", 0) or 0), |
| "sft_train_loss": _as_float(sft.get("train_loss")), |
| "sft_runtime": _as_float(sft.get("train_runtime")), |
| "grpo_backend": grpo.get("backend", ""), |
| "grpo_records": int(grpo.get("records", 0) or 0), |
| "grpo_avg_reward": train_reward, |
| "sft_inference_reward": _as_float(sft_inference.get("avg_env_reward")), |
| "sft_valid_rate": _as_float(sft_inference.get("valid_rate")), |
| "sft_latency_seconds": _as_float(sft_inference.get("avg_latency_seconds")), |
| "grpo_inference_reward": holdout_reward, |
| "grpo_valid_rate": validity, |
| "grpo_latency_seconds": _as_float(grpo_inference.get("avg_latency_seconds")), |
| "train_holdout_gap": train_holdout_gap, |
| "fallback_detected": fallback_detected, |
| "reward_range_ok": not reward_failures, |
| "reward_range_failures": reward_failures[:25], |
| "exploit_rate": round(exploit_count / reward_count, 3) if reward_count else 0.0, |
| "legal_rate": round(legal_count / reward_count, 3) if reward_count else 0.0, |
| "candidate_diversity": round(candidate_diversity, 3), |
| "top_candidate_rate": round(top_candidate_rate, 3), |
| "reward_components": (grpo.get("reward_summary") or {}).get("avg_reward_components", {}), |
| "primary_reward_channels": (grpo.get("reward_summary") or {}).get("avg_primary_reward_channels", {}), |
| "sft_history": sft_history, |
| "grpo_history": grpo_history, |
| "artifact_paths": { |
| "sft": sft.get("artifact_path", ""), |
| "grpo": grpo.get("artifact_path", ""), |
| }, |
| } |
|
|
|
|
| def _write_charts(rows: list[dict[str, Any]], plot_dir: Path, *, mode: str) -> dict[str, str]: |
| completed = [row for row in rows if row["status"] == "completed"] |
| labels = [_bandit_chart_label(str(row["label"])) for row in completed] |
| charts = { |
| "sft_vs_grpo_reward": plot_dir / "sft_vs_grpo_reward.png", |
| "sft_loss_curves": plot_dir / "sft_loss_curves.png", |
| "qwen_model_sft_reward": plot_dir / "qwen_model_sft_reward.png", |
| "qwen_model_sft_loss": plot_dir / "qwen_model_sft_loss.png", |
| "sft_validity_reward": plot_dir / "sft_validity_reward.png", |
| "grpo_reward_curves": plot_dir / "grpo_reward_curves.png", |
| "qwen_model_grpo_reward": plot_dir / "qwen_model_grpo_reward.png", |
| "reward_component_bars": plot_dir / "reward_component_bars.png", |
| "anti_cheat_failure_rates": plot_dir / "anti_cheat_failure_rates.png", |
| "train_holdout_gap": plot_dir / "train_holdout_gap.png", |
| "inference_validity_reward": plot_dir / "inference_validity_reward.png", |
| "inference_latency_validity": plot_dir / "inference_latency_validity.png", |
| } |
| _bar_chart( |
| charts["sft_vs_grpo_reward"], |
| "SFT Baseline vs GRPO + Bandits Policy Reward", |
| labels, |
| { |
| "SFT inference reward": [row["sft_inference_reward"] for row in completed], |
| "GRPO + Bandits inference reward": [row["grpo_inference_reward"] for row in completed], |
| }, |
| ) |
| _line_chart( |
| charts["sft_loss_curves"], |
| "Qwen + Bandits SFT Training Loss Curves", |
| { |
| _bandit_chart_label(str(row["label"])): _history_series(row["sft_history"], ("loss", "train_loss")) |
| for row in completed |
| }, |
| ylabel="Loss", |
| ) |
| _bar_chart( |
| charts["qwen_model_sft_reward"], |
| "Qwen + Bandits Model Sweep SFT Reward", |
| labels, |
| {"SFT inference reward": [row["sft_inference_reward"] for row in completed]}, |
| ) |
| _bar_chart( |
| charts["qwen_model_sft_loss"], |
| "Qwen + Bandits Model Sweep SFT Loss", |
| labels, |
| {"SFT train loss": [row["sft_train_loss"] for row in completed]}, |
| ylabel="Loss", |
| ) |
| _bar_chart( |
| charts["sft_validity_reward"], |
| "SFT Inference Validity and Reward", |
| labels, |
| { |
| "SFT valid rate": [row["sft_valid_rate"] for row in completed], |
| "SFT reward": [row["sft_inference_reward"] for row in completed], |
| }, |
| ylabel="Rate / reward", |
| ) |
| _line_chart( |
| charts["grpo_reward_curves"], |
| "GRPO + Bandits Reward Curves", |
| { |
| _bandit_chart_label(str(row["label"])): _history_series( |
| row["grpo_history"], |
| ("reward", "rewards/environment_reward_verifier", "mean_reward", "train_reward"), |
| ) |
| for row in completed |
| }, |
| ylabel="Reward", |
| ) |
| _bar_chart( |
| charts["qwen_model_grpo_reward"], |
| "Qwen + Bandits Model Sweep GRPO Reward", |
| labels, |
| {"GRPO + Bandits train reward": [row["grpo_avg_reward"] for row in completed]}, |
| ) |
|
|
| component_names = sorted( |
| { |
| key |
| for row in completed |
| for key, value in dict(row.get("reward_components") or {}).items() |
| if isinstance(value, int | float) |
| } |
| ) |
| component_means = [] |
| for key in component_names: |
| values = [_as_float((row.get("reward_components") or {}).get(key)) for row in completed] |
| component_means.append(round(sum(values) / len(values), 3) if values else 0.0) |
| _bar_chart( |
| charts["reward_component_bars"], |
| "Mean GRPO Reward Components", |
| component_names, |
| {"component reward": component_means}, |
| ) |
| _bar_chart( |
| charts["anti_cheat_failure_rates"], |
| "Anti-Cheat and Failure Visibility", |
| labels, |
| { |
| "exploit/invalid rate": [row["exploit_rate"] for row in completed], |
| "illegal rate": [round(1.0 - row["legal_rate"], 3) for row in completed], |
| "candidate collapse": [row["top_candidate_rate"] for row in completed], |
| }, |
| ylabel="Rate", |
| ) |
| _bar_chart( |
| charts["train_holdout_gap"], |
| "Train vs Holdout Reward Gap", |
| labels, |
| {"train - holdout": [abs(row["train_holdout_gap"]) for row in completed]}, |
| ylabel="Absolute reward gap", |
| ) |
| _bar_chart( |
| charts["inference_validity_reward"], |
| "Inference Validity and Reward", |
| labels, |
| { |
| "GRPO valid rate": [row["grpo_valid_rate"] for row in completed], |
| "GRPO holdout reward": [row["grpo_inference_reward"] for row in completed], |
| }, |
| ylabel="Rate / reward", |
| ) |
| _bar_chart( |
| charts["inference_latency_validity"], |
| "Inference Latency and Validity", |
| labels, |
| { |
| "SFT latency sec": [row["sft_latency_seconds"] for row in completed], |
| "GRPO latency sec": [row["grpo_latency_seconds"] for row in completed], |
| "GRPO valid rate": [row["grpo_valid_rate"] for row in completed], |
| }, |
| ylabel="Seconds / rate", |
| ) |
| chart_index: dict[str, str] = {} |
| for key, path in charts.items(): |
| try: |
| chart_index[key] = str(path.relative_to(ROOT)) |
| except ValueError: |
| chart_index[key] = str(path) |
| return chart_index |
|
|
|
|
| def generate_report( |
| sweep_dir: Path, |
| plot_dir: Path, |
| output_path: Path, |
| anti_hacking_output: Path, |
| mode: str = "full", |
| ) -> tuple[dict[str, Any], dict[str, Any]]: |
| run_dirs = sorted(path for path in sweep_dir.iterdir() if path.is_dir()) if sweep_dir.exists() else [] |
| mode = "sft-baseline" if mode in {"sft", "sft-only", "sft_baseline", "sft-baseline"} else "full" |
| rows = [_summarize_run(run_dir, mode=mode) for run_dir in run_dirs] |
| chart_paths = _write_charts(rows, plot_dir, mode=mode) |
| completed = [row for row in rows if row["status"] == "completed"] |
| failed = [row for row in rows if row["status"] == "failed"] |
| warnings: list[str] = [] |
|
|
| if not completed: |
| warnings.append("no_completed_models") |
| for row in completed: |
| if row["fallback_detected"]: |
| warnings.append(f"{row['label']}:fallback_detected") |
| if not row["reward_range_ok"]: |
| warnings.append(f"{row['label']}:reward_range_violation") |
| if mode == "sft-baseline": |
| if row["sft_valid_rate"] < 0.8: |
| warnings.append(f"{row['label']}:low_sft_validity") |
| else: |
| if row["exploit_rate"] > 0.35: |
| warnings.append(f"{row['label']}:high_exploit_rate") |
| if row["top_candidate_rate"] > 0.85 and row["candidate_diversity"] < 0.2: |
| warnings.append(f"{row['label']}:candidate_collapse") |
| if row["grpo_valid_rate"] < 0.8: |
| warnings.append(f"{row['label']}:low_validity") |
| if abs(row["train_holdout_gap"]) > 0.25: |
| warnings.append(f"{row['label']}:large_train_holdout_gap") |
|
|
| public_rows = [ |
| {key: value for key, value in row.items() if key not in {"sft_history", "grpo_history"}} |
| for row in rows |
| ] |
| summary = { |
| "status": "ok" if completed else "incomplete", |
| "training_mode": mode, |
| "completed_models": len(completed), |
| "failed_or_skipped_models": len(failed), |
| "models": public_rows, |
| "charts": chart_paths, |
| } |
| anti_hacking = { |
| "passed": bool(completed) and not warnings, |
| "training_mode": mode, |
| "warnings": warnings, |
| "completed_models": [row["model_id"] for row in completed], |
| "failed_or_skipped_models": [{"model_id": row["model_id"], "error": row["error"]} for row in failed], |
| "checks": { |
| "reward_bounds": [REWARD_MIN, REWARD_MAX], |
| "reward_precision": 3, |
| "fallback_backends_rejected": True, |
| "exploit_rate_threshold": 0.35, |
| "train_holdout_gap_threshold": 0.25, |
| "min_validity_rate": 0.8, |
| }, |
| } |
|
|
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| anti_hacking_output.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") |
| anti_hacking_output.write_text(json.dumps(anti_hacking, ensure_ascii=True, indent=2), encoding="utf-8") |
| return summary, anti_hacking |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| summary, anti_hacking = generate_report( |
| sweep_dir=ROOT / args.sweep_dir, |
| plot_dir=ROOT / args.plot_dir, |
| output_path=ROOT / args.output, |
| anti_hacking_output=ROOT / args.anti_hacking_output, |
| mode=args.mode, |
| ) |
| print( |
| json.dumps( |
| { |
| "hf_sweep_summary": summary.get("status"), |
| "completed_models": summary.get("completed_models"), |
| "anti_hacking_passed": anti_hacking.get("passed"), |
| }, |
| ensure_ascii=True, |
| ) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|