Spaces:
Running
Running
| """Gradio runner for the private Hugging Face training Space.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| import subprocess | |
| import threading | |
| import time | |
| from typing import Any | |
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| ROOT = Path(__file__).resolve().parents[2] | |
| LOG_DIR = ROOT / "outputs" / "logs" | |
| REPORT_DIR = ROOT / "outputs" / "reports" | |
| STATUS_PATH = REPORT_DIR / "hf_training_status.json" | |
| LOG_PATH = LOG_DIR / "hf_training.log" | |
| LOCK = threading.Lock() | |
| STATUS: dict[str, Any] = { | |
| "status": "idle", | |
| "started_at": None, | |
| "finished_at": None, | |
| "commands": [], | |
| "artifact_repo_id": os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-artifacts"), | |
| } | |
| def _write_status() -> None: | |
| REPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| STATUS_PATH.write_text(json.dumps(STATUS, ensure_ascii=True, indent=2), encoding="utf-8") | |
| def _append_log(message: str) -> None: | |
| LOG_DIR.mkdir(parents=True, exist_ok=True) | |
| with LOG_PATH.open("a", encoding="utf-8") as handle: | |
| handle.write(message.rstrip() + "\n") | |
| def _run_command(args: list[str], env: dict[str, str]) -> None: | |
| started = time.time() | |
| _append_log(f"$ {' '.join(args)}") | |
| proc = subprocess.run( | |
| args, | |
| cwd=ROOT, | |
| env=env, | |
| text=True, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| check=False, | |
| ) | |
| _append_log(proc.stdout or "") | |
| elapsed = round(time.time() - started, 3) | |
| record = { | |
| "args": args, | |
| "returncode": proc.returncode, | |
| "elapsed_seconds": elapsed, | |
| } | |
| with LOCK: | |
| STATUS["commands"].append(record) | |
| _write_status() | |
| if proc.returncode != 0: | |
| raise RuntimeError(f"command_failed:{args}:{proc.returncode}") | |
| def _mirror_results() -> None: | |
| docs_results = ROOT / "docs" / "results" | |
| docs_results.mkdir(parents=True, exist_ok=True) | |
| for source_dir in [REPORT_DIR, ROOT / "outputs" / "plots"]: | |
| if not source_dir.exists(): | |
| continue | |
| for path in source_dir.iterdir(): | |
| if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}: | |
| shutil.copy2(path, docs_results / path.name) | |
| def _upload_artifacts() -> None: | |
| token = os.getenv("HF_TOKEN") | |
| repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-artifacts") | |
| if not token: | |
| _append_log("HF_TOKEN missing; artifact upload skipped") | |
| return | |
| api = HfApi(token=token) | |
| api.create_repo(repo_id=repo_id, repo_type="model", private=True, exist_ok=True) | |
| for rel in [ | |
| "outputs/reports", | |
| "outputs/plots", | |
| "docs/results", | |
| "checkpoints/sft_adapter", | |
| "checkpoints/grpo_adapter", | |
| "checkpoints/merged", | |
| ]: | |
| path = ROOT / rel | |
| if path.exists(): | |
| api.upload_folder( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| folder_path=str(path), | |
| path_in_repo=rel, | |
| commit_message=f"Upload PolyGuard training artifacts: {rel}", | |
| ) | |
| def _improved() -> bool: | |
| path = REPORT_DIR / "improvement_report.json" | |
| if not path.exists(): | |
| return False | |
| try: | |
| payload = json.loads(path.read_text(encoding="utf-8")) | |
| except json.JSONDecodeError: | |
| return False | |
| return payload.get("improved") is True | |
| def _train() -> dict[str, Any]: | |
| model_id = os.getenv("POLYGUARD_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct") | |
| env = os.environ.copy() | |
| env.setdefault("POLYGUARD_OFFLINE_MODE", "false") | |
| env.pop("HF_HUB_ENABLE_HF_TRANSFER", None) | |
| env.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| commands = [ | |
| ["python", "scripts/bootstrap_data.py"], | |
| ["python", "scripts/build_training_corpus.py", "--profile", "massive", "--with-local", "--with-synthetic", "--with-hf"], | |
| [ | |
| "python", | |
| "scripts/train_sft_trl.py", | |
| "--model-id", | |
| model_id, | |
| "--epochs", | |
| "1", | |
| "--max-steps", | |
| "20", | |
| "--batch-size", | |
| "2", | |
| "--max-seq-len", | |
| "512", | |
| "--use-unsloth", | |
| ], | |
| [ | |
| "python", | |
| "scripts/train_grpo_trl.py", | |
| "--model-id", | |
| model_id, | |
| "--max-prompts", | |
| "0", | |
| "--max-steps", | |
| "0", | |
| "--epochs", | |
| "1", | |
| "--batch-size", | |
| "2", | |
| "--num-generations", | |
| "2", | |
| "--max-prompt-length", | |
| "384", | |
| "--max-completion-length", | |
| "64", | |
| "--use-unsloth", | |
| ], | |
| ["python", "scripts/merge_adapters_safe.py", "--adapter-dir", "checkpoints/sft_adapter", "--output-dir", "checkpoints/merged"], | |
| ["python", "scripts/test_inference_postsave.py", "--samples", "3", "--base-model", model_id], | |
| ["python", "scripts/evaluate_policy_ablations.py", "--episodes", "8"], | |
| ["python", "scripts/evaluate_baselines.py"], | |
| ["python", "scripts/evaluate_all.py"], | |
| [ | |
| "python", | |
| "scripts/evaluate_compare_runs.py", | |
| "--baseline", | |
| "outputs/reports/baselines.json", | |
| "--candidate", | |
| "outputs/reports/benchmark_report.json", | |
| "--output", | |
| "outputs/reports/improvement_report.json", | |
| ], | |
| ] | |
| with LOCK: | |
| STATUS.update({"status": "running", "started_at": time.time(), "finished_at": None, "commands": []}) | |
| _write_status() | |
| LOG_PATH.unlink(missing_ok=True) | |
| try: | |
| for command in commands: | |
| _run_command(command, env) | |
| if not _improved() and os.getenv("POLYGUARD_SKIP_GRPO_UPSCALE", "false").lower() not in {"1", "true", "yes", "on"}: | |
| _append_log("improvement=false; rerunning GRPO with 40 steps and 128 prompts") | |
| _run_command( | |
| [ | |
| "python", | |
| "scripts/train_grpo_trl.py", | |
| "--model-id", | |
| model_id, | |
| "--max-prompts", | |
| "0", | |
| "--max-steps", | |
| "0", | |
| "--epochs", | |
| "1", | |
| "--batch-size", | |
| "2", | |
| "--num-generations", | |
| "2", | |
| "--max-prompt-length", | |
| "384", | |
| "--max-completion-length", | |
| "64", | |
| "--use-unsloth", | |
| ], | |
| env, | |
| ) | |
| for command in commands[6:]: | |
| _run_command(command, env) | |
| _mirror_results() | |
| _upload_artifacts() | |
| with LOCK: | |
| STATUS.update({"status": "ok", "finished_at": time.time(), "improved": _improved()}) | |
| _write_status() | |
| except Exception as exc: # noqa: BLE001 | |
| _append_log(str(exc)) | |
| _mirror_results() | |
| _upload_artifacts() | |
| with LOCK: | |
| STATUS.update({"status": "failed", "finished_at": time.time(), "error": str(exc)}) | |
| _write_status() | |
| return STATUS | |
| def run_training() -> tuple[dict[str, Any], str]: | |
| with LOCK: | |
| if STATUS.get("status") == "running": | |
| return STATUS, LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else "" | |
| thread = threading.Thread(target=_train, daemon=True) | |
| thread.start() | |
| return STATUS, "training started" | |
| def read_status() -> tuple[dict[str, Any], str]: | |
| log = LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else "" | |
| return STATUS, log[-20000:] | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks(title="PolyGuard HF Training") as demo: | |
| gr.Markdown("# PolyGuard HF Training") | |
| run_button = gr.Button("Run training", variant="primary") | |
| refresh_button = gr.Button("Refresh") | |
| status_box = gr.JSON(label="Status", value=STATUS) | |
| log_box = gr.Textbox(label="Logs", lines=26) | |
| run_button.click(fn=run_training, outputs=[status_box, log_box]) | |
| refresh_button.click(fn=read_status, outputs=[status_box, log_box]) | |
| return demo | |
| if os.getenv("POLYGUARD_AUTORUN", "1").lower() in {"1", "true", "yes", "on"}: | |
| threading.Thread(target=_train, daemon=True).start() | |
| app = build_app() | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |