| |
| """Generate submission evidence for completed Qwen 0.5B/1.5B PolyGuard runs. |
| |
| This script is intentionally evaluation-only. It never trains or updates model |
| weights. It gathers any already available local/remote artifacts, records what |
| is still pending upload, runs deterministic PolyGuard verifier rollouts, and |
| emits charts/JSON/Markdown suitable for the final submission bundle. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| import shutil |
| import statistics |
| import time |
| from typing import Any, Iterable |
| import zipfile |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| try: |
| from huggingface_hub import HfApi, snapshot_download |
| except Exception: |
| HfApi = None |
| snapshot_download = None |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| DEFAULT_MODELS = "qwen-qwen2-5-0-5b-instruct,qwen-qwen2-5-1-5b-instruct" |
| DEFAULT_ARTIFACT_REPO = "TheJackBright/polyguard-openenv-training-full-artifacts" |
| DEFAULT_TRAINING_SPACE_URL = "https://thejackbright-polyguard-openenv-training-full.hf.space" |
| DEFAULT_REPORT_DIR = ROOT / "outputs" / "reports" / "submission_evidence" / "qwen_0_5b_1_5b" |
| DEFAULT_PLOT_DIR = ROOT / "outputs" / "plots" / "submission_evidence" / "qwen_0_5b_1_5b" |
| DEFAULT_DOCS_DIR = ROOT / "docs" / "results" / "submission_evidence_qwen_0_5b_1_5b" |
| DEFAULT_BUNDLE_ZIP = ROOT / "submission_bundle" / "qwen_0_5b_1_5b_evidence.zip" |
|
|
| RUN_FILE_NAMES = [ |
| "run_metadata.json", |
| "sft_trl_run.json", |
| "sft_history.json", |
| "postsave_inference_sft.json", |
| "grpo_trl_run.json", |
| "grpo_history.json", |
| "grpo_reward_components.jsonl", |
| "postsave_inference_grpo.json", |
| "grpo_ablation_report.json", |
| "error.json", |
| ] |
|
|
| REWARD_COMPONENT_KEYS = [ |
| "format_compliance_score", |
| "candidate_alignment_score", |
| "legality_score", |
| "safety_delta_score", |
| "burden_improvement_score", |
| "disease_stability_score", |
| "dosing_quality_score", |
| "abstention_quality_score", |
| "efficiency_score", |
| "process_fidelity_score", |
| "explanation_grounding_score", |
| "anti_cheat_score", |
| "uncertainty_calibration_score", |
| ] |
|
|
| PRIMARY_CHANNEL_KEYS = [ |
| "safety_legality", |
| "clinical_improvement", |
| "dosing_quality", |
| "process_integrity", |
| ] |
|
|
|
|
| @dataclass |
| class EvidencePaths: |
| report_dir: Path |
| plot_dir: Path |
| docs_dir: Path |
| bundle_zip: Path |
|
|
| @property |
| def run_report_dir(self) -> Path: |
| return self.report_dir / "runs" |
|
|
| @property |
| def docs_reports_dir(self) -> Path: |
| return self.docs_dir / "reports" |
|
|
| @property |
| def docs_charts_dir(self) -> Path: |
| return self.docs_dir / "charts" |
|
|
| @property |
| def docs_traces_dir(self) -> Path: |
| return self.docs_dir / "traces" |
|
|
|
|
| @dataclass |
| class RunEvidence: |
| run_id: str |
| model_id: str |
| label: str |
| source_dir: Path | None = None |
| files: dict[str, str] = field(default_factory=dict) |
| statuses: dict[str, str] = field(default_factory=dict) |
| metrics: dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Generate PolyGuard submission evidence without retraining.") |
| parser.add_argument("--models", default=DEFAULT_MODELS) |
| parser.add_argument("--artifact-repo-id", default=DEFAULT_ARTIFACT_REPO) |
| parser.add_argument("--training-space-url", default=DEFAULT_TRAINING_SPACE_URL) |
| parser.add_argument("--output-dir", default=str(DEFAULT_REPORT_DIR)) |
| parser.add_argument("--plot-dir", default=str(DEFAULT_PLOT_DIR)) |
| parser.add_argument("--docs-dir", default=str(DEFAULT_DOCS_DIR)) |
| parser.add_argument("--bundle-zip", default=str(DEFAULT_BUNDLE_ZIP)) |
| parser.add_argument("--episodes", type=int, default=8) |
| parser.add_argument("--local-only", action="store_true", help="Do not query Hugging Face.") |
| parser.add_argument("--allow-network-errors", action="store_true", default=True) |
| parser.add_argument("--replace", action="store_true", default=True) |
| return parser.parse_args() |
|
|
|
|
| def safe_run_id(value: str) -> str: |
| value = value.strip() |
| if "/" not in value and value.startswith("qwen-"): |
| return value |
| return "".join(ch if ch.isalnum() else "-" for ch in value).strip("-").lower() |
|
|
|
|
| def model_id_from_run_id(value: str) -> str: |
| mapping = { |
| "qwen-qwen2-5-0-5b-instruct": "Qwen/Qwen2.5-0.5B-Instruct", |
| "qwen-qwen2-5-1-5b-instruct": "Qwen/Qwen2.5-1.5B-Instruct", |
| "qwen-qwen2-5-3b-instruct": "Qwen/Qwen2.5-3B-Instruct", |
| } |
| return mapping.get(value, value) |
|
|
|
|
| def friendly_label(run_id: str, model_id: str | None = None) -> str: |
| value = (model_id or run_id).lower() |
| if "0.5b" in value or "0-5b" in value: |
| return "Qwen 0.5B" |
| if "1.5b" in value or "1-5b" in value: |
| return "Qwen 1.5B" |
| if "3b" in value or "3-b" in value: |
| return "Qwen 3B" |
| return model_id or run_id |
|
|
|
|
| def bandit_chart_label(label: str) -> str: |
| if "bandit" in label.lower(): |
| return label |
| if "qwen" in label.lower(): |
| return f"{label} + Bandits" |
| return label |
|
|
|
|
| def comparison_policy_label(policy: str) -> str: |
| labels = { |
| "basic_llm": "Baseline Basic LLM", |
| "sft_policy": "SFT Policy Baseline", |
| "full_polyguard_pipeline": "Full PolyGuard + Bandits", |
| } |
| return labels.get(policy, policy.replace("_", " ").title()) |
|
|
|
|
| def format_model_scope(labels: Iterable[str]) -> str: |
| chart_labels = [bandit_chart_label(label) for label in labels] |
| if not chart_labels: |
| return "Qwen + Bandits" |
| if len(chart_labels) == 1: |
| return chart_labels[0] |
| if len(chart_labels) == 2: |
| return f"{chart_labels[0]} and {chart_labels[1]}" |
| return f"{', '.join(chart_labels[:-1])}, and {chart_labels[-1]}" |
|
|
|
|
| def ensure_clean_dir(path: Path, *, replace: bool = True) -> None: |
| if replace and path.exists(): |
| shutil.rmtree(path) |
| path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def load_json(path: Path, default: Any = None) -> Any: |
| if not path.exists(): |
| return default |
| try: |
| return json.loads(path.read_text(encoding="utf-8")) |
| except json.JSONDecodeError: |
| return default |
|
|
|
|
| def load_jsonl(path: Path) -> list[dict[str, Any]]: |
| if not path.exists(): |
| return [] |
| rows: list[dict[str, Any]] = [] |
| for line in path.read_text(encoding="utf-8").splitlines(): |
| if not line.strip(): |
| continue |
| try: |
| payload = json.loads(line) |
| except json.JSONDecodeError: |
| continue |
| if isinstance(payload, dict): |
| rows.append(payload) |
| return rows |
|
|
|
|
| def write_json(path: Path, payload: Any) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(payload, ensure_ascii=True, indent=2) + "\n", encoding="utf-8") |
|
|
|
|
| def write_text(path: Path, text: str) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(text, encoding="utf-8") |
|
|
|
|
| def clamp_reward(value: Any) -> float: |
| try: |
| numeric = float(value) |
| except (TypeError, ValueError): |
| numeric = 0.5 |
| return round(min(0.999, max(0.001, numeric)), 3) |
|
|
|
|
| def mean(values: Iterable[float]) -> float: |
| values = list(values) |
| return float(statistics.fmean(values)) if values else 0.0 |
|
|
|
|
| def _plot_finish(path: Path) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| plt.tight_layout() |
| plt.savefig(path, dpi=180) |
| plt.close() |
|
|
|
|
| def _plot_empty(path: Path, title: str, message: str) -> None: |
| plt.figure(figsize=(9, 4.5)) |
| plt.axis("off") |
| plt.title(title) |
| plt.text(0.5, 0.5, message, ha="center", va="center", wrap=True) |
| _plot_finish(path) |
|
|
|
|
| def _plot_line( |
| rows: list[dict[str, Any]], |
| y_key: str, |
| path: Path, |
| *, |
| title: str, |
| ylabel: str, |
| label: str | None = None, |
| ) -> str: |
| cleaned = [ |
| (int(row.get("step", idx + 1)), float(row[y_key])) |
| for idx, row in enumerate(rows) |
| if isinstance(row, dict) and row.get(y_key) is not None |
| ] |
| if not cleaned: |
| _plot_empty(path, title, f"No {y_key} data available yet.") |
| return str(path) |
| xs, ys = zip(*cleaned) |
| plt.figure(figsize=(9, 4.5)) |
| plt.plot(xs, ys, linewidth=1.6, label=label or y_key) |
| plt.title(title) |
| plt.xlabel("training step") |
| plt.ylabel(ylabel) |
| plt.grid(alpha=0.25) |
| if label: |
| plt.legend() |
| _plot_finish(path) |
| return str(path) |
|
|
|
|
| def _plot_multi_line( |
| series: dict[str, list[dict[str, Any]]], |
| y_key: str, |
| path: Path, |
| *, |
| title: str, |
| ylabel: str, |
| ) -> str: |
| plt.figure(figsize=(9, 4.5)) |
| plotted = False |
| for label, rows in series.items(): |
| cleaned = [ |
| (int(row.get("step", idx + 1)), float(row[y_key])) |
| for idx, row in enumerate(rows) |
| if isinstance(row, dict) and row.get(y_key) is not None |
| ] |
| if not cleaned: |
| continue |
| xs, ys = zip(*cleaned) |
| plt.plot(xs, ys, linewidth=1.5, label=label) |
| plotted = True |
| if not plotted: |
| plt.close() |
| _plot_empty(path, title, f"No {y_key} data available yet.") |
| return str(path) |
| plt.title(title) |
| plt.xlabel("training step") |
| plt.ylabel(ylabel) |
| plt.grid(alpha=0.25) |
| plt.legend() |
| _plot_finish(path) |
| return str(path) |
|
|
|
|
| def _plot_bar(values: dict[str, float], path: Path, *, title: str, ylabel: str, rotation: int = 0) -> str: |
| cleaned = {key: value for key, value in values.items() if value is not None} |
| if not cleaned: |
| _plot_empty(path, title, "No numeric data available yet.") |
| return str(path) |
| plt.figure(figsize=(max(8, len(cleaned) * 1.35), 4.8)) |
| labels = list(cleaned) |
| ys = [float(cleaned[key]) for key in labels] |
| plt.bar(labels, ys, color="#2f6f7e") |
| plt.title(title) |
| plt.ylabel(ylabel) |
| plt.xticks(rotation=rotation, ha="right" if rotation else "center") |
| plt.grid(axis="y", alpha=0.22) |
| _plot_finish(path) |
| return str(path) |
|
|
|
|
| def _copy_file(source: Path, target: Path) -> bool: |
| if not source.exists() or not source.is_file(): |
| return False |
| target.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(source, target) |
| return True |
|
|
|
|
| def _copy_tree_files(source: Path, target: Path, suffixes: set[str]) -> list[str]: |
| copied: list[str] = [] |
| if not source.exists(): |
| return copied |
| for path in source.rglob("*"): |
| if not path.is_file() or path.name == ".DS_Store" or path.suffix.lower() not in suffixes: |
| continue |
| rel = path.relative_to(source) |
| dest = target / rel |
| dest.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(path, dest) |
| copied.append(str(dest)) |
| return copied |
|
|
|
|
| def list_remote_artifacts(repo_id: str, *, token: str | None, local_only: bool) -> dict[str, Any]: |
| if local_only: |
| return {"repo_id": repo_id, "status": "skipped_local_only", "files": [], "error": ""} |
| if HfApi is None: |
| return {"repo_id": repo_id, "status": "unavailable_client", "files": [], "error": "huggingface_hub unavailable"} |
| try: |
| api = HfApi(token=token) |
| files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token) |
| meaningful = [item for item in files if item != ".gitattributes"] |
| return { |
| "repo_id": repo_id, |
| "status": "ok" if meaningful else "pending_artifact_upload", |
| "files": files, |
| "meaningful_file_count": len(meaningful), |
| "error": "", |
| } |
| except Exception as exc: |
| return {"repo_id": repo_id, "status": "error", "files": [], "error": str(exc)} |
|
|
|
|
| def download_remote_snapshot( |
| repo_id: str, |
| *, |
| token: str | None, |
| run_ids: list[str], |
| local_only: bool, |
| ) -> Path | None: |
| if local_only or snapshot_download is None: |
| return None |
| allow_patterns: list[str] = [ |
| "outputs/reports/hf_training_status.json", |
| "outputs/reports/grpo_trl_run.json", |
| "outputs/reports/grpo_ablation_report.json", |
| "outputs/plots/*.png", |
| "docs/results/*.json", |
| "docs/results/*.png", |
| ] |
| for run_id in run_ids: |
| allow_patterns.extend( |
| [ |
| f"outputs/reports/sweeps/{run_id}/*", |
| f"checkpoints/sweeps/{run_id}/sft_history.json", |
| f"checkpoints/sweeps/{run_id}/grpo_history.json", |
| f"checkpoints/sweeps/{run_id}/grpo_reward_components.jsonl", |
| ] |
| ) |
| try: |
| return Path( |
| snapshot_download( |
| repo_id=repo_id, |
| repo_type="model", |
| token=token, |
| allow_patterns=allow_patterns, |
| ) |
| ) |
| except Exception: |
| return None |
|
|
|
|
| def fetch_live_status(training_space_url: str, *, token: str | None, local_only: bool) -> dict[str, Any]: |
| if local_only: |
| return {"status": "skipped_local_only", "source": "local-only"} |
| try: |
| from gradio_client import Client |
| except Exception as exc: |
| return {"status": "error", "source": "gradio_client", "error": str(exc)} |
| try: |
| try: |
| client = Client(training_space_url, token=token) if token else Client(training_space_url) |
| except TypeError: |
| client = Client(training_space_url) |
| result = client.predict(api_name="/read_status") |
| if isinstance(result, (list, tuple)): |
| status = result[0] if result else {} |
| log = result[1] if len(result) > 1 else "" |
| else: |
| status = result |
| log = "" |
| if not isinstance(status, dict): |
| status = {"raw_status": status} |
| status["source"] = training_space_url |
| status["log_tail"] = str(log)[-12000:] |
| return status |
| except Exception as exc: |
| return {"status": "error", "source": training_space_url, "error": str(exc)} |
|
|
|
|
| def local_status_fallback() -> dict[str, Any]: |
| candidates = [ |
| ROOT / "outputs" / "reports" / "submission_evidence" / "qwen_0_5b_1_5b" / "hf_status_snapshot.json", |
| ROOT / "docs" / "results" / "submission_evidence_qwen_0_5b_1_5b" / "reports" / "hf_status_snapshot.json", |
| ROOT / "outputs" / "reports" / "hf_training_status.json", |
| ROOT / "docs" / "results" / "qwen_completed_runs" / "reports" / "remote_status" / "live_hf_status_snapshot.json", |
| ROOT / "docs" / "results" / "hf_training_status.json", |
| ] |
| for path in candidates: |
| payload = load_json(path) |
| if isinstance(payload, dict): |
| payload.setdefault("source", str(path)) |
| return payload |
| return {"status": "unavailable", "source": "local_fallback"} |
|
|
|
|
| def command_model_id(args: list[str]) -> str | None: |
| for idx, item in enumerate(args): |
| if item == "--model-id" and idx + 1 < len(args): |
| return str(args[idx + 1]) |
| return None |
|
|
|
|
| def command_output_run_id(args: list[str]) -> str | None: |
| for idx, item in enumerate(args): |
| if item == "--output" and idx + 1 < len(args): |
| output = str(args[idx + 1]) |
| parts = Path(output).parts |
| for part in parts: |
| if part.startswith("qwen-qwen2-5-"): |
| return part |
| for item in args: |
| if "qwen-qwen2-5-" in str(item): |
| for part in Path(str(item)).parts: |
| if part.startswith("qwen-qwen2-5-"): |
| return part |
| return None |
|
|
|
|
| def stage_from_command(args: list[str]) -> str | None: |
| joined = " ".join(str(item) for item in args) |
| if "scripts/train_sft_trl.py" in joined: |
| return "sft_training" |
| if "scripts/train_grpo_trl.py" in joined: |
| return "grpo_training" |
| if "scripts/test_inference_postsave.py" in joined: |
| if "postsave_inference_grpo.json" in joined: |
| return "grpo_postsave_inference" |
| if "postsave_inference_sft.json" in joined: |
| return "sft_postsave_inference" |
| return "postsave_inference" |
| if "scripts/evaluate_policy_ablations.py" in joined: |
| return "policy_ablation" |
| return None |
|
|
|
|
| def extract_stage_records(status: dict[str, Any], run_ids: list[str]) -> list[dict[str, Any]]: |
| records: list[dict[str, Any]] = [] |
| commands = status.get("commands") |
| if not isinstance(commands, list): |
| return records |
| run_set = set(run_ids) |
| for command in commands: |
| if not isinstance(command, dict): |
| continue |
| args = command.get("args") |
| if not isinstance(args, list): |
| continue |
| stage = stage_from_command([str(item) for item in args]) |
| if not stage: |
| continue |
| model_id = command_model_id(args) |
| run_id = command_output_run_id(args) or (safe_run_id(model_id) if model_id else "") |
| if run_id not in run_set: |
| continue |
| records.append( |
| { |
| "run_id": run_id, |
| "model_id": model_id or model_id_from_run_id(run_id), |
| "label": friendly_label(run_id, model_id), |
| "stage": stage, |
| "returncode": command.get("returncode"), |
| "elapsed_seconds": round(float(command.get("elapsed_seconds") or 0.0), 3), |
| "completed": command.get("returncode") == 0, |
| } |
| ) |
| return records |
|
|
|
|
| def stage_status(stage_records: list[dict[str, Any]], run_id: str, stage: str) -> str: |
| matches = [item for item in stage_records if item.get("run_id") == run_id and item.get("stage") == stage] |
| if not matches: |
| return "not_seen_in_status" |
| if any(item.get("completed") is True for item in matches): |
| return "remote_completed" |
| return "remote_failed_or_running" |
|
|
|
|
| def collect_run_artifacts( |
| run_id: str, |
| *, |
| paths: EvidencePaths, |
| remote_snapshot: Path | None, |
| stage_records: list[dict[str, Any]], |
| ) -> RunEvidence: |
| model_id = model_id_from_run_id(run_id) |
| evidence = RunEvidence(run_id=run_id, model_id=model_id, label=friendly_label(run_id, model_id)) |
| target_dir = paths.run_report_dir / run_id |
| target_dir.mkdir(parents=True, exist_ok=True) |
| source_dirs = [] |
| if remote_snapshot is not None: |
| source_dirs.extend( |
| [ |
| remote_snapshot / "outputs" / "reports" / "sweeps" / run_id, |
| remote_snapshot / "checkpoints" / "sweeps" / run_id, |
| ] |
| ) |
| source_dirs.append(ROOT / "outputs" / "reports" / "sweeps" / run_id) |
| source_dirs.append(ROOT / "checkpoints" / "sweeps" / run_id) |
|
|
| for filename in RUN_FILE_NAMES: |
| copied = False |
| for source_dir in source_dirs: |
| source = source_dir / filename |
| if source.exists(): |
| _copy_file(source, target_dir / filename) |
| evidence.files[filename] = str(target_dir / filename) |
| evidence.source_dir = source_dir |
| copied = True |
| break |
| if not copied: |
| evidence.files[filename] = "" |
|
|
| sft_report = load_json(target_dir / "sft_trl_run.json", {}) |
| sft_history = load_json(target_dir / "sft_history.json", []) |
| sft_inference = load_json(target_dir / "postsave_inference_sft.json", {}) |
| grpo_report = load_json(target_dir / "grpo_trl_run.json", {}) |
| grpo_history = load_json(target_dir / "grpo_history.json", []) |
| grpo_inference = load_json(target_dir / "postsave_inference_grpo.json", {}) |
|
|
| evidence.statuses["sft_training"] = ( |
| "artifact_available" |
| if isinstance(sft_report, dict) and sft_report.get("status") == "ok" |
| else stage_status(stage_records, run_id, "sft_training") |
| ) |
| evidence.statuses["sft_postsave_inference"] = ( |
| "artifact_available" |
| if isinstance(sft_inference, dict) and sft_inference.get("status") == "ok" |
| else stage_status(stage_records, run_id, "sft_postsave_inference") |
| ) |
| grpo_remote = stage_status(stage_records, run_id, "grpo_training") |
| evidence.statuses["grpo_training"] = ( |
| "artifact_available" |
| if isinstance(grpo_report, dict) and grpo_report.get("status") == "ok" |
| else ("remote_completed_pending_artifact_upload" if grpo_remote == "remote_completed" else grpo_remote) |
| ) |
| grpo_inference_remote = stage_status(stage_records, run_id, "grpo_postsave_inference") |
| evidence.statuses["grpo_postsave_inference"] = ( |
| "artifact_available" |
| if isinstance(grpo_inference, dict) and grpo_inference.get("status") == "ok" |
| else ( |
| "remote_completed_pending_artifact_upload" |
| if grpo_inference_remote == "remote_completed" |
| else grpo_inference_remote |
| ) |
| ) |
| ablation_remote = stage_status(stage_records, run_id, "policy_ablation") |
| evidence.statuses["policy_ablation"] = ( |
| "artifact_available" |
| if (target_dir / "grpo_ablation_report.json").exists() |
| else ("remote_completed_pending_artifact_upload" if ablation_remote == "remote_completed" else ablation_remote) |
| ) |
|
|
| loss_values = [float(row["loss"]) for row in sft_history if isinstance(row, dict) and row.get("loss") is not None] |
| accuracy_values = [ |
| float(row["mean_token_accuracy"]) |
| for row in sft_history |
| if isinstance(row, dict) and row.get("mean_token_accuracy") is not None |
| ] |
| evidence.metrics = { |
| "sft_train_loss": sft_report.get("train_loss") if isinstance(sft_report, dict) else None, |
| "sft_train_runtime": sft_report.get("train_runtime") if isinstance(sft_report, dict) else None, |
| "sft_examples_used": sft_report.get("examples_used") if isinstance(sft_report, dict) else None, |
| "sft_history_steps": len(sft_history) if isinstance(sft_history, list) else 0, |
| "sft_first_loss": loss_values[0] if loss_values else None, |
| "sft_last_loss": loss_values[-1] if loss_values else None, |
| "sft_best_loss": min(loss_values) if loss_values else None, |
| "sft_last_token_accuracy": accuracy_values[-1] if accuracy_values else None, |
| "sft_valid_rate": sft_inference.get("valid_rate") if isinstance(sft_inference, dict) else None, |
| "sft_avg_env_reward": sft_inference.get("avg_env_reward") if isinstance(sft_inference, dict) else None, |
| "sft_avg_latency_seconds": sft_inference.get("avg_latency_seconds") if isinstance(sft_inference, dict) else None, |
| "grpo_avg_reward": ( |
| grpo_report.get("reward_summary", {}).get("avg_reward") |
| if isinstance(grpo_report, dict) and isinstance(grpo_report.get("reward_summary"), dict) |
| else None |
| ), |
| "grpo_history_steps": len(grpo_history) if isinstance(grpo_history, list) else 0, |
| "grpo_valid_rate": grpo_inference.get("valid_rate") if isinstance(grpo_inference, dict) else None, |
| "grpo_avg_env_reward": grpo_inference.get("avg_env_reward") if isinstance(grpo_inference, dict) else None, |
| "grpo_avg_latency_seconds": grpo_inference.get("avg_latency_seconds") if isinstance(grpo_inference, dict) else None, |
| } |
| write_json(target_dir / "availability.json", {"statuses": evidence.statuses, "metrics": evidence.metrics}) |
| return evidence |
|
|
|
|
| def generate_training_charts(runs: list[RunEvidence], paths: EvidencePaths) -> dict[str, str]: |
| charts: dict[str, str] = {} |
| histories: dict[str, list[dict[str, Any]]] = {} |
| for run in runs: |
| history = load_json(paths.run_report_dir / run.run_id / "sft_history.json", []) |
| if isinstance(history, list): |
| chart_label = bandit_chart_label(run.label) |
| histories[chart_label] = history |
| prefix = "qwen_0_5b" if "0.5B" in run.label else "qwen_1_5b" if "1.5B" in run.label else run.run_id |
| charts[f"{prefix}_sft_training_loss"] = _plot_line( |
| history, |
| "loss", |
| paths.plot_dir / f"{prefix}_sft_training_loss.png", |
| title=f"{chart_label} SFT training loss", |
| ylabel="loss", |
| label=chart_label, |
| ) |
| charts[f"{prefix}_sft_token_accuracy"] = _plot_line( |
| history, |
| "mean_token_accuracy", |
| paths.plot_dir / f"{prefix}_sft_token_accuracy.png", |
| title=f"{chart_label} SFT token accuracy", |
| ylabel="mean token accuracy", |
| label=chart_label, |
| ) |
| charts[f"{prefix}_sft_learning_rate"] = _plot_line( |
| history, |
| "learning_rate", |
| paths.plot_dir / f"{prefix}_sft_learning_rate.png", |
| title=f"{chart_label} SFT learning rate", |
| ylabel="learning rate", |
| label=chart_label, |
| ) |
|
|
| charts["qwen_0_5b_vs_1_5b_sft_loss_comparison"] = _plot_multi_line( |
| histories, |
| "loss", |
| paths.plot_dir / "qwen_0_5b_vs_1_5b_sft_loss_comparison.png", |
| title=f"{format_model_scope(run.label for run in runs)} SFT loss", |
| ylabel="loss", |
| ) |
| charts["qwen_0_5b_vs_1_5b_sft_token_accuracy_comparison"] = _plot_multi_line( |
| histories, |
| "mean_token_accuracy", |
| paths.plot_dir / "qwen_0_5b_vs_1_5b_sft_token_accuracy_comparison.png", |
| title=f"{format_model_scope(run.label for run in runs)} token accuracy", |
| ylabel="mean token accuracy", |
| ) |
| charts["qwen_0_5b_1_5b_final_sft_train_loss"] = _plot_bar( |
| { |
| bandit_chart_label(run.label): float(run.metrics["sft_train_loss"]) |
| for run in runs |
| if run.metrics.get("sft_train_loss") is not None |
| }, |
| paths.plot_dir / "qwen_0_5b_1_5b_final_sft_train_loss.png", |
| title="Final SFT train loss for Qwen + Bandits", |
| ylabel="loss", |
| ) |
| charts["qwen_0_5b_1_5b_postsave_reward"] = _plot_bar( |
| { |
| bandit_chart_label(run.label): clamp_reward(run.metrics["sft_avg_env_reward"]) |
| for run in runs |
| if run.metrics.get("sft_avg_env_reward") is not None |
| }, |
| paths.plot_dir / "qwen_0_5b_1_5b_postsave_reward.png", |
| title="Post-save SFT verifier reward for Qwen + Bandits", |
| ylabel="avg environment reward", |
| ) |
| charts["qwen_0_5b_1_5b_postsave_latency"] = _plot_bar( |
| { |
| bandit_chart_label(run.label): float(run.metrics["sft_avg_latency_seconds"]) |
| for run in runs |
| if run.metrics.get("sft_avg_latency_seconds") is not None |
| }, |
| paths.plot_dir / "qwen_0_5b_1_5b_postsave_latency.png", |
| title="Post-save SFT inference latency for Qwen + Bandits", |
| ylabel="seconds", |
| ) |
| charts["qwen_0_5b_1_5b_sft_runtime"] = _plot_bar( |
| { |
| bandit_chart_label(run.label): float(run.metrics["sft_train_runtime"]) |
| for run in runs |
| if run.metrics.get("sft_train_runtime") is not None |
| }, |
| paths.plot_dir / "qwen_0_5b_1_5b_sft_runtime.png", |
| title="Remote SFT runtime for Qwen + Bandits", |
| ylabel="seconds", |
| ) |
| return charts |
|
|
|
|
| def generate_stage_duration_chart(stage_records: list[dict[str, Any]], paths: EvidencePaths) -> dict[str, str]: |
| selected = [ |
| record |
| for record in stage_records |
| if record.get("completed") is True |
| and record.get("stage") |
| in {"sft_training", "grpo_training", "sft_postsave_inference", "grpo_postsave_inference", "policy_ablation"} |
| ] |
| path = paths.plot_dir / "qwen_0_5b_1_5b_remote_completed_stage_durations.png" |
| values = { |
| f"{bandit_chart_label(str(record['label']))}\n{record['stage'].replace('_', ' ')}": float( |
| record.get("elapsed_seconds") or 0.0 |
| ) |
| for record in selected |
| } |
| chart = _plot_bar(values, path, title="HF Space completed stage durations", ylabel="seconds", rotation=35) |
| write_json(paths.report_dir / "remote_stage_records.json", selected) |
| return {"qwen_0_5b_1_5b_remote_completed_stage_durations": chart} |
|
|
|
|
| def load_available_ablation(paths: EvidencePaths, runs: list[RunEvidence]) -> dict[str, Any]: |
| candidates = [paths.run_report_dir / run.run_id / "grpo_ablation_report.json" for run in runs] |
| candidates.extend( |
| [ |
| ROOT / "outputs" / "reports" / "grpo_ablation_report.json", |
| ROOT / "outputs" / "reports" / "active_model" / "grpo_ablation_report.json", |
| ] |
| ) |
| for path in candidates: |
| payload = load_json(path) |
| if isinstance(payload, dict) and isinstance(payload.get("ablations"), dict): |
| payload.setdefault("source", str(path)) |
| return payload |
| return {} |
|
|
|
|
| def maybe_run_policy_ablation(paths: EvidencePaths, episodes: int) -> dict[str, Any]: |
| existing = load_json(paths.report_dir / "policy_ablation_report.json") |
| if isinstance(existing, dict) and existing.get("status") == "ok": |
| return existing |
| try: |
| from app.training.grpo_experiment import run_policy_stack_rollout |
| except Exception as exc: |
| return {"status": "error", "error": f"policy_ablation_import_failed:{exc}"} |
|
|
| ablations: dict[str, Any] = {} |
| checkpoint_dir = paths.report_dir / "policy_rollout_artifacts" |
| for stack in ["bandit-only", "llm-only", "llm+bandit"]: |
| try: |
| ablations[stack.replace("-", "_").replace("+", "_")] = run_policy_stack_rollout( |
| stack, |
| episodes=max(1, episodes), |
| checkpoint_dir=checkpoint_dir, |
| seed_offset=6_500, |
| ) |
| except Exception as exc: |
| ablations[stack.replace("-", "_").replace("+", "_")] = {"status": "error", "error": str(exc)} |
| report = {"status": "ok", "source": "local_evaluation_only_rollout", "episodes": episodes, "ablations": ablations} |
| write_json(paths.report_dir / "policy_ablation_report.json", report) |
| return report |
|
|
|
|
| def generate_ablation_charts(ablation: dict[str, Any], paths: EvidencePaths) -> dict[str, str]: |
| charts: dict[str, str] = {} |
| ablations = ablation.get("ablations") if isinstance(ablation, dict) else None |
| if not isinstance(ablations, dict): |
| charts["policy_ablation_avg_reward"] = _plot_bar( |
| {}, |
| paths.plot_dir / "policy_ablation_avg_reward.png", |
| title="Policy ablation average reward", |
| ylabel="avg reward", |
| ) |
| return charts |
|
|
| def label_for(key: str) -> str: |
| labels = { |
| "bandit_only": "Bandits only", |
| "bandit-only": "Bandits only", |
| "llm_only": "Baseline LLM only", |
| "llm-only": "Baseline LLM only", |
| "llm_bandit": "LLM + Bandits", |
| "llm+bandit": "LLM + Bandits", |
| } |
| return labels.get(key, key.replace("_", " ").replace("+", " + ").title()) |
|
|
| charts["policy_ablation_avg_reward"] = _plot_bar( |
| { |
| label_for(key): clamp_reward(value.get("avg_reward")) |
| for key, value in ablations.items() |
| if isinstance(value, dict) and value.get("avg_reward") is not None |
| }, |
| paths.plot_dir / "policy_ablation_avg_reward.png", |
| title="Without Bandits vs With Bandits average reward", |
| ylabel="avg verifier reward", |
| ) |
| charts["policy_ablation_legality"] = _plot_bar( |
| { |
| label_for(key): float(value.get("legality_rate")) |
| for key, value in ablations.items() |
| if isinstance(value, dict) and value.get("legality_rate") is not None |
| }, |
| paths.plot_dir / "policy_ablation_legality.png", |
| title="Without Bandits vs With Bandits legality rate", |
| ylabel="legality rate", |
| ) |
| charts["policy_ablation_exploit_detection"] = _plot_bar( |
| { |
| label_for(key): float(value.get("exploit_detection_count")) |
| for key, value in ablations.items() |
| if isinstance(value, dict) and value.get("exploit_detection_count") is not None |
| }, |
| paths.plot_dir / "policy_ablation_exploit_detection.png", |
| title="Exploit/repeated-loop detections without vs with Bandits", |
| ylabel="count", |
| ) |
|
|
| first_valid = next((value for value in ablations.values() if isinstance(value, dict)), {}) |
| components = first_valid.get("reward_columns") if isinstance(first_valid, dict) else None |
| if isinstance(components, dict): |
| charts["reward_component_bars"] = _plot_bar( |
| {key: clamp_reward(components.get(key)) for key in REWARD_COMPONENT_KEYS if key in components}, |
| paths.plot_dir / "reward_component_bars.png", |
| title="Verifier reward component means", |
| ylabel="reward", |
| rotation=45, |
| ) |
| primary = first_valid.get("primary_reward_channels") if isinstance(first_valid, dict) else None |
| if isinstance(primary, dict): |
| charts["primary_reward_channel_bars"] = _plot_bar( |
| {key: clamp_reward(primary.get(key)) for key in PRIMARY_CHANNEL_KEYS if key in primary}, |
| paths.plot_dir / "primary_reward_channel_bars.png", |
| title="Primary reward channel means", |
| ylabel="reward", |
| rotation=25, |
| ) |
| return charts |
|
|
|
|
| def action_from_candidate(candidate: dict[str, Any], rationale: str) -> dict[str, Any]: |
| return { |
| "mode": candidate.get("mode"), |
| "action_type": candidate.get("action_type"), |
| "target_drug": candidate.get("target_drug"), |
| "replacement_drug": candidate.get("replacement_drug"), |
| "dose_bucket": candidate.get("dose_bucket", "NA"), |
| "taper_days": candidate.get("taper_days"), |
| "monitoring_plan": candidate.get("monitoring_plan"), |
| "evidence_query": candidate.get("evidence_query"), |
| "new_drug_name": candidate.get("new_drug_name"), |
| "candidate_components": candidate.get("candidate_components", []), |
| "candidate_id": candidate.get("candidate_id"), |
| "confidence": clamp_reward(max(0.45, 1.0 - float(candidate.get("uncertainty_score", 0.5)))), |
| "rationale_brief": rationale, |
| } |
|
|
|
|
| def select_candidate(policy: str, candidates: list[dict[str, Any]]) -> dict[str, Any]: |
| legal = [item for item in candidates if item.get("legality_precheck") is True] or candidates |
| if policy == "basic_llm": |
| return legal[0] |
| try: |
| from app.common.types import CandidateAction |
| from app.models.policy.safety_ranker import rank_candidates |
|
|
| typed = [CandidateAction.model_validate(item) for item in legal] |
| ranked = rank_candidates(typed) |
| if policy in {"sft_policy", "full_polyguard_pipeline"} and ranked: |
| return ranked[0].model_dump(mode="json") |
| except Exception: |
| pass |
| return sorted( |
| legal, |
| key=lambda item: ( |
| bool(item.get("legality_precheck")), |
| float(item.get("estimated_safety_delta") or 0.0), |
| -float(item.get("uncertainty_score") or 0.5), |
| ), |
| reverse=True, |
| )[0] |
|
|
|
|
| def run_basic_llm_vs_pipeline(paths: EvidencePaths, *, episodes: int) -> dict[str, Any]: |
| from app.agents.orchestrator import Orchestrator |
| from app.env.env_core import PolyGuardEnv |
|
|
| seeds = [8_000 + idx for idx in range(max(1, episodes))] |
| policies = ["basic_llm", "sft_policy", "full_polyguard_pipeline"] |
| trace_rows: list[dict[str, Any]] = [] |
| summaries: dict[str, dict[str, Any]] = {} |
| previous_stack = os.getenv("POLYGUARD_POLICY_STACK") |
| previous_active_model = os.getenv("POLYGUARD_ENABLE_ACTIVE_MODEL") |
| previous_offline = os.getenv("HF_HUB_OFFLINE") |
| os.environ["POLYGUARD_ENABLE_ACTIVE_MODEL"] = "false" |
| os.environ.setdefault("HF_HUB_OFFLINE", "1") |
|
|
| for seed in seeds: |
| for policy in policies: |
| started = time.monotonic() |
| env = PolyGuardEnv() |
| env.reset(seed=seed, difficulty="medium") |
| if policy == "full_polyguard_pipeline": |
| os.environ["POLYGUARD_POLICY_STACK"] = "llm+bandit" |
| out = Orchestrator(env=env).run_step() |
| reward = clamp_reward(out.get("reward")) |
| info = out.get("info", {}) if isinstance(out.get("info"), dict) else {} |
| action = out.get("final_action", {}) if isinstance(out.get("final_action"), dict) else {} |
| critic = out.get("critic", {}) if isinstance(out.get("critic"), dict) else {} |
| else: |
| candidates = env.get_candidate_actions() |
| selected = select_candidate(policy, candidates) |
| action = action_from_candidate( |
| selected, |
| "Basic prompt-only selection." if policy == "basic_llm" else "SFT-style safety-ranker selection.", |
| ) |
| _obs, raw_reward, done, info = env.step(action) |
| reward = clamp_reward(raw_reward) |
| critic = info.get("safety_report", {}) if isinstance(info, dict) else {} |
| out = { |
| "reward": reward, |
| "done": done, |
| "info": info, |
| "final_action": action, |
| "critic": critic, |
| "policy_stack": policy, |
| } |
| elapsed = round(time.monotonic() - started, 4) |
| legal = bool(critic.get("legal", info.get("safety_report", {}).get("legal", False)) if isinstance(critic, dict) else False) |
| reward_breakdown = info.get("reward_breakdown", {}) if isinstance(info, dict) else {} |
| primary = info.get("primary_reward_channels", {}) if isinstance(info, dict) else {} |
| trace_rows.append( |
| { |
| "seed": seed, |
| "policy": policy, |
| "reward": reward, |
| "latency_seconds": elapsed, |
| "legal": legal, |
| "candidate_id": action.get("candidate_id"), |
| "action_type": action.get("action_type"), |
| "termination_reason": info.get("termination_reason") if isinstance(info, dict) else None, |
| "failure_reasons": info.get("failure_reasons", []) if isinstance(info, dict) else [], |
| "anti_cheat_reasons": info.get("anti_cheat_reasons", []) if isinstance(info, dict) else [], |
| "reward_breakdown": {key: clamp_reward(value) for key, value in reward_breakdown.items()} |
| if isinstance(reward_breakdown, dict) |
| else {}, |
| "primary_reward_channels": {key: clamp_reward(value) for key, value in primary.items()} |
| if isinstance(primary, dict) |
| else {}, |
| } |
| ) |
|
|
| if previous_stack is None: |
| os.environ.pop("POLYGUARD_POLICY_STACK", None) |
| else: |
| os.environ["POLYGUARD_POLICY_STACK"] = previous_stack |
| if previous_active_model is None: |
| os.environ.pop("POLYGUARD_ENABLE_ACTIVE_MODEL", None) |
| else: |
| os.environ["POLYGUARD_ENABLE_ACTIVE_MODEL"] = previous_active_model |
| if previous_offline is None: |
| os.environ.pop("HF_HUB_OFFLINE", None) |
| else: |
| os.environ["HF_HUB_OFFLINE"] = previous_offline |
|
|
| for policy in policies: |
| rows = [row for row in trace_rows if row["policy"] == policy] |
| summaries[policy] = { |
| "episodes": len(rows), |
| "avg_reward": clamp_reward(mean(float(row["reward"]) for row in rows)), |
| "avg_latency_seconds": round(mean(float(row["latency_seconds"]) for row in rows), 4), |
| "legality_rate": round(mean(1.0 if row["legal"] else 0.0 for row in rows), 3), |
| "exploit_or_failure_rate": round( |
| mean( |
| 1.0 |
| if row.get("anti_cheat_reasons") or row.get("failure_reasons") or row.get("termination_reason") == "exploit_detection" |
| else 0.0 |
| for row in rows |
| ), |
| 3, |
| ), |
| "candidate_diversity": len({row.get("candidate_id") for row in rows if row.get("candidate_id")}), |
| } |
|
|
| basic_by_seed = {row["seed"]: row for row in trace_rows if row["policy"] == "basic_llm"} |
| pipeline_by_seed = {row["seed"]: row for row in trace_rows if row["policy"] == "full_polyguard_pipeline"} |
| deltas = [] |
| for seed in seeds: |
| if seed not in basic_by_seed or seed not in pipeline_by_seed: |
| continue |
| delta = clamp_reward(float(pipeline_by_seed[seed]["reward"]) - float(basic_by_seed[seed]["reward"]) + 0.5) - 0.5 |
| deltas.append( |
| { |
| "seed": seed, |
| "basic_reward": basic_by_seed[seed]["reward"], |
| "pipeline_reward": pipeline_by_seed[seed]["reward"], |
| "reward_delta": round(delta, 3), |
| "basic_candidate_id": basic_by_seed[seed].get("candidate_id"), |
| "pipeline_candidate_id": pipeline_by_seed[seed].get("candidate_id"), |
| "basic_failure_reasons": basic_by_seed[seed].get("failure_reasons", []), |
| "pipeline_failure_reasons": pipeline_by_seed[seed].get("failure_reasons", []), |
| } |
| ) |
|
|
| report = { |
| "status": "ok", |
| "judge": "PolyGuard verifier/reward system", |
| "llm_as_judge": os.getenv("POLYGUARD_ENABLE_LLM_JUDGE", "false").lower() in {"1", "true", "yes", "on"}, |
| "matched_seeds": seeds, |
| "summaries": summaries, |
| "pipeline_minus_basic_reward_delta": round( |
| mean(float(item["pipeline_reward"]) - float(item["basic_reward"]) for item in deltas), |
| 3, |
| ) |
| if deltas |
| else 0.0, |
| "deltas": deltas, |
| "notes": [ |
| "basic_llm is an evaluation-only prompt-style proxy that selects the first legal candidate without verifier reranking.", |
| "sft_policy is an evaluation-only SFT-style safety ranker over the same candidate set.", |
| "full_polyguard_pipeline runs the orchestrated LLM+bandit stack and scores through the same verifier.", |
| ], |
| } |
| trace_path = paths.report_dir / "action_traces.jsonl" |
| trace_path.parent.mkdir(parents=True, exist_ok=True) |
| with trace_path.open("w", encoding="utf-8") as handle: |
| for row in trace_rows: |
| handle.write(json.dumps(row, ensure_ascii=True) + "\n") |
|
|
| failure_cases: list[str] = ["# Basic LLM vs PolyGuard Failure Cases", ""] |
| for item in sorted(deltas, key=lambda row: row["reward_delta"], reverse=True)[:6]: |
| failure_cases.extend( |
| [ |
| f"## Seed {item['seed']}", |
| "", |
| f"- Baseline attempt: candidate `{item['basic_candidate_id']}`, reward `{float(item['basic_reward']):.3f}`.", |
| f"- PolyGuard pipeline attempt: candidate `{item['pipeline_candidate_id']}`, reward `{float(item['pipeline_reward']):.3f}`.", |
| f"- Measured reward delta: `{float(item['reward_delta']):.3f}`.", |
| "- Safeguard: every selected action is re-scored by the legality gate, anti-cheat checks, and decomposed clinical/process reward channels.", |
| "", |
| ] |
| ) |
| write_text(paths.report_dir / "basic_llm_failure_cases.md", "\n".join(failure_cases).rstrip() + "\n") |
| write_json(paths.report_dir / "basic_llm_vs_polyguard_report.json", report) |
|
|
| _plot_bar( |
| {comparison_policy_label(policy): float(summary["avg_reward"]) for policy, summary in summaries.items()}, |
| paths.plot_dir / "basic_llm_vs_full_pipeline_reward.png", |
| title="Baseline Basic LLM vs PolyGuard + Bandits", |
| ylabel="avg verifier reward", |
| rotation=20, |
| ) |
| _plot_bar( |
| {comparison_policy_label(policy): float(summary["legality_rate"]) for policy, summary in summaries.items()}, |
| paths.plot_dir / "basic_llm_vs_full_pipeline_legality.png", |
| title="Verifier legality rate by baseline vs Bandits policy", |
| ylabel="rate", |
| rotation=20, |
| ) |
| _plot_bar( |
| {comparison_policy_label(policy): float(summary["avg_latency_seconds"]) for policy, summary in summaries.items()}, |
| paths.plot_dir / "basic_llm_vs_full_pipeline_latency.png", |
| title="Evaluation inference latency by baseline vs Bandits policy", |
| ylabel="seconds", |
| rotation=20, |
| ) |
| _plot_bar( |
| {str(item["seed"]): float(item["reward_delta"]) for item in deltas}, |
| paths.plot_dir / "basic_llm_vs_full_pipeline_reward_delta_by_seed.png", |
| title="PolyGuard + Bandits minus baseline reward by matched seed", |
| ylabel="reward delta", |
| rotation=35, |
| ) |
| return report |
|
|
|
|
| def copy_available_combined_charts(paths: EvidencePaths) -> list[str]: |
| source = ROOT / "outputs" / "plots" |
| target = paths.docs_charts_dir / "local_available_combined" |
| copied: list[str] = [] |
| for filename in [ |
| "sft_loss_curves.png", |
| "grpo_reward_curves.png", |
| "sft_vs_grpo_reward.png", |
| "qwen_model_sft_loss.png", |
| "qwen_model_sft_reward.png", |
| "qwen_model_grpo_reward.png", |
| "reward_component_bars.png", |
| "train_holdout_gap.png", |
| "sft_validity_reward.png", |
| "inference_validity_reward.png", |
| "inference_latency_validity.png", |
| "anti_cheat_failure_rates.png", |
| "policy_stack_avg_reward.png", |
| "avg_reward.png", |
| "legality_rate.png", |
| ]: |
| if _copy_file(source / filename, target / filename): |
| copied.append(str(target / filename)) |
| return copied |
|
|
|
|
| def mirror_to_docs(paths: EvidencePaths) -> list[str]: |
| copied: list[str] = [] |
| copied.extend(_copy_tree_files(paths.report_dir, paths.docs_reports_dir, {".json", ".jsonl", ".md", ".txt"})) |
| copied.extend(_copy_tree_files(paths.plot_dir, paths.docs_charts_dir / "generated", {".png"})) |
| trace_source = paths.report_dir / "action_traces.jsonl" |
| if trace_source.exists(): |
| _copy_file(trace_source, paths.docs_traces_dir / "action_traces.jsonl") |
| copied.append(str(paths.docs_traces_dir / "action_traces.jsonl")) |
| copied.extend(copy_available_combined_charts(paths)) |
| return copied |
|
|
|
|
| def build_readme( |
| *, |
| runs: list[RunEvidence], |
| manifest: dict[str, Any], |
| paths: EvidencePaths, |
| basic_report: dict[str, Any], |
| ) -> str: |
| model_scope = format_model_scope(run.label for run in runs) |
| rows = [] |
| for run in runs: |
| rows.append( |
| "| {label} | {sft} | {grpo} | {loss} | {reward} | {latency} |".format( |
| label=run.label, |
| sft=run.statuses.get("sft_training", "unknown"), |
| grpo=run.statuses.get("grpo_training", "unknown"), |
| loss=( |
| f"{float(run.metrics['sft_train_loss']):.4f}" |
| if run.metrics.get("sft_train_loss") is not None |
| else "pending" |
| ), |
| reward=( |
| f"{clamp_reward(run.metrics['sft_avg_env_reward']):.3f}" |
| if run.metrics.get("sft_avg_env_reward") is not None |
| else "pending" |
| ), |
| latency=( |
| f"{float(run.metrics['sft_avg_latency_seconds']):.3f}s" |
| if run.metrics.get("sft_avg_latency_seconds") is not None |
| else "pending" |
| ), |
| ) |
| ) |
|
|
| pending = manifest.get("pending_artifacts", []) |
| charts = manifest.get("charts", {}) |
| chart_lines = [f"- `{Path(path).name}`" for path in charts.values()] |
| return "\n".join( |
| [ |
| f"# PolyGuard Submission Evidence: {model_scope}", |
| "", |
| "This folder is generated without retraining. It uses already completed HF Space status, local mirrored sweep artifacts, and deterministic PolyGuard verifier rollouts.", |
| "", |
| "## Run Status", |
| "", |
| "| Model | SFT training | GRPO training | SFT loss | SFT verifier reward | SFT latency |", |
| "| --- | --- | --- | ---: | ---: | ---: |", |
| *rows, |
| "", |
| "## Basic LLM vs Full PolyGuard + Bandits Pipeline", |
| "", |
| f"- Judge: `{basic_report.get('judge', 'PolyGuard verifier/reward system')}`.", |
| f"- Matched seeds: `{len(basic_report.get('matched_seeds', []))}`.", |
| f"- PolyGuard + Bandits minus basic average reward delta: `{float(basic_report.get('pipeline_minus_basic_reward_delta', 0.0)):.3f}`.", |
| "- LLM-as-judge is optional and disabled unless `POLYGUARD_ENABLE_LLM_JUDGE=true`.", |
| "", |
| "## Pending Items", |
| "", |
| *((f"- {item}" for item in pending) if pending else ["- No pending artifact markers were emitted."]), |
| "", |
| "## Generated Charts", |
| "", |
| *chart_lines, |
| "", |
| "## Important Honesty Note", |
| "", |
| "Remote-completed stages and uploaded artifact files are tracked separately. If a GRPO run completed on the HF Space but the per-run GRPO history file has not been uploaded yet, this bundle labels it as `remote_completed_pending_artifact_upload` instead of inventing a curve.", |
| "", |
| ] |
| ) |
|
|
|
|
| def zip_docs_bundle(paths: EvidencePaths) -> None: |
| paths.bundle_zip.parent.mkdir(parents=True, exist_ok=True) |
| if paths.bundle_zip.exists(): |
| paths.bundle_zip.unlink() |
| with zipfile.ZipFile(paths.bundle_zip, "w", compression=zipfile.ZIP_DEFLATED) as archive: |
| for path in paths.docs_dir.rglob("*"): |
| if path.is_file() and path.name != ".DS_Store": |
| archive.write(path, arcname=str(path.relative_to(paths.docs_dir.parent))) |
|
|
|
|
| def validate_rewards_in_report(report: dict[str, Any]) -> list[str]: |
| errors: list[str] = [] |
| summaries = report.get("summaries", {}) |
| if isinstance(summaries, dict): |
| for policy, summary in summaries.items(): |
| if not isinstance(summary, dict): |
| continue |
| value = summary.get("avg_reward") |
| if value is None: |
| continue |
| rounded = clamp_reward(value) |
| if rounded != float(value): |
| errors.append(f"{policy}.avg_reward is not clamped/rounded: {value}") |
| for item in report.get("deltas", []) if isinstance(report.get("deltas"), list) else []: |
| for key in ["basic_reward", "pipeline_reward"]: |
| if key in item and clamp_reward(item[key]) != float(item[key]): |
| errors.append(f"delta seed {item.get('seed')} {key} is not clamped/rounded") |
| return errors |
|
|
|
|
| def generate_evidence( |
| *, |
| models: list[str], |
| artifact_repo_id: str, |
| training_space_url: str, |
| paths: EvidencePaths, |
| episodes: int, |
| local_only: bool, |
| replace: bool = True, |
| ) -> dict[str, Any]: |
| ensure_clean_dir(paths.report_dir, replace=replace) |
| ensure_clean_dir(paths.plot_dir, replace=replace) |
| ensure_clean_dir(paths.docs_dir, replace=replace) |
| paths.bundle_zip.parent.mkdir(parents=True, exist_ok=True) |
|
|
| run_ids = [safe_run_id(model) for model in models] |
| token = os.getenv("HF_TOKEN") |
| artifact_listing = list_remote_artifacts(artifact_repo_id, token=token, local_only=local_only) |
| remote_snapshot = download_remote_snapshot(artifact_repo_id, token=token, run_ids=run_ids, local_only=local_only) |
| live_status = fetch_live_status(training_space_url, token=token, local_only=local_only) |
| if live_status.get("status") in {"error", "skipped_local_only"}: |
| fallback = local_status_fallback() |
| if fallback.get("status") != "unavailable": |
| live_status = fallback |
| write_json(paths.report_dir / "hf_status_snapshot.json", live_status) |
| write_json(paths.report_dir / "artifact_repo_listing.json", artifact_listing) |
|
|
| stage_records = extract_stage_records(live_status, run_ids) |
| runs = [ |
| collect_run_artifacts(run_id, paths=paths, remote_snapshot=remote_snapshot, stage_records=stage_records) |
| for run_id in run_ids |
| ] |
|
|
| charts: dict[str, str] = {} |
| charts.update(generate_training_charts(runs, paths)) |
| charts.update(generate_stage_duration_chart(stage_records, paths)) |
|
|
| ablation = load_available_ablation(paths, runs) |
| if not ablation: |
| ablation = maybe_run_policy_ablation(paths, episodes) |
| else: |
| write_json(paths.report_dir / "policy_ablation_report.json", ablation) |
| charts.update(generate_ablation_charts(ablation, paths)) |
|
|
| basic_report = run_basic_llm_vs_pipeline(paths, episodes=episodes) |
| charts.update( |
| { |
| "basic_llm_vs_full_pipeline_reward": str(paths.plot_dir / "basic_llm_vs_full_pipeline_reward.png"), |
| "basic_llm_vs_full_pipeline_legality": str(paths.plot_dir / "basic_llm_vs_full_pipeline_legality.png"), |
| "basic_llm_vs_full_pipeline_latency": str(paths.plot_dir / "basic_llm_vs_full_pipeline_latency.png"), |
| "basic_llm_vs_full_pipeline_reward_delta_by_seed": str( |
| paths.plot_dir / "basic_llm_vs_full_pipeline_reward_delta_by_seed.png" |
| ), |
| } |
| ) |
|
|
| pending_artifacts: list[str] = [] |
| for run in runs: |
| for stage, status in run.statuses.items(): |
| if "pending" in status or status in {"not_seen_in_status", "remote_failed_or_running"}: |
| pending_artifacts.append(f"{run.label} {stage}: {status}") |
| if not run.files.get("grpo_history.json"): |
| pending_artifacts.append(f"{run.label} grpo_history.json: pending_artifact_upload") |
| if not run.files.get("postsave_inference_grpo.json"): |
| pending_artifacts.append(f"{run.label} postsave_inference_grpo.json: pending_artifact_upload") |
|
|
| reward_validation_errors = validate_rewards_in_report(basic_report) |
| manifest = { |
| "status": "ok" if not reward_validation_errors else "failed_reward_validation", |
| "generated_at_unix": time.time(), |
| "models": [ |
| { |
| "run_id": run.run_id, |
| "model_id": run.model_id, |
| "label": run.label, |
| "statuses": run.statuses, |
| "metrics": run.metrics, |
| "files": run.files, |
| } |
| for run in runs |
| ], |
| "artifact_repo": artifact_listing, |
| "remote_snapshot_used": str(remote_snapshot) if remote_snapshot else "", |
| "training_space_status": { |
| "status": live_status.get("status"), |
| "source": live_status.get("source"), |
| "completed_run_ids": live_status.get("completed_run_ids", []), |
| }, |
| "stage_records": stage_records, |
| "charts": charts, |
| "pending_artifacts": sorted(set(pending_artifacts)), |
| "reward_validation_errors": reward_validation_errors, |
| "primary_judge": "PolyGuard verifier/reward system", |
| } |
| write_json(paths.report_dir / "manifest.json", manifest) |
| write_json(paths.report_dir / "submission_summary.json", manifest) |
| readme = build_readme(runs=runs, manifest=manifest, paths=paths, basic_report=basic_report) |
| write_text(paths.report_dir / "README.md", readme) |
|
|
| mirrored = mirror_to_docs(paths) |
| write_text(paths.docs_dir / "README.md", readme) |
| write_json(paths.docs_dir / "manifest.json", manifest) |
| write_json(paths.docs_dir / "submission_summary.json", manifest) |
| write_json(paths.report_dir / "mirrored_files.json", mirrored) |
| zip_docs_bundle(paths) |
| manifest["bundle_zip"] = str(paths.bundle_zip) |
| manifest["mirrored_file_count"] = len(mirrored) |
| write_json(paths.report_dir / "manifest.json", manifest) |
| write_json(paths.docs_dir / "manifest.json", manifest) |
| return manifest |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| models = [item.strip() for item in args.models.split(",") if item.strip()] |
| paths = EvidencePaths( |
| report_dir=Path(args.output_dir), |
| plot_dir=Path(args.plot_dir), |
| docs_dir=Path(args.docs_dir), |
| bundle_zip=Path(args.bundle_zip), |
| ) |
| manifest = generate_evidence( |
| models=models, |
| artifact_repo_id=args.artifact_repo_id, |
| training_space_url=args.training_space_url, |
| paths=paths, |
| episodes=args.episodes, |
| local_only=args.local_only, |
| replace=args.replace, |
| ) |
| print(json.dumps({"status": manifest["status"], "docs_dir": str(paths.docs_dir), "bundle_zip": str(paths.bundle_zip)}, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|