polyguard-openenv-workbench / polyguard-rl /scripts /generate_submission_evidence.py
TheJackBright's picture
Deploy GitHub root master to Space
c296d62
#!/usr/bin/env python3
"""Generate submission evidence for completed Qwen 0.5B/1.5B PolyGuard runs.
This script is intentionally evaluation-only. It never trains or updates model
weights. It gathers any already available local/remote artifacts, records what
is still pending upload, runs deterministic PolyGuard verifier rollouts, and
emits charts/JSON/Markdown suitable for the final submission bundle.
"""
from __future__ import annotations
import argparse
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
import shutil
import statistics
import time
from typing import Any, Iterable
import zipfile
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
try: # Optional; unavailable in local test environments is fine.
from huggingface_hub import HfApi, snapshot_download
except Exception: # noqa: BLE001
HfApi = None # type: ignore[assignment]
snapshot_download = None # type: ignore[assignment]
ROOT = Path(__file__).resolve().parents[1]
DEFAULT_MODELS = "qwen-qwen2-5-0-5b-instruct,qwen-qwen2-5-1-5b-instruct"
DEFAULT_ARTIFACT_REPO = "TheJackBright/polyguard-openenv-training-full-artifacts"
DEFAULT_TRAINING_SPACE_URL = "https://thejackbright-polyguard-openenv-training-full.hf.space"
DEFAULT_REPORT_DIR = ROOT / "outputs" / "reports" / "submission_evidence" / "qwen_0_5b_1_5b"
DEFAULT_PLOT_DIR = ROOT / "outputs" / "plots" / "submission_evidence" / "qwen_0_5b_1_5b"
DEFAULT_DOCS_DIR = ROOT / "docs" / "results" / "submission_evidence_qwen_0_5b_1_5b"
DEFAULT_BUNDLE_ZIP = ROOT / "submission_bundle" / "qwen_0_5b_1_5b_evidence.zip"
RUN_FILE_NAMES = [
"run_metadata.json",
"sft_trl_run.json",
"sft_history.json",
"postsave_inference_sft.json",
"grpo_trl_run.json",
"grpo_history.json",
"grpo_reward_components.jsonl",
"postsave_inference_grpo.json",
"grpo_ablation_report.json",
"error.json",
]
REWARD_COMPONENT_KEYS = [
"format_compliance_score",
"candidate_alignment_score",
"legality_score",
"safety_delta_score",
"burden_improvement_score",
"disease_stability_score",
"dosing_quality_score",
"abstention_quality_score",
"efficiency_score",
"process_fidelity_score",
"explanation_grounding_score",
"anti_cheat_score",
"uncertainty_calibration_score",
]
PRIMARY_CHANNEL_KEYS = [
"safety_legality",
"clinical_improvement",
"dosing_quality",
"process_integrity",
]
@dataclass
class EvidencePaths:
report_dir: Path
plot_dir: Path
docs_dir: Path
bundle_zip: Path
@property
def run_report_dir(self) -> Path:
return self.report_dir / "runs"
@property
def docs_reports_dir(self) -> Path:
return self.docs_dir / "reports"
@property
def docs_charts_dir(self) -> Path:
return self.docs_dir / "charts"
@property
def docs_traces_dir(self) -> Path:
return self.docs_dir / "traces"
@dataclass
class RunEvidence:
run_id: str
model_id: str
label: str
source_dir: Path | None = None
files: dict[str, str] = field(default_factory=dict)
statuses: dict[str, str] = field(default_factory=dict)
metrics: dict[str, Any] = field(default_factory=dict)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate PolyGuard submission evidence without retraining.")
parser.add_argument("--models", default=DEFAULT_MODELS)
parser.add_argument("--artifact-repo-id", default=DEFAULT_ARTIFACT_REPO)
parser.add_argument("--training-space-url", default=DEFAULT_TRAINING_SPACE_URL)
parser.add_argument("--output-dir", default=str(DEFAULT_REPORT_DIR))
parser.add_argument("--plot-dir", default=str(DEFAULT_PLOT_DIR))
parser.add_argument("--docs-dir", default=str(DEFAULT_DOCS_DIR))
parser.add_argument("--bundle-zip", default=str(DEFAULT_BUNDLE_ZIP))
parser.add_argument("--episodes", type=int, default=8)
parser.add_argument("--local-only", action="store_true", help="Do not query Hugging Face.")
parser.add_argument("--allow-network-errors", action="store_true", default=True)
parser.add_argument("--replace", action="store_true", default=True)
return parser.parse_args()
def safe_run_id(value: str) -> str:
value = value.strip()
if "/" not in value and value.startswith("qwen-"):
return value
return "".join(ch if ch.isalnum() else "-" for ch in value).strip("-").lower()
def model_id_from_run_id(value: str) -> str:
mapping = {
"qwen-qwen2-5-0-5b-instruct": "Qwen/Qwen2.5-0.5B-Instruct",
"qwen-qwen2-5-1-5b-instruct": "Qwen/Qwen2.5-1.5B-Instruct",
"qwen-qwen2-5-3b-instruct": "Qwen/Qwen2.5-3B-Instruct",
}
return mapping.get(value, value)
def friendly_label(run_id: str, model_id: str | None = None) -> str:
value = (model_id or run_id).lower()
if "0.5b" in value or "0-5b" in value:
return "Qwen 0.5B"
if "1.5b" in value or "1-5b" in value:
return "Qwen 1.5B"
if "3b" in value or "3-b" in value:
return "Qwen 3B"
return model_id or run_id
def bandit_chart_label(label: str) -> str:
if "bandit" in label.lower():
return label
if "qwen" in label.lower():
return f"{label} + Bandits"
return label
def comparison_policy_label(policy: str) -> str:
labels = {
"basic_llm": "Baseline Basic LLM",
"sft_policy": "SFT Policy Baseline",
"full_polyguard_pipeline": "Full PolyGuard + Bandits",
}
return labels.get(policy, policy.replace("_", " ").title())
def format_model_scope(labels: Iterable[str]) -> str:
chart_labels = [bandit_chart_label(label) for label in labels]
if not chart_labels:
return "Qwen + Bandits"
if len(chart_labels) == 1:
return chart_labels[0]
if len(chart_labels) == 2:
return f"{chart_labels[0]} and {chart_labels[1]}"
return f"{', '.join(chart_labels[:-1])}, and {chart_labels[-1]}"
def ensure_clean_dir(path: Path, *, replace: bool = True) -> None:
if replace and path.exists():
shutil.rmtree(path)
path.mkdir(parents=True, exist_ok=True)
def load_json(path: Path, default: Any = None) -> Any:
if not path.exists():
return default
try:
return json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return default
def load_jsonl(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
rows: list[dict[str, Any]] = []
for line in path.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
try:
payload = json.loads(line)
except json.JSONDecodeError:
continue
if isinstance(payload, dict):
rows.append(payload)
return rows
def write_json(path: Path, payload: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, ensure_ascii=True, indent=2) + "\n", encoding="utf-8")
def write_text(path: Path, text: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(text, encoding="utf-8")
def clamp_reward(value: Any) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
numeric = 0.5
return round(min(0.999, max(0.001, numeric)), 3)
def mean(values: Iterable[float]) -> float:
values = list(values)
return float(statistics.fmean(values)) if values else 0.0
def _plot_finish(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
plt.tight_layout()
plt.savefig(path, dpi=180)
plt.close()
def _plot_empty(path: Path, title: str, message: str) -> None:
plt.figure(figsize=(9, 4.5))
plt.axis("off")
plt.title(title)
plt.text(0.5, 0.5, message, ha="center", va="center", wrap=True)
_plot_finish(path)
def _plot_line(
rows: list[dict[str, Any]],
y_key: str,
path: Path,
*,
title: str,
ylabel: str,
label: str | None = None,
) -> str:
cleaned = [
(int(row.get("step", idx + 1)), float(row[y_key]))
for idx, row in enumerate(rows)
if isinstance(row, dict) and row.get(y_key) is not None
]
if not cleaned:
_plot_empty(path, title, f"No {y_key} data available yet.")
return str(path)
xs, ys = zip(*cleaned)
plt.figure(figsize=(9, 4.5))
plt.plot(xs, ys, linewidth=1.6, label=label or y_key)
plt.title(title)
plt.xlabel("training step")
plt.ylabel(ylabel)
plt.grid(alpha=0.25)
if label:
plt.legend()
_plot_finish(path)
return str(path)
def _plot_multi_line(
series: dict[str, list[dict[str, Any]]],
y_key: str,
path: Path,
*,
title: str,
ylabel: str,
) -> str:
plt.figure(figsize=(9, 4.5))
plotted = False
for label, rows in series.items():
cleaned = [
(int(row.get("step", idx + 1)), float(row[y_key]))
for idx, row in enumerate(rows)
if isinstance(row, dict) and row.get(y_key) is not None
]
if not cleaned:
continue
xs, ys = zip(*cleaned)
plt.plot(xs, ys, linewidth=1.5, label=label)
plotted = True
if not plotted:
plt.close()
_plot_empty(path, title, f"No {y_key} data available yet.")
return str(path)
plt.title(title)
plt.xlabel("training step")
plt.ylabel(ylabel)
plt.grid(alpha=0.25)
plt.legend()
_plot_finish(path)
return str(path)
def _plot_bar(values: dict[str, float], path: Path, *, title: str, ylabel: str, rotation: int = 0) -> str:
cleaned = {key: value for key, value in values.items() if value is not None}
if not cleaned:
_plot_empty(path, title, "No numeric data available yet.")
return str(path)
plt.figure(figsize=(max(8, len(cleaned) * 1.35), 4.8))
labels = list(cleaned)
ys = [float(cleaned[key]) for key in labels]
plt.bar(labels, ys, color="#2f6f7e")
plt.title(title)
plt.ylabel(ylabel)
plt.xticks(rotation=rotation, ha="right" if rotation else "center")
plt.grid(axis="y", alpha=0.22)
_plot_finish(path)
return str(path)
def _copy_file(source: Path, target: Path) -> bool:
if not source.exists() or not source.is_file():
return False
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source, target)
return True
def _copy_tree_files(source: Path, target: Path, suffixes: set[str]) -> list[str]:
copied: list[str] = []
if not source.exists():
return copied
for path in source.rglob("*"):
if not path.is_file() or path.name == ".DS_Store" or path.suffix.lower() not in suffixes:
continue
rel = path.relative_to(source)
dest = target / rel
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, dest)
copied.append(str(dest))
return copied
def list_remote_artifacts(repo_id: str, *, token: str | None, local_only: bool) -> dict[str, Any]:
if local_only:
return {"repo_id": repo_id, "status": "skipped_local_only", "files": [], "error": ""}
if HfApi is None:
return {"repo_id": repo_id, "status": "unavailable_client", "files": [], "error": "huggingface_hub unavailable"}
try:
api = HfApi(token=token)
files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
meaningful = [item for item in files if item != ".gitattributes"]
return {
"repo_id": repo_id,
"status": "ok" if meaningful else "pending_artifact_upload",
"files": files,
"meaningful_file_count": len(meaningful),
"error": "",
}
except Exception as exc: # noqa: BLE001
return {"repo_id": repo_id, "status": "error", "files": [], "error": str(exc)}
def download_remote_snapshot(
repo_id: str,
*,
token: str | None,
run_ids: list[str],
local_only: bool,
) -> Path | None:
if local_only or snapshot_download is None:
return None
allow_patterns: list[str] = [
"outputs/reports/hf_training_status.json",
"outputs/reports/grpo_trl_run.json",
"outputs/reports/grpo_ablation_report.json",
"outputs/plots/*.png",
"docs/results/*.json",
"docs/results/*.png",
]
for run_id in run_ids:
allow_patterns.extend(
[
f"outputs/reports/sweeps/{run_id}/*",
f"checkpoints/sweeps/{run_id}/sft_history.json",
f"checkpoints/sweeps/{run_id}/grpo_history.json",
f"checkpoints/sweeps/{run_id}/grpo_reward_components.jsonl",
]
)
try:
return Path(
snapshot_download(
repo_id=repo_id,
repo_type="model",
token=token,
allow_patterns=allow_patterns,
)
)
except Exception:
return None
def fetch_live_status(training_space_url: str, *, token: str | None, local_only: bool) -> dict[str, Any]:
if local_only:
return {"status": "skipped_local_only", "source": "local-only"}
try:
from gradio_client import Client
except Exception as exc: # noqa: BLE001
return {"status": "error", "source": "gradio_client", "error": str(exc)}
try:
try:
client = Client(training_space_url, token=token) if token else Client(training_space_url)
except TypeError:
client = Client(training_space_url)
result = client.predict(api_name="/read_status")
if isinstance(result, (list, tuple)):
status = result[0] if result else {}
log = result[1] if len(result) > 1 else ""
else:
status = result
log = ""
if not isinstance(status, dict):
status = {"raw_status": status}
status["source"] = training_space_url
status["log_tail"] = str(log)[-12000:]
return status
except Exception as exc: # noqa: BLE001
return {"status": "error", "source": training_space_url, "error": str(exc)}
def local_status_fallback() -> dict[str, Any]:
candidates = [
ROOT / "outputs" / "reports" / "submission_evidence" / "qwen_0_5b_1_5b" / "hf_status_snapshot.json",
ROOT / "docs" / "results" / "submission_evidence_qwen_0_5b_1_5b" / "reports" / "hf_status_snapshot.json",
ROOT / "outputs" / "reports" / "hf_training_status.json",
ROOT / "docs" / "results" / "qwen_completed_runs" / "reports" / "remote_status" / "live_hf_status_snapshot.json",
ROOT / "docs" / "results" / "hf_training_status.json",
]
for path in candidates:
payload = load_json(path)
if isinstance(payload, dict):
payload.setdefault("source", str(path))
return payload
return {"status": "unavailable", "source": "local_fallback"}
def command_model_id(args: list[str]) -> str | None:
for idx, item in enumerate(args):
if item == "--model-id" and idx + 1 < len(args):
return str(args[idx + 1])
return None
def command_output_run_id(args: list[str]) -> str | None:
for idx, item in enumerate(args):
if item == "--output" and idx + 1 < len(args):
output = str(args[idx + 1])
parts = Path(output).parts
for part in parts:
if part.startswith("qwen-qwen2-5-"):
return part
for item in args:
if "qwen-qwen2-5-" in str(item):
for part in Path(str(item)).parts:
if part.startswith("qwen-qwen2-5-"):
return part
return None
def stage_from_command(args: list[str]) -> str | None:
joined = " ".join(str(item) for item in args)
if "scripts/train_sft_trl.py" in joined:
return "sft_training"
if "scripts/train_grpo_trl.py" in joined:
return "grpo_training"
if "scripts/test_inference_postsave.py" in joined:
if "postsave_inference_grpo.json" in joined:
return "grpo_postsave_inference"
if "postsave_inference_sft.json" in joined:
return "sft_postsave_inference"
return "postsave_inference"
if "scripts/evaluate_policy_ablations.py" in joined:
return "policy_ablation"
return None
def extract_stage_records(status: dict[str, Any], run_ids: list[str]) -> list[dict[str, Any]]:
records: list[dict[str, Any]] = []
commands = status.get("commands")
if not isinstance(commands, list):
return records
run_set = set(run_ids)
for command in commands:
if not isinstance(command, dict):
continue
args = command.get("args")
if not isinstance(args, list):
continue
stage = stage_from_command([str(item) for item in args])
if not stage:
continue
model_id = command_model_id(args)
run_id = command_output_run_id(args) or (safe_run_id(model_id) if model_id else "")
if run_id not in run_set:
continue
records.append(
{
"run_id": run_id,
"model_id": model_id or model_id_from_run_id(run_id),
"label": friendly_label(run_id, model_id),
"stage": stage,
"returncode": command.get("returncode"),
"elapsed_seconds": round(float(command.get("elapsed_seconds") or 0.0), 3),
"completed": command.get("returncode") == 0,
}
)
return records
def stage_status(stage_records: list[dict[str, Any]], run_id: str, stage: str) -> str:
matches = [item for item in stage_records if item.get("run_id") == run_id and item.get("stage") == stage]
if not matches:
return "not_seen_in_status"
if any(item.get("completed") is True for item in matches):
return "remote_completed"
return "remote_failed_or_running"
def collect_run_artifacts(
run_id: str,
*,
paths: EvidencePaths,
remote_snapshot: Path | None,
stage_records: list[dict[str, Any]],
) -> RunEvidence:
model_id = model_id_from_run_id(run_id)
evidence = RunEvidence(run_id=run_id, model_id=model_id, label=friendly_label(run_id, model_id))
target_dir = paths.run_report_dir / run_id
target_dir.mkdir(parents=True, exist_ok=True)
source_dirs = []
if remote_snapshot is not None:
source_dirs.extend(
[
remote_snapshot / "outputs" / "reports" / "sweeps" / run_id,
remote_snapshot / "checkpoints" / "sweeps" / run_id,
]
)
source_dirs.append(ROOT / "outputs" / "reports" / "sweeps" / run_id)
source_dirs.append(ROOT / "checkpoints" / "sweeps" / run_id)
for filename in RUN_FILE_NAMES:
copied = False
for source_dir in source_dirs:
source = source_dir / filename
if source.exists():
_copy_file(source, target_dir / filename)
evidence.files[filename] = str(target_dir / filename)
evidence.source_dir = source_dir
copied = True
break
if not copied:
evidence.files[filename] = ""
sft_report = load_json(target_dir / "sft_trl_run.json", {})
sft_history = load_json(target_dir / "sft_history.json", [])
sft_inference = load_json(target_dir / "postsave_inference_sft.json", {})
grpo_report = load_json(target_dir / "grpo_trl_run.json", {})
grpo_history = load_json(target_dir / "grpo_history.json", [])
grpo_inference = load_json(target_dir / "postsave_inference_grpo.json", {})
evidence.statuses["sft_training"] = (
"artifact_available"
if isinstance(sft_report, dict) and sft_report.get("status") == "ok"
else stage_status(stage_records, run_id, "sft_training")
)
evidence.statuses["sft_postsave_inference"] = (
"artifact_available"
if isinstance(sft_inference, dict) and sft_inference.get("status") == "ok"
else stage_status(stage_records, run_id, "sft_postsave_inference")
)
grpo_remote = stage_status(stage_records, run_id, "grpo_training")
evidence.statuses["grpo_training"] = (
"artifact_available"
if isinstance(grpo_report, dict) and grpo_report.get("status") == "ok"
else ("remote_completed_pending_artifact_upload" if grpo_remote == "remote_completed" else grpo_remote)
)
grpo_inference_remote = stage_status(stage_records, run_id, "grpo_postsave_inference")
evidence.statuses["grpo_postsave_inference"] = (
"artifact_available"
if isinstance(grpo_inference, dict) and grpo_inference.get("status") == "ok"
else (
"remote_completed_pending_artifact_upload"
if grpo_inference_remote == "remote_completed"
else grpo_inference_remote
)
)
ablation_remote = stage_status(stage_records, run_id, "policy_ablation")
evidence.statuses["policy_ablation"] = (
"artifact_available"
if (target_dir / "grpo_ablation_report.json").exists()
else ("remote_completed_pending_artifact_upload" if ablation_remote == "remote_completed" else ablation_remote)
)
loss_values = [float(row["loss"]) for row in sft_history if isinstance(row, dict) and row.get("loss") is not None]
accuracy_values = [
float(row["mean_token_accuracy"])
for row in sft_history
if isinstance(row, dict) and row.get("mean_token_accuracy") is not None
]
evidence.metrics = {
"sft_train_loss": sft_report.get("train_loss") if isinstance(sft_report, dict) else None,
"sft_train_runtime": sft_report.get("train_runtime") if isinstance(sft_report, dict) else None,
"sft_examples_used": sft_report.get("examples_used") if isinstance(sft_report, dict) else None,
"sft_history_steps": len(sft_history) if isinstance(sft_history, list) else 0,
"sft_first_loss": loss_values[0] if loss_values else None,
"sft_last_loss": loss_values[-1] if loss_values else None,
"sft_best_loss": min(loss_values) if loss_values else None,
"sft_last_token_accuracy": accuracy_values[-1] if accuracy_values else None,
"sft_valid_rate": sft_inference.get("valid_rate") if isinstance(sft_inference, dict) else None,
"sft_avg_env_reward": sft_inference.get("avg_env_reward") if isinstance(sft_inference, dict) else None,
"sft_avg_latency_seconds": sft_inference.get("avg_latency_seconds") if isinstance(sft_inference, dict) else None,
"grpo_avg_reward": (
grpo_report.get("reward_summary", {}).get("avg_reward")
if isinstance(grpo_report, dict) and isinstance(grpo_report.get("reward_summary"), dict)
else None
),
"grpo_history_steps": len(grpo_history) if isinstance(grpo_history, list) else 0,
"grpo_valid_rate": grpo_inference.get("valid_rate") if isinstance(grpo_inference, dict) else None,
"grpo_avg_env_reward": grpo_inference.get("avg_env_reward") if isinstance(grpo_inference, dict) else None,
"grpo_avg_latency_seconds": grpo_inference.get("avg_latency_seconds") if isinstance(grpo_inference, dict) else None,
}
write_json(target_dir / "availability.json", {"statuses": evidence.statuses, "metrics": evidence.metrics})
return evidence
def generate_training_charts(runs: list[RunEvidence], paths: EvidencePaths) -> dict[str, str]:
charts: dict[str, str] = {}
histories: dict[str, list[dict[str, Any]]] = {}
for run in runs:
history = load_json(paths.run_report_dir / run.run_id / "sft_history.json", [])
if isinstance(history, list):
chart_label = bandit_chart_label(run.label)
histories[chart_label] = history
prefix = "qwen_0_5b" if "0.5B" in run.label else "qwen_1_5b" if "1.5B" in run.label else run.run_id
charts[f"{prefix}_sft_training_loss"] = _plot_line(
history,
"loss",
paths.plot_dir / f"{prefix}_sft_training_loss.png",
title=f"{chart_label} SFT training loss",
ylabel="loss",
label=chart_label,
)
charts[f"{prefix}_sft_token_accuracy"] = _plot_line(
history,
"mean_token_accuracy",
paths.plot_dir / f"{prefix}_sft_token_accuracy.png",
title=f"{chart_label} SFT token accuracy",
ylabel="mean token accuracy",
label=chart_label,
)
charts[f"{prefix}_sft_learning_rate"] = _plot_line(
history,
"learning_rate",
paths.plot_dir / f"{prefix}_sft_learning_rate.png",
title=f"{chart_label} SFT learning rate",
ylabel="learning rate",
label=chart_label,
)
charts["qwen_0_5b_vs_1_5b_sft_loss_comparison"] = _plot_multi_line(
histories,
"loss",
paths.plot_dir / "qwen_0_5b_vs_1_5b_sft_loss_comparison.png",
title=f"{format_model_scope(run.label for run in runs)} SFT loss",
ylabel="loss",
)
charts["qwen_0_5b_vs_1_5b_sft_token_accuracy_comparison"] = _plot_multi_line(
histories,
"mean_token_accuracy",
paths.plot_dir / "qwen_0_5b_vs_1_5b_sft_token_accuracy_comparison.png",
title=f"{format_model_scope(run.label for run in runs)} token accuracy",
ylabel="mean token accuracy",
)
charts["qwen_0_5b_1_5b_final_sft_train_loss"] = _plot_bar(
{
bandit_chart_label(run.label): float(run.metrics["sft_train_loss"])
for run in runs
if run.metrics.get("sft_train_loss") is not None
},
paths.plot_dir / "qwen_0_5b_1_5b_final_sft_train_loss.png",
title="Final SFT train loss for Qwen + Bandits",
ylabel="loss",
)
charts["qwen_0_5b_1_5b_postsave_reward"] = _plot_bar(
{
bandit_chart_label(run.label): clamp_reward(run.metrics["sft_avg_env_reward"])
for run in runs
if run.metrics.get("sft_avg_env_reward") is not None
},
paths.plot_dir / "qwen_0_5b_1_5b_postsave_reward.png",
title="Post-save SFT verifier reward for Qwen + Bandits",
ylabel="avg environment reward",
)
charts["qwen_0_5b_1_5b_postsave_latency"] = _plot_bar(
{
bandit_chart_label(run.label): float(run.metrics["sft_avg_latency_seconds"])
for run in runs
if run.metrics.get("sft_avg_latency_seconds") is not None
},
paths.plot_dir / "qwen_0_5b_1_5b_postsave_latency.png",
title="Post-save SFT inference latency for Qwen + Bandits",
ylabel="seconds",
)
charts["qwen_0_5b_1_5b_sft_runtime"] = _plot_bar(
{
bandit_chart_label(run.label): float(run.metrics["sft_train_runtime"])
for run in runs
if run.metrics.get("sft_train_runtime") is not None
},
paths.plot_dir / "qwen_0_5b_1_5b_sft_runtime.png",
title="Remote SFT runtime for Qwen + Bandits",
ylabel="seconds",
)
return charts
def generate_stage_duration_chart(stage_records: list[dict[str, Any]], paths: EvidencePaths) -> dict[str, str]:
selected = [
record
for record in stage_records
if record.get("completed") is True
and record.get("stage")
in {"sft_training", "grpo_training", "sft_postsave_inference", "grpo_postsave_inference", "policy_ablation"}
]
path = paths.plot_dir / "qwen_0_5b_1_5b_remote_completed_stage_durations.png"
values = {
f"{bandit_chart_label(str(record['label']))}\n{record['stage'].replace('_', ' ')}": float(
record.get("elapsed_seconds") or 0.0
)
for record in selected
}
chart = _plot_bar(values, path, title="HF Space completed stage durations", ylabel="seconds", rotation=35)
write_json(paths.report_dir / "remote_stage_records.json", selected)
return {"qwen_0_5b_1_5b_remote_completed_stage_durations": chart}
def load_available_ablation(paths: EvidencePaths, runs: list[RunEvidence]) -> dict[str, Any]:
candidates = [paths.run_report_dir / run.run_id / "grpo_ablation_report.json" for run in runs]
candidates.extend(
[
ROOT / "outputs" / "reports" / "grpo_ablation_report.json",
ROOT / "outputs" / "reports" / "active_model" / "grpo_ablation_report.json",
]
)
for path in candidates:
payload = load_json(path)
if isinstance(payload, dict) and isinstance(payload.get("ablations"), dict):
payload.setdefault("source", str(path))
return payload
return {}
def maybe_run_policy_ablation(paths: EvidencePaths, episodes: int) -> dict[str, Any]:
existing = load_json(paths.report_dir / "policy_ablation_report.json")
if isinstance(existing, dict) and existing.get("status") == "ok":
return existing
try:
from app.training.grpo_experiment import run_policy_stack_rollout
except Exception as exc: # noqa: BLE001
return {"status": "error", "error": f"policy_ablation_import_failed:{exc}"}
ablations: dict[str, Any] = {}
checkpoint_dir = paths.report_dir / "policy_rollout_artifacts"
for stack in ["bandit-only", "llm-only", "llm+bandit"]:
try:
ablations[stack.replace("-", "_").replace("+", "_")] = run_policy_stack_rollout(
stack,
episodes=max(1, episodes),
checkpoint_dir=checkpoint_dir,
seed_offset=6_500,
)
except Exception as exc: # noqa: BLE001
ablations[stack.replace("-", "_").replace("+", "_")] = {"status": "error", "error": str(exc)}
report = {"status": "ok", "source": "local_evaluation_only_rollout", "episodes": episodes, "ablations": ablations}
write_json(paths.report_dir / "policy_ablation_report.json", report)
return report
def generate_ablation_charts(ablation: dict[str, Any], paths: EvidencePaths) -> dict[str, str]:
charts: dict[str, str] = {}
ablations = ablation.get("ablations") if isinstance(ablation, dict) else None
if not isinstance(ablations, dict):
charts["policy_ablation_avg_reward"] = _plot_bar(
{},
paths.plot_dir / "policy_ablation_avg_reward.png",
title="Policy ablation average reward",
ylabel="avg reward",
)
return charts
def label_for(key: str) -> str:
labels = {
"bandit_only": "Bandits only",
"bandit-only": "Bandits only",
"llm_only": "Baseline LLM only",
"llm-only": "Baseline LLM only",
"llm_bandit": "LLM + Bandits",
"llm+bandit": "LLM + Bandits",
}
return labels.get(key, key.replace("_", " ").replace("+", " + ").title())
charts["policy_ablation_avg_reward"] = _plot_bar(
{
label_for(key): clamp_reward(value.get("avg_reward"))
for key, value in ablations.items()
if isinstance(value, dict) and value.get("avg_reward") is not None
},
paths.plot_dir / "policy_ablation_avg_reward.png",
title="Without Bandits vs With Bandits average reward",
ylabel="avg verifier reward",
)
charts["policy_ablation_legality"] = _plot_bar(
{
label_for(key): float(value.get("legality_rate"))
for key, value in ablations.items()
if isinstance(value, dict) and value.get("legality_rate") is not None
},
paths.plot_dir / "policy_ablation_legality.png",
title="Without Bandits vs With Bandits legality rate",
ylabel="legality rate",
)
charts["policy_ablation_exploit_detection"] = _plot_bar(
{
label_for(key): float(value.get("exploit_detection_count"))
for key, value in ablations.items()
if isinstance(value, dict) and value.get("exploit_detection_count") is not None
},
paths.plot_dir / "policy_ablation_exploit_detection.png",
title="Exploit/repeated-loop detections without vs with Bandits",
ylabel="count",
)
first_valid = next((value for value in ablations.values() if isinstance(value, dict)), {})
components = first_valid.get("reward_columns") if isinstance(first_valid, dict) else None
if isinstance(components, dict):
charts["reward_component_bars"] = _plot_bar(
{key: clamp_reward(components.get(key)) for key in REWARD_COMPONENT_KEYS if key in components},
paths.plot_dir / "reward_component_bars.png",
title="Verifier reward component means",
ylabel="reward",
rotation=45,
)
primary = first_valid.get("primary_reward_channels") if isinstance(first_valid, dict) else None
if isinstance(primary, dict):
charts["primary_reward_channel_bars"] = _plot_bar(
{key: clamp_reward(primary.get(key)) for key in PRIMARY_CHANNEL_KEYS if key in primary},
paths.plot_dir / "primary_reward_channel_bars.png",
title="Primary reward channel means",
ylabel="reward",
rotation=25,
)
return charts
def action_from_candidate(candidate: dict[str, Any], rationale: str) -> dict[str, Any]:
return {
"mode": candidate.get("mode"),
"action_type": candidate.get("action_type"),
"target_drug": candidate.get("target_drug"),
"replacement_drug": candidate.get("replacement_drug"),
"dose_bucket": candidate.get("dose_bucket", "NA"),
"taper_days": candidate.get("taper_days"),
"monitoring_plan": candidate.get("monitoring_plan"),
"evidence_query": candidate.get("evidence_query"),
"new_drug_name": candidate.get("new_drug_name"),
"candidate_components": candidate.get("candidate_components", []),
"candidate_id": candidate.get("candidate_id"),
"confidence": clamp_reward(max(0.45, 1.0 - float(candidate.get("uncertainty_score", 0.5)))),
"rationale_brief": rationale,
}
def select_candidate(policy: str, candidates: list[dict[str, Any]]) -> dict[str, Any]:
legal = [item for item in candidates if item.get("legality_precheck") is True] or candidates
if policy == "basic_llm":
return legal[0]
try:
from app.common.types import CandidateAction
from app.models.policy.safety_ranker import rank_candidates
typed = [CandidateAction.model_validate(item) for item in legal]
ranked = rank_candidates(typed)
if policy in {"sft_policy", "full_polyguard_pipeline"} and ranked:
return ranked[0].model_dump(mode="json")
except Exception:
pass
return sorted(
legal,
key=lambda item: (
bool(item.get("legality_precheck")),
float(item.get("estimated_safety_delta") or 0.0),
-float(item.get("uncertainty_score") or 0.5),
),
reverse=True,
)[0]
def run_basic_llm_vs_pipeline(paths: EvidencePaths, *, episodes: int) -> dict[str, Any]:
from app.agents.orchestrator import Orchestrator
from app.env.env_core import PolyGuardEnv
seeds = [8_000 + idx for idx in range(max(1, episodes))]
policies = ["basic_llm", "sft_policy", "full_polyguard_pipeline"]
trace_rows: list[dict[str, Any]] = []
summaries: dict[str, dict[str, Any]] = {}
previous_stack = os.getenv("POLYGUARD_POLICY_STACK")
previous_active_model = os.getenv("POLYGUARD_ENABLE_ACTIVE_MODEL")
previous_offline = os.getenv("HF_HUB_OFFLINE")
os.environ["POLYGUARD_ENABLE_ACTIVE_MODEL"] = "false"
os.environ.setdefault("HF_HUB_OFFLINE", "1")
for seed in seeds:
for policy in policies:
started = time.monotonic()
env = PolyGuardEnv()
env.reset(seed=seed, difficulty="medium")
if policy == "full_polyguard_pipeline":
os.environ["POLYGUARD_POLICY_STACK"] = "llm+bandit"
out = Orchestrator(env=env).run_step()
reward = clamp_reward(out.get("reward"))
info = out.get("info", {}) if isinstance(out.get("info"), dict) else {}
action = out.get("final_action", {}) if isinstance(out.get("final_action"), dict) else {}
critic = out.get("critic", {}) if isinstance(out.get("critic"), dict) else {}
else:
candidates = env.get_candidate_actions()
selected = select_candidate(policy, candidates)
action = action_from_candidate(
selected,
"Basic prompt-only selection." if policy == "basic_llm" else "SFT-style safety-ranker selection.",
)
_obs, raw_reward, done, info = env.step(action)
reward = clamp_reward(raw_reward)
critic = info.get("safety_report", {}) if isinstance(info, dict) else {}
out = {
"reward": reward,
"done": done,
"info": info,
"final_action": action,
"critic": critic,
"policy_stack": policy,
}
elapsed = round(time.monotonic() - started, 4)
legal = bool(critic.get("legal", info.get("safety_report", {}).get("legal", False)) if isinstance(critic, dict) else False)
reward_breakdown = info.get("reward_breakdown", {}) if isinstance(info, dict) else {}
primary = info.get("primary_reward_channels", {}) if isinstance(info, dict) else {}
trace_rows.append(
{
"seed": seed,
"policy": policy,
"reward": reward,
"latency_seconds": elapsed,
"legal": legal,
"candidate_id": action.get("candidate_id"),
"action_type": action.get("action_type"),
"termination_reason": info.get("termination_reason") if isinstance(info, dict) else None,
"failure_reasons": info.get("failure_reasons", []) if isinstance(info, dict) else [],
"anti_cheat_reasons": info.get("anti_cheat_reasons", []) if isinstance(info, dict) else [],
"reward_breakdown": {key: clamp_reward(value) for key, value in reward_breakdown.items()}
if isinstance(reward_breakdown, dict)
else {},
"primary_reward_channels": {key: clamp_reward(value) for key, value in primary.items()}
if isinstance(primary, dict)
else {},
}
)
if previous_stack is None:
os.environ.pop("POLYGUARD_POLICY_STACK", None)
else:
os.environ["POLYGUARD_POLICY_STACK"] = previous_stack
if previous_active_model is None:
os.environ.pop("POLYGUARD_ENABLE_ACTIVE_MODEL", None)
else:
os.environ["POLYGUARD_ENABLE_ACTIVE_MODEL"] = previous_active_model
if previous_offline is None:
os.environ.pop("HF_HUB_OFFLINE", None)
else:
os.environ["HF_HUB_OFFLINE"] = previous_offline
for policy in policies:
rows = [row for row in trace_rows if row["policy"] == policy]
summaries[policy] = {
"episodes": len(rows),
"avg_reward": clamp_reward(mean(float(row["reward"]) for row in rows)),
"avg_latency_seconds": round(mean(float(row["latency_seconds"]) for row in rows), 4),
"legality_rate": round(mean(1.0 if row["legal"] else 0.0 for row in rows), 3),
"exploit_or_failure_rate": round(
mean(
1.0
if row.get("anti_cheat_reasons") or row.get("failure_reasons") or row.get("termination_reason") == "exploit_detection"
else 0.0
for row in rows
),
3,
),
"candidate_diversity": len({row.get("candidate_id") for row in rows if row.get("candidate_id")}),
}
basic_by_seed = {row["seed"]: row for row in trace_rows if row["policy"] == "basic_llm"}
pipeline_by_seed = {row["seed"]: row for row in trace_rows if row["policy"] == "full_polyguard_pipeline"}
deltas = []
for seed in seeds:
if seed not in basic_by_seed or seed not in pipeline_by_seed:
continue
delta = clamp_reward(float(pipeline_by_seed[seed]["reward"]) - float(basic_by_seed[seed]["reward"]) + 0.5) - 0.5
deltas.append(
{
"seed": seed,
"basic_reward": basic_by_seed[seed]["reward"],
"pipeline_reward": pipeline_by_seed[seed]["reward"],
"reward_delta": round(delta, 3),
"basic_candidate_id": basic_by_seed[seed].get("candidate_id"),
"pipeline_candidate_id": pipeline_by_seed[seed].get("candidate_id"),
"basic_failure_reasons": basic_by_seed[seed].get("failure_reasons", []),
"pipeline_failure_reasons": pipeline_by_seed[seed].get("failure_reasons", []),
}
)
report = {
"status": "ok",
"judge": "PolyGuard verifier/reward system",
"llm_as_judge": os.getenv("POLYGUARD_ENABLE_LLM_JUDGE", "false").lower() in {"1", "true", "yes", "on"},
"matched_seeds": seeds,
"summaries": summaries,
"pipeline_minus_basic_reward_delta": round(
mean(float(item["pipeline_reward"]) - float(item["basic_reward"]) for item in deltas),
3,
)
if deltas
else 0.0,
"deltas": deltas,
"notes": [
"basic_llm is an evaluation-only prompt-style proxy that selects the first legal candidate without verifier reranking.",
"sft_policy is an evaluation-only SFT-style safety ranker over the same candidate set.",
"full_polyguard_pipeline runs the orchestrated LLM+bandit stack and scores through the same verifier.",
],
}
trace_path = paths.report_dir / "action_traces.jsonl"
trace_path.parent.mkdir(parents=True, exist_ok=True)
with trace_path.open("w", encoding="utf-8") as handle:
for row in trace_rows:
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
failure_cases: list[str] = ["# Basic LLM vs PolyGuard Failure Cases", ""]
for item in sorted(deltas, key=lambda row: row["reward_delta"], reverse=True)[:6]:
failure_cases.extend(
[
f"## Seed {item['seed']}",
"",
f"- Baseline attempt: candidate `{item['basic_candidate_id']}`, reward `{float(item['basic_reward']):.3f}`.",
f"- PolyGuard pipeline attempt: candidate `{item['pipeline_candidate_id']}`, reward `{float(item['pipeline_reward']):.3f}`.",
f"- Measured reward delta: `{float(item['reward_delta']):.3f}`.",
"- Safeguard: every selected action is re-scored by the legality gate, anti-cheat checks, and decomposed clinical/process reward channels.",
"",
]
)
write_text(paths.report_dir / "basic_llm_failure_cases.md", "\n".join(failure_cases).rstrip() + "\n")
write_json(paths.report_dir / "basic_llm_vs_polyguard_report.json", report)
_plot_bar(
{comparison_policy_label(policy): float(summary["avg_reward"]) for policy, summary in summaries.items()},
paths.plot_dir / "basic_llm_vs_full_pipeline_reward.png",
title="Baseline Basic LLM vs PolyGuard + Bandits",
ylabel="avg verifier reward",
rotation=20,
)
_plot_bar(
{comparison_policy_label(policy): float(summary["legality_rate"]) for policy, summary in summaries.items()},
paths.plot_dir / "basic_llm_vs_full_pipeline_legality.png",
title="Verifier legality rate by baseline vs Bandits policy",
ylabel="rate",
rotation=20,
)
_plot_bar(
{comparison_policy_label(policy): float(summary["avg_latency_seconds"]) for policy, summary in summaries.items()},
paths.plot_dir / "basic_llm_vs_full_pipeline_latency.png",
title="Evaluation inference latency by baseline vs Bandits policy",
ylabel="seconds",
rotation=20,
)
_plot_bar(
{str(item["seed"]): float(item["reward_delta"]) for item in deltas},
paths.plot_dir / "basic_llm_vs_full_pipeline_reward_delta_by_seed.png",
title="PolyGuard + Bandits minus baseline reward by matched seed",
ylabel="reward delta",
rotation=35,
)
return report
def copy_available_combined_charts(paths: EvidencePaths) -> list[str]:
source = ROOT / "outputs" / "plots"
target = paths.docs_charts_dir / "local_available_combined"
copied: list[str] = []
for filename in [
"sft_loss_curves.png",
"grpo_reward_curves.png",
"sft_vs_grpo_reward.png",
"qwen_model_sft_loss.png",
"qwen_model_sft_reward.png",
"qwen_model_grpo_reward.png",
"reward_component_bars.png",
"train_holdout_gap.png",
"sft_validity_reward.png",
"inference_validity_reward.png",
"inference_latency_validity.png",
"anti_cheat_failure_rates.png",
"policy_stack_avg_reward.png",
"avg_reward.png",
"legality_rate.png",
]:
if _copy_file(source / filename, target / filename):
copied.append(str(target / filename))
return copied
def mirror_to_docs(paths: EvidencePaths) -> list[str]:
copied: list[str] = []
copied.extend(_copy_tree_files(paths.report_dir, paths.docs_reports_dir, {".json", ".jsonl", ".md", ".txt"}))
copied.extend(_copy_tree_files(paths.plot_dir, paths.docs_charts_dir / "generated", {".png"}))
trace_source = paths.report_dir / "action_traces.jsonl"
if trace_source.exists():
_copy_file(trace_source, paths.docs_traces_dir / "action_traces.jsonl")
copied.append(str(paths.docs_traces_dir / "action_traces.jsonl"))
copied.extend(copy_available_combined_charts(paths))
return copied
def build_readme(
*,
runs: list[RunEvidence],
manifest: dict[str, Any],
paths: EvidencePaths,
basic_report: dict[str, Any],
) -> str:
model_scope = format_model_scope(run.label for run in runs)
rows = []
for run in runs:
rows.append(
"| {label} | {sft} | {grpo} | {loss} | {reward} | {latency} |".format(
label=run.label,
sft=run.statuses.get("sft_training", "unknown"),
grpo=run.statuses.get("grpo_training", "unknown"),
loss=(
f"{float(run.metrics['sft_train_loss']):.4f}"
if run.metrics.get("sft_train_loss") is not None
else "pending"
),
reward=(
f"{clamp_reward(run.metrics['sft_avg_env_reward']):.3f}"
if run.metrics.get("sft_avg_env_reward") is not None
else "pending"
),
latency=(
f"{float(run.metrics['sft_avg_latency_seconds']):.3f}s"
if run.metrics.get("sft_avg_latency_seconds") is not None
else "pending"
),
)
)
pending = manifest.get("pending_artifacts", [])
charts = manifest.get("charts", {})
chart_lines = [f"- `{Path(path).name}`" for path in charts.values()]
return "\n".join(
[
f"# PolyGuard Submission Evidence: {model_scope}",
"",
"This folder is generated without retraining. It uses already completed HF Space status, local mirrored sweep artifacts, and deterministic PolyGuard verifier rollouts.",
"",
"## Run Status",
"",
"| Model | SFT training | GRPO training | SFT loss | SFT verifier reward | SFT latency |",
"| --- | --- | --- | ---: | ---: | ---: |",
*rows,
"",
"## Basic LLM vs Full PolyGuard + Bandits Pipeline",
"",
f"- Judge: `{basic_report.get('judge', 'PolyGuard verifier/reward system')}`.",
f"- Matched seeds: `{len(basic_report.get('matched_seeds', []))}`.",
f"- PolyGuard + Bandits minus basic average reward delta: `{float(basic_report.get('pipeline_minus_basic_reward_delta', 0.0)):.3f}`.",
"- LLM-as-judge is optional and disabled unless `POLYGUARD_ENABLE_LLM_JUDGE=true`.",
"",
"## Pending Items",
"",
*((f"- {item}" for item in pending) if pending else ["- No pending artifact markers were emitted."]),
"",
"## Generated Charts",
"",
*chart_lines,
"",
"## Important Honesty Note",
"",
"Remote-completed stages and uploaded artifact files are tracked separately. If a GRPO run completed on the HF Space but the per-run GRPO history file has not been uploaded yet, this bundle labels it as `remote_completed_pending_artifact_upload` instead of inventing a curve.",
"",
]
)
def zip_docs_bundle(paths: EvidencePaths) -> None:
paths.bundle_zip.parent.mkdir(parents=True, exist_ok=True)
if paths.bundle_zip.exists():
paths.bundle_zip.unlink()
with zipfile.ZipFile(paths.bundle_zip, "w", compression=zipfile.ZIP_DEFLATED) as archive:
for path in paths.docs_dir.rglob("*"):
if path.is_file() and path.name != ".DS_Store":
archive.write(path, arcname=str(path.relative_to(paths.docs_dir.parent)))
def validate_rewards_in_report(report: dict[str, Any]) -> list[str]:
errors: list[str] = []
summaries = report.get("summaries", {})
if isinstance(summaries, dict):
for policy, summary in summaries.items():
if not isinstance(summary, dict):
continue
value = summary.get("avg_reward")
if value is None:
continue
rounded = clamp_reward(value)
if rounded != float(value):
errors.append(f"{policy}.avg_reward is not clamped/rounded: {value}")
for item in report.get("deltas", []) if isinstance(report.get("deltas"), list) else []:
for key in ["basic_reward", "pipeline_reward"]:
if key in item and clamp_reward(item[key]) != float(item[key]):
errors.append(f"delta seed {item.get('seed')} {key} is not clamped/rounded")
return errors
def generate_evidence(
*,
models: list[str],
artifact_repo_id: str,
training_space_url: str,
paths: EvidencePaths,
episodes: int,
local_only: bool,
replace: bool = True,
) -> dict[str, Any]:
ensure_clean_dir(paths.report_dir, replace=replace)
ensure_clean_dir(paths.plot_dir, replace=replace)
ensure_clean_dir(paths.docs_dir, replace=replace)
paths.bundle_zip.parent.mkdir(parents=True, exist_ok=True)
run_ids = [safe_run_id(model) for model in models]
token = os.getenv("HF_TOKEN")
artifact_listing = list_remote_artifacts(artifact_repo_id, token=token, local_only=local_only)
remote_snapshot = download_remote_snapshot(artifact_repo_id, token=token, run_ids=run_ids, local_only=local_only)
live_status = fetch_live_status(training_space_url, token=token, local_only=local_only)
if live_status.get("status") in {"error", "skipped_local_only"}:
fallback = local_status_fallback()
if fallback.get("status") != "unavailable":
live_status = fallback
write_json(paths.report_dir / "hf_status_snapshot.json", live_status)
write_json(paths.report_dir / "artifact_repo_listing.json", artifact_listing)
stage_records = extract_stage_records(live_status, run_ids)
runs = [
collect_run_artifacts(run_id, paths=paths, remote_snapshot=remote_snapshot, stage_records=stage_records)
for run_id in run_ids
]
charts: dict[str, str] = {}
charts.update(generate_training_charts(runs, paths))
charts.update(generate_stage_duration_chart(stage_records, paths))
ablation = load_available_ablation(paths, runs)
if not ablation:
ablation = maybe_run_policy_ablation(paths, episodes)
else:
write_json(paths.report_dir / "policy_ablation_report.json", ablation)
charts.update(generate_ablation_charts(ablation, paths))
basic_report = run_basic_llm_vs_pipeline(paths, episodes=episodes)
charts.update(
{
"basic_llm_vs_full_pipeline_reward": str(paths.plot_dir / "basic_llm_vs_full_pipeline_reward.png"),
"basic_llm_vs_full_pipeline_legality": str(paths.plot_dir / "basic_llm_vs_full_pipeline_legality.png"),
"basic_llm_vs_full_pipeline_latency": str(paths.plot_dir / "basic_llm_vs_full_pipeline_latency.png"),
"basic_llm_vs_full_pipeline_reward_delta_by_seed": str(
paths.plot_dir / "basic_llm_vs_full_pipeline_reward_delta_by_seed.png"
),
}
)
pending_artifacts: list[str] = []
for run in runs:
for stage, status in run.statuses.items():
if "pending" in status or status in {"not_seen_in_status", "remote_failed_or_running"}:
pending_artifacts.append(f"{run.label} {stage}: {status}")
if not run.files.get("grpo_history.json"):
pending_artifacts.append(f"{run.label} grpo_history.json: pending_artifact_upload")
if not run.files.get("postsave_inference_grpo.json"):
pending_artifacts.append(f"{run.label} postsave_inference_grpo.json: pending_artifact_upload")
reward_validation_errors = validate_rewards_in_report(basic_report)
manifest = {
"status": "ok" if not reward_validation_errors else "failed_reward_validation",
"generated_at_unix": time.time(),
"models": [
{
"run_id": run.run_id,
"model_id": run.model_id,
"label": run.label,
"statuses": run.statuses,
"metrics": run.metrics,
"files": run.files,
}
for run in runs
],
"artifact_repo": artifact_listing,
"remote_snapshot_used": str(remote_snapshot) if remote_snapshot else "",
"training_space_status": {
"status": live_status.get("status"),
"source": live_status.get("source"),
"completed_run_ids": live_status.get("completed_run_ids", []),
},
"stage_records": stage_records,
"charts": charts,
"pending_artifacts": sorted(set(pending_artifacts)),
"reward_validation_errors": reward_validation_errors,
"primary_judge": "PolyGuard verifier/reward system",
}
write_json(paths.report_dir / "manifest.json", manifest)
write_json(paths.report_dir / "submission_summary.json", manifest)
readme = build_readme(runs=runs, manifest=manifest, paths=paths, basic_report=basic_report)
write_text(paths.report_dir / "README.md", readme)
mirrored = mirror_to_docs(paths)
write_text(paths.docs_dir / "README.md", readme)
write_json(paths.docs_dir / "manifest.json", manifest)
write_json(paths.docs_dir / "submission_summary.json", manifest)
write_json(paths.report_dir / "mirrored_files.json", mirrored)
zip_docs_bundle(paths)
manifest["bundle_zip"] = str(paths.bundle_zip)
manifest["mirrored_file_count"] = len(mirrored)
write_json(paths.report_dir / "manifest.json", manifest)
write_json(paths.docs_dir / "manifest.json", manifest)
return manifest
def main() -> None:
args = parse_args()
models = [item.strip() for item in args.models.split(",") if item.strip()]
paths = EvidencePaths(
report_dir=Path(args.output_dir),
plot_dir=Path(args.plot_dir),
docs_dir=Path(args.docs_dir),
bundle_zip=Path(args.bundle_zip),
)
manifest = generate_evidence(
models=models,
artifact_repo_id=args.artifact_repo_id,
training_space_url=args.training_space_url,
paths=paths,
episodes=args.episodes,
local_only=args.local_only,
replace=args.replace,
)
print(json.dumps({"status": manifest["status"], "docs_dir": str(paths.docs_dir), "bundle_zip": str(paths.bundle_zip)}, indent=2))
if __name__ == "__main__":
main()