TheJackBright's picture
Deploy GitHub root master to Space
c296d62
"""Active trained-model discovery for product inference.
The HF training Space writes model artifacts into per-sweep folders. The app
uses this module to find the locally activated artifact without hard-coding a
specific checkpoint path into the API or agent stack.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[3]
ACTIVE_DIR = ROOT / "checkpoints" / "active"
MANIFEST_PATH = ACTIVE_DIR / "active_model_manifest.json"
DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct"
def _truthy(value: str | None) -> bool | None:
if value is None:
return None
lowered = value.strip().lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
return None
def _read_json(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return {}
return payload if isinstance(payload, dict) else {}
def _resolve_path(value: str | Path | None, default: Path) -> Path:
if value is None or str(value).strip() == "":
return default
path = Path(str(value)).expanduser()
if path.is_absolute():
return path
return ROOT / path
def _adapter_base_model(adapter_dir: Path) -> str:
payload = _read_json(adapter_dir / "adapter_config.json")
value = payload.get("base_model_name_or_path")
return str(value) if isinstance(value, str) else ""
def active_model_status() -> dict[str, Any]:
"""Return the activated model artifact contract used by the app."""
manifest = _read_json(MANIFEST_PATH)
env_enabled = _truthy(os.getenv("POLYGUARD_ENABLE_ACTIVE_MODEL"))
manifest_enabled = bool(manifest.get("enabled", False))
enabled = env_enabled if env_enabled is not None else manifest_enabled
preferred_artifact = (
os.getenv("POLYGUARD_ACTIVE_PREFERRED_ARTIFACT")
or str(manifest.get("preferred_artifact") or "grpo_adapter")
)
if preferred_artifact not in {"grpo_adapter", "merged", "sft_adapter"}:
preferred_artifact = "grpo_adapter"
grpo_adapter = _resolve_path(
os.getenv("POLYGUARD_ACTIVE_GRPO_ADAPTER") or manifest.get("grpo_adapter"),
ACTIVE_DIR / "grpo_adapter",
)
sft_adapter = _resolve_path(
os.getenv("POLYGUARD_ACTIVE_SFT_ADAPTER") or manifest.get("sft_adapter"),
ACTIVE_DIR / "sft_adapter",
)
merged_model = _resolve_path(
os.getenv("POLYGUARD_ACTIVE_MERGED_MODEL") or manifest.get("merged_model"),
ACTIVE_DIR / "merged",
)
base_model = (
os.getenv("POLYGUARD_ACTIVE_BASE_MODEL")
or str(manifest.get("base_model") or "")
or _adapter_base_model(grpo_adapter)
or _adapter_base_model(sft_adapter)
or os.getenv("POLYGUARD_HF_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
)
availability = {
"grpo_adapter": (grpo_adapter / "adapter_config.json").exists()
and (grpo_adapter / "adapter_model.safetensors").exists(),
"merged": (merged_model / "config.json").exists(),
"sft_adapter": (sft_adapter / "adapter_config.json").exists()
and (sft_adapter / "adapter_model.safetensors").exists(),
}
load_order = [preferred_artifact] + [
item for item in ["grpo_adapter", "merged", "sft_adapter"] if item != preferred_artifact
]
active = any(availability.values())
return {
"enabled": enabled,
"active": active,
"manifest_path": str(MANIFEST_PATH),
"manifest_exists": MANIFEST_PATH.exists(),
"run_id": str(manifest.get("run_id") or DEFAULT_RUN_ID),
"source": str(manifest.get("source") or ""),
"label": str(manifest.get("label") or ""),
"model_id": str(manifest.get("model_id") or base_model),
"base_model": base_model,
"preferred_artifact": preferred_artifact,
"load_order": load_order,
"availability": availability,
"paths": {
"grpo_adapter": str(grpo_adapter),
"merged": str(merged_model),
"sft_adapter": str(sft_adapter),
},
"reports": manifest.get("reports", {}) if isinstance(manifest.get("reports"), dict) else {},
"notes": str(manifest.get("notes") or ""),
}
def available_artifact_path(status: dict[str, Any] | None = None) -> tuple[str, Path] | None:
"""Return the first available artifact according to the active load order."""
status = status or active_model_status()
if not status.get("enabled") or not status.get("active"):
return None
paths = status.get("paths", {})
availability = status.get("availability", {})
if not isinstance(paths, dict) or not isinstance(availability, dict):
return None
for artifact in status.get("load_order", []):
if availability.get(artifact) and paths.get(artifact):
return str(artifact), Path(str(paths[artifact]))
return None