polyguard-openenv-workbench / polyguard-rl /scripts /generate_hf_training_report.py
TheJackBright's picture
Deploy GitHub root master to Space
c296d62
#!/usr/bin/env python3
"""Generate sweep summaries, charts, and anti-hacking checks for HF training."""
from __future__ import annotations
import argparse
from collections import Counter
import json
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
ROOT = Path(__file__).resolve().parents[1]
REWARD_MIN = 0.001
REWARD_MAX = 0.999
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Summarize PolyGuard HF training sweeps.")
parser.add_argument("--sweep-dir", default="outputs/reports/sweeps")
parser.add_argument("--plot-dir", default="outputs/plots")
parser.add_argument("--output", default="outputs/reports/hf_sweep_summary.json")
parser.add_argument("--anti-hacking-output", default="outputs/reports/anti_hacking_overfit_report.json")
parser.add_argument(
"--mode",
choices=["full", "sft-baseline"],
default="full",
help="Report mode. SFT baseline mode treats GRPO artifacts as optional.",
)
return parser.parse_args()
def _read_json(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return {}
return payload if isinstance(payload, dict) else {}
def _read_history(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return []
return [row for row in payload if isinstance(row, dict)] if isinstance(payload, list) else []
def _read_jsonl(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
try:
payload = json.loads(line)
except json.JSONDecodeError:
continue
if isinstance(payload, dict):
rows.append(payload)
return rows
def _as_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _is_reward_value(value: Any) -> bool:
if isinstance(value, bool) or not isinstance(value, int | float):
return False
number = float(value)
return REWARD_MIN <= number <= REWARD_MAX and round(number, 3) == number
def _scan_reward_payload(payload: Any, failures: list[str], path: str) -> None:
if isinstance(payload, dict):
for key, value in payload.items():
next_path = f"{path}.{key}" if path else str(key)
if key in {"reward", "env_reward", "avg_reward", "avg_env_reward"} or key.endswith("_score"):
if not _is_reward_value(value):
failures.append(f"{next_path}={value!r}")
elif key in {"reward_breakdown", "primary_reward_channels", "avg_reward_components", "avg_primary_reward_channels"}:
if isinstance(value, dict):
for sub_key, sub_value in value.items():
if not _is_reward_value(sub_value):
failures.append(f"{next_path}.{sub_key}={sub_value!r}")
else:
failures.append(f"{next_path}=not_dict")
else:
_scan_reward_payload(value, failures, next_path)
elif isinstance(payload, list):
for idx, item in enumerate(payload):
_scan_reward_payload(item, failures, f"{path}[{idx}]")
def _history_series(history: list[dict[str, Any]], names: tuple[str, ...]) -> tuple[list[int], list[float]]:
xs: list[int] = []
ys: list[float] = []
for idx, row in enumerate(history, start=1):
for name in names:
if name in row:
xs.append(idx)
ys.append(_as_float(row.get(name)))
break
return xs, ys
def _plot_placeholder(path: Path, title: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.text(0.5, 0.5, "No completed sweep data yet", ha="center", va="center", fontsize=12)
ax.set_axis_off()
ax.set_title(title)
fig.tight_layout()
fig.savefig(path, dpi=160)
plt.close(fig)
def _bar_chart(path: Path, title: str, labels: list[str], series: dict[str, list[float]], ylabel: str = "Reward") -> None:
if not labels or not series:
_plot_placeholder(path, title)
return
path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(10, 5.2))
width = 0.8 / max(1, len(series))
x_positions = list(range(len(labels)))
offsets = [(-0.4 + (idx + 0.5) * width) for idx in range(len(series))]
for offset, (name, values) in zip(offsets, series.items(), strict=False):
ax.bar([x + offset for x in x_positions], values, width=width, label=name)
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.set_xticks(x_positions)
ax.set_xticklabels(labels, rotation=20, ha="right")
ax.set_ylim(0, max(1.0, max((max(vals) for vals in series.values() if vals), default=1.0) * 1.15))
ax.legend()
ax.grid(axis="y", alpha=0.24)
fig.tight_layout()
fig.savefig(path, dpi=160)
plt.close(fig)
def _line_chart(
path: Path,
title: str,
curves: dict[str, tuple[list[int], list[float]]],
ylabel: str,
) -> None:
curves = {key: value for key, value in curves.items() if value[0] and value[1]}
if not curves:
_plot_placeholder(path, title)
return
path.parent.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(10, 5.2))
for label, (xs, ys) in curves.items():
ax.plot(xs, ys, marker="o", linewidth=1.6, markersize=3.5, label=label)
ax.set_title(title)
ax.set_xlabel("Logged step")
ax.set_ylabel(ylabel)
ax.grid(alpha=0.24)
ax.legend()
fig.tight_layout()
fig.savefig(path, dpi=160)
plt.close(fig)
def _safe_model_label(model_id: str, fallback: str) -> str:
if model_id:
return model_id.split("/")[-1].replace("-Instruct", "")
return fallback
def _bandit_chart_label(label: str) -> str:
if "bandit" in label.lower():
return label
if "qwen" in label.lower():
return f"{label} + Bandits"
return label
def _summarize_run(run_dir: Path, *, mode: str) -> dict[str, Any]:
sft = _read_json(run_dir / "sft_trl_run.json")
grpo = _read_json(run_dir / "grpo_trl_run.json")
sft_inference = _read_json(run_dir / "postsave_inference_sft.json")
grpo_inference = _read_json(run_dir / "postsave_inference_grpo.json")
error = _read_json(run_dir / "error.json")
sft_history = _read_history(run_dir / "sft_history.json")
grpo_history = _read_history(run_dir / "grpo_history.json")
reward_rows = _read_jsonl(run_dir / "grpo_reward_components.jsonl")
sft_only = mode == "sft-baseline"
model_id = str(grpo.get("model_id") or sft.get("model_id") or error.get("model_id") or run_dir.name)
fallback_detected = any(
"fallback" in str(payload.get("backend", "")).lower()
or str(payload.get("model_source", "")).lower() == "fallback_policy"
for payload in ([sft, sft_inference] if sft_only else [sft, grpo, sft_inference, grpo_inference])
)
reward_failures: list[str] = []
if not sft_only:
_scan_reward_payload(grpo, reward_failures, "grpo")
_scan_reward_payload(sft_inference, reward_failures, "sft_inference")
if not sft_only:
_scan_reward_payload(grpo_inference, reward_failures, "grpo_inference")
for idx, row in enumerate(reward_rows):
_scan_reward_payload(row, reward_failures, f"reward_log[{idx}]")
legal_count = sum(1 for row in reward_rows if row.get("legal") is True)
reward_count = len(reward_rows)
exploit_count = sum(
1
for row in reward_rows
if any(
marker in str(row.get("termination_reason", "")).lower()
for marker in ["cheat", "exploit", "abuse", "timeout", "invalid"]
)
)
selected = [str(row.get("selected_candidate_id") or row.get("generated_candidate_id") or "") for row in reward_rows]
selected = [item for item in selected if item]
counts = Counter(selected)
top_candidate_rate = (max(counts.values()) / len(selected)) if selected else 0.0
candidate_diversity = (len(counts) / len(selected)) if selected else 0.0
train_reward = _as_float((grpo.get("reward_summary") or {}).get("avg_reward"))
if sft_only:
holdout_reward = _as_float(sft_inference.get("avg_env_reward"))
train_reward = holdout_reward
else:
holdout_reward = _as_float(grpo_inference.get("avg_env_reward"), train_reward)
train_holdout_gap = round(train_reward - holdout_reward, 3)
validity = _as_float(sft_inference.get("valid_rate") if sft_only else grpo_inference.get("valid_rate"), 0.0)
completed = sft.get("status") == "ok" if sft_only else sft.get("status") == "ok" and grpo.get("status") == "ok"
return {
"run_id": run_dir.name,
"training_mode": mode,
"model_id": model_id,
"label": _safe_model_label(model_id, run_dir.name),
"status": "failed" if error else ("completed" if completed else "incomplete"),
"error": error.get("error", ""),
"sft_backend": sft.get("backend", ""),
"sft_examples": int(sft.get("examples_used", 0) or 0),
"sft_train_loss": _as_float(sft.get("train_loss")),
"sft_runtime": _as_float(sft.get("train_runtime")),
"grpo_backend": grpo.get("backend", ""),
"grpo_records": int(grpo.get("records", 0) or 0),
"grpo_avg_reward": train_reward,
"sft_inference_reward": _as_float(sft_inference.get("avg_env_reward")),
"sft_valid_rate": _as_float(sft_inference.get("valid_rate")),
"sft_latency_seconds": _as_float(sft_inference.get("avg_latency_seconds")),
"grpo_inference_reward": holdout_reward,
"grpo_valid_rate": validity,
"grpo_latency_seconds": _as_float(grpo_inference.get("avg_latency_seconds")),
"train_holdout_gap": train_holdout_gap,
"fallback_detected": fallback_detected,
"reward_range_ok": not reward_failures,
"reward_range_failures": reward_failures[:25],
"exploit_rate": round(exploit_count / reward_count, 3) if reward_count else 0.0,
"legal_rate": round(legal_count / reward_count, 3) if reward_count else 0.0,
"candidate_diversity": round(candidate_diversity, 3),
"top_candidate_rate": round(top_candidate_rate, 3),
"reward_components": (grpo.get("reward_summary") or {}).get("avg_reward_components", {}),
"primary_reward_channels": (grpo.get("reward_summary") or {}).get("avg_primary_reward_channels", {}),
"sft_history": sft_history,
"grpo_history": grpo_history,
"artifact_paths": {
"sft": sft.get("artifact_path", ""),
"grpo": grpo.get("artifact_path", ""),
},
}
def _write_charts(rows: list[dict[str, Any]], plot_dir: Path, *, mode: str) -> dict[str, str]:
completed = [row for row in rows if row["status"] == "completed"]
labels = [_bandit_chart_label(str(row["label"])) for row in completed]
charts = {
"sft_vs_grpo_reward": plot_dir / "sft_vs_grpo_reward.png",
"sft_loss_curves": plot_dir / "sft_loss_curves.png",
"qwen_model_sft_reward": plot_dir / "qwen_model_sft_reward.png",
"qwen_model_sft_loss": plot_dir / "qwen_model_sft_loss.png",
"sft_validity_reward": plot_dir / "sft_validity_reward.png",
"grpo_reward_curves": plot_dir / "grpo_reward_curves.png",
"qwen_model_grpo_reward": plot_dir / "qwen_model_grpo_reward.png",
"reward_component_bars": plot_dir / "reward_component_bars.png",
"anti_cheat_failure_rates": plot_dir / "anti_cheat_failure_rates.png",
"train_holdout_gap": plot_dir / "train_holdout_gap.png",
"inference_validity_reward": plot_dir / "inference_validity_reward.png",
"inference_latency_validity": plot_dir / "inference_latency_validity.png",
}
_bar_chart(
charts["sft_vs_grpo_reward"],
"SFT Baseline vs GRPO + Bandits Policy Reward",
labels,
{
"SFT inference reward": [row["sft_inference_reward"] for row in completed],
"GRPO + Bandits inference reward": [row["grpo_inference_reward"] for row in completed],
},
)
_line_chart(
charts["sft_loss_curves"],
"Qwen + Bandits SFT Training Loss Curves",
{
_bandit_chart_label(str(row["label"])): _history_series(row["sft_history"], ("loss", "train_loss"))
for row in completed
},
ylabel="Loss",
)
_bar_chart(
charts["qwen_model_sft_reward"],
"Qwen + Bandits Model Sweep SFT Reward",
labels,
{"SFT inference reward": [row["sft_inference_reward"] for row in completed]},
)
_bar_chart(
charts["qwen_model_sft_loss"],
"Qwen + Bandits Model Sweep SFT Loss",
labels,
{"SFT train loss": [row["sft_train_loss"] for row in completed]},
ylabel="Loss",
)
_bar_chart(
charts["sft_validity_reward"],
"SFT Inference Validity and Reward",
labels,
{
"SFT valid rate": [row["sft_valid_rate"] for row in completed],
"SFT reward": [row["sft_inference_reward"] for row in completed],
},
ylabel="Rate / reward",
)
_line_chart(
charts["grpo_reward_curves"],
"GRPO + Bandits Reward Curves",
{
_bandit_chart_label(str(row["label"])): _history_series(
row["grpo_history"],
("reward", "rewards/environment_reward_verifier", "mean_reward", "train_reward"),
)
for row in completed
},
ylabel="Reward",
)
_bar_chart(
charts["qwen_model_grpo_reward"],
"Qwen + Bandits Model Sweep GRPO Reward",
labels,
{"GRPO + Bandits train reward": [row["grpo_avg_reward"] for row in completed]},
)
component_names = sorted(
{
key
for row in completed
for key, value in dict(row.get("reward_components") or {}).items()
if isinstance(value, int | float)
}
)
component_means = []
for key in component_names:
values = [_as_float((row.get("reward_components") or {}).get(key)) for row in completed]
component_means.append(round(sum(values) / len(values), 3) if values else 0.0)
_bar_chart(
charts["reward_component_bars"],
"Mean GRPO Reward Components",
component_names,
{"component reward": component_means},
)
_bar_chart(
charts["anti_cheat_failure_rates"],
"Anti-Cheat and Failure Visibility",
labels,
{
"exploit/invalid rate": [row["exploit_rate"] for row in completed],
"illegal rate": [round(1.0 - row["legal_rate"], 3) for row in completed],
"candidate collapse": [row["top_candidate_rate"] for row in completed],
},
ylabel="Rate",
)
_bar_chart(
charts["train_holdout_gap"],
"Train vs Holdout Reward Gap",
labels,
{"train - holdout": [abs(row["train_holdout_gap"]) for row in completed]},
ylabel="Absolute reward gap",
)
_bar_chart(
charts["inference_validity_reward"],
"Inference Validity and Reward",
labels,
{
"GRPO valid rate": [row["grpo_valid_rate"] for row in completed],
"GRPO holdout reward": [row["grpo_inference_reward"] for row in completed],
},
ylabel="Rate / reward",
)
_bar_chart(
charts["inference_latency_validity"],
"Inference Latency and Validity",
labels,
{
"SFT latency sec": [row["sft_latency_seconds"] for row in completed],
"GRPO latency sec": [row["grpo_latency_seconds"] for row in completed],
"GRPO valid rate": [row["grpo_valid_rate"] for row in completed],
},
ylabel="Seconds / rate",
)
chart_index: dict[str, str] = {}
for key, path in charts.items():
try:
chart_index[key] = str(path.relative_to(ROOT))
except ValueError:
chart_index[key] = str(path)
return chart_index
def generate_report(
sweep_dir: Path,
plot_dir: Path,
output_path: Path,
anti_hacking_output: Path,
mode: str = "full",
) -> tuple[dict[str, Any], dict[str, Any]]:
run_dirs = sorted(path for path in sweep_dir.iterdir() if path.is_dir()) if sweep_dir.exists() else []
mode = "sft-baseline" if mode in {"sft", "sft-only", "sft_baseline", "sft-baseline"} else "full"
rows = [_summarize_run(run_dir, mode=mode) for run_dir in run_dirs]
chart_paths = _write_charts(rows, plot_dir, mode=mode)
completed = [row for row in rows if row["status"] == "completed"]
failed = [row for row in rows if row["status"] == "failed"]
warnings: list[str] = []
if not completed:
warnings.append("no_completed_models")
for row in completed:
if row["fallback_detected"]:
warnings.append(f"{row['label']}:fallback_detected")
if not row["reward_range_ok"]:
warnings.append(f"{row['label']}:reward_range_violation")
if mode == "sft-baseline":
if row["sft_valid_rate"] < 0.8:
warnings.append(f"{row['label']}:low_sft_validity")
else:
if row["exploit_rate"] > 0.35:
warnings.append(f"{row['label']}:high_exploit_rate")
if row["top_candidate_rate"] > 0.85 and row["candidate_diversity"] < 0.2:
warnings.append(f"{row['label']}:candidate_collapse")
if row["grpo_valid_rate"] < 0.8:
warnings.append(f"{row['label']}:low_validity")
if abs(row["train_holdout_gap"]) > 0.25:
warnings.append(f"{row['label']}:large_train_holdout_gap")
public_rows = [
{key: value for key, value in row.items() if key not in {"sft_history", "grpo_history"}}
for row in rows
]
summary = {
"status": "ok" if completed else "incomplete",
"training_mode": mode,
"completed_models": len(completed),
"failed_or_skipped_models": len(failed),
"models": public_rows,
"charts": chart_paths,
}
anti_hacking = {
"passed": bool(completed) and not warnings,
"training_mode": mode,
"warnings": warnings,
"completed_models": [row["model_id"] for row in completed],
"failed_or_skipped_models": [{"model_id": row["model_id"], "error": row["error"]} for row in failed],
"checks": {
"reward_bounds": [REWARD_MIN, REWARD_MAX],
"reward_precision": 3,
"fallback_backends_rejected": True,
"exploit_rate_threshold": 0.35,
"train_holdout_gap_threshold": 0.25,
"min_validity_rate": 0.8,
},
}
output_path.parent.mkdir(parents=True, exist_ok=True)
anti_hacking_output.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
anti_hacking_output.write_text(json.dumps(anti_hacking, ensure_ascii=True, indent=2), encoding="utf-8")
return summary, anti_hacking
def main() -> None:
args = parse_args()
summary, anti_hacking = generate_report(
sweep_dir=ROOT / args.sweep_dir,
plot_dir=ROOT / args.plot_dir,
output_path=ROOT / args.output,
anti_hacking_output=ROOT / args.anti_hacking_output,
mode=args.mode,
)
print(
json.dumps(
{
"hf_sweep_summary": summary.get("status"),
"completed_models": summary.get("completed_models"),
"anti_hacking_passed": anti_hacking.get("passed"),
},
ensure_ascii=True,
)
)
if __name__ == "__main__":
main()