polyguard-openenv-training-3b-continuation / scripts /monitor_training_space_status.py
adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
#!/usr/bin/env python3
"""Write a compact HF training Space status report."""
from __future__ import annotations
import argparse
from datetime import datetime, timezone
import json
import os
from pathlib import Path
from typing import Any
from huggingface_hub import HfApi
ROOT = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Monitor PolyGuard HF training Space.")
parser.add_argument("--space-id", default="TheJackBright/polyguard-openenv-training-full")
parser.add_argument("--artifact-repo-id", default="TheJackBright/polyguard-openenv-training-full-artifacts")
parser.add_argument(
"--output",
default="outputs/reports/submission_evidence/qwen_0_5b_1_5b_3b/training_space_runtime_status.json",
)
return parser.parse_args()
def load_json(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return {}
return payload if isinstance(payload, dict) else {}
def stage_records_from(path: Path) -> list[dict[str, Any]]:
payload = load_json(path)
records = payload.get("stage_records")
return records if isinstance(records, list) else []
def model_statuses_from(path: Path) -> dict[str, dict[str, str]]:
payload = load_json(path)
models = payload.get("models")
if not isinstance(models, list):
return {}
statuses: dict[str, dict[str, str]] = {}
for item in models:
if not isinstance(item, dict):
continue
run_id = str(item.get("run_id") or "")
model_statuses = item.get("statuses")
if run_id and isinstance(model_statuses, dict):
statuses[run_id] = {str(key): str(value) for key, value in model_statuses.items()}
return statuses
def main() -> None:
args = parse_args()
token = os.getenv("HF_TOKEN")
api = HfApi(token=token)
runtime_error = ""
artifact_error = ""
runtime: Any = {}
artifact_files: list[str] = []
try:
info = api.space_info(args.space_id)
runtime = getattr(info, "runtime", None)
except Exception as exc: # noqa: BLE001
runtime_error = str(exc)
try:
artifact_files = api.list_repo_files(repo_id=args.artifact_repo_id, repo_type="model", token=token)
except Exception as exc: # noqa: BLE001
artifact_error = str(exc)
prior_records = stage_records_from(ROOT / "outputs" / "reports" / "submission_evidence" / "qwen_0_5b_1_5b" / "manifest.json")
current_records = stage_records_from(ROOT / "outputs" / "reports" / "submission_evidence" / "qwen_0_5b_1_5b_3b" / "manifest.json")
prior_model_statuses = model_statuses_from(
ROOT / "outputs" / "reports" / "submission_evidence" / "qwen_0_5b_1_5b" / "manifest.json"
)
current_model_statuses = model_statuses_from(
ROOT / "outputs" / "reports" / "submission_evidence" / "qwen_0_5b_1_5b_3b" / "manifest.json"
)
stage_records = prior_records + [record for record in current_records if record not in prior_records]
completed_stages = {
f"{record.get('run_id')}:{record.get('stage')}": record
for record in stage_records
if isinstance(record, dict) and record.get("completed") is True
}
run_ids = [
"qwen-qwen2-5-0-5b-instruct",
"qwen-qwen2-5-1-5b-instruct",
"qwen-qwen2-5-3b-instruct",
]
run_statuses = {}
for run_id in run_ids:
merged_statuses = {**prior_model_statuses.get(run_id, {}), **current_model_statuses.get(run_id, {})}
def status_for(stage: str) -> str:
value = merged_statuses.get(stage, "")
if value == "artifact_available":
return "artifact_available"
if "remote_completed" in value:
return value
if f"{run_id}:{stage}" in completed_stages:
return "completed"
return value or "pending_or_unseen"
run_statuses[run_id] = {
"sft_training": status_for("sft_training"),
"grpo_training": status_for("grpo_training"),
"sft_postsave_inference": status_for("sft_postsave_inference"),
"grpo_postsave_inference": status_for("grpo_postsave_inference"),
"policy_ablation": status_for("policy_ablation"),
"artifact_files": [
item for item in artifact_files if f"outputs/reports/sweeps/{run_id}/" in item or f"checkpoints/sweeps/{run_id}/" in item
],
}
report = {
"status": "ok",
"generated_at_utc": datetime.now(timezone.utc).isoformat(),
"space_id": args.space_id,
"artifact_repo_id": args.artifact_repo_id,
"runtime": repr(runtime),
"runtime_error": runtime_error,
"artifact_error": artifact_error,
"artifact_file_count": len(artifact_files),
"has_usable_active_bundle": any(item.startswith("usable_model_bundles/local-qwen-0-5b-active-smoke/") for item in artifact_files),
"has_full_sweep_artifacts": any("outputs/reports/sweeps/" in item or "checkpoints/sweeps/" in item for item in artifact_files),
"run_statuses": run_statuses,
"interpretation": (
"The Space is not actively training if runtime contains stage='PAUSED'. "
"Completed stage records are taken from live evidence snapshots when available; "
"missing per-run artifact files mean the full sweep checkpoints/reports are not yet downloadable."
),
}
output = ROOT / args.output
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(json.dumps(report, ensure_ascii=True, indent=2) + "\n", encoding="utf-8")
docs_output = ROOT / "docs" / "results" / "submission_evidence_qwen_0_5b_1_5b_3b" / "reports" / output.name
docs_output.parent.mkdir(parents=True, exist_ok=True)
docs_output.write_text(json.dumps(report, ensure_ascii=True, indent=2) + "\n", encoding="utf-8")
print(json.dumps(report, ensure_ascii=True, indent=2))
if __name__ == "__main__":
main()