"""Active trained-model discovery for product inference. The HF training Space writes model artifacts into per-sweep folders. The app uses this module to find the locally activated artifact without hard-coding a specific checkpoint path into the API or agent stack. """ from __future__ import annotations import json import os from pathlib import Path from typing import Any ROOT = Path(__file__).resolve().parents[3] ACTIVE_DIR = ROOT / "checkpoints" / "active" MANIFEST_PATH = ACTIVE_DIR / "active_model_manifest.json" DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct" def _truthy(value: str | None) -> bool | None: if value is None: return None lowered = value.strip().lower() if lowered in {"1", "true", "yes", "on"}: return True if lowered in {"0", "false", "no", "off"}: return False return None def _read_json(path: Path) -> dict[str, Any]: if not path.exists(): return {} try: payload = json.loads(path.read_text(encoding="utf-8")) except Exception: return {} return payload if isinstance(payload, dict) else {} def _resolve_path(value: str | Path | None, default: Path) -> Path: if value is None or str(value).strip() == "": return default path = Path(str(value)).expanduser() if path.is_absolute(): return path return ROOT / path def _adapter_base_model(adapter_dir: Path) -> str: payload = _read_json(adapter_dir / "adapter_config.json") value = payload.get("base_model_name_or_path") return str(value) if isinstance(value, str) else "" def active_model_status() -> dict[str, Any]: """Return the activated model artifact contract used by the app.""" manifest = _read_json(MANIFEST_PATH) env_enabled = _truthy(os.getenv("POLYGUARD_ENABLE_ACTIVE_MODEL")) manifest_enabled = bool(manifest.get("enabled", False)) enabled = env_enabled if env_enabled is not None else manifest_enabled preferred_artifact = ( os.getenv("POLYGUARD_ACTIVE_PREFERRED_ARTIFACT") or str(manifest.get("preferred_artifact") or "grpo_adapter") ) if preferred_artifact not in {"grpo_adapter", "merged", "sft_adapter"}: preferred_artifact = "grpo_adapter" grpo_adapter = _resolve_path( os.getenv("POLYGUARD_ACTIVE_GRPO_ADAPTER") or manifest.get("grpo_adapter"), ACTIVE_DIR / "grpo_adapter", ) sft_adapter = _resolve_path( os.getenv("POLYGUARD_ACTIVE_SFT_ADAPTER") or manifest.get("sft_adapter"), ACTIVE_DIR / "sft_adapter", ) merged_model = _resolve_path( os.getenv("POLYGUARD_ACTIVE_MERGED_MODEL") or manifest.get("merged_model"), ACTIVE_DIR / "merged", ) base_model = ( os.getenv("POLYGUARD_ACTIVE_BASE_MODEL") or str(manifest.get("base_model") or "") or _adapter_base_model(grpo_adapter) or _adapter_base_model(sft_adapter) or os.getenv("POLYGUARD_HF_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") ) availability = { "grpo_adapter": (grpo_adapter / "adapter_config.json").exists() and (grpo_adapter / "adapter_model.safetensors").exists(), "merged": (merged_model / "config.json").exists(), "sft_adapter": (sft_adapter / "adapter_config.json").exists() and (sft_adapter / "adapter_model.safetensors").exists(), } load_order = [preferred_artifact] + [ item for item in ["grpo_adapter", "merged", "sft_adapter"] if item != preferred_artifact ] active = any(availability.values()) return { "enabled": enabled, "active": active, "manifest_path": str(MANIFEST_PATH), "manifest_exists": MANIFEST_PATH.exists(), "run_id": str(manifest.get("run_id") or DEFAULT_RUN_ID), "source": str(manifest.get("source") or ""), "label": str(manifest.get("label") or ""), "model_id": str(manifest.get("model_id") or base_model), "base_model": base_model, "preferred_artifact": preferred_artifact, "load_order": load_order, "availability": availability, "paths": { "grpo_adapter": str(grpo_adapter), "merged": str(merged_model), "sft_adapter": str(sft_adapter), }, "reports": manifest.get("reports", {}) if isinstance(manifest.get("reports"), dict) else {}, "notes": str(manifest.get("notes") or ""), } def available_artifact_path(status: dict[str, Any] | None = None) -> tuple[str, Path] | None: """Return the first available artifact according to the active load order.""" status = status or active_model_status() if not status.get("enabled") or not status.get("active"): return None paths = status.get("paths", {}) availability = status.get("availability", {}) if not isinstance(paths, dict) or not isinstance(availability, dict): return None for artifact in status.get("load_order", []): if availability.get(artifact) and paths.get(artifact): return str(artifact), Path(str(paths[artifact])) return None