TheJackBright's picture
Deploy GitHub root master to Space
c296d62
"""Gradio runner for the private Hugging Face training Space."""
from __future__ import annotations
import json
import os
from pathlib import Path
import shutil
import subprocess
import threading
import time
from typing import Any
import gradio as gr
from huggingface_hub import HfApi
from huggingface_hub import snapshot_download
ROOT = Path(__file__).resolve().parents[2]
LOG_DIR = ROOT / "outputs" / "logs"
REPORT_DIR = ROOT / "outputs" / "reports"
STATUS_PATH = REPORT_DIR / "hf_training_status.json"
LOG_PATH = LOG_DIR / "hf_training.log"
LOCK = threading.Lock()
STATUS: dict[str, Any] = {
"status": "idle",
"started_at": None,
"finished_at": None,
"commands": [],
"artifact_repo_id": os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-full-artifacts"),
"training_mode": os.getenv("POLYGUARD_TRAINING_MODE", "full"),
"model_sweep": os.getenv(
"POLYGUARD_MODEL_SWEEP",
"Qwen/Qwen2.5-0.5B-Instruct,Qwen/Qwen2.5-1.5B-Instruct,Qwen/Qwen2.5-3B-Instruct",
),
}
def _env_bool(name: str, default: bool = False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.lower() in {"1", "true", "yes", "on"}
def _write_status() -> None:
REPORT_DIR.mkdir(parents=True, exist_ok=True)
STATUS_PATH.write_text(json.dumps(STATUS, ensure_ascii=True, indent=2), encoding="utf-8")
def _append_log(message: str) -> None:
LOG_DIR.mkdir(parents=True, exist_ok=True)
with LOG_PATH.open("a", encoding="utf-8") as handle:
handle.write(message.rstrip() + "\n")
def _upload_relpath(rel: str, *, commit_suffix: str = "") -> None:
if not _env_bool("POLYGUARD_INCREMENTAL_UPLOAD", True):
return
token = os.getenv("HF_TOKEN")
repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-full-artifacts")
if not token:
return
path = ROOT / rel
if not path.exists():
return
try:
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="model", private=True, exist_ok=True)
if path.is_file():
api.upload_file(
repo_id=repo_id,
repo_type="model",
path_or_fileobj=str(path),
path_in_repo=rel,
commit_message=f"Upload PolyGuard artifact: {commit_suffix or rel}",
)
else:
api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=str(path),
path_in_repo=rel,
commit_message=f"Upload PolyGuard artifact folder: {commit_suffix or rel}",
ignore_patterns=[".DS_Store", "**/.DS_Store"],
)
except Exception as exc: # noqa: BLE001
_append_log(f"incremental_upload_skipped:{rel}:{exc}")
def _upload_status_and_log(context: str) -> None:
_upload_relpath("outputs/reports/hf_training_status.json", commit_suffix=f"status {context}")
_upload_relpath("outputs/logs/hf_training.log", commit_suffix=f"log {context}")
def _upload_run_snapshot(run_id: str, stage: str) -> None:
if not _env_bool("POLYGUARD_UPLOAD_AFTER_EACH_STAGE", True):
return
_upload_status_and_log(f"{run_id} {stage}")
_upload_relpath(f"outputs/reports/sweeps/{run_id}", commit_suffix=f"{run_id} reports after {stage}")
_upload_relpath(f"checkpoints/sweeps/{run_id}", commit_suffix=f"{run_id} checkpoints after {stage}")
def _run_command(args: list[str], env: dict[str, str]) -> None:
started = time.time()
last_incremental_upload = started
_append_log(f"$ {' '.join(args)}")
proc = subprocess.Popen(
args,
cwd=ROOT,
env=env,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
assert proc.stdout is not None
saw_output = False
for line in proc.stdout:
saw_output = True
_append_log(line)
now = time.time()
if now - last_incremental_upload >= _env_int("POLYGUARD_LOG_UPLOAD_INTERVAL_SECONDS", 180):
_upload_status_and_log("running")
last_incremental_upload = now
proc.wait()
elapsed = round(time.time() - started, 3)
record = {
"args": args,
"returncode": proc.returncode,
"elapsed_seconds": elapsed,
}
with LOCK:
STATUS["commands"].append(record)
_write_status()
_upload_status_and_log("command_complete")
if proc.returncode != 0:
if not saw_output:
_append_log("<no command output>")
_upload_status_and_log("command_failed")
raise RuntimeError(f"command_failed:{args}:{proc.returncode}")
def _env_int(name: str, default: int) -> int:
try:
return int(os.getenv(name, str(default)))
except ValueError:
return default
def _env_float(name: str, default: float) -> float:
try:
return float(os.getenv(name, str(default)))
except ValueError:
return default
def _csv_env(name: str, default: str) -> list[str]:
value = os.getenv(name, default)
return [item.strip() for item in value.split(",") if item.strip()]
def _indexed_int_env(name: str, index: int, default: int) -> int:
values = _csv_env(name, "")
if index >= len(values):
return default
try:
return int(values[index])
except ValueError:
return default
def _indexed_float_env(name: str, index: int, default: float) -> float:
values = _csv_env(name, "")
if index >= len(values):
return default
try:
return float(values[index])
except ValueError:
return default
def _safe_name(value: str) -> str:
return "".join(ch if ch.isalnum() else "-" for ch in value).strip("-").lower()
def _copy_file_if_exists(source: Path, target: Path) -> None:
if source.exists():
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source, target)
def _copy_dir_if_exists(source: Path, target: Path) -> None:
if source.exists():
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(source, target, dirs_exist_ok=True)
def _record_reused_artifact(name: str, path: Path) -> None:
with LOCK:
STATUS["commands"].append(
{
"args": ["reuse_artifact", name, str(path)],
"returncode": 0,
"elapsed_seconds": 0.0,
}
)
_write_status()
def _restore_remote_artifacts() -> None:
if os.getenv("POLYGUARD_REUSE_REMOTE_GRPO", "false").lower() not in {"1", "true", "yes", "on"}:
return
token = os.getenv("HF_TOKEN")
repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-full-artifacts")
if not token:
return
try:
snapshot = Path(
snapshot_download(
repo_id=repo_id,
repo_type="model",
token=token,
allow_patterns=[
"checkpoints/grpo_adapter/*",
"outputs/reports/grpo_trl_run.json",
],
)
)
except Exception as exc: # noqa: BLE001
_append_log(f"remote_artifact_restore_skipped:{exc}")
return
for rel in ["checkpoints/grpo_adapter", "outputs/reports/grpo_trl_run.json"]:
source = snapshot / rel
target = ROOT / rel
if source.is_dir():
shutil.copytree(source, target, dirs_exist_ok=True)
elif source.is_file():
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(source, target)
def _grpo_artifact_ready() -> bool:
report = REPORT_DIR / "grpo_trl_run.json"
adapter = ROOT / "checkpoints" / "grpo_adapter"
if not report.exists() or not adapter.exists():
return False
if not (adapter / "adapter_config.json").exists() or not (adapter / "adapter_model.safetensors").exists():
return False
try:
payload = json.loads(report.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return False
return payload.get("status") == "ok" and bool(payload.get("artifact_path"))
def _mirror_results() -> None:
docs_results = ROOT / "docs" / "results"
docs_results.mkdir(parents=True, exist_ok=True)
for source_dir in [REPORT_DIR, ROOT / "outputs" / "plots"]:
if not source_dir.exists():
continue
for path in source_dir.rglob("*"):
if path.is_file() and path.suffix.lower() in {".json", ".txt", ".png"}:
target = docs_results / path.relative_to(source_dir)
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, target)
def _upload_artifacts() -> None:
token = os.getenv("HF_TOKEN")
repo_id = os.getenv("POLYGUARD_ARTIFACT_REPO_ID", "TheJackBright/polyguard-openenv-training-full-artifacts")
if not token:
_append_log("HF_TOKEN missing; artifact upload skipped")
return
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="model", private=True, exist_ok=True)
for rel in [
"outputs/reports",
"outputs/plots",
"docs/results",
"checkpoints/sft_adapter",
"checkpoints/grpo_adapter",
"checkpoints/merged",
"checkpoints/sweeps",
]:
path = ROOT / rel
if path.exists():
api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=str(path),
path_in_repo=rel,
commit_message=f"Upload PolyGuard training artifacts: {rel}",
)
def _improved() -> bool:
path = REPORT_DIR / "improvement_report.json"
if not path.exists():
return False
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return False
return payload.get("improved") is True
def _promote_run_artifacts(run_id: str) -> None:
checkpoint_dir = ROOT / "checkpoints" / "sweeps" / run_id
report_dir = REPORT_DIR / "sweeps" / run_id
_copy_dir_if_exists(checkpoint_dir / "sft_adapter", ROOT / "checkpoints" / "sft_adapter")
_copy_dir_if_exists(checkpoint_dir / "grpo_adapter", ROOT / "checkpoints" / "grpo_adapter")
_copy_dir_if_exists(checkpoint_dir / "merged", ROOT / "checkpoints" / "merged")
_copy_file_if_exists(report_dir / "sft_trl_run.json", REPORT_DIR / "sft_trl_run.json")
_copy_file_if_exists(report_dir / "grpo_trl_run.json", REPORT_DIR / "grpo_trl_run.json")
_copy_file_if_exists(report_dir / "postsave_inference_grpo.json", REPORT_DIR / "postsave_inference.json")
_copy_file_if_exists(report_dir / "grpo_ablation_report.json", REPORT_DIR / "grpo_ablation_report.json")
def _promote_sft_run_artifacts(run_id: str) -> None:
checkpoint_dir = ROOT / "checkpoints" / "sweeps" / run_id
report_dir = REPORT_DIR / "sweeps" / run_id
_copy_dir_if_exists(checkpoint_dir / "sft_adapter", ROOT / "checkpoints" / "sft_adapter")
_copy_dir_if_exists(checkpoint_dir / "merged", ROOT / "checkpoints" / "merged")
_copy_file_if_exists(report_dir / "sft_trl_run.json", REPORT_DIR / "sft_trl_run.json")
_copy_file_if_exists(report_dir / "postsave_inference_sft.json", REPORT_DIR / "postsave_inference.json")
def _run_model_experiment(
model_id: str,
env: dict[str, str],
*,
model_index: int,
run_grpo: bool,
) -> str:
run_id = _safe_name(model_id)
checkpoint_dir = ROOT / "checkpoints" / "sweeps" / run_id
report_dir = REPORT_DIR / "sweeps" / run_id
checkpoint_dir.mkdir(parents=True, exist_ok=True)
report_dir.mkdir(parents=True, exist_ok=True)
sft_epochs = _indexed_int_env("POLYGUARD_SFT_EPOCH_SWEEP", model_index, _env_int("POLYGUARD_SFT_EPOCHS", 2))
sft_max_steps = _indexed_int_env(
"POLYGUARD_SFT_MAX_STEP_SWEEP",
model_index,
_env_int("POLYGUARD_SFT_MAX_STEPS", 0),
)
sft_batch_size = _indexed_int_env(
"POLYGUARD_SFT_BATCH_SIZE_SWEEP",
model_index,
_env_int("POLYGUARD_SFT_BATCH_SIZE", 2),
)
sft_learning_rate = _indexed_float_env(
"POLYGUARD_SFT_LEARNING_RATE_SWEEP",
model_index,
_env_float("POLYGUARD_SFT_LEARNING_RATE", 2e-5),
)
grpo_epochs = _env_float("POLYGUARD_GRPO_EPOCHS", 1.0)
grpo_max_steps = _env_int("POLYGUARD_GRPO_MAX_STEPS", 0)
grpo_max_prompts = _env_int("POLYGUARD_GRPO_MAX_PROMPTS", 0)
_append_log(f"model_experiment_start:{model_id}")
(report_dir / "run_metadata.json").write_text(
json.dumps(
{
"training_mode": "full" if run_grpo else "sft-baseline",
"model_id": model_id,
"model_index": model_index,
"sft_epochs": sft_epochs,
"sft_max_steps": sft_max_steps,
"sft_batch_size": sft_batch_size,
"sft_learning_rate": sft_learning_rate,
},
ensure_ascii=True,
indent=2,
),
encoding="utf-8",
)
_run_command(
[
"python",
"scripts/train_sft_trl.py",
"--model-id",
model_id,
"--dataset-path",
"data/processed/training_corpus_sft.json",
"--output-dir",
f"checkpoints/sweeps/{run_id}",
"--report-path",
f"outputs/reports/sweeps/{run_id}/sft_trl_run.json",
"--epochs",
str(sft_epochs),
"--max-steps",
str(sft_max_steps),
"--batch-size",
str(sft_batch_size),
"--max-seq-len",
str(_env_int("POLYGUARD_SFT_MAX_SEQ_LEN", 512)),
"--learning-rate",
str(sft_learning_rate),
"--use-unsloth",
],
env,
)
_copy_file_if_exists(checkpoint_dir / "sft_history.json", report_dir / "sft_history.json")
_upload_run_snapshot(run_id, "sft_training")
if run_grpo:
_run_command(
[
"python",
"scripts/train_grpo_trl.py",
"--model-id",
model_id,
"--prompts-path",
"data/processed/training_corpus_grpo_prompts.jsonl",
"--output-dir",
f"checkpoints/sweeps/{run_id}",
"--report-path",
f"outputs/reports/sweeps/{run_id}/grpo_trl_run.json",
"--max-prompts",
str(grpo_max_prompts),
"--max-steps",
str(grpo_max_steps),
"--epochs",
str(grpo_epochs),
"--batch-size",
str(_env_int("POLYGUARD_GRPO_BATCH_SIZE", 2)),
"--grad-accum",
str(_env_int("POLYGUARD_GRPO_GRAD_ACCUM", 1)),
"--num-generations",
str(_env_int("POLYGUARD_GRPO_NUM_GENERATIONS", 2)),
"--max-prompt-length",
str(_env_int("POLYGUARD_GRPO_MAX_PROMPT_LENGTH", 384)),
"--max-completion-length",
str(_env_int("POLYGUARD_GRPO_MAX_COMPLETION_LENGTH", 64)),
"--learning-rate",
str(_env_float("POLYGUARD_GRPO_LEARNING_RATE", 1e-6)),
"--use-unsloth",
],
env,
)
_copy_file_if_exists(checkpoint_dir / "grpo_history.json", report_dir / "grpo_history.json")
_copy_file_if_exists(checkpoint_dir / "grpo_reward_components.jsonl", report_dir / "grpo_reward_components.jsonl")
_upload_run_snapshot(run_id, "grpo_training")
_run_command(
[
"python",
"scripts/merge_adapters_safe.py",
"--adapter-dir",
f"checkpoints/sweeps/{run_id}/sft_adapter",
"--output-dir",
f"checkpoints/sweeps/{run_id}/merged",
],
env,
)
_upload_run_snapshot(run_id, "sft_merge")
_run_command(
[
"python",
"scripts/test_inference_postsave.py",
"--samples",
str(_env_int("POLYGUARD_INFERENCE_SAMPLES", 5)),
"--base-model",
model_id,
"--merged-model",
f"checkpoints/sweeps/{run_id}/merged",
"--adapter-dir",
f"checkpoints/sweeps/{run_id}/sft_adapter",
"--output",
f"outputs/reports/sweeps/{run_id}/postsave_inference_sft.json",
],
env,
)
_upload_run_snapshot(run_id, "sft_postsave_inference")
if run_grpo:
_run_command(
[
"python",
"scripts/test_inference_postsave.py",
"--samples",
str(_env_int("POLYGUARD_INFERENCE_SAMPLES", 5)),
"--base-model",
model_id,
"--merged-model",
f"checkpoints/sweeps/{run_id}/missing_merged_grpo",
"--adapter-dir",
f"checkpoints/sweeps/{run_id}/grpo_adapter",
"--output",
f"outputs/reports/sweeps/{run_id}/postsave_inference_grpo.json",
],
env,
)
_upload_run_snapshot(run_id, "grpo_postsave_inference")
_run_command(
[
"python",
"scripts/evaluate_policy_ablations.py",
"--episodes",
str(_env_int("POLYGUARD_ABLATION_EPISODES", 8)),
"--checkpoint-dir",
f"checkpoints/sweeps/{run_id}",
"--output",
f"outputs/reports/sweeps/{run_id}/grpo_ablation_report.json",
],
env,
)
_promote_run_artifacts(run_id)
_upload_run_snapshot(run_id, "policy_ablation")
for rel in [
"checkpoints/sft_adapter",
"checkpoints/grpo_adapter",
"checkpoints/merged",
"outputs/reports/sft_trl_run.json",
"outputs/reports/grpo_trl_run.json",
"outputs/reports/postsave_inference.json",
"outputs/reports/grpo_ablation_report.json",
]:
_upload_relpath(rel, commit_suffix=f"promoted {run_id}")
else:
_promote_sft_run_artifacts(run_id)
_upload_run_snapshot(run_id, "sft_promoted")
for rel in [
"checkpoints/sft_adapter",
"checkpoints/merged",
"outputs/reports/sft_trl_run.json",
"outputs/reports/postsave_inference.json",
]:
_upload_relpath(rel, commit_suffix=f"promoted {run_id}")
_append_log(f"model_experiment_done:{model_id}")
_upload_run_snapshot(run_id, "complete")
return run_id
def _train() -> dict[str, Any]:
training_mode = os.getenv("POLYGUARD_TRAINING_MODE", "full").strip().lower()
run_grpo = training_mode not in {"sft", "sft-only", "sft-baseline", "sft_baseline"}
model_sweep = _csv_env(
"POLYGUARD_MODEL_SWEEP",
"Qwen/Qwen2.5-0.5B-Instruct,Qwen/Qwen2.5-1.5B-Instruct,Qwen/Qwen2.5-3B-Instruct",
)
env = os.environ.copy()
env.setdefault("POLYGUARD_OFFLINE_MODE", "false")
env.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
env.setdefault("TOKENIZERS_PARALLELISM", "false")
setup_commands = [
["python", "scripts/bootstrap_data.py"],
["python", "scripts/build_training_corpus.py", "--profile", "massive", "--with-local", "--with-synthetic", "--with-hf"],
]
with LOCK:
STATUS.update(
{
"status": "running",
"started_at": time.time(),
"finished_at": None,
"commands": [],
"model_sweep": model_sweep,
"training_mode": "full" if run_grpo else "sft-baseline",
}
)
_write_status()
LOG_PATH.unlink(missing_ok=True)
_restore_remote_artifacts()
try:
for command in setup_commands:
_run_command(command, env)
completed_run_ids: list[str] = []
for model_index, model_id in enumerate(model_sweep):
run_id = _safe_name(model_id)
try:
completed_run_ids.append(
_run_model_experiment(
model_id=model_id,
env=env,
model_index=model_index,
run_grpo=run_grpo,
)
)
except Exception as exc: # noqa: BLE001
error_dir = REPORT_DIR / "sweeps" / run_id
error_dir.mkdir(parents=True, exist_ok=True)
(error_dir / "error.json").write_text(
json.dumps(
{"status": "failed", "model_id": model_id, "error": str(exc)},
ensure_ascii=True,
indent=2,
),
encoding="utf-8",
)
_append_log(f"model_experiment_failed:{model_id}:{exc}")
_upload_run_snapshot(run_id, "failed")
if not completed_run_ids:
raise RuntimeError("all_model_experiments_failed")
if run_grpo and _grpo_artifact_ready():
_append_log("top_level_grpo_adapter_ready")
_record_reused_artifact("grpo_adapter", ROOT / "checkpoints" / "grpo_adapter")
eval_commands = [
["python", "scripts/evaluate_baselines.py"],
["python", "scripts/evaluate_all.py"],
[
"python",
"scripts/evaluate_compare_runs.py",
"--baseline",
"outputs/reports/baselines.json",
"--candidate",
"outputs/reports/benchmark_report.json",
"--output",
"outputs/reports/improvement_report.json",
],
["python", "scripts/benchmark_inference.py"],
]
if run_grpo:
eval_commands.append(["python", "scripts/run_robustness_suite.py"])
eval_commands.append(["python", "scripts/generate_hf_training_report.py", "--mode", "full" if run_grpo else "sft-baseline"])
for command in eval_commands:
_run_command(command, env)
anti_hacking = {}
anti_path = REPORT_DIR / "anti_hacking_overfit_report.json"
if anti_path.exists():
anti_hacking = json.loads(anti_path.read_text(encoding="utf-8"))
with LOCK:
STATUS.update(
{
"status": "ok",
"finished_at": time.time(),
"improved": _improved(),
"anti_hacking_passed": anti_hacking.get("passed"),
"completed_run_ids": completed_run_ids,
}
)
_write_status()
_mirror_results()
_upload_artifacts()
except Exception as exc: # noqa: BLE001
_append_log(str(exc))
with LOCK:
STATUS.update({"status": "failed", "finished_at": time.time(), "error": str(exc)})
_write_status()
_mirror_results()
_upload_artifacts()
return STATUS
def run_training() -> tuple[dict[str, Any], str]:
with LOCK:
if STATUS.get("status") == "running":
return STATUS, LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else ""
thread = threading.Thread(target=_train, daemon=True)
thread.start()
return STATUS, "training started"
def read_status() -> tuple[dict[str, Any], str]:
log = LOG_PATH.read_text(encoding="utf-8") if LOG_PATH.exists() else ""
return STATUS, log[-20000:]
def build_app() -> gr.Blocks:
with gr.Blocks(title="PolyGuard HF Training") as demo:
gr.Markdown("# PolyGuard HF Training")
run_button = gr.Button("Run training", variant="primary")
refresh_button = gr.Button("Refresh")
status_box = gr.JSON(label="Status", value=STATUS)
log_box = gr.Textbox(label="Logs", lines=26)
run_button.click(fn=run_training, outputs=[status_box, log_box])
refresh_button.click(fn=read_status, outputs=[status_box, log_box])
return demo
if os.getenv("POLYGUARD_AUTORUN", "1").lower() in {"1", "true", "yes", "on"}:
threading.Thread(target=_train, daemon=True).start()
app = build_app()
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)