"""Gradio runner for the private Hugging Face training Space.""" from __future__ import annotations import json import os from pathlib import Path import shutil import subprocess import threading import time from typing import Any import gradio as gr from huggingface_hub import HfApi from huggingface_hub import snapshot_download ROOT = Path(__file__).resolve().parents[2] LOG_DIR = ROOT / "outputs" / "logs" REPORT_DIR = ROOT / "outputs" / "reports" STATUS_PATH = REPORT_DIR / "hf_training_status.json" LOG_PATH = LOG_DIR / "hf_training.log" LOCK = threading.Lock() STATUS: dict[str, Any] = { "status": "idle", "started_at": None, "finished_at": None, "commands": [], "artifact_repo_id": os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-full-artifacts"), "training_mode": os.getenv("POLYGUARD_TRAINING_MODE", "full"), "model_sweep": os.getenv( "POLYGUARD_MODEL_SWEEP", "Qwen/Qwen2.5-0.5B-Instruct,Qwen/Qwen2.5-1.5B-Instruct,Qwen/Qwen2.5-3B-Instruct", ), } def _env_bool(name: str, default: bool = False) -> bool: value = os.getenv(name) if value is None: return default return value.lower() in {"1", "true", "yes", "on"} def _write_status() -> None: REPORT_DIR.mkdir(parents=True, exist_ok=True) STATUS_PATH.write_text(json.dumps(STATUS, ensure_ascii=True, indent=2), encoding="utf-8") def _append_log(message: str) -> None: LOG_DIR.mkdir(parents=True, exist_ok=True) with LOG_PATH.open("a", encoding="utf-8") as handle: handle.write(message.rstrip() + "\n") def _upload_relpath(rel: str, *, commit_suffix: str = "") -> None: if not _env_bool("POLYGUARD_INCREMENTAL_UPLOAD", True): return token = os.getenv("HF_TOKEN") repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-full-artifacts") if not token: return path = ROOT / rel if not path.exists(): return try: api = HfApi(token=token) api.create_repo(repo_id=repo_id, repo_type="model", private=True, exist_ok=True) if path.is_file(): api.upload_file( repo_id=repo_id, repo_type="model", path_or_fileobj=str(path), path_in_repo=rel, commit_message=f"Upload PolyGuard artifact: {commit_suffix or rel}", ) else: api.upload_folder( repo_id=repo_id, repo_type="model", folder_path=str(path), path_in_repo=rel, commit_message=f"Upload PolyGuard artifact folder: {commit_suffix or rel}", ignore_patterns=[".DS_Store", "**/.DS_Store"], ) except Exception as exc: # noqa: BLE001 _append_log(f"incremental_upload_skipped:{rel}:{exc}") def _upload_status_and_log(context: str) -> None: _upload_relpath("outputs/reports/hf_training_status.json", commit_suffix=f"status {context}") _upload_relpath("outputs/logs/hf_training.log", commit_suffix=f"log {context}") def _upload_run_snapshot(run_id: str, stage: str) -> None: if not _env_bool("POLYGUARD_UPLOAD_AFTER_EACH_STAGE", True): return _upload_status_and_log(f"{run_id} {stage}") _upload_relpath(f"outputs/reports/sweeps/{run_id}", commit_suffix=f"{run_id} reports after {stage}") _upload_relpath(f"checkpoints/sweeps/{run_id}", commit_suffix=f"{run_id} checkpoints after {stage}") def _run_command(args: list[str], env: dict[str, str]) -> None: started = time.time() last_incremental_upload = started _append_log(f"$ {' '.join(args)}") proc = subprocess.Popen( args, cwd=ROOT, env=env, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) assert proc.stdout is not None saw_output = False for line in proc.stdout: saw_output = True _append_log(line) now = time.time() if now - last_incremental_upload >= _env_int("POLYGUARD_LOG_UPLOAD_INTERVAL_SECONDS", 180): _upload_status_and_log("running") last_incremental_upload = now proc.wait() elapsed = round(time.time() - started, 3) record = { "args": args, "returncode": proc.returncode, "elapsed_seconds": elapsed, } with LOCK: STATUS["commands"].append(record) _write_status() _upload_status_and_log("command_complete") if proc.returncode != 0: if not saw_output: _append_log("") _upload_status_and_log("command_failed") raise RuntimeError(f"command_failed:{args}:{proc.returncode}") def _env_int(name: str, default: int) -> int: try: return int(os.getenv(name, str(default))) except ValueError: return default def _env_float(name: str, default: float) -> float: try: return float(os.getenv(name, str(default))) except ValueError: return default def _csv_env(name: str, default: str) -> list[str]: value = os.getenv(name, default) return [item.strip() for item in value.split(",") if item.strip()] def _indexed_int_env(name: str, index: int, default: int) -> int: values = _csv_env(name, "") if index >= len(values): return default try: return int(values[index]) except ValueError: return default def _indexed_float_env(name: str, index: int, default: float) -> float: values = _csv_env(name, "") if index >= len(values): return default try: return float(values[index]) except ValueError: return default def _safe_name(value: str) -> str: return "".join(ch if ch.isalnum() else "-" for ch in value).strip("-").lower() def _copy_file_if_exists(source: Path, target: Path) -> None: if source.exists(): target.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(source, target) def _copy_dir_if_exists(source: Path, target: Path) -> None: if source.exists(): target.parent.mkdir(parents=True, exist_ok=True) shutil.copytree(source, target, dirs_exist_ok=True) def _record_reused_artifact(name: str, path: Path) -> None: with LOCK: STATUS["commands"].append( { "args": ["reuse_artifact", name, str(path)], "returncode": 0, "elapsed_seconds": 0.0, } ) _write_status() def _restore_remote_artifacts() -> None: if os.getenv("POLYGUARD_REUSE_REMOTE_GRPO", "false").lower() not in {"1", "true", "yes", "on"}: return token = os.getenv("HF_TOKEN") repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-full-artifacts") if not token: return try: snapshot = Path( snapshot_download( repo_id=repo_id, repo_type="model", token=token, allow_patterns=[ "checkpoints/grpo_adapter/*", "outputs/reports/grpo_trl_run.json", ], ) ) except Exception as exc: # noqa: BLE001 _append_log(f"remote_artifact_restore_skipped:{exc}") return for rel in ["checkpoints/grpo_adapter", "outputs/reports/grpo_trl_run.json"]: source = snapshot / rel target = ROOT / rel if source.is_dir(): shutil.copytree(source, target, dirs_exist_ok=True) elif source.is_file(): target.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(source, target) def _grpo_artifact_ready() -> bool: report = REPORT_DIR / "grpo_trl_run.json" adapter = ROOT / "checkpoints" / "grpo_adapter" if not report.exists() or not adapter.exists(): return False if not (adapter / "adapter_config.json").exists() or not (adapter / "adapter_model.safetensors").exists(): return False try: payload = json.loads(report.read_text(encoding="utf-8")) except json.JSONDecodeError: return False return payload.get("status") == "ok" and bool(payload.get("artifact_path")) def _mirror_results() -> None: docs_results = ROOT / "docs" / "results" docs_results.mkdir(parents=True, exist_ok=True) for source_dir in [REPORT_DIR, ROOT / "outputs" / "plots"]: if not source_dir.exists(): continue for path in source_dir.rglob("*"): if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}: target = docs_results / path.relative_to(source_dir) target.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(path, target) def _upload_artifacts() -> None: token = os.getenv("HF_TOKEN") repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-full-artifacts") if not token: _append_log("HF_TOKEN missing; artifact upload skipped") return api = HfApi(token=token) api.create_repo(repo_id=repo_id, repo_type="model", private=True, exist_ok=True) for rel in [ "outputs/reports", "outputs/plots", "docs/results", "checkpoints/sft_adapter", "checkpoints/grpo_adapter", "checkpoints/merged", "checkpoints/sweeps", ]: path = ROOT / rel if path.exists(): api.upload_folder( repo_id=repo_id, repo_type="model", folder_path=str(path), path_in_repo=rel, commit_message=f"Upload PolyGuard training artifacts: {rel}", ) def _improved() -> bool: path = REPORT_DIR / "improvement_report.json" if not path.exists(): return False try: payload = json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError: return False return payload.get("improved") is True def _promote_run_artifacts(run_id: str) -> None: checkpoint_dir = ROOT / "checkpoints" / "sweeps" / run_id report_dir = REPORT_DIR / "sweeps" / run_id _copy_dir_if_exists(checkpoint_dir / "sft_adapter", ROOT / "checkpoints" / "sft_adapter") _copy_dir_if_exists(checkpoint_dir / "grpo_adapter", ROOT / "checkpoints" / "grpo_adapter") _copy_dir_if_exists(checkpoint_dir / "merged", ROOT / "checkpoints" / "merged") _copy_file_if_exists(report_dir / "sft_trl_run.json", REPORT_DIR / "sft_trl_run.json") _copy_file_if_exists(report_dir / "grpo_trl_run.json", REPORT_DIR / "grpo_trl_run.json") _copy_file_if_exists(report_dir / "postsave_inference_grpo.json", REPORT_DIR / "postsave_inference.json") _copy_file_if_exists(report_dir / "grpo_ablation_report.json", REPORT_DIR / "grpo_ablation_report.json") def _promote_sft_run_artifacts(run_id: str) -> None: checkpoint_dir = ROOT / "checkpoints" / "sweeps" / run_id report_dir = REPORT_DIR / "sweeps" / run_id _copy_dir_if_exists(checkpoint_dir / "sft_adapter", ROOT / "checkpoints" / "sft_adapter") _copy_dir_if_exists(checkpoint_dir / "merged", ROOT / "checkpoints" / "merged") _copy_file_if_exists(report_dir / "sft_trl_run.json", REPORT_DIR / "sft_trl_run.json") _copy_file_if_exists(report_dir / "postsave_inference_sft.json", REPORT_DIR / "postsave_inference.json") def _run_model_experiment( model_id: str, env: dict[str, str], *, model_index: int, run_grpo: bool, ) -> str: run_id = _safe_name(model_id) checkpoint_dir = ROOT / "checkpoints" / "sweeps" / run_id report_dir = REPORT_DIR / "sweeps" / run_id checkpoint_dir.mkdir(parents=True, exist_ok=True) report_dir.mkdir(parents=True, exist_ok=True) sft_epochs = _indexed_int_env("POLYGUARD_SFT_EPOCH_SWEEP", model_index, _env_int("POLYGUARD_SFT_EPOCHS", 2)) sft_max_steps = _indexed_int_env( "POLYGUARD_SFT_MAX_STEP_SWEEP", model_index, _env_int("POLYGUARD_SFT_MAX_STEPS", 0), ) sft_batch_size = _indexed_int_env( "POLYGUARD_SFT_BATCH_SIZE_SWEEP", model_index, _env_int("POLYGUARD_SFT_BATCH_SIZE", 2), ) sft_learning_rate = _indexed_float_env( "POLYGUARD_SFT_LEARNING_RATE_SWEEP", model_index, _env_float("POLYGUARD_SFT_LEARNING_RATE", 2e-5), ) grpo_epochs = _env_float("POLYGUARD_GRPO_EPOCHS", 1.0) grpo_max_steps = _env_int("POLYGUARD_GRPO_MAX_STEPS", 0) grpo_max_prompts = _env_int("POLYGUARD_GRPO_MAX_PROMPTS", 0) _append_log(f"model_experiment_start:{model_id}") (report_dir / "run_metadata.json").write_text( json.dumps( { "training_mode": "full" if run_grpo else "sft-baseline", "model_id": model_id, "model_index": model_index, "sft_epochs": sft_epochs, "sft_max_steps": sft_max_steps, "sft_batch_size": sft_batch_size, "sft_learning_rate": sft_learning_rate, }, ensure_ascii=True, indent=2, ), encoding="utf-8", ) _run_command( [ "python", "scripts/train_sft_trl.py", "--model-id", model_id, "--dataset-path", "data/processed/training_corpus_sft.json", "--output-dir", f"checkpoints/sweeps/{run_id}", "--report-path", f"outputs/reports/sweeps/{run_id}/sft_trl_run.json", "--epochs", str(sft_epochs), "--max-steps", str(sft_max_steps), "--batch-size", str(sft_batch_size), "--max-seq-len", str(_env_int("POLYGUARD_SFT_MAX_SEQ_LEN", 512)), "--learning-rate", str(sft_learning_rate), "--use-unsloth", ], env, ) _copy_file_if_exists(checkpoint_dir / "sft_history.json", report_dir / "sft_history.json") _upload_run_snapshot(run_id, "sft_training") if run_grpo: _run_command( [ "python", "scripts/train_grpo_trl.py", "--model-id", model_id, "--prompts-path", "data/processed/training_corpus_grpo_prompts.jsonl", "--output-dir", f"checkpoints/sweeps/{run_id}", "--report-path", f"outputs/reports/sweeps/{run_id}/grpo_trl_run.json", "--max-prompts", str(grpo_max_prompts), "--max-steps", str(grpo_max_steps), "--epochs", str(grpo_epochs), "--batch-size", str(_env_int("POLYGUARD_GRPO_BATCH_SIZE", 2)), "--grad-accum", str(_env_int("POLYGUARD_GRPO_GRAD_ACCUM", 1)), "--num-generations", str(_env_int("POLYGUARD_GRPO_NUM_GENERATIONS", 2)), "--max-prompt-length", str(_env_int("POLYGUARD_GRPO_MAX_PROMPT_LENGTH", 384)), "--max-completion-length", str(_env_int("POLYGUARD_GRPO_MAX_COMPLETION_LENGTH", 64)), "--learning-rate", str(_env_float("POLYGUARD_GRPO_LEARNING_RATE", 1e-6)), "--use-unsloth", ], env, ) _copy_file_if_exists(checkpoint_dir / "grpo_history.json", report_dir / "grpo_history.json") _copy_file_if_exists(checkpoint_dir / "grpo_reward_components.jsonl", report_dir / "grpo_reward_components.jsonl") _upload_run_snapshot(run_id, "grpo_training") _run_command( [ "python", "scripts/merge_adapters_safe.py", "--adapter-dir", f"checkpoints/sweeps/{run_id}/sft_adapter", "--output-dir", f"checkpoints/sweeps/{run_id}/merged", ], env, ) _upload_run_snapshot(run_id, "sft_merge") _run_command( [ "python", "scripts/test_inference_postsave.py", "--samples", str(_env_int("POLYGUARD_INFERENCE_SAMPLES", 5)), "--base-model", model_id, "--merged-model", f"checkpoints/sweeps/{run_id}/merged", "--adapter-dir", f"checkpoints/sweeps/{run_id}/sft_adapter", "--output", f"outputs/reports/sweeps/{run_id}/postsave_inference_sft.json", ], env, ) _upload_run_snapshot(run_id, "sft_postsave_inference") if run_grpo: _run_command( [ "python", "scripts/test_inference_postsave.py", "--samples", str(_env_int("POLYGUARD_INFERENCE_SAMPLES", 5)), "--base-model", model_id, "--merged-model", f"checkpoints/sweeps/{run_id}/missing_merged_grpo", "--adapter-dir", f"checkpoints/sweeps/{run_id}/grpo_adapter", "--output", f"outputs/reports/sweeps/{run_id}/postsave_inference_grpo.json", ], env, ) _upload_run_snapshot(run_id, "grpo_postsave_inference") _run_command( [ "python", "scripts/evaluate_policy_ablations.py", "--episodes", str(_env_int("POLYGUARD_ABLATION_EPISODES", 8)), "--checkpoint-dir", f"checkpoints/sweeps/{run_id}", "--output", f"outputs/reports/sweeps/{run_id}/grpo_ablation_report.json", ], env, ) _promote_run_artifacts(run_id) _upload_run_snapshot(run_id, "policy_ablation") for rel in [ "checkpoints/sft_adapter", "checkpoints/grpo_adapter", "checkpoints/merged", "outputs/reports/sft_trl_run.json", "outputs/reports/grpo_trl_run.json", "outputs/reports/postsave_inference.json", "outputs/reports/grpo_ablation_report.json", ]: _upload_relpath(rel, commit_suffix=f"promoted {run_id}") else: _promote_sft_run_artifacts(run_id) _upload_run_snapshot(run_id, "sft_promoted") for rel in [ "checkpoints/sft_adapter", "checkpoints/merged", "outputs/reports/sft_trl_run.json", "outputs/reports/postsave_inference.json", ]: _upload_relpath(rel, commit_suffix=f"promoted {run_id}") _append_log(f"model_experiment_done:{model_id}") _upload_run_snapshot(run_id, "complete") return run_id def _train() -> dict[str, Any]: training_mode = os.getenv("POLYGUARD_TRAINING_MODE", "full").strip().lower() run_grpo = training_mode not in {"sft", "sft-only", "sft-baseline", "sft_baseline"} model_sweep = _csv_env( "POLYGUARD_MODEL_SWEEP", "Qwen/Qwen2.5-0.5B-Instruct,Qwen/Qwen2.5-1.5B-Instruct,Qwen/Qwen2.5-3B-Instruct", ) env = os.environ.copy() env.setdefault("POLYGUARD_OFFLINE_MODE", "false") env.pop("HF_HUB_ENABLE_HF_TRANSFER", None) env.setdefault("TOKENIZERS_PARALLELISM", "false") setup_commands = [ ["python", "scripts/bootstrap_data.py"], ["python", "scripts/build_training_corpus.py", "--profile", "massive", "--with-local", "--with-synthetic", "--with-hf"], ] with LOCK: STATUS.update( { "status": "running", "started_at": time.time(), "finished_at": None, "commands": [], "model_sweep": model_sweep, "training_mode": "full" if run_grpo else "sft-baseline", } ) _write_status() LOG_PATH.unlink(missing_ok=True) _restore_remote_artifacts() try: for command in setup_commands: _run_command(command, env) completed_run_ids: list[str] = [] for model_index, model_id in enumerate(model_sweep): run_id = _safe_name(model_id) try: completed_run_ids.append( _run_model_experiment( model_id=model_id, env=env, model_index=model_index, run_grpo=run_grpo, ) ) except Exception as exc: # noqa: BLE001 error_dir = REPORT_DIR / "sweeps" / run_id error_dir.mkdir(parents=True, exist_ok=True) (error_dir / "error.json").write_text( json.dumps( {"status": "failed", "model_id": model_id, "error": str(exc)}, ensure_ascii=True, indent=2, ), encoding="utf-8", ) _append_log(f"model_experiment_failed:{model_id}:{exc}") _upload_run_snapshot(run_id, "failed") if not completed_run_ids: raise RuntimeError("all_model_experiments_failed") if run_grpo and _grpo_artifact_ready(): _append_log("top_level_grpo_adapter_ready") _record_reused_artifact("grpo_adapter", ROOT / "checkpoints" / "grpo_adapter") eval_commands = [ ["python", "scripts/evaluate_baselines.py"], ["python", "scripts/evaluate_all.py"], [ "python", "scripts/evaluate_compare_runs.py", "--baseline", "outputs/reports/baselines.json", "--candidate", "outputs/reports/benchmark_report.json", "--output", "outputs/reports/improvement_report.json", ], ["python", "scripts/benchmark_inference.py"], ] if run_grpo: eval_commands.append(["python", "scripts/run_robustness_suite.py"]) eval_commands.append(["python", "scripts/generate_hf_training_report.py", "--mode", "full" if run_grpo else "sft-baseline"]) for command in eval_commands: _run_command(command, env) anti_hacking = {} anti_path = REPORT_DIR / "anti_hacking_overfit_report.json" if anti_path.exists(): anti_hacking = json.loads(anti_path.read_text(encoding="utf-8")) with LOCK: STATUS.update( { "status": "ok", "finished_at": time.time(), "improved": _improved(), "anti_hacking_passed": anti_hacking.get("passed"), "completed_run_ids": completed_run_ids, } ) _write_status() _mirror_results() _upload_artifacts() except Exception as exc: # noqa: BLE001 _append_log(str(exc)) with LOCK: STATUS.update({"status": "failed", "finished_at": time.time(), "error": str(exc)}) _write_status() _mirror_results() _upload_artifacts() return STATUS def run_training() -> tuple[dict[str, Any], str]: with LOCK: if STATUS.get("status") == "running": return STATUS, LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else "" thread = threading.Thread(target=_train, daemon=True) thread.start() return STATUS, "training started" def read_status() -> tuple[dict[str, Any], str]: log = LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else "" return STATUS, log[-20000:] def build_app() -> gr.Blocks: with gr.Blocks(title="PolyGuard HF Training") as demo: gr.Markdown("# PolyGuard HF Training") run_button = gr.Button("Run training", variant="primary") refresh_button = gr.Button("Refresh") status_box = gr.JSON(label="Status", value=STATUS) log_box = gr.Textbox(label="Logs", lines=26) run_button.click(fn=run_training, outputs=[status_box, log_box]) refresh_button.click(fn=read_status, outputs=[status_box, log_box]) return demo if os.getenv("POLYGUARD_AUTORUN", "1").lower() in {"1", "true", "yes", "on"}: threading.Thread(target=_train, daemon=True).start() app = build_app() if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)