"""Gradio runner for the private Hugging Face training Space.""" from __future__ import annotations import json import os from pathlib import Path import shutil import subprocess import threading import time from typing import Any import gradio as gr from huggingface_hub import HfApi ROOT = Path(__file__).resolve().parents[2] LOG_DIR = ROOT / "outputs" / "logs" REPORT_DIR = ROOT / "outputs" / "reports" STATUS_PATH = REPORT_DIR / "hf_training_status.json" LOG_PATH = LOG_DIR / "hf_training.log" LOCK = threading.Lock() STATUS: dict[str, Any] = { "status": "idle", "started_at": None, "finished_at": None, "commands": [], "artifact_repo_id": os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-artifacts"), } def _write_status() -> None: REPORT_DIR.mkdir(parents=True, exist_ok=True) STATUS_PATH.write_text(json.dumps(STATUS, ensure_ascii=True, indent=2), encoding="utf-8") def _append_log(message: str) -> None: LOG_DIR.mkdir(parents=True, exist_ok=True) with LOG_PATH.open("a", encoding="utf-8") as handle: handle.write(message.rstrip() + "\n") def _run_command(args: list[str], env: dict[str, str]) -> None: started = time.time() _append_log(f"$ {' '.join(args)}") proc = subprocess.run( args, cwd=ROOT, env=env, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=False, ) _append_log(proc.stdout or "") elapsed = round(time.time() - started, 3) record = { "args": args, "returncode": proc.returncode, "elapsed_seconds": elapsed, } with LOCK: STATUS["commands"].append(record) _write_status() if proc.returncode != 0: raise RuntimeError(f"command_failed:{args}:{proc.returncode}") def _mirror_results() -> None: docs_results = ROOT / "docs" / "results" docs_results.mkdir(parents=True, exist_ok=True) for source_dir in [REPORT_DIR, ROOT / "outputs" / "plots"]: if not source_dir.exists(): continue for path in source_dir.iterdir(): if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}: shutil.copy2(path, docs_results / path.name) def _upload_artifacts() -> None: token = os.getenv("HF_TOKEN") repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-artifacts") if not token: _append_log("HF_TOKEN missing; artifact upload skipped") return api = HfApi(token=token) api.create_repo(repo_id=repo_id, repo_type="model", private=True, exist_ok=True) for rel in [ "outputs/reports", "outputs/plots", "docs/results", "checkpoints/sft_adapter", "checkpoints/grpo_adapter", "checkpoints/merged", ]: path = ROOT / rel if path.exists(): api.upload_folder( repo_id=repo_id, repo_type="model", folder_path=str(path), path_in_repo=rel, commit_message=f"Upload PolyGuard training artifacts: {rel}", ) def _improved() -> bool: path = REPORT_DIR / "improvement_report.json" if not path.exists(): return False try: payload = json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError: return False return payload.get("improved") is True def _train() -> dict[str, Any]: model_id = os.getenv("POLYGUARD_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct") env = os.environ.copy() env.setdefault("POLYGUARD_OFFLINE_MODE", "false") env.pop("HF_HUB_ENABLE_HF_TRANSFER", None) env.setdefault("TOKENIZERS_PARALLELISM", "false") commands = [ ["python", "scripts/bootstrap_data.py"], ["python", "scripts/build_training_corpus.py", "--profile", "massive", "--with-local", "--with-synthetic", "--with-hf"], [ "python", "scripts/train_sft_trl.py", "--model-id", model_id, "--epochs", "1", "--max-steps", "20", "--batch-size", "2", "--max-seq-len", "512", "--use-unsloth", ], [ "python", "scripts/train_grpo_trl.py", "--model-id", model_id, "--max-prompts", "0", "--max-steps", "0", "--epochs", "1", "--batch-size", "2", "--num-generations", "2", "--max-prompt-length", "384", "--max-completion-length", "64", "--use-unsloth", ], ["python", "scripts/merge_adapters_safe.py", "--adapter-dir", "checkpoints/sft_adapter", "--output-dir", "checkpoints/merged"], ["python", "scripts/test_inference_postsave.py", "--samples", "3", "--base-model", model_id], ["python", "scripts/evaluate_policy_ablations.py", "--episodes", "8"], ["python", "scripts/evaluate_baselines.py"], ["python", "scripts/evaluate_all.py"], [ "python", "scripts/evaluate_compare_runs.py", "--baseline", "outputs/reports/baselines.json", "--candidate", "outputs/reports/benchmark_report.json", "--output", "outputs/reports/improvement_report.json", ], ] with LOCK: STATUS.update({"status": "running", "started_at": time.time(), "finished_at": None, "commands": []}) _write_status() LOG_PATH.unlink(missing_ok=True) try: for command in commands: _run_command(command, env) if not _improved() and os.getenv("POLYGUARD_SKIP_GRPO_UPSCALE", "false").lower() not in {"1", "true", "yes", "on"}: _append_log("improvement=false; rerunning GRPO with 40 steps and 128 prompts") _run_command( [ "python", "scripts/train_grpo_trl.py", "--model-id", model_id, "--max-prompts", "0", "--max-steps", "0", "--epochs", "1", "--batch-size", "2", "--num-generations", "2", "--max-prompt-length", "384", "--max-completion-length", "64", "--use-unsloth", ], env, ) for command in commands[6:]: _run_command(command, env) _mirror_results() _upload_artifacts() with LOCK: STATUS.update({"status": "ok", "finished_at": time.time(), "improved": _improved()}) _write_status() except Exception as exc: # noqa: BLE001 _append_log(str(exc)) _mirror_results() _upload_artifacts() with LOCK: STATUS.update({"status": "failed", "finished_at": time.time(), "error": str(exc)}) _write_status() return STATUS def run_training() -> tuple[dict[str, Any], str]: with LOCK: if STATUS.get("status") == "running": return STATUS, LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else "" thread = threading.Thread(target=_train, daemon=True) thread.start() return STATUS, "training started" def read_status() -> tuple[dict[str, Any], str]: log = LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else "" return STATUS, log[-20000:] def build_app() -> gr.Blocks: with gr.Blocks(title="PolyGuard HF Training") as demo: gr.Markdown("# PolyGuard HF Training") run_button = gr.Button("Run training", variant="primary") refresh_button = gr.Button("Refresh") status_box = gr.JSON(label="Status", value=STATUS) log_box = gr.Textbox(label="Logs", lines=26) run_button.click(fn=run_training, outputs=[status_box, log_box]) refresh_button.click(fn=read_status, outputs=[status_box, log_box]) return demo if os.getenv("POLYGUARD_AUTORUN", "1").lower() in {"1", "true", "yes", "on"}: threading.Thread(target=_train, daemon=True).start() app = build_app() if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)