openenv / scripts /render_rft_proof.py
sentinel-space-publisher
space: publish latest Sentinel app snapshot
c452421
from __future__ import annotations
import argparse
import json
from collections import defaultdict
from pathlib import Path
from statistics import mean
from typing import Any, Dict, Iterable, List, Optional
def _load_json(path: Path) -> Dict[str, Any]:
if not path.exists():
return {}
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
def _load_jsonl(path: Path) -> List[Dict[str, Any]]:
if not path.exists():
return []
rows: List[Dict[str, Any]] = []
for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
line = line.strip()
if not line:
continue
try:
item = json.loads(line)
except json.JSONDecodeError:
continue
if isinstance(item, dict):
rows.append(item)
return rows
def _get(payload: Dict[str, Any], dotted_key: str, default: Any = None) -> Any:
cur: Any = payload
for part in dotted_key.split("."):
if not isinstance(cur, dict) or part not in cur:
return default
cur = cur[part]
return cur
def _as_float(value: Any, default: float = 0.0) -> float:
try:
if value is None:
return default
return float(value)
except (TypeError, ValueError):
return default
def _ensure_matplotlib():
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
return plt
def _save_placeholder(path: Path, title: str, message: str) -> None:
plt = _ensure_matplotlib()
fig, ax = plt.subplots(figsize=(10, 5.4))
ax.axis("off")
ax.text(0.5, 0.62, title, ha="center", va="center", fontsize=17, fontweight="bold")
ax.text(0.5, 0.42, message, ha="center", va="center", fontsize=11, wrap=True)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _task_groups(rollouts: Iterable[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
grouped: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
for row in rollouts:
grouped[str(row.get("task_id") or "unknown")].append(row)
return dict(sorted(grouped.items(), key=lambda item: item[0]))
def _save_keep_drop(path: Path, rollouts: List[Dict[str, Any]]) -> None:
if not rollouts:
_save_placeholder(path, "RFT Keep/Drop By Task", "No RFT rollouts found.")
return
plt = _ensure_matplotlib()
groups = _task_groups(rollouts)
labels = list(groups)
kept = [sum(1 for row in groups[label] if row.get("kept")) for label in labels]
dropped = [len(groups[label]) - kept[index] for index, label in enumerate(labels)]
fig, ax = plt.subplots(figsize=(12, 5.8))
ax.bar(labels, kept, color="#2ca25f", label="kept for RFT")
ax.bar(labels, dropped, bottom=kept, color="#d95f02", label="rejected")
ax.set_title("RFT Rejection Sampling: Kept vs Rejected Rollouts")
ax.set_ylabel("rollouts")
ax.tick_params(axis="x", rotation=25)
ax.grid(True, axis="y", alpha=0.25)
ax.legend()
for i, label in enumerate(labels):
total = len(groups[label])
rate = kept[i] / total if total else 0.0
ax.text(i, kept[i] + dropped[i] + 0.25, f"{rate:.0%}", ha="center", fontsize=9)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _save_score_by_task(path: Path, rollouts: List[Dict[str, Any]], min_score: Optional[float]) -> None:
if not rollouts:
_save_placeholder(path, "RFT Score Distribution", "No RFT rollouts found.")
return
plt = _ensure_matplotlib()
groups = _task_groups(rollouts)
labels = list(groups)
fig, ax = plt.subplots(figsize=(12, 5.8))
for index, label in enumerate(labels):
rows = groups[label]
scores = [_as_float(row.get("score")) for row in rows]
colors = ["#2ca25f" if row.get("kept") else "#d95f02" for row in rows]
xs = [index + ((i % 7) - 3) * 0.025 for i in range(len(rows))]
ax.scatter(xs, scores, c=colors, alpha=0.8, s=36, edgecolors="white", linewidths=0.4)
if min_score is not None:
ax.axhline(min_score, color="#333333", linestyle="--", linewidth=1.4, label=f"keep score >= {min_score:g}")
ax.legend()
ax.set_title("RFT Rollout Scores By Task")
ax.set_ylabel("filter score")
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=25, ha="right")
ax.grid(True, axis="y", alpha=0.25)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _save_fp_by_task(path: Path, rollouts: List[Dict[str, Any]], max_fp: Optional[float]) -> None:
if not rollouts:
_save_placeholder(path, "RFT False Positive Distribution", "No RFT rollouts found.")
return
plt = _ensure_matplotlib()
groups = _task_groups(rollouts)
labels = list(groups)
fig, ax = plt.subplots(figsize=(12, 5.8))
for index, label in enumerate(labels):
rows = groups[label]
fps = [_as_float(row.get("fp")) for row in rows]
colors = ["#2ca25f" if row.get("kept") else "#d95f02" for row in rows]
xs = [index + ((i % 7) - 3) * 0.025 for i in range(len(rows))]
ax.scatter(xs, fps, c=colors, alpha=0.8, s=36, edgecolors="white", linewidths=0.4)
if max_fp is not None:
ax.axhline(max_fp, color="#333333", linestyle="--", linewidth=1.4, label=f"keep fp <= {max_fp:g}")
ax.legend()
ax.set_title("RFT False Positives By Task")
ax.set_ylabel("false positives / episode")
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=25, ha="right")
ax.grid(True, axis="y", alpha=0.25)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _save_score_vs_fp(path: Path, rollouts: List[Dict[str, Any]], min_score: Optional[float], max_fp: Optional[float]) -> None:
if not rollouts:
_save_placeholder(path, "RFT Score vs False Positives", "No RFT rollouts found.")
return
plt = _ensure_matplotlib()
groups = _task_groups(rollouts)
palette = ["#1b9e77", "#7570b3", "#e7298a", "#66a61e", "#e6ab02", "#a6761d"]
fig, ax = plt.subplots(figsize=(10.5, 6.2))
for index, (task_id, rows) in enumerate(groups.items()):
kept_rows = [row for row in rows if row.get("kept")]
drop_rows = [row for row in rows if not row.get("kept")]
color = palette[index % len(palette)]
if drop_rows:
ax.scatter(
[_as_float(row.get("fp")) for row in drop_rows],
[_as_float(row.get("score")) for row in drop_rows],
marker="x",
s=50,
color=color,
alpha=0.55,
label=f"{task_id} rejected",
)
if kept_rows:
ax.scatter(
[_as_float(row.get("fp")) for row in kept_rows],
[_as_float(row.get("score")) for row in kept_rows],
marker="o",
s=60,
color=color,
edgecolors="black",
linewidths=0.4,
label=f"{task_id} kept",
)
if min_score is not None:
ax.axhline(min_score, color="#111111", linestyle="--", linewidth=1.2)
if max_fp is not None:
ax.axvline(max_fp, color="#111111", linestyle="--", linewidth=1.2)
ax.set_title("RFT Filter Boundary: Keep High Score, Low False Positives")
ax.set_xlabel("false positives / episode")
ax.set_ylabel("filter score")
ax.grid(True, alpha=0.25)
ax.legend(fontsize=7, ncol=2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _save_timeline(path: Path, rollouts: List[Dict[str, Any]]) -> None:
if not rollouts:
_save_placeholder(path, "RFT Rollout Timeline", "No RFT rollouts found.")
return
plt = _ensure_matplotlib()
xs = list(range(1, len(rollouts) + 1))
scores = [_as_float(row.get("score")) for row in rollouts]
kept_x = [xs[i] for i, row in enumerate(rollouts) if row.get("kept")]
kept_y = [scores[i] for i, row in enumerate(rollouts) if row.get("kept")]
drop_x = [xs[i] for i, row in enumerate(rollouts) if not row.get("kept")]
drop_y = [scores[i] for i, row in enumerate(rollouts) if not row.get("kept")]
rolling_keep = []
for index in range(len(rollouts)):
start = max(0, index - 9)
window = rollouts[start : index + 1]
rolling_keep.append(sum(1 for row in window if row.get("kept")) / len(window))
fig, ax = plt.subplots(figsize=(12, 5.8))
ax.plot(xs, scores, color="#6b7280", linewidth=1.1, alpha=0.65, label="score")
ax.scatter(kept_x, kept_y, color="#2ca25f", s=45, label="kept")
ax.scatter(drop_x, drop_y, color="#d95f02", marker="x", s=42, label="rejected")
ax2 = ax.twinx()
ax2.plot(xs, rolling_keep, color="#2563eb", linewidth=2, label="rolling keep rate")
ax.set_title("RFT Rollout Timeline")
ax.set_xlabel("generated rollout")
ax.set_ylabel("filter score")
ax2.set_ylabel("rolling keep rate")
ax.grid(True, axis="y", alpha=0.25)
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2, loc="best")
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _save_eval_overview(path: Path, eval_report: Dict[str, Any]) -> None:
if not eval_report:
_save_placeholder(path, "Held-Out Eval After RFT", "No eval report provided yet.")
return
plt = _ensure_matplotlib()
metrics = [
("Mean score", "mean_score"),
("Detection", "detection_rate"),
("Risk reduction", "risk_reduction_rate"),
("Worker rehab", "worker_rehabilitation_rate"),
("False positive", "false_positive_rate"),
]
baseline = _get(eval_report, "overall.baseline", {})
candidate = _get(eval_report, "overall.candidate", {})
labels = [label for label, _ in metrics]
base_values = [_as_float(baseline.get(key)) for _, key in metrics]
cand_values = [_as_float(candidate.get(key)) for _, key in metrics]
xs = list(range(len(labels)))
width = 0.38
fig, ax = plt.subplots(figsize=(12, 5.8))
ax.bar([x - width / 2 for x in xs], base_values, width=width, color="#d95f02", label=str(eval_report.get("baseline_label") or "baseline"))
ax.bar([x + width / 2 for x in xs], cand_values, width=width, color="#2ca25f", label=str(eval_report.get("candidate_label") or "candidate"))
ax.set_title("Held-Out Evaluation: Baseline vs RFT Candidate")
ax.set_ylabel("rate / score")
ax.set_xticks(xs)
ax.set_xticklabels(labels, rotation=20, ha="right")
ax.grid(True, axis="y", alpha=0.25)
ax.legend()
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _save_eval_task_delta(path: Path, eval_report: Dict[str, Any]) -> None:
per_task = _get(eval_report, "per_task", {})
if not isinstance(per_task, dict) or not per_task:
_save_placeholder(path, "RFT Held-Out Score Delta By Task", "No per-task eval rows found.")
return
labels = []
deltas = []
for task_id, payload in sorted(per_task.items()):
baseline_score = _as_float(_get(payload, "baseline.mean_score"))
candidate_score = _as_float(_get(payload, "candidate.mean_score"))
labels.append(str(task_id))
deltas.append(candidate_score - baseline_score)
plt = _ensure_matplotlib()
colors = ["#2ca25f" if value >= 0 else "#d95f02" for value in deltas]
fig, ax = plt.subplots(figsize=(12, 5.8))
ax.bar(labels, deltas, color=colors)
ax.axhline(0.0, color="#111111", linewidth=1)
ax.set_title("Held-Out Score Delta By Task")
ax.set_ylabel("candidate mean score - baseline mean score")
ax.tick_params(axis="x", rotation=25)
ax.grid(True, axis="y", alpha=0.25)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _write_markdown(
path: Path,
label: str,
rollouts: List[Dict[str, Any]],
kept: List[Dict[str, Any]],
summary: Dict[str, Any],
eval_report: Dict[str, Any],
images: List[str],
) -> None:
total = len(rollouts)
kept_count = len(kept)
keep_rate = kept_count / total if total else 0.0
mean_score_total = mean([_as_float(row.get("score")) for row in rollouts]) if rollouts else 0.0
mean_score_kept = mean([_as_float(row.get("score")) for row in kept]) if kept else 0.0
mean_fp_kept = mean([_as_float(row.get("fp")) for row in kept]) if kept else 0.0
eval_overall = _get(eval_report, "overall", {})
if eval_overall:
intro = (
"This folder is the rejection-sampling fine-tuning proof layer. "
"It shows which model-generated rollouts were accepted, which were rejected, "
"and what the held-out evaluation says after the polish pass."
)
else:
intro = (
"This folder is the rejection-sampling fine-tuning proof layer. "
"It shows which model-generated rollouts were accepted, which were rejected, "
"and which low-false-positive samples were used for the polish pass. "
"Held-out model evaluation was intentionally omitted for this proof pack."
)
lines = [
f"# {label} RFT Proof Pack",
"",
intro,
"",
"## Summary",
"",
f"- Total generated rollouts: `{total}`",
f"- Kept rollouts used for SFT: `{kept_count}`",
f"- Keep rate: `{keep_rate:.1%}`",
f"- Mean rollout score: `{mean_score_total:.3f}`",
f"- Mean kept score: `{mean_score_kept:.3f}`",
f"- Mean kept false positives: `{mean_fp_kept:.2f}`",
]
if summary:
lines.extend([
f"- RFT status: `{_get(summary, 'sft.status', summary.get('status', 'unknown'))}`",
f"- Output adapter: `{_get(summary, 'output.final_dir', summary.get('final_dir', 'see RFT output dir'))}`",
])
if eval_overall:
lines.extend([
"",
"## Held-Out Eval",
"",
f"- Baseline mean score: `{_as_float(eval_overall.get('baseline_mean_score')):.3f}`",
f"- Candidate mean score: `{_as_float(eval_overall.get('candidate_mean_score')):.3f}`",
f"- Mean score delta: `{_as_float(eval_overall.get('mean_score_delta')):.3f}`",
f"- Candidate risk reduction: `{_as_float(eval_overall.get('candidate_risk_reduction_rate')):.1%}`",
f"- Candidate false-positive rate: `{_as_float(eval_overall.get('candidate_false_positive_rate')):.1%}`",
])
lines.extend(["", "## Plots", ""])
for image in images:
title = Path(image).stem.replace("_", " ").title()
lines.extend([f"### {title}", "", f"![{title}]({image})", ""])
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def render_rft_proof(
rft_dir: Path,
output_dir: Path,
eval_report_path: Optional[Path],
label: str,
min_score: Optional[float],
max_fp: Optional[float],
) -> Dict[str, Any]:
output_dir.mkdir(parents=True, exist_ok=True)
rollouts = _load_jsonl(rft_dir / "rollouts.jsonl")
kept = [row for row in rollouts if row.get("kept")]
summary = _load_json(rft_dir / "rft_summary.json")
eval_report = _load_json(eval_report_path) if eval_report_path else {}
if min_score is None:
min_score = _as_float(_get(summary, "config.MIN_SCORE"), default=float("nan"))
if min_score != min_score:
min_score = None
if max_fp is None:
max_fp = _as_float(_get(summary, "config.MAX_FP"), default=float("nan"))
if max_fp != max_fp:
max_fp = None
image_names = [
"01_rft_keep_drop_by_task.png",
"02_rft_score_distribution.png",
"03_rft_false_positive_distribution.png",
"04_rft_score_vs_fp_filter.png",
"05_rft_rollout_timeline.png",
"06_rft_eval_overview.png",
"07_rft_eval_task_delta.png",
]
_save_keep_drop(output_dir / image_names[0], rollouts)
_save_score_by_task(output_dir / image_names[1], rollouts, min_score)
_save_fp_by_task(output_dir / image_names[2], rollouts, max_fp)
_save_score_vs_fp(output_dir / image_names[3], rollouts, min_score, max_fp)
_save_timeline(output_dir / image_names[4], rollouts)
_save_eval_overview(output_dir / image_names[5], eval_report)
_save_eval_task_delta(output_dir / image_names[6], eval_report)
manifest = {
"label": label,
"rft_dir": str(rft_dir),
"eval_report_path": str(eval_report_path) if eval_report_path else "",
"total_rollouts": len(rollouts),
"kept_rollouts": len(kept),
"keep_rate": len(kept) / len(rollouts) if rollouts else 0.0,
"images": image_names,
}
(output_dir / "rft_plot_manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
_write_markdown(output_dir / "rft_proof.md", label, rollouts, kept, summary, eval_report, image_names)
return manifest
def main() -> None:
parser = argparse.ArgumentParser(description="Render proof plots for a SENTINEL RFT polish run.")
parser.add_argument("--rft-dir", default="/data/sentinel_outputs_rft_phase1_100", help="Directory containing rollouts.jsonl and rft_summary.json.")
parser.add_argument("--eval-report", default="/data/rft_eval/sentinel_held_out_report.json", help="Optional held-out eval JSON report.")
parser.add_argument("--output-dir", default="outputs/rft_phase1_100/plots", help="Where to write PNG plots and markdown.")
parser.add_argument("--label", default="Phase 1 + RFT", help="Label used in the markdown report.")
parser.add_argument("--min-score", type=float, default=None, help="Override score threshold line.")
parser.add_argument("--max-fp", type=float, default=None, help="Override false-positive threshold line.")
args = parser.parse_args()
eval_report = Path(args.eval_report) if args.eval_report else None
manifest = render_rft_proof(
rft_dir=Path(args.rft_dir),
output_dir=Path(args.output_dir),
eval_report_path=eval_report,
label=args.label,
min_score=args.min_score,
max_fp=args.max_fp,
)
print(json.dumps(manifest, indent=2))
if __name__ == "__main__":
main()