polyguard-openenv-workbench / polyguard-rl /scripts /activate_sweep_model.py
TheJackBright's picture
Deploy GitHub root master to Space
c296d62
#!/usr/bin/env python3
"""Activate a pulled sweep model for local API/UI inference."""
from __future__ import annotations
import argparse
from datetime import datetime, timezone
import json
from pathlib import Path
import shutil
import sys
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Activate a PolyGuard sweep checkpoint for product inference.")
parser.add_argument("--run-id", default=DEFAULT_RUN_ID)
parser.add_argument("--source", choices=["sweep", "top-level"], default="sweep")
parser.add_argument("--preferred-artifact", choices=["grpo_adapter", "merged", "sft_adapter"], default="grpo_adapter")
parser.add_argument("--mode", choices=["symlink", "copy"], default="symlink")
parser.add_argument("--label", default="")
parser.add_argument("--disable", action="store_true", help="Write the manifest but keep active-model loading disabled.")
return parser.parse_args()
def _read_json(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
return payload if isinstance(payload, dict) else {}
def _replace_path(source: Path, target: Path, *, mode: str) -> bool:
if not source.exists():
return False
if target.is_symlink() or target.is_file():
target.unlink()
elif target.exists():
shutil.rmtree(target)
target.parent.mkdir(parents=True, exist_ok=True)
if mode == "symlink":
target.symlink_to(source.resolve(), target_is_directory=True)
else:
shutil.copytree(source, target)
return True
def _copy_reports(source: Path, target: Path) -> dict[str, str]:
copied: dict[str, str] = {}
if target.exists():
shutil.rmtree(target)
target.mkdir(parents=True, exist_ok=True)
if not source.exists():
return copied
target_resolved = target.resolve()
for path in source.rglob("*"):
if not path.is_file() or path.suffix.lower() not in {".json", ".jsonl", ".txt"}:
continue
try:
path.resolve().relative_to(target_resolved)
continue
except ValueError:
pass
rel = path.relative_to(source)
out = target / rel
out.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, out)
copied[str(rel)] = str(out.relative_to(ROOT))
return copied
def _model_id(checkpoint_dir: Path, report_dir: Path) -> str:
for path in [
report_dir / "run_metadata.json",
checkpoint_dir / "merged" / "merge_report.json",
checkpoint_dir / "grpo_adapter" / "adapter_config.json",
checkpoint_dir / "sft_adapter" / "adapter_config.json",
]:
payload = _read_json(path)
for key in ["model_id", "base_model", "base_model_name_or_path"]:
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return "Qwen/Qwen2.5-0.5B-Instruct"
def main() -> None:
args = parse_args()
if args.source == "sweep":
checkpoint_dir = ROOT / "checkpoints" / "sweeps" / args.run_id
report_dir = ROOT / "outputs" / "reports" / "sweeps" / args.run_id
else:
checkpoint_dir = ROOT / "checkpoints"
report_dir = ROOT / "outputs" / "reports"
active_dir = ROOT / "checkpoints" / "active"
active_report_dir = ROOT / "outputs" / "reports" / "active_model"
active_dir.mkdir(parents=True, exist_ok=True)
availability = {
"grpo_adapter": _replace_path(checkpoint_dir / "grpo_adapter", active_dir / "grpo_adapter", mode=args.mode),
"merged": _replace_path(checkpoint_dir / "merged", active_dir / "merged", mode=args.mode),
"sft_adapter": _replace_path(checkpoint_dir / "sft_adapter", active_dir / "sft_adapter", mode=args.mode),
}
if not any(availability.values()):
raise SystemExit(f"no_model_artifacts_found:{checkpoint_dir}")
reports = _copy_reports(report_dir, active_report_dir)
manifest = {
"status": "ok",
"enabled": not args.disable,
"activated_at_utc": datetime.now(timezone.utc).isoformat(),
"run_id": args.run_id,
"source": args.source,
"label": args.label,
"model_id": _model_id(checkpoint_dir, report_dir),
"base_model": _model_id(checkpoint_dir, report_dir),
"preferred_artifact": args.preferred_artifact,
"mode": args.mode,
"source_checkpoint_dir": str(checkpoint_dir.relative_to(ROOT)),
"source_report_dir": str(report_dir.relative_to(ROOT)) if report_dir.exists() else "",
"grpo_adapter": "checkpoints/active/grpo_adapter",
"merged_model": "checkpoints/active/merged",
"sft_adapter": "checkpoints/active/sft_adapter",
"availability": availability,
"reports": reports,
"notes": (
"This manifest controls local product inference. Prefer grpo_adapter for the RL policy; "
"merged is the SFT baseline fallback when no GRPO adapter is available."
),
}
(active_dir / "active_model_manifest.json").write_text(
json.dumps(manifest, ensure_ascii=True, indent=2),
encoding="utf-8",
)
for mirror in [
ROOT / "outputs" / "reports" / "active_model" / "active_model_manifest.json",
ROOT / "docs" / "results" / "active_model_manifest.json",
]:
mirror.parent.mkdir(parents=True, exist_ok=True)
mirror.write_text(json.dumps(manifest, ensure_ascii=True, indent=2), encoding="utf-8")
print(json.dumps(manifest, ensure_ascii=True, indent=2))
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
sys.exit(130)