polyguard-openenv / app /hf_space /training_runner.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""Gradio runner for the private Hugging Face training Space."""
from __future__ import annotations
import json
import os
from pathlib import Path
import shutil
import subprocess
import threading
import time
from typing import Any
import gradio as gr
from huggingface_hub import HfApi
ROOT = Path(__file__).resolve().parents[2]
LOG_DIR = ROOT / "outputs" / "logs"
REPORT_DIR = ROOT / "outputs" / "reports"
STATUS_PATH = REPORT_DIR / "hf_training_status.json"
LOG_PATH = LOG_DIR / "hf_training.log"
LOCK = threading.Lock()
STATUS: dict[str, Any] = {
"status": "idle",
"started_at": None,
"finished_at": None,
"commands": [],
"artifact_repo_id": os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-artifacts"),
}
def _write_status() -> None:
REPORT_DIR.mkdir(parents=True, exist_ok=True)
STATUS_PATH.write_text(json.dumps(STATUS, ensure_ascii=True, indent=2), encoding="utf-8")
def _append_log(message: str) -> None:
LOG_DIR.mkdir(parents=True, exist_ok=True)
with LOG_PATH.open("a", encoding="utf-8") as handle:
handle.write(message.rstrip() + "\n")
def _run_command(args: list[str], env: dict[str, str]) -> None:
started = time.time()
_append_log(f"$ {' '.join(args)}")
proc = subprocess.run(
args,
cwd=ROOT,
env=env,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=False,
)
_append_log(proc.stdout or "")
elapsed = round(time.time() - started, 3)
record = {
"args": args,
"returncode": proc.returncode,
"elapsed_seconds": elapsed,
}
with LOCK:
STATUS["commands"].append(record)
_write_status()
if proc.returncode != 0:
raise RuntimeError(f"command_failed:{args}:{proc.returncode}")
def _mirror_results() -> None:
docs_results = ROOT / "docs" / "results"
docs_results.mkdir(parents=True, exist_ok=True)
for source_dir in [REPORT_DIR, ROOT / "outputs" / "plots"]:
if not source_dir.exists():
continue
for path in source_dir.iterdir():
if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}:
shutil.copy2(path, docs_results / path.name)
def _upload_artifacts() -> None:
token = os.getenv("HF_TOKEN")
repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-artifacts")
if not token:
_append_log("HF_TOKEN missing; artifact upload skipped")
return
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="model", private=True, exist_ok=True)
for rel in [
"outputs/reports",
"outputs/plots",
"docs/results",
"checkpoints/sft_adapter",
"checkpoints/grpo_adapter",
"checkpoints/merged",
]:
path = ROOT / rel
if path.exists():
api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=str(path),
path_in_repo=rel,
commit_message=f"Upload PolyGuard training artifacts: {rel}",
)
def _improved() -> bool:
path = REPORT_DIR / "improvement_report.json"
if not path.exists():
return False
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return False
return payload.get("improved") is True
def _train() -> dict[str, Any]:
model_id = os.getenv("POLYGUARD_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
env = os.environ.copy()
env.setdefault("POLYGUARD_OFFLINE_MODE", "false")
env.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
env.setdefault("TOKENIZERS_PARALLELISM", "false")
commands = [
["python", "scripts/bootstrap_data.py"],
["python", "scripts/build_training_corpus.py", "--profile", "massive", "--with-local", "--with-synthetic", "--with-hf"],
[
"python",
"scripts/train_sft_trl.py",
"--model-id",
model_id,
"--epochs",
"1",
"--max-steps",
"20",
"--batch-size",
"2",
"--max-seq-len",
"512",
"--use-unsloth",
],
[
"python",
"scripts/train_grpo_trl.py",
"--model-id",
model_id,
"--max-prompts",
"0",
"--max-steps",
"0",
"--epochs",
"1",
"--batch-size",
"2",
"--num-generations",
"2",
"--max-prompt-length",
"384",
"--max-completion-length",
"64",
"--use-unsloth",
],
["python", "scripts/merge_adapters_safe.py", "--adapter-dir", "checkpoints/sft_adapter", "--output-dir", "checkpoints/merged"],
["python", "scripts/test_inference_postsave.py", "--samples", "3", "--base-model", model_id],
["python", "scripts/evaluate_policy_ablations.py", "--episodes", "8"],
["python", "scripts/evaluate_baselines.py"],
["python", "scripts/evaluate_all.py"],
[
"python",
"scripts/evaluate_compare_runs.py",
"--baseline",
"outputs/reports/baselines.json",
"--candidate",
"outputs/reports/benchmark_report.json",
"--output",
"outputs/reports/improvement_report.json",
],
]
with LOCK:
STATUS.update({"status": "running", "started_at": time.time(), "finished_at": None, "commands": []})
_write_status()
LOG_PATH.unlink(missing_ok=True)
try:
for command in commands:
_run_command(command, env)
if not _improved() and os.getenv("POLYGUARD_SKIP_GRPO_UPSCALE", "false").lower() not in {"1", "true", "yes", "on"}:
_append_log("improvement=false; rerunning GRPO with 40 steps and 128 prompts")
_run_command(
[
"python",
"scripts/train_grpo_trl.py",
"--model-id",
model_id,
"--max-prompts",
"0",
"--max-steps",
"0",
"--epochs",
"1",
"--batch-size",
"2",
"--num-generations",
"2",
"--max-prompt-length",
"384",
"--max-completion-length",
"64",
"--use-unsloth",
],
env,
)
for command in commands[6:]:
_run_command(command, env)
_mirror_results()
_upload_artifacts()
with LOCK:
STATUS.update({"status": "ok", "finished_at": time.time(), "improved": _improved()})
_write_status()
except Exception as exc: # noqa: BLE001
_append_log(str(exc))
_mirror_results()
_upload_artifacts()
with LOCK:
STATUS.update({"status": "failed", "finished_at": time.time(), "error": str(exc)})
_write_status()
return STATUS
def run_training() -> tuple[dict[str, Any], str]:
with LOCK:
if STATUS.get("status") == "running":
return STATUS, LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else ""
thread = threading.Thread(target=_train, daemon=True)
thread.start()
return STATUS, "training started"
def read_status() -> tuple[dict[str, Any], str]:
log = LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else ""
return STATUS, log[-20000:]
def build_app() -> gr.Blocks:
with gr.Blocks(title="PolyGuard HF Training") as demo:
gr.Markdown("# PolyGuard HF Training")
run_button = gr.Button("Run training", variant="primary")
refresh_button = gr.Button("Refresh")
status_box = gr.JSON(label="Status", value=STATUS)
log_box = gr.Textbox(label="Logs", lines=26)
run_button.click(fn=run_training, outputs=[status_box, log_box])
refresh_button.click(fn=read_status, outputs=[status_box, log_box])
return demo
if os.getenv("POLYGUARD_AUTORUN", "1").lower() in {"1", "true", "yes", "on"}:
threading.Thread(target=_train, daemon=True).start()
app = build_app()
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)