#!/usr/bin/env python3 """Generate sweep summaries, charts, and anti-hacking checks for HF training.""" from __future__ import annotations import argparse from collections import Counter import json from pathlib import Path from typing import Any import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 ROOT = Path(__file__).resolve().parents[1] REWARD_MIN = 0.001 REWARD_MAX = 0.999 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Summarize PolyGuard HF training sweeps.") parser.add_argument("--sweep-dir", default="outputs/reports/sweeps") parser.add_argument("--plot-dir", default="outputs/plots") parser.add_argument("--output", default="outputs/reports/hf_sweep_summary.json") parser.add_argument("--anti-hacking-output", default="outputs/reports/anti_hacking_overfit_report.json") parser.add_argument( "--mode", choices=["full", "sft-baseline"], default="full", help="Report mode. SFT baseline mode treats GRPO artifacts as optional.", ) return parser.parse_args() def _read_json(path: Path) -> dict[str, Any]: if not path.exists(): return {} try: payload = json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError: return {} return payload if isinstance(payload, dict) else {} def _read_history(path: Path) -> list[dict[str, Any]]: if not path.exists(): return [] try: payload = json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError: return [] return [row for row in payload if isinstance(row, dict)] if isinstance(payload, list) else [] def _read_jsonl(path: Path) -> list[dict[str, Any]]: if not path.exists(): return [] rows: list[dict[str, Any]] = [] with path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue try: payload = json.loads(line) except json.JSONDecodeError: continue if isinstance(payload, dict): rows.append(payload) return rows def _as_float(value: Any, default: float = 0.0) -> float: try: return float(value) except (TypeError, ValueError): return default def _is_reward_value(value: Any) -> bool: if isinstance(value, bool) or not isinstance(value, int | float): return False number = float(value) return REWARD_MIN <= number <= REWARD_MAX and round(number, 3) == number def _scan_reward_payload(payload: Any, failures: list[str], path: str) -> None: if isinstance(payload, dict): for key, value in payload.items(): next_path = f"{path}.{key}" if path else str(key) if key in {"reward", "env_reward", "avg_reward", "avg_env_reward"} or key.endswith("_score"): if not _is_reward_value(value): failures.append(f"{next_path}={value!r}") elif key in {"reward_breakdown", "primary_reward_channels", "avg_reward_components", "avg_primary_reward_channels"}: if isinstance(value, dict): for sub_key, sub_value in value.items(): if not _is_reward_value(sub_value): failures.append(f"{next_path}.{sub_key}={sub_value!r}") else: failures.append(f"{next_path}=not_dict") else: _scan_reward_payload(value, failures, next_path) elif isinstance(payload, list): for idx, item in enumerate(payload): _scan_reward_payload(item, failures, f"{path}[{idx}]") def _history_series(history: list[dict[str, Any]], names: tuple[str, ...]) -> tuple[list[int], list[float]]: xs: list[int] = [] ys: list[float] = [] for idx, row in enumerate(history, start=1): for name in names: if name in row: xs.append(idx) ys.append(_as_float(row.get(name))) break return xs, ys def _plot_placeholder(path: Path, title: str) -> None: path.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(8, 4.5)) ax.text(0.5, 0.5, "No completed sweep data yet", ha="center", va="center", fontsize=12) ax.set_axis_off() ax.set_title(title) fig.tight_layout() fig.savefig(path, dpi=160) plt.close(fig) def _bar_chart(path: Path, title: str, labels: list[str], series: dict[str, list[float]], ylabel: str = "Reward") -> None: if not labels or not series: _plot_placeholder(path, title) return path.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(10, 5.2)) width = 0.8 / max(1, len(series)) x_positions = list(range(len(labels))) offsets = [(-0.4 + (idx + 0.5) * width) for idx in range(len(series))] for offset, (name, values) in zip(offsets, series.items(), strict=False): ax.bar([x + offset for x in x_positions], values, width=width, label=name) ax.set_title(title) ax.set_ylabel(ylabel) ax.set_xticks(x_positions) ax.set_xticklabels(labels, rotation=20, ha="right") ax.set_ylim(0, max(1.0, max((max(vals) for vals in series.values() if vals), default=1.0) * 1.15)) ax.legend() ax.grid(axis="y", alpha=0.24) fig.tight_layout() fig.savefig(path, dpi=160) plt.close(fig) def _line_chart( path: Path, title: str, curves: dict[str, tuple[list[int], list[float]]], ylabel: str, ) -> None: curves = {key: value for key, value in curves.items() if value[0] and value[1]} if not curves: _plot_placeholder(path, title) return path.parent.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(10, 5.2)) for label, (xs, ys) in curves.items(): ax.plot(xs, ys, marker="o", linewidth=1.6, markersize=3.5, label=label) ax.set_title(title) ax.set_xlabel("Logged step") ax.set_ylabel(ylabel) ax.grid(alpha=0.24) ax.legend() fig.tight_layout() fig.savefig(path, dpi=160) plt.close(fig) def _safe_model_label(model_id: str, fallback: str) -> str: if model_id: return model_id.split("/")[-1].replace("-Instruct", "") return fallback def _bandit_chart_label(label: str) -> str: if "bandit" in label.lower(): return label if "qwen" in label.lower(): return f"{label} + Bandits" return label def _summarize_run(run_dir: Path, *, mode: str) -> dict[str, Any]: sft = _read_json(run_dir / "sft_trl_run.json") grpo = _read_json(run_dir / "grpo_trl_run.json") sft_inference = _read_json(run_dir / "postsave_inference_sft.json") grpo_inference = _read_json(run_dir / "postsave_inference_grpo.json") error = _read_json(run_dir / "error.json") sft_history = _read_history(run_dir / "sft_history.json") grpo_history = _read_history(run_dir / "grpo_history.json") reward_rows = _read_jsonl(run_dir / "grpo_reward_components.jsonl") sft_only = mode == "sft-baseline" model_id = str(grpo.get("model_id") or sft.get("model_id") or error.get("model_id") or run_dir.name) fallback_detected = any( "fallback" in str(payload.get("backend", "")).lower() or str(payload.get("model_source", "")).lower() == "fallback_policy" for payload in ([sft, sft_inference] if sft_only else [sft, grpo, sft_inference, grpo_inference]) ) reward_failures: list[str] = [] if not sft_only: _scan_reward_payload(grpo, reward_failures, "grpo") _scan_reward_payload(sft_inference, reward_failures, "sft_inference") if not sft_only: _scan_reward_payload(grpo_inference, reward_failures, "grpo_inference") for idx, row in enumerate(reward_rows): _scan_reward_payload(row, reward_failures, f"reward_log[{idx}]") legal_count = sum(1 for row in reward_rows if row.get("legal") is True) reward_count = len(reward_rows) exploit_count = sum( 1 for row in reward_rows if any( marker in str(row.get("termination_reason", "")).lower() for marker in ["cheat", "exploit", "abuse", "timeout", "invalid"] ) ) selected = [str(row.get("selected_candidate_id") or row.get("generated_candidate_id") or "") for row in reward_rows] selected = [item for item in selected if item] counts = Counter(selected) top_candidate_rate = (max(counts.values()) / len(selected)) if selected else 0.0 candidate_diversity = (len(counts) / len(selected)) if selected else 0.0 train_reward = _as_float((grpo.get("reward_summary") or {}).get("avg_reward")) if sft_only: holdout_reward = _as_float(sft_inference.get("avg_env_reward")) train_reward = holdout_reward else: holdout_reward = _as_float(grpo_inference.get("avg_env_reward"), train_reward) train_holdout_gap = round(train_reward - holdout_reward, 3) validity = _as_float(sft_inference.get("valid_rate") if sft_only else grpo_inference.get("valid_rate"), 0.0) completed = sft.get("status") == "ok" if sft_only else sft.get("status") == "ok" and grpo.get("status") == "ok" return { "run_id": run_dir.name, "training_mode": mode, "model_id": model_id, "label": _safe_model_label(model_id, run_dir.name), "status": "failed" if error else ("completed" if completed else "incomplete"), "error": error.get("error", ""), "sft_backend": sft.get("backend", ""), "sft_examples": int(sft.get("examples_used", 0) or 0), "sft_train_loss": _as_float(sft.get("train_loss")), "sft_runtime": _as_float(sft.get("train_runtime")), "grpo_backend": grpo.get("backend", ""), "grpo_records": int(grpo.get("records", 0) or 0), "grpo_avg_reward": train_reward, "sft_inference_reward": _as_float(sft_inference.get("avg_env_reward")), "sft_valid_rate": _as_float(sft_inference.get("valid_rate")), "sft_latency_seconds": _as_float(sft_inference.get("avg_latency_seconds")), "grpo_inference_reward": holdout_reward, "grpo_valid_rate": validity, "grpo_latency_seconds": _as_float(grpo_inference.get("avg_latency_seconds")), "train_holdout_gap": train_holdout_gap, "fallback_detected": fallback_detected, "reward_range_ok": not reward_failures, "reward_range_failures": reward_failures[:25], "exploit_rate": round(exploit_count / reward_count, 3) if reward_count else 0.0, "legal_rate": round(legal_count / reward_count, 3) if reward_count else 0.0, "candidate_diversity": round(candidate_diversity, 3), "top_candidate_rate": round(top_candidate_rate, 3), "reward_components": (grpo.get("reward_summary") or {}).get("avg_reward_components", {}), "primary_reward_channels": (grpo.get("reward_summary") or {}).get("avg_primary_reward_channels", {}), "sft_history": sft_history, "grpo_history": grpo_history, "artifact_paths": { "sft": sft.get("artifact_path", ""), "grpo": grpo.get("artifact_path", ""), }, } def _write_charts(rows: list[dict[str, Any]], plot_dir: Path, *, mode: str) -> dict[str, str]: completed = [row for row in rows if row["status"] == "completed"] labels = [_bandit_chart_label(str(row["label"])) for row in completed] charts = { "sft_vs_grpo_reward": plot_dir / "sft_vs_grpo_reward.png", "sft_loss_curves": plot_dir / "sft_loss_curves.png", "qwen_model_sft_reward": plot_dir / "qwen_model_sft_reward.png", "qwen_model_sft_loss": plot_dir / "qwen_model_sft_loss.png", "sft_validity_reward": plot_dir / "sft_validity_reward.png", "grpo_reward_curves": plot_dir / "grpo_reward_curves.png", "qwen_model_grpo_reward": plot_dir / "qwen_model_grpo_reward.png", "reward_component_bars": plot_dir / "reward_component_bars.png", "anti_cheat_failure_rates": plot_dir / "anti_cheat_failure_rates.png", "train_holdout_gap": plot_dir / "train_holdout_gap.png", "inference_validity_reward": plot_dir / "inference_validity_reward.png", "inference_latency_validity": plot_dir / "inference_latency_validity.png", } _bar_chart( charts["sft_vs_grpo_reward"], "SFT Baseline vs GRPO + Bandits Policy Reward", labels, { "SFT inference reward": [row["sft_inference_reward"] for row in completed], "GRPO + Bandits inference reward": [row["grpo_inference_reward"] for row in completed], }, ) _line_chart( charts["sft_loss_curves"], "Qwen + Bandits SFT Training Loss Curves", { _bandit_chart_label(str(row["label"])): _history_series(row["sft_history"], ("loss", "train_loss")) for row in completed }, ylabel="Loss", ) _bar_chart( charts["qwen_model_sft_reward"], "Qwen + Bandits Model Sweep SFT Reward", labels, {"SFT inference reward": [row["sft_inference_reward"] for row in completed]}, ) _bar_chart( charts["qwen_model_sft_loss"], "Qwen + Bandits Model Sweep SFT Loss", labels, {"SFT train loss": [row["sft_train_loss"] for row in completed]}, ylabel="Loss", ) _bar_chart( charts["sft_validity_reward"], "SFT Inference Validity and Reward", labels, { "SFT valid rate": [row["sft_valid_rate"] for row in completed], "SFT reward": [row["sft_inference_reward"] for row in completed], }, ylabel="Rate / reward", ) _line_chart( charts["grpo_reward_curves"], "GRPO + Bandits Reward Curves", { _bandit_chart_label(str(row["label"])): _history_series( row["grpo_history"], ("reward", "rewards/environment_reward_verifier", "mean_reward", "train_reward"), ) for row in completed }, ylabel="Reward", ) _bar_chart( charts["qwen_model_grpo_reward"], "Qwen + Bandits Model Sweep GRPO Reward", labels, {"GRPO + Bandits train reward": [row["grpo_avg_reward"] for row in completed]}, ) component_names = sorted( { key for row in completed for key, value in dict(row.get("reward_components") or {}).items() if isinstance(value, int | float) } ) component_means = [] for key in component_names: values = [_as_float((row.get("reward_components") or {}).get(key)) for row in completed] component_means.append(round(sum(values) / len(values), 3) if values else 0.0) _bar_chart( charts["reward_component_bars"], "Mean GRPO Reward Components", component_names, {"component reward": component_means}, ) _bar_chart( charts["anti_cheat_failure_rates"], "Anti-Cheat and Failure Visibility", labels, { "exploit/invalid rate": [row["exploit_rate"] for row in completed], "illegal rate": [round(1.0 - row["legal_rate"], 3) for row in completed], "candidate collapse": [row["top_candidate_rate"] for row in completed], }, ylabel="Rate", ) _bar_chart( charts["train_holdout_gap"], "Train vs Holdout Reward Gap", labels, {"train - holdout": [abs(row["train_holdout_gap"]) for row in completed]}, ylabel="Absolute reward gap", ) _bar_chart( charts["inference_validity_reward"], "Inference Validity and Reward", labels, { "GRPO valid rate": [row["grpo_valid_rate"] for row in completed], "GRPO holdout reward": [row["grpo_inference_reward"] for row in completed], }, ylabel="Rate / reward", ) _bar_chart( charts["inference_latency_validity"], "Inference Latency and Validity", labels, { "SFT latency sec": [row["sft_latency_seconds"] for row in completed], "GRPO latency sec": [row["grpo_latency_seconds"] for row in completed], "GRPO valid rate": [row["grpo_valid_rate"] for row in completed], }, ylabel="Seconds / rate", ) chart_index: dict[str, str] = {} for key, path in charts.items(): try: chart_index[key] = str(path.relative_to(ROOT)) except ValueError: chart_index[key] = str(path) return chart_index def generate_report( sweep_dir: Path, plot_dir: Path, output_path: Path, anti_hacking_output: Path, mode: str = "full", ) -> tuple[dict[str, Any], dict[str, Any]]: run_dirs = sorted(path for path in sweep_dir.iterdir() if path.is_dir()) if sweep_dir.exists() else [] mode = "sft-baseline" if mode in {"sft", "sft-only", "sft_baseline", "sft-baseline"} else "full" rows = [_summarize_run(run_dir, mode=mode) for run_dir in run_dirs] chart_paths = _write_charts(rows, plot_dir, mode=mode) completed = [row for row in rows if row["status"] == "completed"] failed = [row for row in rows if row["status"] == "failed"] warnings: list[str] = [] if not completed: warnings.append("no_completed_models") for row in completed: if row["fallback_detected"]: warnings.append(f"{row['label']}:fallback_detected") if not row["reward_range_ok"]: warnings.append(f"{row['label']}:reward_range_violation") if mode == "sft-baseline": if row["sft_valid_rate"] < 0.8: warnings.append(f"{row['label']}:low_sft_validity") else: if row["exploit_rate"] > 0.35: warnings.append(f"{row['label']}:high_exploit_rate") if row["top_candidate_rate"] > 0.85 and row["candidate_diversity"] < 0.2: warnings.append(f"{row['label']}:candidate_collapse") if row["grpo_valid_rate"] < 0.8: warnings.append(f"{row['label']}:low_validity") if abs(row["train_holdout_gap"]) > 0.25: warnings.append(f"{row['label']}:large_train_holdout_gap") public_rows = [ {key: value for key, value in row.items() if key not in {"sft_history", "grpo_history"}} for row in rows ] summary = { "status": "ok" if completed else "incomplete", "training_mode": mode, "completed_models": len(completed), "failed_or_skipped_models": len(failed), "models": public_rows, "charts": chart_paths, } anti_hacking = { "passed": bool(completed) and not warnings, "training_mode": mode, "warnings": warnings, "completed_models": [row["model_id"] for row in completed], "failed_or_skipped_models": [{"model_id": row["model_id"], "error": row["error"]} for row in failed], "checks": { "reward_bounds": [REWARD_MIN, REWARD_MAX], "reward_precision": 3, "fallback_backends_rejected": True, "exploit_rate_threshold": 0.35, "train_holdout_gap_threshold": 0.25, "min_validity_rate": 0.8, }, } output_path.parent.mkdir(parents=True, exist_ok=True) anti_hacking_output.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") anti_hacking_output.write_text(json.dumps(anti_hacking, ensure_ascii=True, indent=2), encoding="utf-8") return summary, anti_hacking def main() -> None: args = parse_args() summary, anti_hacking = generate_report( sweep_dir=ROOT / args.sweep_dir, plot_dir=ROOT / args.plot_dir, output_path=ROOT / args.output, anti_hacking_output=ROOT / args.anti_hacking_output, mode=args.mode, ) print( json.dumps( { "hf_sweep_summary": summary.get("status"), "completed_models": summary.get("completed_models"), "anti_hacking_passed": anti_hacking.get("passed"), }, ensure_ascii=True, ) ) if __name__ == "__main__": main()