File size: 5,088 Bytes
21c7db9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """Active trained-model discovery for product inference.
The HF training Space writes model artifacts into per-sweep folders. The app
uses this module to find the locally activated artifact without hard-coding a
specific checkpoint path into the API or agent stack.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[3]
ACTIVE_DIR = ROOT / "checkpoints" / "active"
MANIFEST_PATH = ACTIVE_DIR / "active_model_manifest.json"
DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct"
def _truthy(value: str | None) -> bool | None:
if value is None:
return None
lowered = value.strip().lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
return None
def _read_json(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
return payload if isinstance(payload, dict) else {}
def _resolve_path(value: str | Path | None, default: Path) -> Path:
if value is None or str(value).strip() == "":
return default
path = Path(str(value)).expanduser()
if path.is_absolute():
return path
return ROOT / path
def _adapter_base_model(adapter_dir: Path) -> str:
payload = _read_json(adapter_dir / "adapter_config.json")
value = payload.get("base_model_name_or_path")
return str(value) if isinstance(value, str) else ""
def active_model_status() -> dict[str, Any]:
"""Return the activated model artifact contract used by the app."""
manifest = _read_json(MANIFEST_PATH)
env_enabled = _truthy(os.getenv("POLYGUARD_ENABLE_ACTIVE_MODEL"))
manifest_enabled = bool(manifest.get("enabled", False))
enabled = env_enabled if env_enabled is not None else manifest_enabled
preferred_artifact = (
os.getenv("POLYGUARD_ACTIVE_PREFERRED_ARTIFACT")
or str(manifest.get("preferred_artifact") or "grpo_adapter")
)
if preferred_artifact not in {"grpo_adapter", "merged", "sft_adapter"}:
preferred_artifact = "grpo_adapter"
grpo_adapter = _resolve_path(
os.getenv("POLYGUARD_ACTIVE_GRPO_ADAPTER") or manifest.get("grpo_adapter"),
ACTIVE_DIR / "grpo_adapter",
)
sft_adapter = _resolve_path(
os.getenv("POLYGUARD_ACTIVE_SFT_ADAPTER") or manifest.get("sft_adapter"),
ACTIVE_DIR / "sft_adapter",
)
merged_model = _resolve_path(
os.getenv("POLYGUARD_ACTIVE_MERGED_MODEL") or manifest.get("merged_model"),
ACTIVE_DIR / "merged",
)
base_model = (
os.getenv("POLYGUARD_ACTIVE_BASE_MODEL")
or str(manifest.get("base_model") or "")
or _adapter_base_model(grpo_adapter)
or _adapter_base_model(sft_adapter)
or os.getenv("POLYGUARD_HF_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
)
availability = {
"grpo_adapter": (grpo_adapter / "adapter_config.json").exists()
and (grpo_adapter / "adapter_model.safetensors").exists(),
"merged": (merged_model / "config.json").exists(),
"sft_adapter": (sft_adapter / "adapter_config.json").exists()
and (sft_adapter / "adapter_model.safetensors").exists(),
}
load_order = [preferred_artifact] + [
item for item in ["grpo_adapter", "merged", "sft_adapter"] if item != preferred_artifact
]
active = any(availability.values())
return {
"enabled": enabled,
"active": active,
"manifest_path": str(MANIFEST_PATH),
"manifest_exists": MANIFEST_PATH.exists(),
"run_id": str(manifest.get("run_id") or DEFAULT_RUN_ID),
"source": str(manifest.get("source") or ""),
"label": str(manifest.get("label") or ""),
"model_id": str(manifest.get("model_id") or base_model),
"base_model": base_model,
"preferred_artifact": preferred_artifact,
"load_order": load_order,
"availability": availability,
"paths": {
"grpo_adapter": str(grpo_adapter),
"merged": str(merged_model),
"sft_adapter": str(sft_adapter),
},
"reports": manifest.get("reports", {}) if isinstance(manifest.get("reports"), dict) else {},
"notes": str(manifest.get("notes") or ""),
}
def available_artifact_path(status: dict[str, Any] | None = None) -> tuple[str, Path] | None:
"""Return the first available artifact according to the active load order."""
status = status or active_model_status()
if not status.get("enabled") or not status.get("active"):
return None
paths = status.get("paths", {})
availability = status.get("availability", {})
if not isinstance(paths, dict) or not isinstance(availability, dict):
return None
for artifact in status.get("load_order", []):
if availability.get(artifact) and paths.get(artifact):
return str(artifact), Path(str(paths[artifact]))
return None
|