| |
| """Activate a pulled sweep model for local API/UI inference.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from datetime import datetime, timezone |
| import json |
| from pathlib import Path |
| import shutil |
| import sys |
| from typing import Any |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct" |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Activate a PolyGuard sweep checkpoint for product inference.") |
| parser.add_argument("--run-id", default=DEFAULT_RUN_ID) |
| parser.add_argument("--source", choices=["sweep", "top-level"], default="sweep") |
| parser.add_argument("--preferred-artifact", choices=["grpo_adapter", "merged", "sft_adapter"], default="grpo_adapter") |
| parser.add_argument("--mode", choices=["symlink", "copy"], default="symlink") |
| parser.add_argument("--label", default="") |
| parser.add_argument("--disable", action="store_true", help="Write the manifest but keep active-model loading disabled.") |
| return parser.parse_args() |
|
|
|
|
| def _read_json(path: Path) -> dict[str, Any]: |
| if not path.exists(): |
| return {} |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except Exception: |
| return {} |
| return payload if isinstance(payload, dict) else {} |
|
|
|
|
| def _replace_path(source: Path, target: Path, *, mode: str) -> bool: |
| if not source.exists(): |
| return False |
| if target.is_symlink() or target.is_file(): |
| target.unlink() |
| elif target.exists(): |
| shutil.rmtree(target) |
| target.parent.mkdir(parents=True, exist_ok=True) |
| if mode == "symlink": |
| target.symlink_to(source.resolve(), target_is_directory=True) |
| else: |
| shutil.copytree(source, target) |
| return True |
|
|
|
|
| def _copy_reports(source: Path, target: Path) -> dict[str, str]: |
| copied: dict[str, str] = {} |
| if target.exists(): |
| shutil.rmtree(target) |
| target.mkdir(parents=True, exist_ok=True) |
| if not source.exists(): |
| return copied |
| target_resolved = target.resolve() |
| for path in source.rglob("*"): |
| if not path.is_file() or path.suffix.lower() not in {".json", ".jsonl", ".txt"}: |
| continue |
| try: |
| path.resolve().relative_to(target_resolved) |
| continue |
| except ValueError: |
| pass |
| rel = path.relative_to(source) |
| out = target / rel |
| out.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(path, out) |
| copied[str(rel)] = str(out.relative_to(ROOT)) |
| return copied |
|
|
|
|
| def _model_id(checkpoint_dir: Path, report_dir: Path) -> str: |
| for path in [ |
| report_dir / "run_metadata.json", |
| checkpoint_dir / "merged" / "merge_report.json", |
| checkpoint_dir / "grpo_adapter" / "adapter_config.json", |
| checkpoint_dir / "sft_adapter" / "adapter_config.json", |
| ]: |
| payload = _read_json(path) |
| for key in ["model_id", "base_model", "base_model_name_or_path"]: |
| value = payload.get(key) |
| if isinstance(value, str) and value.strip(): |
| return value.strip() |
| return "Qwen/Qwen2.5-0.5B-Instruct" |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.source == "sweep": |
| checkpoint_dir = ROOT / "checkpoints" / "sweeps" / args.run_id |
| report_dir = ROOT / "outputs" / "reports" / "sweeps" / args.run_id |
| else: |
| checkpoint_dir = ROOT / "checkpoints" |
| report_dir = ROOT / "outputs" / "reports" |
|
|
| active_dir = ROOT / "checkpoints" / "active" |
| active_report_dir = ROOT / "outputs" / "reports" / "active_model" |
| active_dir.mkdir(parents=True, exist_ok=True) |
|
|
| availability = { |
| "grpo_adapter": _replace_path(checkpoint_dir / "grpo_adapter", active_dir / "grpo_adapter", mode=args.mode), |
| "merged": _replace_path(checkpoint_dir / "merged", active_dir / "merged", mode=args.mode), |
| "sft_adapter": _replace_path(checkpoint_dir / "sft_adapter", active_dir / "sft_adapter", mode=args.mode), |
| } |
| if not any(availability.values()): |
| raise SystemExit(f"no_model_artifacts_found:{checkpoint_dir}") |
|
|
| reports = _copy_reports(report_dir, active_report_dir) |
| manifest = { |
| "status": "ok", |
| "enabled": not args.disable, |
| "activated_at_utc": datetime.now(timezone.utc).isoformat(), |
| "run_id": args.run_id, |
| "source": args.source, |
| "label": args.label, |
| "model_id": _model_id(checkpoint_dir, report_dir), |
| "base_model": _model_id(checkpoint_dir, report_dir), |
| "preferred_artifact": args.preferred_artifact, |
| "mode": args.mode, |
| "source_checkpoint_dir": str(checkpoint_dir.relative_to(ROOT)), |
| "source_report_dir": str(report_dir.relative_to(ROOT)) if report_dir.exists() else "", |
| "grpo_adapter": "checkpoints/active/grpo_adapter", |
| "merged_model": "checkpoints/active/merged", |
| "sft_adapter": "checkpoints/active/sft_adapter", |
| "availability": availability, |
| "reports": reports, |
| "notes": ( |
| "This manifest controls local product inference. Prefer grpo_adapter for the RL policy; " |
| "merged is the SFT baseline fallback when no GRPO adapter is available." |
| ), |
| } |
| (active_dir / "active_model_manifest.json").write_text( |
| json.dumps(manifest, ensure_ascii=True, indent=2), |
| encoding="utf-8", |
| ) |
| for mirror in [ |
| ROOT / "outputs" / "reports" / "active_model" / "active_model_manifest.json", |
| ROOT / "docs" / "results" / "active_model_manifest.json", |
| ]: |
| mirror.parent.mkdir(parents=True, exist_ok=True) |
| mirror.write_text(json.dumps(manifest, ensure_ascii=True, indent=2), encoding="utf-8") |
| print(json.dumps(manifest, ensure_ascii=True, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| main() |
| except KeyboardInterrupt: |
| sys.exit(130) |
|
|