File size: 5,088 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Active trained-model discovery for product inference.

The HF training Space writes model artifacts into per-sweep folders. The app
uses this module to find the locally activated artifact without hard-coding a
specific checkpoint path into the API or agent stack.
"""

from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[3]
ACTIVE_DIR = ROOT / "checkpoints" / "active"
MANIFEST_PATH = ACTIVE_DIR / "active_model_manifest.json"
DEFAULT_RUN_ID = "qwen-qwen2-5-0-5b-instruct"


def _truthy(value: str | None) -> bool | None:
    if value is None:
        return None
    lowered = value.strip().lower()
    if lowered in {"1", "true", "yes", "on"}:
        return True
    if lowered in {"0", "false", "no", "off"}:
        return False
    return None


def _read_json(path: Path) -> dict[str, Any]:
    if not path.exists():
        return {}
    try:
        payload = json.loads(path.read_text(encoding="utf-8"))
    except Exception:
        return {}
    return payload if isinstance(payload, dict) else {}


def _resolve_path(value: str | Path | None, default: Path) -> Path:
    if value is None or str(value).strip() == "":
        return default
    path = Path(str(value)).expanduser()
    if path.is_absolute():
        return path
    return ROOT / path


def _adapter_base_model(adapter_dir: Path) -> str:
    payload = _read_json(adapter_dir / "adapter_config.json")
    value = payload.get("base_model_name_or_path")
    return str(value) if isinstance(value, str) else ""


def active_model_status() -> dict[str, Any]:
    """Return the activated model artifact contract used by the app."""

    manifest = _read_json(MANIFEST_PATH)
    env_enabled = _truthy(os.getenv("POLYGUARD_ENABLE_ACTIVE_MODEL"))
    manifest_enabled = bool(manifest.get("enabled", False))
    enabled = env_enabled if env_enabled is not None else manifest_enabled

    preferred_artifact = (
        os.getenv("POLYGUARD_ACTIVE_PREFERRED_ARTIFACT")
        or str(manifest.get("preferred_artifact") or "grpo_adapter")
    )
    if preferred_artifact not in {"grpo_adapter", "merged", "sft_adapter"}:
        preferred_artifact = "grpo_adapter"

    grpo_adapter = _resolve_path(
        os.getenv("POLYGUARD_ACTIVE_GRPO_ADAPTER") or manifest.get("grpo_adapter"),
        ACTIVE_DIR / "grpo_adapter",
    )
    sft_adapter = _resolve_path(
        os.getenv("POLYGUARD_ACTIVE_SFT_ADAPTER") or manifest.get("sft_adapter"),
        ACTIVE_DIR / "sft_adapter",
    )
    merged_model = _resolve_path(
        os.getenv("POLYGUARD_ACTIVE_MERGED_MODEL") or manifest.get("merged_model"),
        ACTIVE_DIR / "merged",
    )
    base_model = (
        os.getenv("POLYGUARD_ACTIVE_BASE_MODEL")
        or str(manifest.get("base_model") or "")
        or _adapter_base_model(grpo_adapter)
        or _adapter_base_model(sft_adapter)
        or os.getenv("POLYGUARD_HF_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
    )

    availability = {
        "grpo_adapter": (grpo_adapter / "adapter_config.json").exists()
        and (grpo_adapter / "adapter_model.safetensors").exists(),
        "merged": (merged_model / "config.json").exists(),
        "sft_adapter": (sft_adapter / "adapter_config.json").exists()
        and (sft_adapter / "adapter_model.safetensors").exists(),
    }
    load_order = [preferred_artifact] + [
        item for item in ["grpo_adapter", "merged", "sft_adapter"] if item != preferred_artifact
    ]
    active = any(availability.values())

    return {
        "enabled": enabled,
        "active": active,
        "manifest_path": str(MANIFEST_PATH),
        "manifest_exists": MANIFEST_PATH.exists(),
        "run_id": str(manifest.get("run_id") or DEFAULT_RUN_ID),
        "source": str(manifest.get("source") or ""),
        "label": str(manifest.get("label") or ""),
        "model_id": str(manifest.get("model_id") or base_model),
        "base_model": base_model,
        "preferred_artifact": preferred_artifact,
        "load_order": load_order,
        "availability": availability,
        "paths": {
            "grpo_adapter": str(grpo_adapter),
            "merged": str(merged_model),
            "sft_adapter": str(sft_adapter),
        },
        "reports": manifest.get("reports", {}) if isinstance(manifest.get("reports"), dict) else {},
        "notes": str(manifest.get("notes") or ""),
    }


def available_artifact_path(status: dict[str, Any] | None = None) -> tuple[str, Path] | None:
    """Return the first available artifact according to the active load order."""

    status = status or active_model_status()
    if not status.get("enabled") or not status.get("active"):
        return None
    paths = status.get("paths", {})
    availability = status.get("availability", {})
    if not isinstance(paths, dict) or not isinstance(availability, dict):
        return None
    for artifact in status.get("load_order", []):
        if availability.get(artifact) and paths.get(artifact):
            return str(artifact), Path(str(paths[artifact]))
    return None