Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from statistics import mean | |
| from typing import Any, Dict, Iterable, List, Optional | |
| def _load_json(path: Path) -> Dict[str, Any]: | |
| if not path.exists(): | |
| return {} | |
| try: | |
| return json.loads(path.read_text(encoding="utf-8")) | |
| except Exception: | |
| return {} | |
| def _load_jsonl(path: Path) -> List[Dict[str, Any]]: | |
| if not path.exists(): | |
| return [] | |
| rows: List[Dict[str, Any]] = [] | |
| for line in path.read_text(encoding="utf-8", errors="ignore").splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| item = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| if isinstance(item, dict): | |
| rows.append(item) | |
| return rows | |
| def _get(payload: Dict[str, Any], dotted_key: str, default: Any = None) -> Any: | |
| cur: Any = payload | |
| for part in dotted_key.split("."): | |
| if not isinstance(cur, dict) or part not in cur: | |
| return default | |
| cur = cur[part] | |
| return cur | |
| def _as_float(value: Any, default: float = 0.0) -> float: | |
| try: | |
| if value is None: | |
| return default | |
| return float(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def _ensure_matplotlib(): | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| return plt | |
| def _save_placeholder(path: Path, title: str, message: str) -> None: | |
| plt = _ensure_matplotlib() | |
| fig, ax = plt.subplots(figsize=(10, 5.4)) | |
| ax.axis("off") | |
| ax.text(0.5, 0.62, title, ha="center", va="center", fontsize=17, fontweight="bold") | |
| ax.text(0.5, 0.42, message, ha="center", va="center", fontsize=11, wrap=True) | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=170) | |
| plt.close(fig) | |
| def _task_groups(rollouts: Iterable[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: | |
| grouped: Dict[str, List[Dict[str, Any]]] = defaultdict(list) | |
| for row in rollouts: | |
| grouped[str(row.get("task_id") or "unknown")].append(row) | |
| return dict(sorted(grouped.items(), key=lambda item: item[0])) | |
| def _save_keep_drop(path: Path, rollouts: List[Dict[str, Any]]) -> None: | |
| if not rollouts: | |
| _save_placeholder(path, "RFT Keep/Drop By Task", "No RFT rollouts found.") | |
| return | |
| plt = _ensure_matplotlib() | |
| groups = _task_groups(rollouts) | |
| labels = list(groups) | |
| kept = [sum(1 for row in groups[label] if row.get("kept")) for label in labels] | |
| dropped = [len(groups[label]) - kept[index] for index, label in enumerate(labels)] | |
| fig, ax = plt.subplots(figsize=(12, 5.8)) | |
| ax.bar(labels, kept, color="#2ca25f", label="kept for RFT") | |
| ax.bar(labels, dropped, bottom=kept, color="#d95f02", label="rejected") | |
| ax.set_title("RFT Rejection Sampling: Kept vs Rejected Rollouts") | |
| ax.set_ylabel("rollouts") | |
| ax.tick_params(axis="x", rotation=25) | |
| ax.grid(True, axis="y", alpha=0.25) | |
| ax.legend() | |
| for i, label in enumerate(labels): | |
| total = len(groups[label]) | |
| rate = kept[i] / total if total else 0.0 | |
| ax.text(i, kept[i] + dropped[i] + 0.25, f"{rate:.0%}", ha="center", fontsize=9) | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=170) | |
| plt.close(fig) | |
| def _save_score_by_task(path: Path, rollouts: List[Dict[str, Any]], min_score: Optional[float]) -> None: | |
| if not rollouts: | |
| _save_placeholder(path, "RFT Score Distribution", "No RFT rollouts found.") | |
| return | |
| plt = _ensure_matplotlib() | |
| groups = _task_groups(rollouts) | |
| labels = list(groups) | |
| fig, ax = plt.subplots(figsize=(12, 5.8)) | |
| for index, label in enumerate(labels): | |
| rows = groups[label] | |
| scores = [_as_float(row.get("score")) for row in rows] | |
| colors = ["#2ca25f" if row.get("kept") else "#d95f02" for row in rows] | |
| xs = [index + ((i % 7) - 3) * 0.025 for i in range(len(rows))] | |
| ax.scatter(xs, scores, c=colors, alpha=0.8, s=36, edgecolors="white", linewidths=0.4) | |
| if min_score is not None: | |
| ax.axhline(min_score, color="#333333", linestyle="--", linewidth=1.4, label=f"keep score >= {min_score:g}") | |
| ax.legend() | |
| ax.set_title("RFT Rollout Scores By Task") | |
| ax.set_ylabel("filter score") | |
| ax.set_xticks(range(len(labels))) | |
| ax.set_xticklabels(labels, rotation=25, ha="right") | |
| ax.grid(True, axis="y", alpha=0.25) | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=170) | |
| plt.close(fig) | |
| def _save_fp_by_task(path: Path, rollouts: List[Dict[str, Any]], max_fp: Optional[float]) -> None: | |
| if not rollouts: | |
| _save_placeholder(path, "RFT False Positive Distribution", "No RFT rollouts found.") | |
| return | |
| plt = _ensure_matplotlib() | |
| groups = _task_groups(rollouts) | |
| labels = list(groups) | |
| fig, ax = plt.subplots(figsize=(12, 5.8)) | |
| for index, label in enumerate(labels): | |
| rows = groups[label] | |
| fps = [_as_float(row.get("fp")) for row in rows] | |
| colors = ["#2ca25f" if row.get("kept") else "#d95f02" for row in rows] | |
| xs = [index + ((i % 7) - 3) * 0.025 for i in range(len(rows))] | |
| ax.scatter(xs, fps, c=colors, alpha=0.8, s=36, edgecolors="white", linewidths=0.4) | |
| if max_fp is not None: | |
| ax.axhline(max_fp, color="#333333", linestyle="--", linewidth=1.4, label=f"keep fp <= {max_fp:g}") | |
| ax.legend() | |
| ax.set_title("RFT False Positives By Task") | |
| ax.set_ylabel("false positives / episode") | |
| ax.set_xticks(range(len(labels))) | |
| ax.set_xticklabels(labels, rotation=25, ha="right") | |
| ax.grid(True, axis="y", alpha=0.25) | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=170) | |
| plt.close(fig) | |
| def _save_score_vs_fp(path: Path, rollouts: List[Dict[str, Any]], min_score: Optional[float], max_fp: Optional[float]) -> None: | |
| if not rollouts: | |
| _save_placeholder(path, "RFT Score vs False Positives", "No RFT rollouts found.") | |
| return | |
| plt = _ensure_matplotlib() | |
| groups = _task_groups(rollouts) | |
| palette = ["#1b9e77", "#7570b3", "#e7298a", "#66a61e", "#e6ab02", "#a6761d"] | |
| fig, ax = plt.subplots(figsize=(10.5, 6.2)) | |
| for index, (task_id, rows) in enumerate(groups.items()): | |
| kept_rows = [row for row in rows if row.get("kept")] | |
| drop_rows = [row for row in rows if not row.get("kept")] | |
| color = palette[index % len(palette)] | |
| if drop_rows: | |
| ax.scatter( | |
| [_as_float(row.get("fp")) for row in drop_rows], | |
| [_as_float(row.get("score")) for row in drop_rows], | |
| marker="x", | |
| s=50, | |
| color=color, | |
| alpha=0.55, | |
| label=f"{task_id} rejected", | |
| ) | |
| if kept_rows: | |
| ax.scatter( | |
| [_as_float(row.get("fp")) for row in kept_rows], | |
| [_as_float(row.get("score")) for row in kept_rows], | |
| marker="o", | |
| s=60, | |
| color=color, | |
| edgecolors="black", | |
| linewidths=0.4, | |
| label=f"{task_id} kept", | |
| ) | |
| if min_score is not None: | |
| ax.axhline(min_score, color="#111111", linestyle="--", linewidth=1.2) | |
| if max_fp is not None: | |
| ax.axvline(max_fp, color="#111111", linestyle="--", linewidth=1.2) | |
| ax.set_title("RFT Filter Boundary: Keep High Score, Low False Positives") | |
| ax.set_xlabel("false positives / episode") | |
| ax.set_ylabel("filter score") | |
| ax.grid(True, alpha=0.25) | |
| ax.legend(fontsize=7, ncol=2) | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=170) | |
| plt.close(fig) | |
| def _save_timeline(path: Path, rollouts: List[Dict[str, Any]]) -> None: | |
| if not rollouts: | |
| _save_placeholder(path, "RFT Rollout Timeline", "No RFT rollouts found.") | |
| return | |
| plt = _ensure_matplotlib() | |
| xs = list(range(1, len(rollouts) + 1)) | |
| scores = [_as_float(row.get("score")) for row in rollouts] | |
| kept_x = [xs[i] for i, row in enumerate(rollouts) if row.get("kept")] | |
| kept_y = [scores[i] for i, row in enumerate(rollouts) if row.get("kept")] | |
| drop_x = [xs[i] for i, row in enumerate(rollouts) if not row.get("kept")] | |
| drop_y = [scores[i] for i, row in enumerate(rollouts) if not row.get("kept")] | |
| rolling_keep = [] | |
| for index in range(len(rollouts)): | |
| start = max(0, index - 9) | |
| window = rollouts[start : index + 1] | |
| rolling_keep.append(sum(1 for row in window if row.get("kept")) / len(window)) | |
| fig, ax = plt.subplots(figsize=(12, 5.8)) | |
| ax.plot(xs, scores, color="#6b7280", linewidth=1.1, alpha=0.65, label="score") | |
| ax.scatter(kept_x, kept_y, color="#2ca25f", s=45, label="kept") | |
| ax.scatter(drop_x, drop_y, color="#d95f02", marker="x", s=42, label="rejected") | |
| ax2 = ax.twinx() | |
| ax2.plot(xs, rolling_keep, color="#2563eb", linewidth=2, label="rolling keep rate") | |
| ax.set_title("RFT Rollout Timeline") | |
| ax.set_xlabel("generated rollout") | |
| ax.set_ylabel("filter score") | |
| ax2.set_ylabel("rolling keep rate") | |
| ax.grid(True, axis="y", alpha=0.25) | |
| lines, labels = ax.get_legend_handles_labels() | |
| lines2, labels2 = ax2.get_legend_handles_labels() | |
| ax.legend(lines + lines2, labels + labels2, loc="best") | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=170) | |
| plt.close(fig) | |
| def _save_eval_overview(path: Path, eval_report: Dict[str, Any]) -> None: | |
| if not eval_report: | |
| _save_placeholder(path, "Held-Out Eval After RFT", "No eval report provided yet.") | |
| return | |
| plt = _ensure_matplotlib() | |
| metrics = [ | |
| ("Mean score", "mean_score"), | |
| ("Detection", "detection_rate"), | |
| ("Risk reduction", "risk_reduction_rate"), | |
| ("Worker rehab", "worker_rehabilitation_rate"), | |
| ("False positive", "false_positive_rate"), | |
| ] | |
| baseline = _get(eval_report, "overall.baseline", {}) | |
| candidate = _get(eval_report, "overall.candidate", {}) | |
| labels = [label for label, _ in metrics] | |
| base_values = [_as_float(baseline.get(key)) for _, key in metrics] | |
| cand_values = [_as_float(candidate.get(key)) for _, key in metrics] | |
| xs = list(range(len(labels))) | |
| width = 0.38 | |
| fig, ax = plt.subplots(figsize=(12, 5.8)) | |
| ax.bar([x - width / 2 for x in xs], base_values, width=width, color="#d95f02", label=str(eval_report.get("baseline_label") or "baseline")) | |
| ax.bar([x + width / 2 for x in xs], cand_values, width=width, color="#2ca25f", label=str(eval_report.get("candidate_label") or "candidate")) | |
| ax.set_title("Held-Out Evaluation: Baseline vs RFT Candidate") | |
| ax.set_ylabel("rate / score") | |
| ax.set_xticks(xs) | |
| ax.set_xticklabels(labels, rotation=20, ha="right") | |
| ax.grid(True, axis="y", alpha=0.25) | |
| ax.legend() | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=170) | |
| plt.close(fig) | |
| def _save_eval_task_delta(path: Path, eval_report: Dict[str, Any]) -> None: | |
| per_task = _get(eval_report, "per_task", {}) | |
| if not isinstance(per_task, dict) or not per_task: | |
| _save_placeholder(path, "RFT Held-Out Score Delta By Task", "No per-task eval rows found.") | |
| return | |
| labels = [] | |
| deltas = [] | |
| for task_id, payload in sorted(per_task.items()): | |
| baseline_score = _as_float(_get(payload, "baseline.mean_score")) | |
| candidate_score = _as_float(_get(payload, "candidate.mean_score")) | |
| labels.append(str(task_id)) | |
| deltas.append(candidate_score - baseline_score) | |
| plt = _ensure_matplotlib() | |
| colors = ["#2ca25f" if value >= 0 else "#d95f02" for value in deltas] | |
| fig, ax = plt.subplots(figsize=(12, 5.8)) | |
| ax.bar(labels, deltas, color=colors) | |
| ax.axhline(0.0, color="#111111", linewidth=1) | |
| ax.set_title("Held-Out Score Delta By Task") | |
| ax.set_ylabel("candidate mean score - baseline mean score") | |
| ax.tick_params(axis="x", rotation=25) | |
| ax.grid(True, axis="y", alpha=0.25) | |
| fig.tight_layout() | |
| fig.savefig(path, dpi=170) | |
| plt.close(fig) | |
| def _write_markdown( | |
| path: Path, | |
| label: str, | |
| rollouts: List[Dict[str, Any]], | |
| kept: List[Dict[str, Any]], | |
| summary: Dict[str, Any], | |
| eval_report: Dict[str, Any], | |
| images: List[str], | |
| ) -> None: | |
| total = len(rollouts) | |
| kept_count = len(kept) | |
| keep_rate = kept_count / total if total else 0.0 | |
| mean_score_total = mean([_as_float(row.get("score")) for row in rollouts]) if rollouts else 0.0 | |
| mean_score_kept = mean([_as_float(row.get("score")) for row in kept]) if kept else 0.0 | |
| mean_fp_kept = mean([_as_float(row.get("fp")) for row in kept]) if kept else 0.0 | |
| eval_overall = _get(eval_report, "overall", {}) | |
| if eval_overall: | |
| intro = ( | |
| "This folder is the rejection-sampling fine-tuning proof layer. " | |
| "It shows which model-generated rollouts were accepted, which were rejected, " | |
| "and what the held-out evaluation says after the polish pass." | |
| ) | |
| else: | |
| intro = ( | |
| "This folder is the rejection-sampling fine-tuning proof layer. " | |
| "It shows which model-generated rollouts were accepted, which were rejected, " | |
| "and which low-false-positive samples were used for the polish pass. " | |
| "Held-out model evaluation was intentionally omitted for this proof pack." | |
| ) | |
| lines = [ | |
| f"# {label} RFT Proof Pack", | |
| "", | |
| intro, | |
| "", | |
| "## Summary", | |
| "", | |
| f"- Total generated rollouts: `{total}`", | |
| f"- Kept rollouts used for SFT: `{kept_count}`", | |
| f"- Keep rate: `{keep_rate:.1%}`", | |
| f"- Mean rollout score: `{mean_score_total:.3f}`", | |
| f"- Mean kept score: `{mean_score_kept:.3f}`", | |
| f"- Mean kept false positives: `{mean_fp_kept:.2f}`", | |
| ] | |
| if summary: | |
| lines.extend([ | |
| f"- RFT status: `{_get(summary, 'sft.status', summary.get('status', 'unknown'))}`", | |
| f"- Output adapter: `{_get(summary, 'output.final_dir', summary.get('final_dir', 'see RFT output dir'))}`", | |
| ]) | |
| if eval_overall: | |
| lines.extend([ | |
| "", | |
| "## Held-Out Eval", | |
| "", | |
| f"- Baseline mean score: `{_as_float(eval_overall.get('baseline_mean_score')):.3f}`", | |
| f"- Candidate mean score: `{_as_float(eval_overall.get('candidate_mean_score')):.3f}`", | |
| f"- Mean score delta: `{_as_float(eval_overall.get('mean_score_delta')):.3f}`", | |
| f"- Candidate risk reduction: `{_as_float(eval_overall.get('candidate_risk_reduction_rate')):.1%}`", | |
| f"- Candidate false-positive rate: `{_as_float(eval_overall.get('candidate_false_positive_rate')):.1%}`", | |
| ]) | |
| lines.extend(["", "## Plots", ""]) | |
| for image in images: | |
| title = Path(image).stem.replace("_", " ").title() | |
| lines.extend([f"### {title}", "", f"", ""]) | |
| path.write_text("\n".join(lines) + "\n", encoding="utf-8") | |
| def render_rft_proof( | |
| rft_dir: Path, | |
| output_dir: Path, | |
| eval_report_path: Optional[Path], | |
| label: str, | |
| min_score: Optional[float], | |
| max_fp: Optional[float], | |
| ) -> Dict[str, Any]: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| rollouts = _load_jsonl(rft_dir / "rollouts.jsonl") | |
| kept = [row for row in rollouts if row.get("kept")] | |
| summary = _load_json(rft_dir / "rft_summary.json") | |
| eval_report = _load_json(eval_report_path) if eval_report_path else {} | |
| if min_score is None: | |
| min_score = _as_float(_get(summary, "config.MIN_SCORE"), default=float("nan")) | |
| if min_score != min_score: | |
| min_score = None | |
| if max_fp is None: | |
| max_fp = _as_float(_get(summary, "config.MAX_FP"), default=float("nan")) | |
| if max_fp != max_fp: | |
| max_fp = None | |
| image_names = [ | |
| "01_rft_keep_drop_by_task.png", | |
| "02_rft_score_distribution.png", | |
| "03_rft_false_positive_distribution.png", | |
| "04_rft_score_vs_fp_filter.png", | |
| "05_rft_rollout_timeline.png", | |
| "06_rft_eval_overview.png", | |
| "07_rft_eval_task_delta.png", | |
| ] | |
| _save_keep_drop(output_dir / image_names[0], rollouts) | |
| _save_score_by_task(output_dir / image_names[1], rollouts, min_score) | |
| _save_fp_by_task(output_dir / image_names[2], rollouts, max_fp) | |
| _save_score_vs_fp(output_dir / image_names[3], rollouts, min_score, max_fp) | |
| _save_timeline(output_dir / image_names[4], rollouts) | |
| _save_eval_overview(output_dir / image_names[5], eval_report) | |
| _save_eval_task_delta(output_dir / image_names[6], eval_report) | |
| manifest = { | |
| "label": label, | |
| "rft_dir": str(rft_dir), | |
| "eval_report_path": str(eval_report_path) if eval_report_path else "", | |
| "total_rollouts": len(rollouts), | |
| "kept_rollouts": len(kept), | |
| "keep_rate": len(kept) / len(rollouts) if rollouts else 0.0, | |
| "images": image_names, | |
| } | |
| (output_dir / "rft_plot_manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8") | |
| _write_markdown(output_dir / "rft_proof.md", label, rollouts, kept, summary, eval_report, image_names) | |
| return manifest | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Render proof plots for a SENTINEL RFT polish run.") | |
| parser.add_argument("--rft-dir", default="/data/sentinel_outputs_rft_phase1_100", help="Directory containing rollouts.jsonl and rft_summary.json.") | |
| parser.add_argument("--eval-report", default="/data/rft_eval/sentinel_held_out_report.json", help="Optional held-out eval JSON report.") | |
| parser.add_argument("--output-dir", default="outputs/rft_phase1_100/plots", help="Where to write PNG plots and markdown.") | |
| parser.add_argument("--label", default="Phase 1 + RFT", help="Label used in the markdown report.") | |
| parser.add_argument("--min-score", type=float, default=None, help="Override score threshold line.") | |
| parser.add_argument("--max-fp", type=float, default=None, help="Override false-positive threshold line.") | |
| args = parser.parse_args() | |
| eval_report = Path(args.eval_report) if args.eval_report else None | |
| manifest = render_rft_proof( | |
| rft_dir=Path(args.rft_dir), | |
| output_dir=Path(args.output_dir), | |
| eval_report_path=eval_report, | |
| label=args.label, | |
| min_score=args.min_score, | |
| max_fp=args.max_fp, | |
| ) | |
| print(json.dumps(manifest, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |