| """Active trained-model discovery for product inference. |
| |
| The HF training Space writes model artifacts into per-sweep folders. The app |
| uses this module to find the locally activated artifact without hard-coding a |
| specific checkpoint path into the API or agent stack. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[3] |
| ACTIVE_DIR = ROOT / "checkpoints" / "active" |
| MANIFEST_PATH = ACTIVE_DIR / "active_model_manifest.json" |
| DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct" |
|
|
|
|
| def _truthy(value: str | None) -> bool | None: |
| if value is None: |
| return None |
| lowered = value.strip().lower() |
| if lowered in {"1", "true", "yes", "on"}: |
| return True |
| if lowered in {"0", "false", "no", "off"}: |
| return False |
| return None |
|
|
|
|
| def _read_json(path: Path) -> dict[str, Any]: |
| if not path.exists(): |
| return {} |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except Exception: |
| return {} |
| return payload if isinstance(payload, dict) else {} |
|
|
|
|
| def _resolve_path(value: str | Path | None, default: Path) -> Path: |
| if value is None or str(value).strip() == "": |
| return default |
| path = Path(str(value)).expanduser() |
| if path.is_absolute(): |
| return path |
| return ROOT / path |
|
|
|
|
| def _adapter_base_model(adapter_dir: Path) -> str: |
| payload = _read_json(adapter_dir / "adapter_config.json") |
| value = payload.get("base_model_name_or_path") |
| return str(value) if isinstance(value, str) else "" |
|
|
|
|
| def active_model_status() -> dict[str, Any]: |
| """Return the activated model artifact contract used by the app.""" |
|
|
| manifest = _read_json(MANIFEST_PATH) |
| env_enabled = _truthy(os.getenv("POLYGUARD_ENABLE_ACTIVE_MODEL")) |
| manifest_enabled = bool(manifest.get("enabled", False)) |
| enabled = env_enabled if env_enabled is not None else manifest_enabled |
|
|
| preferred_artifact = ( |
| os.getenv("POLYGUARD_ACTIVE_PREFERRED_ARTIFACT") |
| or str(manifest.get("preferred_artifact") or "grpo_adapter") |
| ) |
| if preferred_artifact not in {"grpo_adapter", "merged", "sft_adapter"}: |
| preferred_artifact = "grpo_adapter" |
|
|
| grpo_adapter = _resolve_path( |
| os.getenv("POLYGUARD_ACTIVE_GRPO_ADAPTER") or manifest.get("grpo_adapter"), |
| ACTIVE_DIR / "grpo_adapter", |
| ) |
| sft_adapter = _resolve_path( |
| os.getenv("POLYGUARD_ACTIVE_SFT_ADAPTER") or manifest.get("sft_adapter"), |
| ACTIVE_DIR / "sft_adapter", |
| ) |
| merged_model = _resolve_path( |
| os.getenv("POLYGUARD_ACTIVE_MERGED_MODEL") or manifest.get("merged_model"), |
| ACTIVE_DIR / "merged", |
| ) |
| base_model = ( |
| os.getenv("POLYGUARD_ACTIVE_BASE_MODEL") |
| or str(manifest.get("base_model") or "") |
| or _adapter_base_model(grpo_adapter) |
| or _adapter_base_model(sft_adapter) |
| or os.getenv("POLYGUARD_HF_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") |
| ) |
|
|
| availability = { |
| "grpo_adapter": (grpo_adapter / "adapter_config.json").exists() |
| and (grpo_adapter / "adapter_model.safetensors").exists(), |
| "merged": (merged_model / "config.json").exists(), |
| "sft_adapter": (sft_adapter / "adapter_config.json").exists() |
| and (sft_adapter / "adapter_model.safetensors").exists(), |
| } |
| load_order = [preferred_artifact] + [ |
| item for item in ["grpo_adapter", "merged", "sft_adapter"] if item != preferred_artifact |
| ] |
| active = any(availability.values()) |
|
|
| return { |
| "enabled": enabled, |
| "active": active, |
| "manifest_path": str(MANIFEST_PATH), |
| "manifest_exists": MANIFEST_PATH.exists(), |
| "run_id": str(manifest.get("run_id") or DEFAULT_RUN_ID), |
| "source": str(manifest.get("source") or ""), |
| "label": str(manifest.get("label") or ""), |
| "model_id": str(manifest.get("model_id") or base_model), |
| "base_model": base_model, |
| "preferred_artifact": preferred_artifact, |
| "load_order": load_order, |
| "availability": availability, |
| "paths": { |
| "grpo_adapter": str(grpo_adapter), |
| "merged": str(merged_model), |
| "sft_adapter": str(sft_adapter), |
| }, |
| "reports": manifest.get("reports", {}) if isinstance(manifest.get("reports"), dict) else {}, |
| "notes": str(manifest.get("notes") or ""), |
| } |
|
|
|
|
| def available_artifact_path(status: dict[str, Any] | None = None) -> tuple[str, Path] | None: |
| """Return the first available artifact according to the active load order.""" |
|
|
| status = status or active_model_status() |
| if not status.get("enabled") or not status.get("active"): |
| return None |
| paths = status.get("paths", {}) |
| availability = status.get("availability", {}) |
| if not isinstance(paths, dict) or not isinstance(availability, dict): |
| return None |
| for artifact in status.get("load_order", []): |
| if availability.get(artifact) and paths.get(artifact): |
| return str(artifact), Path(str(paths[artifact])) |
| return None |
|
|